imate
C++/CUDA Reference
Loading...
Searching...
No Matches
cblas_api.cpp
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12// =======
13// Headers
14// =======
15
16#include "./cblas_api.h"
17
18# if defined(USE_ANY_CBLAS) && (USE_ANY_CBLAS == 1)
19
20#include <cstdlib> // abort
21#include <iostream> // std::cerr
22
23
24// ===============
25// cblas interface
26// ===============
27
33
34namespace cblas_api
35{
36
37 // =====
38 // xgemv (float)
39 // =====
40
43
44 template<>
45 void xgemv<float>(
46 const CBLAS_LAYOUT layout,
47 const CBLAS_TRANSPOSE TransA,
48 const int M,
49 const int N,
50 const float alpha,
51 const float* RESTRICT A,
52 const int lda,
53 const float* RESTRICT X,
54 const int incX,
55 const float beta,
56 float* RESTRICT Y,
57 const int incY)
58 {
59 cblas_sgemv(layout, TransA, M, N, alpha, A, lda, X, incX, beta, Y,
60 incY);
61 }
62
63
64 // =====
65 // xgemv (double)
66 // =====
67
70
71 template<>
72 void xgemv<double>(
73 const CBLAS_LAYOUT layout,
74 const CBLAS_TRANSPOSE TransA,
75 const int M,
76 const int N,
77 const double alpha,
78 const double* RESTRICT A,
79 const int lda,
80 const double* RESTRICT X,
81 const int incX,
82 const double beta,
83 double* RESTRICT Y,
84 const int incY)
85 {
86 cblas_dgemv(layout, TransA, M, N, alpha, A, lda, X, incX, beta, Y,
87 incY);
88 }
89
90
91 // =====
92 // xgemv (long double)
93 // =====
94
97
98 template<>
99 void xgemv<long double>(
100 const CBLAS_LAYOUT layout,
101 const CBLAS_TRANSPOSE TransA,
102 const int M,
103 const int N,
104 const long double alpha,
105 const long double* RESTRICT A,
106 const int lda,
107 const long double* RESTRICT X,
108 const int incX,
109 const long double beta,
110 long double* RESTRICT Y,
111 const int incY)
112 {
113 // Mark unused variables to avoid compiler warnings
114 // (-Wno-unused-parameter)
115 (void) layout;
116 (void) TransA;
117 (void) M;
118 (void) N;
119 (void) alpha;
120 (void) A;
121 (void) lda;
122 (void) X;
123 (void) incX;
124 (void) beta;
125 (void) Y;
126 (void) incY;
127
128 std::cerr << "Error: cblas_?gemv for long double type is not "
129 << "implemented. To use long double type, set USE_CBLAS "
130 << "and USE_MKL to 0 and recompile the package."
131 << std::endl;
132 abort();
133 }
134
135
136 // =====
137 // xsymv (float)
138 // =====
139
142
143 template<>
144 void xsymv<float>(
145 const CBLAS_LAYOUT layout,
146 const CBLAS_UPLO Uplo,
147 const int N,
148 const float alpha,
149 const float* RESTRICT A,
150 const int lda,
151 const float* RESTRICT X,
152 const int incX,
153 const float beta,
154 float* RESTRICT Y,
155 const int incY)
156 {
157 cblas_ssymv(layout, Uplo, N, alpha, A, lda, X, incX, beta, Y, incY);
158 }
159
160
161 // =====
162 // xsymv (double)
163 // =====
164
167
168 template<>
169 void xsymv<double>(
170 const CBLAS_LAYOUT layout,
171 const CBLAS_UPLO Uplo,
172 const int N,
173 const double alpha,
174 const double* RESTRICT A,
175 const int lda,
176 const double* RESTRICT X,
177 const int incX,
178 const double beta,
179 double* RESTRICT Y,
180 const int incY)
181 {
182 cblas_dsymv(layout, Uplo, N, alpha, A, lda, X, incX, beta, Y, incY);
183 }
184
185
186 // =====
187 // xsymv (long double)
188 // =====
189
192
193 template<>
194 void xsymv<long double>(
195 const CBLAS_LAYOUT layout,
196 const CBLAS_UPLO Uplo,
197 const int N,
198 const long double alpha,
199 const long double* RESTRICT A,
200 const int lda,
201 const long double* RESTRICT X,
202 const int incX,
203 const long double beta,
204 long double* RESTRICT Y,
205 const int incY)
206 {
207 // Mark unused variables to avoid compiler warnings
208 // (-Wno-unused-parameter)
209 (void) layout;
210 (void) Uplo;
211 (void) N;
212 (void) alpha;
213 (void) A;
214 (void) lda;
215 (void) X;
216 (void) incX;
217 (void) beta;
218 (void) Y;
219 (void) incY;
220
221 std::cerr << "Error: cblas_?symv for long double type is not "
222 << "implemented. To use long double type, set USE_CBLAS "
223 << "and USE_MKL to 0 and recompile the package."
224 << std::endl;
225 abort();
226 }
227
228
229 // =====
230 // xcopy (float)
231 // =====
232
235
236 template <>
237 void xcopy<float>(
238 const int N,
239 const float* RESTRICT X,
240 const int incX,
241 float* RESTRICT Y,
242 const int incY)
243 {
244 cblas_scopy(N, X, incX, Y, incY);
245 }
246
247
248 // =====
249 // xcopy (double)
250 // =====
251
254
255 template <>
256 void xcopy<double>(
257 const int N,
258 const double* RESTRICT X,
259 const int incX,
260 double* RESTRICT Y,
261 const int incY)
262 {
263 cblas_dcopy(N, X, incX, Y, incY);
264 }
265
266
267 // =====
268 // xcopy (long double)
269 // =====
270
273
274 template <>
275 void xcopy<long double>(
276 const int N,
277 const long double* RESTRICT X,
278 const int incX,
279 long double* RESTRICT Y,
280 const int incY)
281 {
282 // Mark unused variables to avoid compiler warnings
283 // (-Wno-unused-parameter)
284 (void) N;
285 (void) X;
286 (void) incX;
287 (void) Y;
288 (void) incY;
289
290 std::cerr << "Error: cblas_?copy for long double type is not "
291 << "implemented. To use long double type, set USE_CBLAS "
292 << "and USE_MKL to 0 and recompile the package."
293 << std::endl;
294 abort();
295 }
296
297
298 // =====
299 // xaxpy (float)
300 // =====
301
304
305 template <>
306 void xaxpy<float>(
307 const int N,
308 const float alpha,
309 const float* RESTRICT X,
310 const int incX,
311 float* RESTRICT Y,
312 const int incY)
313 {
314 cblas_saxpy(N, alpha, X, incX, Y, incY);
315 }
316
317
318 // =====
319 // xaxpy (double)
320 // =====
321
324
325 template <>
326 void xaxpy<double>(
327 const int N,
328 const double alpha,
329 const double* RESTRICT X,
330 const int incX,
331 double* RESTRICT Y,
332 const int incY)
333 {
334 cblas_daxpy(N, alpha, X, incX, Y, incY);
335 }
336
337
338 // =====
339 // xaxpy (long double)
340 // =====
341
344
345 template <>
346 void xaxpy<long double>(
347 const int N,
348 const long double alpha,
349 const long double* RESTRICT X,
350 const int incX,
351 long double* RESTRICT Y,
352 const int incY)
353 {
354 // Mark unused variables to avoid compiler warnings
355 // (-Wno-unused-parameter)
356 (void) N;
357 (void) alpha;
358 (void) X;
359 (void) incX;
360 (void) Y;
361 (void) incY;
362
363 std::cerr << "Error: cblas_?axpy for long double type is not "
364 << "implemented. To use long double type, set USE_CBLAS "
365 << "and USE_MKL to 0 and recompile the package."
366 << std::endl;
367 abort();
368 }
369
370
371 // ====
372 // xdot (float)
373 // ====
374
377
378 template <>
379 float xdot<float>(
380 const int N,
381 const float* RESTRICT X,
382 const int incX,
383 const float* RESTRICT Y,
384 const int incY)
385 {
386 return cblas_sdot(N, X, incX, Y, incY);
387 }
388
389
390 // ====
391 // xdot (double)
392 // ====
393
396
397 template <>
398 double xdot<double>(
399 const int N,
400 const double* RESTRICT X,
401 const int incX,
402 const double* RESTRICT Y,
403 const int incY)
404 {
405 return cblas_ddot(N, X, incX, Y, incY);
406 }
407
408
409 // ====
410 // xdot (long double)
411 // ====
412
415
416 template <>
417 long double xdot<long double>(
418 const int N,
419 const long double* RESTRICT X,
420 const int incX,
421 const long double* RESTRICT Y,
422 const int incY)
423 {
424 // Mark unused variables to avoid compiler warnings
425 // (-Wno-unused-parameter)
426 (void) N;
427 (void) X;
428 (void) incX;
429 (void) Y;
430 (void) incY;
431
432 std::cerr << "Error: cblas_?dot for long double type is not "
433 << "implemented. To use long double type, set USE_CBLAS "
434 << "and USE_MKL to 0 and recompile the package."
435 << std::endl;
436 abort();
437 }
438
439
440 // =====
441 // xnrm2 (float)
442 // =====
443
446
447 template <>
448 float xnrm2<float>(
449 const int N,
450 const float* RESTRICT X,
451 const int incX)
452 {
453 return cblas_snrm2(N, X, incX);
454 }
455
456
457 // =====
458 // xnrm2 (double)
459 // =====
460
463
464 template <>
465 double xnrm2<double>(
466 const int N,
467 const double* RESTRICT X,
468 const int incX)
469 {
470 return cblas_dnrm2(N, X, incX);
471 }
472
473
474 // =====
475 // xnrm2 (long double)
476 // =====
477
480
481 template <>
482 long double xnrm2<long double>(
483 const int N,
484 const long double* RESTRICT X,
485 const int incX)
486 {
487 // Mark unused variables to avoid compiler warnings
488 // (-Wno-unused-parameter)
489 (void) N;
490 (void) X;
491 (void) incX;
492
493 std::cerr << "Error: cblas_?nrm2 for long double type is not "
494 << "implemented. To use long double type, set USE_CBLAS "
495 << "i and USE_MKL to 0 and recompile the package."
496 << std::endl;
497 abort();
498 }
499
500
501 // =====
502 // xscal (float)
503 // =====
504
507
508 template <>
509 void xscal<float>(
510 const int N,
511 const float alpha,
512 float* RESTRICT X,
513 const int incX)
514 {
515 cblas_sscal(N, alpha, X, incX);
516 }
517
518
519 // =====
520 // xscal (double)
521 // =====
522
525
526 template <>
527 void xscal<double>(
528 const int N,
529 const double alpha,
530 double* RESTRICT X,
531 const int incX)
532 {
533 cblas_dscal(N, alpha, X, incX);
534 }
535
536
537 // =====
538 // xscal (long double)
539 // =====
540
543
544 template <>
545 void xscal<long double>(
546 const int N,
547 const long double alpha,
548 long double* RESTRICT X,
549 const int incX)
550 {
551 // Mark unused variables to avoid compiler warnings
552 // (-Wno-unused-parameter)
553 (void) N;
554 (void) alpha;
555 (void) X;
556 (void) incX;
557
558 std::cerr << "Error: cblas_?scal for long double type is not "
559 << "implemented. To use long double type, set USE_CBLAS "
560 << "and USE_MKL to 0 and recompile the package."
561 << std::endl;
562 abort();
563 }
564} // namespace cblas_api
565
566#endif // USE_ANY_CBLAS
#define RESTRICT