58 #include <Teuchos_Comm.hpp>
59 #include <Teuchos_CommHelpers.hpp>
60 #include <Teuchos_DefaultComm.hpp>
61 #include <Teuchos_RCP.hpp>
69 using Teuchos::rcp_const_cast;
70 using Teuchos::rcp_dynamic_cast;
72 using ztcrsmatrix_t = Tpetra::CrsMatrix<zscalar_t, zlno_t, zgno_t, znode_t>;
77 typename crsAdapter_t::ConstWeightsHostView1D::execution_space;
81 template <
typename adapter_t,
typename matrix_t>
84 using idsHost_t =
typename adapter_t::ConstIdsHostView;
85 using offsetsHost_t =
typename adapter_t::ConstOffsetsHostView;
87 typename adapter_t::user_t::nonconst_local_inds_host_view_type;
89 typename adapter_t::user_t::nonconst_values_host_view_type;
92 const auto nrows = matrix.getLocalNumRows();
93 const auto ncols = matrix.getLocalNumEntries();
94 const auto maxNumEntries = matrix.getLocalMaxNumRowEntries();
96 typename adapter_t::Base::ConstIdsHostView colIdsHost_(
"colIdsHost_", ncols);
97 typename adapter_t::Base::ConstOffsetsHostView offsHost_(
"offsHost_",
100 localInds_t localColInds(
"localColInds", maxNumEntries);
101 localVals_t localVals(
"localVals", maxNumEntries);
103 for (
size_t r = 0; r < nrows; r++) {
104 size_t numEntries = 0;
105 matrix.getLocalRowCopy(r, localColInds, localVals, numEntries);;
107 offsHost_(r + 1) = offsHost_(r) + numEntries;
108 for (
size_t e = offsHost_(r), i = 0; e < offsHost_(r + 1); e++) {
109 colIdsHost_(e) = matrix.getColMap()->getGlobalElement(localColInds(i++));
113 idsHost_t rowIdsHost;
114 ia.getRowIDsHostView(rowIdsHost);
116 const auto matrixIDS = matrix.getRowMap()->getLocalElementList();
120 idsHost_t colIdsHost;
121 offsetsHost_t offsetsHost;
122 ia.getCRSHostView(offsetsHost, colIdsHost);
128 template <
typename adapter_t,
typename matrix_t>
130 using idsDevice_t =
typename adapter_t::ConstIdsDeviceView;
131 using idsHost_t =
typename adapter_t::ConstIdsHostView;
132 using weightsDevice_t =
typename adapter_t::WeightsDeviceView1D;
133 using weightsHost_t =
typename adapter_t::WeightsHostView1D;
135 const auto nrows = ia.getLocalNumIDs();
145 idsDevice_t rowIdsDevice;
146 ia.getRowIDsDeviceView(rowIdsDevice);
147 idsHost_t rowIdsHost;
148 ia.getRowIDsHostView(rowIdsHost);
156 typename adapter_t::WeightsDeviceView1D{}, 50),
159 weightsDevice_t wgts0(
"wgts0", nrows);
160 Kokkos::parallel_for(
161 nrows, KOKKOS_LAMBDA(
const int idx) { wgts0(idx) = idx * 2; });
167 weightsDevice_t wgts1(
"wgts1", nrows);
168 Kokkos::parallel_for(
169 nrows, KOKKOS_LAMBDA(
const int idx) { wgts1(idx) = idx * 3; });
177 weightsDevice_t weightsDevice;
180 weightsHost_t weightsHost;
188 weightsDevice_t weightsDevice;
191 weightsHost_t weightsHost;
199 weightsDevice_t wgtsDevice;
203 weightsHost_t wgtsHost;
204 Z2_TEST_THROW(ia.getRowWeightsHostView(wgtsHost, 2), std::runtime_error);
212 int main(
int narg,
char *arg[]) {
216 Tpetra::ScopeGuard tscope(&narg, &arg);
217 const auto comm = Tpetra::getDefaultComm();
220 Teuchos::ParameterList params;
221 params.set(
"input file",
"simple");
222 params.set(
"file type",
"Chaco");
227 const auto crsMatrix = uinput->getUITpetraCrsMatrix();
229 const auto nrows = crsMatrix->getLocalNumRows();
234 const int nWeights = 2;
243 auto tpetraCrsMatrixInput = rcp(
new crsAdapter_t(crsMatrix, nWeights));
247 crsPart_t *p =
new crsPart_t[nrows];
248 memset(p, 0,
sizeof(crsPart_t) * nrows);
249 ArrayRCP<crsPart_t> solnParts(p, 0, nrows,
true);
251 crsSoln_t solution(env, comm, nWeights);
252 solution.setParts(solnParts);
255 tpetraCrsMatrixInput->applyPartitioningSolution(*crsMatrix, mMigrate,
257 const auto newM = rcp(mMigrate);
261 PrintFromRoot(
"Input adapter for Tpetra::CrsMatrix migrated to proc 0");
265 }
catch (std::exception &e) {
266 std::cout << e.what() << std::endl;
void PrintFromRoot(const std::string &message)
Provides access for Zoltan2 to Tpetra::CrsMatrix data.
int main(int narg, char **arg)
common code used by tests
typename InputTraits< User >::part_t part_t
#define Z2_TEST_COMPARE_ARRAYS(val1, val2)
A PartitioningSolution is a solution to a partitioning problem.
#define Z2_TEST_DEVICE_HOST_VIEWS(deviceView, hostView)
#define Z2_TEST_EQUALITY(val1, val2)
Defines the TpetraRowMatrixAdapter class.
The user parameters, debug, timing and memory profiling output objects, and error checking methods...
#define Z2_TEST_THROW(code, ExceptType)
#define Z2_TEST_NOTHROW(code)