10 #ifndef TPETRA_DETAILS_GETNUMDIAGS_HPP
11 #define TPETRA_DETAILS_GETNUMDIAGS_HPP
20 #include "Tpetra_CrsGraph.hpp"
21 #include "Teuchos_CommHelpers.hpp"
33 template<
class LocalGraphType,
class LocalMapType>
37 const LocalMapType& rowMap,
38 const LocalMapType& colMap) :
39 G_ (G), rowMap_ (rowMap), colMap_ (colMap)
44 using result_type =
typename LocalMapType::local_ordinal_type;
47 KOKKOS_INLINE_FUNCTION
void
48 operator () (
const typename LocalMapType::local_ordinal_type lclRow,
49 result_type& diagCount)
const
51 using LO =
typename LocalMapType::local_ordinal_type;
52 using LOT = typename ::Tpetra::Details::OrdinalTraits<LO>;
54 auto G_row = G_.rowConst (lclRow);
55 const LO numEnt = G_row.length;
60 const LO lclDiagCol = colMap_.getLocalElement (rowMap_.getGlobalElement (lclRow));
62 if (lclDiagCol != LOT::invalid ()) {
65 for (LO k = 0; k < numEnt; ++k) {
66 if (lclDiagCol == G_row(k)) {
81 template<
class LO,
class GO,
class NT>
82 typename ::Tpetra::CrsGraph<LO, GO, NT>::local_ordinal_type
83 countLocalNumDiagsInFillCompleteGraph (const ::Tpetra::CrsGraph<LO, GO, NT>& G)
86 using local_map_type =
typename crs_graph_type::map_type::local_map_type;
87 using local_graph_device_type =
typename crs_graph_type::local_graph_device_type;
88 using functor_type = CountLocalNumDiags<local_graph_device_type, local_map_type>;
89 using execution_space =
typename crs_graph_type::device_type::execution_space;
90 using policy_type = Kokkos::RangePolicy<execution_space, LO>;
93 const auto colMap = G.getColMap ();
94 if (rowMap.get () ==
nullptr || colMap.get () ==
nullptr) {
99 functor_type f (G.getLocalGraphDevice (), rowMap->getLocalMap (), colMap->getLocalMap ());
100 Kokkos::parallel_reduce (policy_type (0, G.getLocalNumRows ()), f, lclNumDiags);
114 template<
class MapType>
115 typename MapType::local_ordinal_type
116 getLocalDiagonalColumnIndex (
const typename MapType::local_ordinal_type lclRow,
117 const MapType& rowMap,
118 const MapType& colMap)
120 return colMap.getLocalElement (rowMap.getGlobalElement (lclRow));
124 template<
class LO,
class GO,
class NT>
125 typename ::Tpetra::RowGraph<LO, GO, NT>::local_ordinal_type
126 countLocalNumDiagsInNonFillCompleteLocallyIndexedGraphWithRowViews (const ::Tpetra::RowGraph<LO, GO, NT>& G)
128 using LOT = typename ::Tpetra::Details::OrdinalTraits<LO>;
130 const auto rowMap = G.getRowMap ();
131 const auto colMap = G.getColMap ();
132 if (rowMap.get () ==
nullptr || colMap.get () ==
nullptr) {
136 TEUCHOS_TEST_FOR_EXCEPTION
137 (! G.supportsRowViews (), std::logic_error,
"Not implemented!");
139 typename ::Tpetra::RowGraph<LO, GO, NT>::local_inds_host_view_type
141 const LO lclNumRows =
static_cast<LO
> (G.getLocalNumRows ());
144 for (LO lclRow = 0; lclRow < lclNumRows; ++lclRow) {
145 G.getLocalRowView (lclRow, lclColInds);
146 const LO numEnt =
static_cast<LO
> (lclColInds.size ());
148 const LO lclDiagCol = colMap->getLocalElement (rowMap->getGlobalElement (lclRow));
150 if (lclDiagCol != LOT::invalid ()) {
153 for (LO k = 0; k < numEnt; ++k) {
154 if (lclDiagCol == lclColInds[k]) {
168 template<
class LO,
class GO,
class NT>
169 typename ::Tpetra::RowGraph<LO, GO, NT>::local_ordinal_type
170 countLocalNumDiagsInNonFillCompleteLocallyIndexedGraphWithoutRowViews (const ::Tpetra::RowGraph<LO, GO, NT>& G)
172 using LOT = typename ::Tpetra::Details::OrdinalTraits<LO>;
174 const auto rowMap = G.getRowMap ();
175 const auto colMap = G.getColMap ();
176 if (rowMap.get () ==
nullptr || colMap.get () ==
nullptr) {
180 using inds_type = typename ::Tpetra::RowGraph<LO,GO,NT>::nonconst_local_inds_host_view_type;
181 inds_type lclColIndsBuf(
"lclColIndsBuf",G.getLocalMaxNumRowEntries());
182 const LO lclNumRows =
static_cast<LO
> (G.getLocalNumRows ());
185 for (LO lclRow = 0; lclRow < lclNumRows; ++lclRow) {
186 size_t numEntSizeT = G.getNumEntriesInLocalRow (lclRow);
187 const LO numEnt =
static_cast<LO
> (numEntSizeT);
189 inds_type lclColInds = Kokkos::subview(lclColIndsBuf,std::make_pair(0,numEnt));
190 G.getLocalRowCopy (lclRow, lclColInds, numEntSizeT);
193 const LO lclDiagCol =
194 colMap->getLocalElement (rowMap->getGlobalElement (lclRow));
196 if (lclDiagCol != LOT::invalid ()) {
199 for (LO k = 0; k < numEnt; ++k) {
200 if (lclDiagCol == lclColInds[k]) {
214 template<
class LO,
class GO,
class NT>
215 typename ::Tpetra::RowGraph<LO, GO, NT>::local_ordinal_type
216 countLocalNumDiagsInNonFillCompleteGloballyIndexedGraphWithRowViews (const ::Tpetra::RowGraph<LO, GO, NT>& G)
218 const auto rowMap = G.getRowMap ();
219 if (rowMap.get () ==
nullptr) {
223 typename ::Tpetra::RowGraph<LO,GO,NT>::global_inds_host_view_type
225 const LO lclNumRows =
static_cast<LO
> (G.getLocalNumRows ());
228 for (LO lclRow = 0; lclRow < lclNumRows; ++lclRow) {
229 const GO gblRow = rowMap->getGlobalElement (lclRow);
230 G.getGlobalRowView (gblRow, gblColInds);
231 const LO numEnt =
static_cast<LO
> (gblColInds.size ());
235 for (LO k = 0; k < numEnt; ++k) {
236 if (gblRow == gblColInds[k]) {
249 template<
class LO,
class GO,
class NT>
250 typename ::Tpetra::RowGraph<LO, GO, NT>::local_ordinal_type
251 countLocalNumDiagsInNonFillCompleteGloballyIndexedGraphWithoutRowViews (const ::Tpetra::RowGraph<LO, GO, NT>& G)
253 using gids_type = typename ::Tpetra::RowGraph<LO,GO,NT>::nonconst_global_inds_host_view_type ;
254 const auto rowMap = G.getRowMap ();
255 if (rowMap.get () ==
nullptr) {
259 gids_type gblColIndsBuf;
260 const LO lclNumRows =
static_cast<LO
> (G.getLocalNumRows ());
263 for (LO lclRow = 0; lclRow < lclNumRows; ++lclRow) {
264 size_t numEntSizeT = G.getNumEntriesInLocalRow (lclRow);
265 const LO numEnt =
static_cast<LO
> (numEntSizeT);
266 if (static_cast<LO> (gblColIndsBuf.size ()) < numEnt) {
267 Kokkos::resize(gblColIndsBuf,numEnt);
270 gids_type gblColInds = Kokkos::subview(gblColIndsBuf,std::make_pair((LO)0, numEnt));
271 const GO gblRow = rowMap->getGlobalElement (lclRow);
272 G.getGlobalRowCopy (gblRow, gblColInds, numEntSizeT);
277 for (LO k = 0; k < numEnt; ++k) {
278 if (gblRow == gblColInds[k]) {
296 template<
class MatrixType>
298 static typename MatrixType::local_ordinal_type
299 getLocalNumDiags (
const MatrixType& A)
301 using LO =
typename MatrixType::local_ordinal_type;
302 using GO =
typename MatrixType::global_ordinal_type;
303 using NT =
typename MatrixType::node_type;
306 auto G = A.getGraph ();
307 if (G.get () ==
nullptr) {
317 template<
class LO,
class GO,
class NT>
320 getLocalNumDiags (const ::Tpetra::RowGraph<LO, GO, NT>& G)
324 const crs_graph_type* G_crs =
dynamic_cast<const crs_graph_type*
> (&G);
325 if (G_crs !=
nullptr && G_crs->isFillComplete ()) {
326 return countLocalNumDiagsInFillCompleteGraph (*G_crs);
329 if (G.isLocallyIndexed ()) {
330 if (G.supportsRowViews ()) {
331 return countLocalNumDiagsInNonFillCompleteLocallyIndexedGraphWithRowViews (G);
334 return countLocalNumDiagsInNonFillCompleteLocallyIndexedGraphWithoutRowViews (G);
337 else if (G.isGloballyIndexed ()) {
338 if (G.supportsRowViews ()) {
339 return countLocalNumDiagsInNonFillCompleteGloballyIndexedGraphWithRowViews (G);
342 return countLocalNumDiagsInNonFillCompleteGloballyIndexedGraphWithoutRowViews (G);
353 template<
class LO,
class GO,
class NT>
356 getLocalNumDiags (const ::Tpetra::CrsGraph<LO, GO, NT>& G)
366 template<
class CrsGraphType>
367 typename CrsGraphType::local_ordinal_type
375 template<
class CrsGraphType>
376 typename CrsGraphType::global_ordinal_type
379 using GO =
typename CrsGraphType::global_ordinal_type;
381 const auto map = G.getRowMap ();
382 if (map.get () ==
nullptr) {
386 const auto comm = map->getComm ();
387 if (comm.get () ==
nullptr) {
393 using Teuchos::REDUCE_SUM;
394 using Teuchos::reduceAll;
395 using Teuchos::outArg;
398 reduceAll<int, GO> (*comm, REDUCE_SUM, lclNumDiags, outArg (gblNumDiags));
407 #endif // TPETRA_DETAILS_GETNUMDIAGS_HPP
Teuchos::RCP< const map_type > getRowMap() const override
Returns the Map that describes the row distribution in this graph.
Import KokkosSparse::OrdinalTraits, a traits class for "invalid" (flag) values of integer types...
An abstract interface for graphs accessed by rows.
Kokkos::parallel_reduce functor for counting the local number of diagonal entries in a sparse graph...
CrsGraphType::global_ordinal_type getGlobalNumDiags(const CrsGraphType &G)
Number of populated diagonal entries in the given sparse graph, over all processes in the graph's (MP...
CrsGraphType::local_ordinal_type getLocalNumDiags(const CrsGraphType &G)
Number of populated diagonal entries in the given sparse graph, on the calling (MPI) process...
KOKKOS_INLINE_FUNCTION void operator()(const typename LocalMapType::local_ordinal_type lclRow, result_type &diagCount) const
Reduction function: result is (diagonal count, error count).
A distributed graph accessed by rows (adjacency lists) and stored sparsely.
Implementation of Tpetra::Details::getLocalNumDiags (see below).