Amesos2 - Direct Sparse Solver Interfaces  Version of the Day
Amesos2_cuSOLVER_FunctionMap.hpp
1 // @HEADER
2 // *****************************************************************************
3 // Amesos2: Templated Direct Sparse Solver Package
4 //
5 // Copyright 2011 NTESS and the Amesos2 contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
11 #define AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
12 
13 #include "Amesos2_FunctionMap.hpp"
14 #include "Amesos2_cuSOLVER_TypeMap.hpp"
15 
16 #include <cuda.h>
17 #include <cusolverSp.h>
18 #include <cusolverDn.h>
19 #include <cusparse.h>
20 #include <cusolverSp_LOWLEVEL_PREVIEW.h>
21 
22 #ifdef HAVE_TEUCHOS_COMPLEX
23 #include <cuComplex.h>
24 #endif
25 
26 namespace Amesos2 {
27 
28  template <>
29  struct FunctionMap<cuSOLVER,double>
30  {
31  static cusolverStatus_t bufferInfo(
32  cusolverSpHandle_t handle,
33  int size,
34  int nnz,
35  cusparseMatDescr_t & desc,
36  const double * values,
37  const int * rowPtr,
38  const int * colIdx,
39  csrcholInfo_t & chol_info,
40  size_t * internalDataInBytes,
41  size_t * workspaceInBytes)
42  {
43  cusolverStatus_t status =
44  cusolverSpDcsrcholBufferInfo(handle, size, nnz, desc, values,
45  rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
46  return status;
47  }
48 
49  static cusolverStatus_t numeric(
50  cusolverSpHandle_t handle,
51  int size,
52  int nnz,
53  cusparseMatDescr_t & desc,
54  const double * values,
55  const int * rowPtr,
56  const int * colIdx,
57  csrcholInfo_t & chol_info,
58  void * buffer)
59  {
60  cusolverStatus_t status = cusolverSpDcsrcholFactor(
61  handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
62  cudaDeviceSynchronize();
63  return status;
64  }
65 
66  static cusolverStatus_t solve(
67  cusolverSpHandle_t handle,
68  int size,
69  const double * b,
70  double * x,
71  csrcholInfo_t & chol_info,
72  void * buffer)
73  {
74  cusolverStatus_t status = cusolverSpDcsrcholSolve(
75  handle, size, b, x, chol_info, buffer);
76  cudaDeviceSynchronize();
77  return status;
78  }
79  };
80 
81  template <>
82  struct FunctionMap<cuSOLVER,float>
83  {
84  static cusolverStatus_t bufferInfo(
85  cusolverSpHandle_t handle,
86  int size,
87  int nnz,
88  cusparseMatDescr_t & desc,
89  const float * values,
90  const int * rowPtr,
91  const int * colIdx,
92  csrcholInfo_t & chol_info,
93  size_t * internalDataInBytes,
94  size_t * workspaceInBytes)
95  {
96  cusolverStatus_t status =
97  cusolverSpScsrcholBufferInfo(handle, size, nnz, desc, values,
98  rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
99  return status;
100  }
101 
102  static cusolverStatus_t numeric(
103  cusolverSpHandle_t handle,
104  int size,
105  int nnz,
106  cusparseMatDescr_t & desc,
107  const float * values,
108  const int * rowPtr,
109  const int * colIdx,
110  csrcholInfo_t & chol_info,
111  void * buffer)
112  {
113  cusolverStatus_t status = cusolverSpScsrcholFactor(
114  handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
115  cudaDeviceSynchronize();
116  return status;
117  }
118 
119  static cusolverStatus_t solve(
120  cusolverSpHandle_t handle,
121  int size,
122  const float * b,
123  float * x,
124  csrcholInfo_t & chol_info,
125  void * buffer)
126  {
127  cusolverStatus_t status = cusolverSpScsrcholSolve(
128  handle, size, b, x, chol_info, buffer);
129  cudaDeviceSynchronize();
130  return status;
131  }
132  };
133 
134 #ifdef HAVE_TEUCHOS_COMPLEX
135  template <>
136  struct FunctionMap<cuSOLVER,Kokkos::complex<double>>
137  {
138  static cusolverStatus_t bufferInfo(
139  cusolverSpHandle_t handle,
140  int size,
141  int nnz,
142  cusparseMatDescr_t & desc,
143  const void * values,
144  const int * rowPtr,
145  const int * colIdx,
146  csrcholInfo_t & chol_info,
147  size_t * internalDataInBytes,
148  size_t * workspaceInBytes)
149  {
150  typedef cuDoubleComplex scalar_t;
151  const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
152  cusolverStatus_t status =
153  cusolverSpZcsrcholBufferInfo(handle, size, nnz, desc,
154  cu_values, rowPtr, colIdx, chol_info,
155  internalDataInBytes, workspaceInBytes);
156  return status;
157  }
158 
159  static cusolverStatus_t numeric(
160  cusolverSpHandle_t handle,
161  int size,
162  int nnz,
163  cusparseMatDescr_t & desc,
164  const void * values,
165  const int * rowPtr,
166  const int * colIdx,
167  csrcholInfo_t & chol_info,
168  void * buffer)
169  {
170  typedef cuDoubleComplex scalar_t;
171  const scalar_t * cu_values =
172  reinterpret_cast<const scalar_t *>(values);
173  cusolverStatus_t status = cusolverSpZcsrcholFactor(
174  handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
175  cudaDeviceSynchronize();
176  return status;
177  }
178 
179  static cusolverStatus_t solve(
180  cusolverSpHandle_t handle,
181  int size,
182  const void * b,
183  void * x,
184  csrcholInfo_t & chol_info,
185  void * buffer)
186  {
187  typedef cuDoubleComplex scalar_t;
188  const scalar_t * cu_b = reinterpret_cast<const scalar_t *>(b);
189  scalar_t * cu_x = reinterpret_cast<scalar_t *>(x);
190  cusolverStatus_t status = cusolverSpZcsrcholSolve(
191  handle, size, cu_b, cu_x, chol_info, buffer);
192  cudaDeviceSynchronize();
193  return status;
194  }
195  };
196 
197  template <>
198  struct FunctionMap<cuSOLVER,Kokkos::complex<float>>
199  {
200  static cusolverStatus_t bufferInfo(
201  cusolverSpHandle_t handle,
202  int size,
203  int nnz,
204  cusparseMatDescr_t & desc,
205  const void * values,
206  const int * rowPtr,
207  const int * colIdx,
208  csrcholInfo_t & chol_info,
209  size_t * internalDataInBytes,
210  size_t * workspaceInBytes)
211  {
212  typedef cuFloatComplex scalar_t;
213  const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
214  cusolverStatus_t status =
215  cusolverSpCcsrcholBufferInfo(handle, size, nnz, desc,
216  cu_values, rowPtr, colIdx, chol_info,
217  internalDataInBytes, workspaceInBytes);
218  return status;
219  }
220 
221  static cusolverStatus_t numeric(
222  cusolverSpHandle_t handle,
223  int size,
224  int nnz,
225  cusparseMatDescr_t & desc,
226  const void * values,
227  const int * rowPtr,
228  const int * colIdx,
229  csrcholInfo_t & chol_info,
230  void * buffer)
231  {
232  typedef cuFloatComplex scalar_t;
233  const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
234  cusolverStatus_t status = cusolverSpCcsrcholFactor(
235  handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
236  cudaDeviceSynchronize();
237  return status;
238  }
239 
240  static cusolverStatus_t solve(
241  cusolverSpHandle_t handle,
242  int size,
243  const void * b,
244  void * x,
245  csrcholInfo_t & chol_info,
246  void * buffer)
247  {
248  typedef cuFloatComplex scalar_t;
249  const scalar_t * cu_b = reinterpret_cast<const scalar_t *>(b);
250  scalar_t * cu_x = reinterpret_cast<scalar_t *>(x);
251  cusolverStatus_t status = cusolverSpCcsrcholSolve(
252  handle, size, cu_b, cu_x, chol_info, buffer);
253  cudaDeviceSynchronize();
254  return status;
255  }
256  };
257 #endif
258 
259 } // end namespace Amesos2
260 
261 #endif // AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
const int size
Definition: klu2_simple.cpp:50
Declaration of Function mapping class for Amesos2.