15 #ifndef Intrepid2_DirectSumBasis_h
16 #define Intrepid2_DirectSumBasis_h
18 #include <Kokkos_DynRankView.hpp>
32 template<
typename BasisBaseClass>
36 using BasisBase = BasisBaseClass;
37 using BasisPtr = Teuchos::RCP<BasisBase>;
39 using DeviceType =
typename BasisBase::DeviceType;
40 using ExecutionSpace =
typename BasisBase::ExecutionSpace;
41 using OutputValueType =
typename BasisBase::OutputValueType;
42 using PointValueType =
typename BasisBase::PointValueType;
44 using OrdinalTypeArray1DHost =
typename BasisBase::OrdinalTypeArray1DHost;
45 using OrdinalTypeArray2DHost =
typename BasisBase::OrdinalTypeArray2DHost;
46 using OutputViewType =
typename BasisBase::OutputViewType;
47 using PointViewType =
typename BasisBase::PointViewType;
48 using ScalarViewType =
typename BasisBase::ScalarViewType;
61 basis1_(basis1),basis2_(basis2)
63 INTREPID2_TEST_FOR_EXCEPTION(basis1->getBasisType() != basis2->getBasisType(), std::invalid_argument,
"basis1 and basis2 must agree in basis type");
64 INTREPID2_TEST_FOR_EXCEPTION(basis1->getBaseCellTopology().getKey() != basis2->getBaseCellTopology().getKey(),
65 std::invalid_argument,
"basis1 and basis2 must agree in cell topology");
66 INTREPID2_TEST_FOR_EXCEPTION(basis1->getNumTensorialExtrusions() != basis2->getNumTensorialExtrusions(),
67 std::invalid_argument,
"basis1 and basis2 must agree in number of tensorial extrusions");
68 INTREPID2_TEST_FOR_EXCEPTION(basis1->getCoordinateSystem() != basis2->getCoordinateSystem(),
69 std::invalid_argument,
"basis1 and basis2 must agree in coordinate system");
71 this->basisCardinality_ = basis1->getCardinality() + basis2->getCardinality();
72 this->basisDegree_ = std::max(basis1->getDegree(), basis2->getDegree());
75 std::ostringstream basisName;
76 basisName << basis1->getName() <<
" + " << basis2->getName();
77 name_ = basisName.str();
80 this->basisCellTopologyKey_ = basis1->getBaseCellTopology().getKey();
81 this->basisType_ = basis1->getBasisType();
82 this->basisCoordinates_ = basis1->getCoordinateSystem();
84 if (this->basisType_ == BASIS_FEM_HIERARCHICAL)
86 int degreeLength = basis1_->getPolynomialDegreeLength();
87 INTREPID2_TEST_FOR_EXCEPTION(degreeLength != basis2_->getPolynomialDegreeLength(), std::invalid_argument,
"Basis1 and Basis2 must agree on polynomial degree length");
89 this->fieldOrdinalPolynomialDegree_ = OrdinalTypeArray2DHost(
"DirectSumBasis degree lookup", this->basisCardinality_,degreeLength);
90 this->fieldOrdinalH1PolynomialDegree_ = OrdinalTypeArray2DHost(
"DirectSumBasis H^1 degree lookup",this->basisCardinality_,degreeLength);
92 for (
int fieldOrdinal1=0; fieldOrdinal1<basis1_->getCardinality(); fieldOrdinal1++)
94 int fieldOrdinal = fieldOrdinal1;
95 auto polynomialDegree = basis1->getPolynomialDegreeOfField(fieldOrdinal1);
96 auto polynomialH1Degree = basis1->getH1PolynomialDegreeOfField(fieldOrdinal1);
97 for (
int d=0; d<degreeLength; d++)
99 this->fieldOrdinalPolynomialDegree_ (fieldOrdinal,d) = polynomialDegree(d);
100 this->fieldOrdinalH1PolynomialDegree_(fieldOrdinal,d) = polynomialH1Degree(d);
103 for (
int fieldOrdinal2=0; fieldOrdinal2<basis2_->getCardinality(); fieldOrdinal2++)
105 int fieldOrdinal = basis1->getCardinality() + fieldOrdinal2;
107 auto polynomialDegree = basis2->getPolynomialDegreeOfField(fieldOrdinal2);
108 auto polynomialH1Degree = basis2->getH1PolynomialDegreeOfField(fieldOrdinal2);
109 for (
int d=0; d<degreeLength; d++)
111 this->fieldOrdinalPolynomialDegree_ (fieldOrdinal,d) = polynomialDegree(d);
112 this->fieldOrdinalH1PolynomialDegree_(fieldOrdinal,d) = polynomialH1Degree(d);
119 const auto & cardinality = this->basisCardinality_;
122 const ordinal_type tagSize = 4;
123 const ordinal_type posScDim = 0;
124 const ordinal_type posScOrd = 1;
125 const ordinal_type posDfOrd = 2;
127 OrdinalTypeArray1DHost tagView(
"tag view", cardinality*tagSize);
129 shards::CellTopology cellTopo(getCellTopologyData(this->basisCellTopologyKey_));
131 unsigned spaceDim = cellTopo.getDimension();
133 ordinal_type basis2Offset = basis1_->getCardinality();
135 for (
unsigned d=0; d<=spaceDim; d++)
137 unsigned subcellCount = cellTopo.getSubcellCount(d);
138 for (
unsigned subcellOrdinal=0; subcellOrdinal<subcellCount; subcellOrdinal++)
140 ordinal_type subcellDofCount1 = basis1->getDofCount(d, subcellOrdinal);
141 ordinal_type subcellDofCount2 = basis2->getDofCount(d, subcellOrdinal);
143 ordinal_type subcellDofCount = subcellDofCount1 + subcellDofCount2;
144 for (ordinal_type localDofID=0; localDofID<subcellDofCount; localDofID++)
146 ordinal_type fieldOrdinal;
147 if (localDofID < subcellDofCount1)
150 fieldOrdinal = basis1_->getDofOrdinal(d, subcellOrdinal, localDofID);
155 fieldOrdinal = basis2Offset + basis2_->getDofOrdinal(d, subcellOrdinal, localDofID - subcellDofCount1);
157 tagView(fieldOrdinal*tagSize+0) = d;
158 tagView(fieldOrdinal*tagSize+1) = subcellOrdinal;
159 tagView(fieldOrdinal*tagSize+2) = localDofID;
160 tagView(fieldOrdinal*tagSize+3) = subcellDofCount;
166 this->setOrdinalTagData(this->tagToOrdinal_,
169 this->basisCardinality_,
187 const int numScalarFamilies1 = basisValues1.numTensorDataFamilies();
188 if (numScalarFamilies1 > 0)
191 const int numScalarFamilies2 = basisValues2.numTensorDataFamilies();
192 INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() <=0, std::invalid_argument,
"When basis1 has scalar value, basis2 must also");
193 std::vector< TensorData<OutputValueType,DeviceType> > scalarFamilies(numScalarFamilies1 + numScalarFamilies2);
194 for (
int i=0; i<numScalarFamilies1; i++)
196 scalarFamilies[i] = basisValues1.
tensorData(i);
198 for (
int i=0; i<numScalarFamilies2; i++)
200 scalarFamilies[i+numScalarFamilies1] = basisValues2.
tensorData(i);
207 INTREPID2_TEST_FOR_EXCEPTION(!basisValues1.
vectorData().isValid(), std::invalid_argument,
"When basis1 does not have tensorData() defined, it must have a valid vectorData()");
208 INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() > 0, std::invalid_argument,
"When basis1 has vector value, basis2 must also");
210 const auto & vectorData1 = basisValues1.
vectorData();
211 const auto & vectorData2 = basisValues2.
vectorData();
213 const int numFamilies1 = vectorData1.numFamilies();
214 const int numComponents = vectorData1.numComponents();
215 INTREPID2_TEST_FOR_EXCEPTION(numComponents != vectorData2.numComponents(), std::invalid_argument,
"basis1 and basis2 must agree on the number of components in each vector");
216 const int numFamilies2 = vectorData2.numFamilies();
218 const int numFamilies = numFamilies1 + numFamilies2;
221 for (
int i=0; i<numFamilies1; i++)
223 for (
int j=0; j<numComponents; j++)
225 vectorComponents[i][j] = vectorData1.getComponent(i,j);
228 for (
int i=0; i<numFamilies2; i++)
230 for (
int j=0; j<numComponents; j++)
232 vectorComponents[i+numFamilies1][j] = vectorData2.getComponent(i,j);
249 const int basisCardinality1 = basis1_->getCardinality();
250 const int basisCardinality2 = basis2_->getCardinality();
251 const int basisCardinality = basisCardinality1 + basisCardinality2;
253 auto dofCoords1 = Kokkos::subview(dofCoords, std::make_pair(0,basisCardinality1), Kokkos::ALL());
254 auto dofCoords2 = Kokkos::subview(dofCoords, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
256 basis1_->getDofCoords(dofCoords1);
257 basis2_->getDofCoords(dofCoords2);
272 const int basisCardinality1 = basis1_->getCardinality();
273 const int basisCardinality2 = basis2_->getCardinality();
274 const int basisCardinality = basisCardinality1 + basisCardinality2;
276 auto dofCoeffs1 = Kokkos::subview(dofCoeffs, std::make_pair(0,basisCardinality1), Kokkos::ALL());
277 auto dofCoeffs2 = Kokkos::subview(dofCoeffs, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
279 basis1_->getDofCoeffs(dofCoeffs1);
280 basis2_->getDofCoeffs(dofCoeffs2);
291 return name_.c_str();
297 using BasisBase::getValues;
314 const EOperator operatorType = OPERATOR_VALUE )
const override
316 const int fieldStartOrdinal1 = 0;
317 const int numFields1 = basis1_->getCardinality();
318 const int fieldStartOrdinal2 = numFields1;
319 const int numFields2 = basis2_->getCardinality();
324 basis1_->getValues(basisValues1, inputPoints, operatorType);
325 basis2_->getValues(basisValues2, inputPoints, operatorType);
346 virtual void getValues( OutputViewType outputValues,
const PointViewType inputPoints,
347 const EOperator operatorType = OPERATOR_VALUE )
const override
349 int cardinality1 = basis1_->getCardinality();
350 int cardinality2 = basis2_->getCardinality();
352 auto range1 = std::make_pair(0,cardinality1);
353 auto range2 = std::make_pair(cardinality1,cardinality1+cardinality2);
354 if (outputValues.rank() == 2)
356 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL());
357 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL());
359 basis1_->getValues(outputValues1, inputPoints, operatorType);
360 basis2_->getValues(outputValues2, inputPoints, operatorType);
362 else if (outputValues.rank() == 3)
364 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL());
365 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL());
367 basis1_->getValues(outputValues1, inputPoints, operatorType);
368 basis2_->getValues(outputValues2, inputPoints, operatorType);
370 else if (outputValues.rank() == 4)
372 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
373 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
375 basis1_->getValues(outputValues1, inputPoints, operatorType);
376 basis2_->getValues(outputValues2, inputPoints, operatorType);
378 else if (outputValues.rank() == 5)
380 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
381 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
383 basis1_->getValues(outputValues1, inputPoints, operatorType);
384 basis2_->getValues(outputValues2, inputPoints, operatorType);
386 else if (outputValues.rank() == 6)
388 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
389 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
391 basis1_->getValues(outputValues1, inputPoints, operatorType);
392 basis2_->getValues(outputValues2, inputPoints, operatorType);
394 else if (outputValues.rank() == 7)
396 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
397 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
399 basis1_->getValues(outputValues1, inputPoints, operatorType);
400 basis2_->getValues(outputValues2, inputPoints, operatorType);
404 INTREPID2_TEST_FOR_EXCEPTION(
true, std::invalid_argument,
"Unsupported outputValues rank");
408 virtual int getNumTensorialExtrusions()
const override
410 return basis1_->getNumTensorialExtrusions();
virtual void getValues(BasisValues< OutputValueType, DeviceType > outputValues, const TensorPoints< PointValueType, DeviceType > inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell, using point and output value containers that allow pre...
virtual void getDofCoords(ScalarViewType dofCoords) const override
Fills in spatial locations (coordinates) of degrees of freedom (nodes) on the reference cell...
virtual BasisValues< OutputValueType, DeviceType > allocateBasisValues(TensorPoints< PointValueType, DeviceType > points, const EOperator operatorType=OPERATOR_VALUE) const override
Allocate BasisValues container suitable for passing to the getValues() variant that takes a TensorPoi...
View-like interface to tensor points; point components are stored separately; the appropriate coordin...
BasisValues< Scalar, DeviceType > basisValuesForFields(const int &fieldStartOrdinal, const int &numFields)
field start and length must align with families in vectorData_ or tensorDataFamilies_ (whichever is v...
const VectorDataType & vectorData() const
VectorData accessor.
TensorDataType & tensorData()
TensorData accessor for single-family scalar data.
The data containers in Intrepid2 that support sum factorization and other reduced-data optimizations ...
virtual void getValues(OutputViewType outputValues, const PointViewType inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell.
A basis that is the direct sum of two other bases.
virtual const char * getName() const override
Returns basis name.
virtual void getDofCoeffs(ScalarViewType dofCoeffs) const override
Fills in coefficients of degrees of freedom for Lagrangian basis on the reference cell...
View-like interface to tensor data; tensor components are stored separately and multiplied together a...
Reference-space field values for a basis, designed to support typical vector-valued bases...
Basis_DirectSumBasis(BasisPtr basis1, BasisPtr basis2)
Constructor.