11#ifndef _CU_ARITHMETICS_CU_CAST_H_
12#define _CU_ARITHMETICS_CU_CAST_H_
19#include "../_cu_definitions/cu_types.h"
63 template <
typename InputDataType,
typename OutputDataType>
81 #if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
105 #if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
129 #if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
153 #if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
177 #if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
201 #if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
225 #if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
249 #if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
273 #if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
296 #if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
319 #if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
342 #if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
365 #if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
389 #if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
413 #if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
437 #if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
487 return static_cast<double>(
x);
531 return static_cast<float>(
x);
550 #if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
575 #if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
600 #if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
624 #if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
653 return static_cast<float>(
x);
676 return static_cast<int>(
x);
699 return static_cast<double>(
x);
722 return static_cast<int>(
x);
741 #if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
744 const unsigned int x)
766 #if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
769 const unsigned int x)
791 #if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
794 const unsigned int x)
816 #if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
819 const unsigned int x)
843 const unsigned int x)
845 return static_cast<float>(
x);
868 return static_cast<unsigned int>(
x);
889 const unsigned int x)
891 return static_cast<double>(
x);
914 return static_cast<unsigned int>(
x);
933 #if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
937 const long long int x)
959 #if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
963 const long long int x)
985 #if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
988 const long long int x)
1010 #if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
1014 const long long int x)
1038 const long long int x)
1040 return static_cast<float>(
x);
1063 return static_cast<long long int>(
x);
1084 const long long int x)
1086 return static_cast<double>(
x);
1109 return static_cast<long long int>(
x);
1128 #if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
1132 const unsigned long long int x)
1154 #if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
1158 const unsigned long long int x)
1180 #if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
1183 const unsigned long long int x)
1205 #if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
1209 const unsigned long long int x)
1233 const unsigned long long int x)
1235 return static_cast<float>(
x);
1256 float,
unsigned long long int>(
1259 return static_cast<unsigned long long int>(
x);
1280 const unsigned long long int x)
1282 return static_cast<double>(
x);
1303 double,
unsigned long long int>(
1306 return static_cast<unsigned long long int>(
x);
Cast from float to __half and __nv_bfloat16 types and vice-versa, and float to double and vice-versa.
__host__ __device__ DataType abs(const DataType x)
Absolute value of a floating point number.
__host__ __device__ double cast< float, double >(const float x)
Cast float type to double type.
__host__ __device__ long long int cast< double, long long int >(const double x)
Cast double type to long long int type in round-to-nearest-even mode.
__host__ __device__ int cast< float, int >(const float x)
Cast float type to int type in round-to-nearest-even mode.
__host__ __device__ float cast< int, float >(const int x)
Cast int type to __nv_fp8_e5m2 type in round-to-nearest-even mode.
__host__ __device__ double cast< unsigned int, double >(const unsigned int x)
Cast unsigned int type to double type in round-to-nearest-even mode.
__host__ __device__ double cast< unsigned long long int, double >(const unsigned long long int x)
Cast unsigned long long int type to double type in round-to-nearest-even mode.
__host__ __device__ unsigned int cast< double, unsigned int >(const double x)
Cast double type to unsigned int type in round-to-nearest-even mode.
__host__ __device__ double cast< double, double >(const double x)
Cast double type to double type (no action needed)
__host__ __device__ float cast< long long int, float >(const long long int x)
Cast long long int type to __nv_fp8_e5m2 type in round-to-nearest-even mode.
__host__ __device__ double cast< int, double >(const int x)
Cast int type to double type in round-to-nearest-even mode.
__host__ __device__ OutputDataType cast(const InputDataType x)
Cast a floating point type to another floating point type.
__host__ __device__ float cast< float, float >(const float x)
Cast __nv_fp8_e5m2 type to float type.
__host__ __device__ unsigned int cast< float, unsigned int >(const float x)
Cast float type to unsigned int type in round-to-nearest-even mode.
__host__ __device__ long long int cast< float, long long int >(const float x)
Cast float type to long long int type in round-to-nearest-even mode.
__host__ __device__ float cast< unsigned long long int, float >(const unsigned long long int x)
Cast unsigned long long int type to __nv_fp8_e5m2 type in round-to-nearest-even mode.
__host__ __device__ float cast< unsigned int, float >(const unsigned int x)
Cast unsigned int type to __nv_fp8_e5m2 type in round-to-nearest-even mode.
__host__ __device__ int cast< double, int >(const double x)
Cast double type to int type in round-to-nearest-even mode.
__host__ __device__ double cast< long long int, double >(const long long int x)
Cast long long int type to double type in round-to-nearest-even mode.
__host__ __device__ float cast< double, float >(const double x)
Cast double type to float type.