49 #ifndef Intrepid2_DirectSumBasis_h
50 #define Intrepid2_DirectSumBasis_h
52 #include <Kokkos_DynRankView.hpp>
66 template<
typename BasisBaseClass>
70 using BasisBase = BasisBaseClass;
71 using BasisPtr = Teuchos::RCP<BasisBase>;
73 using DeviceType =
typename BasisBase::DeviceType;
74 using ExecutionSpace =
typename BasisBase::ExecutionSpace;
75 using OutputValueType =
typename BasisBase::OutputValueType;
76 using PointValueType =
typename BasisBase::PointValueType;
78 using OrdinalTypeArray1DHost =
typename BasisBase::OrdinalTypeArray1DHost;
79 using OrdinalTypeArray2DHost =
typename BasisBase::OrdinalTypeArray2DHost;
80 using OutputViewType =
typename BasisBase::OutputViewType;
81 using PointViewType =
typename BasisBase::PointViewType;
82 using ScalarViewType =
typename BasisBase::ScalarViewType;
95 basis1_(basis1),basis2_(basis2)
97 INTREPID2_TEST_FOR_EXCEPTION(basis1->getBasisType() != basis2->getBasisType(), std::invalid_argument,
"basis1 and basis2 must agree in basis type");
98 INTREPID2_TEST_FOR_EXCEPTION(basis1->getBaseCellTopology().getKey() != basis2->getBaseCellTopology().getKey(),
99 std::invalid_argument,
"basis1 and basis2 must agree in cell topology");
100 INTREPID2_TEST_FOR_EXCEPTION(basis1->getNumTensorialExtrusions() != basis2->getNumTensorialExtrusions(),
101 std::invalid_argument,
"basis1 and basis2 must agree in number of tensorial extrusions");
102 INTREPID2_TEST_FOR_EXCEPTION(basis1->getCoordinateSystem() != basis2->getCoordinateSystem(),
103 std::invalid_argument,
"basis1 and basis2 must agree in coordinate system");
105 this->basisCardinality_ = basis1->getCardinality() + basis2->getCardinality();
106 this->basisDegree_ = std::max(basis1->getDegree(), basis2->getDegree());
109 std::ostringstream basisName;
110 basisName << basis1->getName() <<
" + " << basis2->getName();
111 name_ = basisName.str();
114 this->basisCellTopology_ = basis1->getBaseCellTopology();
115 this->basisType_ = basis1->getBasisType();
116 this->basisCoordinates_ = basis1->getCoordinateSystem();
118 if (this->basisType_ == BASIS_FEM_HIERARCHICAL)
120 int degreeLength = basis1_->getPolynomialDegreeLength();
121 INTREPID2_TEST_FOR_EXCEPTION(degreeLength != basis2_->getPolynomialDegreeLength(), std::invalid_argument,
"Basis1 and Basis2 must agree on polynomial degree length");
123 this->fieldOrdinalPolynomialDegree_ = OrdinalTypeArray2DHost(
"DirectSumBasis degree lookup", this->basisCardinality_,degreeLength);
124 this->fieldOrdinalH1PolynomialDegree_ = OrdinalTypeArray2DHost(
"DirectSumBasis H^1 degree lookup",this->basisCardinality_,degreeLength);
126 for (
int fieldOrdinal1=0; fieldOrdinal1<basis1_->getCardinality(); fieldOrdinal1++)
128 int fieldOrdinal = fieldOrdinal1;
129 auto polynomialDegree = basis1->getPolynomialDegreeOfField(fieldOrdinal1);
130 auto polynomialH1Degree = basis1->getH1PolynomialDegreeOfField(fieldOrdinal1);
131 for (
int d=0; d<degreeLength; d++)
133 this->fieldOrdinalPolynomialDegree_ (fieldOrdinal,d) = polynomialDegree(d);
134 this->fieldOrdinalH1PolynomialDegree_(fieldOrdinal,d) = polynomialH1Degree(d);
137 for (
int fieldOrdinal2=0; fieldOrdinal2<basis2_->getCardinality(); fieldOrdinal2++)
139 int fieldOrdinal = basis1->getCardinality() + fieldOrdinal2;
141 auto polynomialDegree = basis2->getPolynomialDegreeOfField(fieldOrdinal2);
142 auto polynomialH1Degree = basis2->getH1PolynomialDegreeOfField(fieldOrdinal2);
143 for (
int d=0; d<degreeLength; d++)
145 this->fieldOrdinalPolynomialDegree_ (fieldOrdinal,d) = polynomialDegree(d);
146 this->fieldOrdinalH1PolynomialDegree_(fieldOrdinal,d) = polynomialH1Degree(d);
153 const auto & cardinality = this->basisCardinality_;
156 const ordinal_type tagSize = 4;
157 const ordinal_type posScDim = 0;
158 const ordinal_type posScOrd = 1;
159 const ordinal_type posDfOrd = 2;
161 OrdinalTypeArray1DHost tagView(
"tag view", cardinality*tagSize);
163 shards::CellTopology cellTopo = this->basisCellTopology_;
165 unsigned spaceDim = cellTopo.getDimension();
167 ordinal_type basis2Offset = basis1_->getCardinality();
169 for (
unsigned d=0; d<=spaceDim; d++)
171 unsigned subcellCount = cellTopo.getSubcellCount(d);
172 for (
unsigned subcellOrdinal=0; subcellOrdinal<subcellCount; subcellOrdinal++)
174 ordinal_type subcellDofCount1 = basis1->getDofCount(d, subcellOrdinal);
175 ordinal_type subcellDofCount2 = basis2->getDofCount(d, subcellOrdinal);
177 ordinal_type subcellDofCount = subcellDofCount1 + subcellDofCount2;
178 for (ordinal_type localDofID=0; localDofID<subcellDofCount; localDofID++)
180 ordinal_type fieldOrdinal;
181 if (localDofID < subcellDofCount1)
184 fieldOrdinal = basis1_->getDofOrdinal(d, subcellOrdinal, localDofID);
189 fieldOrdinal = basis2Offset + basis2_->getDofOrdinal(d, subcellOrdinal, localDofID - subcellDofCount1);
191 tagView(fieldOrdinal*tagSize+0) = d;
192 tagView(fieldOrdinal*tagSize+1) = subcellOrdinal;
193 tagView(fieldOrdinal*tagSize+2) = localDofID;
194 tagView(fieldOrdinal*tagSize+3) = subcellDofCount;
200 this->setOrdinalTagData(this->tagToOrdinal_,
203 this->basisCardinality_,
221 const int numScalarFamilies1 = basisValues1.numTensorDataFamilies();
222 if (numScalarFamilies1 > 0)
225 const int numScalarFamilies2 = basisValues2.numTensorDataFamilies();
226 INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() <=0, std::invalid_argument,
"When basis1 has scalar value, basis2 must also");
227 std::vector< TensorData<OutputValueType,DeviceType> > scalarFamilies(numScalarFamilies1 + numScalarFamilies2);
228 for (
int i=0; i<numScalarFamilies1; i++)
230 scalarFamilies[i] = basisValues1.
tensorData(i);
232 for (
int i=0; i<numScalarFamilies2; i++)
234 scalarFamilies[i+numScalarFamilies1] = basisValues2.
tensorData(i);
241 INTREPID2_TEST_FOR_EXCEPTION(!basisValues1.
vectorData().isValid(), std::invalid_argument,
"When basis1 does not have tensorData() defined, it must have a valid vectorData()");
242 INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() > 0, std::invalid_argument,
"When basis1 has vector value, basis2 must also");
244 const auto & vectorData1 = basisValues1.
vectorData();
245 const auto & vectorData2 = basisValues2.
vectorData();
247 const int numFamilies1 = vectorData1.numFamilies();
248 const int numComponents = vectorData1.numComponents();
249 INTREPID2_TEST_FOR_EXCEPTION(numComponents != vectorData2.numComponents(), std::invalid_argument,
"basis1 and basis2 must agree on the number of components in each vector");
250 const int numFamilies2 = vectorData2.numFamilies();
252 const int numFamilies = numFamilies1 + numFamilies2;
255 for (
int i=0; i<numFamilies1; i++)
257 for (
int j=0; j<numComponents; j++)
259 vectorComponents[i][j] = vectorData1.getComponent(i,j);
262 for (
int i=0; i<numFamilies2; i++)
264 for (
int j=0; j<numComponents; j++)
266 vectorComponents[i+numFamilies1][j] = vectorData2.getComponent(i,j);
283 const int basisCardinality1 = basis1_->getCardinality();
284 const int basisCardinality2 = basis2_->getCardinality();
285 const int basisCardinality = basisCardinality1 + basisCardinality2;
287 auto dofCoords1 = Kokkos::subview(dofCoords, std::make_pair(0,basisCardinality1), Kokkos::ALL());
288 auto dofCoords2 = Kokkos::subview(dofCoords, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
290 basis1_->getDofCoords(dofCoords1);
291 basis2_->getDofCoords(dofCoords2);
306 const int basisCardinality1 = basis1_->getCardinality();
307 const int basisCardinality2 = basis2_->getCardinality();
308 const int basisCardinality = basisCardinality1 + basisCardinality2;
310 auto dofCoeffs1 = Kokkos::subview(dofCoeffs, std::make_pair(0,basisCardinality1), Kokkos::ALL());
311 auto dofCoeffs2 = Kokkos::subview(dofCoeffs, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
313 basis1_->getDofCoeffs(dofCoeffs1);
314 basis2_->getDofCoeffs(dofCoeffs2);
325 return name_.c_str();
331 using BasisBase::getValues;
348 const EOperator operatorType = OPERATOR_VALUE )
const override
350 const int fieldStartOrdinal1 = 0;
351 const int numFields1 = basis1_->getCardinality();
352 const int fieldStartOrdinal2 = numFields1;
353 const int numFields2 = basis2_->getCardinality();
358 basis1_->getValues(basisValues1, inputPoints, operatorType);
359 basis2_->getValues(basisValues2, inputPoints, operatorType);
380 virtual void getValues( OutputViewType outputValues,
const PointViewType inputPoints,
381 const EOperator operatorType = OPERATOR_VALUE )
const override
383 int cardinality1 = basis1_->getCardinality();
384 int cardinality2 = basis2_->getCardinality();
386 auto range1 = std::make_pair(0,cardinality1);
387 auto range2 = std::make_pair(cardinality1,cardinality1+cardinality2);
388 if (outputValues.rank() == 2)
390 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL());
391 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL());
393 basis1_->getValues(outputValues1, inputPoints, operatorType);
394 basis2_->getValues(outputValues2, inputPoints, operatorType);
396 else if (outputValues.rank() == 3)
398 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL());
399 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL());
401 basis1_->getValues(outputValues1, inputPoints, operatorType);
402 basis2_->getValues(outputValues2, inputPoints, operatorType);
404 else if (outputValues.rank() == 4)
406 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
407 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
409 basis1_->getValues(outputValues1, inputPoints, operatorType);
410 basis2_->getValues(outputValues2, inputPoints, operatorType);
412 else if (outputValues.rank() == 5)
414 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
415 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
417 basis1_->getValues(outputValues1, inputPoints, operatorType);
418 basis2_->getValues(outputValues2, inputPoints, operatorType);
420 else if (outputValues.rank() == 6)
422 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
423 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
425 basis1_->getValues(outputValues1, inputPoints, operatorType);
426 basis2_->getValues(outputValues2, inputPoints, operatorType);
428 else if (outputValues.rank() == 7)
430 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
431 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
433 basis1_->getValues(outputValues1, inputPoints, operatorType);
434 basis2_->getValues(outputValues2, inputPoints, operatorType);
438 INTREPID2_TEST_FOR_EXCEPTION(
true, std::invalid_argument,
"Unsupported outputValues rank");
442 virtual int getNumTensorialExtrusions()
const override
444 return basis1_->getNumTensorialExtrusions();
virtual void getValues(BasisValues< OutputValueType, DeviceType > outputValues, const TensorPoints< PointValueType, DeviceType > inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell, using point and output value containers that allow pre...
virtual void getDofCoords(ScalarViewType dofCoords) const override
Fills in spatial locations (coordinates) of degrees of freedom (nodes) on the reference cell...
virtual BasisValues< OutputValueType, DeviceType > allocateBasisValues(TensorPoints< PointValueType, DeviceType > points, const EOperator operatorType=OPERATOR_VALUE) const override
Allocate BasisValues container suitable for passing to the getValues() variant that takes a TensorPoi...
View-like interface to tensor points; point components are stored separately; the appropriate coordin...
const VectorDataType & vectorData() const
VectorData accessor.
BasisValues< Scalar, ExecSpaceType > basisValuesForFields(const int &fieldStartOrdinal, const int &numFields)
field start and length must align with families in vectorData_ or tensorDataFamilies_ (whichever is v...
The data containers in Intrepid2 that support sum factorization and other reduced-data optimizations ...
virtual void getValues(OutputViewType outputValues, const PointViewType inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell.
A basis that is the direct sum of two other bases.
virtual const char * getName() const override
Returns basis name.
virtual void getDofCoeffs(ScalarViewType dofCoeffs) const override
Fills in coefficients of degrees of freedom for Lagrangian basis on the reference cell...
TensorDataType & tensorData()
TensorData accessor for single-family scalar data.
View-like interface to tensor data; tensor components are stored separately and multiplied together a...
Reference-space field values for a basis, designed to support typical vector-valued bases...
Basis_DirectSumBasis(BasisPtr basis1, BasisPtr basis2)
Constructor.