42 #ifndef TPETRA_DETAILS_LCLDOT_HPP
43 #define TPETRA_DETAILS_LCLDOT_HPP
52 #include "Kokkos_DualView.hpp"
53 #include "Kokkos_ArithTraits.hpp"
54 #include "KokkosBlas1_dot.hpp"
55 #include "Teuchos_ArrayView.hpp"
56 #include "Teuchos_TestForException.hpp"
61 template<
class RV,
class XMV>
63 lclDot (
const RV& dotsOut,
66 const size_t lclNumRows,
68 const size_t whichVecsX[],
69 const size_t whichVecsY[],
70 const bool constantStrideX,
71 const bool constantStrideY)
74 using Kokkos::subview;
75 typedef typename RV::non_const_value_type dot_type;
76 #ifdef HAVE_TPETRA_DEBUG
77 const char prefix[] =
"Tpetra::MultiVector::lclDotImpl: ";
78 #endif // HAVE_TPETRA_DEBUG
80 static_assert (Kokkos::Impl::is_view<RV>::value,
81 "Tpetra::MultiVector::lclDotImpl: "
82 "The first argument dotsOut is not a Kokkos::View.");
83 static_assert (RV::rank == 1,
"Tpetra::MultiVector::lclDotImpl: "
84 "The first argument dotsOut must have rank 1.");
85 static_assert (Kokkos::Impl::is_view<XMV>::value,
86 "Tpetra::MultiVector::lclDotImpl: The type of the 2nd and "
87 "3rd arguments (X_lcl and Y_lcl) is not a Kokkos::View.");
88 static_assert (XMV::rank == 2,
"Tpetra::MultiVector::lclDotImpl: "
89 "X_lcl and Y_lcl must have rank 2.");
94 const std::pair<size_t, size_t> rowRng (0, lclNumRows);
95 const std::pair<size_t, size_t> colRng (0, numVecs);
96 RV theDots = subview (dotsOut, colRng);
97 XMV X = subview (X_lcl, rowRng, ALL ());
98 XMV Y = subview (Y_lcl, rowRng, ALL ());
100 #ifdef HAVE_TPETRA_DEBUG
101 if (lclNumRows != 0) {
102 TEUCHOS_TEST_FOR_EXCEPTION
103 (X.extent (0) != lclNumRows, std::logic_error, prefix <<
104 "X.extent(0) = " << X.extent (0) <<
" != lclNumRows "
105 "= " << lclNumRows <<
". "
106 "Please report this bug to the Tpetra developers.");
107 TEUCHOS_TEST_FOR_EXCEPTION
108 (Y.extent (0) != lclNumRows, std::logic_error, prefix <<
109 "Y.extent(0) = " << Y.extent (0) <<
" != lclNumRows "
110 "= " << lclNumRows <<
". "
111 "Please report this bug to the Tpetra developers.");
115 TEUCHOS_TEST_FOR_EXCEPTION
117 (X.extent (0) != lclNumRows || X.extent (1) != numVecs),
118 std::logic_error, prefix <<
"X is " << X.extent (0) <<
" x " <<
119 X.extent (1) <<
" (constant stride), which differs from the "
120 "local dimensions " << lclNumRows <<
" x " << numVecs <<
". "
121 "Please report this bug to the Tpetra developers.");
122 TEUCHOS_TEST_FOR_EXCEPTION
123 (! constantStrideX &&
124 (X.extent (0) != lclNumRows || X.extent (1) < numVecs),
125 std::logic_error, prefix <<
"X is " << X.extent (0) <<
" x " <<
126 X.extent (1) <<
" (NOT constant stride), but the local "
127 "dimensions are " << lclNumRows <<
" x " << numVecs <<
". "
128 "Please report this bug to the Tpetra developers.");
129 TEUCHOS_TEST_FOR_EXCEPTION
131 (Y.extent (0) != lclNumRows || Y.extent (1) != numVecs),
132 std::logic_error, prefix <<
"Y is " << Y.extent (0) <<
" x " <<
133 Y.extent (1) <<
" (constant stride), which differs from the "
134 "local dimensions " << lclNumRows <<
" x " << numVecs <<
". "
135 "Please report this bug to the Tpetra developers.");
136 TEUCHOS_TEST_FOR_EXCEPTION
137 (! constantStrideY &&
138 (Y.extent (0) != lclNumRows || Y.extent (1) < numVecs),
139 std::logic_error, prefix <<
"Y is " << Y.extent (0) <<
" x " <<
140 Y.extent (1) <<
" (NOT constant stride), but the local "
141 "dimensions are " << lclNumRows <<
" x " << numVecs <<
". "
142 "Please report this bug to the Tpetra developers.");
144 #endif // HAVE_TPETRA_DEBUG
146 if (lclNumRows == 0) {
147 const dot_type zero = Kokkos::Details::ArithTraits<dot_type>::zero ();
151 if (constantStrideX && constantStrideY) {
152 if (X.extent (1) == 1) {
153 typename RV::non_const_value_type result =
154 KokkosBlas::dot (subview (X, ALL (), 0), subview (Y, ALL (), 0));
158 KokkosBlas::dot (theDots, X, Y);
167 for (
size_t k = 0; k < numVecs; ++k) {
168 const size_t X_col = constantStrideX ? k : whichVecsX[k];
169 const size_t Y_col = constantStrideY ? k : whichVecsY[k];
170 KokkosBlas::dot (subview (theDots, k), subview (X, ALL (), X_col),
171 subview (Y, ALL (), Y_col));
180 #endif // TPETRA_DETAILS_LCLDOT_HPP
void deep_copy(MultiVector< DS, DL, DG, DN > &dst, const MultiVector< SS, SL, SG, SN > &src)
Copy the contents of the MultiVector src into dst.