10 #ifndef TPETRA_DETAILS_NORMIMPL_HPP
11 #define TPETRA_DETAILS_NORMIMPL_HPP
22 #include "TpetraCore_config.h"
23 #include "Kokkos_Core.hpp"
24 #include "Teuchos_ArrayView.hpp"
25 #include "Teuchos_CommHelpers.hpp"
26 #include "KokkosBlas.hpp"
27 #if KOKKOS_VERSION >= 40799
28 #include "KokkosKernels_ArithTraits.hpp"
30 #include "Kokkos_ArithTraits.hpp"
33 #ifndef DOXYGEN_SHOULD_SKIP_THIS
37 template <
class OrdinalType>
40 #endif // DOXYGEN_SHOULD_SKIP_THIS
57 template <
class ValueType,
62 const Kokkos::View<const ValueType**, ArrayLayout, DeviceType>& X,
64 const Teuchos::ArrayView<const size_t>& whichVecs,
65 const bool isConstantStride,
66 const bool isDistributed,
67 const Teuchos::Comm<int>* comm);
80 template <
class RV,
class XMV>
81 void lclNormImpl(
const RV& normsOut,
84 const Teuchos::ArrayView<const size_t>& whichVecs,
85 const bool constantStride,
88 using Kokkos::subview;
89 using mag_type =
typename RV::non_const_value_type;
91 static_assert(static_cast<int>(RV::rank) == 1,
92 "Tpetra::MultiVector::lclNormImpl: "
93 "The first argument normsOut must have rank 1.");
94 static_assert(Kokkos::is_view<XMV>::value,
95 "Tpetra::MultiVector::lclNormImpl: "
96 "The second argument X is not a Kokkos::View.");
97 static_assert(static_cast<int>(XMV::rank) == 2,
98 "Tpetra::MultiVector::lclNormImpl: "
99 "The second argument X must have rank 2.");
101 const size_t lclNumRows =
static_cast<size_t>(X.extent(0));
102 TEUCHOS_TEST_FOR_EXCEPTION(lclNumRows != 0 && constantStride &&
103 static_cast<size_t>(X.extent(1)) != numVecs,
104 std::logic_error,
"Constant Stride X's dimensions are " << X.extent(0) <<
" x " << X.extent(1) <<
", which differ from the local dimensions " << lclNumRows <<
" x " << numVecs <<
". Please report this bug to "
105 "the Tpetra developers.");
106 TEUCHOS_TEST_FOR_EXCEPTION(lclNumRows != 0 && !constantStride &&
107 static_cast<size_t>(X.extent(1)) < numVecs,
108 std::logic_error,
"Strided X's dimensions are " << X.extent(0) <<
" x " << X.extent(1) <<
", which are incompatible with the local dimensions " << lclNumRows <<
" x " << numVecs <<
". Please report this bug to "
109 "the Tpetra developers.");
111 if (lclNumRows == 0) {
112 #if KOKKOS_VERSION >= 40799
113 const mag_type zeroMag = KokkosKernels::ArithTraits<mag_type>::zero();
115 const mag_type zeroMag = Kokkos::ArithTraits<mag_type>::zero();
118 using execution_space =
typename RV::execution_space;
121 if (constantStride) {
122 if (whichNorm == NORM_INF) {
123 KokkosBlas::nrminf(normsOut, X);
124 }
else if (whichNorm == NORM_ONE) {
125 KokkosBlas::nrm1(normsOut, X);
126 }
else if (whichNorm == NORM_TWO) {
127 KokkosBlas::nrm2_squared(normsOut, X);
129 TEUCHOS_TEST_FOR_EXCEPTION(
true, std::logic_error,
"Should never get here!");
137 for (
size_t k = 0; k < numVecs; ++k) {
138 const size_t X_col = constantStride ? k : whichVecs[k];
139 if (whichNorm == NORM_INF) {
140 KokkosBlas::nrminf(subview(normsOut, k),
141 subview(X, ALL(), X_col));
142 }
else if (whichNorm == NORM_ONE) {
143 KokkosBlas::nrm1(subview(normsOut, k),
144 subview(X, ALL(), X_col));
145 }
else if (whichNorm == NORM_TWO) {
146 KokkosBlas::nrm2_squared(subview(normsOut, k),
147 subview(X, ALL(), X_col));
149 TEUCHOS_TEST_FOR_EXCEPTION(
true, std::logic_error,
"Should never get here!");
158 template <
class ViewType>
159 class SquareRootFunctor {
161 typedef typename ViewType::execution_space execution_space;
162 typedef typename ViewType::size_type size_type;
164 SquareRootFunctor(
const ViewType& theView)
165 : theView_(theView) {}
167 KOKKOS_INLINE_FUNCTION
void
168 operator()(
const size_type& i)
const {
169 typedef typename ViewType::non_const_value_type value_type;
170 #if KOKKOS_VERSION >= 40799
171 typedef KokkosKernels::ArithTraits<value_type> KAT;
173 typedef Kokkos::ArithTraits<value_type> KAT;
175 theView_(i) = KAT::sqrt(theView_(i));
183 void gblNormImpl(
const RV& normsOut,
184 const Teuchos::Comm<int>*
const comm,
185 const bool distributed,
187 using Teuchos::REDUCE_MAX;
188 using Teuchos::REDUCE_SUM;
189 using Teuchos::reduceAll;
190 typedef typename RV::non_const_value_type mag_type;
192 const size_t numVecs = normsOut.extent(0);
207 if (distributed && comm !=
nullptr) {
211 const int nv =
static_cast<int>(numVecs);
212 const bool commIsInterComm = ::Tpetra::Details::isInterComm(*comm);
214 if (commIsInterComm) {
215 RV lclNorms(Kokkos::ViewAllocateWithoutInitializing(
"MV::normImpl lcl"), numVecs);
217 using execution_space =
typename RV::execution_space;
219 const mag_type*
const lclSum = lclNorms.data();
220 mag_type*
const gblSum = normsOut.data();
222 if (whichNorm == NORM_INF) {
223 reduceAll<int, mag_type>(*comm, REDUCE_MAX, nv, lclSum, gblSum);
225 reduceAll<int, mag_type>(*comm, REDUCE_SUM, nv, lclSum, gblSum);
228 mag_type*
const gblSum = normsOut.data();
229 if (whichNorm == NORM_INF) {
230 reduceAll<int, mag_type>(*comm, REDUCE_MAX, nv, gblSum, gblSum);
232 reduceAll<int, mag_type>(*comm, REDUCE_SUM, nv, gblSum, gblSum);
237 if (whichNorm == NORM_TWO) {
243 const bool inHostMemory =
244 std::is_same<
typename RV::memory_space,
245 typename RV::host_mirror_space::memory_space>::value;
247 for (
size_t j = 0; j < numVecs; ++j) {
248 #if KOKKOS_VERSION >= 40799
249 normsOut(j) = KokkosKernels::ArithTraits<mag_type>::sqrt(normsOut(j));
251 normsOut(j) = Kokkos::ArithTraits<mag_type>::sqrt(normsOut(j));
259 SquareRootFunctor<RV> f(normsOut);
260 typedef typename RV::execution_space execution_space;
261 typedef Kokkos::RangePolicy<execution_space, size_t> range_type;
262 Kokkos::parallel_for(range_type(0, numVecs), f);
269 template <
class ValueType,
274 const Kokkos::View<const ValueType**, ArrayLayout, DeviceType>& X,
276 const Teuchos::ArrayView<const size_t>& whichVecs,
277 const bool isConstantStride,
278 const bool isDistributed,
279 const Teuchos::Comm<int>* comm) {
280 using execution_space =
typename DeviceType::execution_space;
281 using RV = Kokkos::View<MagnitudeType*, Kokkos::HostSpace>;
285 const size_t numVecs = isConstantStride ?
static_cast<size_t>(X.extent(1)) : static_cast<size_t>(whichVecs.size());
289 RV normsOut(norms, numVecs);
291 Impl::lclNormImpl(normsOut, X, numVecs, whichVecs,
292 isConstantStride, whichNorm);
299 execution_space exec_space_instance = execution_space();
300 exec_space_instance.fence();
302 Impl::gblNormImpl(normsOut, comm, isDistributed, whichNorm);
308 #endif // TPETRA_DETAILS_NORMIMPL_HPP
void deep_copy(MultiVector< DS, DL, DG, DN > &dst, const MultiVector< SS, SL, SG, SN > &src)
Copy the contents of the MultiVector src into dst.
void normImpl(MagnitudeType norms[], const Kokkos::View< const ValueType **, ArrayLayout, DeviceType > &X, const EWhichNorm whichNorm, const Teuchos::ArrayView< const size_t > &whichVecs, const bool isConstantStride, const bool isDistributed, const Teuchos::Comm< int > *comm)
Implementation of MultiVector norms.
EWhichNorm
Input argument for normImpl() (which see).