imate
C++/CUDA Reference
c_dense_affine_matrix_function.h
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3  * SPDX-License-Identifier: BSD-3-Clause
4  * SPDX-FileType: SOURCE
5  *
6  * This program is free software: you can redistribute it and/or modify it
7  * under the terms of the license found in the LICENSE.txt file in the root
8  * directory of this source tree.
9  */
10 
11 
12 #ifndef _C_LINEAR_OPERATOR_C_DENSE_AFFINE_MATRIX_FUNCTION_H_
13 #define _C_LINEAR_OPERATOR_C_DENSE_AFFINE_MATRIX_FUNCTION_H_
14 
15 
16 // =======
17 // Headers
18 // =======
19 
20 #include "../_definitions/types.h" // LongIndexType, FlagType
21 #include "./c_affine_matrix_function.h" // cAffineMatrixFunction
22 #include "./c_dense_matrix.h" // cDenseMatrix
23 
24 
25 // ==============================
26 // c Dense Affine Matrix Function
27 // ==============================
28 
29 template <typename DataType>
31 {
32  public:
33 
34  // Member methods
36  const DataType* A_,
37  const FlagType A_is_row_major_,
38  const LongIndexType num_rows_,
39  const LongIndexType num_colums_);
40 
42  const DataType* A_,
43  const FlagType A_is_row_major_,
44  const LongIndexType num_rows_,
45  const LongIndexType num_columns_,
46  const DataType* B_,
47  const FlagType B_is_row_major_);
48 
50 
51  virtual void dot(
52  const DataType* vector,
53  DataType* product);
54 
55  virtual void transpose_dot(
56  const DataType* vector,
57  DataType* product);
58 
59  protected:
60 
61  // Member data
64 };
65 
66 #endif // _C_LINEAR_OPERATOR_C_DENSE_AFFINE_MATRIX_FUNCTION_H_
Base class for affine matrix functions of one parameter.
cDenseAffineMatrixFunction(const DataType *A_, const FlagType A_is_row_major_, const LongIndexType num_rows_, const LongIndexType num_colums_)
Constructor. Matrix B is assumed to be the identity matrix.
virtual void dot(const DataType *vector, DataType *product)
Computes the matrix vector product:
virtual void transpose_dot(const DataType *vector, DataType *product)
Computes the matrix vector product:
int LongIndexType
Definition: types.h:60
int FlagType
Definition: types.h:68