imate
C++/CUDA Reference
Loading...
Searching...
No Matches
c_csc_affine_matrix_function.cpp
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12// =======
13// Headers
14// =======
15
17#include <cassert> // assert
18#include <cstddef> // NULL
19
20
21// =============
22// constructor 1
23// =============
24
47
48template <typename DataType>
50 const DataType* A_data_,
51 const LongIndexType* A_indices_,
52 const LongIndexType* A_index_pointer_,
53 const LongIndexType num_rows_,
54 const LongIndexType num_columns_,
55 const FlagType A_is_symmetric_):
56
57 // Base class constructor
58 cLinearOperatorBase(num_rows_, num_columns_),
59
60 // Initializer list
61 A(A_data_, A_indices_, A_index_pointer_, num_rows_, num_columns_,
62 A_is_symmetric_)
63{
64 // This constructor is called assuming B is identity
65 this->B_is_identity = true;
66
67 // When B is identity, the eigenvalues of A+tB are known for any t
69}
70
71
72// =============
73// constructor 2
74// =============
75
110
111template <typename DataType>
113 const DataType* A_data_,
114 const LongIndexType* A_indices_,
115 const LongIndexType* A_index_pointer_,
116 const LongIndexType num_rows_,
117 const LongIndexType num_columns_,
118 const FlagType A_is_symmetric_,
119 const DataType* B_data_,
120 const LongIndexType* B_indices_,
121 const LongIndexType* B_index_pointer_,
122 const FlagType B_is_symmetric_):
123
124 // Base class constructor
125 cLinearOperatorBase(num_rows_, num_columns_),
126
127 // Initializer list
128 A(A_data_, A_indices_, A_index_pointer_, num_rows_, num_columns_,
129 A_is_symmetric_),
130 B(B_data_, B_indices_, B_index_pointer_, num_rows_, num_columns_,
131 B_is_symmetric_)
132{
133 // Matrix B is assumed to be non-zero. Check if it is identity or generic
134 if (this->B.is_identity_matrix())
135 {
136 this->B_is_identity = true;
138 }
139}
140
141
142// ==========
143// destructor
144// ==========
145
148
149template <typename DataType>
153
154
155// ============
156// set symmetry
157// ============
158
170
171template <typename DataType>
173 const FlagType symmetric)
174{
175 if (symmetric == 1)
176 {
177 this->A.set_symmetry(1);
178 this->B.set_symmetry(1);
179 }
180 else
181 {
182 this->A.set_symmetry(0);
183 this->B.set_symmetry(0);
184 }
185}
186
187
188// ===
189// dot
190// ===
191
206
207template <typename DataType>
209 const DataType* vector,
210 DataType* product)
211{
212 // Matrix A times vector
213 this->A.dot(vector, product);
214 LongIndexType min_vector_size;
215
216 // Matrix B times vector to be added to the product
217 if (this->B_is_identity)
218 {
219 // Check parameter is set
220 assert((this->parameters != NULL) && "Parameter is not set.");
221
222 // Find minimum of the number of rows and columns
223 min_vector_size = \
224 (this->num_rows < this->num_columns) ? \
225 this->num_rows : this->num_columns;
226
227 // Adding input vector to product
228 this->_add_scaled_vector(vector, min_vector_size,
229 this->parameters[0], product);
230 }
231 else
232 {
233 // Check parameter is set
234 assert((this->parameters != NULL) && "Parameter is not set.");
235
236 // Adding parameter times B times input vector to the product
237 this->B.dot_plus(vector, this->parameters[0], product);
238 }
239}
240
241
242// =============
243// transpose dot
244// =============
245
259
260template <typename DataType>
262 const DataType* vector,
263 DataType* product)
264{
265 // Matrix A times vector
266 this->A.transpose_dot(vector, product);
267 LongIndexType min_vector_size;
268
269 // Matrix B times vector to be added to the product
270 if (this->B_is_identity)
271 {
272 // Check parameter is set
273 assert((this->parameters != NULL) && "Parameter is not set.");
274
275 // Find minimum of the number of rows and columns
276 min_vector_size = \
277 (this->num_rows < this->num_columns) ? \
278 this->num_rows : this->num_columns;
279
280 // Adding input vector to product
281 this->_add_scaled_vector(vector, min_vector_size,
282 this->parameters[0], product);
283 }
284 else
285 {
286 // Check parameter is set
287 assert((this->parameters != NULL) && "Parameter is not set.");
288
289 // Adding "parameter * B * input vector" to the product
290 this->B.transpose_dot_plus(vector, this->parameters[0], product);
291 }
292}
293
294
295// ===============================
296// Explicit template instantiation
297// ===============================
298
Container for CSC affine matrix functions of one parameter.
cCSCAffineMatrixFunction(const DataType *A_data_, const LongIndexType *A_indices_, const LongIndexType *A_index_pointer_, const LongIndexType num_rows_, const LongIndexType num_columns_, const FlagType A_is_symmetric_)
Default constructor.
virtual void set_symmetry(const FlagType symmetric)
Specify whether the matrices are symmetic or non-symmetric.
virtual void transpose_dot(const DataType *vector, DataType *product)
Matrix vector product written in place.
virtual void dot(const DataType *vector, DataType *product)
Matrix vector product.
Base class for cLinearOperator and cuLinearOperator . This class is not templated so that both cpp an...
virtual void set_symmetry(FlagType symmetric)=0
int LongIndexType
Definition types.h:60
int FlagType
Definition types.h:68