10 #ifndef AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
11 #define AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
14 #include "Amesos2_cuSOLVER_TypeMap.hpp"
17 #include <cusolverSp.h>
18 #include <cusolverDn.h>
20 #include <cusolverSp_LOWLEVEL_PREVIEW.h>
22 #ifdef HAVE_TEUCHOS_COMPLEX
23 #include <cuComplex.h>
29 struct FunctionMap<cuSOLVER,double>
31 static cusolverStatus_t bufferInfo(
32 cusolverSpHandle_t handle,
35 cusparseMatDescr_t & desc,
36 const double * values,
39 csrcholInfo_t & chol_info,
40 size_t * internalDataInBytes,
41 size_t * workspaceInBytes)
43 cusolverStatus_t status =
44 cusolverSpDcsrcholBufferInfo(handle, size, nnz, desc, values,
45 rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
49 static cusolverStatus_t numeric(
50 cusolverSpHandle_t handle,
53 cusparseMatDescr_t & desc,
54 const double * values,
57 csrcholInfo_t & chol_info,
60 cusolverStatus_t status = cusolverSpDcsrcholFactor(
61 handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
65 static cusolverStatus_t solve(
66 cusolverSpHandle_t handle,
70 csrcholInfo_t & chol_info,
73 cusolverStatus_t status = cusolverSpDcsrcholSolve(
74 handle, size, b, x, chol_info, buffer);
80 struct FunctionMap<cuSOLVER,float>
82 static cusolverStatus_t bufferInfo(
83 cusolverSpHandle_t handle,
86 cusparseMatDescr_t & desc,
90 csrcholInfo_t & chol_info,
91 size_t * internalDataInBytes,
92 size_t * workspaceInBytes)
94 cusolverStatus_t status =
95 cusolverSpScsrcholBufferInfo(handle, size, nnz, desc, values,
96 rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
100 static cusolverStatus_t numeric(
101 cusolverSpHandle_t handle,
104 cusparseMatDescr_t & desc,
105 const float * values,
108 csrcholInfo_t & chol_info,
111 cusolverStatus_t status = cusolverSpScsrcholFactor(
112 handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
116 static cusolverStatus_t solve(
117 cusolverSpHandle_t handle,
121 csrcholInfo_t & chol_info,
124 cusolverStatus_t status = cusolverSpScsrcholSolve(
125 handle, size, b, x, chol_info, buffer);
130 #ifdef HAVE_TEUCHOS_COMPLEX
132 struct FunctionMap<cuSOLVER,Kokkos::complex<double>>
134 static cusolverStatus_t bufferInfo(
135 cusolverSpHandle_t handle,
138 cusparseMatDescr_t & desc,
142 csrcholInfo_t & chol_info,
143 size_t * internalDataInBytes,
144 size_t * workspaceInBytes)
146 typedef cuDoubleComplex scalar_t;
147 const scalar_t * cu_values =
reinterpret_cast<const scalar_t *
>(values);
148 cusolverStatus_t status =
149 cusolverSpZcsrcholBufferInfo(handle, size, nnz, desc,
150 cu_values, rowPtr, colIdx, chol_info,
151 internalDataInBytes, workspaceInBytes);
155 static cusolverStatus_t numeric(
156 cusolverSpHandle_t handle,
159 cusparseMatDescr_t & desc,
163 csrcholInfo_t & chol_info,
166 typedef cuDoubleComplex scalar_t;
167 const scalar_t * cu_values =
168 reinterpret_cast<const scalar_t *
>(values);
169 cusolverStatus_t status = cusolverSpZcsrcholFactor(
170 handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
174 static cusolverStatus_t solve(
175 cusolverSpHandle_t handle,
179 csrcholInfo_t & chol_info,
182 typedef cuDoubleComplex scalar_t;
183 const scalar_t * cu_b =
reinterpret_cast<const scalar_t *
>(b);
184 scalar_t * cu_x =
reinterpret_cast<scalar_t *
>(x);
185 cusolverStatus_t status = cusolverSpZcsrcholSolve(
186 handle, size, cu_b, cu_x, chol_info, buffer);
192 struct FunctionMap<cuSOLVER,Kokkos::complex<float>>
194 static cusolverStatus_t bufferInfo(
195 cusolverSpHandle_t handle,
198 cusparseMatDescr_t & desc,
202 csrcholInfo_t & chol_info,
203 size_t * internalDataInBytes,
204 size_t * workspaceInBytes)
206 typedef cuFloatComplex scalar_t;
207 const scalar_t * cu_values =
reinterpret_cast<const scalar_t *
>(values);
208 cusolverStatus_t status =
209 cusolverSpCcsrcholBufferInfo(handle, size, nnz, desc,
210 cu_values, rowPtr, colIdx, chol_info,
211 internalDataInBytes, workspaceInBytes);
215 static cusolverStatus_t numeric(
216 cusolverSpHandle_t handle,
219 cusparseMatDescr_t & desc,
223 csrcholInfo_t & chol_info,
226 typedef cuFloatComplex scalar_t;
227 const scalar_t * cu_values =
reinterpret_cast<const scalar_t *
>(values);
228 cusolverStatus_t status = cusolverSpCcsrcholFactor(
229 handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
233 static cusolverStatus_t solve(
234 cusolverSpHandle_t handle,
238 csrcholInfo_t & chol_info,
241 typedef cuFloatComplex scalar_t;
242 const scalar_t * cu_b =
reinterpret_cast<const scalar_t *
>(b);
243 scalar_t * cu_x =
reinterpret_cast<scalar_t *
>(x);
244 cusolverStatus_t status = cusolverSpCcsrcholSolve(
245 handle, size, cu_b, cu_x, chol_info, buffer);
253 #endif // AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
const int size
Definition: klu2_simple.cpp:50
Declaration of Function mapping class for Amesos2.