Sacado Package Browser (Single Doxygen Collection)  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
dfad_sfc_example.cpp
Go to the documentation of this file.
1 // @HEADER
2 // ***********************************************************************
3 //
4 // Sacado Package
5 // Copyright (2006) Sandia Corporation
6 //
7 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
8 // the U.S. Government retains certain rights in this software.
9 //
10 // This library is free software; you can redistribute it and/or modify
11 // it under the terms of the GNU Lesser General Public License as
12 // published by the Free Software Foundation; either version 2.1 of the
13 // License, or (at your option) any later version.
14 //
15 // This library is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 // Lesser General Public License for more details.
19 //
20 // You should have received a copy of the GNU Lesser General Public
21 // License along with this library; if not, write to the Free Software
22 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
23 // USA
24 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
25 // (etphipp@sandia.gov).
26 //
27 // ***********************************************************************
28 // @HEADER
29 
30 // dfad_sfc_example
31 //
32 // usage:
33 // dfad_sfc_example
34 //
35 // output:
36 // Uses the scalar flop counter to count the flops for a derivative
37 // of a simple function using DFad
38 
39 #include <iostream>
40 #include <iomanip>
41 
42 #include "Sacado_No_Kokkos.hpp"
43 
44 // The function to differentiate
45 template <typename ScalarT>
46 ScalarT func(const ScalarT& a, const ScalarT& b, const ScalarT& c) {
47  ScalarT r = c*std::log(b+1.)/std::sin(a);
48 
49  return r;
50 }
51 
52 // The analytic derivative of func(a,b,c) with respect to a and b
53 template <typename ScalarT>
54 void func_deriv(const ScalarT& a, const ScalarT& b, const ScalarT& c,
55  ScalarT& drda, ScalarT& drdb)
56 {
57  drda = -(c*std::log(b+1.)/std::pow(std::sin(a),2.))*std::cos(a);
58  drdb = c / ((b+1.)*std::sin(a));
59 }
60 
63 
64 int main(int argc, char **argv)
65 {
66  double pi = std::atan(1.0)*4.0;
67 
68  // Values of function arguments
69  double a = pi/4;
70  double b = 2.0;
71  double c = 3.0;
72 
73  // Number of independent variables
74  int num_deriv = 2;
75 
76  // Compute function
77  SFC as(a);
78  SFC bs(b);
79  SFC cs(c);
81  SFC rs = func(as, bs, cs);
83 
84  std::cout << "Flop counts for function evaluation:";
85  SFC::printCounters(std::cout);
86 
87  // Compute derivative analytically
88  SFC drdas, drdbs;
90  func_deriv(as, bs, cs, drdas, drdbs);
92 
93  std::cout << "\nFlop counts for analytic derivative evaluation:";
94  SFC::printCounters(std::cout);
95 
96  // Compute function and derivative with AD
97  FAD_SFC afad(num_deriv, 0, a);
98  FAD_SFC bfad(num_deriv, 1, b);
99  FAD_SFC cfad(c);
101  FAD_SFC rfad = func(afad, bfad, cfad);
103 
104  std::cout << "\nFlop counts for AD function and derivative evaluation:";
105  SFC::printCounters(std::cout);
106 
107  // Extract value and derivatives
108  double r = rs.val(); // r
109  double drda = drdas.val(); // dr/da
110  double drdb = drdbs.val(); // dr/db
111 
112  double r_ad = rfad.val().val(); // r
113  double drda_ad = rfad.dx(0).val(); // dr/da
114  double drdb_ad = rfad.dx(1).val(); // dr/db
115 
116  // Print the results
117  int p = 4;
118  int w = p+7;
119  std::cout.setf(std::ios::scientific);
120  std::cout.precision(p);
121  std::cout << "\nValues/derivatives of computation" << std::endl
122  << " r = " << r << " (original) == " << std::setw(w) << r_ad
123  << " (AD) Error = " << std::setw(w) << r - r_ad << std::endl
124  << "dr/da = " << std::setw(w) << drda << " (analytic) == "
125  << std::setw(w) << drda_ad << " (AD) Error = " << std::setw(w)
126  << drda - drda_ad << std::endl
127  << "dr/db = " << std::setw(w) << drdb << " (analytic) == "
128  << std::setw(w) << drdb_ad << " (AD) Error = " << std::setw(w)
129  << drdb - drdb_ad << std::endl;
130 
131  double tol = 1.0e-14;
133  // The Solaris and Irix CC compilers get higher counts for operator=
134  // than does g++, which avoids an extra copy when returning a function value.
135  // The test on fc.totalFlopCount allows for this variation.
136  if (std::fabs(r - r_ad) < tol &&
137  std::fabs(drda - drda_ad) < tol &&
138  std::fabs(drdb - drdb_ad) < tol &&
139  (fc.totalFlopCount == 40)) {
140  std::cout << "\nExample passed!" << std::endl;
141  return 0;
142  }
143  else {
144  std::cout <<"\nSomething is wrong, example failed!" << std::endl;
145  return 1;
146  }
147 }
static FlopCounts getCounters()
Get the flop counts after a block of computations.
static void resetCounters()
Reset static flop counters before starting a block of computations.
const T & val() const
Return the current value.
atan(expr.val())
KOKKOS_INLINE_FUNCTION mpl::enable_if_c< ExprLevel< Expr< T1 > >::value==ExprLevel< Expr< T2 > >::value, Expr< PowerOp< Expr< T1 >, Expr< T2 > > > >::type pow(const Expr< T1 > &expr1, const Expr< T2 > &expr2)
expr expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c *expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 c
static std::ostream & printCounters(std::ostream &out)
Print the current static flop counts to out.
int main()
Definition: ad_example.cpp:191
Class storing flop counts and summary flop counts.
TypeTo as(const TypeFrom &t)
void func_deriv(double a, double b, double c, double &drda, double &drdb)
sin(expr.val())
Sacado::Fad::DFad< SFC > FAD_SFC
log(expr.val())
const double tol
const T func(int n, T *x)
Definition: ad_example.cpp:49
static void finalizeCounters()
Finalize total flop count after block of computations.
Sacado::FlopCounterPack::ScalarFlopCounter< double > SFC
fabs(expr.val())
cos(expr.val())