32 template <
typename OrdinalType,
typename FadType>
35 OrdinalType workspace_size_) :
36 use_dynamic(use_dynamic_),
37 workspace_size(workspace_size_),
39 workspace_pointer(NULL)
47 template <
typename OrdinalType,
typename FadType>
50 use_dynamic(a.use_dynamic),
51 workspace_size(a.workspace_size),
53 workspace_pointer(NULL)
62 template <
typename OrdinalType,
typename FadType>
76 if (workspace_size > 0)
80 template <
typename OrdinalType,
typename FadType>
89 dot = &a.fastAccessDx(0);
94 template <
typename OrdinalType,
typename FadType>
98 OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
111 bool is_contiguous = is_array_contiguous(a, n, n_dot);
117 cdot = &a[0].fastAccessDx(0);
126 dot = allocate_array(n*n_dot);
128 for (OrdinalType i=0; i<n; i++) {
129 val[i] = a[i*inc].val();
130 for (OrdinalType j=0; j<n_dot; j++)
139 template <
typename OrdinalType,
typename FadType>
143 OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
156 bool is_contiguous = is_array_contiguous(A, m*n, n_dot);
162 cdot = &A[0].fastAccessDx(0);
171 dot = allocate_array(m*n*n_dot);
173 for (OrdinalType j=0; j<n; j++) {
174 for (OrdinalType i=0; i<m; i++) {
175 val[j*m+i] = A[j*lda+i].val();
176 for (OrdinalType k=0; k<n_dot; k++)
177 dot[(k*n+j)*m+i] = A[j*lda+i].fastAccessDx(k);
186 template <
typename OrdinalType,
typename FadType>
197 template <
typename OrdinalType,
typename FadType>
201 OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
211 template <
typename OrdinalType,
typename FadType>
214 unpack(
const ValueType*
A, OrdinalType m, OrdinalType n, OrdinalType lda,
215 OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
216 const ValueType*& cval,
const ValueType*& cdot)
const
225 template <
typename OrdinalType,
typename FadType>
236 template <
typename OrdinalType,
typename FadType>
240 OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
250 template <
typename OrdinalType,
typename FadType>
254 OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
264 template <
typename OrdinalType,
typename FadType>
275 "ArrayTraits::unpack(): FadType has wrong number of " <<
276 "derivative components. Got " << n_dot <<
277 ", expected " << final_n_dot <<
".");
279 if (n_dot > final_n_dot)
282 OrdinalType n_avail = a.availableSize();
283 if (n_avail < final_n_dot) {
284 dot = alloate(final_n_dot);
285 for (OrdinalType i=0; i<final_n_dot; i++)
288 else if (n_avail > 0)
289 dot = &a.fastAccessDx(0);
294 template <
typename OrdinalType,
typename FadType>
298 OrdinalType& final_n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
310 bool is_contiguous = is_array_contiguous(a, n, n_dot);
314 "ArrayTraits::unpack(): FadType has wrong number of " <<
315 "derivative components. Got " << n_dot <<
316 ", expected " << final_n_dot <<
".");
318 if (n_dot > final_n_dot)
327 val = allocate_array(n);
328 for (OrdinalType i=0; i<n; i++)
329 val[i] = a[i*inc].
val();
332 OrdinalType n_avail = a[0].availableSize();
333 if (is_contiguous && n_avail >= final_n_dot && final_n_dot > 0) {
335 dot = &a[0].fastAccessDx(0);
337 else if (final_n_dot > 0) {
339 dot = allocate_array(n*final_n_dot);
340 for (OrdinalType i=0; i<n; i++) {
342 for (OrdinalType j=0; j<n_dot; j++)
345 for (OrdinalType j=0; j<final_n_dot; j++)
355 template <
typename OrdinalType,
typename FadType>
359 OrdinalType& n_dot, OrdinalType& final_n_dot,
360 OrdinalType& lda_val, OrdinalType& lda_dot,
372 bool is_contiguous = is_array_contiguous(A, m*n, n_dot);
376 "ArrayTraits::unpack(): FadType has wrong number of " <<
377 "derivative components. Got " << n_dot <<
378 ", expected " << final_n_dot <<
".");
380 if (n_dot > final_n_dot)
389 val = allocate_array(m*n);
390 for (OrdinalType j=0; j<n; j++)
391 for (OrdinalType i=0; i<m; i++)
392 val[j*m+i] = A[j*lda+i].
val();
395 OrdinalType n_avail = A[0].availableSize();
396 if (is_contiguous && n_avail >= final_n_dot && final_n_dot > 0) {
398 dot = &A[0].fastAccessDx(0);
400 else if (final_n_dot > 0) {
402 dot = allocate_array(m*n*final_n_dot);
403 for (OrdinalType j=0; j<n; j++) {
404 for (OrdinalType i=0; i<m; i++) {
406 for (OrdinalType k=0; k<n_dot; k++)
407 dot[(k*n+j)*m+i] = A[j*lda+i].fastAccessDx(k);
409 for (OrdinalType k=0; k<final_n_dot; k++)
410 dot[(k*n+j)*m+i] = 0.0;
420 template <
typename OrdinalType,
typename FadType>
431 if (a.size() != n_dot)
434 for (OrdinalType i=0; i<n_dot; i++)
435 a.fastAccessDx(i) = dot[i];
438 template <
typename OrdinalType,
typename FadType>
442 OrdinalType n_dot, OrdinalType inc_val, OrdinalType inc_dot,
449 if (&a[0].
val() != val)
450 for (OrdinalType i=0; i<n; i++)
451 a[i*inc].
val() = val[i*inc_val];
457 if (a[0].size() != n_dot)
458 for (OrdinalType i=0; i<n; i++)
459 a[i*inc].resize(n_dot);
462 if (a[0].
dx() != dot)
463 for (OrdinalType i=0; i<n; i++)
464 for (OrdinalType j=0; j<n_dot; j++)
468 template <
typename OrdinalType,
typename FadType>
472 OrdinalType n_dot, OrdinalType lda_val, OrdinalType lda_dot,
479 if (&A[0].
val() != val)
480 for (OrdinalType j=0; j<n; j++)
481 for (OrdinalType i=0; i<m; i++)
482 A[i+j*lda].
val() = val[i+j*lda_val];
488 if (A[0].size() != n_dot)
489 for (OrdinalType j=0; j<n; j++)
490 for (OrdinalType i=0; i<m; i++)
491 A[i+j*lda].resize(n_dot);
494 if (A[0].
dx() != dot)
495 for (OrdinalType j=0; j<n; j++)
496 for (OrdinalType i=0; i<m; i++)
497 for (OrdinalType k=0; k<n_dot; k++)
501 template <
typename OrdinalType,
typename FadType>
506 if (n_dot > 0 && a.dx() != dot) {
507 free_array(dot, n_dot);
511 template <
typename OrdinalType,
typename FadType>
515 OrdinalType inc_val, OrdinalType inc_dot,
521 if (val != &a[0].
val())
522 free_array(val, n*inc_val);
524 if (n_dot > 0 && a[0].
dx() != dot)
525 free_array(dot, n*inc_dot*n_dot);
528 template <
typename OrdinalType,
typename FadType>
531 free(
const FadType*
A, OrdinalType m, OrdinalType n, OrdinalType n_dot,
532 OrdinalType lda_val, OrdinalType lda_dot,
538 if (val != &A[0].
val())
539 free_array(val, lda_val*n);
541 if (n_dot > 0 && A[0].
dx() != dot)
542 free_array(dot, lda_dot*n*n_dot);
545 template <
typename OrdinalType,
typename FadType>
556 "ArrayTraits::allocate_array(): " <<
557 "Requested workspace memory beyond size allocated. " <<
558 "Workspace size is " << workspace_size <<
559 ", currently used is " << workspace_pointer-workspace <<
560 ", requested size is " << size <<
".");
565 workspace_pointer += size;
569 template <
typename OrdinalType,
typename FadType>
574 if (use_dynamic && ptr != NULL)
577 workspace_pointer -= size;
580 template <
typename OrdinalType,
typename FadType>
586 (&(a[n-1].val())-&(a[0].
val()) == n-1) &&
587 (a[n-1].dx()-a[0].dx() == n-1);
590 template <
typename OrdinalType,
typename FadType>
593 bool use_dynamic_, OrdinalType static_workspace_size_) :
595 arrayTraits(use_dynamic_, static_workspace_size_),
597 use_default_impl(use_default_impl_)
601 template <
typename OrdinalType,
typename FadType>
605 arrayTraits(x.arrayTraits),
607 use_default_impl(x.use_default_impl)
611 template <
typename OrdinalType,
typename FadType>
617 template <
typename OrdinalType,
typename FadType>
621 const OrdinalType incx)
const
623 if (use_default_impl) {
624 BLASType::SCAL(n,alpha,x,incx);
632 OrdinalType n_alpha_dot = 0, n_x_dot = 0, n_dot = 0;
633 OrdinalType incx_val, incx_dot;
634 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
636 arrayTraits.unpack(x, n, incx, n_x_dot, n_dot, incx_val, incx_dot,
642 (n_x_dot != n_dot && n_x_dot != 0),
644 "BLAS::SCAL(): All arguments must have " <<
645 "the same number of derivative components, or none");
650 blas.SCAL(n*n_x_dot, alpha_val, x_dot, incx_dot);
651 for (OrdinalType i=0; i<n_alpha_dot; i++)
652 blas.AXPY(n, alpha_dot[i], x_val, incx_val, x_dot+i*n*incx_dot, incx_dot);
653 blas.SCAL(n, alpha_val, x_val, incx_val);
656 arrayTraits.pack(x, n, incx, n_dot, incx_val, incx_dot, x_val, x_dot);
659 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
660 arrayTraits.free(x, n, n_dot, incx_val, incx_dot, x_val, x_dot);
663 template <
typename OrdinalType,
typename FadType>
666 COPY(
const OrdinalType n,
const FadType* x,
const OrdinalType incx,
667 FadType* y,
const OrdinalType incy)
const
669 if (use_default_impl) {
670 BLASType::COPY(n,x,incx,y,incy);
677 OrdinalType n_x_dot = x[0].size();
678 OrdinalType n_y_dot = y[0].size();
679 if (n_x_dot == 0 || n_y_dot == 0 || n_x_dot != n_y_dot ||
680 !arrayTraits.is_array_contiguous(x, n, n_x_dot) ||
681 !arrayTraits.is_array_contiguous(y, n, n_y_dot))
682 BLASType::COPY(n,x,incx,y,incy);
684 blas.COPY(n, &x[0].
val(), incx, &y[0].
val(), incy);
690 template <
typename OrdinalType,
typename FadType>
691 template <
typename alpha_type,
typename x_type>
694 AXPY(
const OrdinalType n,
const alpha_type& alpha,
const x_type* x,
695 const OrdinalType incx,
FadType* y,
const OrdinalType incy)
const
697 if (use_default_impl) {
698 BLASType::AXPY(n,alpha,x,incx,y,incy);
707 OrdinalType n_alpha_dot, n_x_dot, n_y_dot, n_dot;
708 OrdinalType incx_val, incy_val, incx_dot, incy_dot;
709 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
710 arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
716 else if (n_x_dot > 0)
720 arrayTraits.unpack(y, n, incy, n_y_dot, n_dot, incy_val, incy_dot, y_val,
726 (n_x_dot != n_dot && n_x_dot != 0) ||
727 (n_y_dot != n_dot && n_y_dot != 0),
729 "BLAS::AXPY(): All arguments must have " <<
730 "the same number of derivative components, or none");
735 blas.AXPY(n*n_x_dot, alpha_val, x_dot, incx_dot, y_dot, incy_dot);
736 for (OrdinalType i=0; i<n_alpha_dot; i++)
737 blas.AXPY(n, alpha_dot[i], x_val, incx_val, y_dot+i*n*incy_dot, incy_dot);
738 blas.AXPY(n, alpha_val, x_val, incx_val, y_val, incy_val);
741 arrayTraits.pack(y, n, incy, n_dot, incy_val, incy_dot, y_val, y_dot);
744 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
745 arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
746 arrayTraits.free(y, n, n_dot, incy_val, incy_dot, y_val, y_dot);
749 template <
typename OrdinalType,
typename FadType>
750 template <
typename x_type,
typename y_type>
753 DOT(
const OrdinalType n,
const x_type* x,
const OrdinalType incx,
754 const y_type* y,
const OrdinalType incy)
const
756 if (use_default_impl)
757 return BLASType::DOT(n,x,incx,y,incy);
762 OrdinalType n_x_dot, n_y_dot;
763 OrdinalType incx_val, incy_val, incx_dot, incy_dot;
764 arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
765 arrayTraits.unpack(y, n, incy, n_y_dot, incy_val, incy_dot, y_val, y_dot);
768 OrdinalType n_z_dot = 0;
771 else if (n_y_dot > 0)
780 Fad_DOT(n, x_val, incx_val, n_x_dot, x_dot, incx_dot,
781 y_val, incy_val, n_y_dot, y_dot, incy_dot,
782 z_val, n_z_dot, z_dot);
785 arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
786 arrayTraits.free(y, n, n_y_dot, incy_val, incy_dot, y_val, y_dot);
791 template <
typename OrdinalType,
typename FadType>
794 NRM2(
const OrdinalType n,
const FadType* x,
const OrdinalType incx)
const
796 if (use_default_impl)
797 return BLASType::NRM2(n,x,incx);
801 OrdinalType n_x_dot, incx_val, incx_dot;
802 arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
808 z.val() = blas.NRM2(n, x_val, incx_val);
813 for (OrdinalType i=0; i<n_x_dot; i++)
818 arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
823 template <
typename OrdinalType,
typename FadType>
824 template <
typename alpha_type,
typename A_type,
typename x_type,
829 const alpha_type& alpha,
const A_type*
A,
830 const OrdinalType lda,
const x_type* x,
831 const OrdinalType incx,
const beta_type& beta,
832 FadType* y,
const OrdinalType incy)
const
834 if (use_default_impl) {
835 BLASType::GEMV(trans,m,n,alpha,A,lda,x,incx,beta,y,incy);
839 OrdinalType n_x_rows = n;
840 OrdinalType n_y_rows = m;
854 OrdinalType n_alpha_dot, n_A_dot, n_x_dot, n_beta_dot, n_y_dot = 0, n_dot;
855 OrdinalType lda_val, incx_val, incy_val, lda_dot, incx_dot, incy_dot;
856 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
857 arrayTraits.unpack(A, m, n, lda, n_A_dot, lda_val, lda_dot, A_val, A_dot);
858 arrayTraits.unpack(x, n_x_rows, incx, n_x_dot, incx_val, incx_dot, x_val,
860 arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
866 else if (n_A_dot > 0)
868 else if (n_x_dot > 0)
870 else if (n_beta_dot > 0)
874 arrayTraits.unpack(y, n_y_rows, incy, n_y_dot, n_dot, incy_val, incy_dot,
878 Fad_GEMV(trans, m, n, alpha_val, n_alpha_dot, alpha_dot, A_val, lda_val,
879 n_A_dot, A_dot, lda_dot, x_val, incx_val, n_x_dot, x_dot, incx_dot,
880 beta_val, n_beta_dot, beta_dot, y_val, incy_val, n_y_dot, y_dot,
884 arrayTraits.pack(y, n_y_rows, incy, n_dot, incy_val, incy_dot, y_val, y_dot);
887 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
888 arrayTraits.free(A, m, n, n_A_dot, lda_val, lda_dot, A_val, A_dot);
889 arrayTraits.free(x, n_x_rows, n_x_dot, incx_val, incx_dot, x_val, x_dot);
890 arrayTraits.free(beta, n_beta_dot, beta_dot);
891 arrayTraits.free(y, n_y_rows, n_dot, incy_val, incy_dot, y_val, y_dot);
894 template <
typename OrdinalType,
typename FadType>
895 template <
typename A_type>
899 const OrdinalType n,
const A_type*
A,
const OrdinalType lda,
900 FadType* x,
const OrdinalType incx)
const
902 if (use_default_impl) {
903 BLASType::TRMV(uplo,trans,diag,n,A,lda,x,incx);
910 OrdinalType n_A_dot = 0, n_x_dot = 0, n_dot = 0;
911 OrdinalType lda_val, incx_val, lda_dot, incx_dot;
912 arrayTraits.unpack(A, n, n, lda, n_A_dot, lda_val, lda_dot, A_val, A_dot);
914 arrayTraits.unpack(x, n, incx, n_x_dot, n_dot, incx_val, incx_dot, x_val,
920 (n_x_dot != n_dot && n_x_dot != 0),
922 "BLAS::TRMV(): All arguments must have " <<
923 "the same number of derivative components, or none");
932 for (OrdinalType i=0; i<n_x_dot; i++)
933 blas.TRMV(uplo, trans, diag, n, A_val, lda_val, x_dot+i*incx_dot*n,
938 if (gemv_Ax.size() != std::size_t(n))
940 for (OrdinalType i=0; i<n_A_dot; i++) {
941 blas.COPY(n, x_val, incx_val, &gemv_Ax[0], OrdinalType(1));
943 lda_dot, &gemv_Ax[0], OrdinalType(1));
944 blas.AXPY(n, 1.0, &gemv_Ax[0], OrdinalType(1), x_dot+i*incx_dot*n,
949 blas.TRMV(uplo, trans, diag, n, A_val, lda_val, x_val, incx_val);
952 arrayTraits.pack(x, n, incx, n_dot, incx_val, incx_dot, x_val, x_dot);
955 arrayTraits.free(A, n, n, n_A_dot, lda_val, lda_dot, A_val, A_dot);
956 arrayTraits.free(x, n, n_dot, incx_val, incx_dot, x_val, x_dot);
959 template <
typename OrdinalType,
typename FadType>
960 template <
typename alpha_type,
typename x_type,
typename y_type>
963 GER(
const OrdinalType m,
const OrdinalType n,
const alpha_type& alpha,
964 const x_type* x,
const OrdinalType incx,
965 const y_type* y,
const OrdinalType incy,
966 FadType*
A,
const OrdinalType lda)
const
968 if (use_default_impl) {
969 BLASType::GER(m,n,alpha,x,incx,y,incy,A,lda);
979 OrdinalType n_alpha_dot, n_x_dot, n_y_dot, n_A_dot, n_dot;
980 OrdinalType lda_val, incx_val, incy_val, lda_dot, incx_dot, incy_dot;
981 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
982 arrayTraits.unpack(x, m, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
983 arrayTraits.unpack(y, n, incy, n_y_dot, incy_val, incy_dot, y_val, y_dot);
989 else if (n_x_dot > 0)
991 else if (n_y_dot > 0)
995 arrayTraits.unpack(A, m, n, lda, n_A_dot, n_dot, lda_val, lda_dot, A_val,
999 Fad_GER(m, n, alpha_val, n_alpha_dot, alpha_dot, x_val, incx_val,
1000 n_x_dot, x_dot, incx_dot, y_val, incy_val, n_y_dot, y_dot,
1001 incy_dot, A_val, lda_val, n_A_dot, A_dot, lda_dot, n_dot);
1004 arrayTraits.pack(A, m, n, lda, n_dot, lda_val, lda_dot, A_val, A_dot);
1007 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1008 arrayTraits.free(x, m, n_x_dot, incx_val, incx_dot, x_val, x_dot);
1009 arrayTraits.free(y, n, n_y_dot, incy_val, incy_dot, y_val, y_dot);
1010 arrayTraits.free(A, m, n, n_dot, lda_val, lda_dot, A_val, A_dot);
1013 template <
typename OrdinalType,
typename FadType>
1014 template <
typename alpha_type,
typename A_type,
typename B_type,
1019 const OrdinalType m,
const OrdinalType n,
const OrdinalType k,
1020 const alpha_type& alpha,
const A_type*
A,
const OrdinalType lda,
1021 const B_type*
B,
const OrdinalType ldb,
const beta_type& beta,
1022 FadType*
C,
const OrdinalType ldc)
const
1024 if (use_default_impl) {
1025 BLASType::GEMM(transa,transb,m,n,k,alpha,A,lda,B,ldb,beta,C,ldc);
1029 OrdinalType n_A_rows = m;
1030 OrdinalType n_A_cols = k;
1036 OrdinalType n_B_rows = k;
1037 OrdinalType n_B_cols = n;
1051 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_beta_dot, n_C_dot = 0, n_dot;
1052 OrdinalType lda_val, ldb_val, ldc_val, lda_dot, ldb_dot, ldc_dot;
1053 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1054 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1056 arrayTraits.unpack(B, n_B_rows, n_B_cols, ldb, n_B_dot, ldb_val, ldb_dot,
1058 arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
1062 if (n_alpha_dot > 0)
1063 n_dot = n_alpha_dot;
1064 else if (n_A_dot > 0)
1066 else if (n_B_dot > 0)
1068 else if (n_beta_dot > 0)
1072 arrayTraits.unpack(C, m, n, ldc, n_C_dot, n_dot, ldc_val, ldc_dot, C_val,
1076 Fad_GEMM(transa, transb, m, n, k,
1077 alpha_val, n_alpha_dot, alpha_dot,
1078 A_val, lda_val, n_A_dot, A_dot, lda_dot,
1079 B_val, ldb_val, n_B_dot, B_dot, ldb_dot,
1080 beta_val, n_beta_dot, beta_dot,
1081 C_val, ldc_val, n_C_dot, C_dot, ldc_dot, n_dot);
1084 arrayTraits.pack(C, m, n, ldc, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1087 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1088 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1090 arrayTraits.free(B, n_B_rows, n_B_cols, n_B_dot, ldb_val, ldb_dot, B_val,
1092 arrayTraits.free(beta, n_beta_dot, beta_dot);
1093 arrayTraits.free(C, m, n, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1096 template <
typename OrdinalType,
typename FadType>
1097 template <
typename alpha_type,
typename A_type,
typename B_type,
1102 const OrdinalType m,
const OrdinalType n,
1103 const alpha_type& alpha,
const A_type*
A,
const OrdinalType lda,
1104 const B_type*
B,
const OrdinalType ldb,
const beta_type& beta,
1105 FadType*
C,
const OrdinalType ldc)
const
1107 if (use_default_impl) {
1108 BLASType::SYMM(side,uplo,m,n,alpha,A,lda,B,ldb,beta,C,ldc);
1112 OrdinalType n_A_rows = m;
1113 OrdinalType n_A_cols = m;
1127 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_beta_dot, n_C_dot, n_dot;
1128 OrdinalType lda_val, ldb_val, ldc_val, lda_dot, ldb_dot, ldc_dot;
1129 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1130 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1132 arrayTraits.unpack(B, m, n, ldb, n_B_dot, ldb_val, ldb_dot, B_val, B_dot);
1133 arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
1137 if (n_alpha_dot > 0)
1138 n_dot = n_alpha_dot;
1139 else if (n_A_dot > 0)
1141 else if (n_B_dot > 0)
1143 else if (n_beta_dot > 0)
1147 arrayTraits.unpack(C, m, n, ldc, n_C_dot, n_dot, ldc_val, ldc_dot, C_val,
1151 Fad_SYMM(side, uplo, m, n,
1152 alpha_val, n_alpha_dot, alpha_dot,
1153 A_val, lda_val, n_A_dot, A_dot, lda_dot,
1154 B_val, ldb_val, n_B_dot, B_dot, ldb_dot,
1155 beta_val, n_beta_dot, beta_dot,
1156 C_val, ldc_val, n_C_dot, C_dot, ldc_dot, n_dot);
1159 arrayTraits.pack(C, m, n, ldc, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1162 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1163 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1165 arrayTraits.free(B, m, n, n_B_dot, ldb_val, ldb_dot, B_val, B_dot);
1166 arrayTraits.free(beta, n_beta_dot, beta_dot);
1167 arrayTraits.free(C, m, n, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1170 template <
typename OrdinalType,
typename FadType>
1171 template <
typename alpha_type,
typename A_type>
1176 const OrdinalType m,
const OrdinalType n,
1177 const alpha_type& alpha,
const A_type*
A,
const OrdinalType lda,
1178 FadType*
B,
const OrdinalType ldb)
const
1180 if (use_default_impl) {
1181 BLASType::TRMM(side,uplo,transa,diag,m,n,alpha,A,lda,B,ldb);
1185 OrdinalType n_A_rows = m;
1186 OrdinalType n_A_cols = m;
1197 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_dot;
1198 OrdinalType lda_val, ldb_val, lda_dot, ldb_dot;
1199 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1200 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1205 if (n_alpha_dot > 0)
1206 n_dot = n_alpha_dot;
1207 else if (n_A_dot > 0)
1211 arrayTraits.unpack(B, m, n, ldb, n_B_dot, n_dot, ldb_val, ldb_dot, B_val,
1215 Fad_TRMM(side, uplo, transa, diag, m, n,
1216 alpha_val, n_alpha_dot, alpha_dot,
1217 A_val, lda_val, n_A_dot, A_dot, lda_dot,
1218 B_val, ldb_val, n_B_dot, B_dot, ldb_dot, n_dot);
1221 arrayTraits.pack(B, m, n, ldb, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1224 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1225 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1227 arrayTraits.free(B, m, n, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1230 template <
typename OrdinalType,
typename FadType>
1231 template <
typename alpha_type,
typename A_type>
1236 const OrdinalType m,
const OrdinalType n,
1237 const alpha_type& alpha,
const A_type*
A,
const OrdinalType lda,
1238 FadType*
B,
const OrdinalType ldb)
const
1240 if (use_default_impl) {
1241 BLASType::TRSM(side,uplo,transa,diag,m,n,alpha,A,lda,B,ldb);
1245 OrdinalType n_A_rows = m;
1246 OrdinalType n_A_cols = m;
1257 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_dot;
1258 OrdinalType lda_val, ldb_val, lda_dot, ldb_dot;
1259 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1260 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1265 if (n_alpha_dot > 0)
1266 n_dot = n_alpha_dot;
1267 else if (n_A_dot > 0)
1271 arrayTraits.unpack(B, m, n, ldb, n_B_dot, n_dot, ldb_val, ldb_dot, B_val,
1275 Fad_TRSM(side, uplo, transa, diag, m, n,
1276 alpha_val, n_alpha_dot, alpha_dot,
1277 A_val, lda_val, n_A_dot, A_dot, lda_dot,
1278 B_val, ldb_val, n_B_dot, B_dot, ldb_dot, n_dot);
1281 arrayTraits.pack(B, m, n, ldb, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1284 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1285 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1287 arrayTraits.free(B, m, n, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1290 template <
typename OrdinalType,
typename FadType>
1291 template <
typename x_type,
typename y_type>
1296 const OrdinalType incx,
1297 const OrdinalType n_x_dot,
1298 const x_type* x_dot,
1299 const OrdinalType incx_dot,
1301 const OrdinalType incy,
1302 const OrdinalType n_y_dot,
1303 const y_type* y_dot,
1304 const OrdinalType incy_dot,
1306 const OrdinalType n_z_dot,
1312 (n_y_dot != n_z_dot && n_y_dot != 0),
1314 "BLAS::Fad_DOT(): All arguments must have " <<
1315 "the same number of derivative components, or none");
1320 if (incx_dot == OrdinalType(1))
1321 blas.GEMV(
Teuchos::TRANS, n, n_x_dot, 1.0, x_dot, n, y, incy, 0.0, z_dot,
1324 for (OrdinalType i=0; i<n_z_dot; i++)
1325 z_dot[i] = blas.DOT(n, x_dot+i*incx_dot*n, incx_dot, y, incy);
1330 if (incy_dot == OrdinalType(1) &&
1332 blas.GEMV(
Teuchos::TRANS, n, n_y_dot, 1.0, y_dot, n, x, incx, 1.0, z_dot,
1335 for (OrdinalType i=0; i<n_z_dot; i++)
1336 z_dot[i] += blas.DOT(n, x, incx, y_dot+i*incy_dot*n, incy_dot);
1340 z = blas.DOT(n, x, incx, y, incy);
1343 template <
typename OrdinalType,
typename FadType>
1344 template <
typename alpha_type,
typename A_type,
typename x_type,
1349 const OrdinalType m,
1350 const OrdinalType n,
1351 const alpha_type& alpha,
1352 const OrdinalType n_alpha_dot,
1353 const alpha_type* alpha_dot,
1355 const OrdinalType lda,
1356 const OrdinalType n_A_dot,
1357 const A_type* A_dot,
1358 const OrdinalType lda_dot,
1360 const OrdinalType incx,
1361 const OrdinalType n_x_dot,
1362 const x_type* x_dot,
1363 const OrdinalType incx_dot,
1364 const beta_type& beta,
1365 const OrdinalType n_beta_dot,
1366 const beta_type* beta_dot,
1368 const OrdinalType incy,
1369 const OrdinalType n_y_dot,
1371 const OrdinalType incy_dot,
1372 const OrdinalType n_dot)
const
1377 (n_A_dot != n_dot && n_A_dot != 0) ||
1378 (n_x_dot != n_dot && n_x_dot != 0) ||
1379 (n_beta_dot != n_dot && n_beta_dot != 0) ||
1380 (n_y_dot != n_dot && n_y_dot != 0),
1382 "BLAS::Fad_GEMV(): All arguments must have " <<
1383 "the same number of derivative components, or none");
1385 OrdinalType n_A_rows = m;
1386 OrdinalType n_A_cols = n;
1387 OrdinalType n_x_rows = n;
1388 OrdinalType n_y_rows = m;
1398 blas.SCAL(n_y_rows*n_y_dot, beta, y_dot, incy_dot);
1404 alpha, A, lda, x_dot, n_x_rows, 1.0, y_dot, n_y_rows);
1406 for (OrdinalType i=0; i<n_x_dot; i++)
1407 blas.GEMV(trans, m, n, alpha, A, lda, x_dot+i*incx_dot*n_x_rows,
1408 incx_dot, 1.0, y_dot+i*incy_dot*n_y_rows, incy_dot);
1412 if (n_alpha_dot > 0) {
1413 if (gemv_Ax.size() != std::size_t(n))
1415 blas.GEMV(trans, m, n, 1.0, A, lda, x, incx, 0.0, &gemv_Ax[0],
1417 for (OrdinalType i=0; i<n_alpha_dot; i++)
1418 blas.AXPY(n_y_rows, alpha_dot[i], &gemv_Ax[0], OrdinalType(1),
1419 y_dot+i*incy_dot*n_y_rows, incy_dot);
1423 for (OrdinalType i=0; i<n_A_dot; i++)
1424 blas.GEMV(trans, m, n, alpha, A_dot+i*lda_dot*n, lda_dot, x, incx, 1.0,
1425 y_dot+i*incy_dot*n_y_rows, incy_dot);
1428 for (OrdinalType i=0; i<n_beta_dot; i++)
1429 blas.AXPY(n_y_rows, beta_dot[i], y, incy, y_dot+i*incy_dot*n_y_rows,
1433 if (n_alpha_dot > 0) {
1434 blas.SCAL(n_y_rows, beta, y, incy);
1435 blas.AXPY(n_y_rows, alpha, &gemv_Ax[0], OrdinalType(1), y, incy);
1438 blas.GEMV(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
1441 template <
typename OrdinalType,
typename FadType>
1442 template <
typename alpha_type,
typename x_type,
typename y_type>
1446 const OrdinalType n,
1447 const alpha_type& alpha,
1448 const OrdinalType n_alpha_dot,
1449 const alpha_type* alpha_dot,
1451 const OrdinalType incx,
1452 const OrdinalType n_x_dot,
1453 const x_type* x_dot,
1454 const OrdinalType incx_dot,
1456 const OrdinalType incy,
1457 const OrdinalType n_y_dot,
1458 const y_type* y_dot,
1459 const OrdinalType incy_dot,
1461 const OrdinalType lda,
1462 const OrdinalType n_A_dot,
1464 const OrdinalType lda_dot,
1465 const OrdinalType n_dot)
const
1470 (n_A_dot != n_dot && n_A_dot != 0) ||
1471 (n_x_dot != n_dot && n_x_dot != 0) ||
1472 (n_y_dot != n_dot && n_y_dot != 0),
1474 "BLAS::Fad_GER(): All arguments must have " <<
1475 "the same number of derivative components, or none");
1479 for (OrdinalType i=0; i<n_alpha_dot; i++)
1480 blas.GER(m, n, alpha_dot[i], x, incx, y, incy, A_dot+i*lda_dot*n, lda_dot);
1483 for (OrdinalType i=0; i<n_x_dot; i++)
1484 blas.GER(m, n, alpha, x_dot+i*incx_dot*m, incx_dot, y, incy,
1485 A_dot+i*lda_dot*n, lda_dot);
1489 blas.GER(m, n*n_y_dot, alpha, x, incx, y_dot, incy_dot, A_dot, lda_dot);
1492 blas.GER(m, n, alpha, x, incx, y, incy, A, lda);
1495 template <
typename OrdinalType,
typename FadType>
1496 template <
typename alpha_type,
typename A_type,
typename B_type,
1502 const OrdinalType m,
1503 const OrdinalType n,
1504 const OrdinalType k,
1505 const alpha_type& alpha,
1506 const OrdinalType n_alpha_dot,
1507 const alpha_type* alpha_dot,
1509 const OrdinalType lda,
1510 const OrdinalType n_A_dot,
1511 const A_type* A_dot,
1512 const OrdinalType lda_dot,
1514 const OrdinalType ldb,
1515 const OrdinalType n_B_dot,
1516 const B_type* B_dot,
1517 const OrdinalType ldb_dot,
1518 const beta_type& beta,
1519 const OrdinalType n_beta_dot,
1520 const beta_type* beta_dot,
1522 const OrdinalType ldc,
1523 const OrdinalType n_C_dot,
1525 const OrdinalType ldc_dot,
1526 const OrdinalType n_dot)
const
1531 (n_A_dot != n_dot && n_A_dot != 0) ||
1532 (n_B_dot != n_dot && n_B_dot != 0) ||
1533 (n_beta_dot != n_dot && n_beta_dot != 0) ||
1534 (n_C_dot != n_dot && n_C_dot != 0),
1536 "BLAS::Fad_GEMM(): All arguments must have " <<
1537 "the same number of derivative components, or none");
1539 OrdinalType n_A_cols = k;
1544 OrdinalType n_B_cols = n;
1552 blas.SCAL(m*n*n_C_dot, beta, C_dot, OrdinalType(1));
1554 for (OrdinalType i=0; i<n_C_dot; i++)
1555 for (OrdinalType j=0; j<n; j++)
1556 blas.SCAL(m, beta, C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1560 for (OrdinalType i=0; i<n_B_dot; i++)
1561 blas.GEMM(transa, transb, m, n, k, alpha, A, lda, B_dot+i*ldb_dot*n_B_cols,
1562 ldb_dot, 1.0, C_dot+i*ldc_dot*n, ldc_dot);
1565 if (n_alpha_dot > 0) {
1566 if (gemm_AB.size() != std::size_t(m*n))
1567 gemm_AB.resize(m*n);
1568 blas.GEMM(transa, transb, m, n, k, 1.0, A, lda, B, ldb, 0.0, &gemm_AB[0],
1571 for (OrdinalType i=0; i<n_alpha_dot; i++)
1572 blas.AXPY(m*n, alpha_dot[i], &gemm_AB[0], OrdinalType(1),
1573 C_dot+i*ldc_dot*n, OrdinalType(1));
1575 for (OrdinalType i=0; i<n_alpha_dot; i++)
1576 for (OrdinalType j=0; j<n; j++)
1577 blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1578 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1582 for (OrdinalType i=0; i<n_A_dot; i++)
1583 blas.GEMM(transa, transb, m, n, k, alpha, A_dot+i*lda_dot*n_A_cols,
1584 lda_dot, B, ldb, 1.0, C_dot+i*ldc_dot*n, ldc_dot);
1587 if (ldc == m && ldc_dot == m)
1588 for (OrdinalType i=0; i<n_beta_dot; i++)
1589 blas.AXPY(m*n, beta_dot[i], C, OrdinalType(1), C_dot+i*ldc_dot*n,
1592 for (OrdinalType i=0; i<n_beta_dot; i++)
1593 for (OrdinalType j=0; j<n; j++)
1594 blas.AXPY(m, beta_dot[i], C+j*ldc, OrdinalType(1),
1595 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1598 if (n_alpha_dot > 0) {
1600 blas.SCAL(m*n, beta, C, OrdinalType(1));
1601 blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), C, OrdinalType(1));
1604 for (OrdinalType j=0; j<n; j++) {
1605 blas.SCAL(m, beta, C+j*ldc, OrdinalType(1));
1606 blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), C+j*ldc,
1611 blas.GEMM(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
1614 template <
typename OrdinalType,
typename FadType>
1615 template <
typename alpha_type,
typename A_type,
typename B_type,
1620 const OrdinalType m,
1621 const OrdinalType n,
1622 const alpha_type& alpha,
1623 const OrdinalType n_alpha_dot,
1624 const alpha_type* alpha_dot,
1626 const OrdinalType lda,
1627 const OrdinalType n_A_dot,
1628 const A_type* A_dot,
1629 const OrdinalType lda_dot,
1631 const OrdinalType ldb,
1632 const OrdinalType n_B_dot,
1633 const B_type* B_dot,
1634 const OrdinalType ldb_dot,
1635 const beta_type& beta,
1636 const OrdinalType n_beta_dot,
1637 const beta_type* beta_dot,
1639 const OrdinalType ldc,
1640 const OrdinalType n_C_dot,
1642 const OrdinalType ldc_dot,
1643 const OrdinalType n_dot)
const
1648 (n_A_dot != n_dot && n_A_dot != 0) ||
1649 (n_B_dot != n_dot && n_B_dot != 0) ||
1650 (n_beta_dot != n_dot && n_beta_dot != 0) ||
1651 (n_C_dot != n_dot && n_C_dot != 0),
1653 "BLAS::Fad_SYMM(): All arguments must have " <<
1654 "the same number of derivative components, or none");
1656 OrdinalType n_A_cols = m;
1664 blas.SCAL(m*n*n_C_dot, beta, C_dot, OrdinalType(1));
1666 for (OrdinalType i=0; i<n_C_dot; i++)
1667 for (OrdinalType j=0; j<n; j++)
1668 blas.SCAL(m, beta, C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1672 for (OrdinalType i=0; i<n_B_dot; i++)
1673 blas.SYMM(side, uplo, m, n, alpha, A, lda, B_dot+i*ldb_dot*n,
1674 ldb_dot, 1.0, C_dot+i*ldc_dot*n, ldc_dot);
1677 if (n_alpha_dot > 0) {
1678 if (gemm_AB.size() != std::size_t(m*n))
1679 gemm_AB.resize(m*n);
1680 blas.SYMM(side, uplo, m, n, 1.0, A, lda, B, ldb, 0.0, &gemm_AB[0],
1683 for (OrdinalType i=0; i<n_alpha_dot; i++)
1684 blas.AXPY(m*n, alpha_dot[i], &gemm_AB[0], OrdinalType(1),
1685 C_dot+i*ldc_dot*n, OrdinalType(1));
1687 for (OrdinalType i=0; i<n_alpha_dot; i++)
1688 for (OrdinalType j=0; j<n; j++)
1689 blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1690 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1694 for (OrdinalType i=0; i<n_A_dot; i++)
1695 blas.SYMM(side, uplo, m, n, alpha, A_dot+i*lda_dot*n_A_cols, lda_dot, B,
1696 ldb, 1.0, C_dot+i*ldc_dot*n, ldc_dot);
1699 if (ldc == m && ldc_dot == m)
1700 for (OrdinalType i=0; i<n_beta_dot; i++)
1701 blas.AXPY(m*n, beta_dot[i], C, OrdinalType(1), C_dot+i*ldc_dot*n,
1704 for (OrdinalType i=0; i<n_beta_dot; i++)
1705 for (OrdinalType j=0; j<n; j++)
1706 blas.AXPY(m, beta_dot[i], C+j*ldc, OrdinalType(1),
1707 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1710 if (n_alpha_dot > 0) {
1712 blas.SCAL(m*n, beta, C, OrdinalType(1));
1713 blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), C, OrdinalType(1));
1716 for (OrdinalType j=0; j<n; j++) {
1717 blas.SCAL(m, beta, C+j*ldc, OrdinalType(1));
1718 blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), C+j*ldc,
1723 blas.SYMM(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc);
1726 template <
typename OrdinalType,
typename FadType>
1727 template <
typename alpha_type,
typename A_type>
1734 const OrdinalType m,
1735 const OrdinalType n,
1736 const alpha_type& alpha,
1737 const OrdinalType n_alpha_dot,
1738 const alpha_type* alpha_dot,
1740 const OrdinalType lda,
1741 const OrdinalType n_A_dot,
1742 const A_type* A_dot,
1743 const OrdinalType lda_dot,
1745 const OrdinalType ldb,
1746 const OrdinalType n_B_dot,
1748 const OrdinalType ldb_dot,
1749 const OrdinalType n_dot)
const
1754 (n_A_dot != n_dot && n_A_dot != 0) ||
1755 (n_B_dot != n_dot && n_B_dot != 0),
1757 "BLAS::Fad_TRMM(): All arguments must have " <<
1758 "the same number of derivative components, or none");
1760 OrdinalType n_A_cols = m;
1766 for (OrdinalType i=0; i<n_B_dot; i++)
1767 blas.TRMM(side, uplo, transa, diag, m, n, alpha, A, lda, B_dot+i*ldb_dot*n,
1771 if (n_alpha_dot > 0) {
1772 if (gemm_AB.size() != std::size_t(m*n))
1773 gemm_AB.resize(m*n);
1775 blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1777 for (OrdinalType j=0; j<n; j++)
1778 blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1779 blas.TRMM(side, uplo, transa, diag, m, n, 1.0, A, lda, &gemm_AB[0],
1782 for (OrdinalType i=0; i<n_alpha_dot; i++)
1783 blas.AXPY(m*n, alpha_dot[i], &gemm_AB[0], OrdinalType(1),
1784 B_dot+i*ldb_dot*n, OrdinalType(1));
1786 for (OrdinalType i=0; i<n_alpha_dot; i++)
1787 for (OrdinalType j=0; j<n; j++)
1788 blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1789 B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1794 if (gemm_AB.size() != std::size_t(m*n))
1795 gemm_AB.resize(m*n);
1796 for (OrdinalType i=0; i<n_A_dot; i++) {
1798 blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1800 for (OrdinalType j=0; j<n; j++)
1801 blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1803 A_dot+i*lda_dot*n_A_cols, lda_dot, &gemm_AB[0],
1806 blas.AXPY(m*n, 1.0, &gemm_AB[0], OrdinalType(1),
1807 B_dot+i*ldb_dot*n, OrdinalType(1));
1809 for (OrdinalType j=0; j<n; j++)
1810 blas.AXPY(m, 1.0, &gemm_AB[j*m], OrdinalType(1),
1811 B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1816 if (n_alpha_dot > 0 && n_A_dot == 0) {
1818 blas.SCAL(m*n, 0.0, B, OrdinalType(1));
1819 blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), B, OrdinalType(1));
1822 for (OrdinalType j=0; j<n; j++) {
1823 blas.SCAL(m, 0.0, B+j*ldb, OrdinalType(1));
1824 blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), B+j*ldb,
1829 blas.TRMM(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb);
1832 template <
typename OrdinalType,
typename FadType>
1833 template <
typename alpha_type,
typename A_type>
1840 const OrdinalType m,
1841 const OrdinalType n,
1842 const alpha_type& alpha,
1843 const OrdinalType n_alpha_dot,
1844 const alpha_type* alpha_dot,
1846 const OrdinalType lda,
1847 const OrdinalType n_A_dot,
1848 const A_type* A_dot,
1849 const OrdinalType lda_dot,
1851 const OrdinalType ldb,
1852 const OrdinalType n_B_dot,
1854 const OrdinalType ldb_dot,
1855 const OrdinalType n_dot)
const
1860 (n_A_dot != n_dot && n_A_dot != 0) ||
1861 (n_B_dot != n_dot && n_B_dot != 0),
1863 "BLAS::Fad_TRSM(): All arguments must have " <<
1864 "the same number of derivative components, or none");
1866 OrdinalType n_A_cols = m;
1874 blas.SCAL(m*n*n_B_dot, alpha, B_dot, OrdinalType(1));
1876 for (OrdinalType i=0; i<n_B_dot; i++)
1877 for (OrdinalType j=0; j<n; j++)
1878 blas.SCAL(m, alpha, B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1882 if (n_alpha_dot > 0) {
1883 if (ldb == m && ldb_dot == m)
1884 for (OrdinalType i=0; i<n_alpha_dot; i++)
1885 blas.AXPY(m*n, alpha_dot[i], B, OrdinalType(1),
1886 B_dot+i*ldb_dot*n, OrdinalType(1));
1888 for (OrdinalType i=0; i<n_alpha_dot; i++)
1889 for (OrdinalType j=0; j<n; j++)
1890 blas.AXPY(m, alpha_dot[i], B+j*ldb, OrdinalType(1),
1891 B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1895 blas.TRSM(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb);
1899 if (gemm_AB.size() != std::size_t(m*n))
1900 gemm_AB.resize(m*n);
1901 for (OrdinalType i=0; i<n_A_dot; i++) {
1903 blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1905 for (OrdinalType j=0; j<n; j++)
1906 blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1908 A_dot+i*lda_dot*n_A_cols, lda_dot, &gemm_AB[0],
1911 blas.AXPY(m*n, -1.0, &gemm_AB[0], OrdinalType(1),
1912 B_dot+i*ldb_dot*n, OrdinalType(1));
1914 for (OrdinalType j=0; j<n; j++)
1915 blas.AXPY(m, -1.0, &gemm_AB[j*m], OrdinalType(1),
1916 B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1922 blas.TRSM(side, uplo, transa, diag, m, n*n_dot, 1.0, A, lda, B_dot,
1925 for (OrdinalType i=0; i<n_dot; i++)
1926 blas.TRSM(side, uplo, transa, diag, m, n, 1.0, A, lda, B_dot+i*ldb_dot*n,
void TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, FadType *B, const OrdinalType ldb) const
Solves the matrix equations: op(A)*X=alpha*B or X*op(A)=alpha*B where X and B are m by n matrices...
void TRMV(Teuchos::EUplo uplo, Teuchos::ETransp trans, Teuchos::EDiag diag, const OrdinalType n, const A_type *A, const OrdinalType lda, FadType *x, const OrdinalType incx) const
Performs the matrix-std::vector operation: x <- A*x or x <- A'*x where A is a unit/non-unit n by n uppe...
void Fad_GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const B_type *B, const OrdinalType ldb, const OrdinalType n_B_dot, const B_type *B_dot, const OrdinalType ldb_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *C, const OrdinalType ldc, const OrdinalType n_C_dot, ValueType *C_dot, const OrdinalType ldc_dot, const OrdinalType n_dot) const
Implementation of GEMM.
void GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const x_type *x, const OrdinalType incx, const beta_type &beta, FadType *y, const OrdinalType incy) const
Performs the matrix-std::vector operation: y <- alpha*A*x+beta*y or y <- alpha*A'*x+beta*y where A is a...
void GER(const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const x_type *x, const OrdinalType incx, const y_type *y, const OrdinalType incy, FadType *A, const OrdinalType lda) const
Performs the rank 1 operation: A <- alpha*x*y'+A.
void Fad_TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, ValueType *B, const OrdinalType ldb, const OrdinalType n_B_dot, ValueType *B_dot, const OrdinalType ldb_dot, const OrdinalType n_dot) const
Implementation of TRMM.
bool is_array_contiguous(const FadType *a, OrdinalType n, OrdinalType n_dot) const
FadType DOT(const OrdinalType n, const x_type *x, const OrdinalType incx, const y_type *y, const OrdinalType incy) const
Form the dot product of the vectors x and y.
virtual ~BLAS()
Destructor.
Fad specializations for Teuchos::BLAS wrappers.
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
void GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const B_type *B, const OrdinalType ldb, const beta_type &beta, FadType *C, const OrdinalType ldc) const
Performs the matrix-matrix operation: C <- alpha*op(A)*op(B)+beta*C where op(A) is either A or A'...
ValueType * allocate_array(OrdinalType size) const
void TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, FadType *B, const OrdinalType ldb) const
Performs the matrix-matrix operation: C <- alpha*op(A)*B+beta*C or C <- alpha*B*op(A)+beta*C where op...
ValueType * workspace_pointer
Pointer to current free entry in workspace.
OrdinalType workspace_size
Size of static workspace.
void Fad_SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const B_type *B, const OrdinalType ldb, const OrdinalType n_B_dot, const B_type *B_dot, const OrdinalType ldb_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *C, const OrdinalType ldc, const OrdinalType n_C_dot, ValueType *C_dot, const OrdinalType ldc_dot, const OrdinalType n_dot) const
Implementation of SYMM.
Sacado::dummy< ValueType, scalar_type >::type ScalarType
ArrayTraits(bool use_dynamic=true, OrdinalType workspace_size=0)
void AXPY(const OrdinalType n, const alpha_type &alpha, const x_type *x, const OrdinalType incx, FadType *y, const OrdinalType incy) const
Perform the operation: y <- y+alpha*x.
void COPY(const OrdinalType n, const FadType *x, const OrdinalType incx, FadType *y, const OrdinalType incy) const
Copy the std::vector x to the std::vector y.
void SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const B_type *B, const OrdinalType ldb, const beta_type &beta, FadType *C, const OrdinalType ldc) const
Performs the matrix-matrix operation: C <- alpha*A*B+beta*C or C <- alpha*B*A+beta*C where A is an m ...
void Fad_TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, ValueType *B, const OrdinalType ldb, const OrdinalType n_B_dot, ValueType *B_dot, const OrdinalType ldb_dot, const OrdinalType n_dot) const
Implementation of TRMM.
static magnitudeType magnitude(T a)
void free_array(const ValueType *ptr, OrdinalType size) const
expr expr expr fastAccessDx(i)) FAD_UNARYOP_MACRO(exp
MagnitudeType NRM2(const OrdinalType n, const FadType *x, const OrdinalType incx) const
Compute the 2-norm of the std::vector x.
ValueType * workspace
Workspace for holding contiguous values/derivatives.
void Fad_GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *y, const OrdinalType incy, const OrdinalType n_y_dot, ValueType *y_dot, const OrdinalType incy_dot, const OrdinalType n_dot) const
Implementation of GEMV.
void SCAL(const OrdinalType n, const FadType &alpha, FadType *x, const OrdinalType incx) const
Scale the std::vector x by the constant alpha.
void Fad_DOT(const OrdinalType n, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const y_type *y, const OrdinalType incy, const OrdinalType n_y_dot, const y_type *y_dot, const OrdinalType incy_dot, ValueType &z, const OrdinalType n_z_dot, ValueType *zdot) const
Implementation of DOT.
Base template specification for ValueType.
void Fad_GER(const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const y_type *y, const OrdinalType incy, const OrdinalType n_y_dot, const y_type *y_dot, const OrdinalType incy_dot, ValueType *A, const OrdinalType lda, const OrdinalType n_A_dot, ValueType *A_dot, const OrdinalType lda_dot, const OrdinalType n_dot) const
Implementation of GER.
BLAS(bool use_default_impl=true, bool use_dynamic=true, OrdinalType static_workspace_size=0)
Default constructor.