imate
C++/CUDA Reference
Loading...
Searching...
No Matches
c_matrix_operations.h
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12#ifndef _C_BASIC_ALGEBRA_C_MATRIX_OPERATIONS_H_
13#define _C_BASIC_ALGEBRA_C_MATRIX_OPERATIONS_H_
14
15// =======
16// Headers
17// =======
18
19#include "../_definitions/types.h" // IndexType, LongIndexType, FlagType
20
21// Restrict qualifier
22#if defined(_MSC_VER)
23 #define RESTRICT __restrict
24#elif defined(__INTEL_COMPILER)
25 #define RESTRICT __restrict
26#elif defined(__CUDA__) || defined(__GNUC__) || defined(__clang__)
27 #define RESTRICT __restrict__
28#else
29 #define RESTRICT
30#endif
31
32
33// ===================
34// c Matrix Operations
35// ===================
36
65
66template <typename DataType>
68{
69 public:
70
71 // dense matvec
72 static void dense_matvec(
73 const DataType* RESTRICT A,
74 const DataType* RESTRICT b,
75 const LongIndexType num_rows,
76 const LongIndexType num_columns,
77 const FlagType A_is_row_major,
78 const FlagType A_is_symmetric,
79 DataType* RESTRICT c);
80
81 // dense matvec plus
82 static void dense_matvec_plus(
83 const DataType* RESTRICT A,
84 const DataType* RESTRICT b,
85 const DataType alpha,
86 const LongIndexType num_rows,
87 const LongIndexType num_columns,
88 const FlagType A_is_row_major,
89 const FlagType A_is_symmetric,
90 DataType* RESTRICT c);
91
92 // dense transposed matvec
93 static void dense_transposed_matvec(
94 const DataType* RESTRICT A,
95 const DataType* RESTRICT b,
96 const LongIndexType num_rows,
97 const LongIndexType num_columns,
98 const FlagType A_is_row_major,
99 const FlagType A_is_symmetric,
100 DataType* RESTRICT c);
101
102 // dense transposed matvec plus
104 const DataType* RESTRICT A,
105 const DataType* RESTRICT b,
106 const DataType alpha,
107 const LongIndexType num_rows,
108 const LongIndexType num_columns,
109 const FlagType A_is_row_major,
110 const FlagType A_is_symmetric,
111 DataType* RESTRICT c);
112
113 // CSR matvec
114 static void csr_matvec(
115 const DataType* RESTRICT A_data,
116 const LongIndexType* RESTRICT A_column_indices,
117 const LongIndexType* RESTRICT A_index_pointer,
118 const DataType* RESTRICT b,
119 const LongIndexType num_rows,
120 DataType* RESTRICT c);
121
122 // CSR matvec plus
123 static void csr_matvec_plus(
124 const DataType* RESTRICT A_data,
125 const LongIndexType* RESTRICT A_column_indices,
126 const LongIndexType* RESTRICT A_index_pointer,
127 const DataType* RESTRICT b,
128 const DataType alpha,
129 const LongIndexType num_rows,
130 DataType* RESTRICT c);
131
132 // CSR transposed matvec
133 static void csr_transposed_matvec(
134 const DataType* RESTRICT A_data,
135 const LongIndexType* RESTRICT A_column_indices,
136 const LongIndexType* RESTRICT A_index_pointer,
137 const FlagType A_is_symmetric,
138 const DataType* RESTRICT b,
139 const LongIndexType num_rows,
140 const LongIndexType num_columns,
141 DataType* RESTRICT c);
142
143 // CSR transposed matvec plus
144 static void csr_transposed_matvec_plus(
145 const DataType* RESTRICT A_data,
146 const LongIndexType* RESTRICT A_column_indices,
147 const LongIndexType* RESTRICT A_index_pointer,
148 const FlagType A_is_symmetric,
149 const DataType* RESTRICT b,
150 const DataType alpha,
151 const LongIndexType num_rows,
152 DataType* RESTRICT c);
153
154 // CSC matvec
155 static void csc_matvec(
156 const DataType* RESTRICT A_data,
157 const LongIndexType* RESTRICT A_row_indices,
158 const LongIndexType* RESTRICT A_index_pointer,
159 const FlagType A_is_symmetric,
160 const DataType* RESTRICT b,
161 const LongIndexType num_rows,
162 const LongIndexType num_columns,
163 DataType* RESTRICT c);
164
165 // CSC matvec plus
166 static void csc_matvec_plus(
167 const DataType* RESTRICT A_data,
168 const LongIndexType* RESTRICT A_row_indices,
169 const LongIndexType* RESTRICT A_index_pointer,
170 const FlagType A_is_symmetric,
171 const DataType* RESTRICT b,
172 const DataType alpha,
173 const LongIndexType num_columns,
174 DataType* RESTRICT c);
175
176 // CSC transposed matvec
177 static void csc_transposed_matvec(
178 const DataType* RESTRICT A_data,
179 const LongIndexType* RESTRICT A_row_indices,
180 const LongIndexType* RESTRICT A_index_pointer,
181 const DataType* RESTRICT b,
182 const LongIndexType num_columns,
183 DataType* RESTRICT c);
184
185 // CSC transposed matvec plus
186 static void csc_transposed_matvec_plus(
187 const DataType* RESTRICT A_data,
188 const LongIndexType* RESTRICT A_row_indices,
189 const LongIndexType* RESTRICT A_index_pointer,
190 const DataType* RESTRICT b,
191 const DataType alpha,
192 const LongIndexType num_columns,
193 DataType* RESTRICT c);
194
195 // Create Band Matrix
196 static void create_band_matrix(
197 DataType* RESTRICT A,
198 const LongIndexType num_rows,
199 const LongIndexType num_columns,
200 const FlagType A_is_row_major,
201 const DataType* RESTRICT diagonals,
202 const DataType* RESTRICT supdiagonals,
203 const IndexType non_zero_size,
204 const FlagType tridiagonal);
205};
206
207#endif // _C_BASIC_ALGEBRA_C_MATRIX_OPERATIONS_H_
#define RESTRICT
A static class for matrix-vector operations, which are similar to the level-2 operations of the BLAS ...
static void csc_transposed_matvec(const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void dense_matvec(const DataType *RESTRICT A, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, const FlagType A_is_symmetric, DataType *RESTRICT c)
Computes the matrix vector multiplication where is a dense matrix.
static void csr_transposed_matvec(const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const FlagType A_is_symmetric, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void dense_matvec_plus(const DataType *RESTRICT A, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, const FlagType A_is_symmetric, DataType *RESTRICT c)
Computes the operation where is a dense matrix.
static void create_band_matrix(DataType *RESTRICT A, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, const DataType *RESTRICT diagonals, const DataType *RESTRICT supdiagonals, const IndexType non_zero_size, const FlagType tridiagonal)
Creates bi-diagonal or symmetric tri-diagonal matrix from the diagonal array (diagonals) and off-diag...
static void csc_transposed_matvec_plus(const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void csc_matvec_plus(const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const FlagType A_is_symmetric, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void dense_transposed_matvec(const DataType *RESTRICT A, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, const FlagType A_is_symmetric, DataType *RESTRICT c)
Computes matrix vector multiplication where is dense, and is the transpose of the matrix .
static void csr_transposed_matvec_plus(const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const FlagType A_is_symmetric, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csr_matvec(const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const LongIndexType num_rows, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void dense_transposed_matvec_plus(const DataType *RESTRICT A, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, const FlagType A_is_symmetric, DataType *RESTRICT c)
Computes where is dense, and is the transpose of the matrix .
static void csc_matvec(const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_row_indices, const LongIndexType *RESTRICT A_index_pointer, const FlagType A_is_symmetric, const DataType *RESTRICT b, const LongIndexType num_rows, const LongIndexType num_columns, DataType *RESTRICT c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void csr_matvec_plus(const DataType *RESTRICT A_data, const LongIndexType *RESTRICT A_column_indices, const LongIndexType *RESTRICT A_index_pointer, const DataType *RESTRICT b, const DataType alpha, const LongIndexType num_rows, DataType *RESTRICT c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
int LongIndexType
Definition types.h:60
int FlagType
Definition types.h:68
int IndexType
Definition types.h:65