10 #ifndef BELOS_PSEUDO_BLOCK_CG_ITER_MP_VECTOR_HPP
11 #define BELOS_PSEUDO_BLOCK_CG_ITER_MP_VECTOR_HPP
22 #include "BelosPseudoBlockCGIter.hpp"
41 template<
class Storage,
class MV,
class OP>
43 virtual public CGIteration<Sacado::MP::Vector<Storage>, MV, OP> {
51 typedef MultiVecTraits<ScalarType,MV>
MVT;
52 typedef OperatorTraits<ScalarType,MV,OP>
OPT;
67 const Teuchos::RCP<StatusTest<ScalarType,MV,OP> > &tester,
113 void initializeCG(CGIterationState<ScalarType,MV>& newstate);
120 CGIterationState<ScalarType,MV> empty;
132 CGIterationState<ScalarType,MV> state;
167 const LinearProblem<ScalarType,MV,OP>&
getProblem()
const {
return *lp_; }
175 "Belos::PseudoBlockCGIter::setBlockSize(): Cannot use a block size that is not one.");
192 if (static_cast<size_type> (iter_) >= diag_.size ()) {
195 return diag_ (0, iter_);
207 if (static_cast<size_type> (iter_) >= offdiag_.size ()) {
210 return offdiag_ (0, iter_);
270 template<
class StorageType,
class MV,
class OP>
273 const Teuchos::RCP<OutputManager<ScalarType> > &printer,
274 const Teuchos::RCP<StatusTest<ScalarType,MV,OP> > &tester,
282 assertPositiveDefiniteness_( params.get(
"Assert Positive Definiteness",
true) ),
283 numEntriesForCondEst_(params.get(
"Max Size For Condest",0) ),
284 breakDownTol_(params.get(
"Ensemble Breakdown Tolerance", 0.0))
291 template <
class StorageType,
class MV,
class OP>
292 void PseudoBlockCGIter<Sacado::MP::Vector<StorageType>,MV,OP>::
293 initializeCG(CGIterationState<ScalarType,MV>& newstate)
299 "Belos::PseudoBlockCGIter::initialize(): Cannot initialize state storage!");
305 int numRHS = MVT::GetNumberVecs(*tmp);
311 R_ = MVT::Clone( *tmp, numRHS_ );
312 Z_ = MVT::Clone( *tmp, numRHS_ );
313 P_ = MVT::Clone( *tmp, numRHS_ );
314 AP_ = MVT::Clone( *tmp, numRHS_ );
318 if(numEntriesForCondEst_ > 0) {
319 diag_.resize(numEntriesForCondEst_);
320 offdiag_.resize(numEntriesForCondEst_-1);
325 std::string errstr(
"Belos::BlockPseudoCGIter::initialize(): Specified multivectors must have a consistent length and width.");
334 std::invalid_argument, errstr );
336 std::invalid_argument, errstr );
339 if (newstate.R != R_) {
341 MVT::MvAddMv( one, *newstate.R, zero, *newstate.R, *R_ );
348 lp_->applyLeftPrec( *R_, *Z_ );
351 lp_->applyRightPrec( *Z_, *tmp1 );
356 lp_->applyRightPrec( *R_, *Z_ );
361 MVT::MvAddMv( one, *Z_, zero, *Z_, *P_ );
366 "Belos::CGIter::initialize(): CGStateIterState does not have initial residual.");
376 template <
class StorageType,
class MV,
class OP>
377 void PseudoBlockCGIter<Sacado::MP::Vector<StorageType>,MV,OP>::
383 if (initialized_ ==
false) {
389 std::vector<int> index(1);
390 std::vector<ScalarType> rHz( numRHS_ ), rHz_old( numRHS_ ), pAp( numRHS_ );
398 ScalarType pAp_old=one, beta_old=one ,rHz_old2=one;
404 MVT::MvDot( *R_, *Z_, rHz );
406 if ( assertPositiveDefiniteness_ )
407 for (i=0; i<numRHS_; ++i)
410 "Belos::PseudoBlockCGIter::iterate(): negative value for r^H*M*r encountered!" );
415 while (stest_->checkStatus(
this) != Passed) {
421 lp_->applyOp( *P_, *AP_ );
424 MVT::MvDot( *P_, *AP_, pAp );
426 for (i=0; i<numRHS_; ++i) {
427 if ( assertPositiveDefiniteness_ )
431 "Belos::PseudoBlockCGIter::iterate(): non-positive value for p^H*A*p encountered!" );
433 const int ensemble_size = pAp[i].size();
434 for (
int k=0; k<ensemble_size; ++k) {
435 if (SVT::magnitude(pAp[i].coeff(k)) <= breakDownTol_)
436 alpha(i,i).fastAccessCoeff(k) = 0.0;
438 alpha(i,i).fastAccessCoeff(k) = rHz[i].coeff(k) / pAp[i].coeff(k);
445 MVT::MvTimesMatAddMv( one, *P_, alpha, one, *cur_soln_vec );
446 lp_->updateSolution();
450 for (i=0; i<numRHS_; ++i) {
456 MVT::MvTimesMatAddMv( -one, *AP_, alpha, one, *R_ );
462 lp_->applyLeftPrec( *R_, *Z_ );
465 lp_->applyRightPrec( *Z_, *tmp );
470 lp_->applyRightPrec( *R_, *Z_ );
476 MVT::MvDot( *R_, *Z_, rHz );
477 if ( assertPositiveDefiniteness_ )
478 for (i=0; i<numRHS_; ++i)
481 "Belos::PseudoBlockCGIter::iterate(): negative value for r^H*M*r encountered!" );
484 for (i=0; i<numRHS_; ++i) {
485 const int ensemble_size = rHz[i].size();
486 for (
int k=0; k<ensemble_size; ++k) {
487 if (SVT::magnitude(rHz[i].coeff(k)) <= breakDownTol_)
488 beta(i,i).fastAccessCoeff(k) = 0.0;
490 beta(i,i).fastAccessCoeff(k) = rHz[i].coeff(k) / rHz_old[i].coeff(k);
495 MVT::MvAddMv( one, *Z_i, beta(i,i), *P_i, *P_i );
507 rHz_old2 = rHz_old[0];
508 beta_old = beta(0,0);
KOKKOS_INLINE_FUNCTION PCE< Storage > sqrt(const PCE< Storage > &a)
bool assertPositiveDefiniteness_
Teuchos::ScalarTraits< typename Storage::value_type > SVT
bool is_null(const std::shared_ptr< T > &p)
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
static magnitudeType real(T a)
const Teuchos::RCP< LinearProblem< ScalarType, MV, OP > > lp_
OperatorTraits< ScalarType, MV, OP > OPT
void initialize()
Initialize the solver with the initial vectors from the linear problem or random data.
Teuchos::RCP< const MV > getNativeResiduals(std::vector< MagnitudeType > *norms) const
Teuchos::RCP< MV > getCurrentUpdate() const
Get the current update to the linear system.
virtual ~PseudoBlockCGIter()
Destructor.
This class implements the pseudo-block CG iteration, where the basic CG algorithm is performed on all...
Teuchos::ScalarTraits< ScalarType > SCT
void setBlockSize(int blockSize)
Set the blocksize.
int getNumIters() const
Get the current iteration count.
void resetNumIters(int iter=0)
Reset the iteration count.
Teuchos::ArrayView< MagnitudeType > getDiag()
Gets the diagonal for condition estimation.
MultiVecTraits< ScalarType, MV > MVT
int getBlockSize() const
Get the blocksize to be used by the iterative solver in solving this linear problem.
int numEntriesForCondEst_
Teuchos::ArrayView< MagnitudeType > getOffDiag()
Gets the off-diagonal for condition estimation.
SCT::magnitudeType MagnitudeType
void setDoCondEst(bool val)
Sets whether or not to store the diagonal for condition estimation.
Sacado::MP::Vector< Storage > ScalarType
CGIterationState< ScalarType, MV > getState() const
Get the current state of the linear solver.
bool isInitialized()
States whether the solver has been initialized or not.
const LinearProblem< ScalarType, MV, OP > & getProblem() const
Get a constant reference to the linear problem.
SVT::magnitudeType breakDownTol_
const Teuchos::RCP< OutputManager< ScalarType > > om_
Teuchos::ArrayRCP< MagnitudeType > offdiag_
const Teuchos::RCP< StatusTest< ScalarType, MV, OP > > stest_