44 #ifndef TPETRA_TRANSFORM_MULTIVECTOR_HPP
45 #define TPETRA_TRANSFORM_MULTIVECTOR_HPP
51 #include "Tpetra_Map.hpp"
52 #include "Teuchos_Comm.hpp"
53 #include "Teuchos_TestForException.hpp"
54 #include "Kokkos_Core.hpp"
69 template<
class InputViewType,
71 class UnaryFunctionType,
73 class MultiVectorUnaryTransformLoopBody {
75 static_assert (static_cast<int> (InputViewType::Rank) == 2,
76 "InputViewType must be a rank-2 Kokkos::View.");
77 static_assert (static_cast<int> (OutputViewType::Rank) == 2,
78 "OutputViewType must be a rank-2 Kokkos::View.");
81 MultiVectorUnaryTransformLoopBody (
const InputViewType& in,
82 const OutputViewType& out,
83 UnaryFunctionType f) :
84 in_ (in), out_ (out), f_ (f)
87 KOKKOS_INLINE_FUNCTION
void
88 operator () (
const LocalIndexType i)
const {
89 using LO = LocalIndexType;
90 const LO numCols =
static_cast<LO
> (in_.extent (1));
91 for (LO j = 0; j < numCols; ++j) {
92 out_(i,j) = f_ (in_(i,j));
106 template<
class InputViewType1,
107 class InputViewType2,
108 class OutputViewType,
109 class BinaryFunctionType,
110 class LocalIndexType>
111 class MultiVectorBinaryTransformLoopBody {
113 static_assert (static_cast<int> (InputViewType1::Rank) == 2,
114 "InputViewType1 must be a rank-2 Kokkos::View.");
115 static_assert (static_cast<int> (InputViewType2::Rank) == 2,
116 "InputViewType2 must be a rank-2 Kokkos::View.");
117 static_assert (static_cast<int> (OutputViewType::Rank) == 2,
118 "OutputViewType must be a rank-2 Kokkos::View.");
121 MultiVectorBinaryTransformLoopBody (
const InputViewType1& in1,
122 const InputViewType2& in2,
123 const OutputViewType& out,
124 BinaryFunctionType f) :
125 in1_ (in1), in2_ (in2), out_ (out), f_ (f)
128 KOKKOS_INLINE_FUNCTION
void
129 operator () (
const LocalIndexType i)
const {
130 using LO = LocalIndexType;
131 const LO numCols =
static_cast<LO
> (in1_.extent (1));
132 for (LO j = 0; j < numCols; ++j) {
133 out_(i,j) = f_ (in1_(i,j), in2_(i,j));
141 BinaryFunctionType f_;
148 template<
class InputViewType,
149 class OutputViewType,
150 class UnaryFunctionType,
151 class LocalIndexType>
152 class VectorUnaryTransformLoopBody {
154 static_assert (static_cast<int> (InputViewType::Rank) == 1,
155 "InputViewType must be a rank-1 Kokkos::View.");
156 static_assert (static_cast<int> (OutputViewType::Rank) == 1,
157 "OutputViewType must be a rank-1 Kokkos::View.");
160 VectorUnaryTransformLoopBody (
const InputViewType& in,
161 const OutputViewType& out,
162 UnaryFunctionType f) :
163 in_ (in), out_ (out), f_ (f)
166 KOKKOS_INLINE_FUNCTION
void
167 operator () (
const LocalIndexType i)
const {
168 out_(i) = f_ (in_(i));
174 UnaryFunctionType f_;
181 template<
class InputViewType1,
182 class InputViewType2,
183 class OutputViewType,
184 class BinaryFunctionType,
185 class LocalIndexType>
186 class VectorBinaryTransformLoopBody {
188 static_assert (static_cast<int> (InputViewType1::Rank) == 1,
189 "InputViewType1 must be a rank-1 Kokkos::View.");
190 static_assert (static_cast<int> (InputViewType1::Rank) == 1,
191 "InputViewType1 must be a rank-1 Kokkos::View.");
192 static_assert (static_cast<int> (OutputViewType::Rank) == 1,
193 "OutputViewType must be a rank-1 Kokkos::View.");
196 VectorBinaryTransformLoopBody (
const InputViewType1& in1,
197 const InputViewType2& in2,
198 const OutputViewType& out,
199 BinaryFunctionType f) :
200 in1_ (in1), in2_ (in2), out_ (out), f_ (f)
203 KOKKOS_INLINE_FUNCTION
void
204 operator () (
const LocalIndexType i)
const {
205 out_(i) = f_ (in1_(i), in2_(i));
212 BinaryFunctionType f_;
219 template<
class ExecutionSpace,
220 class SC,
class LO,
class GO,
class NT,
221 class UnaryFunctionType>
222 class UnaryTransformSameMultiVector {
228 UnaryTransformSameMultiVector (UnaryFunctionType f) : f_ (f) {}
230 KOKKOS_INLINE_FUNCTION
void operator() (IST& X_ij)
const {
238 UnaryFunctionType f_;
241 template<
class ExecutionSpace,
242 class SC,
class LO,
class GO,
class NT,
243 class UnaryFunctionType>
245 unaryTransformSameMultiVector (
const char kernelLabel[],
246 ExecutionSpace execSpace,
250 using functor_type = UnaryTransformSameMultiVector<ExecutionSpace,
251 SC, LO, GO, NT, UnaryFunctionType>;
257 template<
class ExecutionSpace,
258 class SC,
class LO,
class GO,
class NT>
267 template<
class MemorySpace>
268 using transform_memory_space =
269 typename std::conditional<
270 Kokkos::SpaceAccessibility<
272 typename MemorySpace::memory_space>::accessible,
273 typename MemorySpace::memory_space,
274 typename ExecutionSpace::memory_space>::type;
277 using preferred_memory_space =
278 typename MV::device_type::memory_space;
280 transform_memory_space<preferred_memory_space>;
288 sameObject (const ::Tpetra::MultiVector<SC, LO, GO, NT>& input,
289 const ::Tpetra::MultiVector<SC, LO, GO, NT>& output)
291 return &input == &output ||
293 output.getLocalViewHost ().data () ||
294 input.getLocalViewDevice ().data () ==
295 output.getLocalViewDevice ().data ();
298 template<
class UnaryFunctionType>
300 transform_vec_notSameObject
301 (
const char kernelLabel[],
302 ExecutionSpace execSpace,
307 memory_space memSpace;
309 using input_view_type =
311 decltype (
readOnly (input).on (memSpace))>;
312 using output_view_type =
314 decltype (
writeOnly (output).on (memSpace))>;
317 ([=] (
const input_view_type& input_lcl,
318 const output_view_type& output_lcl) {
319 using functor_type = VectorUnaryTransformLoopBody<
320 input_view_type, output_view_type, UnaryFunctionType, LO>;
321 functor_type g (input_lcl, output_lcl, f);
323 const LO lclNumRows =
static_cast<LO
> (input_lcl.extent (0));
324 using range_type = Kokkos::RangePolicy<ExecutionSpace, LO>;
325 range_type range (execSpace, 0, lclNumRows);
326 Kokkos::parallel_for (kernelLabel, range, g);
332 template<
class UnaryFunctionType>
334 transform_mv_notSameObject
335 (
const char kernelLabel[],
336 ExecutionSpace execSpace,
341 memory_space memSpace;
343 using input_view_type =
345 decltype (
readOnly (input).on (memSpace))>;
346 using output_view_type =
348 decltype (
writeOnly (output).on (memSpace))>;
351 ([=] (
const input_view_type& input_lcl,
352 const output_view_type& output_lcl) {
353 using functor_type = MultiVectorUnaryTransformLoopBody<
354 input_view_type, output_view_type, UnaryFunctionType, LO>;
355 functor_type g (input_lcl, output_lcl, f);
357 const LO lclNumRows =
static_cast<LO
> (input_lcl.extent (0));
358 using range_type = Kokkos::RangePolicy<ExecutionSpace, LO>;
359 range_type range (execSpace, 0, lclNumRows);
360 Kokkos::parallel_for (kernelLabel, range, g);
367 template<
class UnaryFunctionType>
370 ExecutionSpace execSpace,
375 using Teuchos::TypeNameTraits;
378 const int myRank = output.
getMap ()->getComm ()->getRank ();
381 std::ostringstream os;
382 os <<
"Proc " << myRank <<
": Tpetra::transform:" << endl
383 <<
" kernelLabel: " << kernelLabel << endl
384 <<
" ExecutionSpace: "
385 << TypeNameTraits<ExecutionSpace>::name () << endl;
386 std::cerr << os.str ();
390 TEUCHOS_TEST_FOR_EXCEPTION
393 " != output.getNumVectors() = " << numVecs <<
".");
398 memory_space memSpace;
399 if (numVecs ==
size_t (1) || ! constStride) {
400 for (
size_t j = 0; j < numVecs; ++j) {
407 if (sameObject (*output_j, *input_j)) {
408 unaryTransformSameMultiVector (kernelLabel, execSpace,
412 transform_vec_notSameObject (kernelLabel, execSpace,
413 *input_j, *output_j, f);
418 if (sameObject (output, input)) {
419 unaryTransformSameMultiVector (kernelLabel, execSpace,
423 transform_mv_notSameObject (kernelLabel, execSpace,
430 template<
class BinaryFunctionType>
433 ExecutionSpace execSpace,
437 BinaryFunctionType f)
439 using Teuchos::TypeNameTraits;
441 const char prefix[] =
"Tpetra::transform (binary): ";
443 const int myRank = output.
getMap ()->getComm ()->getRank ();
446 std::ostringstream os;
447 os <<
"Proc " << myRank <<
": " << prefix << endl
448 <<
" Tpetra::MultiVector<" << TypeNameTraits<SC>::name ()
449 <<
", " << TypeNameTraits<LO>::name () <<
", "
450 << TypeNameTraits<GO>::name () <<
", "
451 << TypeNameTraits<NT>::name () <<
">" << endl
452 <<
" kernelLabel: " << kernelLabel << endl
453 <<
" ExecutionSpace: "
454 << TypeNameTraits<ExecutionSpace>::name () << endl;
455 std::cerr << os.str ();
459 TEUCHOS_TEST_FOR_EXCEPTION
461 prefix <<
"input1.getNumVectors() = " << input1.
getNumVectors ()
462 <<
" != output.getNumVectors() = " << numVecs <<
".");
463 TEUCHOS_TEST_FOR_EXCEPTION
465 prefix <<
"input2.getNumVectors() = " << input2.
getNumVectors ()
466 <<
" != output.getNumVectors() = " << numVecs <<
".");
470 memory_space memSpace;
473 using range_type = Kokkos::RangePolicy<ExecutionSpace, LO>;
474 range_type range (execSpace, 0, lclNumRows);
476 if (numVecs ==
size_t (1) || ! constStride) {
477 for (
size_t j = 0; j < numVecs; ++j) {
485 const bool outin1same = sameObject (*output_j, *input1_j);
486 const bool outin2same = sameObject (*output_j, *input2_j);
488 const bool in1in2same = sameObject (*input1_j, *input2_j);
489 const bool allsame = outin1same && outin2same;
495 vec_type& input1_j_ref = *input1_j;
496 vec_type& input2_j_ref = *input2_j;
497 vec_type& output_j_ref = *output_j;
502 using input1_view_type =
504 decltype (
readOnly (input1_j_ref).on (memSpace))>;
505 using input2_view_type =
507 decltype (
readOnly (input2_j_ref).on (memSpace))>;
508 using rw_output_view_type =
510 decltype (
readWrite (output_j_ref).on (memSpace))>;
511 using wo_output_view_type =
513 decltype (
writeOnly (output_j_ref).on (memSpace))>;
517 ([=] (
const rw_output_view_type& output_lcl) {
518 using functor_type = VectorBinaryTransformLoopBody<
519 typename rw_output_view_type::const_type,
520 typename rw_output_view_type::const_type,
522 BinaryFunctionType, LO>;
523 functor_type functor (output_lcl, output_lcl, output_lcl, f);
524 Kokkos::parallel_for (kernelLabel, range, functor);
528 else if (in1in2same) {
530 ([=] (
const input1_view_type& input1_lcl,
531 const wo_output_view_type& output_lcl) {
532 using functor_type = VectorBinaryTransformLoopBody<
536 BinaryFunctionType, LO>;
537 functor_type functor (input1_lcl, input1_lcl, output_lcl, f);
538 Kokkos::parallel_for (kernelLabel, range, functor);
540 readOnly (input1_j_ref).on (memSpace),
543 else if (outin1same) {
545 ([=] (
const input2_view_type& input2_lcl,
546 const rw_output_view_type& output_lcl) {
547 using functor_type = VectorBinaryTransformLoopBody<
548 typename rw_output_view_type::const_type,
551 BinaryFunctionType, LO>;
552 functor_type functor (output_lcl, input2_lcl, output_lcl, f);
553 Kokkos::parallel_for (kernelLabel, range, functor);
555 readOnly (input2_j_ref).on (memSpace),
558 else if (outin2same) {
560 ([=] (
const input1_view_type& input1_lcl,
561 const rw_output_view_type& output_lcl) {
562 using functor_type = VectorBinaryTransformLoopBody<
564 typename rw_output_view_type::const_type,
566 BinaryFunctionType, LO>;
567 functor_type functor (input1_lcl, output_lcl, output_lcl, f);
568 Kokkos::parallel_for (kernelLabel, range, functor);
570 readOnly (input1_j_ref).on (memSpace),
575 ([=] (
const input1_view_type& input1_lcl,
576 const input2_view_type& input2_lcl,
577 const wo_output_view_type& output_lcl) {
578 using functor_type = VectorBinaryTransformLoopBody<
582 BinaryFunctionType, LO>;
583 functor_type functor (input1_lcl, input2_lcl, output_lcl, f);
584 Kokkos::parallel_for (kernelLabel, range, functor);
586 readOnly (input1_j_ref).on (memSpace),
587 readOnly (input2_j_ref).on (memSpace),
596 const bool outin1same = sameObject (output, input1);
597 const bool outin2same = sameObject (output, input2);
599 const bool in1in2same = sameObject (input1, input2);
600 const bool allsame = outin1same && outin2same;
605 using input1_view_type =
607 decltype (
readOnly (input1).on (memSpace))>;
608 using input2_view_type =
610 decltype (
readOnly (input2).on (memSpace))>;
611 using rw_output_view_type =
613 decltype (
readWrite (output).on (memSpace))>;
614 using wo_output_view_type =
616 decltype (
writeOnly (output).on (memSpace))>;
620 ([=] (
const rw_output_view_type& output_lcl) {
621 using functor_type = MultiVectorBinaryTransformLoopBody<
622 typename rw_output_view_type::const_type,
623 typename rw_output_view_type::const_type,
625 BinaryFunctionType, LO>;
626 functor_type functor (output_lcl, output_lcl, output_lcl, f);
627 Kokkos::parallel_for (kernelLabel, range, functor);
631 else if (in1in2same) {
633 ([=] (
const input1_view_type& input1_lcl,
634 const wo_output_view_type& output_lcl) {
635 using functor_type = MultiVectorBinaryTransformLoopBody<
639 BinaryFunctionType, LO>;
640 functor_type functor (input1_lcl, input1_lcl, output_lcl, f);
641 Kokkos::parallel_for (kernelLabel, range, functor);
646 else if (outin1same) {
648 ([=] (
const input2_view_type& input2_lcl,
649 const rw_output_view_type& output_lcl) {
650 using functor_type = MultiVectorBinaryTransformLoopBody<
651 typename rw_output_view_type::const_type,
654 BinaryFunctionType, LO>;
655 functor_type functor (output_lcl, input2_lcl, output_lcl, f);
656 Kokkos::parallel_for (kernelLabel, range, functor);
661 else if (outin2same) {
663 ([=] (
const input1_view_type& input1_lcl,
664 const rw_output_view_type& output_lcl) {
665 using functor_type = MultiVectorBinaryTransformLoopBody<
667 typename rw_output_view_type::const_type,
669 BinaryFunctionType, LO>;
670 functor_type functor (input1_lcl, output_lcl, output_lcl, f);
671 Kokkos::parallel_for (kernelLabel, range, functor);
678 ([=] (
const input1_view_type& input1_lcl,
679 const input2_view_type& input2_lcl,
680 const wo_output_view_type& output_lcl) {
681 using functor_type = MultiVectorBinaryTransformLoopBody<
685 BinaryFunctionType, LO>;
686 functor_type functor (input1_lcl, input2_lcl, output_lcl, f);
687 Kokkos::parallel_for (kernelLabel, range, functor);
703 template<
class ExecutionSpace,
704 class SC,
class LO,
class GO,
class NT>
709 template<
class UnaryFunctionType>
712 ExecutionSpace execSpace,
719 using UFT = UnaryFunctionType;
721 impl_type::template transform<UFT> (kernelLabel, execSpace,
726 template<
class BinaryFunctionType>
729 ExecutionSpace execSpace,
733 BinaryFunctionType f)
737 using BFT = BinaryFunctionType;
739 impl_type::template transform<BFT> (kernelLabel, execSpace,
740 input1, input2, output, f);
748 #endif // TPETRA_TRANSFORM_MULTIVECTOR_HPP
Include this file to make Tpetra::MultiVector and Tpetra::Vector work with Tpetra::withLocalAccess.
dual_view_type::t_host getLocalViewHost() const
A local Kokkos::View of host memory.
Details::LocalAccess< GlobalObjectType, typename Details::DefaultMemorySpace< GlobalObjectType >::type, Details::AccessMode::WriteOnly > writeOnly(GlobalObjectType &)
Declare that you want to access the given global object's local data in write-only mode...
size_t getNumVectors() const
Number of columns in the multivector.
size_t getLocalLength() const
Local number of rows on the calling process.
bool isConstantStride() const
Whether this multivector has constant stride between columns.
One or more distributed dense vectors.
Details::LocalAccess< GlobalObjectType, typename Details::DefaultMemorySpace< GlobalObjectType >::type, Details::AccessMode::ReadWrite > readWrite(GlobalObjectType &)
Declare that you want to access the given global object's local data in read-and-write mode...
Teuchos::RCP< Vector< Scalar, LocalOrdinal, GlobalOrdinal, Node > > getVectorNonConst(const size_t j)
Return a Vector which is a nonconst view of column j.
static bool verbose()
Whether Tpetra is in verbose mode.
void withLocalAccess(typename Details::ArgsToFunction< LocalAccessTypes...>::type userFunction, LocalAccessTypes...localAccesses)
Get access to a Tpetra global object's local data.
Include this file to make Tpetra::for_each work with Tpetra::MultiVector and Tpetra::Vector.
typename Kokkos::Details::ArithTraits< Scalar >::val_type impl_scalar_type
The type used internally in place of Scalar.
A distributed dense vector.
virtual Teuchos::RCP< const map_type > getMap() const
The Map describing the parallel distribution of this object.
typename Details::GetNonowningLocalObject< LocalAccessType >::nonowning_local_object_type with_local_access_function_argument_type
Type of the local object, that is an argument to the function the user gives to withLocalAccess.
void for_each(const char kernelLabel[], ExecutionSpace execSpace, GlobalDataStructure &X, UserFunctionType f)
Apply a function entrywise to each local entry of a Tpetra global data structure, analogously to std:...
Details::LocalAccess< GlobalObjectType, typename Details::DefaultMemorySpace< GlobalObjectType >::type, Details::AccessMode::ReadOnly > readOnly(GlobalObjectType &)
Declare that you want to access the given global object's local data in read-only mode...
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.