imate
C++/CUDA Reference
Loading...
Searching...
No Matches
cu_csc_affine_matrix_function.cu
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12// =======
13// Headers
14// =======
15
17#include <cstddef> // NULL
18#include <cassert> // assert
19#include "../_cu_definitions/cu_types.h" // __nv_fp8_e5m2, __nv_fp8_e4m3,
20 // __half, __nv_bfloat16
21#include "../_definitions/debugging.h" // ASSERT
22
23
24// =============
25// constructor 1
26// =============
27
52
53template <typename DataType>
55 const DataType* A_data_,
56 const LongIndexType* A_indices_,
57 const LongIndexType* A_index_pointer_,
58 const LongIndexType num_rows_,
59 const LongIndexType num_columns_,
60 const FlagType A_is_symmetric_,
61 const int num_gpu_devices_):
62
63 // Base class constructor
64 cLinearOperatorBase(num_rows_, num_columns_),
65 cuLinearOperator<DataType>(num_gpu_devices_),
66
67 // Initializer list
68 A(A_data_, A_indices_, A_index_pointer_, num_rows_, num_columns_,
69 A_is_symmetric_, num_gpu_devices_)
70{
71 // This constructor is called assuming B is identity
72 this->B_is_identity = true;
73
74 // When B is identity, the eigenvalues of A+tB are known for any t
76
77 // Set gpu device
79}
80
81
82// =============
83// constructor 2
84// =============
85
122
123template <typename DataType>
125 const DataType* A_data_,
126 const LongIndexType* A_indices_,
127 const LongIndexType* A_index_pointer_,
128 const LongIndexType num_rows_,
129 const LongIndexType num_columns_,
130 const FlagType A_is_symmetric_,
131 const DataType* B_data_,
132 const LongIndexType* B_indices_,
133 const LongIndexType* B_index_pointer_,
134 const FlagType B_is_symmetric_,
135 const int num_gpu_devices_):
136
137 // Base class constructor
138 cLinearOperatorBase(num_rows_, num_columns_),
139 cuLinearOperator<DataType>(num_gpu_devices_),
140
141 // Initializer list
142 A(A_data_, A_indices_, A_index_pointer_, num_rows_, num_columns_,
143 A_is_symmetric_, num_gpu_devices_),
144 B(B_data_, B_indices_, B_index_pointer_, num_rows_, num_columns_,
145 B_is_symmetric_, num_gpu_devices_)
146{
147 // Matrix B is assumed to be non-zero. Check if it is identity or generic
148 if (this->B.is_identity_matrix())
149 {
150 this->B_is_identity = true;
152 }
153
154 // Set gpu device
156}
157
158
159// ==========
160// destructor
161// ==========
162
165
166template <typename DataType>
170
171
172// ============
173// set symmetry
174// ============
175
187
188template <typename DataType>
190 const FlagType symmetric)
191{
192 if (symmetric == 1)
193 {
194 this->A.set_symmetry(1);
195 this->B.set_symmetry(1);
196 }
197 else
198 {
199 this->A.set_symmetry(0);
200 this->B.set_symmetry(0);
201 }
202}
203
204
205// ===
206// dot
207// ===
208
224
225template <typename DataType>
227 const DataType* vector,
228 DataType* product)
229{
230 // Matrix A times vector
231 this->A.dot(vector, product);
232 LongIndexType min_vector_size;
233
234 // Matrix B times vector to be added to the product
235 if (this->B_is_identity)
236 {
237 // Check parameter is set
238 ASSERT((this->parameters != NULL), "Parameter is not set.");
239
240 // Find minimum of the number of rows and columns
241 min_vector_size = \
242 (this->num_rows < this->num_columns) ? \
243 this->num_rows : this->num_columns;
244
245 // Adding input vector to product
246 this->_add_scaled_vector(vector, min_vector_size,
247 this->parameters[0], product);
248 }
249 else
250 {
251 // Check parameter is set
252 ASSERT((this->parameters != NULL), "Parameter is not set.");
253
254 // Adding parameter times B times input vector to the product
255 this->B.dot_plus(vector, this->parameters[0], product);
256 }
257}
258
259
260// =============
261// transpose dot
262// =============
263
279
280template <typename DataType>
282 const DataType* vector,
283 DataType* product)
284{
285 // Matrix A times vector
286 this->A.transpose_dot(vector, product);
287 LongIndexType min_vector_size;
288
289 // Matrix B times vector to be added to the product
290 if (this->B_is_identity)
291 {
292 // Check parameter is set
293 ASSERT((this->parameters != NULL), "Parameter is not set.");
294
295 // Find minimum of the number of rows and columns
296 min_vector_size = \
297 (this->num_rows < this->num_columns) ? \
298 this->num_rows : this->num_columns;
299
300 // Adding input vector to product
301 this->_add_scaled_vector(vector, min_vector_size,
302 this->parameters[0], product);
303 }
304 else
305 {
306 // Check parameter is set
307 ASSERT((this->parameters != NULL), "Parameter is not set.");
308
309 // Adding "parameter * B * input vector" to the product
310 this->B.transpose_dot_plus(vector, this->parameters[0], product);
311 }
312}
313
314
315// ===============================
316// Explicit template instantiation
317// ===============================
318
319#if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
321#endif
322
323#if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
325#endif
326
327#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
329#endif
330
331#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
333#endif
334
335#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
337#endif
338
339#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
341#endif
Base class for cLinearOperator and cuLinearOperator . This class is not templated so that both cpp an...
Container for CSC affine matrix functions of one parameter.
virtual void transpose_dot(const DataType *vector, DataType *product)
Matrix vector product written in place.
virtual void dot(const DataType *vector, DataType *product)
Matrix vector product.
virtual void set_symmetry(const FlagType symmetric)
Specify whether the matrices are symmetic or non-symmetric.
cuCSCAffineMatrixFunction(const DataType *A_data_, const LongIndexType *A_indices_, const LongIndexType *A_index_pointer_, const LongIndexType num_rows_, const LongIndexType num_columns_, const FlagType A_is_symmetric_, const int num_gpu_devices_)
Default constructor.
Base class for linear operators. This class serves as interface for all derived classes.
void initialize_cusparse_handle()
Creates a cusparseHandle_t object, if not created already.
#define ASSERT(condition, message)
Definition debugging.h:20
int LongIndexType
Definition types.h:60
int FlagType
Definition types.h:68