10 #ifndef EIGEN_SPECIAL_FUNCTIONS_H 11 #define EIGEN_SPECIAL_FUNCTIONS_H 44 template <
typename Scalar>
47 static EIGEN_STRONG_INLINE Scalar run(
const Scalar) {
48 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
49 THIS_TYPE_IS_NOT_SUPPORTED);
54 template <
typename Scalar>
55 struct lgamma_retval {
59 #if EIGEN_HAS_C99_MATH 61 #if defined(__GLIBC__) && ((__GLIBC__>=2 && __GLIBC_MINOR__ >= 19) || __GLIBC__>2) \ 62 && (defined(_DEFAULT_SOURCE) || defined(_BSD_SOURCE) || defined(_SVID_SOURCE)) 63 #define EIGEN_HAS_LGAMMA_R 67 #if defined(__GLIBC__) && ((__GLIBC__==2 && __GLIBC_MINOR__ < 19) || __GLIBC__<2) \ 68 && (defined(_BSD_SOURCE) || defined(_SVID_SOURCE)) 69 #define EIGEN_HAS_LGAMMA_R 73 struct lgamma_impl<float> {
75 static EIGEN_STRONG_INLINE
float run(
float x) {
76 #if !defined(EIGEN_GPU_COMPILE_PHASE) && defined (EIGEN_HAS_LGAMMA_R) && !defined(__APPLE__) 78 return ::lgammaf_r(x, &dummy);
79 #elif defined(SYCL_DEVICE_ONLY) 80 return cl::sycl::lgamma(x);
88 struct lgamma_impl<double> {
90 static EIGEN_STRONG_INLINE
double run(
double x) {
91 #if !defined(EIGEN_GPU_COMPILE_PHASE) && defined(EIGEN_HAS_LGAMMA_R) && !defined(__APPLE__) 93 return ::lgamma_r(x, &dummy);
94 #elif defined(SYCL_DEVICE_ONLY) 95 return cl::sycl::lgamma(x);
102 #undef EIGEN_HAS_LGAMMA_R 109 template <
typename Scalar>
110 struct digamma_retval {
127 template <
typename Scalar>
128 struct digamma_impl_maybe_poly {
130 static EIGEN_STRONG_INLINE Scalar run(
const Scalar) {
131 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
132 THIS_TYPE_IS_NOT_SUPPORTED);
139 struct digamma_impl_maybe_poly<float> {
141 static EIGEN_STRONG_INLINE
float run(
const float s) {
143 -4.16666666666666666667E-3f,
144 3.96825396825396825397E-3f,
145 -8.33333333333333333333E-3f,
146 8.33333333333333333333E-2f
152 return z * internal::ppolevl<float, 3>::run(z, A);
158 struct digamma_impl_maybe_poly<double> {
160 static EIGEN_STRONG_INLINE
double run(
const double s) {
162 8.33333333333333333333E-2,
163 -2.10927960927960927961E-2,
164 7.57575757575757575758E-3,
165 -4.16666666666666666667E-3,
166 3.96825396825396825397E-3,
167 -8.33333333333333333333E-3,
168 8.33333333333333333333E-2
174 return z * internal::ppolevl<double, 6>::run(z, A);
180 template <
typename Scalar>
181 struct digamma_impl {
183 static Scalar run(Scalar x) {
241 Scalar p, q, nz, s, w, y;
242 bool negative =
false;
244 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
245 const Scalar m_pi = Scalar(EIGEN_PI);
247 const Scalar zero = Scalar(0);
248 const Scalar one = Scalar(1);
249 const Scalar half = Scalar(0.5);
255 p = numext::floor(q);
268 nz = m_pi / numext::tan(m_pi * nz);
279 while (s < Scalar(10)) {
284 y = digamma_impl_maybe_poly<Scalar>::run(s);
286 y = numext::log(s) - (half / s) - y - w;
288 return (negative) ? y - nz : y;
303 template <
typename T>
304 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(
const T& a_x) {
307 const T plus_4 = pset1<T>(4.f);
308 const T minus_4 = pset1<T>(-4.f);
309 const T x = pmax(pmin(a_x, plus_4), minus_4);
311 const T alpha_1 = pset1<T>(-1.60960333262415e-02f);
312 const T alpha_3 = pset1<T>(-2.95459980854025e-03f);
313 const T alpha_5 = pset1<T>(-7.34990630326855e-04f);
314 const T alpha_7 = pset1<T>(-5.69250639462346e-05f);
315 const T alpha_9 = pset1<T>(-2.10102402082508e-06f);
316 const T alpha_11 = pset1<T>(2.77068142495902e-08f);
317 const T alpha_13 = pset1<T>(-2.72614225801306e-10f);
320 const T beta_0 = pset1<T>(-1.42647390514189e-02f);
321 const T beta_2 = pset1<T>(-7.37332916720468e-03f);
322 const T beta_4 = pset1<T>(-1.68282697438203e-03f);
323 const T beta_6 = pset1<T>(-2.13374055278905e-04f);
324 const T beta_8 = pset1<T>(-1.45660718464996e-05f);
327 const T x2 = pmul(x, x);
330 T p = pmadd(x2, alpha_13, alpha_11);
331 p = pmadd(x2, p, alpha_9);
332 p = pmadd(x2, p, alpha_7);
333 p = pmadd(x2, p, alpha_5);
334 p = pmadd(x2, p, alpha_3);
335 p = pmadd(x2, p, alpha_1);
339 T q = pmadd(x2, beta_8, beta_6);
340 q = pmadd(x2, q, beta_4);
341 q = pmadd(x2, q, beta_2);
342 q = pmadd(x2, q, beta_0);
348 template <
typename T>
351 static EIGEN_STRONG_INLINE T run(
const T& x) {
352 return generic_fast_erf_float(x);
356 template <
typename Scalar>
361 #if EIGEN_HAS_C99_MATH 363 struct erf_impl<float> {
365 static EIGEN_STRONG_INLINE
float run(
float x) {
366 #if defined(SYCL_DEVICE_ONLY) 367 return cl::sycl::erf(x);
369 return generic_fast_erf_float(x);
375 struct erf_impl<double> {
377 static EIGEN_STRONG_INLINE
double run(
double x) {
378 #if defined(SYCL_DEVICE_ONLY) 379 return cl::sycl::erf(x);
385 #endif // EIGEN_HAS_C99_MATH 391 template <
typename Scalar>
394 static EIGEN_STRONG_INLINE Scalar run(
const Scalar) {
395 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
396 THIS_TYPE_IS_NOT_SUPPORTED);
401 template <
typename Scalar>
406 #if EIGEN_HAS_C99_MATH 408 struct erfc_impl<float> {
410 static EIGEN_STRONG_INLINE
float run(
const float x) {
411 #if defined(SYCL_DEVICE_ONLY) 412 return cl::sycl::erfc(x);
420 struct erfc_impl<double> {
422 static EIGEN_STRONG_INLINE
double run(
const double x) {
423 #if defined(SYCL_DEVICE_ONLY) 424 return cl::sycl::erfc(x);
430 #endif // EIGEN_HAS_C99_MATH 491 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T flipsign(
492 const T& should_flipsign,
const T& x) {
493 typedef typename unpacket_traits<T>::type Scalar;
494 const T sign_mask = pset1<T>(Scalar(-0.0));
495 T sign_bit = pand<T>(should_flipsign, sign_mask);
496 return pxor<T>(sign_bit, x);
500 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
double flipsign<double>(
501 const double& should_flipsign,
const double& x) {
502 return should_flipsign == 0 ? x : -x;
506 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float flipsign<float>(
507 const float& should_flipsign,
const float& x) {
508 return should_flipsign == 0 ? x : -x;
515 template <
typename T,
typename ScalarType>
516 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_gt_exp_neg_two(
const T& b) {
517 const ScalarType p0[] = {
518 ScalarType(-5.99633501014107895267e1),
519 ScalarType(9.80010754185999661536e1),
520 ScalarType(-5.66762857469070293439e1),
521 ScalarType(1.39312609387279679503e1),
522 ScalarType(-1.23916583867381258016e0)
524 const ScalarType q0[] = {
526 ScalarType(1.95448858338141759834e0),
527 ScalarType(4.67627912898881538453e0),
528 ScalarType(8.63602421390890590575e1),
529 ScalarType(-2.25462687854119370527e2),
530 ScalarType(2.00260212380060660359e2),
531 ScalarType(-8.20372256168333339912e1),
532 ScalarType(1.59056225126211695515e1),
533 ScalarType(-1.18331621121330003142e0)
535 const T sqrt2pi = pset1<T>(ScalarType(2.50662827463100050242e0));
536 const T half = pset1<T>(ScalarType(0.5));
537 T c, c2, ndtri_gt_exp_neg_two;
541 ndtri_gt_exp_neg_two = pmadd(c, pmul(
543 internal::ppolevl<T, 4>::run(c2, p0),
544 internal::ppolevl<T, 8>::run(c2, q0))), c);
545 return pmul(ndtri_gt_exp_neg_two, sqrt2pi);
548 template <
typename T,
typename ScalarType>
549 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_lt_exp_neg_two(
550 const T& b,
const T& should_flipsign) {
554 const ScalarType p1[] = {
555 ScalarType(4.05544892305962419923e0),
556 ScalarType(3.15251094599893866154e1),
557 ScalarType(5.71628192246421288162e1),
558 ScalarType(4.40805073893200834700e1),
559 ScalarType(1.46849561928858024014e1),
560 ScalarType(2.18663306850790267539e0),
561 ScalarType(-1.40256079171354495875e-1),
562 ScalarType(-3.50424626827848203418e-2),
563 ScalarType(-8.57456785154685413611e-4)
565 const ScalarType q1[] = {
567 ScalarType(1.57799883256466749731e1),
568 ScalarType(4.53907635128879210584e1),
569 ScalarType(4.13172038254672030440e1),
570 ScalarType(1.50425385692907503408e1),
571 ScalarType(2.50464946208309415979e0),
572 ScalarType(-1.42182922854787788574e-1),
573 ScalarType(-3.80806407691578277194e-2),
574 ScalarType(-9.33259480895457427372e-4)
579 const ScalarType p2[] = {
580 ScalarType(3.23774891776946035970e0),
581 ScalarType(6.91522889068984211695e0),
582 ScalarType(3.93881025292474443415e0),
583 ScalarType(1.33303460815807542389e0),
584 ScalarType(2.01485389549179081538e-1),
585 ScalarType(1.23716634817820021358e-2),
586 ScalarType(3.01581553508235416007e-4),
587 ScalarType(2.65806974686737550832e-6),
588 ScalarType(6.23974539184983293730e-9)
590 const ScalarType q2[] = {
592 ScalarType(6.02427039364742014255e0),
593 ScalarType(3.67983563856160859403e0),
594 ScalarType(1.37702099489081330271e0),
595 ScalarType(2.16236993594496635890e-1),
596 ScalarType(1.34204006088543189037e-2),
597 ScalarType(3.28014464682127739104e-4),
598 ScalarType(2.89247864745380683936e-6),
599 ScalarType(6.79019408009981274425e-9)
601 const T eight = pset1<T>(ScalarType(8.0));
602 const T one = pset1<T>(ScalarType(1));
603 const T neg_two = pset1<T>(ScalarType(-2));
606 x = psqrt(pmul(neg_two, plog(b)));
607 x0 = psub(x, pdiv(plog(x), x));
612 pdiv(internal::ppolevl<T, 8>::run(z, p1),
613 internal::ppolevl<T, 8>::run(z, q1)),
614 pdiv(internal::ppolevl<T, 8>::run(z, p2),
615 internal::ppolevl<T, 8>::run(z, q2))));
616 return flipsign(should_flipsign, psub(x0, x1));
619 template <
typename T,
typename ScalarType>
620 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
621 T generic_ndtri(
const T& a) {
622 const T maxnum = pset1<T>(NumTraits<ScalarType>::infinity());
623 const T neg_maxnum = pset1<T>(-NumTraits<ScalarType>::infinity());
625 const T zero = pset1<T>(ScalarType(0));
626 const T one = pset1<T>(ScalarType(1));
628 const T exp_neg_two = pset1<T>(ScalarType(0.13533528323661269189));
629 T b,
ndtri, should_flipsign;
631 should_flipsign = pcmp_le(a, psub(one, exp_neg_two));
632 b = pselect(should_flipsign, a, psub(one, a));
635 pcmp_lt(exp_neg_two, b),
636 generic_ndtri_gt_exp_neg_two<T, ScalarType>(b),
637 generic_ndtri_lt_exp_neg_two<T, ScalarType>(b, should_flipsign));
640 pcmp_le(a, zero), neg_maxnum,
641 pselect(pcmp_le(one, a), maxnum, ndtri));
644 template <
typename Scalar>
645 struct ndtri_retval {
649 #if !EIGEN_HAS_C99_MATH 651 template <
typename Scalar>
654 static EIGEN_STRONG_INLINE Scalar run(
const Scalar) {
655 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
656 THIS_TYPE_IS_NOT_SUPPORTED);
663 template <
typename Scalar>
666 static EIGEN_STRONG_INLINE Scalar run(
const Scalar x) {
667 return generic_ndtri<Scalar, Scalar>(x);
671 #endif // EIGEN_HAS_C99_MATH 678 template <
typename Scalar>
679 struct igammac_retval {
684 template <
typename Scalar>
685 struct cephes_helper {
687 static EIGEN_STRONG_INLINE Scalar machep() { assert(
false &&
"machep not supported for this type");
return 0.0; }
689 static EIGEN_STRONG_INLINE Scalar big() { assert(
false &&
"big not supported for this type");
return 0.0; }
691 static EIGEN_STRONG_INLINE Scalar biginv() { assert(
false &&
"biginv not supported for this type");
return 0.0; }
695 struct cephes_helper<float> {
697 static EIGEN_STRONG_INLINE
float machep() {
698 return NumTraits<float>::epsilon() / 2;
701 static EIGEN_STRONG_INLINE
float big() {
703 return 1.0f / (NumTraits<float>::epsilon() / 2);
706 static EIGEN_STRONG_INLINE
float biginv() {
713 struct cephes_helper<double> {
715 static EIGEN_STRONG_INLINE
double machep() {
716 return NumTraits<double>::epsilon() / 2;
719 static EIGEN_STRONG_INLINE
double big() {
720 return 1.0 / NumTraits<double>::epsilon();
723 static EIGEN_STRONG_INLINE
double biginv() {
725 return NumTraits<double>::epsilon();
729 enum IgammaComputationMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
731 template <
typename Scalar>
733 static EIGEN_STRONG_INLINE Scalar main_igamma_term(Scalar a, Scalar x) {
735 Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
736 if (logax < -numext::log(NumTraits<Scalar>::highest()) ||
738 (numext::isnan)(logax)) {
741 return numext::exp(logax);
744 template <
typename Scalar, IgammaComputationMode mode>
746 int igamma_num_iterations() {
753 if (internal::is_same<Scalar, float>::value) {
755 }
else if (internal::is_same<Scalar, double>::value) {
762 template <
typename Scalar, IgammaComputationMode mode>
763 struct igammac_cf_impl {
774 static Scalar run(Scalar a, Scalar x) {
775 const Scalar zero = 0;
776 const Scalar one = 1;
777 const Scalar two = 2;
778 const Scalar machep = cephes_helper<Scalar>::machep();
779 const Scalar big = cephes_helper<Scalar>::big();
780 const Scalar biginv = cephes_helper<Scalar>::biginv();
782 if ((numext::isinf)(x)) {
786 Scalar ax = main_igamma_term<Scalar>(a, x);
797 Scalar z = x + y + one;
801 Scalar pkm1 = x + one;
803 Scalar ans = pkm1 / qkm1;
805 Scalar dpkm2_da = zero;
806 Scalar dqkm2_da = zero;
807 Scalar dpkm1_da = zero;
808 Scalar dqkm1_da = -x;
809 Scalar dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
811 for (
int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
817 Scalar pk = pkm1 * z - pkm2 * yc;
818 Scalar qk = qkm1 * z - qkm2 * yc;
820 Scalar dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c;
821 Scalar dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c;
824 Scalar ans_prev = ans;
827 Scalar dans_da_prev = dans_da;
828 dans_da = (dpk_da - ans * dqk_da) / qk;
831 if (numext::abs(ans_prev - ans) <= machep * numext::abs(ans)) {
835 if (numext::abs(dans_da - dans_da_prev) <= machep) {
851 if (numext::abs(pk) > big) {
865 Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a);
866 Scalar dax_da = ax * dlogax_da;
872 return ans * dax_da + dans_da * ax;
873 case SAMPLE_DERIVATIVE:
875 return -(dans_da + ans * dlogax_da) * x;
880 template <
typename Scalar, IgammaComputationMode mode>
881 struct igamma_series_impl {
891 static Scalar run(Scalar a, Scalar x) {
892 const Scalar zero = 0;
893 const Scalar one = 1;
894 const Scalar machep = cephes_helper<Scalar>::machep();
896 Scalar ax = main_igamma_term<Scalar>(a, x);
914 Scalar dans_da = zero;
916 for (
int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
919 Scalar dterm_da = -x / (r * r);
920 dc_da = term * dc_da + dterm_da * c;
926 if (c <= machep * ans) {
930 if (numext::abs(dc_da) <= machep * numext::abs(dans_da)) {
936 Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a + one);
937 Scalar dax_da = ax * dlogax_da;
943 return ans * dax_da + dans_da * ax;
944 case SAMPLE_DERIVATIVE:
946 return -(dans_da + ans * dlogax_da) * x / a;
951 #if !EIGEN_HAS_C99_MATH 953 template <
typename Scalar>
954 struct igammac_impl {
956 static Scalar run(Scalar a, Scalar x) {
957 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
958 THIS_TYPE_IS_NOT_SUPPORTED);
965 template <
typename Scalar>
966 struct igammac_impl {
968 static Scalar run(Scalar a, Scalar x) {
1023 const Scalar zero = 0;
1024 const Scalar one = 1;
1025 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1027 if ((x < zero) || (a <= zero)) {
1032 if ((numext::isnan)(a) || (numext::isnan)(x)) {
1036 if ((x < one) || (x < a)) {
1037 return (one - igamma_series_impl<Scalar, VALUE>::run(a, x));
1040 return igammac_cf_impl<Scalar, VALUE>::run(a, x);
1044 #endif // EIGEN_HAS_C99_MATH 1050 #if !EIGEN_HAS_C99_MATH 1052 template <
typename Scalar, IgammaComputationMode mode>
1053 struct igamma_generic_impl {
1055 static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar x) {
1056 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
1057 THIS_TYPE_IS_NOT_SUPPORTED);
1064 template <
typename Scalar, IgammaComputationMode mode>
1065 struct igamma_generic_impl {
1067 static Scalar run(Scalar a, Scalar x) {
1076 const Scalar zero = 0;
1077 const Scalar one = 1;
1078 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1080 if (x == zero)
return zero;
1082 if ((x < zero) || (a <= zero)) {
1086 if ((numext::isnan)(a) || (numext::isnan)(x)) {
1090 if ((x > one) && (x > a)) {
1091 Scalar ret = igammac_cf_impl<Scalar, mode>::run(a, x);
1092 if (mode == VALUE) {
1099 return igamma_series_impl<Scalar, mode>::run(a, x);
1103 #endif // EIGEN_HAS_C99_MATH 1105 template <
typename Scalar>
1106 struct igamma_retval {
1107 typedef Scalar type;
1110 template <
typename Scalar>
1111 struct igamma_impl : igamma_generic_impl<Scalar, VALUE> {
1181 template <
typename Scalar>
1182 struct igamma_der_a_retval : igamma_retval<Scalar> {};
1184 template <
typename Scalar>
1185 struct igamma_der_a_impl : igamma_generic_impl<Scalar, DERIVATIVE> {
1202 template <
typename Scalar>
1203 struct gamma_sample_der_alpha_retval : igamma_retval<Scalar> {};
1205 template <
typename Scalar>
1206 struct gamma_sample_der_alpha_impl
1207 : igamma_generic_impl<Scalar, SAMPLE_DERIVATIVE> {
1251 template <
typename Scalar>
1252 struct zeta_retval {
1253 typedef Scalar type;
1256 template <
typename Scalar>
1257 struct zeta_impl_series {
1259 static EIGEN_STRONG_INLINE Scalar run(
const Scalar) {
1260 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
1261 THIS_TYPE_IS_NOT_SUPPORTED);
1267 struct zeta_impl_series<float> {
1269 static EIGEN_STRONG_INLINE
bool run(
float& a,
float& b,
float& s,
const float x,
const float machep) {
1275 b = numext::pow( a, -x );
1277 if( numext::abs(b/s) < machep )
1287 struct zeta_impl_series<double> {
1289 static EIGEN_STRONG_INLINE
bool run(
double& a,
double& b,
double& s,
const double x,
const double machep) {
1291 while( (i < 9) || (a <= 9.0) )
1295 b = numext::pow( a, -x );
1297 if( numext::abs(b/s) < machep )
1306 template <
typename Scalar>
1309 static Scalar run(Scalar x, Scalar q) {
1372 Scalar p, r, a, b, k, s, t, w;
1374 const Scalar A[] = {
1380 Scalar(-1.8924375803183791606e9),
1381 Scalar(7.47242496e10),
1382 Scalar(-2.950130727918164224e12),
1383 Scalar(1.1646782814350067249e14),
1384 Scalar(-4.5979787224074726105e15),
1385 Scalar(1.8152105401943546773e17),
1386 Scalar(-7.1661652561756670113e18)
1389 const Scalar maxnum = NumTraits<Scalar>::infinity();
1390 const Scalar zero = 0.0, half = 0.5, one = 1.0;
1391 const Scalar machep = cephes_helper<Scalar>::machep();
1392 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1404 if(q == numext::floor(q))
1406 if (x == numext::floor(x) &&
long(x) % 2 == 0) {
1414 r = numext::floor(p);
1424 s = numext::pow( q, -x );
1428 if (zeta_impl_series<Scalar>::run(a, b, s, x, machep)) {
1437 for( i=0; i<12; i++ )
1443 t = numext::abs(t/s);
1460 template <
typename Scalar>
1461 struct polygamma_retval {
1462 typedef Scalar type;
1465 #if !EIGEN_HAS_C99_MATH 1467 template <
typename Scalar>
1468 struct polygamma_impl {
1470 static EIGEN_STRONG_INLINE Scalar run(Scalar n, Scalar x) {
1471 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
1472 THIS_TYPE_IS_NOT_SUPPORTED);
1479 template <
typename Scalar>
1480 struct polygamma_impl {
1482 static Scalar run(Scalar n, Scalar x) {
1483 Scalar zero = 0.0, one = 1.0;
1484 Scalar nplus = n + one;
1485 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1488 if (numext::floor(n) != n || n < zero) {
1492 else if (n == zero) {
1493 return digamma_impl<Scalar>::run(x);
1497 Scalar factorial = numext::exp(lgamma_impl<Scalar>::run(nplus));
1498 return numext::pow(-one, nplus) * factorial * zeta_impl<Scalar>::run(nplus, x);
1503 #endif // EIGEN_HAS_C99_MATH 1509 template <
typename Scalar>
1510 struct betainc_retval {
1511 typedef Scalar type;
1514 #if !EIGEN_HAS_C99_MATH 1516 template <
typename Scalar>
1517 struct betainc_impl {
1519 static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x) {
1520 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
1521 THIS_TYPE_IS_NOT_SUPPORTED);
1528 template <
typename Scalar>
1529 struct betainc_impl {
1531 static EIGEN_STRONG_INLINE Scalar run(Scalar, Scalar, Scalar) {
1601 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
1602 THIS_TYPE_IS_NOT_SUPPORTED);
1610 template <
typename Scalar>
1611 struct incbeta_cfe {
1613 static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x,
bool small_branch) {
1614 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, float>::value ||
1615 internal::is_same<Scalar, double>::value),
1616 THIS_TYPE_IS_NOT_SUPPORTED);
1617 const Scalar big = cephes_helper<Scalar>::big();
1618 const Scalar machep = cephes_helper<Scalar>::machep();
1619 const Scalar biginv = cephes_helper<Scalar>::biginv();
1621 const Scalar zero = 0;
1622 const Scalar one = 1;
1623 const Scalar two = 2;
1625 Scalar xk, pk, pkm1, pkm2, qk, qkm1, qkm2;
1626 Scalar k1, k2, k3, k4, k5, k6, k7, k8, k26update;
1630 const int num_iters = (internal::is_same<Scalar, float>::value) ? 100 : 300;
1631 const Scalar thresh =
1632 (internal::is_same<Scalar, float>::value) ? machep : Scalar(3) * machep;
1633 Scalar r = (internal::is_same<Scalar, float>::value) ? zero : one;
1666 xk = -(x * k1 * k2) / (k3 * k4);
1667 pk = pkm1 + pkm2 * xk;
1668 qk = qkm1 + qkm2 * xk;
1674 xk = (x * k5 * k6) / (k7 * k8);
1675 pk = pkm1 + pkm2 * xk;
1676 qk = qkm1 + qkm2 * xk;
1684 if (numext::abs(ans - r) < numext::abs(r) * thresh) {
1699 if ((numext::abs(qk) + numext::abs(pk)) > big) {
1705 if ((numext::abs(qk) < biginv) || (numext::abs(pk) < biginv)) {
1711 }
while (++n < num_iters);
1718 template <
typename Scalar>
1719 struct betainc_helper {};
1722 struct betainc_helper<float> {
1724 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
float incbsa(
float aa,
float bb,
1726 float ans, a, b, t, x, onemx;
1727 bool reversed_a_b =
false;
1732 if (xx > (aa / (aa + bb))) {
1733 reversed_a_b =
true;
1747 if (numext::abs(b * x / a) < 0.3f) {
1748 t = betainc_helper<float>::incbps(a, b, x);
1749 if (reversed_a_b) t = 1.0f - t;
1754 ans = x * (a + b - 2.0f) / (a - 1.0f);
1756 ans = incbeta_cfe<float>::run(a, b, x,
true );
1757 t = b * numext::log(t);
1759 ans = incbeta_cfe<float>::run(a, b, x,
false );
1760 t = (b - 1.0f) * numext::log(t);
1763 t += a * numext::log(x) + lgamma_impl<float>::run(a + b) -
1764 lgamma_impl<float>::run(a) - lgamma_impl<float>::run(b);
1765 t += numext::log(ans / a);
1768 if (reversed_a_b) t = 1.0f - t;
1773 static EIGEN_STRONG_INLINE
float incbps(
float a,
float b,
float x) {
1775 const float machep = cephes_helper<float>::machep();
1777 y = a * numext::log(x) + (b - 1.0f) * numext::log1p(-x) - numext::log(a);
1778 y -= lgamma_impl<float>::run(a) + lgamma_impl<float>::run(b);
1779 y += lgamma_impl<float>::run(a + b);
1792 }
while (numext::abs(u) > machep);
1794 return numext::exp(y) * (1.0f + s);
1799 struct betainc_impl<float> {
1801 static float run(
float a,
float b,
float x) {
1802 const float nan = NumTraits<float>::quiet_NaN();
1805 if (a <= 0.0f)
return nan;
1806 if (b <= 0.0f)
return nan;
1807 if ((x <= 0.0f) || (x >= 1.0f)) {
1808 if (x == 0.0f)
return 0.0f;
1809 if (x == 1.0f)
return 1.0f;
1816 ans = betainc_helper<float>::incbsa(a + 1.0f, b, x);
1817 t = a * numext::log(x) + b * numext::log1p(-x) +
1818 lgamma_impl<float>::run(a + b) - lgamma_impl<float>::run(a + 1.0f) -
1819 lgamma_impl<float>::run(b);
1820 return (ans + numext::exp(t));
1822 return betainc_helper<float>::incbsa(a, b, x);
1828 struct betainc_helper<double> {
1830 static EIGEN_STRONG_INLINE
double incbps(
double a,
double b,
double x) {
1831 const double machep = cephes_helper<double>::machep();
1833 double s, t, u, v, n, t1, z, ai;
1843 while (numext::abs(v) > z) {
1844 u = (n - b) * x / n;
1853 u = a * numext::log(x);
1861 t = lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) -
1862 lgamma_impl<double>::run(b) + u + numext::log(s);
1863 return s = numext::exp(t);
1868 struct betainc_impl<double> {
1870 static double run(
double aa,
double bb,
double xx) {
1871 const double nan = NumTraits<double>::quiet_NaN();
1872 const double machep = cephes_helper<double>::machep();
1875 double a, b, t, x, xc, w, y;
1876 bool reversed_a_b =
false;
1878 if (aa <= 0.0 || bb <= 0.0) {
1882 if ((xx <= 0.0) || (xx >= 1.0)) {
1883 if (xx == 0.0)
return (0.0);
1884 if (xx == 1.0)
return (1.0);
1889 if ((bb * xx) <= 1.0 && xx <= 0.95) {
1890 return betainc_helper<double>::incbps(aa, bb, xx);
1896 if (xx > (aa / (aa + bb))) {
1897 reversed_a_b =
true;
1909 if (reversed_a_b && (b * x) <= 1.0 && x <= 0.95) {
1910 t = betainc_helper<double>::incbps(a, b, x);
1920 y = x * (a + b - 2.0) - (a - 1.0);
1922 w = incbeta_cfe<double>::run(a, b, x,
true );
1924 w = incbeta_cfe<double>::run(a, b, x,
false ) / xc;
1931 y = a * numext::log(x);
1932 t = b * numext::log(xc);
1945 y += t + lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) -
1946 lgamma_impl<double>::run(b);
1947 y += numext::log(w / a);
1964 #endif // EIGEN_HAS_C99_MATH 1970 template <
typename Scalar>
1971 EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
lgamma, Scalar)
1972 lgamma(
const Scalar& x) {
1973 return EIGEN_MATHFUNC_IMPL(
lgamma, Scalar)::run(x);
1976 template <
typename Scalar>
1977 EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
digamma, Scalar)
1979 return EIGEN_MATHFUNC_IMPL(
digamma, Scalar)::run(x);
1982 template <
typename Scalar>
1983 EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
zeta, Scalar)
1984 zeta(
const Scalar& x,
const Scalar& q) {
1985 return EIGEN_MATHFUNC_IMPL(
zeta, Scalar)::run(x, q);
1988 template <
typename Scalar>
1989 EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
polygamma, Scalar)
1990 polygamma(
const Scalar& n,
const Scalar& x) {
1991 return EIGEN_MATHFUNC_IMPL(
polygamma, Scalar)::run(n, x);
1994 template <
typename Scalar>
1995 EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
erf, Scalar)
1996 erf(
const Scalar& x) {
1997 return EIGEN_MATHFUNC_IMPL(
erf, Scalar)::run(x);
2000 template <
typename Scalar>
2001 EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
erfc, Scalar)
2002 erfc(
const Scalar& x) {
2003 return EIGEN_MATHFUNC_IMPL(
erfc, Scalar)::run(x);
2006 template <
typename Scalar>
2007 EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(ndtri, Scalar)
2008 ndtri(
const Scalar& x) {
2009 return EIGEN_MATHFUNC_IMPL(ndtri, Scalar)::run(x);
2012 template <
typename Scalar>
2013 EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
igamma, Scalar)
2014 igamma(
const Scalar& a,
const Scalar& x) {
2015 return EIGEN_MATHFUNC_IMPL(
igamma, Scalar)::run(a, x);
2018 template <
typename Scalar>
2019 EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
igamma_der_a, Scalar)
2021 return EIGEN_MATHFUNC_IMPL(
igamma_der_a, Scalar)::run(a, x);
2024 template <
typename Scalar>
2030 template <
typename Scalar>
2031 EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
igammac, Scalar)
2032 igammac(
const Scalar& a,
const Scalar& x) {
2033 return EIGEN_MATHFUNC_IMPL(
igammac, Scalar)::run(a, x);
2036 template <
typename Scalar>
2037 EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
betainc, Scalar)
2038 betainc(
const Scalar& a,
const Scalar& b,
const Scalar& x) {
2039 return EIGEN_MATHFUNC_IMPL(
betainc, Scalar)::run(a, b, x);
2045 #endif // EIGEN_SPECIAL_FUNCTIONS_H const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_der_a_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma_der_a(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition: SpecialFunctionsArrayAPI.h:51
Namespace containing all symbols from the Eigen library.
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_ndtri_op< typename Derived::Scalar >, const Derived > ndtri(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_lgamma_op< typename Derived::Scalar >, const Derived > lgamma(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_erf_op< typename Derived::Scalar >, const Derived > erf(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igammac_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igammac(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition: SpecialFunctionsArrayAPI.h:90
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition: SpecialFunctionsArrayAPI.h:28
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_erfc_op< typename Derived::Scalar >, const Derived > erfc(const Eigen::ArrayBase< Derived > &x)
const TensorCwiseTernaryOp< internal::scalar_betainc_op< typename XDerived::Scalar >, const ADerived, const BDerived, const XDerived > betainc(const ADerived &a, const BDerived &b, const XDerived &x)
Definition: TensorGlobalFunctions.h:24
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_gamma_sample_der_alpha_op< typename AlphaDerived::Scalar >, const AlphaDerived, const SampleDerived > gamma_sample_der_alpha(const Eigen::ArrayBase< AlphaDerived > &alpha, const Eigen::ArrayBase< SampleDerived > &sample)
Definition: SpecialFunctionsArrayAPI.h:72
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_digamma_op< typename Derived::Scalar >, const Derived > digamma(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_zeta_op< typename DerivedX::Scalar >, const DerivedX, const DerivedQ > zeta(const Eigen::ArrayBase< DerivedX > &x, const Eigen::ArrayBase< DerivedQ > &q)
Definition: SpecialFunctionsArrayAPI.h:156
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_polygamma_op< typename DerivedX::Scalar >, const DerivedN, const DerivedX > polygamma(const Eigen::ArrayBase< DerivedN > &n, const Eigen::ArrayBase< DerivedX > &x)
Definition: SpecialFunctionsArrayAPI.h:112