Sacado Package Browser (Single Doxygen Collection)  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
dfad_sfad_example.cpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Sacado Package
4 //
5 // Copyright 2006 NTESS and the Sacado contributors.
6 // SPDX-License-Identifier: LGPL-2.1-or-later
7 // *****************************************************************************
8 // @HEADER
9 
10 // dfad_sfad_example
11 //
12 // usage:
13 // dfad_sfad_example
14 //
15 // output:
16 // prints the results of computing the second derivative times a vector
17 // for a simple function with forward nested forward mode AD using the
18 // Sacado::Fad::DFad and Sacado::Fad::SFad classes.
19 
20 #include <iostream>
21 #include <iomanip>
22 
23 #include "Sacado.hpp"
24 
25 // The function to differentiate
26 template <typename ScalarT>
27 ScalarT func(const ScalarT& a, const ScalarT& b, const ScalarT& c) {
28  ScalarT r = c*std::log(b+1.)/std::sin(a);
29  return r;
30 }
31 
32 // The analytic first and second derivative of func with respect to a and b
33 void analytic_deriv(double a, double b, double c,
34  double& drda, double& drdb,
35  double& d2rda2, double& d2rdb2, double& d2rdadb)
36 {
37  drda = -(c*std::log(b+1.)/std::pow(std::sin(a),2.))*std::cos(a);
38  drdb = c / ((b+1.)*std::sin(a));
39  d2rda2 = c*std::log(b+1.)/std::sin(a) + 2.*(c*std::log(b+1.)/std::pow(std::sin(a),3.))*std::pow(std::cos(a),2.);
40  d2rdb2 = -c / (std::pow(b+1.,2.)*std::sin(a));
41  d2rdadb = -c / ((b+1.)*std::pow(std::sin(a),2.))*std::cos(a);
42 }
43 
44 // Function that computes func and its first derivative w.r.t a & b using
45 // Sacado AD
46 template <typename ScalarT>
47 void func_and_deriv(const ScalarT& a, const ScalarT& b, const ScalarT& c,
48  ScalarT& r, ScalarT& drda, ScalarT& drdb) {
50  FadType a_fad(2, 0, a);
51  FadType b_fad(2, 1, b);
52  FadType c_fad = c;
53 
54  FadType r_fad = func(a_fad, b_fad, c_fad);
55  r = r_fad.val();
56  drda = r_fad.dx(0);
57  drdb = r_fad.dx(1);
58 }
59 
60 // Function that computes func, its first derivative w.r.t a & b, and its
61 // second derivative in the direction of [v_a, v_b] with Sacado AD
62 //
63 // Define x = [a, b], v = [v_a, v_b], and y(t) = x + t*v. Then
64 // df/dx*v = d/dt f(y(t)) |_{t=0}.
65 //
66 // In the code below, we differentiate with respect to t in this manner.
67 // Addtionally we take a short-cut and don't introduce t directly and
68 // compute a(t) = a + t*v_a, b(t) = b + t*v_b. Instead we
69 // initialize a_fad and b_fad directly as if we had computed them in this way.
70 template <typename ScalarT>
71 void func_and_deriv2(const ScalarT& a, const ScalarT& b, const ScalarT& c,
72  const ScalarT& v_a, const ScalarT& v_b,
73  ScalarT& r, ScalarT& drda, ScalarT& drdb,
74  ScalarT& z_a, ScalarT& z_b) {
76 
77  // The below is equivalent to:
78  // FadType t(1, 0.0); f_fad.fastAccessDx(0) = 1;
79  // FadType a_fad = a + t*v_a;
80  // FadType b_fad = b + t*v_b;
81  FadType a_fad(1, a); a_fad.fastAccessDx(0) = v_a;
82  FadType b_fad(1, b); b_fad.fastAccessDx(0) = v_b;
83  FadType c_fad = c;
84 
85  FadType r_fad, drda_fad, drdb_fad;
86  func_and_deriv(a_fad, b_fad, c_fad, r_fad, drda_fad, drdb_fad);
87  r = r_fad.val(); // r
88  // note: also have r_fad.dx(0) = dr/da*v_a + dr/db*v_b
89  drda = drda_fad.val(); // dr/da
90  drdb = drdb_fad.val(); // dr/db
91  z_a = drda_fad.dx(0); // d^2r/da^2 * v_a + d^2r/dadb * v_b
92  z_b = drdb_fad.dx(0); // d^2r/dadb * v_a + d^2r/db^2 * v_b
93 }
94 
95 int main(int argc, char **argv)
96 {
97  double pi = std::atan(1.0)*4.0;
98 
99  // Values of function arguments
100  double a = pi/4;
101  double b = 2.0;
102  double c = 3.0;
103 
104  // Direction we wish to differentiate for second derivative
105  double v_a = 1.5;
106  double v_b = 3.6;
107 
108  // Compute derivatives via AD
109  double r_ad, drda_ad, drdb_ad, z_a_ad, z_b_ad;
110  func_and_deriv2(a, b, c, v_a, v_b, r_ad, drda_ad, drdb_ad, z_a_ad, z_b_ad);
111 
112  // Compute function
113  double r = func(a, b, c);
114 
115  // Compute derivatives analytically
116  double drda, drdb, d2rda2, d2rdb2, d2rdadb;
117  analytic_deriv(a, b, c, drda, drdb, d2rda2, d2rdb2, d2rdadb);
118  double z_a = d2rda2*v_a + d2rdadb*v_b;
119  double z_b = d2rdadb*v_a + d2rdb2*v_b;
120 
121  // Print the results
122  int p = 4;
123  int w = p+7;
124  std::cout.setf(std::ios::scientific);
125  std::cout.precision(p);
126  std::cout << " r = " << std::setw(w) << r << " (original) == "
127  << std::setw(w) << r_ad << " (AD) Error = " << std::setw(w)
128  << r - r_ad << std::endl
129  << "dr/da = " << std::setw(w) << drda << " (analytic) == "
130  << std::setw(w) << drda_ad << " (AD) Error = " << std::setw(w)
131  << drda - drda_ad << std::endl
132  << "dr/db = " << std::setw(w) << drdb << " (analytic) == "
133  << std::setw(w) << drdb_ad << " (AD) Error = " << std::setw(w)
134  << drdb - drdb_ad << std::endl
135  << "z_a = " << std::setw(w) << z_a << " (analytic) == "
136  << std::setw(w) << z_a_ad << " (AD) Error = " << std::setw(w)
137  << z_a - z_a_ad << std::endl
138  << "z_b = " << std::setw(w) << z_b << " (analytic) == "
139  << std::setw(w) << z_b_ad << " (AD) Error = " << std::setw(w)
140  << z_b - z_b_ad << std::endl;
141 
142  double tol = 1.0e-14;
143  if (std::fabs(r - r_ad) < tol &&
144  std::fabs(drda - drda_ad) < tol &&
145  std::fabs(drdb - drdb_ad) < tol &&
146  std::fabs(z_a - z_a_ad) < tol &&
147  std::fabs(z_b - z_b_ad) < tol) {
148  std::cout << "\nExample passed!" << std::endl;
149  return 0;
150  }
151  else {
152  std::cout <<"\nSomething is wrong, example failed!" << std::endl;
153  return 1;
154  }
155 }
const char * p
Sacado::Fad::DFad< double > FadType
atan(expr.val())
expr expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c *expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 c
void func_and_deriv2(const ScalarT &a, const ScalarT &b, const ScalarT &c, const ScalarT &v_a, const ScalarT &v_b, ScalarT &r, ScalarT &drda, ScalarT &drdb, ScalarT &z_a, ScalarT &z_b)
int main()
Definition: ad_example.cpp:171
sin(expr.val())
log(expr.val())
const double tol
void analytic_deriv(double a, double b, double c, double &drda, double &drdb, double &d2rda2, double &d2rdb2, double &d2rdadb)
const T func(int n, T *x)
Definition: ad_example.cpp:29
SACADO_INLINE_FUNCTION mpl::enable_if_c< ExprLevel< Expr< T1 > >::value==ExprLevel< Expr< T2 > >::value, Expr< PowerOp< Expr< T1 >, Expr< T2 > > > >::type pow(const Expr< T1 > &expr1, const Expr< T2 > &expr2)
fabs(expr.val())
void func_and_deriv(const ScalarT &a, const ScalarT &b, const ScalarT &c, ScalarT &r, ScalarT &drda, ScalarT &drdb)
cos(expr.val())