Panzer  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Panzer_Product_impl.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Panzer: A partial differential equation assembly
4 // engine for strongly coupled complex multiphysics systems
5 //
6 // Copyright 2011 NTESS and the Panzer contributors.
7 // SPDX-License-Identifier: BSD-3-Clause
8 // *****************************************************************************
9 // @HEADER
10 
11 #ifndef PANZER_PRODUCT_IMPL_HPP
12 #define PANZER_PRODUCT_IMPL_HPP
13 
14 #include <cstddef>
15 #include <string>
16 #include <vector>
17 
18 namespace panzer {
19 
20 //**********************************************************************
21 template<typename EvalT, typename Traits>
24  const Teuchos::ParameterList& p)
25  : scaling(1.0)
26 {
27  std::string product_name = p.get<std::string>("Product Name");
29  p.get<Teuchos::RCP<std::vector<std::string> > >("Values Names");
30  Teuchos::RCP<PHX::DataLayout> data_layout =
31  p.get< Teuchos::RCP<PHX::DataLayout> >("Data Layout");
32 
33  if(p.isType<double>("Scaling"))
34  scaling = p.get<double>("Scaling");
35 
36  product = PHX::MDField<ScalarT>(product_name, data_layout);
37 
38  this->addEvaluatedField(product);
39 
40  values.resize(value_names->size());
41  for (std::size_t i=0; i < value_names->size(); ++i) {
42  values[i] = PHX::MDField<const ScalarT>( (*value_names)[i], data_layout);
43  this->addDependentField(values[i]);
44  }
45 
46  std::string n = "Product Evaluator";
47  this->setName(n);
48 }
49 
50 
51 template<int RANK, typename Scalar>
53 
56 
58  : base_(base),source_(source) {}
59 
60  KOKKOS_INLINE_FUNCTION
61  void operator() (const PHX::index_t & ind1) const
62  {
63  using idx_t = PHX::index_t;
64 
65  if (RANK == 1){
66  base_(ind1) = base_(ind1)*source_(ind1);
67  }
68  else if (RANK == 2){
69  for (idx_t ind2=0; ind2 < static_cast<idx_t>(base_.extent(1)); ind2++)
70  base_(ind1,ind2) = base_(ind1,ind2)*source_(ind1,ind2);
71  }
72  else if (RANK == 3){
73  for (idx_t ind2=0; ind2 < static_cast<idx_t>(base_.extent(1)); ind2++)
74  for (idx_t ind3=0; ind3 < static_cast<idx_t>(base_.extent(2)); ind3++)
75  base_(ind1,ind2,ind3) = base_(ind1,ind2,ind3)*source_(ind1,ind2,ind3);
76  }
77  else if (RANK == 4){
78  for (idx_t ind2=0; ind2 < static_cast<idx_t>(base_.extent(1)); ind2++)
79  for (idx_t ind3=0; ind3 < static_cast<idx_t>(base_.extent(2)); ind3++)
80  for (idx_t ind4=0; ind4 < static_cast<idx_t>(base_.extent(3)); ind4++)
81  base_(ind1,ind2,ind3,ind4) = base_(ind1,ind2,ind3,ind4)*source_(ind1,ind2,ind3,ind4);
82  }
83  else if (RANK == 5){
84  for (idx_t ind2=0; ind2 < static_cast<idx_t>(base_.extent(1)); ind2++)
85  for (idx_t ind3=0; ind3 < static_cast<idx_t>(base_.extent(2)); ind3++)
86  for (idx_t ind4=0; ind4 < static_cast<idx_t>(base_.extent(3)); ind4++)
87  for (idx_t ind5=0; ind5 < static_cast<idx_t>(base_.extent(4)); ind5++)
88  base_(ind1,ind2,ind3,ind4,ind5) = base_(ind1,ind2,ind3,ind4,ind5)*source_(ind1,ind2,ind3,ind4,ind5);
89  }
90  else if (RANK == 6){
91  for (idx_t ind2=0; ind2 < static_cast<idx_t>(base_.extent(1)); ind2++)
92  for (idx_t ind3=0; ind3 < static_cast<idx_t>(base_.extent(2)); ind3++)
93  for (idx_t ind4=0; ind4 < static_cast<idx_t>(base_.extent(3)); ind4++)
94  for (idx_t ind5=0; ind5 < static_cast<idx_t>(base_.extent(4)); ind5++)
95  for (idx_t ind6=0; ind6 < static_cast<idx_t>(base_.extent(5)); ind6++)
96  base_(ind1,ind2,ind3,ind4,ind5,ind6) = base_(ind1,ind2,ind3,ind4,ind5,ind6)*source_(ind1,ind2,ind3,ind4,ind5,ind6);
97  }
98  else if (RANK == 7){
99  for (idx_t ind2=0; ind2 < static_cast<idx_t>(base_.extent(1)); ind2++)
100  for (idx_t ind3=0; ind3 < static_cast<idx_t>(base_.extent(2)); ind3++)
101  for (idx_t ind4=0; ind4 < static_cast<idx_t>(base_.extent(3)); ind4++)
102  for (idx_t ind5=0; ind5 < static_cast<idx_t>(base_.extent(4)); ind5++)
103  for (idx_t ind6=0; ind6 < static_cast<idx_t>(base_.extent(5)); ind6++)
104  for (idx_t ind7=0; ind7 < static_cast<idx_t>(base_.extent(6)); ind7++)
105  base_(ind1,ind2,ind3,ind4,ind5,ind6,ind7) = base_(ind1,ind2,ind3,ind4,ind5,ind6,ind7)*source_(ind1,ind2,ind3,ind4,ind5,ind6,ind7);
106  }
107  }
108 };
109 
110 //**********************************************************************
111 template<typename EvalT, typename Traits>
112 void
115  typename Traits::EvalData /* workset */)
116 {
117  product.deep_copy(ScalarT(scaling));
118 
119  for (std::size_t i = 0; i < values.size(); ++i) {
120  const auto length = product.extent(0);
121  if (product.rank() == 1){
122  Kokkos::parallel_for( length, V_MultiplyFunctor<1,ScalarT>(product,values[i]) );
123  }
124  else if (product.rank() == 2){
125  Kokkos::parallel_for( length, V_MultiplyFunctor<2,ScalarT>(product,values[i]) );
126  }
127  else if (product.rank() == 3){
128  Kokkos::parallel_for( length, V_MultiplyFunctor<3,ScalarT>(product,values[i]) );
129  }
130  else if (product.rank() == 4){
131  Kokkos::parallel_for( length, V_MultiplyFunctor<4,ScalarT>(product,values[i]) );
132  }
133  else if (product.rank() == 5){
134  Kokkos::parallel_for( length, V_MultiplyFunctor<5,ScalarT>(product,values[i]) );
135  }
136  else if (product.rank() == 6){
137  Kokkos::parallel_for( length, V_MultiplyFunctor<6,ScalarT>(product,values[i]) );
138  }
139  else if (product.rank() == 7){
140  Kokkos::parallel_for( length, V_MultiplyFunctor<7,ScalarT>(product,values[i]) );
141  }
142  }
143 }
144 
145 //**********************************************************************
146 
147 }
148 
149 #endif
KOKKOS_INLINE_FUNCTION void operator()(const PHX::index_t &ind1) const
T & get(const std::string &name, T def_value)
const PHX::MDField< Scalar > base_
Product(const Teuchos::ParameterList &p)
V_MultiplyFunctor(PHX::MDField< Scalar > &base, const PHX::MDField< const Scalar > &source)
const PHX::MDField< const Scalar > source_
PHX::MDField< ScalarT > product
bool isType(const std::string &name) const
std::vector< PHX::MDField< const ScalarT > > values
typename EvalT::ScalarT ScalarT
void evaluateFields(typename Traits::EvalData d)