imate
C++/CUDA Reference
Loading...
Searching...
No Matches
c_csr_matrix.cpp
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12// =======
13// Headers
14// =======
15
16#include "./c_csr_matrix.h"
17#include <cstddef> // NULL
18#include "../_definitions/definitions.h" // USE_OPENMP
19#if defined(USE_OPENMP) && (USE_OPENMP == 1)
20 #include <omp.h> // omp_in_parallel
21#endif
22#include "../_c_arithmetics/c_arithmetics.h" // c_arithmetics
23#include "../_c_basic_algebra/c_matrix_operations.h" // cMatrixOperations
24
25
26// =============
27// constructor 1
28// =============
29
32
33template <typename DataType>
35 A_data(NULL),
36 A_indices(NULL),
37 A_index_pointer(NULL)
38{
39}
40
41
42// =============
43// constructor 2
44// =============
45
66
67template <typename DataType>
69 const DataType* A_data_,
70 const LongIndexType* A_indices_,
71 const LongIndexType* A_index_pointer_,
72 const LongIndexType num_rows_,
73 const LongIndexType num_columns_,
74 const FlagType A_is_symmetric_):
75
76 // Base class constructors
77 cLinearOperatorBase(num_rows_, num_columns_),
78 cMatrix<DataType>(A_is_symmetric_),
79
80 // Initializer list
81 A_data(A_data_),
82 A_indices(A_indices_),
83 A_index_pointer(A_index_pointer_)
84{
85}
86
87
88// ==========
89// destructor
90// ==========
91
94
95template <typename DataType>
99
100
101// ==================
102// is identity matrix
103// ==================
104
113
114template <typename DataType>
116{
117 FlagType matrix_is_identity = 1;
118 LongIndexType index_pointer;
119 LongIndexType column;
120 DataType matrix_element;
121 const DataType diagonal = 1.0;
122 const DataType off_diagonal = 0.0;
123
124 // Check matrix element-wise
125 #if defined(USE_OPENMP) && (USE_OPENMP == 1)
126 #pragma omp parallel for \
127 schedule(static) \
128 if (!omp_in_parallel()) \
129 default(none) \
130 shared(matrix_is_identity, diagonal, off_diagonal) \
131 private(index_pointer, column, matrix_element)
132 #endif
133 for (LongIndexType row=0; row < this->num_rows; ++row)
134 {
135 if (matrix_is_identity)
136 {
137 for (index_pointer=this->A_index_pointer[row];
138 index_pointer < this->A_index_pointer[row+1];
139 ++index_pointer)
140 {
141 column = this->A_indices[index_pointer];
142
143 if (!((this->A_is_symmetric) && (column >= row)))
144 {
145 matrix_element = this->A_data[index_pointer];
146
147 if (((row == column) && \
148 (!c_arithmetics::is_equal(matrix_element,
149 diagonal))) || \
150 ((row != column) && \
151 (!c_arithmetics::is_equal(matrix_element,
152 off_diagonal))))
153 {
154 #if defined(USE_OPENMP) && (USE_OPENMP == 1)
155 #pragma omp atomic write
156 #endif
157 matrix_is_identity = 0;
158
159 break;
160 }
161 }
162 }
163 }
164 }
165
166 return matrix_is_identity;
167}
168
169
170// =======
171// get nnz
172// =======
173
181
182template <typename DataType>
184{
185 return this->A_index_pointer[this->num_rows];
186}
187
188
189// ===
190// dot
191// ===
192
209
210template <typename DataType>
212 const DataType* vector,
213 DataType* product)
214{
216 this->A_data,
217 this->A_indices,
218 this->A_index_pointer,
219 vector,
220 this->num_rows,
221 product);
222}
223
224
225// ========
226// dot plus
227// ========
228
246
247template <typename DataType>
249 const DataType* vector,
250 const DataType alpha,
251 DataType* product)
252{
254 this->A_data,
255 this->A_indices,
256 this->A_index_pointer,
257 vector,
258 alpha,
259 this->num_rows,
260 product);
261}
262
263
264// =============
265// transpose dot
266// =============
267
284
285template <typename DataType>
287 const DataType* vector,
288 DataType* product)
289{
291 this->A_data,
292 this->A_indices,
293 this->A_index_pointer,
294 this->A_is_symmetric,
295 vector,
296 this->num_rows,
297 this->num_columns,
298 product);
299}
300
301
302// ==================
303// transpose dot plus
304// ==================
305
324
325template <typename DataType>
327 const DataType* vector,
328 const DataType alpha,
329 DataType* product)
330{
332 this->A_data,
333 this->A_indices,
334 this->A_index_pointer,
335 this->A_is_symmetric,
336 vector,
337 alpha,
338 this->num_rows,
339 product);
340}
341
342
343// ===============================
344// Explicit template instantiation
345// ===============================
346
347template class cCSRMatrix<float>;
348template class cCSRMatrix<double>;
349template class cCSRMatrix<long double>;
Container for CSR matrices.
virtual void transpose_dot_plus(const DataType *vector, const DataType alpha, DataType *product)
Transposed-matrix vector product written in place.
virtual void transpose_dot(const DataType *vector, DataType *product)
Transposed-matrix vector product.
virtual void dot(const DataType *vector, DataType *product)
Matrix vector product.
LongIndexType get_nnz() const
Returns the number of non-zero elements of the sparse matrix.
virtual FlagType is_identity_matrix() const
Checks whether the matrix is identity.
virtual ~cCSRMatrix()
Destructor.
cCSRMatrix()
Default constructor.
virtual void dot_plus(const DataType *vector, const DataType alpha, DataType *product)
Matrix vector product written in place.
Base class for cLinearOperator and cuLinearOperator . This class is not templated so that both cpp an...
static void csr_transposed_matvec(const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const FlagType A_is_symmetric, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csr_transposed_matvec_plus(const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const FlagType A_is_symmetric, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csr_matvec(const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_rows, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csr_matvec_plus(const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
Base class for constant matrices.
Definition c_matrix.h:45
bool is_equal(DataType x, DataType y)
Check if two floating point numbers are equal within a tolerance.
Definition _c_is_equal.h:49
int LongIndexType
Definition types.h:60
int FlagType
Definition types.h:68