imate
C++/CUDA Reference
Loading...
Searching...
No Matches
cublas_impl_kernels.cu
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12// =======
13// Headers
14// =======
15
17#include <cuda_runtime.h>
18#include "../_cu_definitions/cu_types.h" // __nv_fp8_e5m2, __nv_fp8_e4m3,
19 // __half, __nv_bfloat16
20#include "./atomic_add.h" // atomicAdd (for double precision)
21#include "../_cu_arithmetics/cu_arithmetics.h" // cu_arithmetics
22
23
24// ===================
25// cublas impl kernels
26// ===================
27
29{
30 // ==================
31 // cublasTgemv kernel
32 // ==================
33
76
77 template <
78 typename DataType, typename ComputeType, unsigned int block_size>
79 __global__ void cublasTgemv_kernel(
80 const bool trans,
81 const int m,
82 const int n,
83 const DataType alpha,
84 const DataType* RESTRICT A,
85 const int lda,
86 const DataType* RESTRICT x,
87 const int incx,
88 const DataType beta,
89 DataType* RESTRICT y,
90 const int incy)
91 {
92 // Each thread is dedicated to compute an element of y
93 const unsigned int i = threadIdx.x + blockIdx.x * blockDim.x;
94
95 // Device shared memory to cache x only (note: we do not cache A since
96 // the elements of A are read only once. In contrast, x is read several
97 // times).
98 __shared__ DataType x_shared[block_size];
99
100 // Summation for the dot product of i-th row of A (or A transposed)
101 // with the entire x. The sum variable is local to i-th thread only,
102 // and is not shared with other threads of block.
103 ComputeType sum = 0.0f;
104
105 // Iterate over blocks of x elements
106 const unsigned int num_blocks = (n + block_size - 1) / block_size;
107
108 // Each thread (index i) loops over all elements j of x in block by
109 // block manner.
110 #pragma unroll
111 for (unsigned long int block_counter = 0;
112 block_counter < num_blocks;
113 ++block_counter)
114 {
115 // Get j-th index of x. This is only used to read x to copy it to
116 // the cache of x.
117 unsigned long int j = threadIdx.x + \
118 block_counter * static_cast<unsigned long int>(block_size);
119
120 // Fill x cache
121 if (j < n)
122 {
123 // Read x from global memory to shared memory
124 x_shared[threadIdx.x] = x[j * incx];
125 }
126 else
127 {
128 // If block element exceeds x size, fill cache with zeros.
129 x_shared[threadIdx.x] = \
130 cu_arithmetics::cast<ComputeType, DataType>(0.0f);
131 }
132
133 // Sync all threads of block to finish caching x from global memory
134 // to shared memory
135 __syncthreads();
136
137 // Now that one block of cache is filled, perform matrix-vector
138 // multiplication for that one block.
139 #pragma unroll
140 for (unsigned int e = 0; e < block_size; ++e)
141 {
142 // Get the index of x (called e_j) corresponding to the e-th
143 // element of the cached block. This is different than the j
144 // above.
145 unsigned long int e_j = e + \
146 block_counter * static_cast<unsigned long int>(block_size);
147
148 // It is necessary to check indices i and e_j with array sizes
149 // as these indices can exceed the array indices since thread
150 // blocks are in the sizes of multiples of 32 (as wrap size).
151 if ((i < m) && (e_j < n))
152 {
153 // Perform matrix-vector multiplication for the i-th row of
154 // A (or i-th row of transposed A) and the e_j th element
155 // of x.
156 if (trans)
157 {
159 A[i * lda + e_j]) * \
160 cu_arithmetics::cast<DataType, ComputeType>(
161 x_shared[e]);
162 }
163 else
164 {
166 A[i + e_j * lda]) * \
167 cu_arithmetics::cast<DataType, ComputeType>(
168 x_shared[e]);
169 }
170 }
171 }
172
173 // Wait till all threads of block done with their matrix-vector
174 // multiplication (each thread has its own sum variable), but they
175 // all read cached x. This sync barrier makes sure no thread
176 // proceeds the next iteration of filling new cache.
177 __syncthreads();
178 }
179
180 // Update output vector only if thread does not exceed matrix size
181 if (i < m)
182 {
183 y[i * incy] = \
184 cu_arithmetics::add<DataType>(
186 alpha,
188 ),
190 beta,
191 y[i * incy]
192 )
193 );
194 }
195 }
196
197
198 // ==================
199 // cublasTcopy kernel
200 // ==================
201
221
222 template <typename DataType>
223 __global__ void cublasTcopy_kernel(
224 const int n,
225 const DataType* RESTRICT x,
226 const int incx,
227 DataType* RESTRICT y,
228 const int incy)
229 {
230 int i = threadIdx.x + blockIdx.x * blockDim.x;
231
232 if (i < n)
233 {
234 y[i * incy] = x[i * incx];
235 }
236 }
237
238
239 // ==================
240 // cublasTaxpy kernel
241 // ==================
242
265
266 template <typename DataType>
267 __global__ void cublasTaxpy_kernel(
268 const int n,
269 const DataType alpha,
270 const DataType* RESTRICT x,
271 const int incx,
272 DataType* RESTRICT y,
273 const int incy)
274 {
275 const int i = threadIdx.x + blockIdx.x * blockDim.x;
276
277 if (i < n)
278 {
279 y[i * incy] = \
280 cu_arithmetics::add<DataType>(
281 cu_arithmetics::mul<DataType>(alpha, x[i * incx]),
282 y[i * incy]
283 );
284 }
285 }
286
287
288 // =================
289 // cublasTdot kernel
290 // =================
291
313
314 template <
315 typename DataType, typename ComputeType, unsigned int block_size>
316 __global__ void cublasTdot_kernel(
317 const int n,
318 const DataType* RESTRICT x,
319 const int incx,
320 const DataType* RESTRICT y,
321 const int incy,
322 ComputeType* RESTRICT result)
323 {
324 // The size of this array should be exactly the number of blocks (for
325 // this, see the corresponding host code, cublas_impl::cublasTdot)
326 __shared__ ComputeType partial_sum[block_size];
327
328 const int tid = threadIdx.x;
329 int i = blockIdx.x * blockDim.x + threadIdx.x;
330
331 ComputeType sum = static_cast<ComputeType>(0.0f);
332 while (i < n)
333 {
334 sum += cu_arithmetics::cast<DataType, ComputeType>(x[i * incx]) * \
335 cu_arithmetics::cast<DataType, ComputeType>(y[i * incy]);
336
337 i += blockDim.x * gridDim.x;
338 }
339
340 partial_sum[tid] = sum;
341
342 __syncthreads();
343
344 // Reduction in shared memory
345 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1)
346 {
347 if (tid < stride)
348 {
349 partial_sum[tid] += partial_sum[tid + stride];
350 }
351 __syncthreads();
352 }
353
354 // Write result for this block to global memory
355 if (tid == 0)
356 {
357 atomicAdd(result, partial_sum[0]);
358 }
359 }
360
361
362 // ==================
363 // cublasTnrm2 kernel
364 // ==================
365
382
383 template <
384 typename DataType, typename ComputeType, unsigned int block_size>
385 __global__ void cublasTnrm2_kernel(
386 const int n,
387 const DataType* RESTRICT x,
388 const int incx,
389 ComputeType* RESTRICT result)
390 {
391 // The size of this array should be exactly the number of blocks (for
392 // this, see the corresponding host code, cublas_impl::cublasTnrm2)
393 __shared__ ComputeType partial_sum[block_size];
394
395 const int tid = threadIdx.x;
396 int i = blockIdx.x * blockDim.x + threadIdx.x;
397
398 ComputeType sum = static_cast<ComputeType>(0.0f);
399 while (i < n)
400 {
402 x[i * incx]);
403 sum += val * val;
404 i += blockDim.x * gridDim.x;
405 }
406
407 partial_sum[tid] = sum;
408
409 __syncthreads();
410
411 // Reduction in shared memory
412 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1)
413 {
414 if (tid < stride)
415 {
416 partial_sum[tid] += partial_sum[tid + stride];
417 }
418 __syncthreads();
419 }
420
421 // Write result for this block to global memory
422 if (tid == 0)
423 {
424 atomicAdd(result, partial_sum[0]);
425 }
426 }
427
428
429 // ==================
430 // cublasTscal kernel
431 // ==================
432
451
452 template <typename DataType>
453 __global__ void cublasTscal_kernel(
454 const int n,
455 const DataType alpha,
456 DataType* RESTRICT x,
457 const int incx)
458 {
459 const int i = threadIdx.x + blockIdx.x * blockDim.x;
460
461 if (i < n)
462 {
463 x[i * incx] = cu_arithmetics::mul<DataType>(x[i * incx], alpha);
464 }
465 }
466
467} // namespace cublas_impl_kernels
468
469
470// ===============================
471// Explicit template instantiation
472// ===============================
473
474// cublasTgemv kernel (__nv_fp8_e5m2)
475#if defined(USE_CUDA_FP8_E5M2) && (USE_CUDA_FP8_E5M2 == 1)
476 template
478 __nv_fp8_e5m2, float, 640>(
479 const bool trans,
480 const int m,
481 const int n,
482 const __nv_fp8_e5m2 alpha,
483 const __nv_fp8_e5m2* RESTRICT A,
484 const int lda,
485 const __nv_fp8_e5m2* RESTRICT x,
486 const int incx,
487 const __nv_fp8_e5m2 beta,
489 const int incy);
490#endif
491
492// cublasTgemv kernel (__nv_fp8_e4m3)
493#if defined(USE_CUDA_FP8_e4m3) && (USE_CUDA_FP8_e4m3 == 1)
494 template
496 __nv_fp8_e4m3, float, 640>(
497 const bool trans,
498 const int m,
499 const int n,
500 const __nv_fp8_e4m3 alpha,
501 const __nv_fp8_e4m3* RESTRICT A,
502 const int lda,
503 const __nv_fp8_e4m3* RESTRICT x,
504 const int incx,
505 const __nv_fp8_e4m3 beta,
507 const int incy);
508#endif
509
510// cublasTgemv kernel (__half)
511#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
512 template
514 __half, float, 640>(
515 const bool trans,
516 const int m,
517 const int n,
518 const __half alpha,
519 const __half* RESTRICT A,
520 const int lda,
521 const __half* RESTRICT x,
522 const int incx,
523 const __half beta,
524 __half* RESTRICT y,
525 const int incy);
526#endif
527
528// cublasTgemv kernel (__nv_bfloat16)
529#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
530 template
532 __nv_bfloat16, float, 640>(
533 const bool trans,
534 const int m,
535 const int n,
536 const __nv_bfloat16 alpha,
537 const __nv_bfloat16* RESTRICT A,
538 const int lda,
539 const __nv_bfloat16* RESTRICT x,
540 const int incx,
541 const __nv_bfloat16 beta,
542 __nv_bfloat16* RESTRICT y,
543 const int incy);
544#endif
545
546// cublasTgemv kernel (float)
547#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
548#if !defined(USE_CUBLAS) || (USE_CUBLAS != 1)
549 template
551 float, float, 640>(
552 const bool trans,
553 const int m,
554 const int n,
555 const float alpha,
556 const float* RESTRICT A,
557 const int lda,
558 const float* RESTRICT x,
559 const int incx,
560 const float beta,
561 float* RESTRICT y,
562 const int incy);
563#endif
564#endif
565
566// cublasTgemv kernel (double)
567#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
568#if !defined(USE_CUBLAS) || (USE_CUBLAS != 1)
569 template
571 double, double, 640>(
572 const bool trans,
573 const int m,
574 const int n,
575 const double alpha,
576 const double* RESTRICT A,
577 const int lda,
578 const double* RESTRICT x,
579 const int incx,
580 const double beta,
581 double* RESTRICT y,
582 const int incy);
583#endif
584#endif
585
586// cublasTcopy kernel (__half)
587#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
588 template
589 __global__ void cublas_impl_kernels::cublasTcopy_kernel<__half>(
590 const int n,
591 const __half* RESTRICT x,
592 const int incx,
593 __half* RESTRICT y,
594 const int incy);
595#endif
596
597// cublasTcopy kernel (__nv_bfloat16)
598#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
599 template
600 __global__ void cublas_impl_kernels::cublasTcopy_kernel<__nv_bfloat16>(
601 const int n,
602 const __nv_bfloat16* RESTRICT x,
603 const int incx,
604 __nv_bfloat16* RESTRICT y,
605 const int incy);
606#endif
607
608// cublasTcopy kernel (float)
609#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
610#if !defined(USE_CUBLAS) || (USE_CUBLAS != 1)
611 template
612 __global__ void cublas_impl_kernels::cublasTcopy_kernel<float>(
613 const int n,
614 const float* RESTRICT x,
615 const int incx,
616 float* RESTRICT y,
617 const int incy);
618#endif
619#endif
620
621// cublasTcopy kernel (double)
622#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
623#if !defined(USE_CUBLAS) || (USE_CUBLAS != 1)
624 template
625 __global__ void cublas_impl_kernels::cublasTcopy_kernel<double>(
626 const int n,
627 const double* RESTRICT x,
628 const int incx,
629 double* RESTRICT y,
630 const int incy);
631#endif
632#endif
633
634// cublasTaxpy kernel (__half)
635#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
636 template
637 __global__ void cublas_impl_kernels::cublasTaxpy_kernel<__half>(
638 const int n,
639 const __half alpha,
640 const __half* RESTRICT x,
641 const int incx,
642 __half* RESTRICT y,
643 const int incy);
644#endif
645
646// cublasTaxpy kernel (__nv_bfloat16)
647#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
648 template
649 __global__ void cublas_impl_kernels::cublasTaxpy_kernel<__nv_bfloat16>(
650 const int n,
651 const __nv_bfloat16 alpha,
652 const __nv_bfloat16* RESTRICT x,
653 const int incx,
654 __nv_bfloat16* RESTRICT y,
655 const int incy);
656#endif
657
658// cublasTaxpy kernel (float)
659#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
660#if !defined(USE_CUBLAS) || (USE_CUBLAS != 1)
661 template
662 __global__ void cublas_impl_kernels::cublasTaxpy_kernel<float>(
663 const int n,
664 const float alpha,
665 const float* RESTRICT x,
666 const int incx,
667 float* RESTRICT y,
668 const int incy);
669#endif
670#endif
671
672// cublasTaxpy kernel (double)
673#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
674#if !defined(USE_CUBLAS) || (USE_CUBLAS != 1)
675 template
676 __global__ void cublas_impl_kernels::cublasTaxpy_kernel<double>(
677 const int n,
678 const double alpha,
679 const double* RESTRICT x,
680 const int incx,
681 double* RESTRICT y,
682 const int incy);
683#endif
684#endif
685
686// cublasTdot kernel (__half)
687#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
688 template
690 __half, float, 256>(
691 const int n,
692 const __half* RESTRICT x,
693 const int incx,
694 const __half* RESTRICT y,
695 const int incy,
696 float* RESTRICT result);
697#endif
698
699// cublasTdot kernel (__nv_bfloat16)
700#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
701 template
703 __nv_bfloat16, float, 256>(
704 const int n,
705 const __nv_bfloat16* RESTRICT x,
706 const int incx,
707 const __nv_bfloat16* RESTRICT y,
708 const int incy,
709 float* RESTRICT result);
710#endif
711
712// cublasTdot kernel (float)
713#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
714#if !defined(USE_CUBLAS) || (USE_CUBLAS != 1)
715 template
717 float, float, 256>(
718 const int n,
719 const float* RESTRICT x,
720 const int incx,
721 const float* RESTRICT y,
722 const int incy,
723 float* RESTRICT result);
724#endif
725#endif
726
727// cublasTdot kernel (double)
728#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
729#if !defined(USE_CUBLAS) || (USE_CUBLAS != 1)
730 template
732 double, double, 256>(
733 const int n,
734 const double* RESTRICT x,
735 const int incx,
736 const double* RESTRICT y,
737 const int incy,
738 double* RESTRICT result);
739#endif
740#endif
741
742// cublasTnrm2 kernel (__half)
743#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
744 template
746 __half, float, 256>(
747 const int n,
748 const __half* RESTRICT x,
749 const int incx,
750 float* RESTRICT result);
751#endif
752
753// cublasTnrm2 kernel (__nv_bfloat16)
754#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
755 template
757 __nv_bfloat16, float, 256>(
758 const int n,
759 const __nv_bfloat16* RESTRICT x,
760 const int incx,
761 float* RESTRICT result);
762#endif
763
764// cublasTnrm2 kernel (float)
765#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
766#if !defined(USE_CUBLAS) || (USE_CUBLAS != 1)
767 template
769 float, float, 256>(
770 const int n,
771 const float* RESTRICT x,
772 const int incx,
773 float* RESTRICT result);
774#endif
775#endif
776
777// cublasTnrm2 kernel (double)
778#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
779#if !defined(USE_CUBLAS) || (USE_CUBLAS != 1)
780 template
782 double, double, 256>(
783 const int n,
784 const double* RESTRICT x,
785 const int incx,
786 double* RESTRICT result);
787#endif
788#endif
789
790// cublasTscal kernel (__half)
791#if defined(USE_CUDA_FP16) && (USE_CUDA_FP16 == 1)
792 template
793 __global__ void cublas_impl_kernels::cublasTscal_kernel<__half>(
794 const int n,
795 const __half alpha,
796 __half* RESTRICT x,
797 const int incx);
798#endif
799
800// cublasTscal kernel (__nv_bfloat16)
801#if defined(USE_CUDA_BF16) && (USE_CUDA_BF16 == 1)
802 template
803 __global__ void cublas_impl_kernels::cublasTscal_kernel<__nv_bfloat16>(
804 const int n,
805 const __nv_bfloat16 alpha,
806 __nv_bfloat16* RESTRICT x,
807 const int incx);
808#endif
809
810// cublasTscal kernel (float)
811#if defined(USE_CUDA_FP32) && (USE_CUDA_FP32 == 1)
812#if !defined(USE_CUBLAS) || (USE_CUBLAS != 1)
813 template
814 __global__ void cublas_impl_kernels::cublasTscal_kernel<float>(
815 const int n,
816 const float alpha,
817 float* RESTRICT x,
818 const int incx);
819#endif
820#endif
821
822// cublasTscal kernel (double)
823#if defined(USE_CUDA_FP64) && (USE_CUDA_FP64 == 1)
824#if !defined(USE_CUBLAS) || (USE_CUBLAS != 1)
825 template
826 __global__ void cublas_impl_kernels::cublasTscal_kernel<double>(
827 const int n,
828 const double alpha,
829 double* RESTRICT x,
830 const int incx);
831#endif
832#endif
#define RESTRICT
__host__ __device__ DataType abs(const DataType x)
Absolute value of a floating point number.
Templated kernel code for implenentations of several BLAS-type functions in CUDA.
__global__ void cublasTscal_kernel(const int n, const DataType alpha, DataType *RESTRICT x, const int incx)
Performs .
__global__ void cublasTaxpy_kernel(const int n, const DataType alpha, const DataType *RESTRICT x, const int incx, DataType *RESTRICT y, const int incy)
Performs .
__global__ void cublasTnrm2_kernel(const int n, const DataType *RESTRICT x, const int incx, ComputeType *RESTRICT result)
Computes .
__global__ void cublasTcopy_kernel(const int n, const DataType *RESTRICT x, const int incx, DataType *RESTRICT y, const int incy)
Performs .
__global__ void cublasTdot_kernel(const int n, const DataType *RESTRICT x, const int incx, const DataType *RESTRICT y, const int incy, ComputeType *RESTRICT result)
Computes .
__global__ void cublasTgemv_kernel(const bool trans, const int m, const int n, const DataType alpha, const DataType *RESTRICT A, const int lda, const DataType *RESTRICT x, const int incx, const DataType beta, DataType *RESTRICT y, const int incy)
Performs the operation .