imate
C++/CUDA Reference
Loading...
Searching...
No Matches
c_dense_matrix.cpp
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12// =======
13// Headers
14// =======
15
16#include "./c_dense_matrix.h"
17#include <cstddef> // NULL
18#include "../_definitions/definitions.h" // USE_OPENMP
19#if defined(USE_OPENMP) && (USE_OPENMP == 1)
20 #include <omp.h> // omp_in_parallel
21#endif
22#include "../_c_arithmetics/c_arithmetics.h" // c_arithmetics
23#include "../_c_basic_algebra/c_matrix_operations.h" // cMatrixOperations
24
25
26// =============
27// constructor 1
28// =============
29
32
33template <typename DataType>
35
36 // Initializer list
37 A(NULL),
38 A_is_row_major(0)
39{
40}
41
42
43// =============
44// constructor 2
45// =============
46
66
67template <typename DataType>
69 const DataType* A_,
70 const LongIndexType num_rows_,
71 const LongIndexType num_columns_,
72 const FlagType A_is_row_major_,
73 const FlagType A_is_symmetric_):
74
75 // Base class constructors
76 cLinearOperatorBase(num_rows_, num_columns_),
77 cMatrix<DataType>(A_is_symmetric_),
78
79 // Initializer list
80 A(A_),
81 A_is_row_major(A_is_row_major_)
82{
83}
84
85
86// ==========
87// destructor
88// ==========
89
92
93template <typename DataType>
97
98
99// ==================
100// is identity matrix
101// ==================
102
111
112template <typename DataType>
114{
115 FlagType matrix_is_identity = 1;
116 DataType matrix_element;
117 const DataType diagonal = 1.0;
118 const DataType off_diagonal = 0.0;
119
120 // Check matrix element-wise
121 if (this->A_is_row_major)
122 {
123 // Row-major matrix
124 LongIndexType column;
125 LongIndexType num_checking_columns;
126
127 #if defined(USE_OPENMP) && (USE_OPENMP == 1)
128 #pragma omp parallel for \
129 schedule(static) \
130 if (!omp_in_parallel()) \
131 default(none) \
132 shared(matrix_is_identity, diagonal, off_diagonal) \
133 private(column, num_checking_columns, matrix_element)
134 #endif
135 for (LongIndexType row=0; row < this->num_rows; ++row)
136 {
137 if (matrix_is_identity)
138 {
139 if (this->A_is_symmetric)
140 {
141 // Check only half of the columns up to diagonal element
142 num_checking_columns = row + 1;
143 }
144 else
145 {
146 num_checking_columns = this->num_columns;
147 }
148
149 for (column=0; column < num_checking_columns; ++column)
150 {
151 // Get an element of the matrix
152 matrix_element = this->A[row * this->num_columns + column];
153
154 // Check the value of element with identity matrix
155 if (((row == column) && \
156 (!c_arithmetics::is_equal(matrix_element,
157 diagonal))) || \
158 ((row != column) && \
159 (!c_arithmetics::is_equal(matrix_element,
160 off_diagonal))))
161 {
162 #if defined(USE_OPENMP) && (USE_OPENMP == 1)
163 #pragma omp atomic write
164 #endif
165 matrix_is_identity = 0;
166
167 break;
168 }
169 }
170 }
171 }
172 }
173 else
174 {
175 // Column-major matrix
176 LongIndexType row;
177 LongIndexType num_checking_rows;
178
179 #if defined(USE_OPENMP) && (USE_OPENMP == 1)
180 #pragma omp parallel for \
181 schedule(static) \
182 if (!omp_in_parallel()) \
183 default(none) \
184 shared(matrix_is_identity, diagonal, off_diagonal) \
185 private(row, num_checking_rows, matrix_element)
186 #endif
187 for (LongIndexType column=0; column < this-> num_columns; ++column)
188 {
189 if (matrix_is_identity)
190 {
191 if (this->A_is_symmetric)
192 {
193 // Check only half of the rows up to diagonal element
194 num_checking_rows = column + 1;
195 }
196 else
197 {
198 num_checking_rows = this->num_rows;
199 }
200
201 for (row=0; row < num_checking_rows; ++row)
202 {
203 // Get an element of the matrix
204 matrix_element = this->A[column * this->num_rows + row];
205
206 // Check the value of element with identity matrix
207 if (((row == column) && \
208 (!c_arithmetics::is_equal(matrix_element,
209 diagonal))) || \
210 ((row != column) && \
211 (!c_arithmetics::is_equal(matrix_element,
212 off_diagonal))))
213 {
214 #if defined(USE_OPENMP) && (USE_OPENMP == 1)
215 #pragma omp atomic write
216 #endif
217 matrix_is_identity = 0;
218
219 break;
220 }
221 }
222 }
223 }
224 }
225
226 return matrix_is_identity;
227}
228
229
230// ===
231// dot
232// ===
233
250
251template <typename DataType>
253 const DataType* vector,
254 DataType* product)
255{
257 this->A,
258 vector,
259 this->num_rows,
260 this->num_columns,
261 this->A_is_row_major,
262 this->A_is_symmetric,
263 product);
264}
265
266
267// ========
268// dot plus
269// ========
270
288
289template <typename DataType>
291 const DataType* vector,
292 const DataType alpha,
293 DataType* product)
294{
296 this->A,
297 vector,
298 alpha,
299 this->num_rows,
300 this->num_columns,
301 this->A_is_row_major,
302 this->A_is_symmetric,
303 product);
304}
305
306
307// =============
308// transpose dot
309// =============
310
327
328template <typename DataType>
330 const DataType* vector,
331 DataType* product)
332{
334 this->A,
335 vector,
336 this->num_rows,
337 this->num_columns,
338 this->A_is_row_major,
339 this->A_is_symmetric,
340 product);
341}
342
343
344// ==================
345// transpose dot plus
346// ==================
347
366
367template <typename DataType>
369 const DataType* vector,
370 const DataType alpha,
371 DataType* product)
372{
374 this->A,
375 vector,
376 alpha,
377 this->num_rows,
378 this->num_columns,
379 this->A_is_row_major,
380 this->A_is_symmetric,
381 product);
382}
383
384
385// ===============================
386// Explicit template instantiation
387// ===============================
388
389template class cDenseMatrix<float>;
390template class cDenseMatrix<double>;
391template class cDenseMatrix<long double>;
Container for dense matrices.
virtual ~cDenseMatrix()
Destructor.
virtual void transpose_dot_plus(const DataType *vector, const DataType alpha, DataType *product)
Transposed-matrix vector product written in place.
virtual FlagType is_identity_matrix() const
Checks whether the matrix is identity.
virtual void dot_plus(const DataType *vector, const DataType alpha, DataType *product)
Matrix vector product written in place.
virtual void transpose_dot(const DataType *vector, DataType *product)
Transposed-matrix vector product.
cDenseMatrix()
Default constructor.
virtual void dot(const DataType *vector, DataType *product)
Matrix vector product.
Base class for cLinearOperator and cuLinearOperator . This class is not templated so that both cpp an...
static void dense_matvec(const DataType *RESTRICT A, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, const FlagType A_is_symmetric, DataType *RESTRICT c)
Computes the matrix vector multiplication where is a dense matrix.
static void dense_matvec_plus(const DataType *RESTRICT A, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, const FlagType A_is_symmetric, DataType *RESTRICT c)
Computes the operation where is a dense matrix.
static void dense_transposed_matvec(const DataType *RESTRICT A, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, const FlagType A_is_symmetric, DataType *RESTRICT c)
Computes matrix vector multiplication where is dense, and is the transpose of the matrix .
static void dense_transposed_matvec_plus(const DataType *RESTRICT A, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, const FlagType A_is_symmetric, DataType *RESTRICT c)
Computes where is dense, and is the transpose of the matrix .
Base class for constant matrices.
Definition c_matrix.h:45
bool is_equal(DataType x, DataType y)
Check if two floating point numbers are equal within a tolerance.
Definition _c_is_equal.h:49
int LongIndexType
Definition types.h:60
int FlagType
Definition types.h:68