imate
C++/CUDA Reference
Loading...
Searching...
No Matches
_cu_mul.h
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11#ifndef _CU_ARITHMETICS_CU_MUL_H_
12#define _CU_ARITHMETICS_CU_MUL_H_
13
14// =======
15// Headers
16// =======
17
18#include "../_cu_definitions/cu_types.h" // __nv_fp8_e5m2, __nv_fp8_e4m3,
19 // __half, __nv_bfloat16, __hmul
20#include <cassert> // assert
21
22
23// =============
24// cu arithmetic
25// =============
26
39
40namespace cu_arithmetics
41{
42 // ===
43 // mul
44 // ===
45
57
58 template <typename DataType>
60 const DataType x,
61 const DataType y);
62
63
64 // ===
65 // mul (__nv_fp8_e5m2)
66 // ===
67
79
80 #if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
81 template<>
83 const __nv_fp8_e5m2 x,
84 const __nv_fp8_e5m2 y)
85 {
86 // Not implemented
87 assert(false);
88
89 return __nv_fp8_e5m2(NAN);
90 }
91 #endif
92
93
94 // ===
95 // mul (__nv_fp8_e4m3)
96 // ===
97
109
110 #if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
111 template<>
113 const __nv_fp8_e4m3 x,
114 const __nv_fp8_e4m3 y)
115 {
116 // Not implemented
117 assert(false);
118
119 return __nv_fp8_e4m3(NAN);
120 }
121 #endif
122
123
124 // ===
125 // mul (__half)
126 // ===
127
139
140 #if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
141 template<>
143 const __half x,
144 const __half y)
145 {
146 return __hmul(x, y);
147 }
148 #endif
149
150
151 // ===
152 // mul (__nv_bfloat16)
153 // ===
154
166
167 #if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
168 template<>
170 const __nv_bfloat16 x,
171 const __nv_bfloat16 y)
172 {
173 return __hmul(x, y);
174 }
175 #endif
176
177
178 // ===
179 // mul (float)
180 // ===
181
193
194 #if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
195 template<>
197 const float x,
198 const float y)
199 {
200 return x * y;
201 }
202 #endif
203
204
205 // ===
206 // mul (double)
207 // ===
208
220
221 #if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
222 template<>
224 const double x,
225 const double y)
226 {
227 return x * y;
228 }
229 #endif
230
231
232 // ===
233 // mul
234 // ===
235
249
250 template <typename DataType>
252 const DataType x,
253 const DataType y,
254 const DataType z);
255
256
257 // ===
258 // mul (__half)
259 // ===
260
274
275 #if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
276 template<>
278 const __half x,
279 const __half y,
280 const __half z)
281 {
282 return __hmul(__hmul(x, y), z);
283 }
284 #endif
285
286
287 // ===
288 // mul (__nv_bfloat16)
289 // ===
290
304
305 #if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
306 template<>
308 const __nv_bfloat16 x,
309 const __nv_bfloat16 y,
310 const __nv_bfloat16 z)
311 {
312 return __hmul(__hmul(x, y), z);
313 }
314 #endif
315
316
317 // ===
318 // mul (float)
319 // ===
320
334
335 #if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
336 template<>
338 const float x,
339 const float y,
340 const float z)
341 {
342 return x * y * z;
343 }
344 #endif
345
346
347 // ===
348 // mul (double)
349 // ===
350
364
365 #if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
366 template<>
368 const double x,
369 const double y,
370 const double z)
371 {
372 return x * y * z;
373 }
374 #endif
375
376} // namespace cu_arithmetics
377
378#endif // _CU_ARITHMETICS_CU_MUL_H_
Cast from float to __half and __nv_bfloat16 types and vice-versa, and float to double and vice-versa.
Definition _cu_abs.h:43
__host__ __device__ DataType mul(const DataType x, const DataType y)
Multiply two floating point numbers in round-to-nearest-even mode.
__host__ __device__ DataType abs(const DataType x)
Absolute value of a floating point number.
__host__ __device__ double mul< double >(const double x, const double y)
Multiply two double type numbers in round-to-nearest-even mode.
Definition _cu_mul.h:223
__host__ __device__ float mul< float >(const float x, const float y)
Multiply two __nv_fp8_e5m2 type numbers in round-to-nearest-even mode.
Definition _cu_mul.h:196