10 #ifndef NESTED_FADUNITTESTS_HPP
11 #define NESTED_FADUNITTESTS_HPP
26 template <
class FadFadType>
63 for (
int j=0; j<
n2; j++) {
65 val2 =
urand.number();
66 a_dfad_.val().fastAccessDx(j) = val2;
67 a_fad_.val().fastAccessDx(j) = val2;
69 val2 =
urand.number();
70 b_dfad_.val().fastAccessDx(j) = val2;
71 b_fad_.val().fastAccessDx(j) = val2;
74 for (
int i=0;
i<
n1;
i++) {
83 for (
int j=0; j<
n2; j++) {
85 val2 =
urand.number();
86 a_dfad_.fastAccessDx(
i).fastAccessDx(j) = val2;
87 a_fad_.fastAccessDx(
i).fastAccessDx(j) = val2;
89 val2 =
urand.number();
90 b_dfad_.fastAccessDx(
i).fastAccessDx(j) = val2;
91 b_fad_.fastAccessDx(
i).fastAccessDx(j) = val2;
102 template <
typename ScalarT>
104 ScalarT t1 = 3. * a +
sin(b) /
log(
fabs(a - b * 7.));
110 t1 /=
cosh(b - 0.7) + 7.*
sinh(t1 + 0.8)*
tanh(9./a) - 9.;
116 template <
typename ScalarT>
134 #define BINARY_OP_TEST(FIXTURENAME,TESTNAME,OP) \
135 TYPED_TEST_P(FIXTURENAME, TESTNAME) { \
136 typedef decltype(this->a_dfad_) FadFadType; \
137 typedef typename Sacado::ValueType<FadFadType>::type FadType; \
138 auto a_dfad = this->a_dfad_; \
139 auto b_dfad = this->b_dfad_; \
140 auto c_dfad = this->c_dfad_; \
141 auto a_fad = this->a_fad_; \
142 auto b_fad = this->b_fad_; \
143 auto c_fad = this->c_fad_; \
144 c_dfad = a_dfad OP b_dfad; \
145 c_fad = a_fad OP b_fad; \
146 COMPARE_NESTED_FADS(c_dfad, c_fad); \
148 double val = this->urand.number(); \
149 c_dfad = a_dfad OP val; \
150 c_fad = a_fad OP FadType(val); \
151 COMPARE_NESTED_FADS(c_dfad, c_fad); \
153 c_dfad = val OP b_dfad; \
154 c_fad = FadType(val) OP b_fad; \
155 COMPARE_NESTED_FADS(c_dfad, c_fad); \
158 #define RELOP_TEST(FIXTURENAME,TESTNAME,OP) \
159 TYPED_TEST_P(FIXTURENAME, TESTNAME) { \
160 typedef decltype(this->a_dfad_) FadFadType; \
161 typedef typename Sacado::ValueType<FadFadType>::type FadType; \
162 auto a_dfad = this->a_dfad_; \
163 auto b_dfad = this->b_dfad_; \
164 auto a_fad = this->a_fad_; \
165 auto b_fad = this->b_fad_; \
166 bool r1 = a_dfad OP b_dfad; \
167 bool r2 = a_fad OP b_fad; \
168 ASSERT_TRUE(r1 == r2); \
170 double val = this->urand.number(); \
171 r1 = a_dfad OP val; \
172 r2 = a_fad OP FadType(val); \
173 ASSERT_TRUE(r1 == r2); \
175 r1 = val OP b_dfad; \
176 r2 = FadType(val) OP b_fad; \
177 ASSERT_TRUE(r1 == r2); \
180 #define BINARY_FUNC_TEST(FIXTURENAME,TESTNAME,FUNC) \
181 TYPED_TEST_P(FIXTURENAME, TESTNAME) { \
182 typedef decltype(this->a_dfad_) FadFadType; \
183 typedef typename Sacado::ValueType<FadFadType>::type FadType; \
184 auto a_dfad = this->a_dfad_; \
185 auto b_dfad = this->b_dfad_; \
186 auto c_dfad = this->c_dfad_; \
187 auto a_fad = this->a_fad_; \
188 auto b_fad = this->b_fad_; \
189 auto c_fad = this->c_fad_; \
190 c_dfad = FUNC (a_dfad,b_dfad); \
191 c_fad = FUNC (a_fad,b_fad); \
192 COMPARE_NESTED_FADS(c_dfad, c_fad); \
194 double val = this->urand.number(); \
195 c_dfad = FUNC (a_dfad,val); \
196 c_fad = FUNC (a_fad,FadType(val)); \
197 COMPARE_NESTED_FADS(c_dfad, c_fad); \
199 c_dfad = FUNC (val,b_dfad); \
200 c_fad = FUNC (FadType(val),b_fad); \
201 COMPARE_NESTED_FADS(c_dfad, c_fad); \
204 #define UNARY_OP_TEST(FIXTURENAME,TESTNAME,OP) \
205 TYPED_TEST_P(FIXTURENAME, TESTNAME) { \
206 auto a_dfad = this->a_dfad_; \
207 auto c_dfad = this->c_dfad_; \
208 auto a_fad = this->a_fad_; \
209 auto c_fad = this->c_fad_; \
210 c_dfad = OP a_dfad; \
212 COMPARE_NESTED_FADS(c_dfad, c_fad); \
215 #define UNARY_FUNC_TEST(FIXTURENAME,TESTNAME,FUNC) \
216 TYPED_TEST_P(FIXTURENAME, TESTNAME) { \
217 auto a_dfad = this->a_dfad_; \
218 auto c_dfad = this->c_dfad_; \
219 auto a_fad = this->a_fad_; \
220 auto c_fad = this->c_fad_; \
221 c_dfad = FUNC (a_dfad); \
222 c_fad = FUNC (a_fad); \
223 COMPARE_NESTED_FADS(c_dfad, c_fad); \
226 #define UNARY_ASSIGNOP_TEST(FIXTURENAME,TESTNAME,OP) \
227 TYPED_TEST_P(FIXTURENAME, TESTNAME){ \
228 auto a_dfad = this->a_dfad_; \
229 auto b_dfad = this->b_dfad_; \
230 auto c_dfad = this->c_dfad_; \
231 auto a_fad = this->a_fad_; \
232 auto b_fad = this->b_fad_; \
233 auto c_fad = this->c_fad_; \
238 COMPARE_NESTED_FADS(c_dfad, c_fad); \
240 double val = this->urand.number(); \
243 COMPARE_NESTED_FADS(c_dfad, c_fad); \
252 RELOP_TEST(FadFadOpsUnitTest, testNotEquals, !=)
253 RELOP_TEST(FadFadOpsUnitTest, testLessThanOrEquals, <=)
254 RELOP_TEST(FadFadOpsUnitTest, testGreaterThanOrEquals, >=)
255 RELOP_TEST(FadFadOpsUnitTest, testLessThan, <)
256 RELOP_TEST(FadFadOpsUnitTest, testGreaterThan, >)
285 typedef decltype(this->a_dfad_) FadFadType;
290 auto a_dfad = this->a_dfad_;
291 auto b_dfad = this->b_dfad_;
292 auto c_dfad = this->c_dfad_;
293 auto a_fad = this->a_fad_;
294 auto b_fad = this->b_fad_;
295 auto c_fad = this->c_fad_;
297 FadFadType aa_dfad = a_dfad + 1.0;
298 c_dfad =
max(aa_dfad, a_dfad);
300 for (
int i=0;
i<this->n1;
i++) {
305 c_dfad =
max(a_dfad, aa_dfad);
307 for (
int i=0;
i<this->n1;
i++) {
312 c_dfad =
max(a_dfad+1.0, a_dfad);
314 for (
int i=0;
i<this->n1;
i++) {
319 c_dfad =
max(a_dfad, a_dfad+1.0);
321 for (
int i=0;
i<this->n1;
i++) {
326 val = a_dfad.val() + 1;
327 c_dfad =
max(a_dfad, val);
329 for (
int i=0;
i<this->n1;
i++) {
333 val = a_dfad.val() - 1;
334 c_dfad =
max(a_dfad, val);
336 for (
int i=0;
i<this->n1;
i++) {
341 val = b_dfad.val() + 1;
342 c_dfad =
max(val, b_dfad);
344 for (
int i=0;
i<this->n1;
i++) {
348 val = b_dfad.val() - 1;
349 c_dfad =
max(val, b_dfad);
351 for (
int i=0;
i<this->n1;
i++) {
358 typedef decltype(this->a_dfad_) FadFadType;
363 auto a_dfad = this->a_dfad_;
364 auto b_dfad = this->b_dfad_;
365 auto c_dfad = this->c_dfad_;
366 auto a_fad = this->a_fad_;
367 auto b_fad = this->b_fad_;
368 auto c_fad = this->c_fad_;
370 FadFadType aa_dfad = a_dfad - 1.0;
371 c_dfad =
min(aa_dfad, a_dfad);
373 for (
int i=0;
i<this->n1;
i++) {
378 c_dfad =
min(a_dfad, aa_dfad);
380 for (
int i=0;
i<this->n1;
i++) {
385 val = a_dfad.val() - 1;
386 c_dfad =
min(a_dfad, val);
388 for (
int i=0;
i<this->n1;
i++) {
392 val = a_dfad.val() + 1;
393 c_dfad =
min(a_dfad, val);
395 for (
int i=0;
i<this->n1;
i++) {
400 val = b_dfad.val() - 1;
401 c_dfad =
min(val, b_dfad);
403 for (
int i=0;
i<this->n1;
i++) {
407 val = b_dfad.val() + 1;
408 c_dfad =
min(val, b_dfad);
410 for (
int i=0;
i<this->n1;
i++) {
417 auto a_dfad = this->a_dfad_;
418 auto b_dfad = this->b_dfad_;
419 auto c_dfad = this->c_dfad_;
420 auto a_fad = this->a_fad_;
421 auto b_fad = this->b_fad_;
422 auto c_fad = this->c_fad_;
424 c_dfad = this->composite1(a_dfad, b_dfad);
425 c_fad = this->composite1_fad(a_fad, b_fad);
430 typedef decltype(this->a_dfad_) FadFadType;
433 auto a_dfad = this->a_dfad_;
434 auto b_dfad = this->b_dfad_;
435 auto a_fad = this->a_fad_;
436 auto b_fad = this->b_fad_;
438 FadFadType aa_dfad = a_dfad;
439 FAD::Fad< FadType > aa_fad = a_fad;
442 aa_dfad = aa_dfad + b_dfad;
443 aa_fad = aa_fad + b_fad;
448 typedef decltype(this->a_dfad_) FadFadType;
451 auto a_dfad = this->a_dfad_;
452 auto b_dfad = this->b_dfad_;
453 auto a_fad = this->a_fad_;
454 auto b_fad = this->b_fad_;
456 FadFadType aa_dfad = a_dfad;
457 FAD::Fad< FadType > aa_fad = a_fad;
460 aa_dfad = aa_dfad - b_dfad;
461 aa_fad = aa_fad - b_fad;
466 typedef decltype(this->a_dfad_) FadFadType;
469 auto a_dfad = this->a_dfad_;
470 auto b_dfad = this->b_dfad_;
471 auto a_fad = this->a_fad_;
472 auto b_fad = this->b_fad_;
474 FadFadType aa_dfad = a_dfad;
475 FAD::Fad< FadType > aa_fad = a_fad;
478 aa_dfad = aa_dfad * b_dfad;
479 aa_fad = aa_fad * b_fad;
484 typedef decltype(this->a_dfad_) FadFadType;
487 auto a_dfad = this->a_dfad_;
488 auto b_dfad = this->b_dfad_;
489 auto a_fad = this->a_fad_;
490 auto b_fad = this->b_fad_;
492 FadFadType aa_dfad = a_dfad;
493 FAD::Fad< FadType > aa_fad = a_fad;
496 aa_dfad = aa_dfad / b_dfad;
497 aa_fad = aa_fad / b_fad;
503 typedef decltype(this->a_dfad_) FadFadType;
507 auto a_dfad = this->a_dfad_;
509 FadFadType
a, b,
c, cc;
515 ScalarType
f =
pow(a.val().val(), b.val().val());
516 ScalarType fp = b.val().val()*
pow(a.val().val(),b.val().val()-1);
517 ScalarType fpp = b.val().val()*(b.val().val()-1)*
pow(a.val().val(),b.val().val()-2);
518 cc = FadFadType(this->n1,
FadType(this->n2,f));
519 for (
int i=0;
i<this->n2; ++
i)
520 cc.val().fastAccessDx(
i) = fp*a.val().dx(
i);
521 for (
int i=0;
i<this->n1; ++
i) {
522 cc.fastAccessDx(
i) =
FadType(this->n2,fp*a.dx(
i).val());
523 for (
int j=0; j<this->n2; ++j)
524 cc.fastAccessDx(
i).fastAccessDx(j) = fpp*a.dx(
i).val()*a.val().dx(j) + fp*a.dx(
i).dx(j);
531 c =
pow(a, b.val().val());
537 cc.val() =
FadType(this->n2,1.0);
538 for (
int i=0;
i<this->n1; ++
i)
539 cc.fastAccessDx(
i) = 0.0;
544 cc.val() =
FadType(this->n2,1.0);
545 for (
int i=0;
i<this->n1; ++
i)
546 cc.fastAccessDx(
i) = 0.0;
548 c =
pow(a, b.val().val());
559 for (
int i=0;
i<this->n1; ++
i)
560 cc.fastAccessDx(
i) =
FadType(this->n2,0.0);
569 for (
int i=0;
i<this->n1; ++
i)
570 cc.fastAccessDx(
i) =
FadType(this->n2,0.0);
572 c =
pow(a, b.val().val());
578 for (
int i=0;
i<this->n1; ++
i)
579 cc.fastAccessDx(
i) = 0.0;
586 c =
pow(a, b.val().val());
594 f =
pow(a.val().val(), b.val().val());
595 fp = b.val().val()*
pow(a.val().val(),b.val().val()-1);
596 fpp = b.val().val()*(b.val().val()-1)*
pow(a.val().val(),b.val().val()-2);
597 cc = FadFadType(this->n1,
FadType(this->n2,f));
598 for (
int i=0;
i<this->n2; ++
i)
599 cc.val().fastAccessDx(
i) = fp*a.val().dx(
i);
600 for (
int i=0;
i<this->n1; ++
i) {
601 cc.fastAccessDx(
i) =
FadType(this->n2,fp*a.dx(
i).val());
602 for (
int j=0; j<this->n2; ++j)
603 cc.fastAccessDx(
i).fastAccessDx(j) = fpp*a.dx(
i).val()*a.val().dx(j) + fp*a.dx(
i).dx(j);
612 c =
pow(a, b.val().val());
624 c =
pow(a, b.val().val());
630 cc.val() =
FadType(this->n2, 1.0);
631 for (
int i=0;
i<this->n1; ++
i)
632 cc.fastAccessDx(
i) = 0.0;
638 c =
pow(a, b.val().val());
655 testLessThanOrEquals,
656 testGreaterThanOrEquals,
690 #endif // NESETD_FADUNITTESTS_HPP
FAD::Fad< FadType > a_fad_
REGISTER_TYPED_TEST_SUITE_P(FadBLASUnitTests, testSCAL1, testSCAL2, testSCAL3, testSCAL4, testCOPY1, testCOPY2, testCOPY3, testCOPY4, testAXPY1, testAXPY2, testAXPY3, testAXPY4, testDOT1, testDOT2, testDOT3, testDOT4, testNRM21, testNRM22, testGEMV1, testGEMV2, testGEMV3, testGEMV4, testGEMV5, testGEMV6, testGEMV7, testGEMV8, testGEMV9, testTRMV1, testTRMV2, testTRMV3, testTRMV4, testGER1, testGER2, testGER3, testGER4, testGER5, testGER6, testGER7, testGEMM1, testGEMM2, testGEMM3, testGEMM4, testGEMM5, testGEMM6, testGEMM7, testGEMM8, testGEMM9, testGEMM10, testSYMM1, testSYMM2, testSYMM3, testSYMM4, testSYMM5, testSYMM6, testSYMM7, testSYMM8, testSYMM9, testTRMM1, testTRMM2, testTRMM3, testTRMM4, testTRMM5, testTRMM6, testTRMM7, testTRSM1, testTRSM2, testTRSM3, testTRSM4, testTRSM5, testTRSM6, testTRSM7)
ScalarT composite1(const ScalarT &a, const ScalarT &b)
TYPED_TEST_P(FadBLASUnitTests, testSCAL1)
#define COMPARE_NESTED_FADS(a, b)
FAD::Fad< FadType > b_fad_
Sacado::Fad::DFad< double > FadType
Sacado::ScalarType< FadFadType >::type ScalarType
#define BINARY_OP_TEST(FIXTURENAME, TESTNAME, OP)
#define UNARY_FUNC_TEST(FIXTURENAME, TESTNAME, FUNC)
expr expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c *expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 c
FAD::Fad< FadType > c_fad_
SimpleFad< ValueT > min(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
#define UNARY_OP_TEST(FIXTURENAME, TESTNAME, OP)
Sacado::ValueType< FadFadType >::type FadType
TYPED_TEST_SUITE_P(FadBLASUnitTests)
Sacado::Random< ScalarType > urand
SimpleFad< ValueT > max(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
SACADO_INLINE_FUNCTION mpl::enable_if_c< ExprLevel< Expr< T1 > >::value==ExprLevel< Expr< T2 > >::value, Expr< PowerOp< Expr< T1 >, Expr< T2 > > > >::type pow(const Expr< T1 > &expr1, const Expr< T2 > &expr2)
#define BINARY_FUNC_TEST(FIXTURENAME, TESTNAME, FUNC)
A random number generator that generates random numbers uniformly distributed in the interval (a...
#define COMPARE_FADS(a, b)
ScalarT composite1_fad(const ScalarT &a, const ScalarT &b)
#define UNARY_ASSIGNOP_TEST(FIXTURENAME, TESTNAME, OP)
#define RELOP_TEST(FIXTURENAME, TESTNAME, OP)
Base template specification for testing whether type is statically sized.