Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_Details_lclDot.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Tpetra: Templated Linear Algebra Services Package
4 //
5 // Copyright 2008 NTESS and the Tpetra contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef TPETRA_DETAILS_LCLDOT_HPP
11 #define TPETRA_DETAILS_LCLDOT_HPP
12 
19 
20 #include "Kokkos_DualView.hpp"
21 #if KOKKOS_VERSION >= 40799
22 #include "KokkosKernels_ArithTraits.hpp"
23 #else
24 #include "Kokkos_ArithTraits.hpp"
25 #endif
26 #include "KokkosBlas1_dot.hpp"
27 #include "Teuchos_ArrayView.hpp"
28 #include "Teuchos_TestForException.hpp"
29 
30 namespace Tpetra {
31 namespace Details {
32 
33 template <class RV, class XMV>
34 void lclDot(const RV& dotsOut,
35  const XMV& X_lcl,
36  const XMV& Y_lcl,
37  const size_t lclNumRows,
38  const size_t numVecs,
39  const size_t whichVecsX[],
40  const size_t whichVecsY[],
41  const bool constantStrideX,
42  const bool constantStrideY) {
43  using Kokkos::ALL;
44  using Kokkos::subview;
45  typedef typename RV::non_const_value_type dot_type;
46 #ifdef HAVE_TPETRA_DEBUG
47  const char prefix[] = "Tpetra::MultiVector::lclDotImpl: ";
48 #endif // HAVE_TPETRA_DEBUG
49 
50  static_assert(Kokkos::is_view<RV>::value,
51  "Tpetra::MultiVector::lclDotImpl: "
52  "The first argument dotsOut is not a Kokkos::View.");
53  static_assert(RV::rank == 1,
54  "Tpetra::MultiVector::lclDotImpl: "
55  "The first argument dotsOut must have rank 1.");
56  static_assert(Kokkos::is_view<XMV>::value,
57  "Tpetra::MultiVector::lclDotImpl: The type of the 2nd and "
58  "3rd arguments (X_lcl and Y_lcl) is not a Kokkos::View.");
59  static_assert(XMV::rank == 2,
60  "Tpetra::MultiVector::lclDotImpl: "
61  "X_lcl and Y_lcl must have rank 2.");
62 
63  // In case the input dimensions don't match, make sure that we
64  // don't overwrite memory that doesn't belong to us, by using
65  // subset views with the minimum dimensions over all input.
66  const std::pair<size_t, size_t> rowRng(0, lclNumRows);
67  const std::pair<size_t, size_t> colRng(0, numVecs);
68  RV theDots = subview(dotsOut, colRng);
69  XMV X = subview(X_lcl, rowRng, ALL());
70  XMV Y = subview(Y_lcl, rowRng, ALL());
71 
72 #ifdef HAVE_TPETRA_DEBUG
73  if (lclNumRows != 0) {
74  TEUCHOS_TEST_FOR_EXCEPTION(X.extent(0) != lclNumRows, std::logic_error, prefix << "X.extent(0) = " << X.extent(0) << " != lclNumRows "
75  "= "
76  << lclNumRows << ". "
77  "Please report this bug to the Tpetra developers.");
78  TEUCHOS_TEST_FOR_EXCEPTION(Y.extent(0) != lclNumRows, std::logic_error, prefix << "Y.extent(0) = " << Y.extent(0) << " != lclNumRows "
79  "= "
80  << lclNumRows << ". "
81  "Please report this bug to the Tpetra developers.");
82  // If a MultiVector is constant stride, then numVecs should
83  // equal its View's number of columns. Otherwise, numVecs
84  // should be less than its View's number of columns.
85  TEUCHOS_TEST_FOR_EXCEPTION(constantStrideX &&
86  (X.extent(0) != lclNumRows || X.extent(1) != numVecs),
87  std::logic_error, prefix << "X is " << X.extent(0) << " x " << X.extent(1) << " (constant stride), which differs from the "
88  "local dimensions "
89  << lclNumRows << " x " << numVecs << ". "
90  "Please report this bug to the Tpetra developers.");
91  TEUCHOS_TEST_FOR_EXCEPTION(!constantStrideX &&
92  (X.extent(0) != lclNumRows || X.extent(1) < numVecs),
93  std::logic_error, prefix << "X is " << X.extent(0) << " x " << X.extent(1) << " (NOT constant stride), but the local "
94  "dimensions are "
95  << lclNumRows << " x " << numVecs << ". "
96  "Please report this bug to the Tpetra developers.");
97  TEUCHOS_TEST_FOR_EXCEPTION(constantStrideY &&
98  (Y.extent(0) != lclNumRows || Y.extent(1) != numVecs),
99  std::logic_error, prefix << "Y is " << Y.extent(0) << " x " << Y.extent(1) << " (constant stride), which differs from the "
100  "local dimensions "
101  << lclNumRows << " x " << numVecs << ". "
102  "Please report this bug to the Tpetra developers.");
103  TEUCHOS_TEST_FOR_EXCEPTION(!constantStrideY &&
104  (Y.extent(0) != lclNumRows || Y.extent(1) < numVecs),
105  std::logic_error, prefix << "Y is " << Y.extent(0) << " x " << Y.extent(1) << " (NOT constant stride), but the local "
106  "dimensions are "
107  << lclNumRows << " x " << numVecs << ". "
108  "Please report this bug to the Tpetra developers.");
109  }
110 #endif // HAVE_TPETRA_DEBUG
111 
112  if (lclNumRows == 0) {
113 #if KOKKOS_VERSION >= 40799
114  const dot_type zero = KokkosKernels::ArithTraits<dot_type>::zero();
115 #else
116  const dot_type zero = Kokkos::ArithTraits<dot_type>::zero();
117 #endif
118  // DEEP_COPY REVIEW - NOT TESTED
119  Kokkos::deep_copy(theDots, zero);
120  } else { // lclNumRows != 0
121  if (constantStrideX && constantStrideY) {
122  if (X.extent(1) == 1) {
123  typename RV::non_const_value_type result =
124  KokkosBlas::dot(subview(X, ALL(), 0), subview(Y, ALL(), 0));
125  // DEEP_COPY REVIEW - NOT TESTED
126  Kokkos::deep_copy(theDots, result);
127  } else {
128  KokkosBlas::dot(theDots, X, Y);
129  }
130  } else { // not constant stride
131  // NOTE (mfh 15 Jul 2014) This does a kernel launch for
132  // every column. It might be better to have a kernel that
133  // does the work all at once. On the other hand, we don't
134  // prioritize performance of MultiVector views of
135  // noncontiguous columns.
136  for (size_t k = 0; k < numVecs; ++k) {
137  const size_t X_col = constantStrideX ? k : whichVecsX[k];
138  const size_t Y_col = constantStrideY ? k : whichVecsY[k];
139  KokkosBlas::dot(subview(theDots, k), subview(X, ALL(), X_col),
140  subview(Y, ALL(), Y_col));
141  } // for each column
142  } // constantStride
143  } // lclNumRows != 0
144 }
145 
146 } // namespace Details
147 } // namespace Tpetra
148 
149 #endif // TPETRA_DETAILS_LCLDOT_HPP
void deep_copy(MultiVector< DS, DL, DG, DN > &dst, const MultiVector< SS, SL, SG, SN > &src)
Copy the contents of the MultiVector src into dst.