42 #ifndef STOKHOS_TPETRA_UTILITIES_UQ_PCE_HPP
43 #define STOKHOS_TPETRA_UTILITIES_UQ_PCE_HPP
47 #include "Tpetra_Map.hpp"
48 #include "Tpetra_MultiVector.hpp"
49 #include "Tpetra_CrsGraph.hpp"
50 #include "Tpetra_CrsMatrix.hpp"
54 #if defined(TPETRA_HAVE_KOKKOS_REFACTOR)
60 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >
61 create_cijk_crs_graph(
const CijkType&
cijk,
63 const size_t matrix_pce_size) {
65 using Teuchos::arrayView;
67 typedef Kokkos::Compat::KokkosDeviceWrapperNode<Device>
Node;
68 typedef Tpetra::Map<LocalOrdinal,GlobalOrdinal,Node> Map;
69 typedef Tpetra::CrsGraph<LocalOrdinal,GlobalOrdinal,Node> Graph;
71 const size_t pce_sz = cijk.dimension();
73 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal,Node>(pce_sz, comm);
74 RCP<Graph> graph = Tpetra::createCrsGraph(map);
75 if (matrix_pce_size == 1) {
77 for (
size_t i=0; i<pce_sz; ++i) {
79 graph->insertGlobalIndices(row, arrayView(&row, 1));
84 for (
size_t i=0; i<pce_sz; ++i) {
86 const size_t num_entry = cijk.num_entry(i);
87 const size_t entry_beg = cijk.entry_begin(i);
88 const size_t entry_end = entry_beg + num_entry;
89 for (
size_t entry = entry_beg; entry < entry_end; ++entry) {
92 graph->insertGlobalIndices(row, arrayView(&j, 1));
93 graph->insertGlobalIndices(row, arrayView(&k, 1));
97 graph->fillComplete();
108 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >
109 create_flat_pce_graph(
111 const CijkType& cijk,
115 const size_t matrix_pce_size) {
122 typedef Kokkos::Compat::KokkosDeviceWrapperNode<Device>
Node;
123 typedef Tpetra::Map<LocalOrdinal,GlobalOrdinal,Node> Map;
124 typedef Tpetra::CrsGraph<LocalOrdinal,GlobalOrdinal,Node> Graph;
130 flat_domain_map =
create_flat_map(*(graph.getDomainMap()), block_size);
137 RCP<const Map> flat_col_map =
143 RCP<const Map> flat_row_map;
144 if (graph.getRangeMap() == graph.getRowMap())
145 flat_row_map = flat_range_map;
151 cijk_graph = create_cijk_crs_graph<LocalOrdinal,GlobalOrdinal,Device>(
153 flat_domain_map->getComm(),
158 RCP<Graph> flat_graph =
rcp(
new Graph(flat_row_map, flat_col_map, 0));
161 ArrayView<const LocalOrdinal> outer_cols;
162 ArrayView<const LocalOrdinal> inner_cols;
163 Array<LocalOrdinal> flat_col_indices;
164 flat_col_indices.reserve(graph.getNodeMaxNumRowEntries()*block_size);
165 const LocalOrdinal num_outer_rows = graph.getNodeNumRows();
166 for (
LocalOrdinal outer_row=0; outer_row < num_outer_rows; outer_row++) {
169 graph.getLocalRowView(outer_row, outer_cols);
173 for (
LocalOrdinal inner_row=0; inner_row < block_size; inner_row++) {
176 const LocalOrdinal flat_row = outer_row*block_size + inner_row;
179 cijk_graph->getLocalRowView(inner_row, inner_cols);
183 flat_col_indices.resize(0);
186 for (
LocalOrdinal outer_entry=0; outer_entry<num_outer_cols;
191 for (
LocalOrdinal inner_entry=0; inner_entry<num_inner_cols;
196 const LocalOrdinal flat_col = outer_col*block_size + inner_col;
197 flat_col_indices.push_back(flat_col);
203 flat_graph->insertLocalIndices(flat_row, flat_col_indices());
208 flat_graph->fillComplete(flat_domain_map, flat_range_map);
219 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >
223 Kokkos::Compat::KokkosDeviceWrapperNode<Device> >& vec,
225 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >& flat_map) {
230 typedef Kokkos::Compat::KokkosDeviceWrapperNode<Device>
Node;
231 typedef Tpetra::MultiVector<BaseScalar,LocalOrdinal,GlobalOrdinal,Node> FlatVector;
232 typedef typename FlatVector::dual_view_type flat_view_type;
235 flat_view_type flat_vals (vec.getLocalViewDevice(), vec.getLocalViewHost());
236 if (vec.need_sync_device ()) {
237 flat_vals.modify_host ();
239 else if (vec.need_sync_host ()) {
240 flat_vals.modify_device ();
244 RCP<FlatVector> flat_vec =
rcp(
new FlatVector(flat_map, flat_vals));
255 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >
259 Kokkos::Compat::KokkosDeviceWrapperNode<Device> >& vec,
261 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >& flat_map) {
266 typedef Kokkos::Compat::KokkosDeviceWrapperNode<Device>
Node;
267 typedef Tpetra::MultiVector<BaseScalar,LocalOrdinal,GlobalOrdinal,Node> FlatVector;
268 typedef typename FlatVector::dual_view_type flat_view_type;
271 flat_view_type flat_vals (vec.getLocalViewDevice(), vec.getLocalViewHost());
272 if (vec.need_sync_device ()) {
273 flat_vals.modify_host ();
275 else if (vec.need_sync_host ()) {
276 flat_vals.modify_device ();
280 RCP<FlatVector> flat_vec =
rcp(
new FlatVector(flat_map, flat_vals));
292 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >
296 Kokkos::Compat::KokkosDeviceWrapperNode<Device> >& vec,
298 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >& flat_map) {
299 typedef Kokkos::Compat::KokkosDeviceWrapperNode<Device>
Node;
316 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >
320 Kokkos::Compat::KokkosDeviceWrapperNode<Device> >& vec,
322 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >& flat_map) {
323 typedef Kokkos::Compat::KokkosDeviceWrapperNode<Device>
Node;
339 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >
343 Kokkos::Compat::KokkosDeviceWrapperNode<Device> >& vec_const,
345 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >& flat_map) {
346 typedef Kokkos::Compat::KokkosDeviceWrapperNode<Device>
Node;
349 return flat_mv->getVector(0);
359 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >
362 LocalOrdinal,GlobalOrdinal,
363 Kokkos::Compat::KokkosDeviceWrapperNode<Device> >& vec,
364 Teuchos::RCP<
const Tpetra::Map<LocalOrdinal,GlobalOrdinal,
365 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >& flat_map) {
366 typedef Kokkos::Compat::KokkosDeviceWrapperNode<Device>
Node;
368 const LocalOrdinal pce_size =
382 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >
385 LocalOrdinal,GlobalOrdinal,
386 Kokkos::Compat::KokkosDeviceWrapperNode<Device> >& vec,
387 const Teuchos::RCP<
const Tpetra::Map<LocalOrdinal,GlobalOrdinal,
388 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >& flat_map) {
389 typedef Kokkos::Compat::KokkosDeviceWrapperNode<Device>
Node;
392 return flat_mv->getVectorNonConst(0);
402 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >
405 LocalOrdinal,GlobalOrdinal,
406 Kokkos::Compat::KokkosDeviceWrapperNode<Device> >& vec,
407 Teuchos::RCP<
const Tpetra::Map<LocalOrdinal,GlobalOrdinal,
408 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >& flat_map) {
409 typedef Kokkos::Compat::KokkosDeviceWrapperNode<Device>
Node;
411 const LocalOrdinal pce_size =
422 typename Device,
typename CijkType>
425 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >
428 LocalOrdinal,GlobalOrdinal,Kokkos::Compat::KokkosDeviceWrapperNode<Device> >& mat,
429 const Teuchos::RCP<
const Tpetra::CrsGraph<LocalOrdinal,GlobalOrdinal,Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >& flat_graph,
430 const Teuchos::RCP<
const Tpetra::CrsGraph<LocalOrdinal,GlobalOrdinal,Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >& cijk_graph,
431 const CijkType& cijk) {
437 typedef Kokkos::Compat::KokkosDeviceWrapperNode<Device>
Node;
440 typedef Tpetra::CrsMatrix<BaseScalar,LocalOrdinal,GlobalOrdinal,Node> FlatMatrix;
442 const LocalOrdinal block_size = cijk.dimension();
443 const LocalOrdinal matrix_pce_size =
447 RCP<FlatMatrix> flat_mat =
rcp(
new FlatMatrix(flat_graph));
450 ArrayView<const Scalar> outer_values;
451 ArrayView<const LocalOrdinal> outer_cols;
452 ArrayView<const LocalOrdinal> inner_cols;
453 ArrayView<const LocalOrdinal> flat_cols;
454 Array<BaseScalar> flat_values;
455 flat_values.reserve(flat_graph->getNodeMaxNumRowEntries());
456 const LocalOrdinal num_outer_rows = mat.getNodeNumRows();
457 for (LocalOrdinal outer_row=0; outer_row < num_outer_rows; outer_row++) {
460 mat.getLocalRowView(outer_row, outer_cols, outer_values);
461 const LocalOrdinal num_outer_cols = outer_cols.size();
464 for (LocalOrdinal inner_row=0; inner_row < block_size; inner_row++) {
467 const LocalOrdinal flat_row = outer_row*block_size + inner_row;
470 cijk_graph->getLocalRowView(inner_row, inner_cols);
471 const LocalOrdinal num_inner_cols = inner_cols.size();
474 const LocalOrdinal num_flat_indices = num_outer_cols*num_inner_cols;
476 flat_values.assign(num_flat_indices, BaseScalar(0));
478 if (matrix_pce_size == 1) {
482 for (LocalOrdinal outer_entry=0; outer_entry<num_outer_cols;
486 flat_values[outer_entry] = outer_values[outer_entry].coeff(0);
494 const size_t num_entry = cijk.num_entry(inner_row);
495 const size_t entry_beg = cijk.entry_begin(inner_row);
496 const size_t entry_end = entry_beg + num_entry;
497 for (
size_t entry = entry_beg; entry < entry_end; ++entry) {
498 const LocalOrdinal j = cijk.coord(entry,0);
499 const LocalOrdinal k = cijk.coord(entry,1);
500 const BaseScalar c = cijk.value(entry);
503 typedef typename ArrayView<const LocalOrdinal>::iterator iterator;
505 std::find(inner_cols.begin(), inner_cols.end(),
j);
507 std::find(inner_cols.begin(), inner_cols.end(), k);
508 const LocalOrdinal j_offset = ptr_j - inner_cols.begin();
509 const LocalOrdinal k_offset = ptr_k - inner_cols.begin();
512 for (LocalOrdinal outer_entry=0; outer_entry<num_outer_cols;
516 flat_values[outer_entry*num_inner_cols + j_offset] +=
517 c*outer_values[outer_entry].coeff(k);
518 flat_values[outer_entry*num_inner_cols + k_offset] +=
519 c*outer_values[outer_entry].coeff(j);
528 flat_graph->getLocalRowView(flat_row, flat_cols);
529 flat_mat->replaceLocalValues(flat_row, flat_cols, flat_values());
534 flat_mat->fillComplete(flat_graph->getDomainMap(),
535 flat_graph->getRangeMap());
544 #endif // STOKHOS_TPETRA_UTILITIES_MP_VECTOR_HPP
Stokhos::StandardStorage< int, double > Storage
KOKKOS_INLINE_FUNCTION constexpr std::enable_if< is_view_uq_pce< View< T, P...> >::value, unsigned >::type dimension_scalar(const View< T, P...> &view)
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)
Teuchos::RCP< const Tpetra::MultiVector< typename Storage::value_type, LocalOrdinal, GlobalOrdinal, Node > > create_flat_vector_view(const Tpetra::MultiVector< Sacado::MP::Vector< Storage >, LocalOrdinal, GlobalOrdinal, Node > &vec_const, const Teuchos::RCP< const Tpetra::Map< LocalOrdinal, GlobalOrdinal, Node > > &flat_map)
KokkosClassic::DefaultNode::DefaultNodeType Node
Teuchos::RCP< Tpetra::CrsMatrix< typename Storage::value_type, LocalOrdinal, GlobalOrdinal, Node > > create_flat_matrix(const Tpetra::CrsMatrix< Sacado::MP::Vector< Storage >, LocalOrdinal, GlobalOrdinal, Node > &mat, const Teuchos::RCP< const Tpetra::CrsGraph< LocalOrdinal, GlobalOrdinal, Node > > &flat_graph, const LocalOrdinal block_size)
KOKKOS_INLINE_FUNCTION constexpr std::enable_if< is_view_uq_pce< view_type >::value, typename CijkType< view_type >::type >::type cijk(const view_type &view)
Teuchos::RCP< Tpetra::Map< LocalOrdinal, GlobalOrdinal, Node > > create_flat_map(const Tpetra::Map< LocalOrdinal, GlobalOrdinal, Node > &map, const LocalOrdinal block_size)