imate
C++/CUDA Reference
Loading...
Searching...
No Matches
cu_matrix_operations.h
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12#ifndef _CU_BASIC_ALGEBRA_CU_MATRIX_OPERATIONS_H_
13#define _CU_BASIC_ALGEBRA_CU_MATRIX_OPERATIONS_H_
14
15// =======
16// Headers
17// =======
18
19#include <cusparse.h> // cusparseHandle_t
20#include "../_definitions/types.h" // IndexType, LongIndexType, FlagType
21
22// Avoid CUBLAS numeration value not handled in switch [-Wswitch-enum] warning
23#ifdef _MSC_VER
24 #pragma warning(push, 0) // Suppress all warnings from the followings
25 #include <cublas_v2.h>
26 #pragma warning(pop) // Restore previous warning level
27#elif defined(__INTEL_LLVM_COMPILER) || defined(__INTEL_COMPILER)
28 #pragma warning(push, 0)
29 #include <cublas_v2.h>
30 #pragma warning(pop)
31#elif defined(__GNUC__) || defined(__clang__)
32 #pragma GCC diagnostic push
33 #pragma GCC diagnostic ignored "-Wswitch-enum"
34 #include <cublas_v2.h>
35 #pragma GCC diagnostic pop
36#else
37 #include <cublas_v2.h>
38#endif
39
40// Restrict qualifier
41#if defined(_MSC_VER)
42 #define RESTRICT __restrict
43#elif defined(__INTEL_COMPILER)
44 #define RESTRICT __restrict
45#elif defined(__CUDA__) || defined(__GNUC__) || defined(__clang__)
46 #define RESTRICT __restrict__
47#else
48 #define RESTRICT
49#endif
50
51
52// =================
53// Matrix Operations
54// =================
55
84
85template <typename DataType>
87{
88 public:
89
90 // dense matvec
91 static void dense_matvec(
92 cublasHandle_t cublas_handle,
93 const DataType* RESTRICT A,
94 const DataType* RESTRICT b,
95 const LongIndexType num_rows,
96 const LongIndexType num_columns,
97 const FlagType A_is_row_major,
98 DataType* RESTRICT c);
99
100 // dense matvec plus
101 static void dense_matvec_plus(
102 cublasHandle_t cublas_handle,
103 const DataType* RESTRICT A,
104 const DataType* RESTRICT b,
105 const DataType alpha,
106 const LongIndexType num_rows,
107 const LongIndexType num_columns,
108 const FlagType A_is_row_major,
109 DataType* RESTRICT c);
110
111 // dense transposed matvec
112 static void dense_transposed_matvec(
113 cublasHandle_t cublas_handle,
114 const DataType* RESTRICT A,
115 const DataType* RESTRICT b,
116 const LongIndexType num_rows,
117 const LongIndexType num_columns,
118 const FlagType A_is_row_major,
119 DataType* RESTRICT c);
120
121 // dense transposed matvec plus
123 cublasHandle_t cublas_handle,
124 const DataType* RESTRICT A,
125 const DataType* RESTRICT b,
126 const DataType alpha,
127 const LongIndexType num_rows,
128 const LongIndexType num_columns,
129 const FlagType A_is_row_major,
130 DataType* RESTRICT c);
131
132 // CSR matvec
133 static void csr_matvec(
134 cusparseHandle_t cusparse_handle,
135 const DataType* RESTRICT A_data,
136 const LongIndexType* RESTRICT A_column_indices,
137 const LongIndexType* RESTRICT A_index_pointer,
138 const DataType* RESTRICT b,
139 const LongIndexType num_rows,
140 DataType* RESTRICT c);
141
142 // CSR matvec plus
143 static void csr_matvec_plus(
144 cusparseHandle_t cusparse_handle,
145 const DataType* RESTRICT A_data,
146 const LongIndexType* RESTRICT A_column_indices,
147 const LongIndexType* RESTRICT A_index_pointer,
148 const DataType* RESTRICT b,
149 const DataType alpha,
150 const LongIndexType num_rows,
151 DataType* RESTRICT c);
152
153 // CSR transposed matvec
154 static void csr_transposed_matvec(
155 cusparseHandle_t cusparse_handle,
156 const DataType* RESTRICT A_data,
157 const LongIndexType* RESTRICT A_column_indices,
158 const LongIndexType* RESTRICT A_index_pointer,
159 const DataType* RESTRICT b,
160 const LongIndexType num_rows,
161 const LongIndexType num_columns,
162 DataType* RESTRICT c);
163
164 // CSR transposed matvec plus
165 static void csr_transposed_matvec_plus(
166 cusparseHandle_t cusparse_handle,
167 const DataType* RESTRICT A_data,
168 const LongIndexType* RESTRICT A_column_indices,
169 const LongIndexType* RESTRICT A_index_pointer,
170 const DataType* RESTRICT b,
171 const DataType alpha,
172 const LongIndexType num_rows,
173 const LongIndexType num_columns,
174 DataType* RESTRICT c);
175
176 // CSC matvec
177 static void csc_matvec(
178 cusparseHandle_t cusparse_handle,
179 const DataType* RESTRICT A_data,
180 const LongIndexType* RESTRICT A_row_indices,
181 const LongIndexType* RESTRICT A_index_pointer,
182 const DataType* RESTRICT b,
183 const LongIndexType num_rows,
184 const LongIndexType num_columns,
185 DataType* RESTRICT c);
186
187 // CSC matvec plus
188 static void csc_matvec_plus(
189 cusparseHandle_t cusparse_handle,
190 const DataType* RESTRICT A_data,
191 const LongIndexType* RESTRICT A_row_indices,
192 const LongIndexType* RESTRICT A_index_pointer,
193 const DataType* RESTRICT b,
194 const DataType alpha,
195 const LongIndexType num_rows,
196 const LongIndexType num_columns,
197 DataType* RESTRICT c);
198
199 // CSC transposed matvec
200 static void csc_transposed_matvec(
201 cusparseHandle_t cusparse_handle,
202 const DataType* RESTRICT A_data,
203 const LongIndexType* RESTRICT A_row_indices,
204 const LongIndexType* RESTRICT A_index_pointer,
205 const DataType* RESTRICT b,
206 const LongIndexType num_columns,
207 DataType* RESTRICT c);
208
209 // CSC transposed matvec plus
210 static void csc_transposed_matvec_plus(
211 cusparseHandle_t cusparse_handle,
212 const DataType* RESTRICT A_data,
213 const LongIndexType* RESTRICT A_row_indices,
214 const LongIndexType* RESTRICT A_index_pointer,
215 const DataType* RESTRICT b,
216 const DataType alpha,
217 const LongIndexType num_columns,
218 DataType* RESTRICT c);
219
220 // Create Band Matrix
221 static void create_band_matrix(
222 cusparseHandle_t cublas_handle,
223 const DataType* RESTRICT diagonals,
224 const DataType* RESTRICT supdiagonals,
225 const IndexType non_zero_size,
226 const FlagType tridiagonal,
227 DataType** matrix);
228};
229
230#endif // _CU_BASIC_ALGEBRA_CU_MATRIX_OPERATIONS_H_
A static class for matrix-vector operations, which are similar to the level-2 operations of the BLAS ...
static void csr_transposed_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void create_band_matrix(cusparseHandle_t cublas_handle, const DataType *RESTRICT diagonals, const DataType *RESTRICT supdiagonals, const IndexType non_zero_size, const FlagType tridiagonal, DataType **matrix)
Creates bi-diagonal or symmetric tri-diagonal matrix from the diagonal array (diagonals) and off-diag...
static void csc_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void dense_matvec(cublasHandle_t cublas_handle, const DataType *RESTRICT A, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *RESTRICT c)
Computes the matrix vector multiplication where is a dense matrix.
static void dense_transposed_matvec(cublasHandle_t cublas_handle, const DataType *RESTRICT A, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *RESTRICT c)
Computes matrix vector multiplication where is dense, and is the transpose of the matrix .
static void csc_transposed_matvec(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void csr_matvec(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_rows, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csc_transposed_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void csr_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csr_transposed_matvec(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csc_matvec(cusparseHandle_t cusparse_handle, const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void dense_transposed_matvec_plus(cublasHandle_t cublas_handle, const DataType *RESTRICT A, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *RESTRICT c)
Computes where is dense, and is the transpose of the matrix .
static void dense_matvec_plus(cublasHandle_t cublas_handle, const DataType *RESTRICT A, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *RESTRICT c)
Computes the operation where is a dense matrix.
#define RESTRICT
int LongIndexType
Definition types.h:60
int FlagType
Definition types.h:68
int IndexType
Definition types.h:65