imate
C++/CUDA Reference
Loading...
Searching...
No Matches
cu_dense_affine_matrix_function.cu
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12// =======
13// Headers
14// =======
15
17#include <cstddef> // NULL
18#include <cassert> // assert
19#include "../_cu_definitions/cu_types.h" // __nv_fp8_e5m2, __nv_fp8_e4m3,
20 // __half, __nv_bfloat16
21#include "../_definitions/debugging.h" // ASSERT
22
23
24// =============
25// constructor 1
26// =============
27
51
52template <typename DataType>
54 const DataType* A_,
55 const LongIndexType num_rows_,
56 const LongIndexType num_columns_,
57 const FlagType A_is_row_major_,
58 const FlagType A_is_symmetric_,
59 const int num_gpu_devices_):
60
61 // Base class constructor
62 cLinearOperatorBase(num_rows_, num_columns_),
63 cuLinearOperator<DataType>(num_gpu_devices_),
64
65 // Initializer list
66 A(A_, num_rows_, num_columns_, A_is_row_major_, A_is_symmetric_,
67 num_gpu_devices_)
68{
69 // This constructor is called assuming B is identity
70 this->B_is_identity = true;
71
72 // When B is identity, the eigenvalues of A+tB are known for any t
74
75 // Set gpu device
77}
78
79
80// =============
81// constructor 2
82// =============
83
120
121template <typename DataType>
123 const DataType* A_,
124 const LongIndexType num_rows_,
125 const LongIndexType num_columns_,
126 const FlagType A_is_row_major_,
127 const FlagType A_is_symmetric_,
128 const DataType* B_,
129 const FlagType B_is_row_major_,
130 const FlagType B_is_symmetric_,
131 const int num_gpu_devices_):
132
133 // Base class constructor
134 cLinearOperatorBase(num_rows_, num_columns_),
135 cuLinearOperator<DataType>(num_gpu_devices_),
136
137 // Initializer list
138 A(A_, num_rows_, num_columns_, A_is_row_major_, A_is_symmetric_,
139 num_gpu_devices_),
140 B(B_, num_rows_, num_columns_, B_is_row_major_, B_is_symmetric_,
141 num_gpu_devices_)
142{
143 // Matrix B is assumed to be non-zero. Check if it is identity or generic
144 if (this->B.is_identity_matrix())
145 {
146 this->B_is_identity = true;
148 }
149
150 // Set gpu device
152}
153
154
155// ==========
156// destructor
157// ==========
158
161
162template <typename DataType>
166
167
168// ============
169// set symmetry
170// ============
171
183
184template <typename DataType>
186 const FlagType symmetric)
187{
188 if (symmetric == 1)
189 {
190 this->A.set_symmetry(1);
191 this->B.set_symmetry(1);
192 }
193 else
194 {
195 this->A.set_symmetry(0);
196 this->B.set_symmetry(0);
197 }
198}
199
200
201// ===
202// dot
203// ===
204
220
221template <typename DataType>
223 const DataType* vector,
224 DataType* product)
225{
226 // Matrix A times vector
227 this->A.dot(vector, product);
228 LongIndexType min_vector_size;
229
230 // Matrix B times vector to be added to the product
231 if (this->B_is_identity)
232 {
233 // Check parameter is set
234 ASSERT((this->parameters != NULL), "Parameter is not set.");
235
236 // Find minimum of the number of rows and columns
237 min_vector_size = \
238 (this->num_rows < this->num_columns) ? \
239 this->num_rows : this->num_columns;
240
241 // Adding input vector to product
242 this->_add_scaled_vector(vector, min_vector_size,
243 this->parameters[0], product);
244 }
245 else
246 {
247 // Check parameter is set
248 ASSERT((this->parameters != NULL), "Parameter is not set.");
249
250 // Adding parameter times B times input vector to the product
251 this->B.dot_plus(vector, this->parameters[0], product);
252 }
253}
254
255
256// =============
257// transpose dot
258// =============
259
275
276template <typename DataType>
278 const DataType* vector,
279 DataType* product)
280{
281 // Matrix A times vector
282 this->A.transpose_dot(vector, product);
283 LongIndexType min_vector_size;
284
285 // Matrix B times vector to be added to the product
286 if (this->B_is_identity)
287 {
288 // Check parameter is set
289 ASSERT((this->parameters != NULL), "Parameter is not set.");
290
291 // Find minimum of the number of rows and columns
292 min_vector_size = \
293 (this->num_rows < this->num_columns) ? \
294 this->num_rows : this->num_columns;
295
296 // Adding input vector to product
297 this->_add_scaled_vector(vector, min_vector_size,
298 this->parameters[0], product);
299 }
300 else
301 {
302 // Check parameter is set
303 ASSERT((this->parameters != NULL), "Parameter is not set.");
304
305 // Adding "parameter * B * input vector" to the product
306 this->B.transpose_dot_plus(vector, this->parameters[0], product);
307 }
308}
309
310
311// ===============================
312// Explicit template instantiation
313// ===============================
314
315#if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
317#endif
318
319#if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
321#endif
322
323#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
325#endif
326
327#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
329#endif
330
331#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
333#endif
334
335#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
337#endif
Base class for cLinearOperator and cuLinearOperator . This class is not templated so that both cpp an...
Container for dense affine matrix functions of one parameter.
virtual void dot(const DataType *vector, DataType *product)
Matrix vector product.
cuDenseAffineMatrixFunction(const DataType *A_, const LongIndexType num_rows_, const LongIndexType num_columns_, const FlagType A_is_row_major_, const FlagType A_is_symmetric_, const int num_gpu_devices_)
Default constructor.
virtual void set_symmetry(const FlagType symmetric)
Specify whether the matrices are symmetic or non-symmetric.
virtual void transpose_dot(const DataType *vector, DataType *product)
Matrix vector product written in place.
Base class for linear operators. This class serves as interface for all derived classes.
void initialize_cublas_handle()
Creates a cublasHandle_t object, if not created already.
#define ASSERT(condition, message)
Definition debugging.h:20
int LongIndexType
Definition types.h:60
int FlagType
Definition types.h:68