imate
C++/CUDA Reference
Loading...
Searching...
No Matches
cu_linear_operator.h
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12#ifndef _CU_LINEAR_OPERATOR_CU_LINEAR_OPERATOR_H_
13#define _CU_LINEAR_OPERATOR_CU_LINEAR_OPERATOR_H_
14
15// =======
16// Headers
17// =======
18
19#include <cusparse.h> // cusparseHandle_t
20#include "../_definitions/types.h" // FlagType, IndexType, LongIndexType
21 //
22// cLinearOperatorBase
23#include "../_c_linear_operator/c_linear_operator_base.h"
24
25// Avoid CUBLAS numeration value not handled in switch [-Wswitch-enum] warning
26#ifdef _MSC_VER
27 #pragma warning(push, 0) // Suppress all warnings from the followings
28 #include <cublas_v2.h> // cublasHandle_t
29 #pragma warning(pop) // Restore previous warning level
30#elif defined(__INTEL_LLVM_COMPILER) || defined(__INTEL_COMPILER)
31 #pragma warning(push, 0)
32 #include <cublas_v2.h> // cublasHandle_t
33 #pragma warning(pop)
34#elif defined(__GNUC__) || defined(__clang__)
35 #pragma GCC diagnostic push
36 #pragma GCC diagnostic ignored "-Wswitch-enum"
37 #include <cublas_v2.h> // cublasHandle_t
38 #pragma GCC diagnostic pop
39#else
40 #include <cublas_v2.h> // cublasHandle_t, cublasCreate, cublasSetMathMode
41#endif
42
43
44// ==================
45// cu Linear Operator
46// ==================
47
61
62template <typename DataType>
64{
65 public:
66
67 // Member methods
69
70 explicit cuLinearOperator(const int num_gpu_devices_);
71
72 virtual ~cuLinearOperator();
73
74 cublasHandle_t get_cublas_handle() const;
75
76 void set_parameters(DataType* parameters_);
77
78 virtual DataType get_eigenvalue(
79 const DataType* known_parameters,
80 const DataType known_eigenvalue,
81 const DataType* inquiry_parameters) const = 0;
82
83 virtual void dot(
84 const DataType* vector,
85 DataType* product) = 0;
86
87 virtual void transpose_dot(
88 const DataType* vector,
89 DataType* product) = 0;
90
91 protected:
92
93 // Member methods
94 int query_gpu_devices() const;
97
98 // Member data
101 cublasHandle_t* cublas_handle;
102 cusparseHandle_t* cusparse_handle;
103 DataType* parameters;
104};
105
106#endif // _CU_LINEAR_OPERATOR_CU_LINEAR_OPERATOR_H_
Base class for cLinearOperator and cuLinearOperator . This class is not templated so that both cpp an...
Base class for linear operators. This class serves as interface for all derived classes.
cuLinearOperator()
Default constructor.
void initialize_cusparse_handle()
Creates a cusparseHandle_t object, if not created already.
cublasHandle_t * cublas_handle
virtual void dot(const DataType *vector, DataType *product)=0
int query_gpu_devices() const
Before any numerical computation, this method chechs if any gpu device is available on the machine,...
cublasHandle_t get_cublas_handle() const
This function returns a reference to the cublasHandle_t object. The object will be created,...
cusparseHandle_t * cusparse_handle
virtual ~cuLinearOperator()
Destructor.
virtual void transpose_dot(const DataType *vector, DataType *product)=0
void initialize_cublas_handle()
Creates a cublasHandle_t object, if not created already.
void set_parameters(DataType *parameters_)
Sets the scalar parameter this->parameters. Parameter is initialized to NULL. However,...
virtual DataType get_eigenvalue(const DataType *known_parameters, const DataType known_eigenvalue, const DataType *inquiry_parameters) const =0