Intrepid2
Intrepid2_DirectSumBasis.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Intrepid2 Package
4 //
5 // Copyright 2007 NTESS and the Intrepid2 contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
15 #ifndef Intrepid2_DirectSumBasis_h
16 #define Intrepid2_DirectSumBasis_h
17 
18 #include <Kokkos_DynRankView.hpp>
19 
20 namespace Intrepid2
21 {
32  template<typename BasisBaseClass>
33  class Basis_DirectSumBasis : public BasisBaseClass
34  {
35  public:
36  using BasisBase = BasisBaseClass;
37  using BasisPtr = Teuchos::RCP<BasisBase>;
38 
39  using DeviceType = typename BasisBase::DeviceType;
40  using ExecutionSpace = typename BasisBase::ExecutionSpace;
41  using OutputValueType = typename BasisBase::OutputValueType;
42  using PointValueType = typename BasisBase::PointValueType;
43 
44  using OrdinalTypeArray1DHost = typename BasisBase::OrdinalTypeArray1DHost;
45  using OrdinalTypeArray2DHost = typename BasisBase::OrdinalTypeArray2DHost;
46  using OutputViewType = typename BasisBase::OutputViewType;
47  using PointViewType = typename BasisBase::PointViewType;
48  using ScalarViewType = typename BasisBase::ScalarViewType;
49  protected:
50  BasisPtr basis1_;
51  BasisPtr basis2_;
52 
53  std::string name_;
54  public:
59  Basis_DirectSumBasis(BasisPtr basis1, BasisPtr basis2)
60  :
61  basis1_(basis1),basis2_(basis2)
62  {
63  INTREPID2_TEST_FOR_EXCEPTION(basis1->getBasisType() != basis2->getBasisType(), std::invalid_argument, "basis1 and basis2 must agree in basis type");
64  INTREPID2_TEST_FOR_EXCEPTION(basis1->getBaseCellTopology().getKey() != basis2->getBaseCellTopology().getKey(),
65  std::invalid_argument, "basis1 and basis2 must agree in cell topology");
66  INTREPID2_TEST_FOR_EXCEPTION(basis1->getNumTensorialExtrusions() != basis2->getNumTensorialExtrusions(),
67  std::invalid_argument, "basis1 and basis2 must agree in number of tensorial extrusions");
68  INTREPID2_TEST_FOR_EXCEPTION(basis1->getCoordinateSystem() != basis2->getCoordinateSystem(),
69  std::invalid_argument, "basis1 and basis2 must agree in coordinate system");
70 
71  this->basisCardinality_ = basis1->getCardinality() + basis2->getCardinality();
72  this->basisDegree_ = std::max(basis1->getDegree(), basis2->getDegree());
73 
74  {
75  std::ostringstream basisName;
76  basisName << basis1->getName() << " + " << basis2->getName();
77  name_ = basisName.str();
78  }
79 
80  this->basisCellTopologyKey_ = basis1->getBaseCellTopology().getKey();
81  this->basisType_ = basis1->getBasisType();
82  this->basisCoordinates_ = basis1->getCoordinateSystem();
83 
84  if (this->basisType_ == BASIS_FEM_HIERARCHICAL)
85  {
86  int degreeLength = basis1_->getPolynomialDegreeLength();
87  INTREPID2_TEST_FOR_EXCEPTION(degreeLength != basis2_->getPolynomialDegreeLength(), std::invalid_argument, "Basis1 and Basis2 must agree on polynomial degree length");
88 
89  this->fieldOrdinalPolynomialDegree_ = OrdinalTypeArray2DHost("DirectSumBasis degree lookup", this->basisCardinality_,degreeLength);
90  this->fieldOrdinalH1PolynomialDegree_ = OrdinalTypeArray2DHost("DirectSumBasis H^1 degree lookup",this->basisCardinality_,degreeLength);
91  // our field ordinals start with basis1_; basis2_ follows
92  for (int fieldOrdinal1=0; fieldOrdinal1<basis1_->getCardinality(); fieldOrdinal1++)
93  {
94  int fieldOrdinal = fieldOrdinal1;
95  auto polynomialDegree = basis1->getPolynomialDegreeOfField(fieldOrdinal1);
96  auto polynomialH1Degree = basis1->getH1PolynomialDegreeOfField(fieldOrdinal1);
97  for (int d=0; d<degreeLength; d++)
98  {
99  this->fieldOrdinalPolynomialDegree_ (fieldOrdinal,d) = polynomialDegree(d);
100  this->fieldOrdinalH1PolynomialDegree_(fieldOrdinal,d) = polynomialH1Degree(d);
101  }
102  }
103  for (int fieldOrdinal2=0; fieldOrdinal2<basis2_->getCardinality(); fieldOrdinal2++)
104  {
105  int fieldOrdinal = basis1->getCardinality() + fieldOrdinal2;
106 
107  auto polynomialDegree = basis2->getPolynomialDegreeOfField(fieldOrdinal2);
108  auto polynomialH1Degree = basis2->getH1PolynomialDegreeOfField(fieldOrdinal2);
109  for (int d=0; d<degreeLength; d++)
110  {
111  this->fieldOrdinalPolynomialDegree_ (fieldOrdinal,d) = polynomialDegree(d);
112  this->fieldOrdinalH1PolynomialDegree_(fieldOrdinal,d) = polynomialH1Degree(d);
113  }
114  }
115  }
116 
117  // initialize tags
118  {
119  const auto & cardinality = this->basisCardinality_;
120 
121  // Basis-dependent initializations
122  const ordinal_type tagSize = 4; // size of DoF tag, i.e., number of fields in the tag
123  const ordinal_type posScDim = 0; // position in the tag, counting from 0, of the subcell dim
124  const ordinal_type posScOrd = 1; // position in the tag, counting from 0, of the subcell ordinal
125  const ordinal_type posDfOrd = 2; // position in the tag, counting from 0, of DoF ordinal relative to the subcell
126 
127  OrdinalTypeArray1DHost tagView("tag view", cardinality*tagSize);
128 
129  shards::CellTopology cellTopo(getCellTopologyData(this->basisCellTopologyKey_));
130 
131  unsigned spaceDim = cellTopo.getDimension();
132 
133  ordinal_type basis2Offset = basis1_->getCardinality();
134 
135  for (unsigned d=0; d<=spaceDim; d++)
136  {
137  unsigned subcellCount = cellTopo.getSubcellCount(d);
138  for (unsigned subcellOrdinal=0; subcellOrdinal<subcellCount; subcellOrdinal++)
139  {
140  ordinal_type subcellDofCount1 = basis1->getDofCount(d, subcellOrdinal);
141  ordinal_type subcellDofCount2 = basis2->getDofCount(d, subcellOrdinal);
142 
143  ordinal_type subcellDofCount = subcellDofCount1 + subcellDofCount2;
144  for (ordinal_type localDofID=0; localDofID<subcellDofCount; localDofID++)
145  {
146  ordinal_type fieldOrdinal;
147  if (localDofID < subcellDofCount1)
148  {
149  // first basis: field ordinal matches the basis1 ordinal
150  fieldOrdinal = basis1_->getDofOrdinal(d, subcellOrdinal, localDofID);
151  }
152  else
153  {
154  // second basis: field ordinal is offset by basis1 cardinality
155  fieldOrdinal = basis2Offset + basis2_->getDofOrdinal(d, subcellOrdinal, localDofID - subcellDofCount1);
156  }
157  tagView(fieldOrdinal*tagSize+0) = d; // subcell dimension
158  tagView(fieldOrdinal*tagSize+1) = subcellOrdinal;
159  tagView(fieldOrdinal*tagSize+2) = localDofID;
160  tagView(fieldOrdinal*tagSize+3) = subcellDofCount;
161  }
162  }
163  }
164  // // Basis-independent function sets tag and enum data in tagToOrdinal_ and ordinalToTag_ arrays:
165  // // tags are constructed on host
166  this->setOrdinalTagData(this->tagToOrdinal_,
167  this->ordinalToTag_,
168  tagView,
169  this->basisCardinality_,
170  tagSize,
171  posScDim,
172  posScOrd,
173  posDfOrd);
174  }
175  }
176 
182  virtual BasisValues<OutputValueType,DeviceType> allocateBasisValues( TensorPoints<PointValueType,DeviceType> points, const EOperator operatorType = OPERATOR_VALUE) const override
183  {
184  BasisValues<OutputValueType,DeviceType> basisValues1 = basis1_->allocateBasisValues(points, operatorType);
185  BasisValues<OutputValueType,DeviceType> basisValues2 = basis2_->allocateBasisValues(points, operatorType);
186 
187  const int numScalarFamilies1 = basisValues1.numTensorDataFamilies();
188  if (numScalarFamilies1 > 0)
189  {
190  // then both basis1 and basis2 should be scalar-valued; check that for basis2:
191  const int numScalarFamilies2 = basisValues2.numTensorDataFamilies();
192  INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() <=0, std::invalid_argument, "When basis1 has scalar value, basis2 must also");
193  std::vector< TensorData<OutputValueType,DeviceType> > scalarFamilies(numScalarFamilies1 + numScalarFamilies2);
194  for (int i=0; i<numScalarFamilies1; i++)
195  {
196  scalarFamilies[i] = basisValues1.tensorData(i);
197  }
198  for (int i=0; i<numScalarFamilies2; i++)
199  {
200  scalarFamilies[i+numScalarFamilies1] = basisValues2.tensorData(i);
201  }
202  return BasisValues<OutputValueType,DeviceType>(scalarFamilies);
203  }
204  else
205  {
206  // then both basis1 and basis2 should be vector-valued; check that:
207  INTREPID2_TEST_FOR_EXCEPTION(!basisValues1.vectorData().isValid(), std::invalid_argument, "When basis1 does not have tensorData() defined, it must have a valid vectorData()");
208  INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() > 0, std::invalid_argument, "When basis1 has vector value, basis2 must also");
209 
210  const auto & vectorData1 = basisValues1.vectorData();
211  const auto & vectorData2 = basisValues2.vectorData();
212 
213  const int numFamilies1 = vectorData1.numFamilies();
214  const int numComponents = vectorData1.numComponents();
215  INTREPID2_TEST_FOR_EXCEPTION(numComponents != vectorData2.numComponents(), std::invalid_argument, "basis1 and basis2 must agree on the number of components in each vector");
216  const int numFamilies2 = vectorData2.numFamilies();
217 
218  const int numFamilies = numFamilies1 + numFamilies2;
219  std::vector< std::vector<TensorData<OutputValueType,DeviceType> > > vectorComponents(numFamilies, std::vector<TensorData<OutputValueType,DeviceType> >(numComponents));
220 
221  for (int i=0; i<numFamilies1; i++)
222  {
223  for (int j=0; j<numComponents; j++)
224  {
225  vectorComponents[i][j] = vectorData1.getComponent(i,j);
226  }
227  }
228  for (int i=0; i<numFamilies2; i++)
229  {
230  for (int j=0; j<numComponents; j++)
231  {
232  vectorComponents[i+numFamilies1][j] = vectorData2.getComponent(i,j);
233  }
234  }
235  VectorData<OutputValueType,DeviceType> vectorData(vectorComponents);
236  return BasisValues<OutputValueType,DeviceType>(vectorData);
237  }
238  }
239 
248  virtual void getDofCoords( ScalarViewType dofCoords ) const override {
249  const int basisCardinality1 = basis1_->getCardinality();
250  const int basisCardinality2 = basis2_->getCardinality();
251  const int basisCardinality = basisCardinality1 + basisCardinality2;
252 
253  auto dofCoords1 = Kokkos::subview(dofCoords, std::make_pair(0,basisCardinality1), Kokkos::ALL());
254  auto dofCoords2 = Kokkos::subview(dofCoords, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
255 
256  basis1_->getDofCoords(dofCoords1);
257  basis2_->getDofCoords(dofCoords2);
258  }
259 
271  virtual void getDofCoeffs( ScalarViewType dofCoeffs ) const override {
272  const int basisCardinality1 = basis1_->getCardinality();
273  const int basisCardinality2 = basis2_->getCardinality();
274  const int basisCardinality = basisCardinality1 + basisCardinality2;
275 
276  auto dofCoeffs1 = Kokkos::subview(dofCoeffs, std::make_pair(0,basisCardinality1), Kokkos::ALL());
277  auto dofCoeffs2 = Kokkos::subview(dofCoeffs, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
278 
279  basis1_->getDofCoeffs(dofCoeffs1);
280  basis2_->getDofCoeffs(dofCoeffs2);
281  }
282 
283 
288  virtual
289  const char*
290  getName() const override {
291  return name_.c_str();
292  }
293 
294  // since the getValues() below only overrides the FEM variants, we specify that
295  // we use the base class's getValues(), which implements the FVD variant by throwing an exception.
296  // (It's an error to use the FVD variant on this basis.)
297  using BasisBase::getValues;
298 
310  virtual
311  void
313  const TensorPoints<PointValueType,DeviceType> inputPoints,
314  const EOperator operatorType = OPERATOR_VALUE ) const override
315  {
316  const int fieldStartOrdinal1 = 0;
317  const int numFields1 = basis1_->getCardinality();
318  const int fieldStartOrdinal2 = numFields1;
319  const int numFields2 = basis2_->getCardinality();
320 
321  auto basisValues1 = outputValues.basisValuesForFields(fieldStartOrdinal1, numFields1);
322  auto basisValues2 = outputValues.basisValuesForFields(fieldStartOrdinal2, numFields2);
323 
324  basis1_->getValues(basisValues1, inputPoints, operatorType);
325  basis2_->getValues(basisValues2, inputPoints, operatorType);
326  }
327 
346  virtual void getValues( OutputViewType outputValues, const PointViewType inputPoints,
347  const EOperator operatorType = OPERATOR_VALUE ) const override
348  {
349  int cardinality1 = basis1_->getCardinality();
350  int cardinality2 = basis2_->getCardinality();
351 
352  auto range1 = std::make_pair(0,cardinality1);
353  auto range2 = std::make_pair(cardinality1,cardinality1+cardinality2);
354  if (outputValues.rank() == 2) // F,P
355  {
356  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL());
357  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL());
358 
359  basis1_->getValues(outputValues1, inputPoints, operatorType);
360  basis2_->getValues(outputValues2, inputPoints, operatorType);
361  }
362  else if (outputValues.rank() == 3) // F,P,D
363  {
364  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL());
365  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL());
366 
367  basis1_->getValues(outputValues1, inputPoints, operatorType);
368  basis2_->getValues(outputValues2, inputPoints, operatorType);
369  }
370  else if (outputValues.rank() == 4) // F,P,D,D
371  {
372  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
373  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
374 
375  basis1_->getValues(outputValues1, inputPoints, operatorType);
376  basis2_->getValues(outputValues2, inputPoints, operatorType);
377  }
378  else if (outputValues.rank() == 5) // F,P,D,D,D
379  {
380  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
381  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
382 
383  basis1_->getValues(outputValues1, inputPoints, operatorType);
384  basis2_->getValues(outputValues2, inputPoints, operatorType);
385  }
386  else if (outputValues.rank() == 6) // F,P,D,D,D,D
387  {
388  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
389  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
390 
391  basis1_->getValues(outputValues1, inputPoints, operatorType);
392  basis2_->getValues(outputValues2, inputPoints, operatorType);
393  }
394  else if (outputValues.rank() == 7) // F,P,D,D,D,D,D
395  {
396  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
397  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
398 
399  basis1_->getValues(outputValues1, inputPoints, operatorType);
400  basis2_->getValues(outputValues2, inputPoints, operatorType);
401  }
402  else
403  {
404  INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Unsupported outputValues rank");
405  }
406  }
407 
408  virtual int getNumTensorialExtrusions() const override
409  {
410  return basis1_->getNumTensorialExtrusions();
411  }
412  };
413 } // end namespace Intrepid2
414 
415 #endif /* Intrepid2_DirectSumBasis_h */
virtual void getValues(BasisValues< OutputValueType, DeviceType > outputValues, const TensorPoints< PointValueType, DeviceType > inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell, using point and output value containers that allow pre...
virtual void getDofCoords(ScalarViewType dofCoords) const override
Fills in spatial locations (coordinates) of degrees of freedom (nodes) on the reference cell...
virtual BasisValues< OutputValueType, DeviceType > allocateBasisValues(TensorPoints< PointValueType, DeviceType > points, const EOperator operatorType=OPERATOR_VALUE) const override
Allocate BasisValues container suitable for passing to the getValues() variant that takes a TensorPoi...
View-like interface to tensor points; point components are stored separately; the appropriate coordin...
BasisValues< Scalar, DeviceType > basisValuesForFields(const int &fieldStartOrdinal, const int &numFields)
field start and length must align with families in vectorData_ or tensorDataFamilies_ (whichever is v...
const VectorDataType & vectorData() const
VectorData accessor.
TensorDataType & tensorData()
TensorData accessor for single-family scalar data.
The data containers in Intrepid2 that support sum factorization and other reduced-data optimizations ...
virtual void getValues(OutputViewType outputValues, const PointViewType inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell.
A basis that is the direct sum of two other bases.
virtual const char * getName() const override
Returns basis name.
virtual void getDofCoeffs(ScalarViewType dofCoeffs) const override
Fills in coefficients of degrees of freedom for Lagrangian basis on the reference cell...
View-like interface to tensor data; tensor components are stored separately and multiplied together a...
Reference-space field values for a basis, designed to support typical vector-valued bases...
Basis_DirectSumBasis(BasisPtr basis1, BasisPtr basis2)
Constructor.