11 #ifndef EIGEN_GENERAL_PRODUCT_H 12 #define EIGEN_GENERAL_PRODUCT_H 26 #ifndef EIGEN_GEMM_TO_COEFFBASED_THRESHOLD 28 #define EIGEN_GEMM_TO_COEFFBASED_THRESHOLD 20 33 template<
int Rows,
int Cols,
int Depth>
struct product_type_selector;
35 template<
int Size,
int MaxSize>
struct product_size_category
38 #ifndef EIGEN_GPU_COMPILE_PHASE 39 is_large = MaxSize ==
Dynamic ||
40 Size >= EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD ||
41 (Size==
Dynamic && MaxSize>=EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD),
45 value = is_large ? Large
51 template<
typename Lhs,
typename Rhs>
struct product_type
53 typedef typename remove_all<Lhs>::type _Lhs;
54 typedef typename remove_all<Rhs>::type _Rhs;
56 MaxRows = traits<_Lhs>::MaxRowsAtCompileTime,
57 Rows = traits<_Lhs>::RowsAtCompileTime,
58 MaxCols = traits<_Rhs>::MaxColsAtCompileTime,
59 Cols = traits<_Rhs>::ColsAtCompileTime,
60 MaxDepth = EIGEN_SIZE_MIN_PREFER_FIXED(traits<_Lhs>::MaxColsAtCompileTime,
61 traits<_Rhs>::MaxRowsAtCompileTime),
62 Depth = EIGEN_SIZE_MIN_PREFER_FIXED(traits<_Lhs>::ColsAtCompileTime,
63 traits<_Rhs>::RowsAtCompileTime)
70 rows_select = product_size_category<Rows,MaxRows>::value,
71 cols_select = product_size_category<Cols,MaxCols>::value,
72 depth_select = product_size_category<Depth,MaxDepth>::value
74 typedef product_type_selector<rows_select, cols_select, depth_select> selector;
78 value = selector::ret,
81 #ifdef EIGEN_DEBUG_PRODUCT 84 EIGEN_DEBUG_VAR(Rows);
85 EIGEN_DEBUG_VAR(Cols);
86 EIGEN_DEBUG_VAR(Depth);
87 EIGEN_DEBUG_VAR(rows_select);
88 EIGEN_DEBUG_VAR(cols_select);
89 EIGEN_DEBUG_VAR(depth_select);
90 EIGEN_DEBUG_VAR(value);
99 template<
int M,
int N>
struct product_type_selector<M,N,1> {
enum { ret = OuterProduct }; };
100 template<
int M>
struct product_type_selector<M, 1, 1> {
enum { ret = LazyCoeffBasedProductMode }; };
101 template<
int N>
struct product_type_selector<1, N, 1> {
enum { ret = LazyCoeffBasedProductMode }; };
102 template<
int Depth>
struct product_type_selector<1, 1, Depth> {
enum { ret = InnerProduct }; };
103 template<>
struct product_type_selector<1, 1, 1> {
enum { ret = InnerProduct }; };
104 template<>
struct product_type_selector<Small,1, Small> {
enum { ret = CoeffBasedProductMode }; };
105 template<>
struct product_type_selector<1, Small,Small> {
enum { ret = CoeffBasedProductMode }; };
106 template<>
struct product_type_selector<Small,Small,Small> {
enum { ret = CoeffBasedProductMode }; };
107 template<>
struct product_type_selector<Small, Small, 1> {
enum { ret = LazyCoeffBasedProductMode }; };
108 template<>
struct product_type_selector<Small, Large, 1> {
enum { ret = LazyCoeffBasedProductMode }; };
109 template<>
struct product_type_selector<Large, Small, 1> {
enum { ret = LazyCoeffBasedProductMode }; };
110 template<>
struct product_type_selector<1, Large,Small> {
enum { ret = CoeffBasedProductMode }; };
111 template<>
struct product_type_selector<1, Large,Large> {
enum { ret = GemvProduct }; };
112 template<>
struct product_type_selector<1, Small,Large> {
enum { ret = CoeffBasedProductMode }; };
113 template<>
struct product_type_selector<Large,1, Small> {
enum { ret = CoeffBasedProductMode }; };
114 template<>
struct product_type_selector<Large,1, Large> {
enum { ret = GemvProduct }; };
115 template<>
struct product_type_selector<Small,1, Large> {
enum { ret = CoeffBasedProductMode }; };
116 template<>
struct product_type_selector<Small,Small,Large> {
enum { ret = GemmProduct }; };
117 template<>
struct product_type_selector<Large,Small,Large> {
enum { ret = GemmProduct }; };
118 template<>
struct product_type_selector<Small,Large,Large> {
enum { ret = GemmProduct }; };
119 template<>
struct product_type_selector<Large,Large,Large> {
enum { ret = GemmProduct }; };
120 template<>
struct product_type_selector<Large,Small,Small> {
enum { ret = CoeffBasedProductMode }; };
121 template<>
struct product_type_selector<Small,Large,Small> {
enum { ret = CoeffBasedProductMode }; };
122 template<>
struct product_type_selector<Large,Large,Small> {
enum { ret = GemmProduct }; };
154 template<
int S
ide,
int StorageOrder,
bool BlasCompatible>
155 struct gemv_dense_selector;
161 template<
typename Scalar,
int Size,
int MaxSize,
bool Cond>
struct gemv_static_vector_if;
163 template<
typename Scalar,
int Size,
int MaxSize>
164 struct gemv_static_vector_if<Scalar,Size,MaxSize,false>
166 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Scalar* data() { eigen_internal_assert(
false &&
"should never be called");
return 0; }
169 template<
typename Scalar,
int Size>
170 struct gemv_static_vector_if<Scalar,Size,
Dynamic,true>
172 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Scalar* data() {
return 0; }
175 template<
typename Scalar,
int Size,
int MaxSize>
176 struct gemv_static_vector_if<Scalar,Size,MaxSize,true>
179 ForceAlignment = internal::packet_traits<Scalar>::Vectorizable,
180 PacketSize = internal::packet_traits<Scalar>::size
182 #if EIGEN_MAX_STATIC_ALIGN_BYTES!=0 183 internal::plain_array<Scalar,EIGEN_SIZE_MIN_PREFER_FIXED(Size,MaxSize),0,EIGEN_PLAIN_ENUM_MIN(AlignedMax,PacketSize)> m_data;
184 EIGEN_STRONG_INLINE Scalar* data() {
return m_data.array; }
188 internal::plain_array<Scalar,EIGEN_SIZE_MIN_PREFER_FIXED(Size,MaxSize)+(ForceAlignment?EIGEN_MAX_ALIGN_BYTES:0),0> m_data;
189 EIGEN_STRONG_INLINE Scalar* data() {
190 return ForceAlignment
191 ?
reinterpret_cast<Scalar*
>((internal::UIntPtr(m_data.array) & ~(std::size_t(EIGEN_MAX_ALIGN_BYTES-1))) + EIGEN_MAX_ALIGN_BYTES)
198 template<
int StorageOrder,
bool BlasCompatible>
199 struct gemv_dense_selector<
OnTheLeft,StorageOrder,BlasCompatible>
201 template<
typename Lhs,
typename Rhs,
typename Dest>
202 static void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
204 Transpose<Dest> destT(dest);
206 gemv_dense_selector<OnTheRight,OtherStorageOrder,BlasCompatible>
207 ::run(rhs.transpose(), lhs.transpose(), destT, alpha);
213 template<
typename Lhs,
typename Rhs,
typename Dest>
214 static inline void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
216 typedef typename Lhs::Scalar LhsScalar;
217 typedef typename Rhs::Scalar RhsScalar;
218 typedef typename Dest::Scalar ResScalar;
219 typedef typename Dest::RealScalar RealScalar;
221 typedef internal::blas_traits<Lhs> LhsBlasTraits;
222 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
223 typedef internal::blas_traits<Rhs> RhsBlasTraits;
224 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
226 typedef Map<Matrix<ResScalar,Dynamic,1>, EIGEN_PLAIN_ENUM_MIN(AlignedMax,internal::packet_traits<ResScalar>::size)> MappedDest;
228 ActualLhsType actualLhs = LhsBlasTraits::extract(lhs);
229 ActualRhsType actualRhs = RhsBlasTraits::extract(rhs);
231 ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
234 typedef typename conditional<Dest::IsVectorAtCompileTime, Dest, typename Dest::ColXpr>::type ActualDest;
239 EvalToDestAtCompileTime = (ActualDest::InnerStrideAtCompileTime==1),
240 ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
241 MightCannotUseDest = ((!EvalToDestAtCompileTime) || ComplexByReal) && (ActualDest::MaxSizeAtCompileTime!=0)
244 typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
245 typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
246 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
248 if(!MightCannotUseDest)
252 general_matrix_vector_product
253 <
Index,LhsScalar,LhsMapper,
ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
254 actualLhs.rows(), actualLhs.cols(),
255 LhsMapper(actualLhs.data(), actualLhs.outerStride()),
256 RhsMapper(actualRhs.data(), actualRhs.innerStride()),
262 gemv_static_vector_if<ResScalar,ActualDest::SizeAtCompileTime,ActualDest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
264 const bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
265 const bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
267 ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
268 evalToDest ? dest.data() : static_dest.data());
272 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN 273 Index size = dest.size();
274 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
276 if(!alphaIsCompatible)
278 MappedDest(actualDestPtr, dest.size()).setZero();
279 compatibleAlpha = RhsScalar(1);
282 MappedDest(actualDestPtr, dest.size()) = dest;
285 general_matrix_vector_product
286 <
Index,LhsScalar,LhsMapper,
ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
287 actualLhs.rows(), actualLhs.cols(),
288 LhsMapper(actualLhs.data(), actualLhs.outerStride()),
289 RhsMapper(actualRhs.data(), actualRhs.innerStride()),
295 if(!alphaIsCompatible)
296 dest.matrix() += actualAlpha * MappedDest(actualDestPtr, dest.size());
298 dest = MappedDest(actualDestPtr, dest.size());
306 template<
typename Lhs,
typename Rhs,
typename Dest>
307 static void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
309 typedef typename Lhs::Scalar LhsScalar;
310 typedef typename Rhs::Scalar RhsScalar;
311 typedef typename Dest::Scalar ResScalar;
313 typedef internal::blas_traits<Lhs> LhsBlasTraits;
314 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
315 typedef internal::blas_traits<Rhs> RhsBlasTraits;
316 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
317 typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
319 typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
320 typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
322 ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
327 DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1 || ActualRhsTypeCleaned::MaxSizeAtCompileTime==0
330 gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
332 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
333 DirectlyUseRhs ?
const_cast<RhsScalar*
>(actualRhs.data()) : static_rhs.data());
337 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN 338 Index size = actualRhs.size();
339 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
341 Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
344 typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
345 typedef const_blas_data_mapper<RhsScalar,Index,ColMajor> RhsMapper;
346 general_matrix_vector_product
347 <
Index,LhsScalar,LhsMapper,
RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
348 actualLhs.rows(), actualLhs.cols(),
349 LhsMapper(actualLhs.data(), actualLhs.outerStride()),
350 RhsMapper(actualRhsPtr, 1),
351 dest.data(), dest.col(0).innerStride(),
356 template<>
struct gemv_dense_selector<
OnTheRight,ColMajor,false>
358 template<
typename Lhs,
typename Rhs,
typename Dest>
359 static void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
361 EIGEN_STATIC_ASSERT((!nested_eval<Lhs,1>::Evaluate),EIGEN_INTERNAL_COMPILATION_ERROR_OR_YOU_MADE_A_PROGRAMMING_MISTAKE);
363 typename nested_eval<Rhs,1>::type actual_rhs(rhs);
364 const Index size = rhs.rows();
365 for(Index k=0; k<size; ++k)
366 dest += (alpha*actual_rhs.coeff(k)) * lhs.col(k);
370 template<>
struct gemv_dense_selector<
OnTheRight,RowMajor,false>
372 template<
typename Lhs,
typename Rhs,
typename Dest>
373 static void run(
const Lhs &lhs,
const Rhs &rhs, Dest& dest,
const typename Dest::Scalar& alpha)
375 EIGEN_STATIC_ASSERT((!nested_eval<Lhs,1>::Evaluate),EIGEN_INTERNAL_COMPILATION_ERROR_OR_YOU_MADE_A_PROGRAMMING_MISTAKE);
376 typename nested_eval<Rhs,Lhs::RowsAtCompileTime>::type actual_rhs(rhs);
377 const Index rows = dest.rows();
378 for(Index i=0; i<rows; ++i)
379 dest.coeffRef(i) += alpha * (lhs.row(i).cwiseProduct(actual_rhs.transpose())).sum();
395 template<
typename Derived>
396 template<
typename OtherDerived>
397 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
398 const Product<Derived, OtherDerived>
406 ProductIsValid = Derived::ColsAtCompileTime==
Dynamic 407 || OtherDerived::RowsAtCompileTime==
Dynamic 408 || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
409 AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
410 SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived)
415 EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
416 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
417 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
418 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
419 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
420 #ifdef EIGEN_DEBUG_PRODUCT 421 internal::product_type<Derived,OtherDerived>::debug();
438 template<
typename Derived>
439 template<
typename OtherDerived>
440 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
445 ProductIsValid = Derived::ColsAtCompileTime==
Dynamic 446 || OtherDerived::RowsAtCompileTime==
Dynamic 447 || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
448 AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
449 SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived)
454 EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
455 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
456 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
457 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
458 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
465 #endif // EIGEN_PRODUCT_H Definition: Constants.h:319
Expression of the product of two arbitrary matrices or vectors.
Definition: Product.h:71
Definition: Constants.h:334
Namespace containing all symbols from the Eigen library.
Definition: Core:141
Derived & derived()
Definition: EigenBase.h:46
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:74
Definition: Constants.h:332
Definition: Eigen_Colamd.h:50
Definition: Constants.h:321
const int Dynamic
Definition: Constants.h:22
Base class for all dense matrices, vectors, and expressions.
Definition: MatrixBase.h:48
const CwiseBinaryOp< internal::scalar_product_op< Scalar, T >, Derived, Constant< T > > operator*(const T &scalar) const
const Product< Derived, OtherDerived, LazyProduct > lazyProduct(const MatrixBase< OtherDerived > &other) const
Definition: GeneralProduct.h:442