10 #ifndef TPETRA_DETAILS_LCLDOT_HPP
11 #define TPETRA_DETAILS_LCLDOT_HPP
20 #include "Kokkos_DualView.hpp"
21 #include "Kokkos_ArithTraits.hpp"
22 #include "KokkosBlas1_dot.hpp"
23 #include "Teuchos_ArrayView.hpp"
24 #include "Teuchos_TestForException.hpp"
29 template<
class RV,
class XMV>
31 lclDot (
const RV& dotsOut,
34 const size_t lclNumRows,
36 const size_t whichVecsX[],
37 const size_t whichVecsY[],
38 const bool constantStrideX,
39 const bool constantStrideY)
42 using Kokkos::subview;
43 typedef typename RV::non_const_value_type dot_type;
44 #ifdef HAVE_TPETRA_DEBUG
45 const char prefix[] =
"Tpetra::MultiVector::lclDotImpl: ";
46 #endif // HAVE_TPETRA_DEBUG
48 static_assert (Kokkos::is_view<RV>::value,
49 "Tpetra::MultiVector::lclDotImpl: "
50 "The first argument dotsOut is not a Kokkos::View.");
51 static_assert (RV::rank == 1,
"Tpetra::MultiVector::lclDotImpl: "
52 "The first argument dotsOut must have rank 1.");
53 static_assert (Kokkos::is_view<XMV>::value,
54 "Tpetra::MultiVector::lclDotImpl: The type of the 2nd and "
55 "3rd arguments (X_lcl and Y_lcl) is not a Kokkos::View.");
56 static_assert (XMV::rank == 2,
"Tpetra::MultiVector::lclDotImpl: "
57 "X_lcl and Y_lcl must have rank 2.");
62 const std::pair<size_t, size_t> rowRng (0, lclNumRows);
63 const std::pair<size_t, size_t> colRng (0, numVecs);
64 RV theDots = subview (dotsOut, colRng);
65 XMV X = subview (X_lcl, rowRng, ALL ());
66 XMV Y = subview (Y_lcl, rowRng, ALL ());
68 #ifdef HAVE_TPETRA_DEBUG
69 if (lclNumRows != 0) {
70 TEUCHOS_TEST_FOR_EXCEPTION
71 (X.extent (0) != lclNumRows, std::logic_error, prefix <<
72 "X.extent(0) = " << X.extent (0) <<
" != lclNumRows "
73 "= " << lclNumRows <<
". "
74 "Please report this bug to the Tpetra developers.");
75 TEUCHOS_TEST_FOR_EXCEPTION
76 (Y.extent (0) != lclNumRows, std::logic_error, prefix <<
77 "Y.extent(0) = " << Y.extent (0) <<
" != lclNumRows "
78 "= " << lclNumRows <<
". "
79 "Please report this bug to the Tpetra developers.");
83 TEUCHOS_TEST_FOR_EXCEPTION
85 (X.extent (0) != lclNumRows || X.extent (1) != numVecs),
86 std::logic_error, prefix <<
"X is " << X.extent (0) <<
" x " <<
87 X.extent (1) <<
" (constant stride), which differs from the "
88 "local dimensions " << lclNumRows <<
" x " << numVecs <<
". "
89 "Please report this bug to the Tpetra developers.");
90 TEUCHOS_TEST_FOR_EXCEPTION
92 (X.extent (0) != lclNumRows || X.extent (1) < numVecs),
93 std::logic_error, prefix <<
"X is " << X.extent (0) <<
" x " <<
94 X.extent (1) <<
" (NOT constant stride), but the local "
95 "dimensions are " << lclNumRows <<
" x " << numVecs <<
". "
96 "Please report this bug to the Tpetra developers.");
97 TEUCHOS_TEST_FOR_EXCEPTION
99 (Y.extent (0) != lclNumRows || Y.extent (1) != numVecs),
100 std::logic_error, prefix <<
"Y is " << Y.extent (0) <<
" x " <<
101 Y.extent (1) <<
" (constant stride), which differs from the "
102 "local dimensions " << lclNumRows <<
" x " << numVecs <<
". "
103 "Please report this bug to the Tpetra developers.");
104 TEUCHOS_TEST_FOR_EXCEPTION
105 (! constantStrideY &&
106 (Y.extent (0) != lclNumRows || Y.extent (1) < numVecs),
107 std::logic_error, prefix <<
"Y is " << Y.extent (0) <<
" x " <<
108 Y.extent (1) <<
" (NOT constant stride), but the local "
109 "dimensions are " << lclNumRows <<
" x " << numVecs <<
". "
110 "Please report this bug to the Tpetra developers.");
112 #endif // HAVE_TPETRA_DEBUG
114 if (lclNumRows == 0) {
115 const dot_type zero = Kokkos::ArithTraits<dot_type>::zero ();
120 if (constantStrideX && constantStrideY) {
121 if (X.extent (1) == 1) {
122 typename RV::non_const_value_type result =
123 KokkosBlas::dot (subview (X, ALL (), 0), subview (Y, ALL (), 0));
128 KokkosBlas::dot (theDots, X, Y);
137 for (
size_t k = 0; k < numVecs; ++k) {
138 const size_t X_col = constantStrideX ? k : whichVecsX[k];
139 const size_t Y_col = constantStrideY ? k : whichVecsY[k];
140 KokkosBlas::dot (subview (theDots, k), subview (X, ALL (), X_col),
141 subview (Y, ALL (), Y_col));
150 #endif // TPETRA_DETAILS_LCLDOT_HPP
void deep_copy(MultiVector< DS, DL, DG, DN > &dst, const MultiVector< SS, SL, SG, SN > &src)
Copy the contents of the MultiVector src into dst.