imate
C++/CUDA Reference
Loading...
Searching...
No Matches
cublas_api.h
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11#ifndef _CU_BASIC_ALGEBRA_CUBLAS_API_H_
12#define _CU_BASIC_ALGEBRA_CUBLAS_API_H_
13
14
15// =======
16// Headers
17// =======
18
19// Avoid CUBLAS numeration value not handled in switch [-Wswitch-enum] warning
20#ifdef _MSC_VER
21 #pragma warning(push, 0) // Suppress all warnings from the followings
22 #include <cublas_v2.h>
23 #pragma warning(pop) // Restore previous warning level
24#elif defined(__INTEL_LLVM_COMPILER) || defined(__INTEL_COMPILER)
25 #pragma warning(push, 0)
26 #include <cublas_v2.h>
27 #pragma warning(pop)
28#elif defined(__GNUC__) || defined(__clang__)
29 #pragma GCC diagnostic push
30 #pragma GCC diagnostic ignored "-Wswitch-enum"
31 #include <cublas_v2.h>
32 #pragma GCC diagnostic pop
33#else
34 #include <cublas_v2.h> // cublasSgemv, cublasDgemv, cublasScopy,
35 // cublasDcopy, cublasSaxpy, cublasDaxpy,
36 // cublasSdot, cublasDdot, cublasSnrm2,
37 // cublasDnrm2, cublasSscal, cublasDscal
38 // cublasHandle_t, cublasStatus_t
39#endif
40
41// Restrict qualifier
42#if defined(_MSC_VER)
43 #define RESTRICT __restrict
44#elif defined(__INTEL_COMPILER)
45 #define RESTRICT __restrict
46#elif defined(__CUDA__) || defined(__GNUC__) || defined(__clang__)
47 #define RESTRICT __restrict__
48#else
49 #define RESTRICT
50#endif
51
52
53// ==========
54// cublas api
55// ==========
56
60
61namespace cublas_api
62{
63 // cublasXgemv
64 template <typename DataType>
65 cublasStatus_t cublasXgemv(
66 cublasHandle_t handle,
67 cublasOperation_t trans,
68 int m,
69 int n,
70 const DataType* RESTRICT alpha,
71 const DataType* RESTRICT A,
72 int lda,
73 const DataType* RESTRICT x,
74 int incx,
75 const DataType* RESTRICT beta,
76 DataType* RESTRICT y,
77 int incy);
78
79 // cublasXcopy
80 template <typename DataType>
81 cublasStatus_t cublasXcopy(
82 cublasHandle_t handle,
83 int n,
84 const DataType* RESTRICT x,
85 int incx,
86 DataType* RESTRICT y,
87 int incy);
88
89 // cublasXaxpy
90 template <typename DataType>
91 cublasStatus_t cublasXaxpy(
92 cublasHandle_t handle,
93 int n,
94 const DataType* RESTRICT alpha,
95 const DataType* RESTRICT x,
96 int incx,
97 DataType* RESTRICT y,
98 int incy);
99
100 // cublasXdot
101 template <typename DataType>
102 cublasStatus_t cublasXdot(
103 cublasHandle_t handle,
104 int n,
105 const DataType* RESTRICT x,
106 int incx,
107 const DataType* RESTRICT y,
108 int incy,
109 DataType* RESTRICT result);
110
111 // cublasXnrm2
112 template <typename DataType>
113 cublasStatus_t cublasXnrm2(
114 cublasHandle_t handle,
115 int n,
116 const DataType* RESTRICT x,
117 int incx,
118 DataType* RESTRICT result);
119
120 // cublasXscal
121 template <typename DataType>
122 cublasStatus_t cublasXscal(
123 cublasHandle_t handle,
124 int n,
125 const DataType* RESTRICT alpha,
126 DataType* RESTRICT x,
127 int incx);
128
129} // namespace cublas_api
130
131
132#endif // _CU_BASIC_ALGEBRA_CUBLAS_API_H_
#define RESTRICT
A collection of templates to wrapper cublas functions.
Definition cublas_api.cu:34
cublasStatus_t cublasXaxpy(cublasHandle_t handle, int n, const DataType *RESTRICT alpha, const DataType *RESTRICT x, int incx, DataType *RESTRICT y, int incy)
cublasStatus_t cublasXnrm2(cublasHandle_t handle, int n, const DataType *RESTRICT x, int incx, DataType *RESTRICT result)
cublasStatus_t cublasXscal(cublasHandle_t handle, int n, const DataType *RESTRICT alpha, DataType *RESTRICT x, int incx)
cublasStatus_t cublasXcopy(cublasHandle_t handle, int n, const DataType *RESTRICT x, int incx, DataType *RESTRICT y, int incy)
cublasStatus_t cublasXgemv(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const DataType *RESTRICT alpha, const DataType *RESTRICT A, int lda, const DataType *RESTRICT x, int incx, const DataType *RESTRICT beta, DataType *RESTRICT y, int incy)
cublasStatus_t cublasXdot(cublasHandle_t handle, int n, const DataType *RESTRICT x, int incx, const DataType *RESTRICT y, int incy, DataType *RESTRICT result)