ROL
ROL_FDivergence.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Rapid Optimization Library (ROL) Package
4 //
5 // Copyright 2014 NTESS and the ROL contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef ROL_FDIVERGENCE_HPP
11 #define ROL_FDIVERGENCE_HPP
12 
14 #include "ROL_Types.hpp"
15 
50 namespace ROL {
51 
52 template<class Real>
53 class FDivergence : public RandVarFunctional<Real> {
54 private:
55  Real thresh_;
56 
57  Real valLam_;
58  Real valLam2_;
59  Real valMu_;
60  Real valMu2_;
61 
67 
70 
75 
76  void checkInputs(void) const {
77  Real zero(0);
78  ROL_TEST_FOR_EXCEPTION((thresh_ <= zero), std::invalid_argument,
79  ">>> ERROR (ROL::FDivergence): Threshold must be positive!");
80  }
81 
82 public:
87  FDivergence(const Real thresh) : RandVarFunctional<Real>(), thresh_(thresh),
88  valLam_(0),valLam2_(0), valMu_(0), valMu2_(0) {
89  checkInputs();
90  }
91 
100  FDivergence(ROL::ParameterList &parlist) : RandVarFunctional<Real>(),
101  valLam_(0),valLam2_(0), valMu_(0), valMu2_(0) {
102  ROL::ParameterList &list
103  = parlist.sublist("SOL").sublist("Risk Measure").sublist("F-Divergence");
104  thresh_ = list.get<Real>("Threshold");
105  checkInputs();
106  }
107 
115  virtual Real Fprimal(Real x, int deriv = 0) const = 0;
116 
129  virtual Real Fdual(Real x, int deriv = 0) const = 0;
130 
131  bool check(std::ostream &outStream = std::cout) const {
132  const Real tol(std::sqrt(ROL_EPSILON<Real>()));
133  bool flag = true;
134 
135  Real x = static_cast<Real>(rand())/static_cast<Real>(RAND_MAX);
136  Real t = static_cast<Real>(rand())/static_cast<Real>(RAND_MAX);
137  Real fp = Fprimal(x);
138  Real fd = Fdual(t);
139  outStream << "Check Fenchel-Young Inequality: F(x) + F*(t) >= xt" << std::endl;
140  outStream << "x = " << x << std::endl;
141  outStream << "t = " << t << std::endl;
142  outStream << "F(x) = " << fp << std::endl;
143  outStream << "F*(t) = " << fd << std::endl;
144  outStream << "Is Valid? " << (fp+fd >= x*t) << std::endl;
145  flag = (fp+fd >= x*t) ? flag : false;
146 
147  x = static_cast<Real>(rand())/static_cast<Real>(RAND_MAX);
148  t = Fprimal(x,1);
149  fp = Fprimal(x);
150  fd = Fdual(t);
151  outStream << "Check Fenchel-Young Equality: F(x) + F(t) = xt for t = d/dx F(x)" << std::endl;
152  outStream << "x = " << x << std::endl;
153  outStream << "t = " << t << std::endl;
154  outStream << "F(x) = " << fp << std::endl;
155  outStream << "F*(t) = " << fd << std::endl;
156  outStream << "Is Valid? " << (std::abs(fp+fd - x*t)<=tol) << std::endl;
157  flag = (std::abs(fp+fd - x*t)<=tol) ? flag : false;
158 
159  t = static_cast<Real>(rand())/static_cast<Real>(RAND_MAX);
160  x = Fdual(t,1);
161  fp = Fprimal(x);
162  fd = Fdual(t);
163  outStream << "Check Fenchel-Young Equality: F(x) + F(t) = xt for x = d/dt F*(t)" << std::endl;
164  outStream << "x = " << x << std::endl;
165  outStream << "t = " << t << std::endl;
166  outStream << "F(x) = " << fp << std::endl;
167  outStream << "F*(t) = " << fd << std::endl;
168  outStream << "Is Valid? " << (std::abs(fp+fd - x*t)<=tol) << std::endl;
169  flag = (std::abs(fp+fd - x*t)<=tol) ? flag : false;
170 
171  return flag;
172  }
173 
174  void initialize(const Vector<Real> &x) {
176  valLam_ = 0; valLam2_ = 0; valMu_ = 0; valMu2_ = 0;
177  }
178 
179  // Value update and get functions
181  const Vector<Real> &x,
182  const std::vector<Real> &xstat,
183  Real &tol) {
184  Real val = computeValue(obj,x,tol);
185  Real xlam = xstat[0];
186  Real xmu = xstat[1];
187  Real r = Fdual((val-xmu)/xlam,0);
188  val_ += weight_ * r;
189  }
190 
191  Real getValue(const Vector<Real> &x,
192  const std::vector<Real> &xstat,
193  SampleGenerator<Real> &sampler) {
194  Real val(0);
195  sampler.sumAll(&val_,&val,1);
196  Real xlam = xstat[0];
197  Real xmu = xstat[1];
198  return xlam*(thresh_ + val) + xmu;
199  }
200 
201  // Gradient update and get functions
203  const Vector<Real> &x,
204  const std::vector<Real> &xstat,
205  Real &tol) {
206  Real val = computeValue(obj,x,tol);
207  Real xlam = xstat[0];
208  Real xmu = xstat[1];
209  Real inp = (val-xmu)/xlam;
210  Real r0 = Fdual(inp,0), r1 = Fdual(inp,1);
211 
212  if (std::abs(r0) >= ROL_EPSILON<Real>()) {
213  val_ += weight_ * r0;
214  }
215  if (std::abs(r1) >= ROL_EPSILON<Real>()) {
216  valLam_ -= weight_ * r1 * inp;
217  valMu_ -= weight_ * r1;
218  computeGradient(*dualVector_,obj,x,tol);
219  g_->axpy(weight_*r1,*dualVector_);
220  }
221  }
222 
224  std::vector<Real> &gstat,
225  const Vector<Real> &x,
226  const std::vector<Real> &xstat,
227  SampleGenerator<Real> &sampler) {
228  std::vector<Real> mygval(3), gval(3);
229  mygval[0] = val_;
230  mygval[1] = valLam_;
231  mygval[2] = valMu_;
232  sampler.sumAll(&mygval[0],&gval[0],3);
233 
234  gstat[0] = thresh_ + gval[0] + gval[1];
235  gstat[1] = static_cast<Real>(1) + gval[2];
236 
237  sampler.sumAll(*g_,g);
238  }
239 
241  const Vector<Real> &v,
242  const std::vector<Real> &vstat,
243  const Vector<Real> &x,
244  const std::vector<Real> &xstat,
245  Real &tol) {
246  Real val = computeValue(obj,x,tol);
247  Real xlam = xstat[0];
248  Real xmu = xstat[1];
249  Real vlam = vstat[0];
250  Real vmu = vstat[1];
251  Real inp = (val-xmu)/xlam;
252  Real r1 = Fdual(inp,1), r2 = Fdual(inp,2);
253  if (std::abs(r2) >= ROL_EPSILON<Real>()) {
254  Real gv = computeGradVec(*dualVector_,obj,v,x,tol);
255  val_ += weight_ * r2 * inp;
256  valLam_ += weight_ * r2 * inp * inp;
257  valLam2_ -= weight_ * r2 * gv * inp;
258  valMu_ += weight_ * r2;
259  valMu2_ -= weight_ * r2 * gv;
260  hv_->axpy(weight_ * r2 * (gv - vmu - vlam*inp)/xlam, *dualVector_);
261  }
262  if (std::abs(r1) >= ROL_EPSILON<Real>()) {
263  computeHessVec(*dualVector_,obj,v,x,tol);
264  hv_->axpy(weight_ * r1, *dualVector_);
265  }
266  }
267 
269  std::vector<Real> &hvstat,
270  const Vector<Real> &v,
271  const std::vector<Real> &vstat,
272  const Vector<Real> &x,
273  const std::vector<Real> &xstat,
274  SampleGenerator<Real> &sampler) {
275  std::vector<Real> myhval(5), hval(5);
276  myhval[0] = val_;
277  myhval[1] = valLam_;
278  myhval[2] = valLam2_;
279  myhval[3] = valMu_;
280  myhval[4] = valMu2_;
281  sampler.sumAll(&myhval[0],&hval[0],5);
282 
283  std::vector<Real> stat(2);
284  Real xlam = xstat[0];
285  //Real xmu = xstat[1];
286  Real vlam = vstat[0];
287  Real vmu = vstat[1];
288  hvstat[0] = (vlam * hval[1] + vmu * hval[0] + hval[2])/xlam;
289  hvstat[1] = (vlam * hval[0] + vmu * hval[3] + hval[4])/xlam;
290 
291  sampler.sumAll(*hv_,hv);
292  }
293 };
294 
295 }
296 
297 #endif
Provides the interface to evaluate objective functions.
void computeHessVec(Vector< Real > &hv, Objective< Real > &obj, const Vector< Real > &v, const Vector< Real > &x, Real &tol)
Ptr< Vector< Real > > g_
Real computeValue(Objective< Real > &obj, const Vector< Real > &x, Real &tol)
Ptr< Vector< Real > > hv_
virtual Real Fdual(Real x, int deriv=0) const =0
Implementation of the scalar dual F function.
Contains definitions of custom data types in ROL.
void initialize(const Vector< Real > &x)
Initialize temporary variables.
Ptr< Vector< Real > > dualVector_
virtual Real Fprimal(Real x, int deriv=0) const =0
Implementation of the scalar primal F function.
Defines the linear algebra or vector space interface.
Definition: ROL_Vector.hpp:46
void sumAll(Real *input, Real *output, int dim) const
Objective_SerialSimOpt(const Ptr< Obj > &obj, const V &ui) z0_ zero()
bool check(std::ostream &outStream=std::cout) const
void checkInputs(void) const
Real getValue(const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure value.
void getGradient(Vector< Real > &g, std::vector< Real > &gstat, const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure (sub)gradient.
void getHessVec(Vector< Real > &hv, std::vector< Real > &hvstat, const Vector< Real > &v, const std::vector< Real > &vstat, const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure Hessian-times-a-vector.
void updateHessVec(Objective< Real > &obj, const Vector< Real > &v, const std::vector< Real > &vstat, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal risk measure storage for Hessian-time-a-vector computation.
void updateGradient(Objective< Real > &obj, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal risk measure storage for gradient computation.
FDivergence(const Real thresh)
Constructor.
void computeGradient(Vector< Real > &g, Objective< Real > &obj, const Vector< Real > &x, Real &tol)
FDivergence(ROL::ParameterList &parlist)
Constructor.
Provides a general interface for the F-divergence distributionally robust expectation.
Real computeGradVec(Vector< Real > &g, Objective< Real > &obj, const Vector< Real > &v, const Vector< Real > &x, Real &tol)
Provides the interface to implement any functional that maps a random variable to a (extended) real n...
void updateValue(Objective< Real > &obj, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal storage for value computation.
virtual void initialize(const Vector< Real > &x)
Initialize temporary variables.