12#ifndef _CU_LINEAR_OPERATOR_CU_CSR_MATRIX_H_
13#define _CU_LINEAR_OPERATOR_CU_CSR_MATRIX_H_
20#include "../_definitions/types.h"
21#include "../_c_linear_operator/c_csr_matrix.h"
43template <
typename DataType>
52 const DataType* A_data_,
58 const int num_gpu_devices_);
67 const DataType* device_vector,
68 DataType* device_product);
71 const DataType* device_vector,
73 DataType* device_product);
76 const DataType* device_vector,
77 DataType* device_product);
80 const DataType* device_vector,
82 DataType* device_product);
91 cusparseOperation_t cusparse_operation,
94 cusparseDnVecDescr_t& cusparse_input_vector,
95 cusparseDnVecDescr_t& cusparse_output_vector,
96 cusparseSpMVAlg_t algorithm);
Container for CSR matrices.
size_t * device_buffer_num_bytes
const LongIndexType * A_index_pointer
cuCSRMatrix()
Default constructor.
virtual ~cuCSRMatrix()
Destructor.
virtual void transpose_dot(const DataType *device_vector, DataType *device_product)
Transposed-matrix vector product.
cusparseSpMatDescr_t * cusparse_matrix_A
virtual void dot_plus(const DataType *device_vector, const DataType alpha, DataType *device_product)
Matrix vector product written in place.
LongIndexType ** device_A_index_pointer
LongIndexType get_nnz() const
Returns the number of non-zero elements of the sparse matrix.
DataType ** device_A_data
virtual void dot(const DataType *device_vector, DataType *device_product)
Matrix vector product.
virtual void transpose_dot_plus(const DataType *device_vector, const DataType alpha, DataType *device_product)
Transposed-matrix vector product written in place.
void allocate_buffer(const int device_id, cusparseOperation_t cusparse_operation, const DataType alpha, const DataType beta, cusparseDnVecDescr_t &cusparse_input_vector, cusparseDnVecDescr_t &cusparse_output_vector, cusparseSpMVAlg_t algorithm)
Allocates an external buffer for matrix-vector multiplication using cusparseSpMV function.
const LongIndexType * A_indices
virtual void copy_host_to_device()
Copies the member data from the host memory to the device memory.
LongIndexType ** device_A_indices
virtual FlagType is_identity_matrix() const
Checks whether the matrix is identity.
Base class for constant matrices.