imate
C++/CUDA Reference
c_matrix_operations.h
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3  * SPDX-License-Identifier: BSD-3-Clause
4  * SPDX-FileType: SOURCE
5  *
6  * This program is free software: you can redistribute it and/or modify it
7  * under the terms of the license found in the LICENSE.txt file in the root
8  * directory of this source tree.
9  */
10 
11 
12 #ifndef _C_BASIC_ALGEBRA_C_MATRIX_OPERATIONS_H_
13 #define _C_BASIC_ALGEBRA_C_MATRIX_OPERATIONS_H_
14 
15 // =======
16 // Headers
17 // =======
18 
19 #include "../_definitions/types.h" // IndexType, LongIndexType, FlagType
20 
21 
22 // ===================
23 // c Matrix Operations
24 // ===================
25 
54 
55 template <typename DataType>
57 {
58  public:
59 
60  // dense matvec
61  static void dense_matvec(
62  const DataType* A,
63  const DataType* b,
64  const LongIndexType num_rows,
65  const LongIndexType num_columns,
66  const FlagType A_is_row_major,
67  DataType* c);
68 
69  // dense matvec plus
70  static void dense_matvec_plus(
71  const DataType* A,
72  const DataType* b,
73  const DataType alpha,
74  const LongIndexType num_rows,
75  const LongIndexType num_columns,
76  const FlagType A_is_row_major,
77  DataType* c);
78 
79  // dense transposed matvec
80  static void dense_transposed_matvec(
81  const DataType* A,
82  const DataType* b,
83  const LongIndexType num_rows,
84  const LongIndexType num_columns,
85  const FlagType A_is_row_major,
86  DataType* c);
87 
88  // dense transposed matvec plus
89  static void dense_transposed_matvec_plus(
90  const DataType* A,
91  const DataType* b,
92  const DataType alpha,
93  const LongIndexType num_rows,
94  const LongIndexType num_columns,
95  const FlagType A_is_row_major,
96  DataType* c);
97 
98  // CSR matvec
99  static void csr_matvec(
100  const DataType* A_data,
101  const LongIndexType* A_column_indices,
102  const LongIndexType* A_index_pointer,
103  const DataType* b,
104  const LongIndexType num_rows,
105  DataType* c);
106 
107  // CSR matvec plus
108  static void csr_matvec_plus(
109  const DataType* A_data,
110  const LongIndexType* A_column_indices,
111  const LongIndexType* A_index_pointer,
112  const DataType* b,
113  const DataType alpha,
114  const LongIndexType num_rows,
115  DataType* c);
116 
117  // CSR transposed matvec
118  static void csr_transposed_matvec(
119  const DataType* A_data,
120  const LongIndexType* A_column_indices,
121  const LongIndexType* A_index_pointer,
122  const DataType* b,
123  const LongIndexType num_rows,
124  const LongIndexType num_columns,
125  DataType* c);
126 
127  // CSR transposed matvec plus
128  static void csr_transposed_matvec_plus(
129  const DataType* A_data,
130  const LongIndexType* A_column_indices,
131  const LongIndexType* A_index_pointer,
132  const DataType* b,
133  const DataType alpha,
134  const LongIndexType num_rows,
135  DataType* c);
136 
137  // CSC matvec
138  static void csc_matvec(
139  const DataType* A_data,
140  const LongIndexType* A_row_indices,
141  const LongIndexType* A_index_pointer,
142  const DataType* b,
143  const LongIndexType num_rows,
144  const LongIndexType num_columns,
145  DataType* c);
146 
147  // CSC matvec plus
148  static void csc_matvec_plus(
149  const DataType* A_data,
150  const LongIndexType* A_row_indices,
151  const LongIndexType* A_index_pointer,
152  const DataType* b,
153  const DataType alpha,
154  const LongIndexType num_columns,
155  DataType* c);
156 
157  // CSC transposed matvec
158  static void csc_transposed_matvec(
159  const DataType* A_data,
160  const LongIndexType* A_row_indices,
161  const LongIndexType* A_index_pointer,
162  const DataType* b,
163  const LongIndexType num_columns,
164  DataType* c);
165 
166  // CSC transposed matvec plus
167  static void csc_transposed_matvec_plus(
168  const DataType* A_data,
169  const LongIndexType* A_row_indices,
170  const LongIndexType* A_index_pointer,
171  const DataType* b,
172  const DataType alpha,
173  const LongIndexType num_columns,
174  DataType* c);
175 
176  // Create Band Matrix
177  static void create_band_matrix(
178  const DataType* diagonals,
179  const DataType* supdiagonals,
180  const IndexType non_zero_size,
181  const FlagType tridiagonal,
182  DataType** matrix);
183 };
184 
185 #endif // _C_BASIC_ALGEBRA_C_MATRIX_OPERATIONS_H_
A static class for matrix-vector operations, which are similar to the level-2 operations of the BLAS ...
static void csr_transposed_matvec(const DataType *A_data, const LongIndexType *A_column_indices, const LongIndexType *A_index_pointer, const DataType *b, const LongIndexType num_rows, const LongIndexType num_columns, DataType *c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csr_matvec(const DataType *A_data, const LongIndexType *A_column_indices, const LongIndexType *A_index_pointer, const DataType *b, const LongIndexType num_rows, DataType *c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void dense_transposed_matvec_plus(const DataType *A, const DataType *b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *c)
Computes where is dense, and is the transpose of the matrix .
static void csr_matvec_plus(const DataType *A_data, const LongIndexType *A_column_indices, const LongIndexType *A_index_pointer, const DataType *b, const DataType alpha, const LongIndexType num_rows, DataType *c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void csc_transposed_matvec_plus(const DataType *A_data, const LongIndexType *A_row_indices, const LongIndexType *A_index_pointer, const DataType *b, const DataType alpha, const LongIndexType num_columns, DataType *c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void dense_transposed_matvec(const DataType *A, const DataType *b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *c)
Computes matrix vector multiplication where is dense, and is the transpose of the matrix .
static void csc_transposed_matvec(const DataType *A_data, const LongIndexType *A_row_indices, const LongIndexType *A_index_pointer, const DataType *b, const LongIndexType num_columns, DataType *c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void dense_matvec_plus(const DataType *A, const DataType *b, const DataType alpha, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *c)
Computes the operation where is a dense matrix.
static void csc_matvec(const DataType *A_data, const LongIndexType *A_row_indices, const LongIndexType *A_index_pointer, const DataType *b, const LongIndexType num_rows, const LongIndexType num_columns, DataType *c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void csr_transposed_matvec_plus(const DataType *A_data, const LongIndexType *A_column_indices, const LongIndexType *A_index_pointer, const DataType *b, const DataType alpha, const LongIndexType num_rows, DataType *c)
Computes where is compressed sparse row (CSR) matrix and is a dense vector. The output is a dense...
static void dense_matvec(const DataType *A, const DataType *b, const LongIndexType num_rows, const LongIndexType num_columns, const FlagType A_is_row_major, DataType *c)
Computes the matrix vector multiplication where is a dense matrix.
static void csc_matvec_plus(const DataType *A_data, const LongIndexType *A_row_indices, const LongIndexType *A_index_pointer, const DataType *b, const DataType alpha, const LongIndexType num_columns, DataType *c)
Computes where is compressed sparse column (CSC) matrix and is a dense vector. The output is a de...
static void create_band_matrix(const DataType *diagonals, const DataType *supdiagonals, const IndexType non_zero_size, const FlagType tridiagonal, DataType **matrix)
Creates bi-diagonal or symmetric tri-diagonal matrix from the diagonal array (diagonals) and off-diag...
int LongIndexType
Definition: types.h:60
int FlagType
Definition: types.h:68
int IndexType
Definition: types.h:65