10 #ifndef AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
11 #define AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
14 #include "Amesos2_cuSOLVER_TypeMap.hpp"
17 #include <cusolverSp.h>
18 #include <cusolverDn.h>
20 #include <cusolverSp_LOWLEVEL_PREVIEW.h>
22 #ifdef HAVE_TEUCHOS_COMPLEX
23 #include <cuComplex.h>
29 struct FunctionMap<cuSOLVER,double>
31 static cusolverStatus_t bufferInfo(
32 cusolverSpHandle_t handle,
35 cusparseMatDescr_t & desc,
36 const double * values,
39 csrcholInfo_t & chol_info,
40 size_t * internalDataInBytes,
41 size_t * workspaceInBytes)
43 cusolverStatus_t status =
44 cusolverSpDcsrcholBufferInfo(handle, size, nnz, desc, values,
45 rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
49 static cusolverStatus_t numeric(
50 cusolverSpHandle_t handle,
53 cusparseMatDescr_t & desc,
54 const double * values,
57 csrcholInfo_t & chol_info,
60 cusolverStatus_t status = cusolverSpDcsrcholFactor(
61 handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
62 cudaDeviceSynchronize();
66 static cusolverStatus_t solve(
67 cusolverSpHandle_t handle,
71 csrcholInfo_t & chol_info,
74 cusolverStatus_t status = cusolverSpDcsrcholSolve(
75 handle, size, b, x, chol_info, buffer);
76 cudaDeviceSynchronize();
82 struct FunctionMap<cuSOLVER,float>
84 static cusolverStatus_t bufferInfo(
85 cusolverSpHandle_t handle,
88 cusparseMatDescr_t & desc,
92 csrcholInfo_t & chol_info,
93 size_t * internalDataInBytes,
94 size_t * workspaceInBytes)
96 cusolverStatus_t status =
97 cusolverSpScsrcholBufferInfo(handle, size, nnz, desc, values,
98 rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
102 static cusolverStatus_t numeric(
103 cusolverSpHandle_t handle,
106 cusparseMatDescr_t & desc,
107 const float * values,
110 csrcholInfo_t & chol_info,
113 cusolverStatus_t status = cusolverSpScsrcholFactor(
114 handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
115 cudaDeviceSynchronize();
119 static cusolverStatus_t solve(
120 cusolverSpHandle_t handle,
124 csrcholInfo_t & chol_info,
127 cusolverStatus_t status = cusolverSpScsrcholSolve(
128 handle, size, b, x, chol_info, buffer);
129 cudaDeviceSynchronize();
134 #ifdef HAVE_TEUCHOS_COMPLEX
136 struct FunctionMap<cuSOLVER,Kokkos::complex<double>>
138 static cusolverStatus_t bufferInfo(
139 cusolverSpHandle_t handle,
142 cusparseMatDescr_t & desc,
146 csrcholInfo_t & chol_info,
147 size_t * internalDataInBytes,
148 size_t * workspaceInBytes)
150 typedef cuDoubleComplex scalar_t;
151 const scalar_t * cu_values =
reinterpret_cast<const scalar_t *
>(values);
152 cusolverStatus_t status =
153 cusolverSpZcsrcholBufferInfo(handle, size, nnz, desc,
154 cu_values, rowPtr, colIdx, chol_info,
155 internalDataInBytes, workspaceInBytes);
159 static cusolverStatus_t numeric(
160 cusolverSpHandle_t handle,
163 cusparseMatDescr_t & desc,
167 csrcholInfo_t & chol_info,
170 typedef cuDoubleComplex scalar_t;
171 const scalar_t * cu_values =
172 reinterpret_cast<const scalar_t *
>(values);
173 cusolverStatus_t status = cusolverSpZcsrcholFactor(
174 handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
175 cudaDeviceSynchronize();
179 static cusolverStatus_t solve(
180 cusolverSpHandle_t handle,
184 csrcholInfo_t & chol_info,
187 typedef cuDoubleComplex scalar_t;
188 const scalar_t * cu_b =
reinterpret_cast<const scalar_t *
>(b);
189 scalar_t * cu_x =
reinterpret_cast<scalar_t *
>(x);
190 cusolverStatus_t status = cusolverSpZcsrcholSolve(
191 handle, size, cu_b, cu_x, chol_info, buffer);
192 cudaDeviceSynchronize();
198 struct FunctionMap<cuSOLVER,Kokkos::complex<float>>
200 static cusolverStatus_t bufferInfo(
201 cusolverSpHandle_t handle,
204 cusparseMatDescr_t & desc,
208 csrcholInfo_t & chol_info,
209 size_t * internalDataInBytes,
210 size_t * workspaceInBytes)
212 typedef cuFloatComplex scalar_t;
213 const scalar_t * cu_values =
reinterpret_cast<const scalar_t *
>(values);
214 cusolverStatus_t status =
215 cusolverSpCcsrcholBufferInfo(handle, size, nnz, desc,
216 cu_values, rowPtr, colIdx, chol_info,
217 internalDataInBytes, workspaceInBytes);
221 static cusolverStatus_t numeric(
222 cusolverSpHandle_t handle,
225 cusparseMatDescr_t & desc,
229 csrcholInfo_t & chol_info,
232 typedef cuFloatComplex scalar_t;
233 const scalar_t * cu_values =
reinterpret_cast<const scalar_t *
>(values);
234 cusolverStatus_t status = cusolverSpCcsrcholFactor(
235 handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
236 cudaDeviceSynchronize();
240 static cusolverStatus_t solve(
241 cusolverSpHandle_t handle,
245 csrcholInfo_t & chol_info,
248 typedef cuFloatComplex scalar_t;
249 const scalar_t * cu_b =
reinterpret_cast<const scalar_t *
>(b);
250 scalar_t * cu_x =
reinterpret_cast<scalar_t *
>(x);
251 cusolverStatus_t status = cusolverSpCcsrcholSolve(
252 handle, size, cu_b, cu_x, chol_info, buffer);
253 cudaDeviceSynchronize();
261 #endif // AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
const int size
Definition: klu2_simple.cpp:50
Declaration of Function mapping class for Amesos2.