42 #ifndef STOKHOS_OPENMP_MKL_CRSMATRIX_HPP
43 #define STOKHOS_OPENMP_MKL_CRSMATRIX_HPP
45 #include "Kokkos_Macros.hpp"
47 #if defined(KOKKOS_ENABLE_OPENMP) && defined(HAVE_STOKHOS_MKL)
49 #include "Kokkos_Core.hpp"
59 template <
typename ValueType,
typename OrdinalType,
typename ExecutionSpace>
60 struct GatherTranspose {
64 typedef typename execution_space::size_type size_type;
65 typedef Kokkos::View< value_type** , Kokkos::LayoutLeft, execution_space > multi_vector_type ;
66 typedef Kokkos::View< value_type** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
68 const multi_vector_type m_x;
69 const trans_multi_vector_type m_xt;
70 const std::vector<ordinal_type> m_indices;
72 GatherTranspose(
const multi_vector_type & x,
73 const trans_multi_vector_type& xt,
74 const std::vector<ordinal_type> & indices) :
75 m_x(x), m_xt(xt), m_indices(indices), ncol(indices.size()) {}
77 inline void operator()(
const size_type row )
const {
78 for (
size_t col=0; col<ncol; ++col)
79 m_xt(col,row) = m_x(row,m_indices[col]);
82 static void apply(
const multi_vector_type & x,
83 const trans_multi_vector_type& xt,
84 const std::vector<ordinal_type> & indices) {
85 const size_t n = xt.extent(1);
86 Kokkos::parallel_for( n, GatherTranspose(x,xt,indices) );
90 template <
typename ValueType,
typename OrdinalType,
typename ExecutionSpace>
91 struct ScatterTranspose {
95 typedef typename execution_space::size_type size_type;
96 typedef Kokkos::View< value_type** , Kokkos::LayoutLeft, execution_space > multi_vector_type ;
97 typedef Kokkos::View< value_type** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
99 const multi_vector_type m_x;
100 const trans_multi_vector_type m_xt;
101 const std::vector<ordinal_type> m_indices;
103 ScatterTranspose(
const multi_vector_type & x,
104 const trans_multi_vector_type& xt,
105 const std::vector<ordinal_type> & indices) :
106 m_x(x), m_xt(xt), m_indices(indices), ncol(indices.size()) {}
108 inline void operator()(
const size_type row )
const {
109 for (
size_t col=0; col<ncol; ++col)
110 m_x(row,m_indices[col]) = m_xt(col,row);
113 static void apply(
const multi_vector_type & x,
114 const trans_multi_vector_type& xt,
115 const std::vector<ordinal_type> & indices) {
116 const size_t n = xt.extent(1);
117 Kokkos::parallel_for( n, ScatterTranspose(x,xt,indices) );
121 template <
typename ValueType,
typename ExecutionSpace>
122 struct GatherVecTranspose {
125 typedef typename execution_space::size_type size_type;
126 typedef Kokkos::View< value_type* , execution_space > vector_type ;
127 typedef Kokkos::View< value_type** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
129 const std::vector<vector_type> m_x;
130 const trans_multi_vector_type m_xt;
132 GatherVecTranspose(
const std::vector<vector_type> & x,
133 const trans_multi_vector_type& xt) :
134 m_x(x), m_xt(xt), ncol(x.size()) {}
136 inline void operator()(
const size_type row )
const {
137 for (
size_t col=0; col<ncol; ++col)
138 m_xt(col,row) = m_x[col](row);
141 static void apply(
const std::vector<vector_type> & x,
142 const trans_multi_vector_type& xt) {
143 const size_t n = xt.extent(1);
144 Kokkos::parallel_for( n, GatherVecTranspose(x,xt) );
148 template <
typename ValueType,
typename ExecutionSpace>
149 struct ScatterVecTranspose {
152 typedef typename execution_space::size_type size_type;
153 typedef Kokkos::View< value_type* , execution_space > vector_type ;
154 typedef Kokkos::View< value_type** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
156 const std::vector<vector_type> m_x;
157 const trans_multi_vector_type m_xt;
159 ScatterVecTranspose(
const std::vector<vector_type> & x,
160 const trans_multi_vector_type& xt) :
161 m_x(x), m_xt(xt), ncol(x.size()) {}
163 inline void operator()(
const size_type row )
const {
164 for (
size_t col=0; col<ncol; ++col)
165 m_x[col](row) = m_xt(col,row);
168 static void apply(
const std::vector<vector_type> & x,
169 const trans_multi_vector_type& xt) {
170 const size_t n = xt.extent(1);
171 Kokkos::parallel_for( n, ScatterVecTranspose(x,xt) );
178 class MKLMultiply {};
180 void multiply(
const CrsMatrix< double , Kokkos::OpenMP >&
A,
181 const Kokkos::View< double* , Kokkos::OpenMP >& x,
182 Kokkos::View< double* , Kokkos::OpenMP >& y,
185 MKL_INT n = A.graph.row_map.extent(0) - 1 ;
186 double *A_values = A.values.data() ;
187 MKL_INT *col_indices = A.graph.entries.data() ;
188 MKL_INT *row_beg =
const_cast<MKL_INT*
>(A.graph.row_map.data()) ;
189 MKL_INT *row_end = row_beg+1;
190 char matdescra[6] = {
'G',
'x',
'N',
'C',
'x',
'x' };
195 double *x_values = x.data() ;
196 double *y_values = y.data() ;
198 mkl_dcsrmv(&trans, &n, &n, &alpha, matdescra, A_values, col_indices,
199 row_beg, row_end, x_values, &beta, y_values);
202 void multiply(
const CrsMatrix< float , Kokkos::OpenMP >& A,
203 const Kokkos::View< float* , Kokkos::OpenMP >& x,
204 Kokkos::View< float* , Kokkos::OpenMP >& y,
207 MKL_INT n = A.graph.row_map.extent(0) - 1 ;
208 float *A_values = A.values.data() ;
209 MKL_INT *col_indices = A.graph.entries.data() ;
210 MKL_INT *row_beg =
const_cast<MKL_INT*
>(A.graph.row_map.data()) ;
211 MKL_INT *row_end = row_beg+1;
212 char matdescra[6] = {
'G',
'x',
'N',
'C',
'x',
'x' };
217 float *x_values = x.data() ;
218 float *y_values = y.data() ;
220 mkl_scsrmv(&trans, &n, &n, &alpha, matdescra, A_values, col_indices,
221 row_beg, row_end, x_values, &beta, y_values);
224 void multiply(
const CrsMatrix< double , Kokkos::OpenMP >& A,
225 const std::vector< Kokkos::View< double* , Kokkos::OpenMP > >& x,
226 std::vector< Kokkos::View< double* , Kokkos::OpenMP > >& y,
231 typedef Kokkos::View< double** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
233 MKL_INT n = A.graph.row_map.extent(0) - 1 ;
234 double *A_values = A.values.data() ;
235 MKL_INT *col_indices = A.graph.entries.data() ;
236 MKL_INT *row_beg =
const_cast<MKL_INT*
>(A.graph.row_map.data()) ;
237 MKL_INT *row_end = row_beg+1;
238 char matdescra[6] = {
'G',
'x',
'N',
'C',
'x',
'x' };
244 MKL_INT ncol = x.size();
245 trans_multi_vector_type xx(
"xx" , ncol , n );
246 trans_multi_vector_type yy(
"yy" , ncol , n );
247 Impl::GatherVecTranspose<value_type,execution_space>::apply(x,xx);
248 double *x_values = xx.data() ;
249 double *y_values = yy.data() ;
252 mkl_dcsrmm(&trans, &n, &ncol, &n, &alpha, matdescra, A_values, col_indices,
253 row_beg, row_end, x_values, &ncol, &beta, y_values, &ncol);
256 Impl::ScatterVecTranspose<value_type,execution_space>::apply(y,yy);
259 void multiply(
const CrsMatrix< float , Kokkos::OpenMP >& A,
260 const std::vector< Kokkos::View< float* , Kokkos::OpenMP > >& x,
261 std::vector< Kokkos::View< float* , Kokkos::OpenMP > >& y,
266 typedef Kokkos::View< float** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
268 MKL_INT n = A.graph.row_map.extent(0) - 1 ;
269 float *A_values = A.values.data() ;
270 MKL_INT *col_indices = A.graph.entries.data() ;
271 MKL_INT *row_beg =
const_cast<MKL_INT*
>(A.graph.row_map.data()) ;
272 MKL_INT *row_end = row_beg+1;
273 char matdescra[6] = {
'G',
'x',
'N',
'C',
'x',
'x' };
279 MKL_INT ncol = x.size();
280 trans_multi_vector_type xx(
"xx" , ncol , n );
281 trans_multi_vector_type yy(
"yy" , ncol , n );
282 Impl::GatherVecTranspose<value_type,execution_space>::apply(x,xx);
283 float *x_values = xx.data() ;
284 float *y_values = yy.data() ;
287 mkl_scsrmm(&trans, &n, &ncol, &n, &alpha, matdescra, A_values, col_indices,
288 row_beg, row_end, x_values, &ncol, &beta, y_values, &ncol);
291 Impl::ScatterVecTranspose<value_type,execution_space>::apply(y,yy);
294 template <
typename ordinal_type>
296 const CrsMatrix< double , Kokkos::OpenMP >& A,
297 const Kokkos::View< double** , Kokkos::LayoutLeft, Kokkos::OpenMP >& x,
298 Kokkos::View< double** , Kokkos::LayoutLeft, Kokkos::OpenMP >& y,
299 const std::vector<ordinal_type>& indices,
304 typedef Kokkos::View< double** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
306 MKL_INT n = A.graph.row_map.extent(0) - 1 ;
307 double *A_values = A.values.data() ;
308 MKL_INT *col_indices = A.graph.entries.data() ;
309 MKL_INT *row_beg =
const_cast<MKL_INT*
>(A.graph.row_map.data()) ;
310 MKL_INT *row_end = row_beg+1;
311 char matdescra[6] = {
'G',
'x',
'N',
'C',
'x',
'x' };
317 MKL_INT ncol = indices.size();
318 trans_multi_vector_type xx(
"xx" , ncol , n );
319 trans_multi_vector_type yy(
"yy" , ncol , n );
320 Impl::GatherTranspose<value_type,ordinal_type,execution_space>::apply(x,xx,indices);
321 double *x_values = xx.data() ;
322 double *y_values = yy.data() ;
325 mkl_dcsrmm(&trans, &n, &ncol, &n, &alpha, matdescra, A_values, col_indices,
326 row_beg, row_end, x_values, &ncol, &beta, y_values, &ncol);
329 Impl::ScatterTranspose<value_type,ordinal_type,execution_space>::apply(y,yy,indices);
332 template <
typename ordinal_type>
334 const CrsMatrix< float , Kokkos::OpenMP >& A,
335 const Kokkos::View< float** , Kokkos::LayoutLeft, Kokkos::OpenMP >& x,
336 Kokkos::View< float** , Kokkos::LayoutLeft, Kokkos::OpenMP >& y,
337 const std::vector<ordinal_type>& indices,
342 typedef Kokkos::View< float** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
344 MKL_INT n = A.graph.row_map.extent(0) - 1 ;
345 float *A_values = A.values.data() ;
346 MKL_INT *col_indices = A.graph.entries.data() ;
347 MKL_INT *row_beg =
const_cast<MKL_INT*
>(A.graph.row_map.data()) ;
348 MKL_INT *row_end = row_beg+1;
349 char matdescra[6] = {
'G',
'x',
'N',
'C',
'x',
'x' };
355 MKL_INT ncol = indices.size();
356 trans_multi_vector_type xx(
"xx" , ncol , n );
357 trans_multi_vector_type yy(
"yy" , ncol , n );
358 Impl::GatherTranspose<value_type,ordinal_type,execution_space>::apply(x,xx,indices);
359 float *x_values = xx.data() ;
360 float *y_values = yy.data() ;
363 mkl_scsrmm(&trans, &n, &ncol, &n, &alpha, matdescra, A_values, col_indices,
364 row_beg, row_end, x_values, &ncol, &beta, y_values, &ncol);
367 Impl::ScatterTranspose<value_type,ordinal_type,execution_space>::apply(y,yy,indices);
Kokkos::DefaultExecutionSpace execution_space
void multiply(const CrsMatrix< MatrixValue, Device, Layout > &A, const InputMultiVectorType &x, OutputMultiVectorType &y, const std::vector< OrdinalType > &col_indices, SingleColumnMultivectorMultiply)