7 #include <Teuchos_ParameterEntryValidator.hpp>
14 #ifndef HAVE_ZOLTAN2_SARMA
17 template<
typename Adapter>
23 const RCP<
const Comm<int> > &,
24 const RCP<const base_adapter_t> &) {
25 throw std::runtime_error(
"SARMA requested but not enabled into Zoltan2\n"
26 "Please set CMake flag TPL_ENABLE_SARMA:BOOL=ON");
38 #include <sarma/sarma.hpp>
41 template<
typename Adapter>
42 class AlgSarma :
public Algorithm<Adapter> {
47 using offset_t =
typename Adapter::offset_t;
48 using scalar_t =
typename Adapter::scalar_t;
51 using userCoord_t =
typename Adapter::userCoord_t;
54 const offset_t *offsets;
61 AlgSarma(
const RCP<const Environment> &env,
const RCP<
const Comm<int> > &comm,
62 const RCP<
const IdentifierAdapter <user_t> > &adapter)
63 : env(env), comm(comm), adapter(adapter),
64 algs(sarma::get_algorithm_map<Ordinal, Value>()), orders(sarma::get_order_map()) {
65 throw std::runtime_error(
"Cannot build SARMA from IdentifierAdapter");
68 AlgSarma(
const RCP<const Environment> &env,
const RCP<
const Comm<int> > &comm,
69 const RCP<
const VectorAdapter <user_t> > &adapter)
70 : env(env), comm(comm), adapter(adapter),
71 algs(sarma::get_algorithm_map<Ordinal, Value>()), orders(sarma::get_order_map()) {
72 throw std::runtime_error(
"Cannot build SARMA from VectorAdapter");
76 AlgSarma(
const RCP<const Environment> &env,
const RCP<
const Comm<int> > &comm,
77 const RCP<
const GraphAdapter <user_t, userCoord_t> > &adapter)
78 : env(env), comm(comm), adapter(adapter),
79 algs(sarma::get_algorithm_map<Ordinal, Value>()), orders(sarma::get_order_map()) {
80 throw std::runtime_error(
"Cannot build SARMA from GraphAdapter");
83 AlgSarma(
const RCP<const Environment> &env,
const RCP<
const Comm<int> > &comm,
84 const RCP<
const MeshAdapter <user_t> > &adapter)
85 : env(env), comm(comm), adapter(adapter),
86 algs(sarma::get_algorithm_map<Ordinal, Value>()), orders(sarma::get_order_map()) {
88 throw std::runtime_error(
"Cannot build SARMA from MeshAdapter");
91 AlgSarma(
const RCP<const Environment> &env,
const RCP<
const Comm<int> > &comm,
92 const RCP<
const MatrixAdapter <user_t, userCoord_t> > &adapter)
93 : env(env), comm(comm), adapter(adapter),
94 offsize(1 + adapter->getLocalNumRows()), nnz(adapter->getLocalNumEntries()),
95 algs(sarma::get_algorithm_map<Ordinal, Value>()), orders(sarma::get_order_map()) {
97 if (adapter->CRSViewAvailable()) {
98 offsets =
new offset_t[offsize]();
99 colids =
new gno_t[nnz]();
101 adapter->getCRSView(offsets, colids, vals);
103 throw std::runtime_error(
"NO CRS Available");
106 static void getValidParameters(ParameterList &pl) {
107 RCP<const Teuchos::ParameterEntryValidator> validator;
108 pl.set(
"sarma_parameters", ParameterList(),
"Sarma options", validator);
111 void partition(
const RCP <PartitioningSolution<Adapter>> &solution);
113 const RCP<const Environment> env;
114 const RCP<const Comm<int> > comm;
115 const RCP<const base_adapter_t> adapter;
116 RCP<const CoordinateModel <base_adapter_t> > model;
119 std::string alg =
"pal";
120 sarma::Order order_type = sarma::Order::NAT;
121 std::string order_str =
"nat";
122 Ordinal row_parts = 8, col_parts = 0;
124 int seed = 2147483647;
125 double sparsify = 1.0;
126 bool triangular =
false, use_data =
false;
129 std::map <std::string, std::pair<std::function < std::pair < std::vector < Ordinal>, std::vector<Ordinal>>(
130 const sarma::Matrix <Ordinal, Value> &,
const Ordinal,
const Ordinal,
const Value,
const int)>,
bool>>
132 std::map <std::string, sarma::Order> orders;
140 template<
typename Adapter>
141 void AlgSarma<Adapter>::getParams() {
142 const Teuchos::ParameterList &pl = env->getParameters();
143 const Teuchos::ParameterList &sparams = pl.sublist(
"sarma_parameters");
146 const Teuchos::ParameterEntry *pe = sparams.getEntryPtr(
"alg");
148 config.alg = pe->getValue(&config.alg);
149 pe = sparams.getEntryPtr(
"order");
152 s = pe->getValue(&s);
153 if (orders.find(s) == orders.end()) {
154 throw std::logic_error(
"Wrong order type.");
156 config.order_type = orders.at(s);
157 config.order_str = s;
159 pe = sparams.getEntryPtr(
"row_parts");
161 config.row_parts = pe->getValue(&config.row_parts);
162 pe = sparams.getEntryPtr(
"col_parts");
164 config.col_parts = pe->getValue(&config.col_parts);
165 pe = sparams.getEntryPtr(
"max_load");
167 config.max_load = pe->getValue(&config.max_load);
168 pe = sparams.getEntryPtr(
"sparsify");
170 config.sparsify = pe->getValue(&config.sparsify);
171 pe = sparams.getEntryPtr(
"triangular");
173 config.triangular = pe->getValue(&config.triangular);
174 pe = sparams.getEntryPtr(
"use_data");
176 config.use_data = pe->getValue(&config.use_data);
177 pe = sparams.getEntryPtr(
"seed");
179 config.seed = pe->getValue(&config.seed);
182 template<
typename Adapter>
186 null.setstate(std::ios_base::badbit);
187 auto M = std::make_shared<sarma::Matrix<Ordinal, Value> >(std::move(std::vector<Ordinal>(offsets, offsets + offsize)),
188 std::move(std::vector<Ordinal>(colids, colids + nnz)), std::move(std::vector<Value>(vals, vals + nnz)),
189 1 + *std::max_element(colids, colids + nnz));
190 auto parts = sarma::Run<Ordinal, Value>(algs.at(config.alg).first, null, M, config.order_type, config.row_parts,
191 config.col_parts, config.max_load, config.triangular,
false, config.sparsify,
192 algs.at(config.alg).second, config.use_data, config.seed);
194 unsigned result_size = parts.first.size() + parts.second.size();
195 auto partl = ArrayRCP<int>(
new int[result_size], 0, result_size,
true);
196 for (
size_t i = 0; i < parts.first.size(); ++i) partl[i] = parts.first[i];
197 for (
size_t i = parts.first.size(); i < result_size; ++i)
198 partl[i] = parts.second[i - parts.first.size()];
200 solution->setParts(partl);
Zoltan2::BaseAdapter< userTypes_t > base_adapter_t
virtual void partition(const RCP< PartitioningSolution< Adapter > > &)
Partitioning method.
typename Adapter::base_adapter_t base_adapter_t
map_t::global_ordinal_type gno_t
Defines the PartitioningSolution class.
AlgSarma(const RCP< const Environment > &, const RCP< const Comm< int > > &, const RCP< const base_adapter_t > &)
SparseMatrixAdapter_t::part_t part_t
Adapter::scalar_t scalar_t
Algorithm defines the base class for all algorithms.
map_t::local_ordinal_type lno_t
Traits class to handle conversions between gno_t/lno_t and TPL data types (e.g., ParMETIS's idx_t...
Defines the CoordinateModel classes.
A gathering of useful namespace methods.
Zoltan2::BasicUserTypes< zscalar_t, zlno_t, zgno_t > user_t