Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
TensorExpr.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
12 
13 namespace Eigen {
14 
30 namespace internal {
31 template<typename NullaryOp, typename XprType>
32 struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
33  : traits<XprType>
34 {
35  typedef traits<XprType> XprTraits;
36  typedef typename XprType::Scalar Scalar;
37  typedef typename XprType::Nested XprTypeNested;
38  typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
39  static const int NumDimensions = XprTraits::NumDimensions;
40  static const int Layout = XprTraits::Layout;
41  typedef typename XprTraits::PointerType PointerType;
42  enum {
43  Flags = 0
44  };
45 };
46 
47 } // end namespace internal
48 
49 
50 
51 template<typename NullaryOp, typename XprType>
52 class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors>
53 {
54  public:
55  typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
56  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
57  typedef typename XprType::CoeffReturnType CoeffReturnType;
58  typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested;
59  typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
60  typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
61 
62  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp())
63  : m_xpr(xpr), m_functor(func) {}
64 
65  EIGEN_DEVICE_FUNC
66  const typename internal::remove_all<typename XprType::Nested>::type&
67  nestedExpression() const { return m_xpr; }
68 
69  EIGEN_DEVICE_FUNC
70  const NullaryOp& functor() const { return m_functor; }
71 
72  protected:
73  typename XprType::Nested m_xpr;
74  const NullaryOp m_functor;
75 };
76 
77 
78 
79 namespace internal {
80 template<typename UnaryOp, typename XprType>
81 struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
82  : traits<XprType>
83 {
84  // TODO(phli): Add InputScalar, InputPacket. Check references to
85  // current Scalar/Packet to see if the intent is Input or Output.
86  typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar;
87  typedef traits<XprType> XprTraits;
88  typedef typename XprType::Nested XprTypeNested;
89  typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
90  static const int NumDimensions = XprTraits::NumDimensions;
91  static const int Layout = XprTraits::Layout;
92  typedef typename TypeConversion<Scalar,
93  typename XprTraits::PointerType
94  >::type
95  PointerType;
96 };
97 
98 template<typename UnaryOp, typename XprType>
99 struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense>
100 {
101  typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type;
102 };
103 
104 template<typename UnaryOp, typename XprType>
105 struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type>
106 {
107  typedef TensorCwiseUnaryOp<UnaryOp, XprType> type;
108 };
109 
110 } // end namespace internal
111 
112 
113 
114 template<typename UnaryOp, typename XprType>
115 class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors>
116 {
117  public:
118  // TODO(phli): Add InputScalar, InputPacket. Check references to
119  // current Scalar/Packet to see if the intent is Input or Output.
120  typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar;
121  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
122  typedef Scalar CoeffReturnType;
123  typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested;
124  typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind;
125  typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index;
126 
127  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp())
128  : m_xpr(xpr), m_functor(func) {}
129 
130  EIGEN_DEVICE_FUNC
131  const UnaryOp& functor() const { return m_functor; }
132 
134  EIGEN_DEVICE_FUNC
135  const typename internal::remove_all<typename XprType::Nested>::type&
136  nestedExpression() const { return m_xpr; }
137 
138  protected:
139  typename XprType::Nested m_xpr;
140  const UnaryOp m_functor;
141 };
142 
143 
144 namespace internal {
145 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
146 struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
147 {
148  // Type promotion to handle the case where the types of the lhs and the rhs
149  // are different.
150  // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
151  // current Scalar/Packet to see if the intent is Inputs or Output.
152  typedef typename result_of<
153  BinaryOp(typename LhsXprType::Scalar,
154  typename RhsXprType::Scalar)>::type Scalar;
155  typedef traits<LhsXprType> XprTraits;
156  typedef typename promote_storage_type<
157  typename traits<LhsXprType>::StorageKind,
158  typename traits<RhsXprType>::StorageKind>::ret StorageKind;
159  typedef typename promote_index_type<
160  typename traits<LhsXprType>::Index,
161  typename traits<RhsXprType>::Index>::type Index;
162  typedef typename LhsXprType::Nested LhsNested;
163  typedef typename RhsXprType::Nested RhsNested;
164  typedef typename remove_reference<LhsNested>::type _LhsNested;
165  typedef typename remove_reference<RhsNested>::type _RhsNested;
166  static const int NumDimensions = XprTraits::NumDimensions;
167  static const int Layout = XprTraits::Layout;
168  typedef typename TypeConversion<Scalar,
169  typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
170  typename traits<LhsXprType>::PointerType,
171  typename traits<RhsXprType>::PointerType>::type
172  >::type
173  PointerType;
174  enum {
175  Flags = 0
176  };
177 };
178 
179 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
180 struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense>
181 {
182  typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type;
183 };
184 
185 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
186 struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type>
187 {
188  typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type;
189 };
190 
191 } // end namespace internal
192 
193 
194 
195 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
196 class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors>
197 {
198  public:
199  // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
200  // current Scalar/Packet to see if the intent is Inputs or Output.
201  typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
202  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
203  typedef Scalar CoeffReturnType;
204  typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
205  typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
206  typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
207 
208  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
209  : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
210 
211  EIGEN_DEVICE_FUNC
212  const BinaryOp& functor() const { return m_functor; }
213 
215  EIGEN_DEVICE_FUNC
216  const typename internal::remove_all<typename LhsXprType::Nested>::type&
217  lhsExpression() const { return m_lhs_xpr; }
218 
219  EIGEN_DEVICE_FUNC
220  const typename internal::remove_all<typename RhsXprType::Nested>::type&
221  rhsExpression() const { return m_rhs_xpr; }
222 
223  protected:
224  typename LhsXprType::Nested m_lhs_xpr;
225  typename RhsXprType::Nested m_rhs_xpr;
226  const BinaryOp m_functor;
227 };
228 
229 
230 namespace internal {
231 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
232 struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
233 {
234  // Type promotion to handle the case where the types of the args are different.
235  typedef typename result_of<
236  TernaryOp(typename Arg1XprType::Scalar,
237  typename Arg2XprType::Scalar,
238  typename Arg3XprType::Scalar)>::type Scalar;
239  typedef traits<Arg1XprType> XprTraits;
240  typedef typename traits<Arg1XprType>::StorageKind StorageKind;
241  typedef typename traits<Arg1XprType>::Index Index;
242  typedef typename Arg1XprType::Nested Arg1Nested;
243  typedef typename Arg2XprType::Nested Arg2Nested;
244  typedef typename Arg3XprType::Nested Arg3Nested;
245  typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
246  typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
247  typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
248  static const int NumDimensions = XprTraits::NumDimensions;
249  static const int Layout = XprTraits::Layout;
250  typedef typename TypeConversion<Scalar,
251  typename conditional<Pointer_type_promotion<typename Arg2XprType::Scalar, Scalar>::val,
252  typename traits<Arg2XprType>::PointerType,
253  typename traits<Arg3XprType>::PointerType>::type
254  >::type
255  PointerType;
256  enum {
257  Flags = 0
258  };
259 };
260 
261 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
262 struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
263 {
264  typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
265 };
266 
267 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
268 struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
269 {
270  typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
271 };
272 
273 } // end namespace internal
274 
275 
276 
277 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
278 class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
279 {
280  public:
281  typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
282  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
283  typedef Scalar CoeffReturnType;
284  typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
285  typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
286  typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
287 
288  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
289  : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
290 
291  EIGEN_DEVICE_FUNC
292  const TernaryOp& functor() const { return m_functor; }
293 
295  EIGEN_DEVICE_FUNC
296  const typename internal::remove_all<typename Arg1XprType::Nested>::type&
297  arg1Expression() const { return m_arg1_xpr; }
298 
299  EIGEN_DEVICE_FUNC
300  const typename internal::remove_all<typename Arg2XprType::Nested>::type&
301  arg2Expression() const { return m_arg2_xpr; }
302 
303  EIGEN_DEVICE_FUNC
304  const typename internal::remove_all<typename Arg3XprType::Nested>::type&
305  arg3Expression() const { return m_arg3_xpr; }
306 
307  protected:
308  typename Arg1XprType::Nested m_arg1_xpr;
309  typename Arg2XprType::Nested m_arg2_xpr;
310  typename Arg3XprType::Nested m_arg3_xpr;
311  const TernaryOp m_functor;
312 };
313 
314 
315 namespace internal {
316 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
317 struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
318  : traits<ThenXprType>
319 {
320  typedef typename traits<ThenXprType>::Scalar Scalar;
321  typedef traits<ThenXprType> XprTraits;
322  typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
323  typename traits<ElseXprType>::StorageKind>::ret StorageKind;
324  typedef typename promote_index_type<typename traits<ElseXprType>::Index,
325  typename traits<ThenXprType>::Index>::type Index;
326  typedef typename IfXprType::Nested IfNested;
327  typedef typename ThenXprType::Nested ThenNested;
328  typedef typename ElseXprType::Nested ElseNested;
329  static const int NumDimensions = XprTraits::NumDimensions;
330  static const int Layout = XprTraits::Layout;
331  typedef typename conditional<Pointer_type_promotion<typename ThenXprType::Scalar, Scalar>::val,
332  typename traits<ThenXprType>::PointerType,
333  typename traits<ElseXprType>::PointerType>::type PointerType;
334 };
335 
336 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
337 struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
338 {
339  typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
340 };
341 
342 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
343 struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
344 {
345  typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
346 };
347 
348 } // end namespace internal
349 
350 
351 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
352 class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors>
353 {
354  public:
355  typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
356  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
357  typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
358  typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
359  typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
360  typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
361  typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
362 
363  EIGEN_DEVICE_FUNC
364  TensorSelectOp(const IfXprType& a_condition,
365  const ThenXprType& a_then,
366  const ElseXprType& a_else)
367  : m_condition(a_condition), m_then(a_then), m_else(a_else)
368  { }
369 
370  EIGEN_DEVICE_FUNC
371  const IfXprType& ifExpression() const { return m_condition; }
372 
373  EIGEN_DEVICE_FUNC
374  const ThenXprType& thenExpression() const { return m_then; }
375 
376  EIGEN_DEVICE_FUNC
377  const ElseXprType& elseExpression() const { return m_else; }
378 
379  protected:
380  typename IfXprType::Nested m_condition;
381  typename ThenXprType::Nested m_then;
382  typename ElseXprType::Nested m_else;
383 };
384 
385 
386 } // end namespace Eigen
387 
388 #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index