19#include "../_cu_definitions/cu_types.h"
21#include "../_definitions/debugging.h"
52template <
typename DataType>
59 const int num_gpu_devices_):
66 A(A_, num_rows_, num_columns_, A_is_row_major_, A_is_symmetric_,
121template <
typename DataType>
131 const int num_gpu_devices_):
138 A(A_, num_rows_, num_columns_, A_is_row_major_, A_is_symmetric_,
140 B(B_, num_rows_, num_columns_, B_is_row_major_, B_is_symmetric_,
144 if (this->
B.is_identity_matrix())
162template <
typename DataType>
184template <
typename DataType>
190 this->A.set_symmetry(1);
191 this->B.set_symmetry(1);
195 this->A.set_symmetry(0);
196 this->B.set_symmetry(0);
221template <
typename DataType>
223 const DataType* vector,
227 this->A.dot(vector, product);
231 if (this->B_is_identity)
234 ASSERT((this->parameters != NULL),
"Parameter is not set.");
238 (this->num_rows < this->num_columns) ? \
239 this->num_rows : this->num_columns;
242 this->_add_scaled_vector(vector, min_vector_size,
243 this->parameters[0], product);
248 ASSERT((this->parameters != NULL),
"Parameter is not set.");
251 this->B.dot_plus(vector, this->parameters[0], product);
276template <
typename DataType>
278 const DataType* vector,
282 this->A.transpose_dot(vector, product);
286 if (this->B_is_identity)
289 ASSERT((this->parameters != NULL),
"Parameter is not set.");
293 (this->num_rows < this->num_columns) ? \
294 this->num_rows : this->num_columns;
297 this->_add_scaled_vector(vector, min_vector_size,
298 this->parameters[0], product);
303 ASSERT((this->parameters != NULL),
"Parameter is not set.");
306 this->B.transpose_dot_plus(vector, this->parameters[0], product);
315#if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
319#if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
323#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
327#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
331#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
335#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
Base class for cLinearOperator and cuLinearOperator . This class is not templated so that both cpp an...
FlagType eigenvalue_relation_known
Container for dense affine matrix functions of one parameter.
virtual void dot(const DataType *vector, DataType *product)
Matrix vector product.
cuDenseAffineMatrixFunction(const DataType *A_, const LongIndexType num_rows_, const LongIndexType num_columns_, const FlagType A_is_row_major_, const FlagType A_is_symmetric_, const int num_gpu_devices_)
Default constructor.
virtual void set_symmetry(const FlagType symmetric)
Specify whether the matrices are symmetic or non-symmetric.
virtual ~cuDenseAffineMatrixFunction()
Destructor.
cuDenseMatrix< DataType > B
virtual void transpose_dot(const DataType *vector, DataType *product)
Matrix vector product written in place.
Base class for linear operators. This class serves as interface for all derived classes.
void initialize_cublas_handle()
Creates a cublasHandle_t object, if not created already.
#define ASSERT(condition, message)