imate
C++/CUDA Reference
Loading...
Searching...
No Matches
cu_affine_matrix_function.cu
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12// =======
13// Headers
14// =======
15
17#include <cassert> // assert
18#include "../_cu_definitions/cu_types.h" // __nv_fp8_e5m2, __nv_fp8_e4m3,
19 // __half, __nv_bfloat16
20#include "../_definitions/debugging.h" // ASSERT
21#include "../_cu_basic_algebra/cu_vector_operations.h" // cuVectorOperations
22#include "../_cuda_utilities/cuda_api.h" // CudaAPI
23#include "../_cu_arithmetics/cu_arithmetics.h" // cu_arithmetics
24
25
26// ===========
27// constructor
28// ===========
29
32
33template <typename DataType>
35 B_is_identity(false)
36{
37 // This class has one parameter that is t in A+tB
38 this->num_parameters = 1;
39}
40
41
42// ==========
43// destructor
44// ==========
45
48
49template <typename DataType>
53
54
55// ==============
56// get eigenvalue
57// ==============
58
95
96template <typename DataType>
98 const DataType* known_parameters,
99 const DataType known_eigenvalue,
100 const DataType* inquiry_parameters) const
101{
102 ASSERT((this->eigenvalue_relation_known == 1),
103 "An eigenvalue relation is not known. This function should be "
104 "called only when the matrix B is a scalar multiple of the "
105 "identity matrix");
106
107 // Shift the eigenvalue by the parameter
108 DataType inquiry_eigenvalue = \
109 cu_arithmetics::add<DataType>(
111 known_eigenvalue,
112 known_parameters[0]),
113 inquiry_parameters[0]
114 );
115
116 return inquiry_eigenvalue;
117}
118
119
120// =================
121// add scaled vector
122// =================
123
138
139template <typename DataType>
141 const DataType* input_vector,
142 const LongIndexType vector_size,
143 const DataType scale,
144 DataType* output_vector) const
145{
146 // Get device id
147 int device_id = CudaAPI<DataType>::get_device();
148
149 // Negative of scale (negate in double type as some types like
150 // __nv_fp8_exmx does not yet support such operation yet)
151 DataType neg_scale = cu_arithmetics::cast<double, DataType>(
153
154 // Subtracting two vectors with minus scale sign, which is adding.
156 this->cublas_handle[device_id], input_vector, vector_size,
157 neg_scale, output_vector);
158}
159
160
161// ===============================
162// Explicit template instantiation
163// ===============================
164
165#if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
167#endif
168
169#if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
171#endif
172
173#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
174 template class cuAffineMatrixFunction<__half>;
175#endif
176
177#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
179#endif
180
181#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
182 template class cuAffineMatrixFunction<float>;
183#endif
184
185#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
186 template class cuAffineMatrixFunction<double>;
187#endif
static int get_device()
Gets the current device in multi-gpu applications.
Definition cuda_api.cu:209
Base class for affine matrix functions of one parameter.
virtual ~cuAffineMatrixFunction()
Virtual destructor.
DataType get_eigenvalue(const DataType *known_parameters, const DataType known_eigenvalue, const DataType *inquiry_parameters) const
This function defines an analytic relationship between a given set of parameters and the correspondin...
void _add_scaled_vector(const DataType *input_vector, const LongIndexType vector_size, const DataType scale, DataType *output_vector) const
Performs the operation , where is an input vector scaled by and it the output vector.
static void subtract_scaled_vector(cublasHandle_t cublas_handle, const DataType *RESTRICT input_vector, const LongIndexType vector_size, const DataType scale, DataType *RESTRICT output_vector)
Subtracts the scaled input vector from the output vector.
#define ASSERT(condition, message)
Definition debugging.h:20
__host__ __device__ DataType abs(const DataType x)
Absolute value of a floating point number.
int LongIndexType
Definition types.h:60