imate
C++/CUDA Reference
Loading...
Searching...
No Matches
cu_matrix.cu
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12// =======
13// Headers
14// =======
15
16#include "./cu_matrix.h"
17#include <cassert> // assert
18#include "../_cu_definitions/cu_types.h" // __nv_fp8_e5m2, __nv_fp8_e4m3,
19 // __half, __nv_bfloat16
20
21
22// =============
23// constructor 1
24// =============
25
28
29template <typename DataType>
31
32 // Initializer list
33 A_is_symmetric(0)
34{
35}
36
37
38// =============
39// constructor 2
40// =============
41
47
48template <typename DataType>
50
51 // Initializer list
52 A_is_symmetric(A_is_symmetric_)
53{
54}
55
56
57// ==========
58// destructor
59// ==========
60
63
64template <typename DataType>
68
69
70// ============
71// set symmetry
72// ============
73
82
83template <typename DataType>
85{
86 if (symmetric == 1)
87 {
88 this->A_is_symmetric = 1;
89 }
90 else
91 {
92 this->A_is_symmetric = 0;
93 }
94}
95
96
97// ==============
98// get eigenvalue
99// ==============
100
116
117template <typename DataType>
119 const DataType* known_parameters,
120 const DataType known_eigenvalue,
121 const DataType* inquiry_parameters) const
122{
123 assert((false) && "This function should not be called within this class");
124
125 // Void unused variables to avoid compiler warnings (-Wno-unused-parameter)
126 (void) known_parameters;
127 (void) known_eigenvalue;
128 (void) inquiry_parameters;
129
130 return 0;
131}
132
133
134// ===============================
135// Explicit template instantiation
136// ===============================
137
138#if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
139 template class cuMatrix<__nv_fp8_e5m2>;
140#endif
141
142#if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
143 template class cuMatrix<__nv_fp8_e4m3>;
144#endif
145
146#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
147 template class cuMatrix<__half>;
148#endif
149
150#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
151 template class cuMatrix<__nv_bfloat16>;
152#endif
153
154#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
155 template class cuMatrix<float>;
156#endif
157
158#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
159 template class cuMatrix<double>;
160#endif
Base class for constant matrices.
Definition cu_matrix.h:45
cuMatrix()
Default constructor.
Definition cu_matrix.cu:30
virtual ~cuMatrix()
Destructor.
Definition cu_matrix.cu:65
DataType get_eigenvalue(const DataType *known_parameters, const DataType known_eigenvalue, const DataType *inquiry_parameters) const
This virtual function is implemented from its pure virtual function of the base class....
Definition cu_matrix.cu:118
virtual void set_symmetry(const FlagType symmetric)
Specify whether the matrix is symmetic or non-symmetric.
Definition cu_matrix.cu:84
int FlagType
Definition types.h:68