50 #ifndef _ZOLTAN2_TPETRAROWMATRIXADAPTER_HPP_
51 #define _ZOLTAN2_TPETRAROWMATRIXADAPTER_HPP_
57 #include <Tpetra_RowMatrix.hpp>
77 template <
typename User,
typename UserCoord = User>
82 #ifndef DOXYGEN_SHOULD_SKIP_THIS
89 using device_t =
typename node_t::device_type;
90 using host_t =
typename Kokkos::HostSpace::memory_space;
92 using userCoord_t = UserCoord;
103 int nWeightsPerRow = 0);
128 void setWeightsDevice(
typename Base::ConstWeightsDeviceView1D val,
int idx);
139 void setWeightsHost(
typename Base::ConstWeightsHostView1D val,
int idx);
218 typename Base::ConstIdsHostView &rowIds)
const override;
221 typename Base::ConstIdsDeviceView &rowIds)
const override;
223 void getCRSView(ArrayRCP<const offset_t> &offsets,
224 ArrayRCP<const gno_t> &colIds)
const;
227 typename Base::ConstOffsetsHostView &offsets,
228 typename Base::ConstIdsHostView &colIds)
const override;
231 typename Base::ConstOffsetsDeviceView &offsets,
232 typename Base::ConstIdsDeviceView &colIds)
const override;
234 void getCRSView(ArrayRCP<const offset_t> &offsets,
235 ArrayRCP<const gno_t> &colIds,
236 ArrayRCP<const scalar_t> &values)
const;
239 typename Base::ConstOffsetsHostView &offsets,
240 typename Base::ConstIdsHostView &colIds,
241 typename Base::ConstScalarsHostView &values)
const override;
244 typename Base::ConstOffsetsDeviceView &offsets,
245 typename Base::ConstIdsDeviceView &colIds,
246 typename Base::ConstScalarsDeviceView &values)
const override;
257 typename Base::WeightsDeviceView &weights)
const override;
263 typename Base::WeightsHostView &weights)
const override;
267 template <
typename Adapter>
269 const User &in, User *&out,
272 template <
typename Adapter>
274 const User &in, RCP<User> &out,
280 const RCP<const User> &inmatrix)
304 virtual RCP<User>
doMigration(
const User &from,
size_t numLocalRows,
305 const gno_t *myNewRows)
const;
312 template <
typename User,
typename UserCoord>
314 const RCP<const User> &inmatrix,
int nWeightsPerRow):
315 matrix_(inmatrix), offset_(), columnIds_(),
316 nWeightsPerRow_(nWeightsPerRow), rowWeights_(),
317 mayHaveDiagonalEntries(true) {
319 using localInds_t =
typename User::nonconst_local_inds_host_view_type;
320 using localVals_t =
typename User::nonconst_values_host_view_type;
322 const auto nrows =
matrix_->getLocalNumRows();
323 const auto nnz =
matrix_->getLocalNumEntries();
324 auto maxNumEntries =
matrix_->getLocalMaxNumRowEntries();
329 colIdsHost_ =
typename Base::ConstIdsHostView(
"colIdsHost_", nnz);
330 offsHost_ =
typename Base::ConstOffsetsHostView(
"offsHost_", nrows + 1);
331 valuesHost_ =
typename Base::ScalarsHostView(
"valuesHost_", nnz);
333 localInds_t localColInds(
"localColInds", maxNumEntries);
334 localVals_t localVals(
"localVals", maxNumEntries);
336 for (
size_t r = 0; r < nrows; r++) {
337 size_t numEntries = 0;
338 matrix_->getLocalRowCopy(r, localColInds, localVals, numEntries);
344 for (
size_t j = 0; j < nnz; j++) {
360 "rowWeightsDevice_", nrows, nWeightsPerRow_);
372 template <
typename User,
typename UserCoord>
374 const scalar_t *weightVal,
int stride,
int idx) {
375 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
376 setRowWeights(weightVal, stride, idx);
379 std::ostringstream emsg;
380 emsg << __FILE__ <<
"," << __LINE__
381 <<
" error: setWeights not yet supported for"
382 <<
" columns or nonzeros." << std::endl;
383 throw std::runtime_error(emsg.str());
388 template <
typename User,
typename UserCoord>
390 typename Base::ConstWeightsDeviceView1D val,
int idx) {
391 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
392 setRowWeightsDevice(val, idx);
395 std::ostringstream emsg;
396 emsg << __FILE__ <<
"," << __LINE__
397 <<
" error: setWeights not yet supported for"
398 <<
" columns or nonzeros." << std::endl;
399 throw std::runtime_error(emsg.str());
404 template <
typename User,
typename UserCoord>
406 typename Base::ConstWeightsHostView1D val,
int idx) {
407 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
408 setRowWeightsHost(val, idx);
411 std::ostringstream emsg;
412 emsg << __FILE__ <<
"," << __LINE__
413 <<
" error: setWeights not yet supported for"
414 <<
" columns or nonzeros." << std::endl;
415 throw std::runtime_error(emsg.str());
420 template <
typename User,
typename UserCoord>
422 const scalar_t *weightVal,
int stride,
int idx) {
425 "Invalid row weight index: " + std::to_string(idx));
427 size_t nrows = getLocalNumRows();
428 ArrayRCP<const scalar_t> weightV(weightVal, 0, nrows * stride,
false);
429 rowWeights_[idx] = input_t(weightV, stride);
433 template <
typename User,
typename UserCoord>
435 typename Base::ConstWeightsDeviceView1D
weights,
int idx) {
438 "Invalid row weight index: " + std::to_string(idx));
440 Kokkos::parallel_for(
441 rowWeightsDevice_.extent(0), KOKKOS_CLASS_LAMBDA(
const int rowID) {
442 rowWeightsDevice_(rowID, idx) =
weights(rowID);
449 template <
typename User,
typename UserCoord>
451 typename Base::ConstWeightsHostView1D weightsHost,
int idx) {
453 "Invalid row weight index: " + std::to_string(idx));
455 auto weightsDevice = Kokkos::create_mirror_view_and_copy(
458 setRowWeightsDevice(weightsDevice, idx);
462 template <
typename User,
typename UserCoord>
464 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
465 setRowWeightIsNumberOfNonZeros(idx);
468 std::ostringstream emsg;
469 emsg << __FILE__ <<
"," << __LINE__
470 <<
" error: setWeightIsNumberOfNonZeros not yet supported for"
471 <<
" columns" << std::endl;
472 throw std::runtime_error(emsg.str());
477 template <
typename User,
typename UserCoord>
480 if (idx < 0 || idx >= nWeightsPerRow_) {
481 std::ostringstream emsg;
482 emsg << __FILE__ <<
":" << __LINE__ <<
" Invalid row weight index " << idx
484 throw std::runtime_error(emsg.str());
487 numNzWeight_(idx) =
true;
491 template <
typename User,
typename UserCoord>
493 return matrix_->getLocalNumRows();
497 template <
typename User,
typename UserCoord>
499 return matrix_->getLocalNumCols();
503 template <
typename User,
typename UserCoord>
505 return matrix_->getLocalNumEntries();
509 template <
typename User,
typename UserCoord>
513 template <
typename User,
typename UserCoord>
515 ArrayView<const gno_t> rowView = matrix_->getRowMap()->getLocalElementList();
516 rowIds = rowView.getRawPtr();
520 template <
typename User,
typename UserCoord>
522 typename Base::ConstIdsHostView &rowIds)
const {
523 auto idsDevice = matrix_->getRowMap()->getMyGlobalIndices();
524 auto tmpIds =
typename Base::IdsHostView(
"", idsDevice.extent(0));
526 Kokkos::deep_copy(tmpIds, idsDevice);
532 template <
typename User,
typename UserCoord>
534 typename Base::ConstIdsDeviceView &rowIds)
const {
536 auto idsDevice = matrix_->getRowMap()->getMyGlobalIndices();
537 auto tmpIds =
typename Base::IdsDeviceView(
"", idsDevice.extent(0));
539 Kokkos::deep_copy(tmpIds, idsDevice);
545 template <
typename User,
typename UserCoord>
547 ArrayRCP<const gno_t> &colIds)
const {
553 template <
typename User,
typename UserCoord>
555 typename Base::ConstOffsetsHostView &offsets,
556 typename Base::ConstIdsHostView &colIds)
const {
557 auto hostOffsets = Kokkos::create_mirror_view(offsDevice_);
558 Kokkos::deep_copy(hostOffsets, offsDevice_);
559 offsets = hostOffsets;
561 auto hostColIds = Kokkos::create_mirror_view(colIdsDevice_);
562 Kokkos::deep_copy(hostColIds, colIdsDevice_);
567 template <
typename User,
typename UserCoord>
569 typename Base::ConstOffsetsDeviceView &offsets,
570 typename Base::ConstIdsDeviceView &colIds)
const {
571 offsets = offsDevice_;
572 colIds = colIdsDevice_;
576 template <
typename User,
typename UserCoord>
578 ArrayRCP<const gno_t> &colIds,
579 ArrayRCP<const scalar_t> &values)
const {
586 template <
typename User,
typename UserCoord>
588 typename Base::ConstOffsetsHostView &offsets,
589 typename Base::ConstIdsHostView &colIds,
590 typename Base::ConstScalarsHostView &values)
const {
591 auto hostOffsets = Kokkos::create_mirror_view(offsDevice_);
592 Kokkos::deep_copy(hostOffsets, offsDevice_);
593 offsets = hostOffsets;
595 auto hostColIds = Kokkos::create_mirror_view(colIdsDevice_);
596 Kokkos::deep_copy(hostColIds, colIdsDevice_);
599 auto hostValues = Kokkos::create_mirror_view(valuesDevice_);
600 Kokkos::deep_copy(hostValues, valuesDevice_);
605 template <
typename User,
typename UserCoord>
607 typename Base::ConstOffsetsDeviceView &offsets,
608 typename Base::ConstIdsDeviceView &colIds,
609 typename Base::ConstScalarsDeviceView &values)
const {
610 offsets = offsDevice_;
611 colIds = colIdsDevice_;
612 values = valuesDevice_;
616 template <
typename User,
typename UserCoord>
620 template <
typename User,
typename UserCoord>
623 if (idx < 0 || idx >= nWeightsPerRow_) {
624 std::ostringstream emsg;
625 emsg << __FILE__ <<
":" << __LINE__ <<
" Invalid row weight index "
627 throw std::runtime_error(emsg.str());
631 rowWeights_[idx].getStridedList(length, weights, stride);
635 template <
typename User,
typename UserCoord>
637 typename Base::WeightsDeviceView1D &
weights,
int idx)
const {
639 "Invalid row weight index.");
641 const auto size = rowWeightsDevice_.extent(0);
642 weights =
typename Base::WeightsDeviceView1D(
"weights", size);
644 Kokkos::parallel_for(
645 size, KOKKOS_CLASS_LAMBDA(
const int id) {
646 weights(
id) = rowWeightsDevice_(
id, idx);
653 template <
typename User,
typename UserCoord>
655 typename Base::WeightsDeviceView &
weights)
const {
657 weights = rowWeightsDevice_;
661 template <
typename User,
typename UserCoord>
663 typename Base::WeightsHostView1D &
weights,
int idx)
const {
665 "Invalid row weight index.");
667 auto weightsDevice =
typename Base::WeightsDeviceView1D(
668 "weights", rowWeightsDevice_.extent(0));
669 getRowWeightsDeviceView(weightsDevice, idx);
671 weights = Kokkos::create_mirror_view(weightsDevice);
672 Kokkos::deep_copy(weights, weightsDevice);
676 template <
typename User,
typename UserCoord>
678 typename Base::WeightsHostView &
weights)
const {
680 weights = Kokkos::create_mirror_view(rowWeightsDevice_);
681 Kokkos::deep_copy(weights, rowWeightsDevice_);
685 template <
typename User,
typename UserCoord>
689 template <
typename User,
typename UserCoord>
690 template <
typename Adapter>
692 const User &in, User *&out,
696 ArrayRCP<gno_t> importList;
699 Zoltan2::getImportList<Adapter, TpetraRowMatrixAdapter<User, UserCoord>>(
700 solution,
this, importList);
705 RCP<User> outPtr = doMigration(in, numNewRows, importList.getRawPtr());
711 template <
typename User,
typename UserCoord>
712 template <
typename Adapter>
714 const User &in, RCP<User> &out,
718 ArrayRCP<gno_t> importList;
721 Zoltan2::getImportList<Adapter, TpetraRowMatrixAdapter<User, UserCoord>>(
722 solution,
this, importList);
727 out = doMigration(in, numNewRows, importList.getRawPtr());
731 template <
typename User,
typename UserCoord>
733 const User &from,
size_t numLocalRows,
const gno_t *myNewRows)
const {
734 typedef Tpetra::Map<lno_t, gno_t, node_t>
map_t;
735 typedef Tpetra::CrsMatrix<scalar_t, lno_t, gno_t, node_t> tcrsmatrix_t;
746 const tcrsmatrix_t *pCrsMatrix =
dynamic_cast<const tcrsmatrix_t *
>(&from);
749 throw std::logic_error(
"TpetraRowMatrixAdapter cannot migrate data for "
750 "your RowMatrix; it can migrate data only for "
751 "Tpetra::CrsMatrix. "
752 "You can inherit from TpetraRowMatrixAdapter and "
753 "implement migration for your RowMatrix.");
757 const RCP<const map_t> &smap = from.getRowMap();
758 gno_t numGlobalRows = smap->getGlobalNumElements();
759 gno_t base = smap->getMinAllGlobalIndex();
762 ArrayView<const gno_t> rowList(myNewRows, numLocalRows);
763 const RCP<const Teuchos::Comm<int>> &comm = from.getComm();
764 RCP<const map_t> tmap = rcp(
new map_t(numGlobalRows, rowList, base, comm));
767 Tpetra::Import<lno_t, gno_t, node_t> importer(smap, tmap);
769 int oldNumElts = smap->getLocalNumElements();
770 int newNumElts = numLocalRows;
773 typedef Tpetra::Vector<scalar_t, lno_t, gno_t, node_t> vector_t;
774 vector_t numOld(smap);
775 vector_t numNew(tmap);
776 for (
int lid = 0; lid < oldNumElts; lid++) {
777 numOld.replaceGlobalValue(smap->getGlobalElement(lid),
778 scalar_t(from.getNumEntriesInLocalRow(lid)));
780 numNew.doImport(numOld, importer, Tpetra::INSERT);
783 ArrayRCP<size_t> nnz(newNumElts);
784 if (newNumElts > 0) {
785 ArrayRCP<scalar_t> ptr = numNew.getDataNonConst(0);
786 for (
int lid = 0; lid < newNumElts; lid++) {
787 nnz[lid] =
static_cast<size_t>(ptr[lid]);
791 RCP<tcrsmatrix_t> M = rcp(
new tcrsmatrix_t(tmap, nnz()));
793 M->doImport(from, importer, Tpetra::INSERT);
796 return Teuchos::rcp_dynamic_cast<User>(M);
801 #endif // _ZOLTAN2_TPETRAROWMATRIXADAPTER_HPP_
bool CRSViewAvailable() const
Indicates whether the MatrixAdapter implements a view of the matrix in compressed sparse row (CRS) fo...
void setWeightsHost(typename Base::ConstWeightsHostView1D val, int idx)
Provide a host view of weights for the primary entity type.
ArrayRCP< StridedData< lno_t, scalar_t > > rowWeights_
bool mayHaveDiagonalEntries
void getRowWeightsHostView(typename Base::WeightsHostView1D &weights, int idx=0) const
virtual RCP< User > doMigration(const User &from, size_t numLocalRows, const gno_t *myNewRows) const
Base::ScalarsHostView valuesHost_
void setRowWeightsHost(typename Base::ConstWeightsHostView1D val, int idx)
Provide a host view to row weights.
Helper functions for Partitioning Problems.
Base::ConstIdsDeviceView colIdsDevice_
#define Z2_FORWARD_EXCEPTIONS
Forward an exception back through call stack.
typename InputTraits< User >::scalar_t scalar_t
MatrixAdapter defines the adapter interface for matrices.
static void AssertCondition(bool condition, const std::string &message, const char *file=__FILE__, int line=__LINE__)
typename Kokkos::HostSpace::memory_space host_t
map_t::global_ordinal_type gno_t
void setWeightsDevice(typename Base::ConstWeightsDeviceView1D val, int idx)
Provide a device view of weights for the primary entity type.
typename node_t::device_type device_t
ArrayRCP< scalar_t > values_
Base::ConstIdsHostView colIdsHost_
typename InputTraits< User >::part_t part_t
bool useNumNonzerosAsRowWeight(int idx) const
Indicate whether row weight with index idx should be the global number of nonzeros in the row...
Provides access for Zoltan2 to Tpetra::RowMatrix data.
typename InputTraits< User >::node_t node_t
void getRowWeightsView(const scalar_t *&weights, int &stride, int idx=0) const
Provide a pointer to the row weights, if any.
size_t getLocalNumEntries() const
Returns the number of nonzeros on this process.
A PartitioningSolution is a solution to a partitioning problem.
void getRowIDsView(const gno_t *&rowIds) const override
void setRowWeightsDevice(typename Base::ConstWeightsDeviceView1D val, int idx)
Provide a device view to row weights.
int getNumWeightsPerRow() const
Returns the number of weights per row (0 or greater). Row weights may be used when partitioning matri...
void getCRSDeviceView(typename Base::ConstOffsetsDeviceView &offsets, typename Base::ConstIdsDeviceView &colIds) const override
Base::ScalarsDeviceView valuesDevice_
size_t getLocalNumColumns() const
Returns the number of columns on this process.
Base::ConstOffsetsDeviceView offsDevice_
The StridedData class manages lists of weights or coordinates.
map_t::local_ordinal_type lno_t
void getRowIDsDeviceView(typename Base::ConstIdsDeviceView &rowIds) const override
ArrayRCP< gno_t > columnIds_
void applyPartitioningSolution(const User &in, User *&out, const PartitioningSolution< Adapter > &solution) const
void setRowWeights(const scalar_t *weightVal, int stride, int idx=0)
Specify a weight for each row.
RCP< const User > matrix_
typename InputTraits< User >::offset_t offset_t
void getCRSHostView(typename Base::ConstOffsetsHostView &offsets, typename Base::ConstIdsHostView &colIds) const override
void getRowWeightsDeviceView(typename Base::WeightsDeviceView1D &weights, int idx=0) const
void setWeightIsDegree(int idx)
Specify an index for which the weight should be the degree of the entity.
Base::ConstOffsetsHostView offsHost_
Defines the MatrixAdapter interface.
void setWeights(const scalar_t *weightVal, int stride, int idx=0)
Specify a weight for each entity of the primaryEntityType.
ArrayRCP< offset_t > offset_
Kokkos::View< bool *, host_t > numNzWeight_
void getRowIDsHostView(typename Base::ConstIdsHostView &rowIds) const override
size_t getLocalNumRows() const
Returns the number of rows on this process.
TpetraRowMatrixAdapter(int nWeightsPerRow, const RCP< const User > &inmatrix)
Base::WeightsDeviceView rowWeightsDevice_
void getCRSView(ArrayRCP< const offset_t > &offsets, ArrayRCP< const gno_t > &colIds) const
void setRowWeightIsNumberOfNonZeros(int idx)
Specify an index for which the row weight should be the global number of nonzeros in the row...
TpetraRowMatrixAdapter(const RCP< const User > &inmatrix, int nWeightsPerRow=0)
Constructor.
This file defines the StridedData class.