imate
C++/CUDA Reference
cu_dense_matrix.h
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3  * SPDX-License-Identifier: BSD-3-Clause
4  * SPDX-FileType: SOURCE
5  *
6  * This program is free software: you can redistribute it and/or modify it
7  * under the terms of the license found in the LICENSE.txt file in the root
8  * directory of this source tree.
9  */
10 
11 
12 #ifndef _CU_LINEAR_OPERATOR_CU_DENSE_MATRIX_H_
13 #define _CU_LINEAR_OPERATOR_CU_DENSE_MATRIX_H_
14 
15 
16 // =======
17 // Headers
18 // =======
19 
20 #include "../_definitions/types.h" // FlagType, LongIndexType
21 #include "../_c_linear_operator/c_dense_matrix.h" // cDenseMatrix
22 #include "./cu_matrix.h" // cuMatrix
23 
24 
25 // ===============
26 // cu Dense Matrix
27 // ===============
28 
29 template <typename DataType>
31  public cuMatrix<DataType>,
32  public cDenseMatrix<DataType>
33 {
34  public:
35 
36  // Member methods
37  cuDenseMatrix();
38 
40  const DataType* A_,
41  const LongIndexType num_rows_,
42  const LongIndexType num_columns_,
43  const FlagType A_is_row_major_,
44  const int num_gpu_devices_);
45 
46  virtual ~cuDenseMatrix();
47 
48  virtual void dot(
49  const DataType* device_vector,
50  DataType* device_product);
51 
52  virtual void dot_plus(
53  const DataType* device_vector,
54  const DataType alpha,
55  DataType* device_product);
56 
57  virtual void transpose_dot(
58  const DataType* device_vector,
59  DataType* device_product);
60 
61  virtual void transpose_dot_plus(
62  const DataType* device_vector,
63  const DataType alpha,
64  DataType* device_product);
65 
66  protected:
67 
68  // Member methods
69  virtual void copy_host_to_device();
70 
71  // Member data
72  DataType** device_A;
73 };
74 
75 #endif // _CU_LINEAR_OPERATOR_CU_DENSE_MATRIX_H_
virtual void transpose_dot_plus(const DataType *device_vector, const DataType alpha, DataType *device_product)
virtual void transpose_dot(const DataType *device_vector, DataType *device_product)
DataType ** device_A
virtual void dot_plus(const DataType *device_vector, const DataType alpha, DataType *device_product)
virtual void copy_host_to_device()
Copies the member data from the host memory to the device memory.
virtual ~cuDenseMatrix()
virtual void dot(const DataType *device_vector, DataType *device_product)
Base class for constant matrices.
Definition: cu_matrix.h:41
int LongIndexType
Definition: types.h:60
int FlagType
Definition: types.h:68