imate
C++/CUDA Reference
c_dense_matrix.cpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3  * SPDX-License-Identifier: BSD-3-Clause
4  * SPDX-FileType: SOURCE
5  *
6  * This program is free software: you can redistribute it and/or modify it
7  * under the terms of the license found in the LICENSE.txt file in the root
8  * directory of this source tree.
9  */
10 
11 
12 // =======
13 // Headers
14 // =======
15 
16 #include "./c_dense_matrix.h"
17 #include <cstddef> // NULL
18 #include "../_c_basic_algebra/c_matrix_operations.h" // cMatrixOperations
19 
20 
21 // =============
22 // constructor 1
23 // =============
24 
25 template <typename DataType>
27  A(NULL),
28  A_is_row_major(0)
29 {
30 }
31 
32 
33 // =============
34 // constructor 2
35 // =============
36 
37 template <typename DataType>
39  const DataType* A_,
40  const LongIndexType num_rows_,
41  const LongIndexType num_columns_,
42  const FlagType A_is_row_major_):
43 
44  // Base class constructor
45  cLinearOperator<DataType>(num_rows_, num_columns_),
46 
47  // Initializer list
48  A(A_),
49  A_is_row_major(A_is_row_major_)
50 {
51 }
52 
53 
54 // ==========
55 // destructor
56 // ==========
57 
58 
59 template <typename DataType>
61 {
62 }
63 
64 
65 // ==================
66 // is identity matrix
67 // ==================
68 
77 
78 template <typename DataType>
80 {
81  FlagType matrix_is_identity = 1;
82  DataType matrix_element;
83 
84  // Check matrix element-wise
85  for (LongIndexType row=0; row < this->num_rows; ++row)
86  {
87  for (LongIndexType column=0; column < this-> num_columns; ++column)
88  {
89  // Get an element of the matrix
90  if (this->A_is_row_major)
91  {
92  matrix_element = this->A[row * this->num_columns + column];
93  }
94  else
95  {
96  matrix_element = this->A[column * this->num_rows + row];
97  }
98 
99  // Check the value of element with identity matrix
100  if ((row == column) && (matrix_element != 1.0))
101  {
102  matrix_is_identity = 0;
103  return matrix_is_identity;
104  }
105  else if (matrix_element != 0.0)
106  {
107  matrix_is_identity = 0;
108  return matrix_is_identity;
109  }
110  }
111  }
112 
113  return matrix_is_identity;
114 }
115 
116 
117 // ===
118 // dot
119 // ===
120 
121 template <typename DataType>
123  const DataType* vector,
124  DataType* product)
125 {
127  this->A,
128  vector,
129  this->num_rows,
130  this->num_columns,
131  this->A_is_row_major,
132  product);
133 }
134 
135 
136 // ========
137 // dot plus
138 // ========
139 
140 template <typename DataType>
142  const DataType* vector,
143  const DataType alpha,
144  DataType* product)
145 {
147  this->A,
148  vector,
149  alpha,
150  this->num_rows,
151  this->num_columns,
152  this->A_is_row_major,
153  product);
154 }
155 
156 
157 // =============
158 // transpose dot
159 // =============
160 
161 template <typename DataType>
163  const DataType* vector,
164  DataType* product)
165 {
167  this->A,
168  vector,
169  this->num_rows,
170  this->num_columns,
171  this->A_is_row_major,
172  product);
173 }
174 
175 
176 // ==================
177 // transpose dot plus
178 // ==================
179 
180 template <typename DataType>
182  const DataType* vector,
183  const DataType alpha,
184  DataType* product)
185 {
187  this->A,
188  vector,
189  alpha,
190  this->num_rows,
191  this->num_columns,
192  this->A_is_row_major,
193  product);
194 }
195 
196 
197 // ===============================
198 // Explicit template instantiation
199 // ===============================
200 
201 template class cDenseMatrix<float>;
202 template class cDenseMatrix<double>;
203 template class cDenseMatrix<long double>;
virtual ~cDenseMatrix()
virtual void transpose_dot_plus(const DataType *vector, const DataType alpha, DataType *product)
virtual FlagType is_identity_matrix() const
Checks whether the matrix is identity.
virtual void dot_plus(const DataType *vector, const DataType alpha, DataType *product)
virtual void transpose_dot(const DataType *vector, DataType *product)
virtual void dot(const DataType *vector, DataType *product)
Base class for linear operators. This class serves as interface for all derived classes.
static void dense_transposed_matvec_plus(const DataType *A, const DataType *b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *c)
Computes where is dense, and is the transpose of the matrix .
static void dense_transposed_matvec(const DataType *A, const DataType *b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *c)
Computes matrix vector multiplication where is dense, and is the transpose of the matrix .
static void dense_matvec_plus(const DataType *A, const DataType *b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *c)
Computes the operation where is a dense matrix.
static void dense_matvec(const DataType *A, const DataType *b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *c)
Computes the matrix vector multiplication where is a dense matrix.
int LongIndexType
Definition: types.h:60
int FlagType
Definition: types.h:68