imate
C++/CUDA Reference
cu_matrix_operations.h
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3  * SPDX-License-Identifier: BSD-3-Clause
4  * SPDX-FileType: SOURCE
5  *
6  * This program is free software: you can redistribute it and/or modify it
7  * under the terms of the license found in the LICENSE.txt file in the root
8  * directory of this source tree.
9  */
10 
11 
12 #ifndef _CU_BASIC_ALGEBRA_CU_MATRIX_OPERATIONS_H_
13 #define _CU_BASIC_ALGEBRA_CU_MATRIX_OPERATIONS_H_
14 
15 // =======
16 // Headers
17 // =======
18 
19 #include <cublas_v2.h> // cublasHandle_t
20 #include <cusparse.h> // cusparseHandle_t
21 #include "../_definitions/types.h" // IndexType, LongIndexType, FlagType
22 
23 
24 // =================
25 // Matrix Operations
26 // =================
27 
56 
57 template <typename DataType>
59 {
60  public:
61 
62  // dense matvec
63  static void dense_matvec(
64  cublasHandle_t cublas_handle,
65  const DataType* A,
66  const DataType* b,
67  const LongIndexType num_rows,
68  const LongIndexType num_columns,
69  const FlagType A_is_row_major,
70  DataType* c);
71 
72  // dense matvec plus
73  static void dense_matvec_plus(
74  cublasHandle_t cublas_handle,
75  const DataType* A,
76  const DataType* b,
77  const DataType alpha,
78  const LongIndexType num_rows,
79  const LongIndexType num_columns,
80  const FlagType A_is_row_major,
81  DataType* c);
82 
83  // dense transposed matvec
84  static void dense_transposed_matvec(
85  cublasHandle_t cublas_handle,
86  const DataType* A,
87  const DataType* b,
88  const LongIndexType num_rows,
89  const LongIndexType num_columns,
90  const FlagType A_is_row_major,
91  DataType* c);
92 
93  // dense transposed matvec plus
94  static void dense_transposed_matvec_plus(
95  cublasHandle_t cublas_handle,
96  const DataType* A,
97  const DataType* b,
98  const DataType alpha,
99  const LongIndexType num_rows,
100  const LongIndexType num_columns,
101  const FlagType A_is_row_major,
102  DataType* c);
103 
104  // CSR matvec
105  static void csr_matvec(
106  cusparseHandle_t cusparse_handle,
107  const DataType* A_data,
108  const LongIndexType* A_column_indices,
109  const LongIndexType* A_index_pointer,
110  const DataType* b,
111  const LongIndexType num_rows,
112  DataType* c);
113 
114  // CSR matvec plus
115  static void csr_matvec_plus(
116  cusparseHandle_t cusparse_handle,
117  const DataType* A_data,
118  const LongIndexType* A_column_indices,
119  const LongIndexType* A_index_pointer,
120  const DataType* b,
121  const DataType alpha,
122  const LongIndexType num_rows,
123  DataType* c);
124 
125  // CSR transposed matvec
126  static void csr_transposed_matvec(
127  cusparseHandle_t cusparse_handle,
128  const DataType* A_data,
129  const LongIndexType* A_column_indices,
130  const LongIndexType* A_index_pointer,
131  const DataType* b,
132  const LongIndexType num_rows,
133  const LongIndexType num_columns,
134  DataType* c);
135 
136  // CSR transposed matvec plus
137  static void csr_transposed_matvec_plus(
138  cusparseHandle_t cusparse_handle,
139  const DataType* A_data,
140  const LongIndexType* A_column_indices,
141  const LongIndexType* A_index_pointer,
142  const DataType* b,
143  const DataType alpha,
144  const LongIndexType num_rows,
145  const LongIndexType num_columns,
146  DataType* c);
147 
148  // CSC matvec
149  static void csc_matvec(
150  cusparseHandle_t cusparse_handle,
151  const DataType* A_data,
152  const LongIndexType* A_row_indices,
153  const LongIndexType* A_index_pointer,
154  const DataType* b,
155  const LongIndexType num_rows,
156  const LongIndexType num_columns,
157  DataType* c);
158 
159  // CSC matvec plus
160  static void csc_matvec_plus(
161  cusparseHandle_t cusparse_handle,
162  const DataType* A_data,
163  const LongIndexType* A_row_indices,
164  const LongIndexType* A_index_pointer,
165  const DataType* b,
166  const DataType alpha,
167  const LongIndexType num_rows,
168  const LongIndexType num_columns,
169  DataType* c);
170 
171  // CSC transposed matvec
172  static void csc_transposed_matvec(
173  cusparseHandle_t cusparse_handle,
174  const DataType* A_data,
175  const LongIndexType* A_row_indices,
176  const LongIndexType* A_index_pointer,
177  const DataType* b,
178  const LongIndexType num_columns,
179  DataType* c);
180 
181  // CSC transposed matvec plus
182  static void csc_transposed_matvec_plus(
183  cusparseHandle_t cusparse_handle,
184  const DataType* A_data,
185  const LongIndexType* A_row_indices,
186  const LongIndexType* A_index_pointer,
187  const DataType* b,
188  const DataType alpha,
189  const LongIndexType num_columns,
190  DataType* c);
191 
192  // Create Band Matrix
193  static void create_band_matrix(
194  cusparseHandle_t cublas_handle,
195  const DataType* diagonals,
196  const DataType* supdiagonals,
197  const IndexType non_zero_size,
198  const FlagType tridiagonal,
199  DataType** matrix);
200 };
201 
202 #endif // _CU_BASIC_ALGEBRA_CU_MATRIX_OPERATIONS_H_
A static class for matrix-vector operations, which are similar to the level-2 operations of the BLAS ...
static void csc_transposed_matvec(cusparseHandle_t cusparse_handle, const DataType *A_data, const LongIndexType *A_row_indices, const LongIndexType *A_index_pointer, const DataType *b, const LongIndexType num_columns, DataType *c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void csr_matvec(cusparseHandle_t cusparse_handle, const DataType *A_data, const LongIndexType *A_column_indices, const LongIndexType *A_index_pointer, const DataType *b, const LongIndexType num_rows, DataType *c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void dense_matvec(cublasHandle_t cublas_handle, const DataType *A, const DataType *b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *c)
Computes the matrix vector multiplication where is a dense matrix.
static void csc_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *A_data, const LongIndexType *A_row_indices, const LongIndexType *A_index_pointer, const DataType *b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, DataType *c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void dense_transposed_matvec_plus(cublasHandle_t cublas_handle, const DataType *A, const DataType *b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *c)
Computes where is dense, and is the transpose of the matrix .
static void csc_transposed_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *A_data, const LongIndexType *A_row_indices, const LongIndexType *A_index_pointer, const DataType *b, const DataType alpha, const LongIndexType num_columns, DataType *c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void dense_transposed_matvec(cublasHandle_t cublas_handle, const DataType *A, const DataType *b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *c)
Computes matrix vector multiplication where is dense, and is the transpose of the matrix .
static void csr_transposed_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *A_data, const LongIndexType *A_column_indices, const LongIndexType *A_index_pointer, const DataType *b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, DataType *c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void create_band_matrix(cusparseHandle_t cublas_handle, const DataType *diagonals, const DataType *supdiagonals, const IndexType non_zero_size, const FlagType tridiagonal, DataType **matrix)
Creates bi-diagonal or symmetric tri-diagonal matrix from the diagonal array (diagonals) and off-diag...
static void csr_matvec_plus(cusparseHandle_t cusparse_handle, const DataType *A_data, const LongIndexType *A_column_indices, const LongIndexType *A_index_pointer, const DataType *b, const DataType alpha, const LongIndexType num_rows, DataType *c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csr_transposed_matvec(cusparseHandle_t cusparse_handle, const DataType *A_data, const LongIndexType *A_column_indices, const LongIndexType *A_index_pointer, const DataType *b, const LongIndexType num_rows, const LongIndexType num_columns, DataType *c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void dense_matvec_plus(cublasHandle_t cublas_handle, const DataType *A, const DataType *b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *c)
Computes the operation where is a dense matrix.
static void csc_matvec(cusparseHandle_t cusparse_handle, const DataType *A_data, const LongIndexType *A_row_indices, const LongIndexType *A_index_pointer, const DataType *b, const LongIndexType num_rows, const LongIndexType num_columns, DataType *c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
int LongIndexType
Definition: types.h:60
int FlagType
Definition: types.h:68
int IndexType
Definition: types.h:65