19#include "../_cu_definitions/cu_types.h"
21#include "../_definitions/debugging.h"
53template <
typename DataType>
55 const DataType* A_data_,
61 const int num_gpu_devices_):
68 A(A_data_, A_indices_, A_index_pointer_, num_rows_, num_columns_,
69 A_is_symmetric_, num_gpu_devices_)
123template <
typename DataType>
125 const DataType* A_data_,
131 const DataType* B_data_,
135 const int num_gpu_devices_):
142 A(A_data_, A_indices_, A_index_pointer_, num_rows_, num_columns_,
143 A_is_symmetric_, num_gpu_devices_),
144 B(B_data_, B_indices_, B_index_pointer_, num_rows_, num_columns_,
145 B_is_symmetric_, num_gpu_devices_)
148 if (this->
B.is_identity_matrix())
166template <
typename DataType>
188template <
typename DataType>
194 this->A.set_symmetry(1);
195 this->B.set_symmetry(1);
199 this->A.set_symmetry(0);
200 this->B.set_symmetry(0);
225template <
typename DataType>
227 const DataType* vector,
231 this->A.dot(vector, product);
235 if (this->B_is_identity)
238 ASSERT((this->parameters != NULL),
"Parameter is not set.");
242 (this->num_rows < this->num_columns) ? \
243 this->num_rows : this->num_columns;
246 this->_add_scaled_vector(vector, min_vector_size,
247 this->parameters[0], product);
252 ASSERT((this->parameters != NULL),
"Parameter is not set.");
255 this->B.dot_plus(vector, this->parameters[0], product);
280template <
typename DataType>
282 const DataType* vector,
286 this->A.transpose_dot(vector, product);
290 if (this->B_is_identity)
293 ASSERT((this->parameters != NULL),
"Parameter is not set.");
297 (this->num_rows < this->num_columns) ? \
298 this->num_rows : this->num_columns;
301 this->_add_scaled_vector(vector, min_vector_size,
302 this->parameters[0], product);
307 ASSERT((this->parameters != NULL),
"Parameter is not set.");
310 this->B.transpose_dot_plus(vector, this->parameters[0], product);
319#if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
323#if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
327#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
331#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
335#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
339#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
Base class for cLinearOperator and cuLinearOperator . This class is not templated so that both cpp an...
FlagType eigenvalue_relation_known
Container for CSR affine matrix functions of one parameter.
virtual void set_symmetry(const FlagType symmetric)
Specify whether the matrices are symmetic or non-symmetric.
cuCSRAffineMatrixFunction(const DataType *A_data_, const LongIndexType *A_indices_, const LongIndexType *A_index_pointer_, const LongIndexType num_rows_, const LongIndexType num_columns_, const FlagType A_is_symmetric_, const int num_gpu_devices_)
Default constructor.
virtual void dot(const DataType *vector, DataType *product)
Matrix vector product.
virtual void transpose_dot(const DataType *vector, DataType *product)
Matrix vector product written in place.
cuCSRMatrix< DataType > B
virtual ~cuCSRAffineMatrixFunction()
Destructor.
Base class for linear operators. This class serves as interface for all derived classes.
void initialize_cusparse_handle()
Creates a cusparseHandle_t object, if not created already.
#define ASSERT(condition, message)