ROL
ROL_KLDivergence.hpp
Go to the documentation of this file.
1 // @HEADER
2 // ************************************************************************
3 //
4 // Rapid Optimization Library (ROL) Package
5 // Copyright (2014) Sandia Corporation
6 //
7 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
8 // license for use of this work by or on behalf of the U.S. Government.
9 //
10 // Redistribution and use in source and binary forms, with or without
11 // modification, are permitted provided that the following conditions are
12 // met:
13 //
14 // 1. Redistributions of source code must retain the above copyright
15 // notice, this list of conditions and the following disclaimer.
16 //
17 // 2. Redistributions in binary form must reproduce the above copyright
18 // notice, this list of conditions and the following disclaimer in the
19 // documentation and/or other materials provided with the distribution.
20 //
21 // 3. Neither the name of the Corporation nor the names of the
22 // contributors may be used to endorse or promote products derived from
23 // this software without specific prior written permission.
24 //
25 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
26 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
28 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
29 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
30 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
32 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 //
37 // Questions? Contact lead developers:
38 // Drew Kouri (dpkouri@sandia.gov) and
39 // Denis Ridzal (dridzal@sandia.gov)
40 //
41 // ************************************************************************
42 // @HEADER
43 
44 #ifndef ROL_KLDIVERGENCE_HPP
45 #define ROL_KLDIVERGENCE_HPP
46 
48 
78 namespace ROL {
79 
80 template<class Real>
81 class KLDivergence : public RandVarFunctional<Real> {
82 private:
83  Real eps_;
84 
85  Real gval_;
86  Real gvval_;
87  Real hval_;
88  ROL::Ptr<Vector<Real> > scaledGradient_;
89  ROL::Ptr<Vector<Real> > scaledHessVec_;
90 
92 
98 
101 
106 
107  void checkInputs(void) const {
108  Real zero(0);
109  ROL_TEST_FOR_EXCEPTION((eps_ <= zero), std::invalid_argument,
110  ">>> ERROR (ROL::KLDivergence): Threshold must be positive!");
111  }
112 
113 public:
118  KLDivergence(const Real eps = 1.e-2)
119  : RandVarFunctional<Real>(), eps_(eps), firstResetKLD_(true) {
120  checkInputs();
121  }
122 
131  KLDivergence(ROL::ParameterList &parlist)
132  : RandVarFunctional<Real>(), firstResetKLD_(true) {
133  ROL::ParameterList &list
134  = parlist.sublist("SOL").sublist("Risk Measure").sublist("KL Divergence");
135  eps_ = list.get<Real>("Threshold");
136  checkInputs();
137  }
138 
139  void initialize(const Vector<Real> &x) {
141  if ( firstResetKLD_ ) {
142  scaledGradient_ = x.dual().clone();
143  scaledHessVec_ = x.dual().clone();
144  firstResetKLD_ = false;
145  }
146  const Real zero(0);
147  gval_ = zero; gvval_ = zero; hval_ = zero;
148  scaledGradient_->zero(); scaledHessVec_->zero();
149  }
150 
152  const Vector<Real> &x,
153  const std::vector<Real> &xstat,
154  Real &tol) {
155  Real val = computeValue(obj,x,tol);
156  Real ev = exponential(val,xstat[0]*eps_);
157  val_ += weight_ * ev;
158  }
159 
160  Real getValue(const Vector<Real> &x,
161  const std::vector<Real> &xstat,
162  SampleGenerator<Real> &sampler) {
163  if ( xstat[0] == static_cast<Real>(0) ) {
164  return ROL_INF<Real>();
165  }
166  Real ev(0);
167  sampler.sumAll(&val_,&ev,1);
168  return (static_cast<Real>(1) + std::log(ev)/eps_)/xstat[0];
169  }
170 
172  const Vector<Real> &x,
173  const std::vector<Real> &xstat,
174  Real &tol) {
175  Real val = computeValue(obj,x,tol);
176  Real ev = exponential(val,xstat[0]*eps_);
177  val_ += weight_ * ev;
178  gval_ += weight_ * ev * val;
179  computeGradient(*dualVector_,obj,x,tol);
180  g_->axpy(weight_*ev,*dualVector_);
181  }
182 
184  std::vector<Real> &gstat,
185  const Vector<Real> &x,
186  const std::vector<Real> &xstat,
187  SampleGenerator<Real> &sampler) {
188  std::vector<Real> local(2), global(2);
189  local[0] = val_;
190  local[1] = gval_;
191  sampler.sumAll(&local[0],&global[0],2);
192  Real ev = global[0], egval = global[1];
193 
194  sampler.sumAll(*g_,g);
195  g.scale(static_cast<Real>(1)/ev);
196 
197  if ( xstat[0] == static_cast<Real>(0) ) {
198  gstat[0] = ROL_INF<Real>();
199  }
200  else {
201  gstat[0] = -((static_cast<Real>(1) + std::log(ev)/eps_)/xstat[0]
202  - egval/ev)/xstat[0];
203  }
204  }
205 
207  const Vector<Real> &v,
208  const std::vector<Real> &vstat,
209  const Vector<Real> &x,
210  const std::vector<Real> &xstat,
211  Real &tol) {
212  Real val = computeValue(obj,x,tol);
213  Real ev = exponential(val,xstat[0]*eps_);
214  Real gv = computeGradVec(*dualVector_,obj,v,x,tol);
215  val_ += weight_ * ev;
216  gv_ += weight_ * ev * gv;
217  gval_ += weight_ * ev * val;
218  gvval_ += weight_ * ev * val * gv;
219  hval_ += weight_ * ev * val * val;
220  g_->axpy(weight_*ev,*dualVector_);
221  scaledGradient_->axpy(weight_*ev*gv,*dualVector_);
222  scaledHessVec_->axpy(weight_*ev*val,*dualVector_);
223  computeHessVec(*dualVector_,obj,v,x,tol);
224  hv_->axpy(weight_*ev,*dualVector_);
225  }
226 
228  std::vector<Real> &hvstat,
229  const Vector<Real> &v,
230  const std::vector<Real> &vstat,
231  const Vector<Real> &x,
232  const std::vector<Real> &xstat,
233  SampleGenerator<Real> &sampler) {
234  std::vector<Real> local(5), global(5);
235  local[0] = val_;
236  local[1] = gv_;
237  local[2] = gval_;
238  local[3] = gvval_;
239  local[4] = hval_;
240  sampler.sumAll(&local[0],&global[0],5);
241  Real ev = global[0], egv = global[1], egval = global[2];
242  Real egvval = global[3], ehval = global[4];
243  Real c0 = static_cast<Real>(1)/ev, c1 = c0*egval, c2 = c0*egv, c3 = eps_*c0;
244 
245  sampler.sumAll(*hv_,hv);
246  dualVector_->zero();
248  hv.axpy(xstat[0]*eps_,*dualVector_);
249  hv.scale(c0);
250 
251  dualVector_->zero();
252  sampler.sumAll(*g_,*dualVector_);
253  hv.axpy(-c3*(vstat[0]*c1 + xstat[0]*c2),*dualVector_);
254 
255  dualVector_->zero();
256  sampler.sumAll(*scaledHessVec_,*dualVector_);
257  hv.axpy(vstat[0]*c3,*dualVector_);
258 
259  if ( xstat[0] == static_cast<Real>(0) ) {
260  hvstat[0] = ROL_INF<Real>();
261  }
262  else {
263  Real xstat2 = static_cast<Real>(2)/(xstat[0]*xstat[0]);
264  Real h11 = xstat2*((static_cast<Real>(1) + std::log(ev)/eps_)/xstat[0] - c1)
265  + (c3*ehval - eps_*c1*c1)/xstat[0];
266  hvstat[0] = vstat[0] * h11 + (c3*egvval - eps_*c1*c2);
267  }
268  }
269 
270 private:
271  Real exponential(const Real arg1, const Real arg2) const {
272  if ( arg1 < arg2 ) {
273  return power(exponential(arg1),arg2);
274  }
275  else {
276  return power(exponential(arg2),arg1);
277  }
278  }
279 
280  Real exponential(const Real arg) const {
281  if ( arg >= std::log(ROL_INF<Real>()) ) {
282  return ROL_INF<Real>();
283  }
284  else {
285  return std::exp(arg);
286  }
287  }
288 
289  Real power(const Real arg, const Real pow) const {
290  if ( arg >= std::pow(ROL_INF<Real>(),static_cast<Real>(1)/pow) ) {
291  return ROL_INF<Real>();
292  }
293  else {
294  return std::pow(arg,pow);
295  }
296  }
297 };
298 
299 }
300 
301 #endif
Provides the interface to evaluate objective functions.
virtual const Vector & dual() const
Return dual representation of , for example, the result of applying a Riesz map, or change of basis...
Definition: ROL_Vector.hpp:226
void computeHessVec(Vector< Real > &hv, Objective< Real > &obj, const Vector< Real > &v, const Vector< Real > &x, Real &tol)
virtual void scale(const Real alpha)=0
Compute where .
Ptr< Vector< Real > > g_
Real computeValue(Objective< Real > &obj, const Vector< Real > &x, Real &tol)
Real getValue(const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure value.
virtual void axpy(const Real alpha, const Vector &x)
Compute where .
Definition: ROL_Vector.hpp:153
Ptr< Vector< Real > > hv_
Provides an interface for the Kullback-Leibler distributionally robust expectation.
void checkInputs(void) const
void updateHessVec(Objective< Real > &obj, const Vector< Real > &v, const std::vector< Real > &vstat, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal risk measure storage for Hessian-time-a-vector computation.
Ptr< Vector< Real > > dualVector_
Defines the linear algebra or vector space interface.
Definition: ROL_Vector.hpp:80
Real exponential(const Real arg) const
void sumAll(Real *input, Real *output, int dim) const
Objective_SerialSimOpt(const Ptr< Obj > &obj, const V &ui) z0_ zero()
KLDivergence(ROL::ParameterList &parlist)
Constructor.
void updateGradient(Objective< Real > &obj, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal risk measure storage for gradient computation.
void initialize(const Vector< Real > &x)
Initialize temporary variables.
ROL::Ptr< Vector< Real > > scaledGradient_
void getGradient(Vector< Real > &g, std::vector< Real > &gstat, const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure (sub)gradient.
Real exponential(const Real arg1, const Real arg2) const
void computeGradient(Vector< Real > &g, Objective< Real > &obj, const Vector< Real > &x, Real &tol)
Real power(const Real arg, const Real pow) const
void updateValue(Objective< Real > &obj, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal storage for value computation.
Real computeGradVec(Vector< Real > &g, Objective< Real > &obj, const Vector< Real > &v, const Vector< Real > &x, Real &tol)
Provides the interface to implement any functional that maps a random variable to a (extended) real n...
KLDivergence(const Real eps=1.e-2)
Constructor.
ROL::Ptr< Vector< Real > > scaledHessVec_
virtual void initialize(const Vector< Real > &x)
Initialize temporary variables.
void getHessVec(Vector< Real > &hv, std::vector< Real > &hvstat, const Vector< Real > &v, const std::vector< Real > &vstat, const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure Hessian-times-a-vector.