Amesos2 - Direct Sparse Solver Interfaces  Version of the Day
Amesos2_cuSOLVER_FunctionMap.hpp
1 // @HEADER
2 // *****************************************************************************
3 // Amesos2: Templated Direct Sparse Solver Package
4 //
5 // Copyright 2011 NTESS and the Amesos2 contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
11 #define AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
12 
13 #include "Amesos2_FunctionMap.hpp"
14 #include "Amesos2_cuSOLVER_TypeMap.hpp"
15 
16 #include <cuda.h>
17 #include <cusolverSp.h>
18 #include <cusolverDn.h>
19 #include <cusparse.h>
20 #include <cusolverSp_LOWLEVEL_PREVIEW.h>
21 
22 #ifdef HAVE_TEUCHOS_COMPLEX
23 #include <cuComplex.h>
24 #endif
25 
26 namespace Amesos2 {
27 
28  template <>
29  struct FunctionMap<cuSOLVER,double>
30  {
31  static cusolverStatus_t bufferInfo(
32  cusolverSpHandle_t handle,
33  int size,
34  int nnz,
35  cusparseMatDescr_t & desc,
36  const double * values,
37  const int * rowPtr,
38  const int * colIdx,
39  csrcholInfo_t & chol_info,
40  size_t * internalDataInBytes,
41  size_t * workspaceInBytes)
42  {
43  cusolverStatus_t status =
44  cusolverSpDcsrcholBufferInfo(handle, size, nnz, desc, values,
45  rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
46  return status;
47  }
48 
49  static cusolverStatus_t numeric(
50  cusolverSpHandle_t handle,
51  int size,
52  int nnz,
53  cusparseMatDescr_t & desc,
54  const double * values,
55  const int * rowPtr,
56  const int * colIdx,
57  csrcholInfo_t & chol_info,
58  void * buffer)
59  {
60  cusolverStatus_t status = cusolverSpDcsrcholFactor(
61  handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
62  return status;
63  }
64 
65  static cusolverStatus_t solve(
66  cusolverSpHandle_t handle,
67  int size,
68  const double * b,
69  double * x,
70  csrcholInfo_t & chol_info,
71  void * buffer)
72  {
73  cusolverStatus_t status = cusolverSpDcsrcholSolve(
74  handle, size, b, x, chol_info, buffer);
75  return status;
76  }
77  };
78 
79  template <>
80  struct FunctionMap<cuSOLVER,float>
81  {
82  static cusolverStatus_t bufferInfo(
83  cusolverSpHandle_t handle,
84  int size,
85  int nnz,
86  cusparseMatDescr_t & desc,
87  const float * values,
88  const int * rowPtr,
89  const int * colIdx,
90  csrcholInfo_t & chol_info,
91  size_t * internalDataInBytes,
92  size_t * workspaceInBytes)
93  {
94  cusolverStatus_t status =
95  cusolverSpScsrcholBufferInfo(handle, size, nnz, desc, values,
96  rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
97  return status;
98  }
99 
100  static cusolverStatus_t numeric(
101  cusolverSpHandle_t handle,
102  int size,
103  int nnz,
104  cusparseMatDescr_t & desc,
105  const float * values,
106  const int * rowPtr,
107  const int * colIdx,
108  csrcholInfo_t & chol_info,
109  void * buffer)
110  {
111  cusolverStatus_t status = cusolverSpScsrcholFactor(
112  handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
113  return status;
114  }
115 
116  static cusolverStatus_t solve(
117  cusolverSpHandle_t handle,
118  int size,
119  const float * b,
120  float * x,
121  csrcholInfo_t & chol_info,
122  void * buffer)
123  {
124  cusolverStatus_t status = cusolverSpScsrcholSolve(
125  handle, size, b, x, chol_info, buffer);
126  return status;
127  }
128  };
129 
130 #ifdef HAVE_TEUCHOS_COMPLEX
131  template <>
132  struct FunctionMap<cuSOLVER,Kokkos::complex<double>>
133  {
134  static cusolverStatus_t bufferInfo(
135  cusolverSpHandle_t handle,
136  int size,
137  int nnz,
138  cusparseMatDescr_t & desc,
139  const void * values,
140  const int * rowPtr,
141  const int * colIdx,
142  csrcholInfo_t & chol_info,
143  size_t * internalDataInBytes,
144  size_t * workspaceInBytes)
145  {
146  typedef cuDoubleComplex scalar_t;
147  const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
148  cusolverStatus_t status =
149  cusolverSpZcsrcholBufferInfo(handle, size, nnz, desc,
150  cu_values, rowPtr, colIdx, chol_info,
151  internalDataInBytes, workspaceInBytes);
152  return status;
153  }
154 
155  static cusolverStatus_t numeric(
156  cusolverSpHandle_t handle,
157  int size,
158  int nnz,
159  cusparseMatDescr_t & desc,
160  const void * values,
161  const int * rowPtr,
162  const int * colIdx,
163  csrcholInfo_t & chol_info,
164  void * buffer)
165  {
166  typedef cuDoubleComplex scalar_t;
167  const scalar_t * cu_values =
168  reinterpret_cast<const scalar_t *>(values);
169  cusolverStatus_t status = cusolverSpZcsrcholFactor(
170  handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
171  return status;
172  }
173 
174  static cusolverStatus_t solve(
175  cusolverSpHandle_t handle,
176  int size,
177  const void * b,
178  void * x,
179  csrcholInfo_t & chol_info,
180  void * buffer)
181  {
182  typedef cuDoubleComplex scalar_t;
183  const scalar_t * cu_b = reinterpret_cast<const scalar_t *>(b);
184  scalar_t * cu_x = reinterpret_cast<scalar_t *>(x);
185  cusolverStatus_t status = cusolverSpZcsrcholSolve(
186  handle, size, cu_b, cu_x, chol_info, buffer);
187  return status;
188  }
189  };
190 
191  template <>
192  struct FunctionMap<cuSOLVER,Kokkos::complex<float>>
193  {
194  static cusolverStatus_t bufferInfo(
195  cusolverSpHandle_t handle,
196  int size,
197  int nnz,
198  cusparseMatDescr_t & desc,
199  const void * values,
200  const int * rowPtr,
201  const int * colIdx,
202  csrcholInfo_t & chol_info,
203  size_t * internalDataInBytes,
204  size_t * workspaceInBytes)
205  {
206  typedef cuFloatComplex scalar_t;
207  const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
208  cusolverStatus_t status =
209  cusolverSpCcsrcholBufferInfo(handle, size, nnz, desc,
210  cu_values, rowPtr, colIdx, chol_info,
211  internalDataInBytes, workspaceInBytes);
212  return status;
213  }
214 
215  static cusolverStatus_t numeric(
216  cusolverSpHandle_t handle,
217  int size,
218  int nnz,
219  cusparseMatDescr_t & desc,
220  const void * values,
221  const int * rowPtr,
222  const int * colIdx,
223  csrcholInfo_t & chol_info,
224  void * buffer)
225  {
226  typedef cuFloatComplex scalar_t;
227  const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
228  cusolverStatus_t status = cusolverSpCcsrcholFactor(
229  handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
230  return status;
231  }
232 
233  static cusolverStatus_t solve(
234  cusolverSpHandle_t handle,
235  int size,
236  const void * b,
237  void * x,
238  csrcholInfo_t & chol_info,
239  void * buffer)
240  {
241  typedef cuFloatComplex scalar_t;
242  const scalar_t * cu_b = reinterpret_cast<const scalar_t *>(b);
243  scalar_t * cu_x = reinterpret_cast<scalar_t *>(x);
244  cusolverStatus_t status = cusolverSpCcsrcholSolve(
245  handle, size, cu_b, cu_x, chol_info, buffer);
246  return status;
247  }
248  };
249 #endif
250 
251 } // end namespace Amesos2
252 
253 #endif // AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
const int size
Definition: klu2_simple.cpp:50
Declaration of Function mapping class for Amesos2.