Rythmos - Transient Integration for Differential Equations  Version of the Day
 All Classes Functions Variables Typedefs Pages
Rythmos_AdjointModelEvaluator.hpp
1 //@HEADER
2 // ***********************************************************************
3 //
4 // Rythmos Package
5 // Copyright (2006) Sandia Corporation
6 //
7 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
8 // license for use of this work by or on behalf of the U.S. Government.
9 //
10 // This library is free software; you can redistribute it and/or modify
11 // it under the terms of the GNU Lesser General Public License as
12 // published by the Free Software Foundation; either version 2.1 of the
13 // License, or (at your option) any later version.
14 //
15 // This library is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 // Lesser General Public License for more details.
19 //
20 // You should have received a copy of the GNU Lesser General Public
21 // License along with this library; if not, write to the Free Software
22 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
23 // USA
24 // Questions? Contact Todd S. Coffey (tscoffe@sandia.gov)
25 //
26 // ***********************************************************************
27 //@HEADER
28 
29 #ifndef RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP
30 #define RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP
31 
32 
33 #include "Rythmos_IntegratorBase.hpp"
34 #include "Thyra_ModelEvaluator.hpp" // Interface
35 #include "Thyra_StateFuncModelEvaluatorBase.hpp" // Implementation
36 #include "Thyra_ModelEvaluatorDelegatorBase.hpp"
37 #include "Thyra_DefaultScaledAdjointLinearOp.hpp"
38 #include "Thyra_DefaultAdjointLinearOpWithSolve.hpp"
39 #include "Thyra_VectorStdOps.hpp"
40 #include "Thyra_MultiVectorStdOps.hpp"
41 #include "Teuchos_implicit_cast.hpp"
42 #include "Teuchos_Assert.hpp"
43 
44 
45 namespace Rythmos {
46 
47 
172 template<class Scalar>
174  : virtual public Thyra::StateFuncModelEvaluatorBase<Scalar>
175 {
176 public:
177 
180 
183 
185  void setFwdStateModel(
186  const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel,
187  const Thyra::ModelEvaluatorBase::InArgs<Scalar> &basePoint );
188 
192  void setFwdTimeRange( const TimeRange<Scalar> &fwdTimeRange );
193 
207  const RCP<const InterpolationBufferBase<Scalar> > &fwdStateSolutionBuffer );
208 
210 
213 
215  RCP<const Thyra::VectorSpaceBase<Scalar> > get_x_space() const;
217  RCP<const Thyra::VectorSpaceBase<Scalar> > get_f_space() const;
219  Thyra::ModelEvaluatorBase::InArgs<Scalar> getNominalValues() const;
221  RCP<Thyra::LinearOpWithSolveBase<Scalar> > create_W() const;
223  RCP<Thyra::LinearOpBase<Scalar> > create_W_op() const;
225  Thyra::ModelEvaluatorBase::InArgs<Scalar> createInArgs() const;
226 
228 
229 private:
230 
233 
235  Thyra::ModelEvaluatorBase::OutArgs<Scalar> createOutArgsImpl() const;
237  void evalModelImpl(
238  const Thyra::ModelEvaluatorBase::InArgs<Scalar> &inArgs_bar,
239  const Thyra::ModelEvaluatorBase::OutArgs<Scalar> &outArgs_bar
240  ) const;
241 
243 
244 private:
245 
246  // /////////////////////////
247  // Private data members
248 
249  RCP<const Thyra::ModelEvaluator<Scalar> > fwdStateModel_;
250  Thyra::ModelEvaluatorBase::InArgs<Scalar> basePoint_;
251  TimeRange<Scalar> fwdTimeRange_;
252  RCP<const InterpolationBufferBase<Scalar> > fwdStateSolutionBuffer_;
253 
254  mutable bool isInitialized_;
255  mutable Thyra::ModelEvaluatorBase::InArgs<Scalar> prototypeInArgs_bar_;
256  mutable Thyra::ModelEvaluatorBase::OutArgs<Scalar> prototypeOutArgs_bar_;
257  mutable Thyra::ModelEvaluatorBase::InArgs<Scalar> adjointNominalValues_;
258  mutable RCP<Thyra::LinearOpBase<Scalar> > my_W_bar_adj_op_;
259  mutable RCP<Thyra::LinearOpBase<Scalar> > my_d_f_d_x_dot_op_;
260 
261  // /////////////////////////
262  // Private member functions
263 
264  // Just-in-time initialization function
265  void initialize() const;
266 
267 };
268 
269 
274 template<class Scalar>
275 RCP<AdjointModelEvaluator<Scalar> >
277  const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel,
278  const TimeRange<Scalar> &fwdTimeRange
279  )
280 {
281  RCP<AdjointModelEvaluator<Scalar> >
282  adjointModel = Teuchos::rcp(new AdjointModelEvaluator<Scalar>);
283  adjointModel->setFwdStateModel(fwdStateModel, fwdStateModel->getNominalValues());
284  adjointModel->setFwdTimeRange(fwdTimeRange);
285  return adjointModel;
286 }
287 
288 
289 // /////////////////////////////////
290 // Implementations
291 
292 
293 // Constructors/Intializers/Accessors
294 
295 
296 template<class Scalar>
298  :isInitialized_(false)
299 {}
300 
301 
302 template<class Scalar>
304  const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel,
305  const Thyra::ModelEvaluatorBase::InArgs<Scalar> &basePoint
306  )
307 {
308  TEUCHOS_TEST_FOR_EXCEPT(is_null(fwdStateModel));
309  fwdStateModel_ = fwdStateModel;
310  basePoint_ = basePoint;
311  isInitialized_ = false;
312 }
313 
314 
315 template<class Scalar>
317  const TimeRange<Scalar> &fwdTimeRange )
318 {
319  fwdTimeRange_ = fwdTimeRange;
320 }
321 
322 
323 template<class Scalar>
325  const RCP<const InterpolationBufferBase<Scalar> > &fwdStateSolutionBuffer )
326 {
327  TEUCHOS_TEST_FOR_EXCEPT(is_null(fwdStateSolutionBuffer));
328  fwdStateSolutionBuffer_ = fwdStateSolutionBuffer;
329 }
330 
331 
332 // Public functions overridden from ModelEvaulator
333 
334 
335 template<class Scalar>
336 RCP<const Thyra::VectorSpaceBase<Scalar> >
338 {
339  initialize();
340  return fwdStateModel_->get_f_space();
341 }
342 
343 
344 template<class Scalar>
345 RCP<const Thyra::VectorSpaceBase<Scalar> >
347 {
348  initialize();
349  return fwdStateModel_->get_x_space();
350 }
351 
352 
353 template<class Scalar>
354 Thyra::ModelEvaluatorBase::InArgs<Scalar>
356 {
357  initialize();
358  return adjointNominalValues_;
359 }
360 
361 
362 template<class Scalar>
363 RCP<Thyra::LinearOpWithSolveBase<Scalar> >
365 {
366  initialize();
367  return Thyra::nonconstAdjointLows<Scalar>(fwdStateModel_->create_W());
368 }
369 
370 
371 template<class Scalar>
372 RCP<Thyra::LinearOpBase<Scalar> >
374 {
375  initialize();
376  return Thyra::nonconstAdjoint<Scalar>(fwdStateModel_->create_W_op());
377 }
378 
379 
380 template<class Scalar>
381 Thyra::ModelEvaluatorBase::InArgs<Scalar>
383 {
384  initialize();
385  return prototypeInArgs_bar_;
386 }
387 
388 
389 // Private functions overridden from ModelEvaulatorDefaultBase
390 
391 
392 template<class Scalar>
393 Thyra::ModelEvaluatorBase::OutArgs<Scalar>
395 {
396  initialize();
397  return prototypeOutArgs_bar_;
398 }
399 
400 
401 template<class Scalar>
402 void AdjointModelEvaluator<Scalar>::evalModelImpl(
403  const Thyra::ModelEvaluatorBase::InArgs<Scalar> &inArgs_bar,
404  const Thyra::ModelEvaluatorBase::OutArgs<Scalar> &outArgs_bar
405  ) const
406 {
407 
408  using Teuchos::rcp_dynamic_cast;
409  using Teuchos::describe;
410  typedef Teuchos::ScalarTraits<Scalar> ST;
411  typedef Thyra::ModelEvaluatorBase MEB;
412  typedef Thyra::DefaultScaledAdjointLinearOp<Scalar> DSALO;
413  typedef Thyra::DefaultAdjointLinearOpWithSolve<Scalar> DALOWS;
414  typedef Teuchos::VerboseObjectTempState<Thyra::ModelEvaluatorBase> VOTSME;
415 
416  //
417  // A) Header stuff
418  //
419 
420  THYRA_MODEL_EVALUATOR_DECORATOR_EVAL_MODEL_GEN_BEGIN(
421  "AdjointModelEvaluator", inArgs_bar, outArgs_bar, Teuchos::null );
422 
423  initialize();
424 
425  VOTSME fwdStateModel_outputTempState(fwdStateModel_,out,verbLevel);
426 
427  //const bool trace = includesVerbLevel(verbLevel, Teuchos::VERB_LOW);
428  const bool dumpAll = includesVerbLevel(localVerbLevel, Teuchos::VERB_EXTREME);
429 
430  //
431  // B) Unpack the input and output arguments to see what we have to compute
432  //
433 
434  // B.1) InArgs
435 
436  const Scalar t_bar = inArgs_bar.get_t();
437  const RCP<const Thyra::VectorBase<Scalar> >
438  lambda_rev_dot = inArgs_bar.get_x_dot().assert_not_null(), // x_bar_dot
439  lambda = inArgs_bar.get_x().assert_not_null(); // x_bar
440  const Scalar alpha_bar = inArgs_bar.get_alpha();
441  const Scalar beta_bar = inArgs_bar.get_beta();
442 
443  if (dumpAll) {
444  *out << "\nlambda_rev_dot = " << describe(*lambda_rev_dot, Teuchos::VERB_EXTREME);
445  *out << "\nlambda = " << describe(*lambda, Teuchos::VERB_EXTREME);
446  *out << "\nalpha_bar = " << alpha_bar << "\n";
447  *out << "\nbeta_bar = " << beta_bar << "\n";
448  }
449 
450  // B.2) OutArgs
451 
452  const RCP<Thyra::VectorBase<Scalar> > f_bar = outArgs_bar.get_f();
453 
454  RCP<DALOWS> W_bar;
455  if (outArgs_bar.supports(MEB::OUT_ARG_W))
456  W_bar = rcp_dynamic_cast<DALOWS>(outArgs_bar.get_W(), true);
457 
458  RCP<DSALO> W_bar_op;
459  if (outArgs_bar.supports(MEB::OUT_ARG_W_op))
460  W_bar_op = rcp_dynamic_cast<DSALO>(outArgs_bar.get_W_op(), true);
461 
462  if (dumpAll) {
463  if (!is_null(W_bar)) {
464  *out << "\nW_bar = " << describe(*W_bar, Teuchos::VERB_EXTREME);
465  }
466  if (!is_null(W_bar_op)) {
467  *out << "\nW_bar_op = " << describe(*W_bar_op, Teuchos::VERB_EXTREME);
468  }
469  }
470 
471  //
472  // C) Evaluate the needed quantities from the underlying forward Model
473  //
474 
475  MEB::InArgs<Scalar> fwdInArgs = fwdStateModel_->createInArgs();
476 
477  // C.1) Set the required input arguments
478 
479  fwdInArgs = basePoint_;
480 
481  if (!is_null(fwdStateSolutionBuffer_)) {
482  const Scalar t = fwdTimeRange_.length() - t_bar;
483  RCP<const Thyra::VectorBase<Scalar> > x, x_dot;
484  get_x_and_x_dot<Scalar>( *fwdStateSolutionBuffer_, t,
485  outArg(x), outArg(x_dot) );
486  fwdInArgs.set_x(x);
487  fwdInArgs.set_x_dot(x);
488  }
489  else {
490  // If we don't have an IB object to get the state from, we will assume
491  // that the problem is linear and, therefore, we can pass in any old value
492  // of x, x_dot, and t and get the W_bar_adj object that we need. For this
493  // purpose, we will assume the model's base point will do.
494 
495  // 2008/05/14: rabartl: ToDo: Implement real variable dependancy
496  // communication support to make sure that this is okay! If the model is
497  // really nonlinear we need to check for this and throw if the user did
498  // not set up a fwdStateSolutionBuffer object!
499  }
500 
501 
502  // C.2) Evaluate W_bar_adj if needed
503 
504  RCP<Thyra::LinearOpWithSolveBase<Scalar> > W_bar_adj;
505  RCP<Thyra::LinearOpBase<Scalar> > W_bar_adj_op;
506  {
507 
508  MEB::OutArgs<Scalar> fwdOutArgs = fwdStateModel_->createOutArgs();
509 
510  // Get or create W_bar_adj or W_bar_adj_op if needed
511  if (!is_null(W_bar)) {
512  // If we have W_bar, the W_bar_adj was already created in
513  // this->create_W()
514  W_bar_adj = W_bar->getNonconstOp();
515  W_bar_adj_op = W_bar_adj;
516  }
517  else if (!is_null(W_bar_op)) {
518  // If we have W_bar_op, the W_bar_adj_op was already created in
519  // this->create_W_op()
520  W_bar_adj_op = W_bar_op->getNonconstOp();
521  }
522  else if (!is_null(f_bar)) {
523  TEUCHOS_TEST_FOR_EXCEPT_MSG(true, "ToDo: Unit test this code!");
524  // If the user did not pass in W_bar or W_bar_op, then we need to create
525  // our own local LOB form W_bar_adj_op of W_bar_adj in order to evaluate
526  // the residual f_bar
527  if (is_null(my_W_bar_adj_op_)) {
528  my_W_bar_adj_op_ = fwdStateModel_->create_W_op();
529  }
530  W_bar_adj_op = my_W_bar_adj_op_;
531  }
532 
533  // Set W_bar_adj or W_bar_adj_op on the OutArgs object
534  if (!is_null(W_bar_adj)) {
535  fwdOutArgs.set_W(W_bar_adj);
536  }
537  else if (!is_null(W_bar_adj_op)) {
538  fwdOutArgs.set_W_op(W_bar_adj_op);
539  }
540 
541  // Set alpha and beta on OutArgs object
542  if (!is_null(W_bar_adj) || !is_null(W_bar_adj_op)) {
543  fwdInArgs.set_alpha(alpha_bar);
544  fwdInArgs.set_beta(beta_bar);
545  }
546 
547  // Evaluate the model
548  if (!is_null(W_bar_adj) || !is_null(W_bar_adj_op)) {
549  fwdStateModel_->evalModel( fwdInArgs, fwdOutArgs );
550  }
551 
552  // Print the objects if requested
553  if (!is_null(W_bar_adj) && dumpAll)
554  *out << "\nW_bar_adj = " << describe(*W_bar_adj, Teuchos::VERB_EXTREME);
555  if (!is_null(W_bar_adj_op) && dumpAll)
556  *out << "\nW_bar_adj_op = " << describe(*W_bar_adj_op, Teuchos::VERB_EXTREME);
557 
558  }
559 
560  // C.3) Evaluate d(f)/d(x_dot) if needed
561 
562  RCP<Thyra::LinearOpBase<Scalar> > d_f_d_x_dot_op;
563  if (!is_null(f_bar)) {
564  if (is_null(my_d_f_d_x_dot_op_)) {
565  my_d_f_d_x_dot_op_ = fwdStateModel_->create_W_op();
566  }
567  d_f_d_x_dot_op = my_d_f_d_x_dot_op_;
568  MEB::OutArgs<Scalar> fwdOutArgs = fwdStateModel_->createOutArgs();
569  fwdOutArgs.set_W_op(d_f_d_x_dot_op);
570  fwdInArgs.set_alpha(ST::one());
571  fwdInArgs.set_beta(ST::zero());
572  fwdStateModel_->evalModel( fwdInArgs, fwdOutArgs );
573  if (dumpAll) {
574  *out << "\nd_f_d_x_dot_op = " << describe(*d_f_d_x_dot_op, Teuchos::VERB_EXTREME);
575  }
576  }
577 
578  //
579  // D) Evaluate the adjoint equation residual:
580  //
581  // f_bar = d(f)/d(x_dot)^T * lambda_hat + 1/beta_bar * W_bar_adj^T * lambda
582  // - d(g)/d(x)^T
583  //
584 
585  if (!is_null(f_bar)) {
586 
587  // D.1) lambda_hat = lambda_rev_dot - alpha_bar/beta_bar * lambda
588  const RCP<Thyra::VectorBase<Scalar> >
589  lambda_hat = createMember(lambda_rev_dot->space());
590  Thyra::V_VpStV<Scalar>( outArg(*lambda_hat),
591  *lambda_rev_dot, -alpha_bar/beta_bar, *lambda );
592  if (dumpAll)
593  *out << "\nlambda_hat = " << describe(*lambda_hat, Teuchos::VERB_EXTREME);
594 
595  // D.2) f_bar = d(f)/d(x_dot)^T * lambda_hat
596  Thyra::apply<Scalar>( *d_f_d_x_dot_op, Thyra::CONJTRANS, *lambda_hat,
597  outArg(*f_bar) );
598 
599  // D.3) f_bar += 1/beta_bar * W_bar_adj^T * lambda
600  Thyra::apply<Scalar>( *W_bar_adj_op, Thyra::CONJTRANS, *lambda,
601  outArg(*f_bar), 1.0/beta_bar, ST::one() );
602 
603  // D.4) f_bar += - d(g)/d(x)^T
604  // 2008/05/15: rabart: ToDo: Implement once we add support for
605  // distributed response functions
606 
607  if (dumpAll)
608  *out << "\nf_bar = " << describe(*f_bar, Teuchos::VERB_EXTREME);
609 
610  }
611 
612  if (dumpAll) {
613  if (!is_null(W_bar)) {
614  *out << "\nW_bar = " << describe(*W_bar, Teuchos::VERB_EXTREME);
615  }
616  if (!is_null(W_bar_op)) {
617  *out << "\nW_bar_op = " << describe(*W_bar_op, Teuchos::VERB_EXTREME);
618  }
619  }
620 
621 
622  //
623  // E) Do any remaining post processing
624  //
625 
626  THYRA_MODEL_EVALUATOR_DECORATOR_EVAL_MODEL_END();
627 
628 }
629 
630 
631 // private
632 
633 
634 template<class Scalar>
635 void AdjointModelEvaluator<Scalar>::initialize() const
636 {
637 
638  typedef Thyra::ModelEvaluatorBase MEB;
639 
640  if (isInitialized_)
641  return;
642 
643  //
644  // A) Validate the that forward Model is of the correct form!
645  //
646 
647  MEB::InArgs<Scalar> fwdStateModelInArgs = fwdStateModel_->createInArgs();
648  MEB::OutArgs<Scalar> fwdStateModelOutArgs = fwdStateModel_->createOutArgs();
649 
650 #ifdef HAVE_RYTHMOS_DEBUG
651  TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_x_dot) );
652  TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_x) );
653  TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_t) );
654  TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_alpha) );
655  TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_beta) );
656  TEUCHOS_ASSERT( fwdStateModelOutArgs.supports(MEB::OUT_ARG_f) );
657  TEUCHOS_ASSERT( fwdStateModelOutArgs.supports(MEB::OUT_ARG_W) );
658 #endif
659 
660  //
661  // B) Set up the prototypical InArgs and OutArgs
662  //
663 
664  {
665  MEB::InArgsSetup<Scalar> inArgs_bar;
666  inArgs_bar.setModelEvalDescription(this->description());
667  inArgs_bar.setSupports( MEB::IN_ARG_x_dot );
668  inArgs_bar.setSupports( MEB::IN_ARG_x );
669  inArgs_bar.setSupports( MEB::IN_ARG_t );
670  inArgs_bar.setSupports( MEB::IN_ARG_alpha );
671  inArgs_bar.setSupports( MEB::IN_ARG_beta );
672  prototypeInArgs_bar_ = inArgs_bar;
673  }
674 
675  {
676  MEB::OutArgsSetup<Scalar> outArgs_bar;
677  outArgs_bar.setModelEvalDescription(this->description());
678  outArgs_bar.setSupports(MEB::OUT_ARG_f);
679  if (fwdStateModelOutArgs.supports(MEB::OUT_ARG_W) ) {
680  outArgs_bar.setSupports(MEB::OUT_ARG_W);
681  outArgs_bar.set_W_properties(fwdStateModelOutArgs.get_W_properties());
682  }
683  if (fwdStateModelOutArgs.supports(MEB::OUT_ARG_W_op) ) {
684  outArgs_bar.setSupports(MEB::OUT_ARG_W_op);
685  outArgs_bar.set_W_properties(fwdStateModelOutArgs.get_W_properties());
686  }
687  prototypeOutArgs_bar_ = outArgs_bar;
688  }
689 
690  //
691  // D) Set up the nominal values for the adjoint
692  //
693 
694  // Copy structure
695  adjointNominalValues_ = prototypeInArgs_bar_;
696  // Just set a zero initial condition for the adjoint
697  const RCP<Thyra::VectorBase<Scalar> > zero_lambda_vec =
698  createMember(fwdStateModel_->get_f_space());
699  V_S( zero_lambda_vec.ptr(), ScalarTraits<Scalar>::zero() );
700  adjointNominalValues_.set_x_dot(zero_lambda_vec);
701  adjointNominalValues_.set_x(zero_lambda_vec);
702 
703  //
704  // E) Wipe out other cached objects
705  //
706 
707  my_W_bar_adj_op_ = Teuchos::null;
708  my_d_f_d_x_dot_op_ = Teuchos::null;
709 
710  //
711  // F) We are initialized!
712  //
713 
714  isInitialized_ = true;
715 
716 }
717 
718 
719 } // namespace Rythmos
720 
721 
722 #endif // RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP
RCP< const Thyra::VectorSpaceBase< Scalar > > get_f_space() const
Standard concrete adjoint ModelEvaluator for time-constant mass matrix models.
RCP< Thyra::LinearOpBase< Scalar > > create_W_op() const
Thyra::ModelEvaluatorBase::InArgs< Scalar > getNominalValues() const
RCP< Thyra::LinearOpWithSolveBase< Scalar > > create_W() const
Base class for an interpolation buffer.
RCP< AdjointModelEvaluator< Scalar > > adjointModelEvaluator(const RCP< const Thyra::ModelEvaluator< Scalar > > &fwdStateModel, const TimeRange< Scalar > &fwdTimeRange)
Nonmember constructor.
void setFwdStateSolutionBuffer(const RCP< const InterpolationBufferBase< Scalar > > &fwdStateSolutionBuffer)
Set the interpolation buffer that will return values of the state solution x and x_dot at various poi...
Thyra::ModelEvaluatorBase::InArgs< Scalar > createInArgs() const
void setFwdStateModel(const RCP< const Thyra::ModelEvaluator< Scalar > > &fwdStateModel, const Thyra::ModelEvaluatorBase::InArgs< Scalar > &basePoint)
Set the underlying forward model and base point.
void setFwdTimeRange(const TimeRange< Scalar > &fwdTimeRange)
Set the forward time range that this adjoint model will be defined over.
RCP< const Thyra::VectorSpaceBase< Scalar > > get_x_space() const