19#include "../_cu_definitions/cu_types.h"
21#include "../_cu_arithmetics/cu_arithmetics.h"
40template <
typename DataType>
42 cublasHandle_t cublas_handle,
43 const DataType*
RESTRICT input_vector,
51 cublas_handle, vector_size, input_vector, incx, output_vector,
54 assert(status == CUBLAS_STATUS_SUCCESS);
75template <
typename DataType>
77 cublasHandle_t cublas_handle,
78 const DataType*
RESTRICT input_vector,
83 cublasStatus_t status;
89 incx, output_vector, incy);
91 assert(status == CUBLAS_STATUS_SUCCESS);
97 assert(status == CUBLAS_STATUS_SUCCESS);
127template <
typename DataType>
129 cublasHandle_t cublas_handle,
130 const DataType*
RESTRICT input_vector,
132 const DataType scale,
144 DataType neg_scale = -scale;
146 cublas_handle, vector_size, &neg_scale, input_vector, incx,
147 output_vector, incy);
149 assert(status == CUBLAS_STATUS_SUCCESS);
168template <
typename DataType>
170 cublasHandle_t cublas_handle,
180 cublas_handle, vector_size, vector1, incx, vector2, incy,
183 assert(status == CUBLAS_STATUS_SUCCESS);
203template <
typename DataType>
205 cublasHandle_t cublas_handle,
213 cublas_handle, vector_size, vector, incx, &norm);
215 assert(status == CUBLAS_STATUS_SUCCESS);
236template <
typename DataType>
238 cublasHandle_t cublas_handle,
244 cublas_handle, vector, vector_size);
252 cublas_handle, vector_size, &scale, vector, incx);
254 assert(status == CUBLAS_STATUS_SUCCESS);
277template <
typename DataType>
279 cublasHandle_t cublas_handle,
286 cublas_handle, vector, vector_size);
304#if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
308#if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
312#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
316#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
320#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
324#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
A static class for vector operations, similar to level-1 operations of the BLAS library....
static DataType normalize_vector_in_place(cublasHandle_t cublas_handle, DataType *RESTRICT vector, const LongIndexType vector_size)
Normalizes a vector based on Euclidean 2-norm. The result is written in-place.
static void copy_scaled_vector(cublasHandle_t cublas_handle, const DataType *RESTRICT input_vector, const LongIndexType vector_size, const DataType scale, DataType *RESTRICT output_vector)
Scales a vector and stores to a new vector.
static void subtract_scaled_vector(cublasHandle_t cublas_handle, const DataType *RESTRICT input_vector, const LongIndexType vector_size, const DataType scale, DataType *RESTRICT output_vector)
Subtracts the scaled input vector from the output vector.
static DataType normalize_vector_and_copy(cublasHandle_t cublas_handle, const DataType *RESTRICT vector, const LongIndexType vector_size, DataType *RESTRICT output_vector)
Normalizes a vector based on Euclidean 2-norm. The result is written into another vector.
static void copy_vector(cublasHandle_t cublas_handle, const DataType *RESTRICT input_vector, const LongIndexType vector_size, DataType *RESTRICT output_vector)
Copies a vector to a new vector. Result is written in-place.
static DataType inner_product(cublasHandle_t cublas_handle, const DataType *RESTRICT vector1, const DataType *RESTRICT vector2, const LongIndexType vector_size)
Computes Euclidean inner product of two vectors.
static DataType euclidean_norm(cublasHandle_t cublas_handle, const DataType *RESTRICT vector, const LongIndexType vector_size)
Computes the Euclidean 2-norm of a 1D array.
__host__ __device__ DataType abs(const DataType x)
Absolute value of a floating point number.
__host__ __device__ DataType div(const DataType x, const DataType y)
Divide two floating point numbers in round-to-nearest-even mode.
bool is_equal(DataType x, DataType y)
Check if two floating point numbers are equal within a tolerance.
cublasStatus_t cublasXaxpy(cublasHandle_t handle, int n, const DataType *RESTRICT alpha, const DataType *RESTRICT x, int incx, DataType *RESTRICT y, int incy)
cublasStatus_t cublasXnrm2(cublasHandle_t handle, int n, const DataType *RESTRICT x, int incx, DataType *RESTRICT result)
cublasStatus_t cublasXscal(cublasHandle_t handle, int n, const DataType *RESTRICT alpha, DataType *RESTRICT x, int incx)
cublasStatus_t cublasXcopy(cublasHandle_t handle, int n, const DataType *RESTRICT x, int incx, DataType *RESTRICT y, int incy)
cublasStatus_t cublasXdot(cublasHandle_t handle, int n, const DataType *RESTRICT x, int incx, const DataType *RESTRICT y, int incy, DataType *RESTRICT result)