imate
C++/CUDA Reference
c_csr_matrix.h
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3  * SPDX-License-Identifier: BSD-3-Clause
4  * SPDX-FileType: SOURCE
5  *
6  * This program is free software: you can redistribute it and/or modify it
7  * under the terms of the license found in the LICENSE.txt file in the root
8  * directory of this source tree.
9  */
10 
11 
12 #ifndef _C_LINEAR_OPERATOR_C_CSR_MATRIX_H_
13 #define _C_LINEAR_OPERATOR_C_CSR_MATRIX_H_
14 
15 
16 // =======
17 // Headers
18 // =======
19 
20 #include "../_definitions/types.h" // FlagType, LongIndexType
21 #include "./c_matrix.h" // cMatrix
22 
23 
24 // ============
25 // c CSR Matrix
26 // ============
27 
28 template <typename DataType>
29 class cCSRMatrix : public cMatrix<DataType>
30 {
31  public:
32 
33  // Member methods
34  cCSRMatrix();
35 
36  cCSRMatrix(
37  const DataType* A_data_,
38  const LongIndexType* A_indices_,
39  const LongIndexType* A_index_pointer_,
40  const LongIndexType num_rows_,
41  const LongIndexType num_columns_);
42 
43  virtual ~cCSRMatrix();
44 
45  virtual FlagType is_identity_matrix() const;
46 
47  LongIndexType get_nnz() const;
48 
49  virtual void dot(
50  const DataType* vector,
51  DataType* product);
52 
53  virtual void dot_plus(
54  const DataType* vector,
55  const DataType alpha,
56  DataType* product);
57 
58  virtual void transpose_dot(
59  const DataType* vector,
60  DataType* product);
61 
62  virtual void transpose_dot_plus(
63  const DataType* vector,
64  const DataType alpha,
65  DataType* product);
66 
67  protected:
68 
69  // Member data
70  const DataType* A_data;
73 };
74 
75 #endif // _C_LINEAR_OPERATOR_C_CSR_MATRIX_H_
virtual void transpose_dot_plus(const DataType *vector, const DataType alpha, DataType *product)
const LongIndexType * A_index_pointer
Definition: c_csr_matrix.h:72
virtual void transpose_dot(const DataType *vector, DataType *product)
virtual void dot(const DataType *vector, DataType *product)
LongIndexType get_nnz() const
Returns the number of non-zero elements of the sparse matrix.
const DataType * A_data
Definition: c_csr_matrix.h:70
virtual FlagType is_identity_matrix() const
Checks whether the matrix is identity.
virtual ~cCSRMatrix()
const LongIndexType * A_indices
Definition: c_csr_matrix.h:71
virtual void dot_plus(const DataType *vector, const DataType alpha, DataType *product)
Base class for constant matrices.
Definition: c_matrix.h:41
int LongIndexType
Definition: types.h:60
int FlagType
Definition: types.h:68