imate
C++/CUDA Reference
Loading...
Searching...
No Matches
c_dense_affine_matrix_function.cpp
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12// =======
13// Headers
14// =======
15
17#include <cassert> // assert
18#include <cstddef> // NULL
19
20
21// =============
22// constructor 1
23// =============
24
46
47template <typename DataType>
49 const DataType* A_,
50 const LongIndexType num_rows_,
51 const LongIndexType num_columns_,
52 const FlagType A_is_row_major_,
53 const FlagType A_is_symmetric_):
54
55 // Base class constructor
56 cLinearOperatorBase(num_rows_, num_columns_),
57
58 // Initializer list
59 A(A_, num_rows_, num_columns_, A_is_row_major_, A_is_symmetric_)
60{
61 // This constructor is called assuming B is identity
62 this->B_is_identity = true;
63
64 // When B is identity, the eigenvalues of A+tB are known for any t
66}
67
68
69// =============
70// constructor 2
71// =============
72
107
108template <typename DataType>
110 const DataType* A_,
111 const LongIndexType num_rows_,
112 const LongIndexType num_columns_,
113 const FlagType A_is_row_major_,
114 const FlagType A_is_symmetric_,
115 const DataType* B_,
116 const FlagType B_is_row_major_,
117 const FlagType B_is_symmetric_):
118
119 // Base class constructor
120 cLinearOperatorBase(num_rows_, num_columns_),
121
122 // Initializer list
123 A(A_, num_rows_, num_columns_, A_is_row_major_, A_is_symmetric_),
124 B(B_, num_rows_, num_columns_, B_is_row_major_, B_is_symmetric_)
125{
126 // Matrix B is assumed to be non-zero. Check if it is identity or generic
127 if (this->B.is_identity_matrix())
128 {
129 this->B_is_identity = true;
131 }
132}
133
134
135// ==========
136// destructor
137// ==========
138
141
142template <typename DataType>
146
147
148// ============
149// set symmetry
150// ============
151
163
164template <typename DataType>
166 const FlagType symmetric)
167{
168 if (symmetric == 1)
169 {
170 this->A.set_symmetry(1);
171 this->B.set_symmetry(1);
172 }
173 else
174 {
175 this->A.set_symmetry(0);
176 this->B.set_symmetry(0);
177 }
178}
179
180
181// ===
182// dot
183// ===
184
199
200template <typename DataType>
202 const DataType* vector,
203 DataType* product)
204{
205 // Matrix A times vector
206 this->A.dot(vector, product);
207 LongIndexType min_vector_size;
208
209 // Matrix B times vector to be added to the product
210 if (this->B_is_identity)
211 {
212 // Check parameter is set
213 assert((this->parameters != NULL) && "Parameter is not set.");
214
215 // Find minimum of the number of rows and columns
216 min_vector_size = \
217 (this->num_rows < this->num_columns) ? \
218 this->num_rows : this->num_columns;
219
220 // Adding input vector to product
221 this->_add_scaled_vector(vector, min_vector_size,
222 this->parameters[0], product);
223 }
224 else
225 {
226 // Check parameter is set
227 assert((this->parameters != NULL) && "Parameter is not set.");
228
229 // Adding parameter times B times input vector to the product
230 this->B.dot_plus(vector, this->parameters[0], product);
231 }
232}
233
234
235// =============
236// transpose dot
237// =============
238
252
253template <typename DataType>
255 const DataType* vector,
256 DataType* product)
257{
258 // Matrix A times vector
259 this->A.transpose_dot(vector, product);
260 LongIndexType min_vector_size;
261
262 // Matrix B times vector to be added to the product
263 if (this->B_is_identity)
264 {
265 // Check parameter is set
266 assert((this->parameters != NULL) && "Parameter is not set.");
267
268 // Find minimum of the number of rows and columns
269 min_vector_size = \
270 (this->num_rows < this->num_columns) ? \
271 this->num_rows : this->num_columns;
272
273 // Adding input vector to product
274 this->_add_scaled_vector(vector, min_vector_size,
275 this->parameters[0], product);
276 }
277 else
278 {
279 // Check parameter is set
280 assert((this->parameters != NULL) && "Parameter is not set.");
281
282 // Adding "parameter * B * input vector" to the product
283 this->B.transpose_dot_plus(vector, this->parameters[0], product);
284 }
285}
286
287
288// ===============================
289// Explicit template instantiation
290// ===============================
291
Container for dense affine matrix functions of one parameter.
cDenseAffineMatrixFunction(const DataType *A_, const LongIndexType num_rows_, const LongIndexType num_columns_, const FlagType A_is_row_major_, const FlagType A_is_symmetric_)
Default constructor.
virtual void dot(const DataType *vector, DataType *product)
Matrix vector product.
virtual void transpose_dot(const DataType *vector, DataType *product)
Matrix vector product written in place.
virtual void set_symmetry(const FlagType symmetric)
Specify whether the matrices are symmetic or non-symmetric.
Base class for cLinearOperator and cuLinearOperator . This class is not templated so that both cpp an...
virtual void set_symmetry(FlagType symmetric)=0
int LongIndexType
Definition types.h:60
int FlagType
Definition types.h:68