ROL
ROL_ScalarMinimizationLineSearch_U.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Rapid Optimization Library (ROL) Package
4 //
5 // Copyright 2014 NTESS and the ROL contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef ROL_SCALARMINIMIZATIONLINESEARCH_U_H
11 #define ROL_SCALARMINIMIZATIONLINESEARCH_U_H
12 
18 #include "ROL_LineSearch_U.hpp"
19 #include "ROL_BrentsScalarMinimization.hpp"
20 #include "ROL_BisectionScalarMinimization.hpp"
21 #include "ROL_GoldenSectionScalarMinimization.hpp"
22 #include "ROL_ScalarFunction.hpp"
23 #include "ROL_Bracketing.hpp"
24 
25 namespace ROL {
26 
27 template<typename Real>
29 private:
30  Ptr<Vector<Real>> xnew_;
31  Ptr<Vector<Real>> g_;
32  Ptr<ScalarMinimization<Real>> sm_;
33  Ptr<Bracketing<Real>> br_;
34  Ptr<ScalarFunction<Real>> sf_;
35 
37  Real c1_;
38  Real c2_;
39  Real c3_;
41 
43 
44  class Phi : public ScalarFunction<Real> {
45  private:
46  const Ptr<Vector<Real>> xnew_;
47  const Ptr<const Vector<Real>> x_, s_;
48  const Ptr<Objective<Real>> obj_;
49  Real ftol_, alpha_, val_;
51  public:
52  Phi(const Ptr<Vector<Real>> &xnew,
53  const Ptr<const Vector<Real>> &x,
54  const Ptr<const Vector<Real>> &s,
55  const Ptr<Objective<Real>> &obj,
56  const bool FDdirDeriv = false)
57  : xnew_(xnew), x_(x), s_(s), obj_(obj),
58  ftol_(std::sqrt(ROL_EPSILON<Real>())),
59  alpha_(ROL_INF<Real>()), val_(ROL_INF<Real>()),
60  FDdirDeriv_(FDdirDeriv) {}
61  Real value(const Real alpha) {
62  if (alpha_ != alpha) {
63  alpha_ = alpha;
64  xnew_->set(*x_); xnew_->axpy(alpha,*s_);
65  obj_->update(*xnew_,UpdateType::Trial);
66  val_ = obj_->value(*xnew_,ftol_);
67  }
68  return val_;
69  }
70  Real deriv(const Real alpha) {
71  Real tol = std::sqrt(ROL_EPSILON<Real>());
72  Real val(0);
73  xnew_->set(*x_); xnew_->axpy(alpha,*s_);
74  if (FDdirDeriv_) {
75  Real snorm = s_->norm();
76  if (snorm > static_cast<Real>(0)) {
77  Real xnorm = xnew_->norm();
78  Real cbrteps = std::cbrt(ROL_EPSILON<Real>());
79  Real h = cbrteps*std::max(xnorm/snorm,static_cast<Real>(1));
80  Real fnew = value(alpha);
81  xnew_->axpy(h,*s_);
82  obj_->update(*xnew_,UpdateType::Trial);
83  Real ftrial = obj_->value(*xnew_,tol);
84  val = (ftrial - fnew) / h;
85  }
86  }
87  else {
88  val = obj_->dirDeriv(*xnew_,*s_,tol);
89  }
90  return val;
91  }
92  };
93 
94  class StatusTest : public ScalarMinimizationStatusTest<Real> {
95  private:
96  Ptr<ScalarFunction<Real>> phi_;
97 
98  const Real f0_;
99  const Real g0_;
100 
101  const Real c1_;
102  const Real c2_;
103  const Real c3_;
104  const int max_nfval_;
106 
107  public:
108  StatusTest(const Real f0, const Real g0,
109  const Real c1, const Real c2, const Real c3,
110  const int max_nfval, ECurvatureConditionU econd,
111  const Ptr<ScalarFunction<Real>> &phi)
112  : phi_(phi), f0_(f0), g0_(g0), c1_(c1), c2_(c2), c3_(c3),
113  max_nfval_(max_nfval), econd_(econd) {}
114 
115  bool check(Real &x, Real &fx, Real &gx,
116  int &nfval, int &ngval, const bool deriv = false) {
117  Real one(1), two(2);
118  bool armijo = (fx <= f0_ + c1_*x*g0_);
119 // bool itcond = (nfval >= max_nfval_);
120  bool curvcond = false;
121 // if (armijo && !itcond) {
122  if (armijo) {
124  curvcond = (fx >= f0_ + (one-c1_)*x*g0_);
125  }
126  else if (econd_ == CURVATURECONDITION_U_NULL) {
127  curvcond = true;
128  }
129  else {
130  if (!deriv) {
131  gx = phi_->deriv(x); ngval++;
132  }
134  curvcond = (gx >= c2_*g0_);
135  }
137  curvcond = (std::abs(gx) <= c2_*std::abs(g0_));
138  }
140  curvcond = (c2_*g0_ <= gx && gx <= -c3_*g0_);
141  }
143  curvcond = (c2_*g0_ <= gx && gx <= (two*c1_ - one)*g0_);
144  }
145  }
146  }
147  //return (armijo && curvcond) || itcond;
148  return (armijo && curvcond);
149  }
150  };
151 
154 
155 public:
156  // Constructor
157  ScalarMinimizationLineSearch_U( ParameterList &parlist,
158  const Ptr<ScalarMinimization<Real>> &sm = nullPtr,
159  const Ptr<Bracketing<Real>> &br = nullPtr,
160  const Ptr<ScalarFunction<Real>> &sf = nullPtr )
161  : LineSearch_U<Real>(parlist) {
162  const Real zero(0), p4(0.4), p6(0.6), p9(0.9), oem4(1.e-4), oem10(1.e-10), one(1);
163  ParameterList &list0 = parlist.sublist("Step").sublist("Line Search");
164  FDdirDeriv_ = list0.get("Finite Difference Directional Derivative",false);
165  ParameterList &list = list0.sublist("Line-Search Method");
166  // Get Bracketing Method
167  if( br == nullPtr ) {
168  br_ = makePtr<Bracketing<Real>>();
169  }
170  else {
171  br_ = br;
172  }
173  // Get ScalarMinimization Method
174  std::string type = list.get("Type","Brent's");
175  Real tol = list.sublist(type).get("Tolerance",oem10);
176  int niter = list.sublist(type).get("Iteration Limit",1000);
177  ROL::ParameterList plist;
178  plist.sublist("Scalar Minimization").set("Type",type);
179  plist.sublist("Scalar Minimization").sublist(type).set("Tolerance",tol);
180  plist.sublist("Scalar Minimization").sublist(type).set("Iteration Limit",niter);
181 
182  if( sm == nullPtr ) { // No user-provided ScalarMinimization object
183  if ( type == "Brent's" ) {
184  sm_ = makePtr<BrentsScalarMinimization<Real>>(plist);
185  }
186  else if ( type == "Bisection" ) {
187  sm_ = makePtr<BisectionScalarMinimization<Real>>(plist);
188  }
189  else if ( type == "Golden Section" ) {
190  sm_ = makePtr<GoldenSectionScalarMinimization<Real>>(plist);
191  }
192  else {
193  ROL_TEST_FOR_EXCEPTION(true, std::invalid_argument,
194  ">>> (ROL::ScalarMinimizationLineSearch): Undefined ScalarMinimization type!");
195  }
196  }
197  else {
198  sm_ = sm;
199  }
200  sf_ = sf;
201 
202  // Status test for line search
203  econd_ = StringToECurvatureConditionU(list0.sublist("Curvature Condition").get("Type","Strong Wolfe Conditions"));
204  max_nfval_ = list0.get("Function Evaluation Limit",20);
205  c1_ = list0.get("Sufficient Decrease Tolerance",oem4);
206  c2_ = list0.sublist("Curvature Condition").get("General Parameter",p9);
207  c3_ = list0.sublist("Curvature Condition").get("Generalized Wolfe Parameter",p6);
208  // Check status test inputs
209  c1_ = ((c1_ < zero) ? oem4 : c1_);
210  c2_ = ((c2_ < zero) ? p9 : c2_);
211  c3_ = ((c3_ < zero) ? p9 : c3_);
212  if ( c2_ <= c1_ ) {
213  c1_ = oem4;
214  c2_ = p9;
215  }
216  EDescentU edesc = StringToEDescentU(list0.sublist("Descent Method").get("Type","Quasi-Newton Method"));
217  if ( edesc == DESCENT_U_NONLINEARCG ) {
218  c2_ = p4;
219  c3_ = std::min(one-c2_,c3_);
220  }
221  }
222 
223  void initialize(const Vector<Real> &x, const Vector<Real> &g) override {
225  xnew_ = x.clone();
226  g_ = g.clone();
227  }
228 
229  // Find the minimum of phi(alpha) = f(x + alpha*s) using Brent's method
230  void run( Real &alpha, Real &fval, int &ls_neval, int &ls_ngrad,
231  const Real &gs, const Vector<Real> &s, const Vector<Real> &x,
232  Objective<Real> &obj ) override {
233  ls_neval = 0; ls_ngrad = 0;
234 
235  // Get initial line search parameter
236  alpha = getInitialAlpha(ls_neval,ls_ngrad,fval,gs,x,s,obj);
237 
238  // Build ScalarFunction and ScalarMinimizationStatusTest
239  Ptr<const Vector<Real>> x_ptr = ROL::makePtrFromRef(x);
240  Ptr<const Vector<Real>> s_ptr = ROL::makePtrFromRef(s);
241  Ptr<Objective<Real>> obj_ptr = ROL::makePtrFromRef(obj);
242 
243  Ptr<ScalarFunction<Real>> phi;
244  if( sf_ == ROL::nullPtr ) {
245  phi = ROL::makePtr<Phi>(xnew_,x_ptr,s_ptr,obj_ptr,FDdirDeriv_);
246  }
247  else {
248  phi = sf_;
249  }
250 
251  Ptr<ScalarMinimizationStatusTest<Real>> test
252  = makePtr<StatusTest>(fval,gs,c1_,c2_,c3_,max_nfval_,econd_,phi);
253 
254  // Run Bracketing
255  int nfval = 0, ngrad = 0;
256  Real A(0), fA = fval;
257  Real B = alpha, fB = phi->value(B);
258  br_->run(alpha,fval,A,fA,B,fB,nfval,ngrad,*phi,*test);
259  B = alpha;
260  ls_neval += nfval; ls_ngrad += ngrad;
261 
262  // Run ScalarMinimization
263  nfval = 0, ngrad = 0;
264  sm_->run(fval, alpha, nfval, ngrad, *phi, A, B, *test);
265  ls_neval += nfval; ls_ngrad += ngrad;
266 
267  setNextInitialAlpha(alpha);
268  }
269 }; // class ROL::ScalarMinimization_U
270 
271 } // namespace ROL
272 
273 #endif
Provides interface for and implements line searches.
virtual void initialize(const Vector< Real > &x, const Vector< Real > &g)
Provides the interface to evaluate objective functions.
void setNextInitialAlpha(Real alpha)
virtual ROL::Ptr< Vector > clone() const =0
Clone to make a new (uninitialized) vector.
EDescentU
Enumeration of descent direction types.
Real ROL_INF(void)
Definition: ROL_Types.hpp:71
Implements line search methods that attempt to minimize the scalar function .
Phi(const Ptr< Vector< Real >> &xnew, const Ptr< const Vector< Real >> &x, const Ptr< const Vector< Real >> &s, const Ptr< Objective< Real >> &obj, const bool FDdirDeriv=false)
ECurvatureConditionU
Enumeration of line-search curvature conditions.
Defines the linear algebra or vector space interface.
Definition: ROL_Vector.hpp:46
ScalarMinimizationLineSearch_U(ParameterList &parlist, const Ptr< ScalarMinimization< Real >> &sm=nullPtr, const Ptr< Bracketing< Real >> &br=nullPtr, const Ptr< ScalarFunction< Real >> &sf=nullPtr)
virtual Real getInitialAlpha(int &ls_neval, int &ls_ngrad, const Real fval, const Real gs, const Vector< Real > &x, const Vector< Real > &s, Objective< Real > &obj)
Objective_SerialSimOpt(const Ptr< Obj > &obj, const V &ui) z0_ zero()
EDescentU StringToEDescentU(std::string s)
StatusTest(const Real f0, const Real g0, const Real c1, const Real c2, const Real c3, const int max_nfval, ECurvatureConditionU econd, const Ptr< ScalarFunction< Real >> &phi)
void initialize(const Vector< Real > &x, const Vector< Real > &g) override
void run(Real &alpha, Real &fval, int &ls_neval, int &ls_ngrad, const Real &gs, const Vector< Real > &s, const Vector< Real > &x, Objective< Real > &obj) override
Real ROL_EPSILON(void)
Platform-dependent machine epsilon.
Definition: ROL_Types.hpp:57
bool check(Real &x, Real &fx, Real &gx, int &nfval, int &ngval, const bool deriv=false)
ECurvatureConditionU StringToECurvatureConditionU(std::string s)