imate
C++/CUDA Reference
Loading...
Searching...
No Matches
cu_matrix_operations.cu
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12// =======
13// Headers
14// =======
15
17#include <cassert> // assert
18#include <omp.h> // omp_in_parallel
19#include "../_cu_definitions/cu_types.h" // __nv_fp8_e5m2, __nv_fp8_e4m3,
20 // __half, __nv_bfloat16
21#include "../_cu_arithmetics/cu_arithmetics.h" // cu_arithmetics
22#include "./cublas_api.h" // cublas_api
23#include "./cusparse_api.h" // cusparse_api
24#include "../_definitions/definitions.h" // LARGE_ARRAY_SIZE
25#include <stdexcept> // std::invalid_argument
26
27
28// ============
29// dense matvec
30// ============
31
56
57template <typename DataType>
59 cublasHandle_t cublas_handle,
60 const DataType* RESTRICT A,
61 const DataType* RESTRICT b,
62 const LongIndexType num_rows,
63 const LongIndexType num_columns,
64 const FlagType A_is_row_major,
65 DataType* RESTRICT c)
66{
67 cublasOperation_t trans;
68 int m;
69 int n;
70 int lda;
71 DataType alpha = cu_arithmetics::cast<float, DataType>(1.0f);
72 DataType beta = cu_arithmetics::cast<float, DataType>(0.0f);
73 int incb = 1;
74 int incc = 1;
75
76 // Since cublas accepts column major (Fortran) ordering, use transpose for
77 // row_major matrix.
78 if (A_is_row_major)
79 {
80 // A is row-major, not compatible with cublas. Use transpose instead.
81 trans = CUBLAS_OP_T;
82 m = num_columns;
83 n = num_rows;
84 }
85 else
86 {
87 // A is column-major, compatible with cublas.
88 trans = CUBLAS_OP_N;
89 m = num_rows;
90 n = num_columns;
91 }
92
93 lda = m;
94
95 // Calling cublas
96 cublasStatus_t status = cublas_api::cublasXgemv<DataType>(
97 cublas_handle, trans, m, n, &alpha, A, lda, b, incb, &beta, c,
98 incc);
99
100 assert(status == CUBLAS_STATUS_SUCCESS);
101}
102
103
104// =================
105// dense matvec plus
106// =================
107
134
135template <typename DataType>
137 cublasHandle_t cublas_handle,
138 const DataType* RESTRICT A,
139 const DataType* RESTRICT b,
140 const DataType alpha,
141 const LongIndexType num_rows,
142 const LongIndexType num_columns,
143 const FlagType A_is_row_major,
144 DataType* RESTRICT c)
145{
146 DataType zero = cu_arithmetics::cast<float, DataType>(0.0f);
147 if (cu_arithmetics::is_equal(alpha, zero))
148 {
149 return;
150 }
151
152 cublasOperation_t trans;
153 int m;
154 int n;
155 int lda;
156 DataType beta = cu_arithmetics::cast<float, DataType>(1.0f);
157 int incb = 1;
158 int incc = 1;
159
160 // Since cublas accepts column major (Fortran) ordering, use transpose for
161 // row_major matrix.
162 if (A_is_row_major)
163 {
164 trans = CUBLAS_OP_T;
165 m = num_columns;
166 n = num_rows;
167 }
168 else
169 {
170 trans = CUBLAS_OP_N;
171 m = num_rows;
172 n = num_columns;
173 }
174
175 lda = m;
176
177 // Calling cublas
178 cublasStatus_t status = cublas_api::cublasXgemv<DataType>(
179 cublas_handle, trans, m, n, &alpha, A, lda, b, incb, &beta, c,
180 incc);
181
182 assert(status == CUBLAS_STATUS_SUCCESS);
183}
184
185
186// =======================
187// dense transposed matvec
188// =======================
189
215
216template <typename DataType>
218 cublasHandle_t cublas_handle,
219 const DataType* RESTRICT A,
220 const DataType* RESTRICT b,
221 const LongIndexType num_rows,
222 const LongIndexType num_columns,
223 const FlagType A_is_row_major,
224 DataType* RESTRICT c)
225{
226 cublasOperation_t trans;
227 int m;
228 int n;
229 int lda;
230 DataType alpha = cu_arithmetics::cast<float, DataType>(1.0f);
231 DataType beta = cu_arithmetics::cast<float, DataType>(0.0f);
232 int incb = 1;
233 int incc = 1;
234
235 // Since cublas accepts column major (Fortran) ordering, use non-transpose
236 // for row_major matrix.
237 if (A_is_row_major)
238 {
239 trans = CUBLAS_OP_N;
240 m = num_columns;
241 n = num_rows;
242 }
243 else
244 {
245 trans = CUBLAS_OP_T;
246 m = num_rows;
247 n = num_columns;
248 }
249
250 lda = m;
251
252 // Calling cublas
253 cublasStatus_t status = cublas_api::cublasXgemv<DataType>(
254 cublas_handle, trans, m, n, &alpha, A, lda, b, incb, &beta, c,
255 incc);
256
257 assert(status == CUBLAS_STATUS_SUCCESS);
258}
259
260
261// ============================
262// dense transposed matvec plus
263// ============================
264
292
293template <typename DataType>
295 cublasHandle_t cublas_handle,
296 const DataType* RESTRICT A,
297 const DataType* RESTRICT b,
298 const DataType alpha,
299 const LongIndexType num_rows,
300 const LongIndexType num_columns,
301 const FlagType A_is_row_major,
302 DataType* RESTRICT c)
303{
304 DataType zero = cu_arithmetics::cast<float, DataType>(0.0f);
305 if (cu_arithmetics::is_equal(alpha, zero))
306 {
307 return;
308 }
309
310 cublasOperation_t trans;
311 int m;
312 int n;
313 int lda;
314 DataType beta = cu_arithmetics::cast<float, DataType>(0.0f);
315 int incb = 1;
316 int incc = 1;
317
318 // Since cublas accepts column major (Fortran) ordering, use non-transpose
319 // for row_major matrix.
320 if (A_is_row_major)
321 {
322 trans = CUBLAS_OP_N;
323 m = num_columns;
324 n = num_rows;
325 }
326 else
327 {
328 trans = CUBLAS_OP_T;
329 m = num_rows;
330 n = num_columns;
331 }
332
333 lda = m;
334
335 // Calling cublas
336 cublasStatus_t status = cublas_api::cublasXgemv<DataType>(
337 cublas_handle, trans, m, n, &alpha, A, lda, b, incb, &beta, c,
338 incc);
339
340 assert(status == CUBLAS_STATUS_SUCCESS);
341}
342
343
344// ==========
345// csr matvec
346// ==========
347
374
375template <typename DataType>
377 cusparseHandle_t cusparse_handle,
378 const DataType* RESTRICT A_data,
379 const LongIndexType* RESTRICT A_column_indices,
380 const LongIndexType* RESTRICT A_index_pointer,
381 const DataType* RESTRICT b,
382 const LongIndexType num_rows,
383 DataType* RESTRICT c)
384{
385 throw std::runtime_error("Function not implemented.");
386}
387
388
389// ===============
390// csr matvec plus
391// ===============
392
423
424template <typename DataType>
426 cusparseHandle_t cusparse_handle,
427 const DataType* RESTRICT A_data,
428 const LongIndexType* RESTRICT A_column_indices,
429 const LongIndexType* RESTRICT A_index_pointer,
430 const DataType* RESTRICT b,
431 const DataType alpha,
432 const LongIndexType num_rows,
433 DataType* RESTRICT c)
434{
435 throw std::runtime_error("Function not implemented.");
436}
437
438
439// =====================
440// csr transposed matvec
441// =====================
442
471
472template <typename DataType>
474 cusparseHandle_t cusparse_handle,
475 const DataType* RESTRICT A_data,
476 const LongIndexType* RESTRICT A_column_indices,
477 const LongIndexType* RESTRICT A_index_pointer,
478 const DataType* RESTRICT b,
479 const LongIndexType num_rows,
480 const LongIndexType num_columns,
481 DataType* RESTRICT c)
482{
483 throw std::runtime_error("Function not implemented.");
484}
485
486
487// ==========================
488// csr transposed matvec plus
489// ==========================
490
523
524template <typename DataType>
526 cusparseHandle_t cusparse_handle,
527 const DataType* RESTRICT A_data,
528 const LongIndexType* RESTRICT A_column_indices,
529 const LongIndexType* RESTRICT A_index_pointer,
530 const DataType* RESTRICT b,
531 const DataType alpha,
532 const LongIndexType num_rows,
533 const LongIndexType num_columns,
534 DataType* RESTRICT c)
535{
536 throw std::runtime_error("Function not implemented.");
537}
538
539
540// ==========
541// csc matvec
542// ==========
543
572
573template <typename DataType>
575 cusparseHandle_t cusparse_handle,
576 const DataType* RESTRICT A_data,
577 const LongIndexType* RESTRICT A_row_indices,
578 const LongIndexType* RESTRICT A_index_pointer,
579 const DataType* RESTRICT b,
580 const LongIndexType num_rows,
581 const LongIndexType num_columns,
582 DataType* RESTRICT c)
583{
584 throw std::runtime_error("Function not implemented.");
585}
586
587
588// ===============
589// csc matvec plus
590// ===============
591
624
625template <typename DataType>
627 cusparseHandle_t cusparse_handle,
628 const DataType* RESTRICT A_data,
629 const LongIndexType* RESTRICT A_row_indices,
630 const LongIndexType* RESTRICT A_index_pointer,
631 const DataType* RESTRICT b,
632 const DataType alpha,
633 const LongIndexType num_rows,
634 const LongIndexType num_columns,
635 DataType* RESTRICT c)
636{
637 throw std::runtime_error("Function not implemented.");
638}
639
640
641// =====================
642// csc transposed matvec
643// =====================
644
672
673template <typename DataType>
675 cusparseHandle_t cusparse_handle,
676 const DataType* RESTRICT A_data,
677 const LongIndexType* RESTRICT A_row_indices,
678 const LongIndexType* RESTRICT A_index_pointer,
679 const DataType* RESTRICT b,
680 const LongIndexType num_columns,
681 DataType* RESTRICT c)
682{
683 throw std::runtime_error("Function not implemented.");
684}
685
686
687// ==========================
688// csc transposed matvec plus
689// ==========================
690
721
722template <typename DataType>
724 cusparseHandle_t cusparse_handle,
725 const DataType* RESTRICT A_data,
726 const LongIndexType* RESTRICT A_row_indices,
727 const LongIndexType* RESTRICT A_index_pointer,
728 const DataType* RESTRICT b,
729 const DataType alpha,
730 const LongIndexType num_columns,
731 DataType* RESTRICT c)
732{
733 throw std::runtime_error("Function not implemented.");
734}
735
736
737// ==================
738// create band matrix
739// ==================
740
781
782template <typename DataType>
784 cusparseHandle_t cublas_handle,
785 const DataType* RESTRICT diagonals,
786 const DataType* RESTRICT supdiagonals,
787 const IndexType non_zero_size,
788 const FlagType tridiagonal,
789 DataType** RESTRICT matrix)
790{
791 throw std::runtime_error("Function not implemented.");
792}
793
794
795// ===============================
796// Explicit template instantiation
797// ===============================
798
799#if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
801#endif
802
803#if defined(USE_CUDA_FP8_E4M3) && (USE_CUDA_FP8_E4M3 == 1)
805#endif
806
807#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
808 template class cuMatrixOperations<__half>;
809#endif
810
811#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
813#endif
814
815#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
816 template class cuMatrixOperations<float>;
817#endif
818
819#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
820 template class cuMatrixOperations<double>;
821#endif
#define RESTRICT
A static class for matrix-vector operations, which are similar to the level-2 operations of the BLAS ...
static void csr_transposed_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void create_band_matrix(cusparseHandle_t cublas_handle, const DataType *RESTRICT diagonals, const DataType *RESTRICT supdiagonals, const IndexType non_zero_size, const FlagType tridiagonal, DataType **matrix)
Creates bi-diagonal or symmetric tri-diagonal matrix from the diagonal array (diagonals) and off-diag...
static void csc_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void dense_matvec(cublasHandle_t cublas_handle, const DataType *RESTRICT A, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *RESTRICT c)
Computes the matrix vector multiplication where is a dense matrix.
static void dense_transposed_matvec(cublasHandle_t cublas_handle, const DataType *RESTRICT A, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *RESTRICT c)
Computes matrix vector multiplication where is dense, and is the transpose of the matrix .
static void csc_transposed_matvec(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void csr_matvec(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_rows, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csc_transposed_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void csr_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csr_transposed_matvec(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csc_matvec(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void dense_transposed_matvec_plus(cublasHandle_t cublas_handle, const DataType *RESTRICT A, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *RESTRICT c)
Computes where is dense, and is the transpose of the matrix .
static void dense_matvec_plus(cublasHandle_t cublas_handle, const DataType *RESTRICT A, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *RESTRICT c)
Computes the operation where is a dense matrix.
__host__ __device__ DataType abs(const DataType x)
Absolute value of a floating point number.
bool is_equal(DataType x, DataType y)
Check if two floating point numbers are equal within a tolerance.
int LongIndexType
Definition types.h:60
int FlagType
Definition types.h:68
int IndexType
Definition types.h:65