Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_Details_lclDot.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Tpetra: Templated Linear Algebra Services Package
4 //
5 // Copyright 2008 NTESS and the Tpetra contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef TPETRA_DETAILS_LCLDOT_HPP
11 #define TPETRA_DETAILS_LCLDOT_HPP
12 
19 
20 #include "Kokkos_DualView.hpp"
21 #include "Kokkos_ArithTraits.hpp"
22 #include "KokkosBlas1_dot.hpp"
23 #include "Teuchos_ArrayView.hpp"
24 #include "Teuchos_TestForException.hpp"
25 
26 namespace Tpetra {
27 namespace Details {
28 
29 template<class RV, class XMV>
30 void
31 lclDot (const RV& dotsOut,
32  const XMV& X_lcl,
33  const XMV& Y_lcl,
34  const size_t lclNumRows,
35  const size_t numVecs,
36  const size_t whichVecsX[],
37  const size_t whichVecsY[],
38  const bool constantStrideX,
39  const bool constantStrideY)
40 {
41  using Kokkos::ALL;
42  using Kokkos::subview;
43  typedef typename RV::non_const_value_type dot_type;
44 #ifdef HAVE_TPETRA_DEBUG
45  const char prefix[] = "Tpetra::MultiVector::lclDotImpl: ";
46 #endif // HAVE_TPETRA_DEBUG
47 
48  static_assert (Kokkos::is_view<RV>::value,
49  "Tpetra::MultiVector::lclDotImpl: "
50  "The first argument dotsOut is not a Kokkos::View.");
51  static_assert (RV::rank == 1, "Tpetra::MultiVector::lclDotImpl: "
52  "The first argument dotsOut must have rank 1.");
53  static_assert (Kokkos::is_view<XMV>::value,
54  "Tpetra::MultiVector::lclDotImpl: The type of the 2nd and "
55  "3rd arguments (X_lcl and Y_lcl) is not a Kokkos::View.");
56  static_assert (XMV::rank == 2, "Tpetra::MultiVector::lclDotImpl: "
57  "X_lcl and Y_lcl must have rank 2.");
58 
59  // In case the input dimensions don't match, make sure that we
60  // don't overwrite memory that doesn't belong to us, by using
61  // subset views with the minimum dimensions over all input.
62  const std::pair<size_t, size_t> rowRng (0, lclNumRows);
63  const std::pair<size_t, size_t> colRng (0, numVecs);
64  RV theDots = subview (dotsOut, colRng);
65  XMV X = subview (X_lcl, rowRng, ALL ());
66  XMV Y = subview (Y_lcl, rowRng, ALL ());
67 
68 #ifdef HAVE_TPETRA_DEBUG
69  if (lclNumRows != 0) {
70  TEUCHOS_TEST_FOR_EXCEPTION
71  (X.extent (0) != lclNumRows, std::logic_error, prefix <<
72  "X.extent(0) = " << X.extent (0) << " != lclNumRows "
73  "= " << lclNumRows << ". "
74  "Please report this bug to the Tpetra developers.");
75  TEUCHOS_TEST_FOR_EXCEPTION
76  (Y.extent (0) != lclNumRows, std::logic_error, prefix <<
77  "Y.extent(0) = " << Y.extent (0) << " != lclNumRows "
78  "= " << lclNumRows << ". "
79  "Please report this bug to the Tpetra developers.");
80  // If a MultiVector is constant stride, then numVecs should
81  // equal its View's number of columns. Otherwise, numVecs
82  // should be less than its View's number of columns.
83  TEUCHOS_TEST_FOR_EXCEPTION
84  (constantStrideX &&
85  (X.extent (0) != lclNumRows || X.extent (1) != numVecs),
86  std::logic_error, prefix << "X is " << X.extent (0) << " x " <<
87  X.extent (1) << " (constant stride), which differs from the "
88  "local dimensions " << lclNumRows << " x " << numVecs << ". "
89  "Please report this bug to the Tpetra developers.");
90  TEUCHOS_TEST_FOR_EXCEPTION
91  (! constantStrideX &&
92  (X.extent (0) != lclNumRows || X.extent (1) < numVecs),
93  std::logic_error, prefix << "X is " << X.extent (0) << " x " <<
94  X.extent (1) << " (NOT constant stride), but the local "
95  "dimensions are " << lclNumRows << " x " << numVecs << ". "
96  "Please report this bug to the Tpetra developers.");
97  TEUCHOS_TEST_FOR_EXCEPTION
98  (constantStrideY &&
99  (Y.extent (0) != lclNumRows || Y.extent (1) != numVecs),
100  std::logic_error, prefix << "Y is " << Y.extent (0) << " x " <<
101  Y.extent (1) << " (constant stride), which differs from the "
102  "local dimensions " << lclNumRows << " x " << numVecs << ". "
103  "Please report this bug to the Tpetra developers.");
104  TEUCHOS_TEST_FOR_EXCEPTION
105  (! constantStrideY &&
106  (Y.extent (0) != lclNumRows || Y.extent (1) < numVecs),
107  std::logic_error, prefix << "Y is " << Y.extent (0) << " x " <<
108  Y.extent (1) << " (NOT constant stride), but the local "
109  "dimensions are " << lclNumRows << " x " << numVecs << ". "
110  "Please report this bug to the Tpetra developers.");
111  }
112 #endif // HAVE_TPETRA_DEBUG
113 
114  if (lclNumRows == 0) {
115  const dot_type zero = Kokkos::ArithTraits<dot_type>::zero ();
116  // DEEP_COPY REVIEW - NOT TESTED
117  Kokkos::deep_copy (theDots, zero);
118  }
119  else { // lclNumRows != 0
120  if (constantStrideX && constantStrideY) {
121  if (X.extent (1) == 1) {
122  typename RV::non_const_value_type result =
123  KokkosBlas::dot (subview (X, ALL (), 0), subview (Y, ALL (), 0));
124  // DEEP_COPY REVIEW - NOT TESTED
125  Kokkos::deep_copy (theDots, result);
126  }
127  else {
128  KokkosBlas::dot (theDots, X, Y);
129  }
130  }
131  else { // not constant stride
132  // NOTE (mfh 15 Jul 2014) This does a kernel launch for
133  // every column. It might be better to have a kernel that
134  // does the work all at once. On the other hand, we don't
135  // prioritize performance of MultiVector views of
136  // noncontiguous columns.
137  for (size_t k = 0; k < numVecs; ++k) {
138  const size_t X_col = constantStrideX ? k : whichVecsX[k];
139  const size_t Y_col = constantStrideY ? k : whichVecsY[k];
140  KokkosBlas::dot (subview (theDots, k), subview (X, ALL (), X_col),
141  subview (Y, ALL (), Y_col));
142  } // for each column
143  } // constantStride
144  } // lclNumRows != 0
145 }
146 
147 } // namespace Details
148 } // namespace Tpetra
149 
150 #endif // TPETRA_DETAILS_LCLDOT_HPP
void deep_copy(MultiVector< DS, DL, DG, DN > &dst, const MultiVector< SS, SL, SG, SN > &src)
Copy the contents of the MultiVector src into dst.