imate
C++/CUDA Reference
Loading...
Searching...
No Matches
cu_dense_matrix.h
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12#ifndef _CU_LINEAR_OPERATOR_CU_DENSE_MATRIX_H_
13#define _CU_LINEAR_OPERATOR_CU_DENSE_MATRIX_H_
14
15
16// =======
17// Headers
18// =======
19
20#include "../_definitions/types.h" // FlagType, LongIndexType
21#include "./cu_matrix.h" // cuMatrix
22
23
24// ===============
25// cu Dense Matrix
26// ===============
27
41
42template <typename DataType>
43class cuDenseMatrix : public cuMatrix<DataType>
44{
45 public:
46
47 // Member methods
49
51 const DataType* A_,
52 const LongIndexType num_rows_,
53 const LongIndexType num_columns_,
54 const FlagType A_is_row_major_,
55 const FlagType A_is_symmetric_,
56 const int num_gpu_devices_);
57
58 virtual ~cuDenseMatrix();
59
60 virtual FlagType is_identity_matrix() const;
61
62 virtual void dot(
63 const DataType* device_vector,
64 DataType* device_product);
65
66 virtual void dot_plus(
67 const DataType* device_vector,
68 const DataType alpha,
69 DataType* device_product);
70
71 virtual void transpose_dot(
72 const DataType* device_vector,
73 DataType* device_product);
74
75 virtual void transpose_dot_plus(
76 const DataType* device_vector,
77 const DataType alpha,
78 DataType* device_product);
79
80 protected:
81
82 // Member methods
83 virtual void copy_host_to_device();
84
85 // Member data
86 DataType** device_A;
87 const DataType* A;
89};
90
91#endif // _CU_LINEAR_OPERATOR_CU_DENSE_MATRIX_H_
Container for dense matrices.
virtual void transpose_dot_plus(const DataType *device_vector, const DataType alpha, DataType *device_product)
Transposed-matrix vector product written in place.
virtual void transpose_dot(const DataType *device_vector, DataType *device_product)
Transposed-matrix vector product.
virtual FlagType is_identity_matrix() const
Checks whether the matrix is identity.
DataType ** device_A
virtual void dot_plus(const DataType *device_vector, const DataType alpha, DataType *device_product)
Matrix vector product written in place.
cuDenseMatrix()
Default constructor.
virtual void copy_host_to_device()
Copies the member data from the host memory to the device memory.
virtual ~cuDenseMatrix()
Destructor. This function removes data from GPU devices.
const DataType * A
const FlagType A_is_row_major
virtual void dot(const DataType *device_vector, DataType *device_product)
Matrix vector product.
Base class for constant matrices.
Definition cu_matrix.h:45
int LongIndexType
Definition types.h:60
int FlagType
Definition types.h:68