44 #ifndef AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
45 #define AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
48 #include "Amesos2_cuSOLVER_TypeMap.hpp"
51 #include <cusolverSp.h>
52 #include <cusolverDn.h>
54 #include <cusolverSp_LOWLEVEL_PREVIEW.h>
56 #ifdef HAVE_TEUCHOS_COMPLEX
57 #include <cuComplex.h>
63 struct FunctionMap<cuSOLVER,double>
65 static cusolverStatus_t bufferInfo(
66 cusolverSpHandle_t handle,
69 cusparseMatDescr_t & desc,
70 const double * values,
73 csrcholInfo_t & chol_info,
74 size_t * internalDataInBytes,
75 size_t * workspaceInBytes)
77 cusolverStatus_t status =
78 cusolverSpDcsrcholBufferInfo(handle, size, nnz, desc, values,
79 rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
83 static cusolverStatus_t numeric(
84 cusolverSpHandle_t handle,
87 cusparseMatDescr_t & desc,
88 const double * values,
91 csrcholInfo_t & chol_info,
94 cusolverStatus_t status = cusolverSpDcsrcholFactor(
95 handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
99 static cusolverStatus_t solve(
100 cusolverSpHandle_t handle,
104 csrcholInfo_t & chol_info,
107 cusolverStatus_t status = cusolverSpDcsrcholSolve(
108 handle, size, b, x, chol_info, buffer);
114 struct FunctionMap<cuSOLVER,float>
116 static cusolverStatus_t bufferInfo(
117 cusolverSpHandle_t handle,
120 cusparseMatDescr_t & desc,
121 const float * values,
124 csrcholInfo_t & chol_info,
125 size_t * internalDataInBytes,
126 size_t * workspaceInBytes)
128 cusolverStatus_t status =
129 cusolverSpScsrcholBufferInfo(handle, size, nnz, desc, values,
130 rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
134 static cusolverStatus_t numeric(
135 cusolverSpHandle_t handle,
138 cusparseMatDescr_t & desc,
139 const float * values,
142 csrcholInfo_t & chol_info,
145 cusolverStatus_t status = cusolverSpScsrcholFactor(
146 handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
150 static cusolverStatus_t solve(
151 cusolverSpHandle_t handle,
155 csrcholInfo_t & chol_info,
158 cusolverStatus_t status = cusolverSpScsrcholSolve(
159 handle, size, b, x, chol_info, buffer);
164 #ifdef HAVE_TEUCHOS_COMPLEX
166 struct FunctionMap<cuSOLVER,Kokkos::complex<double>>
168 static cusolverStatus_t bufferInfo(
169 cusolverSpHandle_t handle,
172 cusparseMatDescr_t & desc,
176 csrcholInfo_t & chol_info,
177 size_t * internalDataInBytes,
178 size_t * workspaceInBytes)
180 typedef cuDoubleComplex scalar_t;
181 const scalar_t * cu_values =
reinterpret_cast<const scalar_t *
>(values);
182 cusolverStatus_t status =
183 cusolverSpZcsrcholBufferInfo(handle, size, nnz, desc,
184 cu_values, rowPtr, colIdx, chol_info,
185 internalDataInBytes, workspaceInBytes);
189 static cusolverStatus_t numeric(
190 cusolverSpHandle_t handle,
193 cusparseMatDescr_t & desc,
197 csrcholInfo_t & chol_info,
200 typedef cuDoubleComplex scalar_t;
201 const scalar_t * cu_values =
202 reinterpret_cast<const scalar_t *
>(values);
203 cusolverStatus_t status = cusolverSpZcsrcholFactor(
204 handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
208 static cusolverStatus_t solve(
209 cusolverSpHandle_t handle,
213 csrcholInfo_t & chol_info,
216 typedef cuDoubleComplex scalar_t;
217 const scalar_t * cu_b =
reinterpret_cast<const scalar_t *
>(b);
218 scalar_t * cu_x =
reinterpret_cast<scalar_t *
>(x);
219 cusolverStatus_t status = cusolverSpZcsrcholSolve(
220 handle, size, cu_b, cu_x, chol_info, buffer);
226 struct FunctionMap<cuSOLVER,Kokkos::complex<float>>
228 static cusolverStatus_t bufferInfo(
229 cusolverSpHandle_t handle,
232 cusparseMatDescr_t & desc,
236 csrcholInfo_t & chol_info,
237 size_t * internalDataInBytes,
238 size_t * workspaceInBytes)
240 typedef cuFloatComplex scalar_t;
241 const scalar_t * cu_values =
reinterpret_cast<const scalar_t *
>(values);
242 cusolverStatus_t status =
243 cusolverSpCcsrcholBufferInfo(handle, size, nnz, desc,
244 cu_values, rowPtr, colIdx, chol_info,
245 internalDataInBytes, workspaceInBytes);
249 static cusolverStatus_t numeric(
250 cusolverSpHandle_t handle,
253 cusparseMatDescr_t & desc,
257 csrcholInfo_t & chol_info,
260 typedef cuFloatComplex scalar_t;
261 const scalar_t * cu_values =
reinterpret_cast<const scalar_t *
>(values);
262 cusolverStatus_t status = cusolverSpCcsrcholFactor(
263 handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
267 static cusolverStatus_t solve(
268 cusolverSpHandle_t handle,
272 csrcholInfo_t & chol_info,
275 typedef cuFloatComplex scalar_t;
276 const scalar_t * cu_b =
reinterpret_cast<const scalar_t *
>(b);
277 scalar_t * cu_x =
reinterpret_cast<scalar_t *
>(x);
278 cusolverStatus_t status = cusolverSpCcsrcholSolve(
279 handle, size, cu_b, cu_x, chol_info, buffer);
287 #endif // AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
Declaration of Function mapping class for Amesos2.