imate
C++/CUDA Reference
Loading...
Searching...
No Matches
cu_csc_matrix.h
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12#ifndef _CU_LINEAR_OPERATOR_CU_CSC_MATRIX_H_
13#define _CU_LINEAR_OPERATOR_CU_CSC_MATRIX_H_
14
15
16// =======
17// Headers
18// =======
19
20#include "../_definitions/types.h" // FlagType, LongIndexType
21#include "../_c_linear_operator/c_csc_matrix.h" // cCSCMatrix
22#include "./cu_matrix.h" // cuMatrix
23
24
25// =============
26// cu CSC Matrix
27// =============
28
42
43template <typename DataType>
44class cuCSCMatrix : public cuMatrix<DataType>
45{
46 public:
47
48 // Member methods
50
52 const DataType* A_data_,
53 const LongIndexType* A_indices_,
54 const LongIndexType* A_index_pointer_,
55 const LongIndexType num_rows_,
56 const LongIndexType num_columns_,
57 const FlagType A_is_symmetric_,
58 const int num_gpu_devices_);
59
60 virtual ~cuCSCMatrix();
61
62 virtual FlagType is_identity_matrix() const;
63
64 LongIndexType get_nnz() const;
65
66 virtual void dot(
67 const DataType* device_vector,
68 DataType* device_product);
69
70 virtual void dot_plus(
71 const DataType* device_vector,
72 const DataType alpha,
73 DataType* device_product);
74
75 virtual void transpose_dot(
76 const DataType* device_vector,
77 DataType* device_product);
78
79 virtual void transpose_dot_plus(
80 const DataType* device_vector,
81 const DataType alpha,
82 DataType* device_product);
83
84 protected:
85
86 // Member methods
87 virtual void copy_host_to_device();
88
89 void allocate_buffer(
90 const int device_id,
91 cusparseOperation_t cusparse_operation,
92 const DataType alpha,
93 const DataType beta,
94 cusparseDnVecDescr_t& cusparse_input_vector,
95 cusparseDnVecDescr_t& cusparse_output_vector,
96 cusparseSpMVAlg_t algorithm);
97
98 // Member data
99 const DataType* A_data;
102 DataType** device_A_data;
107 cusparseSpMatDescr_t* cusparse_matrix_A;
108};
109
110#endif // _CU_LINEAR_OPERATOR_CU_CSC_MATRIX_H_
Container for CSC matrices.
virtual void transpose_dot(const DataType *device_vector, DataType *device_product)
Transposed-matrix vector product.
virtual void transpose_dot_plus(const DataType *device_vector, const DataType alpha, DataType *device_product)
Transposed-matrix vector product written in place.
virtual void dot(const DataType *device_vector, DataType *device_product)
Matrix vector product.
LongIndexType get_nnz() const
Returns the number of non-zero elements of the sparse matrix.
LongIndexType ** device_A_index_pointer
virtual FlagType is_identity_matrix() const
Checks whether the matrix is identity.
size_t * device_buffer_num_bytes
cuCSCMatrix()
Default constructor.
const LongIndexType * A_index_pointer
cusparseSpMatDescr_t * cusparse_matrix_A
const DataType * A_data
virtual void dot_plus(const DataType *device_vector, const DataType alpha, DataType *device_product)
Matrix vector product written in place.
void ** device_buffer
DataType ** device_A_data
virtual void copy_host_to_device()
Copies the member data from the host memory to the device memory.
const LongIndexType * A_indices
LongIndexType ** device_A_indices
void allocate_buffer(const int device_id, cusparseOperation_t cusparse_operation, const DataType alpha, const DataType beta, cusparseDnVecDescr_t &cusparse_input_vector, cusparseDnVecDescr_t &cusparse_output_vector, cusparseSpMVAlg_t algorithm)
Allocates an external buffer for matrix-vector multiplication using cusparseSpMV function.
virtual ~cuCSCMatrix()
Destructor.
Base class for constant matrices.
Definition cu_matrix.h:45
int LongIndexType
Definition types.h:60
int FlagType
Definition types.h:68