44 #ifndef TPETRA_APPLYDIRICHLETBOUNDARYCONDITION_HPP
45 #define TPETRA_APPLYDIRICHLETBOUNDARYCONDITION_HPP
51 #include "Tpetra_CrsMatrix.hpp"
52 #include "Tpetra_Map.hpp"
53 #include "KokkosSparse_CrsMatrix.hpp"
54 #include "Kokkos_ArithTraits.hpp"
67 template<
class CrsMatrixType>
70 (
const typename CrsMatrixType::execution_space& execSpace,
73 typename CrsMatrixType::local_ordinal_type*,
74 typename CrsMatrixType::device_type> & lclRowInds);
85 template<
class CrsMatrixType>
90 typename CrsMatrixType::local_ordinal_type*,
91 typename CrsMatrixType::device_type>& lclRowInds);
102 template<
class CrsMatrixType>
107 typename CrsMatrixType::local_ordinal_type*,
108 Kokkos::HostSpace> & lclRowInds);
112 template<
class SC,
class LO,
class GO,
class NT>
113 struct ApplyDirichletBoundaryConditionToLocalMatrixRows {
116 using local_row_indices_type =
117 Kokkos::View<const LO*, Kokkos::AnonymousSpace>;
120 run (
const execution_space& execSpace,
122 const local_row_indices_type& lclRowInds,
123 const bool runOnHost)
126 using KAT = Kokkos::ArithTraits<IST>;
128 const auto rowMap = A.getRowMap ();
129 TEUCHOS_TEST_FOR_EXCEPTION
130 (rowMap.get () ==
nullptr, std::invalid_argument,
131 "The matrix must have a row Map.");
132 const auto colMap = A.getColMap ();
133 TEUCHOS_TEST_FOR_EXCEPTION
134 (colMap.get () ==
nullptr, std::invalid_argument,
135 "The matrix must have a column Map.");
136 auto A_lcl = A.getLocalMatrix ();
138 const LO lclNumRows =
static_cast<LO
> (rowMap->getNodeNumElements ());
139 TEUCHOS_TEST_FOR_EXCEPTION
140 (lclNumRows != 0 && static_cast<LO>(A_lcl.graph.numRows ()) != lclNumRows,
141 std::invalid_argument,
"The matrix must have been either created "
142 "with a KokkosSparse::CrsMatrix, or must have been fill-completed "
145 auto lclRowMap = A.getRowMap ()->getLocalMap ();
146 auto lclColMap = A.getColMap ()->getLocalMap ();
147 auto rowptr = A_lcl.graph.row_map;
148 auto colind = A_lcl.graph.entries;
149 auto values = A_lcl.values;
151 const bool wasFillComplete = A.isFillComplete ();
152 if (wasFillComplete) {
156 const LO numInputRows = lclRowInds.extent (0);
158 using range_type = Kokkos::RangePolicy<execution_space, LO>;
160 (
"Tpetra::CrsMatrix apply Dirichlet: Device",
161 range_type (execSpace, 0, numInputRows),
162 KOKKOS_LAMBDA (
const LO i) {
163 const GO row_gid = lclRowMap.getGlobalElement (lclRowInds(i));
164 for (
auto j = rowptr(i); j < rowptr(i+1); ++j) {
166 lclColMap.getGlobalElement (colind(j)) == row_gid;
167 values(j) = diagEnt ? KAT::one () : KAT::zero ();
173 Kokkos::RangePolicy<Kokkos::DefaultHostExecutionSpace, LO>;
175 (
"Tpetra::CrsMatrix apply Dirichlet: Host",
176 range_type (0, numInputRows),
178 const GO row_gid = lclRowMap.getGlobalElement (lclRowInds(i));
179 for (
auto j = rowptr(i); j < rowptr(i+1); ++j) {
181 lclColMap.getGlobalElement (colind(j)) == row_gid;
182 values(j) = diagEnt ? KAT::one () : KAT::zero ();
186 if (wasFillComplete) {
187 A.fillComplete (A.getDomainMap (), A.getRangeMap ());
194 template<
class CrsMatrixType>
197 (
const typename CrsMatrixType::execution_space& execSpace,
200 typename CrsMatrixType::local_ordinal_type*,
201 typename CrsMatrixType::device_type> & lclRowInds)
203 using SC =
typename CrsMatrixType::scalar_type;
204 using LO =
typename CrsMatrixType::local_ordinal_type;
205 using GO =
typename CrsMatrixType::global_ordinal_type;
206 using NT =
typename CrsMatrixType::node_type;
208 using local_row_indices_type =
209 Kokkos::View<const LO*, Kokkos::AnonymousSpace>;
210 const local_row_indices_type lclRowInds_a (lclRowInds);
212 using Details::ApplyDirichletBoundaryConditionToLocalMatrixRows;
214 ApplyDirichletBoundaryConditionToLocalMatrixRows<SC, LO, GO, NT>;
215 const bool runOnHost =
false;
216 impl_type::run (execSpace, A, lclRowInds_a, runOnHost);
219 template<
class CrsMatrixType>
224 typename CrsMatrixType::local_ordinal_type*,
225 typename CrsMatrixType::device_type> & lclRowInds)
227 using execution_space =
typename CrsMatrixType::execution_space;
231 template<
class CrsMatrixType>
236 typename CrsMatrixType::local_ordinal_type*,
237 Kokkos::HostSpace> & lclRowInds)
239 using SC =
typename CrsMatrixType::scalar_type;
240 using LO =
typename CrsMatrixType::local_ordinal_type;
241 using GO =
typename CrsMatrixType::global_ordinal_type;
242 using NT =
typename CrsMatrixType::node_type;
244 using execution_space =
typename crs_matrix_type::execution_space;
246 using local_row_indices_type =
247 Kokkos::View<const LO*, Kokkos::AnonymousSpace>;
248 const local_row_indices_type lclRowInds_a (lclRowInds);
250 using Details::ApplyDirichletBoundaryConditionToLocalMatrixRows;
252 ApplyDirichletBoundaryConditionToLocalMatrixRows<SC, LO, GO, NT>;
253 const bool runOnHost =
true;
254 impl_type::run (execution_space (), A, lclRowInds_a, runOnHost);
259 #endif // TPETRA_APPLYDIRICHLETBOUNDARYCONDITION_HPP
Sparse matrix that presents a row-oriented interface that lets users read or modify entries...
typename device_type::execution_space execution_space
The Kokkos execution space.
typename Kokkos::ArithTraits< Scalar >::val_type impl_scalar_type
The type used internally in place of Scalar.
void applyDirichletBoundaryConditionToLocalMatrixRows(const typename CrsMatrixType::execution_space &execSpace, CrsMatrixType &A, const Kokkos::View< typename CrsMatrixType::local_ordinal_type *, typename CrsMatrixType::device_type > &lclRowInds)
For all k in [0, lclRowInds.extent(0)), set local row lclRowInds[k] of A to have 1 on the diagonal an...