imate
C++/CUDA Reference
c_csr_matrix.cpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3  * SPDX-License-Identifier: BSD-3-Clause
4  * SPDX-FileType: SOURCE
5  *
6  * This program is free software: you can redistribute it and/or modify it
7  * under the terms of the license found in the LICENSE.txt file in the root
8  * directory of this source tree.
9  */
10 
11 
12 // =======
13 // Headers
14 // =======
15 
16 #include "./c_csr_matrix.h"
17 #include <cstddef> // NULL
18 #include "../_c_basic_algebra/c_matrix_operations.h" // cMatrixOperations
19 
20 
21 // =============
22 // constructor 1
23 // =============
24 
25 template <typename DataType>
27  A_data(NULL),
28  A_indices(NULL),
29  A_index_pointer(NULL)
30 {
31 }
32 
33 
34 // =============
35 // constructor 2
36 // =============
37 
38 template <typename DataType>
40  const DataType* A_data_,
41  const LongIndexType* A_indices_,
42  const LongIndexType* A_index_pointer_,
43  const LongIndexType num_rows_,
44  const LongIndexType num_columns_):
45 
46  // Base class constructor
47  cLinearOperator<DataType>(num_rows_, num_columns_),
48 
49  // Initializer list
50  A_data(A_data_),
51  A_indices(A_indices_),
52  A_index_pointer(A_index_pointer_)
53 {
54 }
55 
56 
57 // ==========
58 // destructor
59 // ==========
60 
61 template <typename DataType>
63 {
64 }
65 
66 
67 // ==================
68 // is identity matrix
69 // ==================
70 
79 
80 template <typename DataType>
82 {
83  FlagType matrix_is_identity = 1;
84  LongIndexType index_pointer;
85  LongIndexType column;
86 
87  // Check matrix element-wise
88  for (LongIndexType row=0; row < this->num_rows; ++row)
89  {
90  for (index_pointer=this->A_index_pointer[row];
91  index_pointer < this->A_index_pointer[row+1];
92  ++index_pointer)
93  {
94  column = this->A_indices[index_pointer];
95 
96  if ((row == column) && \
97  (this->A_data[index_pointer] != 1.0))
98  {
99  matrix_is_identity = 0;
100  return matrix_is_identity;
101  }
102  else if (this->A_data[index_pointer] != 0.0)
103  {
104  matrix_is_identity = 0;
105  return matrix_is_identity;
106  }
107  }
108  }
109 
110  return matrix_is_identity;
111 }
112 
113 
114 // =======
115 // get nnz
116 // =======
117 
125 
126 template <typename DataType>
128 {
129  return this->A_index_pointer[this->num_rows];
130 }
131 
132 
133 // ===
134 // dot
135 // ===
136 
137 template <typename DataType>
139  const DataType* vector,
140  DataType* product)
141 {
143  this->A_data,
144  this->A_indices,
145  this->A_index_pointer,
146  vector,
147  this->num_rows,
148  product);
149 }
150 
151 
152 // ========
153 // dot plus
154 // ========
155 
156 template <typename DataType>
158  const DataType* vector,
159  const DataType alpha,
160  DataType* product)
161 {
163  this->A_data,
164  this->A_indices,
165  this->A_index_pointer,
166  vector,
167  alpha,
168  this->num_rows,
169  product);
170 }
171 
172 
173 // =============
174 // transpose dot
175 // =============
176 
177 template <typename DataType>
179  const DataType* vector,
180  DataType* product)
181 {
183  this->A_data,
184  this->A_indices,
185  this->A_index_pointer,
186  vector,
187  this->num_rows,
188  this->num_columns,
189  product);
190 }
191 
192 
193 // ==================
194 // transpose dot plus
195 // ==================
196 
197 template <typename DataType>
199  const DataType* vector,
200  const DataType alpha,
201  DataType* product)
202 {
204  this->A_data,
205  this->A_indices,
206  this->A_index_pointer,
207  vector,
208  alpha,
209  this->num_rows,
210  product);
211 }
212 
213 
214 // ===============================
215 // Explicit template instantiation
216 // ===============================
217 
218 template class cCSRMatrix<float>;
219 template class cCSRMatrix<double>;
220 template class cCSRMatrix<long double>;
virtual void transpose_dot_plus(const DataType *vector, const DataType alpha, DataType *product)
virtual void transpose_dot(const DataType *vector, DataType *product)
virtual void dot(const DataType *vector, DataType *product)
LongIndexType get_nnz() const
Returns the number of non-zero elements of the sparse matrix.
virtual FlagType is_identity_matrix() const
Checks whether the matrix is identity.
virtual ~cCSRMatrix()
virtual void dot_plus(const DataType *vector, const DataType alpha, DataType *product)
Base class for linear operators. This class serves as interface for all derived classes.
static void csr_transposed_matvec(const DataType *A_data, const LongIndexType *A_column_indices, const LongIndexType *A_index_pointer, const DataType *b, const LongIndexType num_rows, const LongIndexType num_columns, DataType *c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csr_matvec(const DataType *A_data, const LongIndexType *A_column_indices, const LongIndexType *A_index_pointer, const DataType *b, const LongIndexType num_rows, DataType *c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csr_matvec_plus(const DataType *A_data, const LongIndexType *A_column_indices, const LongIndexType *A_index_pointer, const DataType *b, const DataType alpha, const LongIndexType num_rows, DataType *c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csr_transposed_matvec_plus(const DataType *A_data, const LongIndexType *A_column_indices, const LongIndexType *A_index_pointer, const DataType *b, const DataType alpha, const LongIndexType num_rows, DataType *c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
int LongIndexType
Definition: types.h:60
int FlagType
Definition: types.h:68