Sacado Package Browser (Single Doxygen Collection)  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
dfad_sfad_example.cpp
Go to the documentation of this file.
1 // @HEADER
2 // ***********************************************************************
3 //
4 // Sacado Package
5 // Copyright (2006) Sandia Corporation
6 //
7 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
8 // the U.S. Government retains certain rights in this software.
9 //
10 // This library is free software; you can redistribute it and/or modify
11 // it under the terms of the GNU Lesser General Public License as
12 // published by the Free Software Foundation; either version 2.1 of the
13 // License, or (at your option) any later version.
14 //
15 // This library is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 // Lesser General Public License for more details.
19 //
20 // You should have received a copy of the GNU Lesser General Public
21 // License along with this library; if not, write to the Free Software
22 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
23 // USA
24 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
25 // (etphipp@sandia.gov).
26 //
27 // ***********************************************************************
28 // @HEADER
29 
30 // dfad_sfad_example
31 //
32 // usage:
33 // dfad_sfad_example
34 //
35 // output:
36 // prints the results of computing the second derivative times a vector
37 // for a simple function with forward nested forward mode AD using the
38 // Sacado::Fad::DFad and Sacado::Fad::SFad classes.
39 
40 #include <iostream>
41 #include <iomanip>
42 
43 #include "Sacado.hpp"
44 
45 // The function to differentiate
46 template <typename ScalarT>
47 ScalarT func(const ScalarT& a, const ScalarT& b, const ScalarT& c) {
48  ScalarT r = c*std::log(b+1.)/std::sin(a);
49  return r;
50 }
51 
52 // The analytic first and second derivative of func with respect to a and b
53 void analytic_deriv(double a, double b, double c,
54  double& drda, double& drdb,
55  double& d2rda2, double& d2rdb2, double& d2rdadb)
56 {
57  drda = -(c*std::log(b+1.)/std::pow(std::sin(a),2.))*std::cos(a);
58  drdb = c / ((b+1.)*std::sin(a));
59  d2rda2 = c*std::log(b+1.)/std::sin(a) + 2.*(c*std::log(b+1.)/std::pow(std::sin(a),3.))*std::pow(std::cos(a),2.);
60  d2rdb2 = -c / (std::pow(b+1.,2.)*std::sin(a));
61  d2rdadb = -c / ((b+1.)*std::pow(std::sin(a),2.))*std::cos(a);
62 }
63 
64 // Function that computes func and its first derivative w.r.t a & b using
65 // Sacado AD
66 template <typename ScalarT>
67 void func_and_deriv(const ScalarT& a, const ScalarT& b, const ScalarT& c,
68  ScalarT& r, ScalarT& drda, ScalarT& drdb) {
70  FadType a_fad(2, 0, a);
71  FadType b_fad(2, 1, b);
72  FadType c_fad = c;
73 
74  FadType r_fad = func(a_fad, b_fad, c_fad);
75  r = r_fad.val();
76  drda = r_fad.dx(0);
77  drdb = r_fad.dx(1);
78 }
79 
80 // Function that computes func, its first derivative w.r.t a & b, and its
81 // second derivative in the direction of [v_a, v_b] with Sacado AD
82 //
83 // Define x = [a, b], v = [v_a, v_b], and y(t) = x + t*v. Then
84 // df/dx*v = d/dt f(y(t)) |_{t=0}.
85 //
86 // In the code below, we differentiate with respect to t in this manner.
87 // Addtionally we take a short-cut and don't introduce t directly and
88 // compute a(t) = a + t*v_a, b(t) = b + t*v_b. Instead we
89 // initialize a_fad and b_fad directly as if we had computed them in this way.
90 template <typename ScalarT>
91 void func_and_deriv2(const ScalarT& a, const ScalarT& b, const ScalarT& c,
92  const ScalarT& v_a, const ScalarT& v_b,
93  ScalarT& r, ScalarT& drda, ScalarT& drdb,
94  ScalarT& z_a, ScalarT& z_b) {
96 
97  // The below is equivalent to:
98  // FadType t(1, 0.0); f_fad.fastAccessDx(0) = 1;
99  // FadType a_fad = a + t*v_a;
100  // FadType b_fad = b + t*v_b;
101  FadType a_fad(1, a); a_fad.fastAccessDx(0) = v_a;
102  FadType b_fad(1, b); b_fad.fastAccessDx(0) = v_b;
103  FadType c_fad = c;
104 
105  FadType r_fad, drda_fad, drdb_fad;
106  func_and_deriv(a_fad, b_fad, c_fad, r_fad, drda_fad, drdb_fad);
107  r = r_fad.val(); // r
108  // note: also have r_fad.dx(0) = dr/da*v_a + dr/db*v_b
109  drda = drda_fad.val(); // dr/da
110  drdb = drdb_fad.val(); // dr/db
111  z_a = drda_fad.dx(0); // d^2r/da^2 * v_a + d^2r/dadb * v_b
112  z_b = drdb_fad.dx(0); // d^2r/dadb * v_a + d^2r/db^2 * v_b
113 }
114 
115 int main(int argc, char **argv)
116 {
117  double pi = std::atan(1.0)*4.0;
118 
119  // Values of function arguments
120  double a = pi/4;
121  double b = 2.0;
122  double c = 3.0;
123 
124  // Direction we wish to differentiate for second derivative
125  double v_a = 1.5;
126  double v_b = 3.6;
127 
128  // Compute derivatives via AD
129  double r_ad, drda_ad, drdb_ad, z_a_ad, z_b_ad;
130  func_and_deriv2(a, b, c, v_a, v_b, r_ad, drda_ad, drdb_ad, z_a_ad, z_b_ad);
131 
132  // Compute function
133  double r = func(a, b, c);
134 
135  // Compute derivatives analytically
136  double drda, drdb, d2rda2, d2rdb2, d2rdadb;
137  analytic_deriv(a, b, c, drda, drdb, d2rda2, d2rdb2, d2rdadb);
138  double z_a = d2rda2*v_a + d2rdadb*v_b;
139  double z_b = d2rdadb*v_a + d2rdb2*v_b;
140 
141  // Print the results
142  int p = 4;
143  int w = p+7;
144  std::cout.setf(std::ios::scientific);
145  std::cout.precision(p);
146  std::cout << " r = " << std::setw(w) << r << " (original) == "
147  << std::setw(w) << r_ad << " (AD) Error = " << std::setw(w)
148  << r - r_ad << std::endl
149  << "dr/da = " << std::setw(w) << drda << " (analytic) == "
150  << std::setw(w) << drda_ad << " (AD) Error = " << std::setw(w)
151  << drda - drda_ad << std::endl
152  << "dr/db = " << std::setw(w) << drdb << " (analytic) == "
153  << std::setw(w) << drdb_ad << " (AD) Error = " << std::setw(w)
154  << drdb - drdb_ad << std::endl
155  << "z_a = " << std::setw(w) << z_a << " (analytic) == "
156  << std::setw(w) << z_a_ad << " (AD) Error = " << std::setw(w)
157  << z_a - z_a_ad << std::endl
158  << "z_b = " << std::setw(w) << z_b << " (analytic) == "
159  << std::setw(w) << z_b_ad << " (AD) Error = " << std::setw(w)
160  << z_b - z_b_ad << std::endl;
161 
162  double tol = 1.0e-14;
163  if (std::fabs(r - r_ad) < tol &&
164  std::fabs(drda - drda_ad) < tol &&
165  std::fabs(drdb - drdb_ad) < tol &&
166  std::fabs(z_a - z_a_ad) < tol &&
167  std::fabs(z_b - z_b_ad) < tol) {
168  std::cout << "\nExample passed!" << std::endl;
169  return 0;
170  }
171  else {
172  std::cout <<"\nSomething is wrong, example failed!" << std::endl;
173  return 1;
174  }
175 }
Sacado::Fad::DFad< double > FadType
atan(expr.val())
KOKKOS_INLINE_FUNCTION mpl::enable_if_c< ExprLevel< Expr< T1 > >::value==ExprLevel< Expr< T2 > >::value, Expr< PowerOp< Expr< T1 >, Expr< T2 > > > >::type pow(const Expr< T1 > &expr1, const Expr< T2 > &expr2)
expr expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c *expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 c
void func_and_deriv2(const ScalarT &a, const ScalarT &b, const ScalarT &c, const ScalarT &v_a, const ScalarT &v_b, ScalarT &r, ScalarT &drda, ScalarT &drdb, ScalarT &z_a, ScalarT &z_b)
int main()
Definition: ad_example.cpp:191
sin(expr.val())
log(expr.val())
const double tol
void analytic_deriv(double a, double b, double c, double &drda, double &drdb, double &d2rda2, double &d2rdb2, double &d2rdadb)
const T func(int n, T *x)
Definition: ad_example.cpp:49
fabs(expr.val())
void func_and_deriv(const ScalarT &a, const ScalarT &b, const ScalarT &c, ScalarT &r, ScalarT &drda, ScalarT &drdb)
cos(expr.val())