10 #ifndef STOKHOS_OPENMP_MKL_CRSMATRIX_HPP
11 #define STOKHOS_OPENMP_MKL_CRSMATRIX_HPP
13 #include "Kokkos_Macros.hpp"
15 #if defined(KOKKOS_ENABLE_OPENMP) && defined(HAVE_STOKHOS_MKL)
17 #include "Kokkos_Core.hpp"
27 template <
typename ValueType,
typename OrdinalType,
typename ExecutionSpace>
28 struct GatherTranspose {
32 typedef typename execution_space::size_type size_type;
33 typedef Kokkos::View< value_type** , Kokkos::LayoutLeft, execution_space > multi_vector_type ;
34 typedef Kokkos::View< value_type** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
36 const multi_vector_type m_x;
37 const trans_multi_vector_type m_xt;
38 const std::vector<ordinal_type> m_indices;
40 GatherTranspose(
const multi_vector_type & x,
41 const trans_multi_vector_type& xt,
42 const std::vector<ordinal_type> & indices) :
43 m_x(x), m_xt(xt), m_indices(indices), ncol(indices.size()) {}
45 inline void operator()(
const size_type row )
const {
46 for (
size_t col=0; col<ncol; ++col)
47 m_xt(col,row) = m_x(row,m_indices[col]);
50 static void apply(
const multi_vector_type & x,
51 const trans_multi_vector_type& xt,
52 const std::vector<ordinal_type> & indices) {
53 const size_t n = xt.extent(1);
54 Kokkos::parallel_for( n, GatherTranspose(x,xt,indices) );
58 template <
typename ValueType,
typename OrdinalType,
typename ExecutionSpace>
59 struct ScatterTranspose {
63 typedef typename execution_space::size_type size_type;
64 typedef Kokkos::View< value_type** , Kokkos::LayoutLeft, execution_space > multi_vector_type ;
65 typedef Kokkos::View< value_type** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
67 const multi_vector_type m_x;
68 const trans_multi_vector_type m_xt;
69 const std::vector<ordinal_type> m_indices;
71 ScatterTranspose(
const multi_vector_type & x,
72 const trans_multi_vector_type& xt,
73 const std::vector<ordinal_type> & indices) :
74 m_x(x), m_xt(xt), m_indices(indices), ncol(indices.size()) {}
76 inline void operator()(
const size_type row )
const {
77 for (
size_t col=0; col<ncol; ++col)
78 m_x(row,m_indices[col]) = m_xt(col,row);
81 static void apply(
const multi_vector_type & x,
82 const trans_multi_vector_type& xt,
83 const std::vector<ordinal_type> & indices) {
84 const size_t n = xt.extent(1);
85 Kokkos::parallel_for( n, ScatterTranspose(x,xt,indices) );
89 template <
typename ValueType,
typename ExecutionSpace>
90 struct GatherVecTranspose {
93 typedef typename execution_space::size_type size_type;
94 typedef Kokkos::View< value_type* , execution_space > vector_type ;
95 typedef Kokkos::View< value_type** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
97 const std::vector<vector_type> m_x;
98 const trans_multi_vector_type m_xt;
100 GatherVecTranspose(
const std::vector<vector_type> & x,
101 const trans_multi_vector_type& xt) :
102 m_x(x), m_xt(xt), ncol(x.size()) {}
104 inline void operator()(
const size_type row )
const {
105 for (
size_t col=0; col<ncol; ++col)
106 m_xt(col,row) = m_x[col](row);
109 static void apply(
const std::vector<vector_type> & x,
110 const trans_multi_vector_type& xt) {
111 const size_t n = xt.extent(1);
112 Kokkos::parallel_for( n, GatherVecTranspose(x,xt) );
116 template <
typename ValueType,
typename ExecutionSpace>
117 struct ScatterVecTranspose {
120 typedef typename execution_space::size_type size_type;
121 typedef Kokkos::View< value_type* , execution_space > vector_type ;
122 typedef Kokkos::View< value_type** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
124 const std::vector<vector_type> m_x;
125 const trans_multi_vector_type m_xt;
127 ScatterVecTranspose(
const std::vector<vector_type> & x,
128 const trans_multi_vector_type& xt) :
129 m_x(x), m_xt(xt), ncol(x.size()) {}
131 inline void operator()(
const size_type row )
const {
132 for (
size_t col=0; col<ncol; ++col)
133 m_x[col](row) = m_xt(col,row);
136 static void apply(
const std::vector<vector_type> & x,
137 const trans_multi_vector_type& xt) {
138 const size_t n = xt.extent(1);
139 Kokkos::parallel_for( n, ScatterVecTranspose(x,xt) );
146 class MKLMultiply {};
148 void multiply(
const CrsMatrix< double , Kokkos::OpenMP >&
A,
149 const Kokkos::View< double* , Kokkos::OpenMP >& x,
150 Kokkos::View< double* , Kokkos::OpenMP >& y,
153 MKL_INT n = A.graph.row_map.extent(0) - 1 ;
154 double *A_values = A.values.data() ;
155 MKL_INT *col_indices = A.graph.entries.data() ;
156 MKL_INT *row_beg =
const_cast<MKL_INT*
>(A.graph.row_map.data()) ;
157 MKL_INT *row_end = row_beg+1;
158 char matdescra[6] = {
'G',
'x',
'N',
'C',
'x',
'x' };
163 double *x_values = x.data() ;
164 double *y_values = y.data() ;
166 mkl_dcsrmv(&trans, &n, &n, &alpha, matdescra, A_values, col_indices,
167 row_beg, row_end, x_values, &beta, y_values);
170 void multiply(
const CrsMatrix< float , Kokkos::OpenMP >& A,
171 const Kokkos::View< float* , Kokkos::OpenMP >& x,
172 Kokkos::View< float* , Kokkos::OpenMP >& y,
175 MKL_INT n = A.graph.row_map.extent(0) - 1 ;
176 float *A_values = A.values.data() ;
177 MKL_INT *col_indices = A.graph.entries.data() ;
178 MKL_INT *row_beg =
const_cast<MKL_INT*
>(A.graph.row_map.data()) ;
179 MKL_INT *row_end = row_beg+1;
180 char matdescra[6] = {
'G',
'x',
'N',
'C',
'x',
'x' };
185 float *x_values = x.data() ;
186 float *y_values = y.data() ;
188 mkl_scsrmv(&trans, &n, &n, &alpha, matdescra, A_values, col_indices,
189 row_beg, row_end, x_values, &beta, y_values);
192 void multiply(
const CrsMatrix< double , Kokkos::OpenMP >& A,
193 const std::vector< Kokkos::View< double* , Kokkos::OpenMP > >& x,
194 std::vector< Kokkos::View< double* , Kokkos::OpenMP > >& y,
199 typedef Kokkos::View< double** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
201 MKL_INT n = A.graph.row_map.extent(0) - 1 ;
202 double *A_values = A.values.data() ;
203 MKL_INT *col_indices = A.graph.entries.data() ;
204 MKL_INT *row_beg =
const_cast<MKL_INT*
>(A.graph.row_map.data()) ;
205 MKL_INT *row_end = row_beg+1;
206 char matdescra[6] = {
'G',
'x',
'N',
'C',
'x',
'x' };
212 MKL_INT ncol = x.size();
213 trans_multi_vector_type xx(
"xx" , ncol , n );
214 trans_multi_vector_type yy(
"yy" , ncol , n );
215 Impl::GatherVecTranspose<value_type,execution_space>::apply(x,xx);
216 double *x_values = xx.data() ;
217 double *y_values = yy.data() ;
220 mkl_dcsrmm(&trans, &n, &ncol, &n, &alpha, matdescra, A_values, col_indices,
221 row_beg, row_end, x_values, &ncol, &beta, y_values, &ncol);
224 Impl::ScatterVecTranspose<value_type,execution_space>::apply(y,yy);
227 void multiply(
const CrsMatrix< float , Kokkos::OpenMP >& A,
228 const std::vector< Kokkos::View< float* , Kokkos::OpenMP > >& x,
229 std::vector< Kokkos::View< float* , Kokkos::OpenMP > >& y,
234 typedef Kokkos::View< float** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
236 MKL_INT n = A.graph.row_map.extent(0) - 1 ;
237 float *A_values = A.values.data() ;
238 MKL_INT *col_indices = A.graph.entries.data() ;
239 MKL_INT *row_beg =
const_cast<MKL_INT*
>(A.graph.row_map.data()) ;
240 MKL_INT *row_end = row_beg+1;
241 char matdescra[6] = {
'G',
'x',
'N',
'C',
'x',
'x' };
247 MKL_INT ncol = x.size();
248 trans_multi_vector_type xx(
"xx" , ncol , n );
249 trans_multi_vector_type yy(
"yy" , ncol , n );
250 Impl::GatherVecTranspose<value_type,execution_space>::apply(x,xx);
251 float *x_values = xx.data() ;
252 float *y_values = yy.data() ;
255 mkl_scsrmm(&trans, &n, &ncol, &n, &alpha, matdescra, A_values, col_indices,
256 row_beg, row_end, x_values, &ncol, &beta, y_values, &ncol);
259 Impl::ScatterVecTranspose<value_type,execution_space>::apply(y,yy);
262 template <
typename ordinal_type>
264 const CrsMatrix< double , Kokkos::OpenMP >& A,
265 const Kokkos::View< double** , Kokkos::LayoutLeft, Kokkos::OpenMP >& x,
266 Kokkos::View< double** , Kokkos::LayoutLeft, Kokkos::OpenMP >& y,
267 const std::vector<ordinal_type>& indices,
272 typedef Kokkos::View< double** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
274 MKL_INT n = A.graph.row_map.extent(0) - 1 ;
275 double *A_values = A.values.data() ;
276 MKL_INT *col_indices = A.graph.entries.data() ;
277 MKL_INT *row_beg =
const_cast<MKL_INT*
>(A.graph.row_map.data()) ;
278 MKL_INT *row_end = row_beg+1;
279 char matdescra[6] = {
'G',
'x',
'N',
'C',
'x',
'x' };
285 MKL_INT ncol = indices.size();
286 trans_multi_vector_type xx(
"xx" , ncol , n );
287 trans_multi_vector_type yy(
"yy" , ncol , n );
288 Impl::GatherTranspose<value_type,ordinal_type,execution_space>::apply(x,xx,indices);
289 double *x_values = xx.data() ;
290 double *y_values = yy.data() ;
293 mkl_dcsrmm(&trans, &n, &ncol, &n, &alpha, matdescra, A_values, col_indices,
294 row_beg, row_end, x_values, &ncol, &beta, y_values, &ncol);
297 Impl::ScatterTranspose<value_type,ordinal_type,execution_space>::apply(y,yy,indices);
300 template <
typename ordinal_type>
302 const CrsMatrix< float , Kokkos::OpenMP >& A,
303 const Kokkos::View< float** , Kokkos::LayoutLeft, Kokkos::OpenMP >& x,
304 Kokkos::View< float** , Kokkos::LayoutLeft, Kokkos::OpenMP >& y,
305 const std::vector<ordinal_type>& indices,
310 typedef Kokkos::View< float** , Kokkos::LayoutLeft, execution_space > trans_multi_vector_type ;
312 MKL_INT n = A.graph.row_map.extent(0) - 1 ;
313 float *A_values = A.values.data() ;
314 MKL_INT *col_indices = A.graph.entries.data() ;
315 MKL_INT *row_beg =
const_cast<MKL_INT*
>(A.graph.row_map.data()) ;
316 MKL_INT *row_end = row_beg+1;
317 char matdescra[6] = {
'G',
'x',
'N',
'C',
'x',
'x' };
323 MKL_INT ncol = indices.size();
324 trans_multi_vector_type xx(
"xx" , ncol , n );
325 trans_multi_vector_type yy(
"yy" , ncol , n );
326 Impl::GatherTranspose<value_type,ordinal_type,execution_space>::apply(x,xx,indices);
327 float *x_values = xx.data() ;
328 float *y_values = yy.data() ;
331 mkl_scsrmm(&trans, &n, &ncol, &n, &alpha, matdescra, A_values, col_indices,
332 row_beg, row_end, x_values, &ncol, &beta, y_values, &ncol);
335 Impl::ScatterTranspose<value_type,ordinal_type,execution_space>::apply(y,yy,indices);
Kokkos::DefaultExecutionSpace execution_space
void multiply(const CrsMatrix< MatrixValue, Device, Layout > &A, const InputMultiVectorType &x, OutputMultiVectorType &y, const std::vector< OrdinalType > &col_indices, SingleColumnMultivectorMultiply)