10 #ifndef MUELU_MATRIXFREETENTATIVEP_DEF_HPP
11 #define MUELU_MATRIXFREETENTATIVEP_DEF_HPP
15 #include "MueLu_Aggregates.hpp"
17 #include <Tpetra_KokkosCompat_ClassicNodeAPI_Wrapper.hpp>
21 #include "Xpetra_Map.hpp"
29 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
35 using impl_scalar_type =
typename Kokkos::ArithTraits<Scalar>::val_type;
36 impl_scalar_type implAlpha = alpha;
42 typename Aggregates::aggregates_sizes_type::const_type aggSizes = aggregates_->ComputeAggregateSizes();
44 auto kokkos_view_X = X.getDeviceLocalView(Xpetra::Access::ReadOnly);
45 auto kokkos_view_Y = Y.getDeviceLocalView(Xpetra::Access::ReadWrite);
46 LO numCols = kokkos_view_X.extent(1);
49 auto vertex2AggId = aggregates_->GetVertex2AggId();
50 auto vertex2AggIdView = vertex2AggId->getDeviceLocalView(Xpetra::Access::ReadOnly);
51 LO numNodes = kokkos_view_X.extent(0);
56 "MueLu:MatrixFreeTentativeR_kokkos:apply",
md_range_type({0, 0}, {numCols, numNodes}),
57 KOKKOS_LAMBDA(
const int colIdx,
const int NodeIdx) {
58 LO aggIdx = vertex2AggIdView(NodeIdx, 0);
60 Kokkos::atomic_add(&kokkos_view_Y(aggIdx, colIdx), implAlpha * kokkos_view_X(NodeIdx, colIdx) / Kokkos::sqrt(aggSizes(aggIdx)));
64 const auto vertex2Agg = aggregates_->GetVertex2AggId();
65 auto vertex2AggView = vertex2Agg->getDeviceLocalView(Xpetra::Access::ReadOnly);
66 LO numNodes = kokkos_view_Y.extent(0);
71 "MueLu:MatrixFreeTentativeP:apply",
md_range_type({0, 0}, {numCols, numNodes}),
72 KOKKOS_LAMBDA(
const int colIdx,
const int fineIdx) {
73 LO aggIdx = vertex2AggView(fineIdx, 0);
74 kokkos_view_Y(fineIdx, colIdx) += implAlpha * kokkos_view_X(aggIdx, colIdx) / Kokkos::sqrt(aggSizes(aggIdx));
80 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
87 #define MUELU_MATRIXFREETENTATIVEP_SHORT
88 #endif // MUELU_MATRIXFREETENTATIVEP_DEF_HPP
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
void residual(const MultiVector &X, const MultiVector &B, MultiVector &R) const override
MueLu::DefaultScalar Scalar
virtual void scale(const Scalar &alpha)=0
Kokkos::MDRangePolicy< local_ordinal_type, execution_space, Kokkos::Rank< 2 > > md_range_type
Exception throws to report errors in the internal logical of the program.
void apply(const MultiVector &X, MultiVector &Y, Teuchos::ETransp mode=Teuchos::NO_TRANS, Scalar alpha=Teuchos::ScalarTraits< Scalar >::one(), Scalar beta=Teuchos::ScalarTraits< Scalar >::zero()) const override