Intrepid2
Intrepid2_DirectSumBasis.hpp
Go to the documentation of this file.
1 // @HEADER
2 // ************************************************************************
3 //
4 // Intrepid2 Package
5 // Copyright (2007) Sandia Corporation
6 //
7 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
8 // license for use of this work by or on behalf of the U.S. Government.
9 //
10 // Redistribution and use in source and binary forms, with or without
11 // modification, are permitted provided that the following conditions are
12 // met:
13 //
14 // 1. Redistributions of source code must retain the above copyright
15 // notice, this list of conditions and the following disclaimer.
16 //
17 // 2. Redistributions in binary form must reproduce the above copyright
18 // notice, this list of conditions and the following disclaimer in the
19 // documentation and/or other materials provided with the distribution.
20 //
21 // 3. Neither the name of the Corporation nor the names of the
22 // contributors may be used to endorse or promote products derived from
23 // this software without specific prior written permission.
24 //
25 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
26 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
28 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
29 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
30 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
32 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 //
37 // Questions? Contact Kyungjoo Kim (kyukim@sandia.gov),
38 // Mauro Perego (mperego@sandia.gov), or
39 // Nate Roberts (nvrober@sandia.gov)
40 //
41 // ************************************************************************
42 // @HEADER
43 
49 #ifndef Intrepid2_DirectSumBasis_h
50 #define Intrepid2_DirectSumBasis_h
51 
52 #include <Kokkos_DynRankView.hpp>
53 
54 namespace Intrepid2
55 {
66  template<typename BasisBaseClass>
67  class Basis_DirectSumBasis : public BasisBaseClass
68  {
69  public:
70  using BasisBase = BasisBaseClass;
71  using BasisPtr = Teuchos::RCP<BasisBase>;
72 
73  using DeviceType = typename BasisBase::DeviceType;
74  using ExecutionSpace = typename BasisBase::ExecutionSpace;
75  using OutputValueType = typename BasisBase::OutputValueType;
76  using PointValueType = typename BasisBase::PointValueType;
77 
78  using OrdinalTypeArray1DHost = typename BasisBase::OrdinalTypeArray1DHost;
79  using OrdinalTypeArray2DHost = typename BasisBase::OrdinalTypeArray2DHost;
80  using OutputViewType = typename BasisBase::OutputViewType;
81  using PointViewType = typename BasisBase::PointViewType;
82  using ScalarViewType = typename BasisBase::ScalarViewType;
83  protected:
84  BasisPtr basis1_;
85  BasisPtr basis2_;
86 
87  std::string name_;
88  public:
93  Basis_DirectSumBasis(BasisPtr basis1, BasisPtr basis2)
94  :
95  basis1_(basis1),basis2_(basis2)
96  {
97  INTREPID2_TEST_FOR_EXCEPTION(basis1->getBasisType() != basis2->getBasisType(), std::invalid_argument, "basis1 and basis2 must agree in basis type");
98  INTREPID2_TEST_FOR_EXCEPTION(basis1->getBaseCellTopology().getKey() != basis2->getBaseCellTopology().getKey(),
99  std::invalid_argument, "basis1 and basis2 must agree in cell topology");
100  INTREPID2_TEST_FOR_EXCEPTION(basis1->getNumTensorialExtrusions() != basis2->getNumTensorialExtrusions(),
101  std::invalid_argument, "basis1 and basis2 must agree in number of tensorial extrusions");
102  INTREPID2_TEST_FOR_EXCEPTION(basis1->getCoordinateSystem() != basis2->getCoordinateSystem(),
103  std::invalid_argument, "basis1 and basis2 must agree in coordinate system");
104 
105  this->basisCardinality_ = basis1->getCardinality() + basis2->getCardinality();
106  this->basisDegree_ = std::max(basis1->getDegree(), basis2->getDegree());
107 
108  {
109  std::ostringstream basisName;
110  basisName << basis1->getName() << " + " << basis2->getName();
111  name_ = basisName.str();
112  }
113 
114  this->basisCellTopology_ = basis1->getBaseCellTopology();
115  this->basisType_ = basis1->getBasisType();
116  this->basisCoordinates_ = basis1->getCoordinateSystem();
117 
118  if (this->basisType_ == BASIS_FEM_HIERARCHICAL)
119  {
120  int degreeLength = basis1_->getPolynomialDegreeLength();
121  INTREPID2_TEST_FOR_EXCEPTION(degreeLength != basis2_->getPolynomialDegreeLength(), std::invalid_argument, "Basis1 and Basis2 must agree on polynomial degree length");
122 
123  this->fieldOrdinalPolynomialDegree_ = OrdinalTypeArray2DHost("DirectSumBasis degree lookup", this->basisCardinality_,degreeLength);
124  this->fieldOrdinalH1PolynomialDegree_ = OrdinalTypeArray2DHost("DirectSumBasis H^1 degree lookup",this->basisCardinality_,degreeLength);
125  // our field ordinals start with basis1_; basis2_ follows
126  for (int fieldOrdinal1=0; fieldOrdinal1<basis1_->getCardinality(); fieldOrdinal1++)
127  {
128  int fieldOrdinal = fieldOrdinal1;
129  auto polynomialDegree = basis1->getPolynomialDegreeOfField(fieldOrdinal1);
130  auto polynomialH1Degree = basis1->getH1PolynomialDegreeOfField(fieldOrdinal1);
131  for (int d=0; d<degreeLength; d++)
132  {
133  this->fieldOrdinalPolynomialDegree_ (fieldOrdinal,d) = polynomialDegree(d);
134  this->fieldOrdinalH1PolynomialDegree_(fieldOrdinal,d) = polynomialH1Degree(d);
135  }
136  }
137  for (int fieldOrdinal2=0; fieldOrdinal2<basis2_->getCardinality(); fieldOrdinal2++)
138  {
139  int fieldOrdinal = basis1->getCardinality() + fieldOrdinal2;
140 
141  auto polynomialDegree = basis2->getPolynomialDegreeOfField(fieldOrdinal2);
142  auto polynomialH1Degree = basis2->getH1PolynomialDegreeOfField(fieldOrdinal2);
143  for (int d=0; d<degreeLength; d++)
144  {
145  this->fieldOrdinalPolynomialDegree_ (fieldOrdinal,d) = polynomialDegree(d);
146  this->fieldOrdinalH1PolynomialDegree_(fieldOrdinal,d) = polynomialH1Degree(d);
147  }
148  }
149  }
150 
151  // initialize tags
152  {
153  const auto & cardinality = this->basisCardinality_;
154 
155  // Basis-dependent initializations
156  const ordinal_type tagSize = 4; // size of DoF tag, i.e., number of fields in the tag
157  const ordinal_type posScDim = 0; // position in the tag, counting from 0, of the subcell dim
158  const ordinal_type posScOrd = 1; // position in the tag, counting from 0, of the subcell ordinal
159  const ordinal_type posDfOrd = 2; // position in the tag, counting from 0, of DoF ordinal relative to the subcell
160 
161  OrdinalTypeArray1DHost tagView("tag view", cardinality*tagSize);
162 
163  shards::CellTopology cellTopo = this->basisCellTopology_;
164 
165  unsigned spaceDim = cellTopo.getDimension();
166 
167  ordinal_type basis2Offset = basis1_->getCardinality();
168 
169  for (unsigned d=0; d<=spaceDim; d++)
170  {
171  unsigned subcellCount = cellTopo.getSubcellCount(d);
172  for (unsigned subcellOrdinal=0; subcellOrdinal<subcellCount; subcellOrdinal++)
173  {
174  ordinal_type subcellDofCount1 = basis1->getDofCount(d, subcellOrdinal);
175  ordinal_type subcellDofCount2 = basis2->getDofCount(d, subcellOrdinal);
176 
177  ordinal_type subcellDofCount = subcellDofCount1 + subcellDofCount2;
178  for (ordinal_type localDofID=0; localDofID<subcellDofCount; localDofID++)
179  {
180  ordinal_type fieldOrdinal;
181  if (localDofID < subcellDofCount1)
182  {
183  // first basis: field ordinal matches the basis1 ordinal
184  fieldOrdinal = basis1_->getDofOrdinal(d, subcellOrdinal, localDofID);
185  }
186  else
187  {
188  // second basis: field ordinal is offset by basis1 cardinality
189  fieldOrdinal = basis2Offset + basis2_->getDofOrdinal(d, subcellOrdinal, localDofID - subcellDofCount1);
190  }
191  tagView(fieldOrdinal*tagSize+0) = d; // subcell dimension
192  tagView(fieldOrdinal*tagSize+1) = subcellOrdinal;
193  tagView(fieldOrdinal*tagSize+2) = localDofID;
194  tagView(fieldOrdinal*tagSize+3) = subcellDofCount;
195  }
196  }
197  }
198  // // Basis-independent function sets tag and enum data in tagToOrdinal_ and ordinalToTag_ arrays:
199  // // tags are constructed on host
200  this->setOrdinalTagData(this->tagToOrdinal_,
201  this->ordinalToTag_,
202  tagView,
203  this->basisCardinality_,
204  tagSize,
205  posScDim,
206  posScOrd,
207  posDfOrd);
208  }
209  }
210 
216  virtual BasisValues<OutputValueType,DeviceType> allocateBasisValues( TensorPoints<PointValueType,DeviceType> points, const EOperator operatorType = OPERATOR_VALUE) const override
217  {
218  BasisValues<OutputValueType,DeviceType> basisValues1 = basis1_->allocateBasisValues(points, operatorType);
219  BasisValues<OutputValueType,DeviceType> basisValues2 = basis2_->allocateBasisValues(points, operatorType);
220 
221  const int numScalarFamilies1 = basisValues1.numTensorDataFamilies();
222  if (numScalarFamilies1 > 0)
223  {
224  // then both basis1 and basis2 should be scalar-valued; check that for basis2:
225  const int numScalarFamilies2 = basisValues2.numTensorDataFamilies();
226  INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() <=0, std::invalid_argument, "When basis1 has scalar value, basis2 must also");
227  std::vector< TensorData<OutputValueType,DeviceType> > scalarFamilies(numScalarFamilies1 + numScalarFamilies2);
228  for (int i=0; i<numScalarFamilies1; i++)
229  {
230  scalarFamilies[i] = basisValues1.tensorData(i);
231  }
232  for (int i=0; i<numScalarFamilies2; i++)
233  {
234  scalarFamilies[i+numScalarFamilies1] = basisValues2.tensorData(i);
235  }
236  return BasisValues<OutputValueType,DeviceType>(scalarFamilies);
237  }
238  else
239  {
240  // then both basis1 and basis2 should be vector-valued; check that:
241  INTREPID2_TEST_FOR_EXCEPTION(!basisValues1.vectorData().isValid(), std::invalid_argument, "When basis1 does not have tensorData() defined, it must have a valid vectorData()");
242  INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() > 0, std::invalid_argument, "When basis1 has vector value, basis2 must also");
243 
244  const auto & vectorData1 = basisValues1.vectorData();
245  const auto & vectorData2 = basisValues2.vectorData();
246 
247  const int numFamilies1 = vectorData1.numFamilies();
248  const int numComponents = vectorData1.numComponents();
249  INTREPID2_TEST_FOR_EXCEPTION(numComponents != vectorData2.numComponents(), std::invalid_argument, "basis1 and basis2 must agree on the number of components in each vector");
250  const int numFamilies2 = vectorData2.numFamilies();
251 
252  const int numFamilies = numFamilies1 + numFamilies2;
253  std::vector< std::vector<TensorData<OutputValueType,DeviceType> > > vectorComponents(numFamilies, std::vector<TensorData<OutputValueType,DeviceType> >(numComponents));
254 
255  for (int i=0; i<numFamilies1; i++)
256  {
257  for (int j=0; j<numComponents; j++)
258  {
259  vectorComponents[i][j] = vectorData1.getComponent(i,j);
260  }
261  }
262  for (int i=0; i<numFamilies2; i++)
263  {
264  for (int j=0; j<numComponents; j++)
265  {
266  vectorComponents[i+numFamilies1][j] = vectorData2.getComponent(i,j);
267  }
268  }
269  VectorData<OutputValueType,DeviceType> vectorData(vectorComponents);
270  return BasisValues<OutputValueType,DeviceType>(vectorData);
271  }
272  }
273 
282  virtual void getDofCoords( ScalarViewType dofCoords ) const override {
283  const int basisCardinality1 = basis1_->getCardinality();
284  const int basisCardinality2 = basis2_->getCardinality();
285  const int basisCardinality = basisCardinality1 + basisCardinality2;
286 
287  auto dofCoords1 = Kokkos::subview(dofCoords, std::make_pair(0,basisCardinality1), Kokkos::ALL());
288  auto dofCoords2 = Kokkos::subview(dofCoords, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
289 
290  basis1_->getDofCoords(dofCoords1);
291  basis2_->getDofCoords(dofCoords2);
292  }
293 
305  virtual void getDofCoeffs( ScalarViewType dofCoeffs ) const override {
306  const int basisCardinality1 = basis1_->getCardinality();
307  const int basisCardinality2 = basis2_->getCardinality();
308  const int basisCardinality = basisCardinality1 + basisCardinality2;
309 
310  auto dofCoeffs1 = Kokkos::subview(dofCoeffs, std::make_pair(0,basisCardinality1), Kokkos::ALL());
311  auto dofCoeffs2 = Kokkos::subview(dofCoeffs, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
312 
313  basis1_->getDofCoeffs(dofCoeffs1);
314  basis2_->getDofCoeffs(dofCoeffs2);
315  }
316 
317 
322  virtual
323  const char*
324  getName() const override {
325  return name_.c_str();
326  }
327 
328  // since the getValues() below only overrides the FEM variants, we specify that
329  // we use the base class's getValues(), which implements the FVD variant by throwing an exception.
330  // (It's an error to use the FVD variant on this basis.)
331  using BasisBase::getValues;
332 
344  virtual
345  void
347  const TensorPoints<PointValueType,DeviceType> inputPoints,
348  const EOperator operatorType = OPERATOR_VALUE ) const override
349  {
350  const int fieldStartOrdinal1 = 0;
351  const int numFields1 = basis1_->getCardinality();
352  const int fieldStartOrdinal2 = numFields1;
353  const int numFields2 = basis2_->getCardinality();
354 
355  auto basisValues1 = outputValues.basisValuesForFields(fieldStartOrdinal1, numFields1);
356  auto basisValues2 = outputValues.basisValuesForFields(fieldStartOrdinal2, numFields2);
357 
358  basis1_->getValues(basisValues1, inputPoints, operatorType);
359  basis2_->getValues(basisValues2, inputPoints, operatorType);
360  }
361 
380  virtual void getValues( OutputViewType outputValues, const PointViewType inputPoints,
381  const EOperator operatorType = OPERATOR_VALUE ) const override
382  {
383  int cardinality1 = basis1_->getCardinality();
384  int cardinality2 = basis2_->getCardinality();
385 
386  auto range1 = std::make_pair(0,cardinality1);
387  auto range2 = std::make_pair(cardinality1,cardinality1+cardinality2);
388  if (outputValues.rank() == 2) // F,P
389  {
390  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL());
391  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL());
392 
393  basis1_->getValues(outputValues1, inputPoints, operatorType);
394  basis2_->getValues(outputValues2, inputPoints, operatorType);
395  }
396  else if (outputValues.rank() == 3) // F,P,D
397  {
398  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL());
399  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL());
400 
401  basis1_->getValues(outputValues1, inputPoints, operatorType);
402  basis2_->getValues(outputValues2, inputPoints, operatorType);
403  }
404  else if (outputValues.rank() == 4) // F,P,D,D
405  {
406  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
407  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
408 
409  basis1_->getValues(outputValues1, inputPoints, operatorType);
410  basis2_->getValues(outputValues2, inputPoints, operatorType);
411  }
412  else if (outputValues.rank() == 5) // F,P,D,D,D
413  {
414  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
415  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
416 
417  basis1_->getValues(outputValues1, inputPoints, operatorType);
418  basis2_->getValues(outputValues2, inputPoints, operatorType);
419  }
420  else if (outputValues.rank() == 6) // F,P,D,D,D,D
421  {
422  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
423  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
424 
425  basis1_->getValues(outputValues1, inputPoints, operatorType);
426  basis2_->getValues(outputValues2, inputPoints, operatorType);
427  }
428  else if (outputValues.rank() == 7) // F,P,D,D,D,D,D
429  {
430  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
431  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
432 
433  basis1_->getValues(outputValues1, inputPoints, operatorType);
434  basis2_->getValues(outputValues2, inputPoints, operatorType);
435  }
436  else
437  {
438  INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Unsupported outputValues rank");
439  }
440  }
441 
442  virtual int getNumTensorialExtrusions() const override
443  {
444  return basis1_->getNumTensorialExtrusions();
445  }
446  };
447 } // end namespace Intrepid2
448 
449 #endif /* Intrepid2_DirectSumBasis_h */
virtual void getValues(BasisValues< OutputValueType, DeviceType > outputValues, const TensorPoints< PointValueType, DeviceType > inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell, using point and output value containers that allow pre...
virtual void getDofCoords(ScalarViewType dofCoords) const override
Fills in spatial locations (coordinates) of degrees of freedom (nodes) on the reference cell...
virtual BasisValues< OutputValueType, DeviceType > allocateBasisValues(TensorPoints< PointValueType, DeviceType > points, const EOperator operatorType=OPERATOR_VALUE) const override
Allocate BasisValues container suitable for passing to the getValues() variant that takes a TensorPoi...
View-like interface to tensor points; point components are stored separately; the appropriate coordin...
const VectorDataType & vectorData() const
VectorData accessor.
BasisValues< Scalar, ExecSpaceType > basisValuesForFields(const int &fieldStartOrdinal, const int &numFields)
field start and length must align with families in vectorData_ or tensorDataFamilies_ (whichever is v...
The data containers in Intrepid2 that support sum factorization and other reduced-data optimizations ...
virtual void getValues(OutputViewType outputValues, const PointViewType inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell.
A basis that is the direct sum of two other bases.
virtual const char * getName() const override
Returns basis name.
virtual void getDofCoeffs(ScalarViewType dofCoeffs) const override
Fills in coefficients of degrees of freedom for Lagrangian basis on the reference cell...
TensorDataType & tensorData()
TensorData accessor for single-family scalar data.
View-like interface to tensor data; tensor components are stored separately and multiplied together a...
Reference-space field values for a basis, designed to support typical vector-valued bases...
Basis_DirectSumBasis(BasisPtr basis1, BasisPtr basis2)
Constructor.