19#include "../_cu_definitions/cu_types.h"
21#include "../_cu_arithmetics/cu_arithmetics.h"
24#include "../_definitions/definitions.h"
57template <
typename DataType>
59 cublasHandle_t cublas_handle,
67 cublasOperation_t trans;
96 cublasStatus_t status = cublas_api::cublasXgemv<DataType>(
97 cublas_handle, trans, m, n, &alpha, A, lda, b, incb, &beta, c,
100 assert(status == CUBLAS_STATUS_SUCCESS);
135template <
typename DataType>
137 cublasHandle_t cublas_handle,
140 const DataType alpha,
152 cublasOperation_t trans;
178 cublasStatus_t status = cublas_api::cublasXgemv<DataType>(
179 cublas_handle, trans, m, n, &alpha, A, lda, b, incb, &beta, c,
182 assert(status == CUBLAS_STATUS_SUCCESS);
216template <
typename DataType>
218 cublasHandle_t cublas_handle,
226 cublasOperation_t trans;
253 cublasStatus_t status = cublas_api::cublasXgemv<DataType>(
254 cublas_handle, trans, m, n, &alpha, A, lda, b, incb, &beta, c,
257 assert(status == CUBLAS_STATUS_SUCCESS);
293template <
typename DataType>
295 cublasHandle_t cublas_handle,
298 const DataType alpha,
310 cublasOperation_t trans;
336 cublasStatus_t status = cublas_api::cublasXgemv<DataType>(
337 cublas_handle, trans, m, n, &alpha, A, lda, b, incb, &beta, c,
340 assert(status == CUBLAS_STATUS_SUCCESS);
375template <
typename DataType>
377 cusparseHandle_t cusparse_handle,
385 throw std::runtime_error(
"Function not implemented.");
424template <
typename DataType>
426 cusparseHandle_t cusparse_handle,
431 const DataType alpha,
435 throw std::runtime_error(
"Function not implemented.");
472template <
typename DataType>
474 cusparseHandle_t cusparse_handle,
483 throw std::runtime_error(
"Function not implemented.");
524template <
typename DataType>
526 cusparseHandle_t cusparse_handle,
531 const DataType alpha,
536 throw std::runtime_error(
"Function not implemented.");
573template <
typename DataType>
575 cusparseHandle_t cusparse_handle,
584 throw std::runtime_error(
"Function not implemented.");
625template <
typename DataType>
627 cusparseHandle_t cusparse_handle,
632 const DataType alpha,
637 throw std::runtime_error(
"Function not implemented.");
673template <
typename DataType>
675 cusparseHandle_t cusparse_handle,
683 throw std::runtime_error(
"Function not implemented.");
722template <
typename DataType>
724 cusparseHandle_t cusparse_handle,
729 const DataType alpha,
733 throw std::runtime_error(
"Function not implemented.");
782template <
typename DataType>
784 cusparseHandle_t cublas_handle,
786 const DataType*
RESTRICT supdiagonals,
791 throw std::runtime_error(
"Function not implemented.");
799#if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
803#if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
807#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
811#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
815#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
819#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
A static class for matrix-vector operations, which are similar to the level-2 operations of the BLAS ...
static void csr_transposed_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void create_band_matrix(cusparseHandle_t cublas_handle, const DataType *RESTRICT diagonals, const DataType *RESTRICT supdiagonals, const IndexType non_zero_size, const FlagType tridiagonal, DataType **matrix)
Creates bi-diagonal or symmetric tri-diagonal matrix from the diagonal array (diagonals) and off-diag...
static void csc_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void dense_matvec(cublasHandle_t cublas_handle, const DataType *RESTRICT A, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *RESTRICT c)
Computes the matrix vector multiplication where is a dense matrix.
static void dense_transposed_matvec(cublasHandle_t cublas_handle, const DataType *RESTRICT A, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *RESTRICT c)
Computes matrix vector multiplication where is dense, and is the transpose of the matrix .
static void csc_transposed_matvec(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void csr_matvec(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_rows, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csc_transposed_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void csr_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csr_transposed_matvec(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csc_matvec(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void dense_transposed_matvec_plus(cublasHandle_t cublas_handle, const DataType *RESTRICT A, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *RESTRICT c)
Computes where is dense, and is the transpose of the matrix .
static void dense_matvec_plus(cublasHandle_t cublas_handle, const DataType *RESTRICT A, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *RESTRICT c)
Computes the operation where is a dense matrix.
__host__ __device__ DataType abs(const DataType x)
Absolute value of a floating point number.
bool is_equal(DataType x, DataType y)
Check if two floating point numbers are equal within a tolerance.