Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
SpecialFunctionsFunctors.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Eugene Brevdo <ebrevdo@gmail.com>
5 // Copyright (C) 2016 Gael Guennebaud <gael.guennebaud@inria.fr>
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #ifndef EIGEN_SPECIALFUNCTIONS_FUNCTORS_H
12 #define EIGEN_SPECIALFUNCTIONS_FUNCTORS_H
13 
14 namespace Eigen {
15 
16 namespace internal {
17 
18 
24 template<typename Scalar> struct scalar_igamma_op : binary_op_base<Scalar,Scalar>
25 {
26  EIGEN_EMPTY_STRUCT_CTOR(scalar_igamma_op)
27  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a, const Scalar& x) const {
28  using numext::igamma; return igamma(a, x);
29  }
30  template<typename Packet>
31  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const {
32  return internal::pigamma(a, x);
33  }
34 };
35 template<typename Scalar>
36 struct functor_traits<scalar_igamma_op<Scalar> > {
37  enum {
38  // Guesstimate
39  Cost = 20 * NumTraits<Scalar>::MulCost + 10 * NumTraits<Scalar>::AddCost,
40  PacketAccess = packet_traits<Scalar>::HasIGamma
41  };
42 };
43 
50 template <typename Scalar>
51 struct scalar_igamma_der_a_op {
52  EIGEN_EMPTY_STRUCT_CTOR(scalar_igamma_der_a_op)
53  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a, const Scalar& x) const {
54  using numext::igamma_der_a;
55  return igamma_der_a(a, x);
56  }
57  template <typename Packet>
58  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const {
59  return internal::pigamma_der_a(a, x);
60  }
61 };
62 template <typename Scalar>
63 struct functor_traits<scalar_igamma_der_a_op<Scalar> > {
64  enum {
65  // 2x the cost of igamma
66  Cost = 40 * NumTraits<Scalar>::MulCost + 20 * NumTraits<Scalar>::AddCost,
67  PacketAccess = packet_traits<Scalar>::HasIGammaDerA
68  };
69 };
70 
78 template <typename Scalar>
79 struct scalar_gamma_sample_der_alpha_op {
80  EIGEN_EMPTY_STRUCT_CTOR(scalar_gamma_sample_der_alpha_op)
81  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& alpha, const Scalar& sample) const {
82  using numext::gamma_sample_der_alpha;
83  return gamma_sample_der_alpha(alpha, sample);
84  }
85  template <typename Packet>
86  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& alpha, const Packet& sample) const {
87  return internal::pgamma_sample_der_alpha(alpha, sample);
88  }
89 };
90 template <typename Scalar>
91 struct functor_traits<scalar_gamma_sample_der_alpha_op<Scalar> > {
92  enum {
93  // 2x the cost of igamma, minus the lgamma cost (the lgamma cancels out)
94  Cost = 30 * NumTraits<Scalar>::MulCost + 15 * NumTraits<Scalar>::AddCost,
95  PacketAccess = packet_traits<Scalar>::HasGammaSampleDerAlpha
96  };
97 };
98 
104 template<typename Scalar> struct scalar_igammac_op : binary_op_base<Scalar,Scalar>
105 {
106  EIGEN_EMPTY_STRUCT_CTOR(scalar_igammac_op)
107  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a, const Scalar& x) const {
108  using numext::igammac; return igammac(a, x);
109  }
110  template<typename Packet>
111  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const
112  {
113  return internal::pigammac(a, x);
114  }
115 };
116 template<typename Scalar>
117 struct functor_traits<scalar_igammac_op<Scalar> > {
118  enum {
119  // Guesstimate
120  Cost = 20 * NumTraits<Scalar>::MulCost + 10 * NumTraits<Scalar>::AddCost,
121  PacketAccess = packet_traits<Scalar>::HasIGammac
122  };
123 };
124 
125 
130 template<typename Scalar> struct scalar_betainc_op {
131  EIGEN_EMPTY_STRUCT_CTOR(scalar_betainc_op)
132  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& x, const Scalar& a, const Scalar& b) const {
133  using numext::betainc; return betainc(x, a, b);
134  }
135  template<typename Packet>
136  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& x, const Packet& a, const Packet& b) const
137  {
138  return internal::pbetainc(x, a, b);
139  }
140 };
141 template<typename Scalar>
142 struct functor_traits<scalar_betainc_op<Scalar> > {
143  enum {
144  // Guesstimate
145  Cost = 400 * NumTraits<Scalar>::MulCost + 400 * NumTraits<Scalar>::AddCost,
146  PacketAccess = packet_traits<Scalar>::HasBetaInc
147  };
148 };
149 
150 
156 template<typename Scalar> struct scalar_lgamma_op {
157  EIGEN_EMPTY_STRUCT_CTOR(scalar_lgamma_op)
158  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const {
159  using numext::lgamma; return lgamma(a);
160  }
161  typedef typename packet_traits<Scalar>::type Packet;
162  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::plgamma(a); }
163 };
164 template<typename Scalar>
165 struct functor_traits<scalar_lgamma_op<Scalar> >
166 {
167  enum {
168  // Guesstimate
169  Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
170  PacketAccess = packet_traits<Scalar>::HasLGamma
171  };
172 };
173 
178 template<typename Scalar> struct scalar_digamma_op {
179  EIGEN_EMPTY_STRUCT_CTOR(scalar_digamma_op)
180  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const {
181  using numext::digamma; return digamma(a);
182  }
183  typedef typename packet_traits<Scalar>::type Packet;
184  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::pdigamma(a); }
185 };
186 template<typename Scalar>
187 struct functor_traits<scalar_digamma_op<Scalar> >
188 {
189  enum {
190  // Guesstimate
191  Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
192  PacketAccess = packet_traits<Scalar>::HasDiGamma
193  };
194 };
195 
200 template<typename Scalar> struct scalar_zeta_op {
201  EIGEN_EMPTY_STRUCT_CTOR(scalar_zeta_op)
202  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& x, const Scalar& q) const {
203  using numext::zeta; return zeta(x, q);
204  }
205  typedef typename packet_traits<Scalar>::type Packet;
206  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, const Packet& q) const { return internal::pzeta(x, q); }
207 };
208 template<typename Scalar>
209 struct functor_traits<scalar_zeta_op<Scalar> >
210 {
211  enum {
212  // Guesstimate
213  Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
214  PacketAccess = packet_traits<Scalar>::HasZeta
215  };
216 };
217 
222 template<typename Scalar> struct scalar_polygamma_op {
223  EIGEN_EMPTY_STRUCT_CTOR(scalar_polygamma_op)
224  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& n, const Scalar& x) const {
225  using numext::polygamma; return polygamma(n, x);
226  }
227  typedef typename packet_traits<Scalar>::type Packet;
228  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& n, const Packet& x) const { return internal::ppolygamma(n, x); }
229 };
230 template<typename Scalar>
231 struct functor_traits<scalar_polygamma_op<Scalar> >
232 {
233  enum {
234  // Guesstimate
235  Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
236  PacketAccess = packet_traits<Scalar>::HasPolygamma
237  };
238 };
239 
244 template<typename Scalar> struct scalar_erf_op {
245  EIGEN_EMPTY_STRUCT_CTOR(scalar_erf_op)
246  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
247  operator()(const Scalar& a) const {
248  return numext::erf(a);
249  }
250  template <typename Packet>
251  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
252  return perf(x);
253  }
254 };
255 template <typename Scalar>
256 struct functor_traits<scalar_erf_op<Scalar> > {
257  enum {
258  PacketAccess = packet_traits<Scalar>::HasErf,
259  Cost =
260  (PacketAccess
261 #ifdef EIGEN_VECTORIZE_FMA
262  // TODO(rmlarsen): Move the FMA cost model to a central location.
263  // Haswell can issue 2 add/mul/madd per cycle.
264  // 10 pmadd, 2 pmul, 1 div, 2 other
265  ? (2 * NumTraits<Scalar>::AddCost +
266  7 * NumTraits<Scalar>::MulCost +
267  scalar_div_cost<Scalar, packet_traits<Scalar>::HasDiv>::value)
268 #else
269  ? (12 * NumTraits<Scalar>::AddCost +
270  12 * NumTraits<Scalar>::MulCost +
271  scalar_div_cost<Scalar, packet_traits<Scalar>::HasDiv>::value)
272 #endif
273  // Assume for simplicity that this is as expensive as an exp().
274  : (functor_traits<scalar_exp_op<Scalar> >::Cost))
275  };
276 };
277 
283 template<typename Scalar> struct scalar_erfc_op {
284  EIGEN_EMPTY_STRUCT_CTOR(scalar_erfc_op)
285  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const {
286  using numext::erfc; return erfc(a);
287  }
288  typedef typename packet_traits<Scalar>::type Packet;
289  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::perfc(a); }
290 };
291 template<typename Scalar>
292 struct functor_traits<scalar_erfc_op<Scalar> >
293 {
294  enum {
295  // Guesstimate
296  Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
297  PacketAccess = packet_traits<Scalar>::HasErfc
298  };
299 };
300 
306 template<typename Scalar> struct scalar_ndtri_op {
307  EIGEN_EMPTY_STRUCT_CTOR(scalar_ndtri_op)
308  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const {
309  using numext::ndtri; return ndtri(a);
310  }
311  typedef typename packet_traits<Scalar>::type Packet;
312  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::pndtri(a); }
313 };
314 template<typename Scalar>
315 struct functor_traits<scalar_ndtri_op<Scalar> >
316 {
317  enum {
318  // On average, We are evaluating rational functions with degree N=9 in the
319  // numerator and denominator. This results in 2*N additions and 2*N
320  // multiplications.
321  Cost = 18 * NumTraits<Scalar>::MulCost + 18 * NumTraits<Scalar>::AddCost,
322  PacketAccess = packet_traits<Scalar>::HasNdtri
323  };
324 };
325 
326 } // end namespace internal
327 
328 } // end namespace Eigen
329 
330 #endif // EIGEN_SPECIALFUNCTIONS_FUNCTORS_H
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_der_a_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma_der_a(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition: SpecialFunctionsArrayAPI.h:51
Namespace containing all symbols from the Eigen library.
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_ndtri_op< typename Derived::Scalar >, const Derived > ndtri(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_lgamma_op< typename Derived::Scalar >, const Derived > lgamma(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igammac_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igammac(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition: SpecialFunctionsArrayAPI.h:90
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition: SpecialFunctionsArrayAPI.h:28
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_erfc_op< typename Derived::Scalar >, const Derived > erfc(const Eigen::ArrayBase< Derived > &x)
const TensorCwiseTernaryOp< internal::scalar_betainc_op< typename XDerived::Scalar >, const ADerived, const BDerived, const XDerived > betainc(const ADerived &a, const BDerived &b, const XDerived &x)
Definition: TensorGlobalFunctions.h:24
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_gamma_sample_der_alpha_op< typename AlphaDerived::Scalar >, const AlphaDerived, const SampleDerived > gamma_sample_der_alpha(const Eigen::ArrayBase< AlphaDerived > &alpha, const Eigen::ArrayBase< SampleDerived > &sample)
Definition: SpecialFunctionsArrayAPI.h:72
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_digamma_op< typename Derived::Scalar >, const Derived > digamma(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_zeta_op< typename DerivedX::Scalar >, const DerivedX, const DerivedQ > zeta(const Eigen::ArrayBase< DerivedX > &x, const Eigen::ArrayBase< DerivedQ > &q)
Definition: SpecialFunctionsArrayAPI.h:156
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_polygamma_op< typename DerivedX::Scalar >, const DerivedN, const DerivedX > polygamma(const Eigen::ArrayBase< DerivedN > &n, const Eigen::ArrayBase< DerivedX > &x)
Definition: SpecialFunctionsArrayAPI.h:112