49 #ifndef Intrepid2_TensorViewIterator_h
50 #define Intrepid2_TensorViewIterator_h
55 #include "Kokkos_Vector.hpp"
72 template<
class TensorViewType,
class ViewType1,
class ViewType2 ,
typename ScalarType>
76 enum RankCombinationType
88 Kokkos::vector<RankCombinationType> rank_combination_types_;
109 KOKKOS_INLINE_FUNCTION
111 Kokkos::vector<RankCombinationType> rank_combination_types)
113 tensor_view_iterator_(tensor_view),
114 view1_iterator_(view1),
115 view2_iterator_(view2),
116 rank_combination_types_(rank_combination_types)
134 unsigned max_component_rank = (view1.rank() > view2.rank()) ? view1.rank() : view2.rank();
135 unsigned max_rank = (tensor_view.rank() > max_component_rank) ? tensor_view.rank() : max_component_rank;
137 INTREPID2_TEST_FOR_EXCEPTION_DEVICE_SAFE(rank_combination_types.extent(0) != max_rank, std::invalid_argument,
"need to provide RankCombinationType for the largest-rank View");
139 unsigned expected_rank = 0;
140 bool contracting =
false;
141 for (
unsigned d=0; d<rank_combination_types.extent(0); d++)
143 if (rank_combination_types[d] == TENSOR_CONTRACTION)
146 INTREPID2_TEST_FOR_EXCEPTION_DEVICE_SAFE(view1.extent_int(d) != view2.extent_int(d), std::invalid_argument,
"Contractions can only occur along ranks of equal length");
151 INTREPID2_TEST_FOR_EXCEPTION_DEVICE_SAFE(contracting, std::invalid_argument,
"encountered a non-contraction rank combination after a contraction; contractions can only go at the end");
153 if (rank_combination_types[d] == TENSOR_PRODUCT)
155 INTREPID2_TEST_FOR_EXCEPTION_DEVICE_SAFE(tensor_view.extent_int(d) != view1.extent_int(d) * view2.extent_int(d), std::invalid_argument,
"For TENSOR_PRODUCT rank combination, the tensor View must have length in that dimension equal to the product of the two component views in that dimension");
159 INTREPID2_TEST_FOR_EXCEPTION_DEVICE_SAFE(view1.extent_int(d) != view2.extent_int(d), std::invalid_argument,
"For DIMENSION_MATCH rank combination, all three views must have length equal to each other in that rank");
160 INTREPID2_TEST_FOR_EXCEPTION_DEVICE_SAFE(tensor_view.extent_int(d) != view1.extent_int(d), std::invalid_argument,
"For DIMENSION_MATCH rank combination, all three views must have length equal to each other in that rank");
164 INTREPID2_TEST_FOR_EXCEPTION_DEVICE_SAFE(expected_rank != tensor_view.rank(), std::invalid_argument,
"Tensor view does not match expected rank");
169 KOKKOS_INLINE_FUNCTION
174 if (view2_next_increment_rank > view1_next_increment_rank)
return view2_next_increment_rank;
175 else return view1_next_increment_rank;
180 KOKKOS_INLINE_FUNCTION
195 if (view2_next_increment_rank > view1_next_increment_rank)
198 device_assert(rank_combination_types_[view2_next_increment_rank]==TENSOR_PRODUCT);
199 view1_iterator_.
reset(view2_next_increment_rank);
202 return view2_next_increment_rank;
206 int view1_rank_change = view1_iterator_.
increment();
207 if (view1_rank_change >= 0)
209 switch (rank_combination_types_[view1_rank_change])
211 case DIMENSION_MATCH:
219 case TENSOR_CONTRACTION:
224 return view1_rank_change;
230 KOKKOS_INLINE_FUNCTION
241 KOKKOS_INLINE_FUNCTION
242 void setLocation(Kokkos::Array<int,7> location1, Kokkos::Array<int,7> location2)
246 Kokkos::Array<int,7> tensor_location = location1;
247 for (
unsigned d=0; d<rank_combination_types_.extent(0); d++)
249 switch (rank_combination_types_[d])
253 tensor_location[d] = location2[d] * view1_iterator_.
getExtent(d) + location1[d];
255 case DIMENSION_MATCH:
258 case TENSOR_CONTRACTION:
259 tensor_location[d] = 0;
263 #ifdef HAVE_INTREPID2_DEBUG
265 for (
unsigned d=0; d<rank_combination_types_.extent(0); d++)
267 switch (rank_combination_types_[d])
272 case DIMENSION_MATCH:
273 case TENSOR_CONTRACTION:
274 device_assert(location1[d] == location2[d]);
278 device_assert(location1[d] < view1_iterator_.
getExtent(d));
279 device_assert(location2[d] < view2_iterator_.
getExtent(d));
280 device_assert(tensor_location[d] < tensor_view_iterator_.
getExtent(d));
283 tensor_view_iterator_.
setLocation(tensor_location);
288 KOKKOS_INLINE_FUNCTION
291 return view1_iterator_.
get();
296 KOKKOS_INLINE_FUNCTION
299 return view2_iterator_.
get();
304 KOKKOS_INLINE_FUNCTION
305 void set(ScalarType value)
307 tensor_view_iterator_.
set(value);
KOKKOS_INLINE_FUNCTION ScalarType getView1Entry()
KOKKOS_INLINE_FUNCTION int nextIncrementRank()
KOKKOS_INLINE_FUNCTION int getExtent(int dimension)
KOKKOS_INLINE_FUNCTION TensorViewIterator(TensorViewType tensor_view, ViewType1 view1, ViewType2 view2, Kokkos::vector< RankCombinationType > rank_combination_types)
Constructor.
KOKKOS_INLINE_FUNCTION ScalarType get()
KOKKOS_INLINE_FUNCTION void setLocation(Kokkos::Array< int, 7 > location1, Kokkos::Array< int, 7 > location2)
Implementation of an assert that can safely be called from device code.
KOKKOS_INLINE_FUNCTION void reset(int from_rank_number=0)
KOKKOS_INLINE_FUNCTION void set(ScalarType &value)
Iterator allows linear traversal of (part of) a Kokkos View in a manner that is agnostic to its rank...
KOKKOS_INLINE_FUNCTION void set(ScalarType value)
KOKKOS_INLINE_FUNCTION int increment()
KOKKOS_INLINE_FUNCTION void setLocation(const Kokkos::Array< int, 7 > location)
KOKKOS_INLINE_FUNCTION ScalarType getView2Entry()
KOKKOS_INLINE_FUNCTION int increment()
KOKKOS_INLINE_FUNCTION void setLocation(const Kokkos::Array< int, 7 > &location)
KOKKOS_INLINE_FUNCTION int nextIncrementRank()
A helper class that allows iteration over three Kokkos Views simultaneously, according to tensor comb...