42 #ifndef TPETRA_TRANSFORM_MULTIVECTOR_HPP
43 #define TPETRA_TRANSFORM_MULTIVECTOR_HPP
49 #include "Tpetra_Map.hpp"
50 #include "Teuchos_Comm.hpp"
51 #include "Teuchos_TestForException.hpp"
52 #include "Kokkos_Core.hpp"
67 template<
class InputViewType,
69 class UnaryFunctionType,
71 class MultiVectorUnaryTransformLoopBody {
73 static_assert (static_cast<int> (InputViewType::Rank) == 2,
74 "InputViewType must be a rank-2 Kokkos::View.");
75 static_assert (static_cast<int> (OutputViewType::Rank) == 2,
76 "OutputViewType must be a rank-2 Kokkos::View.");
79 MultiVectorUnaryTransformLoopBody (
const InputViewType& in,
80 const OutputViewType& out,
81 UnaryFunctionType f) :
82 in_ (in), out_ (out), f_ (f)
85 KOKKOS_INLINE_FUNCTION
void
86 operator () (
const LocalIndexType i)
const {
87 using LO = LocalIndexType;
88 const LO numCols =
static_cast<LO
> (in_.extent (1));
89 for (LO j = 0; j < numCols; ++j) {
90 out_(i,j) = f_ (in_(i,j));
104 template<
class InputViewType1,
105 class InputViewType2,
106 class OutputViewType,
107 class BinaryFunctionType,
108 class LocalIndexType>
109 class MultiVectorBinaryTransformLoopBody {
111 static_assert (static_cast<int> (InputViewType1::Rank) == 2,
112 "InputViewType1 must be a rank-2 Kokkos::View.");
113 static_assert (static_cast<int> (InputViewType2::Rank) == 2,
114 "InputViewType2 must be a rank-2 Kokkos::View.");
115 static_assert (static_cast<int> (OutputViewType::Rank) == 2,
116 "OutputViewType must be a rank-2 Kokkos::View.");
119 MultiVectorBinaryTransformLoopBody (
const InputViewType1& in1,
120 const InputViewType2& in2,
121 const OutputViewType& out,
122 BinaryFunctionType f) :
123 in1_ (in1), in2_ (in2), out_ (out), f_ (f)
126 KOKKOS_INLINE_FUNCTION
void
127 operator () (
const LocalIndexType i)
const {
128 using LO = LocalIndexType;
129 const LO numCols =
static_cast<LO
> (in1_.extent (1));
130 for (LO j = 0; j < numCols; ++j) {
131 out_(i,j) = f_ (in1_(i,j), in2_(i,j));
139 BinaryFunctionType f_;
146 template<
class InputViewType,
147 class OutputViewType,
148 class UnaryFunctionType,
149 class LocalIndexType>
150 class VectorUnaryTransformLoopBody {
152 static_assert (static_cast<int> (InputViewType::Rank) == 1,
153 "InputViewType must be a rank-1 Kokkos::View.");
154 static_assert (static_cast<int> (OutputViewType::Rank) == 1,
155 "OutputViewType must be a rank-1 Kokkos::View.");
158 VectorUnaryTransformLoopBody (
const InputViewType& in,
159 const OutputViewType& out,
160 UnaryFunctionType f) :
161 in_ (in), out_ (out), f_ (f)
164 KOKKOS_INLINE_FUNCTION
void
165 operator () (
const LocalIndexType i)
const {
166 out_(i) = f_ (in_(i));
172 UnaryFunctionType f_;
179 template<
class InputViewType1,
180 class InputViewType2,
181 class OutputViewType,
182 class BinaryFunctionType,
183 class LocalIndexType>
184 class VectorBinaryTransformLoopBody {
186 static_assert (static_cast<int> (InputViewType1::Rank) == 1,
187 "InputViewType1 must be a rank-1 Kokkos::View.");
188 static_assert (static_cast<int> (InputViewType1::Rank) == 1,
189 "InputViewType1 must be a rank-1 Kokkos::View.");
190 static_assert (static_cast<int> (OutputViewType::Rank) == 1,
191 "OutputViewType must be a rank-1 Kokkos::View.");
194 VectorBinaryTransformLoopBody (
const InputViewType1& in1,
195 const InputViewType2& in2,
196 const OutputViewType& out,
197 BinaryFunctionType f) :
198 in1_ (in1), in2_ (in2), out_ (out), f_ (f)
201 KOKKOS_INLINE_FUNCTION
void
202 operator () (
const LocalIndexType i)
const {
203 out_(i) = f_ (in1_(i), in2_(i));
210 BinaryFunctionType f_;
217 template<
class ExecutionSpace,
218 class SC,
class LO,
class GO,
class NT,
219 class UnaryFunctionType>
220 class UnaryTransformSameMultiVector {
226 UnaryTransformSameMultiVector (UnaryFunctionType f) : f_ (f) {}
228 KOKKOS_INLINE_FUNCTION
void operator() (IST& X_ij)
const {
236 UnaryFunctionType f_;
239 template<
class ExecutionSpace,
240 class SC,
class LO,
class GO,
class NT,
241 class UnaryFunctionType>
243 unaryTransformSameMultiVector (
const char kernelLabel[],
244 ExecutionSpace execSpace,
248 using functor_type = UnaryTransformSameMultiVector<ExecutionSpace,
249 SC, LO, GO, NT, UnaryFunctionType>;
255 template<
class ExecutionSpace,
256 class SC,
class LO,
class GO,
class NT>
265 template<
class MemorySpace>
266 using transform_memory_space =
267 typename std::conditional<
268 Kokkos::SpaceAccessibility<
270 typename MemorySpace::memory_space>::accessible,
271 typename MemorySpace::memory_space,
272 typename ExecutionSpace::memory_space>::type;
275 using preferred_memory_space =
276 typename MV::device_type::memory_space;
278 transform_memory_space<preferred_memory_space>;
286 sameObject (const ::Tpetra::MultiVector<SC, LO, GO, NT>& input,
287 const ::Tpetra::MultiVector<SC, LO, GO, NT>& output)
289 return &input == &output ||
291 output.getLocalViewHost ().data () ||
292 input.getLocalViewDevice ().data () ==
293 output.getLocalViewDevice ().data ();
296 template<
class UnaryFunctionType>
298 transform_vec_notSameObject
299 (
const char kernelLabel[],
300 ExecutionSpace execSpace,
305 memory_space memSpace;
307 using input_view_type =
309 decltype (
readOnly (input).on (memSpace). at(execSpace))>;
310 using output_view_type =
312 decltype (
writeOnly (output).on (memSpace). at(execSpace))>;
315 ([=] (
const input_view_type& input_lcl,
316 const output_view_type& output_lcl) {
317 using functor_type = VectorUnaryTransformLoopBody<
318 input_view_type, output_view_type, UnaryFunctionType, LO>;
319 functor_type g (input_lcl, output_lcl, f);
321 const LO lclNumRows =
static_cast<LO
> (input_lcl.extent (0));
322 using range_type = Kokkos::RangePolicy<ExecutionSpace, LO>;
323 range_type range (execSpace, 0, lclNumRows);
325 Kokkos::parallel_for (kernelLabel, range, g);
327 readOnly (input).on (memSpace).at (execSpace),
328 writeOnly (output).on (memSpace).at (execSpace));
331 template<
class UnaryFunctionType>
333 transform_mv_notSameObject
334 (
const char kernelLabel[],
335 ExecutionSpace execSpace,
340 memory_space memSpace;
342 using input_view_type =
344 decltype (
readOnly (input).on (memSpace). at(execSpace))>;
345 using output_view_type =
347 decltype (
writeOnly (output).on (memSpace). at(execSpace))>;
350 ([=] (
const input_view_type& input_lcl,
351 const output_view_type& output_lcl) {
352 using functor_type = MultiVectorUnaryTransformLoopBody<
353 input_view_type, output_view_type, UnaryFunctionType, LO>;
354 functor_type g (input_lcl, output_lcl, f);
356 const LO lclNumRows =
static_cast<LO
> (input_lcl.extent (0));
357 using range_type = Kokkos::RangePolicy<ExecutionSpace, LO>;
358 range_type range (execSpace, 0, lclNumRows);
360 Kokkos::parallel_for (kernelLabel, range, g);
362 readOnly (input).on (memSpace).at (execSpace),
363 writeOnly (output).on (memSpace).at (execSpace));
367 template<
class UnaryFunctionType>
370 ExecutionSpace execSpace,
375 using Teuchos::TypeNameTraits;
378 const int myRank = output.
getMap ()->getComm ()->getRank ();
381 std::ostringstream os;
382 os <<
"Proc " << myRank <<
": Tpetra::transform:" << endl
383 <<
" kernelLabel: " << kernelLabel << endl
384 <<
" ExecutionSpace: "
385 << TypeNameTraits<ExecutionSpace>::name () << endl;
386 std::cerr << os.str ();
390 TEUCHOS_TEST_FOR_EXCEPTION
393 " != output.getNumVectors() = " << numVecs <<
".");
398 memory_space memSpace;
399 if (numVecs ==
size_t (1) || ! constStride) {
400 for (
size_t j = 0; j < numVecs; ++j) {
407 if (sameObject (*output_j, *input_j)) {
408 unaryTransformSameMultiVector (kernelLabel, execSpace,
412 transform_vec_notSameObject (kernelLabel, execSpace,
413 *input_j, *output_j, f);
418 if (sameObject (output, input)) {
419 unaryTransformSameMultiVector (kernelLabel, execSpace,
423 transform_mv_notSameObject (kernelLabel, execSpace,
430 template<
class BinaryFunctionType>
433 ExecutionSpace execSpace,
437 BinaryFunctionType f)
439 using Teuchos::TypeNameTraits;
441 const char prefix[] =
"Tpetra::transform (binary): ";
443 const int myRank = output.
getMap ()->getComm ()->getRank ();
446 std::ostringstream os;
447 os <<
"Proc " << myRank <<
": " << prefix << endl
448 <<
" Tpetra::MultiVector<" << TypeNameTraits<SC>::name ()
449 <<
", " << TypeNameTraits<LO>::name () <<
", "
450 << TypeNameTraits<GO>::name () <<
", "
451 << TypeNameTraits<NT>::name () <<
">" << endl
452 <<
" kernelLabel: " << kernelLabel << endl
453 <<
" ExecutionSpace: "
454 << TypeNameTraits<ExecutionSpace>::name () << endl;
455 std::cerr << os.str ();
459 TEUCHOS_TEST_FOR_EXCEPTION
461 prefix <<
"input1.getNumVectors() = " << input1.
getNumVectors ()
462 <<
" != output.getNumVectors() = " << numVecs <<
".");
463 TEUCHOS_TEST_FOR_EXCEPTION
465 prefix <<
"input2.getNumVectors() = " << input2.
getNumVectors ()
466 <<
" != output.getNumVectors() = " << numVecs <<
".");
470 memory_space memSpace;
473 using range_type = Kokkos::RangePolicy<ExecutionSpace, LO>;
474 range_type range (execSpace, 0, lclNumRows);
476 if (numVecs ==
size_t (1) || ! constStride) {
477 for (
size_t j = 0; j < numVecs; ++j) {
485 const bool outin1same = sameObject (*output_j, *input1_j);
486 const bool outin2same = sameObject (*output_j, *input2_j);
488 const bool in1in2same = sameObject (*input1_j, *input2_j);
489 const bool allsame = outin1same && outin2same;
495 vec_type& input1_j_ref = *input1_j;
496 vec_type& input2_j_ref = *input2_j;
497 vec_type& output_j_ref = *output_j;
502 using input1_view_type =
504 decltype (
readOnly (input1_j_ref).on (memSpace). at(execSpace))>;
505 using input2_view_type =
507 decltype (
readOnly (input2_j_ref).on (memSpace). at(execSpace))>;
508 using rw_output_view_type =
510 decltype (
readWrite (output_j_ref).on (memSpace). at(execSpace))>;
511 using wo_output_view_type =
513 decltype (
writeOnly (output_j_ref).on (memSpace). at(execSpace))>;
517 ([=] (
const rw_output_view_type& output_lcl) {
518 using functor_type = VectorBinaryTransformLoopBody<
519 typename rw_output_view_type::const_type,
520 typename rw_output_view_type::const_type,
522 BinaryFunctionType, LO>;
523 functor_type functor (output_lcl, output_lcl, output_lcl, f);
524 Kokkos::parallel_for (kernelLabel, range, functor);
526 readWrite (output_j_ref).on (memSpace).at (execSpace));
528 else if (in1in2same) {
530 ([=] (
const input1_view_type& input1_lcl,
531 const wo_output_view_type& output_lcl) {
532 using functor_type = VectorBinaryTransformLoopBody<
536 BinaryFunctionType, LO>;
537 functor_type functor (input1_lcl, input1_lcl, output_lcl, f);
538 Kokkos::parallel_for (kernelLabel, range, functor);
540 readOnly (input1_j_ref).on (memSpace).at (execSpace),
541 writeOnly (output_j_ref).on (memSpace).at (execSpace));
543 else if (outin1same) {
545 ([=] (
const input2_view_type& input2_lcl,
546 const rw_output_view_type& output_lcl) {
547 using functor_type = VectorBinaryTransformLoopBody<
548 typename rw_output_view_type::const_type,
551 BinaryFunctionType, LO>;
552 functor_type functor (output_lcl, input2_lcl, output_lcl, f);
553 Kokkos::parallel_for (kernelLabel, range, functor);
555 readOnly (input2_j_ref).on (memSpace).at (execSpace),
556 readWrite (output_j_ref).on (memSpace).at (execSpace));
558 else if (outin2same) {
560 ([=] (
const input1_view_type& input1_lcl,
561 const rw_output_view_type& output_lcl) {
562 using functor_type = VectorBinaryTransformLoopBody<
564 typename rw_output_view_type::const_type,
566 BinaryFunctionType, LO>;
567 functor_type functor (input1_lcl, output_lcl, output_lcl, f);
568 Kokkos::parallel_for (kernelLabel, range, functor);
570 readOnly (input1_j_ref).on (memSpace).at (execSpace),
571 readWrite (output_j_ref).on (memSpace).at (execSpace));
575 ([=] (
const input1_view_type& input1_lcl,
576 const input2_view_type& input2_lcl,
577 const wo_output_view_type& output_lcl) {
578 using functor_type = VectorBinaryTransformLoopBody<
582 BinaryFunctionType, LO>;
583 functor_type functor (input1_lcl, input2_lcl, output_lcl, f);
584 Kokkos::parallel_for (kernelLabel, range, functor);
586 readOnly (input1_j_ref).on (memSpace).at (execSpace),
587 readOnly (input2_j_ref).on (memSpace).at (execSpace),
588 writeOnly (output_j_ref).on (memSpace).at (execSpace));
596 const bool outin1same = sameObject (output, input1);
597 const bool outin2same = sameObject (output, input2);
599 const bool in1in2same = sameObject (input1, input2);
600 const bool allsame = outin1same && outin2same;
605 using input1_view_type =
607 decltype (
readOnly (input1).on (memSpace). at(execSpace))>;
608 using input2_view_type =
610 decltype (
readOnly (input2).on (memSpace). at(execSpace))>;
611 using rw_output_view_type =
613 decltype (
readWrite (output).on (memSpace). at(execSpace))>;
614 using wo_output_view_type =
616 decltype (
writeOnly (output).on (memSpace). at(execSpace))>;
620 ([=] (
const rw_output_view_type& output_lcl) {
621 using functor_type = MultiVectorBinaryTransformLoopBody<
622 typename rw_output_view_type::const_type,
623 typename rw_output_view_type::const_type,
625 BinaryFunctionType, LO>;
626 functor_type functor (output_lcl, output_lcl, output_lcl, f);
627 Kokkos::parallel_for (kernelLabel, range, functor);
629 readWrite (output).on (memSpace).at (execSpace));
631 else if (in1in2same) {
633 ([=] (
const input1_view_type& input1_lcl,
634 const wo_output_view_type& output_lcl) {
635 using functor_type = MultiVectorBinaryTransformLoopBody<
639 BinaryFunctionType, LO>;
640 functor_type functor (input1_lcl, input1_lcl, output_lcl, f);
641 Kokkos::parallel_for (kernelLabel, range, functor);
643 readOnly (input1).on (memSpace).at (execSpace),
644 writeOnly (output).on (memSpace).at (execSpace));
646 else if (outin1same) {
648 ([=] (
const input2_view_type& input2_lcl,
649 const rw_output_view_type& output_lcl) {
650 using functor_type = MultiVectorBinaryTransformLoopBody<
651 typename rw_output_view_type::const_type,
654 BinaryFunctionType, LO>;
655 functor_type functor (output_lcl, input2_lcl, output_lcl, f);
656 Kokkos::parallel_for (kernelLabel, range, functor);
658 readOnly (input2).on (memSpace).at (execSpace),
659 readWrite (output).on (memSpace).at (execSpace));
661 else if (outin2same) {
663 ([=] (
const input1_view_type& input1_lcl,
664 const rw_output_view_type& output_lcl) {
665 using functor_type = MultiVectorBinaryTransformLoopBody<
667 typename rw_output_view_type::const_type,
669 BinaryFunctionType, LO>;
670 functor_type functor (input1_lcl, output_lcl, output_lcl, f);
671 Kokkos::parallel_for (kernelLabel, range, functor);
673 readOnly (input1).on (memSpace).at (execSpace),
674 readWrite (output).on (memSpace).at (execSpace));
678 ([=] (
const input1_view_type& input1_lcl,
679 const input2_view_type& input2_lcl,
680 const wo_output_view_type& output_lcl) {
681 using functor_type = MultiVectorBinaryTransformLoopBody<
685 BinaryFunctionType, LO>;
686 functor_type functor (input1_lcl, input2_lcl, output_lcl, f);
687 Kokkos::parallel_for (kernelLabel, range, functor);
689 readOnly (input1).on (memSpace).at (execSpace),
690 readOnly (input2).on (memSpace).at (execSpace),
691 writeOnly (output).on (memSpace).at (execSpace));
703 template<
class ExecutionSpace,
704 class SC,
class LO,
class GO,
class NT>
709 template<
class UnaryFunctionType>
712 ExecutionSpace execSpace,
719 using UFT = UnaryFunctionType;
721 impl_type::template transform<UFT> (kernelLabel, execSpace,
726 template<
class BinaryFunctionType>
729 ExecutionSpace execSpace,
733 BinaryFunctionType f)
737 using BFT = BinaryFunctionType;
739 impl_type::template transform<BFT> (kernelLabel, execSpace,
740 input1, input2, output, f);
748 #endif // TPETRA_TRANSFORM_MULTIVECTOR_HPP
Include this file to make Tpetra::MultiVector and Tpetra::Vector work with Tpetra::withLocalAccess.
dual_view_type::t_host getLocalViewHost() const
A local Kokkos::View of host memory.
size_t getNumVectors() const
Number of columns in the multivector.
size_t getLocalLength() const
Local number of rows on the calling process.
bool isConstantStride() const
Whether this multivector has constant stride between columns.
One or more distributed dense vectors.
Details::LocalAccess< GlobalObjectType, Details::write_only > writeOnly(GlobalObjectType &)
Declare that you want to access the given global object's local data in write-only mode...
Teuchos::RCP< Vector< Scalar, LocalOrdinal, GlobalOrdinal, Node > > getVectorNonConst(const size_t j)
Return a Vector which is a nonconst view of column j.
Details::LocalAccess< GlobalObjectType, Details::read_write > readWrite(GlobalObjectType &)
Declare that you want to access the given global object's local data in read-and-write mode...
static bool verbose()
Whether Tpetra is in verbose mode.
void withLocalAccess(typename Details::ArgsToFunction< LocalAccessTypes...>::type userFunction, LocalAccessTypes...localAccesses)
Get access to a Tpetra global object's local data.
Include this file to make Tpetra::for_each work with Tpetra::MultiVector and Tpetra::Vector.
typename Kokkos::Details::ArithTraits< Scalar >::val_type impl_scalar_type
The type used internally in place of Scalar.
A distributed dense vector.
virtual Teuchos::RCP< const map_type > getMap() const
The Map describing the parallel distribution of this object.
Details::LocalAccess< GlobalObjectType, Details::read_only > readOnly(GlobalObjectType &)
Declare that you want to access the given global object's local data in read-only mode...
typename Details::GetNonowningLocalObject< LocalAccessType >::nonowning_local_object_type with_local_access_function_argument_type
Type of the local object, that is an argument to the function the user gives to withLocalAccess.
void for_each(const char kernelLabel[], ExecutionSpace execSpace, GlobalDataStructure &X, UserFunctionType f)
Apply a function entrywise to each local entry of a Tpetra global data structure, analogously to std:...
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.