10 #ifndef TPETRA_DETAILS_LCLDOT_HPP
11 #define TPETRA_DETAILS_LCLDOT_HPP
20 #include "Kokkos_DualView.hpp"
21 #if KOKKOS_VERSION >= 40799
22 #include "KokkosKernels_ArithTraits.hpp"
24 #include "Kokkos_ArithTraits.hpp"
26 #include "KokkosBlas1_dot.hpp"
27 #include "Teuchos_ArrayView.hpp"
28 #include "Teuchos_TestForException.hpp"
33 template <
class RV,
class XMV>
34 void lclDot(
const RV& dotsOut,
37 const size_t lclNumRows,
39 const size_t whichVecsX[],
40 const size_t whichVecsY[],
41 const bool constantStrideX,
42 const bool constantStrideY) {
44 using Kokkos::subview;
45 typedef typename RV::non_const_value_type dot_type;
46 #ifdef HAVE_TPETRA_DEBUG
47 const char prefix[] =
"Tpetra::MultiVector::lclDotImpl: ";
48 #endif // HAVE_TPETRA_DEBUG
50 static_assert(Kokkos::is_view<RV>::value,
51 "Tpetra::MultiVector::lclDotImpl: "
52 "The first argument dotsOut is not a Kokkos::View.");
53 static_assert(RV::rank == 1,
54 "Tpetra::MultiVector::lclDotImpl: "
55 "The first argument dotsOut must have rank 1.");
56 static_assert(Kokkos::is_view<XMV>::value,
57 "Tpetra::MultiVector::lclDotImpl: The type of the 2nd and "
58 "3rd arguments (X_lcl and Y_lcl) is not a Kokkos::View.");
59 static_assert(XMV::rank == 2,
60 "Tpetra::MultiVector::lclDotImpl: "
61 "X_lcl and Y_lcl must have rank 2.");
66 const std::pair<size_t, size_t> rowRng(0, lclNumRows);
67 const std::pair<size_t, size_t> colRng(0, numVecs);
68 RV theDots = subview(dotsOut, colRng);
69 XMV X = subview(X_lcl, rowRng, ALL());
70 XMV Y = subview(Y_lcl, rowRng, ALL());
72 #ifdef HAVE_TPETRA_DEBUG
73 if (lclNumRows != 0) {
74 TEUCHOS_TEST_FOR_EXCEPTION(X.extent(0) != lclNumRows, std::logic_error, prefix <<
"X.extent(0) = " << X.extent(0) <<
" != lclNumRows "
77 "Please report this bug to the Tpetra developers.");
78 TEUCHOS_TEST_FOR_EXCEPTION(Y.extent(0) != lclNumRows, std::logic_error, prefix <<
"Y.extent(0) = " << Y.extent(0) <<
" != lclNumRows "
81 "Please report this bug to the Tpetra developers.");
85 TEUCHOS_TEST_FOR_EXCEPTION(constantStrideX &&
86 (X.extent(0) != lclNumRows || X.extent(1) != numVecs),
87 std::logic_error, prefix <<
"X is " << X.extent(0) <<
" x " << X.extent(1) <<
" (constant stride), which differs from the "
89 << lclNumRows <<
" x " << numVecs <<
". "
90 "Please report this bug to the Tpetra developers.");
91 TEUCHOS_TEST_FOR_EXCEPTION(!constantStrideX &&
92 (X.extent(0) != lclNumRows || X.extent(1) < numVecs),
93 std::logic_error, prefix <<
"X is " << X.extent(0) <<
" x " << X.extent(1) <<
" (NOT constant stride), but the local "
95 << lclNumRows <<
" x " << numVecs <<
". "
96 "Please report this bug to the Tpetra developers.");
97 TEUCHOS_TEST_FOR_EXCEPTION(constantStrideY &&
98 (Y.extent(0) != lclNumRows || Y.extent(1) != numVecs),
99 std::logic_error, prefix <<
"Y is " << Y.extent(0) <<
" x " << Y.extent(1) <<
" (constant stride), which differs from the "
101 << lclNumRows <<
" x " << numVecs <<
". "
102 "Please report this bug to the Tpetra developers.");
103 TEUCHOS_TEST_FOR_EXCEPTION(!constantStrideY &&
104 (Y.extent(0) != lclNumRows || Y.extent(1) < numVecs),
105 std::logic_error, prefix <<
"Y is " << Y.extent(0) <<
" x " << Y.extent(1) <<
" (NOT constant stride), but the local "
107 << lclNumRows <<
" x " << numVecs <<
". "
108 "Please report this bug to the Tpetra developers.");
110 #endif // HAVE_TPETRA_DEBUG
112 if (lclNumRows == 0) {
113 #if KOKKOS_VERSION >= 40799
114 const dot_type zero = KokkosKernels::ArithTraits<dot_type>::zero();
116 const dot_type zero = Kokkos::ArithTraits<dot_type>::zero();
121 if (constantStrideX && constantStrideY) {
122 if (X.extent(1) == 1) {
123 typename RV::non_const_value_type result =
124 KokkosBlas::dot(subview(X, ALL(), 0), subview(Y, ALL(), 0));
128 KokkosBlas::dot(theDots, X, Y);
136 for (
size_t k = 0; k < numVecs; ++k) {
137 const size_t X_col = constantStrideX ? k : whichVecsX[k];
138 const size_t Y_col = constantStrideY ? k : whichVecsY[k];
139 KokkosBlas::dot(subview(theDots, k), subview(X, ALL(), X_col),
140 subview(Y, ALL(), Y_col));
149 #endif // TPETRA_DETAILS_LCLDOT_HPP
void deep_copy(MultiVector< DS, DL, DG, DN > &dst, const MultiVector< SS, SL, SG, SN > &src)
Copy the contents of the MultiVector src into dst.