10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H 23 template<
typename Broadcast,
typename XprType>
24 struct traits<TensorBroadcastingOp<Broadcast, XprType> > :
public traits<XprType>
26 typedef typename XprType::Scalar Scalar;
27 typedef traits<XprType> XprTraits;
28 typedef typename XprTraits::StorageKind StorageKind;
29 typedef typename XprTraits::Index
Index;
30 typedef typename XprType::Nested Nested;
31 typedef typename remove_reference<Nested>::type _Nested;
32 static const int NumDimensions = XprTraits::NumDimensions;
33 static const int Layout = XprTraits::Layout;
34 typedef typename XprTraits::PointerType PointerType;
37 template<
typename Broadcast,
typename XprType>
38 struct eval<TensorBroadcastingOp<Broadcast, XprType>,
Eigen::Dense>
40 typedef const TensorBroadcastingOp<Broadcast, XprType> EIGEN_DEVICE_REF type;
43 template<
typename Broadcast,
typename XprType>
44 struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorBroadcastingOp<Broadcast, XprType> >::type>
46 typedef TensorBroadcastingOp<Broadcast, XprType> type;
49 template <
typename Dims>
50 struct is_input_scalar {
51 static const bool value =
false;
54 struct is_input_scalar<Sizes<> > {
55 static const bool value =
true;
57 #ifndef EIGEN_EMULATE_CXX11_META_H 58 template <
typename std::ptrdiff_t... Indices>
59 struct is_input_scalar<Sizes<Indices...> > {
60 static const bool value = (Sizes<Indices...>::total_size == 1);
68 template<
typename Broadcast,
typename XprType>
69 class TensorBroadcastingOp :
public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors>
72 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar;
74 typedef typename XprType::CoeffReturnType CoeffReturnType;
75 typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested;
76 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind;
77 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index
Index;
79 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(
const XprType& expr,
const Broadcast& broadcast)
80 : m_xpr(expr), m_broadcast(broadcast) {}
83 const Broadcast& broadcast()
const {
return m_broadcast; }
86 const typename internal::remove_all<typename XprType::Nested>::type&
87 expression()
const {
return m_xpr; }
90 typename XprType::Nested m_xpr;
91 const Broadcast m_broadcast;
96 template<
typename Broadcast,
typename ArgType,
typename Device>
97 struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
99 typedef TensorBroadcastingOp<Broadcast, ArgType> XprType;
100 typedef typename XprType::Index
Index;
101 static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
102 typedef DSizes<Index, NumDims> Dimensions;
103 typedef typename XprType::Scalar Scalar;
104 typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
105 typedef typename XprType::CoeffReturnType CoeffReturnType;
106 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
107 static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
109 bool isCopy, nByOne, oneByN;
111 typedef StorageMemory<CoeffReturnType, Device> Storage;
112 typedef typename Storage::Type EvaluatorPointerType;
115 IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
116 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
117 BlockAccess = TensorEvaluator<ArgType, Device>::BlockAccess,
118 PreferBlockAccess =
true,
119 Layout = TensorEvaluator<ArgType, Device>::Layout,
123 typedef typename internal::remove_const<Scalar>::type ScalarNoConst;
127 typedef DSizes<Index, 2 * NumDims> BroadcastDimensions;
130 typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
131 typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
133 typedef typename TensorEvaluator<const ArgType, Device>::TensorBlock
136 typedef typename internal::TensorMaterializedBlock<ScalarNoConst, NumDims,
141 EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
142 : isCopy(false), nByOne(false), oneByN(false),
143 m_device(device), m_broadcast(op.broadcast()), m_impl(op.expression(), device)
149 EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
150 const InputDimensions& input_dims = m_impl.dimensions();
152 for (
int i = 0; i < NumDims; ++i) {
153 eigen_assert(input_dims[i] > 0);
154 m_dimensions[i] = input_dims[i] * m_broadcast[i];
155 if (m_broadcast[i] != 1) {
160 if (static_cast<int>(Layout) == static_cast<int>(
ColMajor)) {
161 m_inputStrides[0] = 1;
162 m_outputStrides[0] = 1;
163 for (
int i = 1; i < NumDims; ++i) {
164 m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
165 m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
168 m_inputStrides[NumDims-1] = 1;
169 m_outputStrides[NumDims-1] = 1;
170 for (
int i = NumDims-2; i >= 0; --i) {
171 m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
172 m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
176 if (input_dims[0] == 1) {
178 for (
int i = 1; i < NumDims; ++i) {
179 if (m_broadcast[i] != 1) {
184 }
else if (input_dims[NumDims-1] == 1) {
186 for (
int i = 0; i < NumDims-1; ++i) {
187 if (m_broadcast[i] != 1) {
196 if (!oneByN && !nByOne) {
197 if (input_dims[0] == 1 && input_dims[NumDims-1] == 1 && NumDims > 2) {
200 for (
int i = 1; i < NumDims-1; ++i) {
201 if (m_broadcast[i] != 1) {
211 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dimensions; }
213 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType) {
214 m_impl.evalSubExprsIfNeeded(NULL);
218 #ifdef EIGEN_USE_THREADS 219 template <
typename EvalSubExprsCallback>
220 EIGEN_STRONG_INLINE
void evalSubExprsIfNeededAsync(
221 EvaluatorPointerType, EvalSubExprsCallback done) {
222 m_impl.evalSubExprsIfNeededAsync(
nullptr, [done](
bool) { done(
true); });
224 #endif // EIGEN_USE_THREADS 226 EIGEN_STRONG_INLINE
void cleanup() {
230 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index)
const 232 if (internal::is_input_scalar<
typename internal::remove_all<InputDimensions>::type>::value) {
233 return m_impl.coeff(0);
236 if (static_cast<int>(Layout) == static_cast<int>(
ColMajor)) {
238 return m_impl.coeff(index);
240 return coeffColMajor(index);
244 return m_impl.coeff(index);
246 return coeffRowMajor(index);
252 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index indexColMajor(Index index)
const {
253 Index inputIndex = 0;
255 for (
int i = NumDims - 1; i > 0; --i) {
256 const Index idx = index / m_outputStrides[i];
257 if (internal::index_statically_eq<Broadcast>(i, 1)) {
258 eigen_assert(idx < m_impl.dimensions()[i]);
259 inputIndex += idx * m_inputStrides[i];
261 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
262 eigen_assert(idx % m_impl.dimensions()[i] == 0);
264 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
267 index -= idx * m_outputStrides[i];
269 if (internal::index_statically_eq<Broadcast>(0, 1)) {
270 eigen_assert(index < m_impl.dimensions()[0]);
273 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
274 eigen_assert(index % m_impl.dimensions()[0] == 0);
276 inputIndex += (index % m_impl.dimensions()[0]);
282 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index)
const 284 return m_impl.coeff(indexColMajor(index));
287 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index indexRowMajor(Index index)
const {
288 Index inputIndex = 0;
290 for (
int i = 0; i < NumDims - 1; ++i) {
291 const Index idx = index / m_outputStrides[i];
292 if (internal::index_statically_eq<Broadcast>(i, 1)) {
293 eigen_assert(idx < m_impl.dimensions()[i]);
294 inputIndex += idx * m_inputStrides[i];
296 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
297 eigen_assert(idx % m_impl.dimensions()[i] == 0);
299 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
302 index -= idx * m_outputStrides[i];
304 if (internal::index_statically_eq<Broadcast>(NumDims - 1, 1)) {
305 eigen_assert(index < m_impl.dimensions()[NumDims - 1]);
308 if (internal::index_statically_eq<InputDimensions>(NumDims - 1, 1)) {
309 eigen_assert(index % m_impl.dimensions()[NumDims - 1] == 0);
311 inputIndex += (index % m_impl.dimensions()[NumDims - 1]);
317 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index)
const 319 return m_impl.coeff(indexRowMajor(index));
322 template<
int LoadMode>
323 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index)
const 325 if (internal::is_input_scalar<
typename internal::remove_all<InputDimensions>::type>::value) {
326 return internal::pset1<PacketReturnType>(m_impl.coeff(0));
329 if (static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
331 #ifdef EIGEN_GPU_COMPILE_PHASE 334 return m_impl.template packet<Unaligned>(index);
336 return m_impl.template packet<LoadMode>(index);
338 }
else if (oneByN && !nByOne) {
339 return packetNByOne<LoadMode>(index);
340 }
else if (!oneByN && nByOne) {
341 return packetOneByN<LoadMode>(index);
342 }
else if (oneByN && nByOne) {
343 return packetOneByNByOne<LoadMode>(index);
345 return packetColMajor<LoadMode>(index);
349 #ifdef EIGEN_GPU_COMPILE_PHASE 351 return m_impl.template packet<Unaligned>(index);
353 return m_impl.template packet<LoadMode>(index);
355 }
else if (oneByN && !nByOne) {
356 return packetOneByN<LoadMode>(index);
357 }
else if (!oneByN && nByOne) {
358 return packetNByOne<LoadMode>(index);
359 }
else if (oneByN && nByOne) {
360 return packetOneByNByOne<LoadMode>(index);
362 return packetRowMajor<LoadMode>(index);
367 template<
int LoadMode>
368 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByNByOne
371 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
372 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
374 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
375 Index startDim, endDim;
376 Index inputIndex, outputOffset, batchedIndex;
378 if (static_cast<int>(Layout) == static_cast<int>(
ColMajor)) {
379 startDim = NumDims - 1;
383 endDim = NumDims - 2;
386 batchedIndex = index % m_outputStrides[startDim];
387 inputIndex = batchedIndex / m_outputStrides[endDim];
388 outputOffset = batchedIndex % m_outputStrides[endDim];
390 if (outputOffset + PacketSize <= m_outputStrides[endDim]) {
391 values[0] = m_impl.coeff(inputIndex);
392 return internal::pload1<PacketReturnType>(values);
395 for (
int i = 0, cur = 0; i < PacketSize; ++i, ++cur) {
396 if (outputOffset + cur < m_outputStrides[endDim]) {
397 values[i] = m_impl.coeff(inputIndex);
400 inputIndex = (inputIndex == m_inputStrides[startDim] ? 0 : inputIndex);
401 values[i] = m_impl.coeff(inputIndex);
406 return internal::pload<PacketReturnType>(values);
410 template<
int LoadMode>
411 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByN(Index index)
const 413 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
414 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
416 Index dim, inputIndex;
418 if (static_cast<int>(Layout) == static_cast<int>(
ColMajor)) {
424 inputIndex = index % m_inputStrides[dim];
425 if (inputIndex + PacketSize <= m_inputStrides[dim]) {
426 return m_impl.template packet<Unaligned>(inputIndex);
428 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
430 for (
int i = 0; i < PacketSize; ++i) {
431 if (inputIndex > m_inputStrides[dim]-1) {
434 values[i] = m_impl.coeff(inputIndex++);
436 return internal::pload<PacketReturnType>(values);
440 template<
int LoadMode>
441 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetNByOne(Index index)
const 443 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
444 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
446 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
447 Index dim, inputIndex, outputOffset;
449 if (static_cast<int>(Layout) == static_cast<int>(
ColMajor)) {
455 inputIndex = index / m_outputStrides[dim];
456 outputOffset = index % m_outputStrides[dim];
457 if (outputOffset + PacketSize <= m_outputStrides[dim]) {
458 values[0] = m_impl.coeff(inputIndex);
459 return internal::pload1<PacketReturnType>(values);
462 for (
int i = 0, cur = 0; i < PacketSize; ++i, ++cur) {
463 if (outputOffset + cur < m_outputStrides[dim]) {
464 values[i] = m_impl.coeff(inputIndex);
466 values[i] = m_impl.coeff(++inputIndex);
471 return internal::pload<PacketReturnType>(values);
477 template<
int LoadMode>
478 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index)
const 480 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
481 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
483 const Index originalIndex = index;
485 Index inputIndex = 0;
487 for (
int i = NumDims - 1; i > 0; --i) {
488 const Index idx = index / m_outputStrides[i];
489 if (internal::index_statically_eq<Broadcast>(i, 1)) {
490 eigen_assert(idx < m_impl.dimensions()[i]);
491 inputIndex += idx * m_inputStrides[i];
493 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
494 eigen_assert(idx % m_impl.dimensions()[i] == 0);
496 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
499 index -= idx * m_outputStrides[i];
502 if (internal::index_statically_eq<Broadcast>(0, 1)) {
503 eigen_assert(index < m_impl.dimensions()[0]);
504 innermostLoc = index;
506 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
507 eigen_assert(index % m_impl.dimensions()[0] == 0);
510 innermostLoc = index % m_impl.dimensions()[0];
513 inputIndex += innermostLoc;
517 if (innermostLoc + PacketSize <= m_impl.dimensions()[0]) {
518 return m_impl.template packet<Unaligned>(inputIndex);
520 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
521 values[0] = m_impl.coeff(inputIndex);
523 for (
int i = 1; i < PacketSize; ++i) {
524 if (innermostLoc + i < m_impl.dimensions()[0]) {
525 values[i] = m_impl.coeff(inputIndex+i);
527 values[i] = coeffColMajor(originalIndex+i);
530 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
535 template<
int LoadMode>
536 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index)
const 538 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
539 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
541 const Index originalIndex = index;
543 Index inputIndex = 0;
545 for (
int i = 0; i < NumDims - 1; ++i) {
546 const Index idx = index / m_outputStrides[i];
547 if (internal::index_statically_eq<Broadcast>(i, 1)) {
548 eigen_assert(idx < m_impl.dimensions()[i]);
549 inputIndex += idx * m_inputStrides[i];
551 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
552 eigen_assert(idx % m_impl.dimensions()[i] == 0);
554 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
557 index -= idx * m_outputStrides[i];
560 if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
561 eigen_assert(index < m_impl.dimensions()[NumDims-1]);
562 innermostLoc = index;
564 if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
565 eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
568 innermostLoc = index % m_impl.dimensions()[NumDims-1];
571 inputIndex += innermostLoc;
575 if (innermostLoc + PacketSize <= m_impl.dimensions()[NumDims-1]) {
576 return m_impl.template packet<Unaligned>(inputIndex);
578 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
579 values[0] = m_impl.coeff(inputIndex);
581 for (
int i = 1; i < PacketSize; ++i) {
582 if (innermostLoc + i < m_impl.dimensions()[NumDims-1]) {
583 values[i] = m_impl.coeff(inputIndex+i);
585 values[i] = coeffRowMajor(originalIndex+i);
588 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
593 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
594 costPerCoeff(
bool vectorized)
const {
595 double compute_cost = TensorOpCost::AddCost<Index>();
596 if (!isCopy && NumDims > 0) {
598 for (
int i = NumDims - 1; i > 0; --i) {
599 compute_cost += TensorOpCost::DivCost<Index>();
600 if (internal::index_statically_eq<Broadcast>(i, 1)) {
602 TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
604 if (!internal::index_statically_eq<InputDimensions>(i, 1)) {
605 compute_cost += TensorOpCost::MulCost<Index>() +
606 TensorOpCost::ModCost<Index>() +
607 TensorOpCost::AddCost<Index>();
611 TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
614 return m_impl.costPerCoeff(vectorized) +
615 TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
618 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
619 internal::TensorBlockResourceRequirements getResourceRequirements()
const {
622 const size_t target_size = m_device.firstLevelCacheSize();
623 return internal::TensorBlockResourceRequirements::merge(
624 m_impl.getResourceRequirements(),
625 internal::TensorBlockResourceRequirements::skewed<Scalar>(target_size));
628 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock
629 block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
630 bool =
false)
const {
631 BlockBroadcastingParams params = blockBroadcastingParams(desc);
633 if (params.inner_dim_size == 0 || params.bcast_dim_size == 0) {
638 const typename TensorBlock::Storage block_storage =
639 TensorBlock::prepareStorage(desc, scratch);
640 ScalarNoConst* materialized_output = block_storage.data();
643 size_t materialized_input_size = 0;
644 ScalarNoConst* materialized_input = NULL;
649 array<BlockBroadcastingIteratorState, NumDims> it;
652 for (
int i = params.inner_dim_count + 1; i < NumDims; ++i) {
653 const Index dim = IsColMajor ? i : NumDims - 1 - i;
654 it[idx].size = params.output_dims[dim];
656 it[idx].output_stride = m_outputStrides[dim];
657 it[idx].output_span = it[idx].output_stride * (it[idx].size - 1);
662 Index output_offset = 0;
666 const Index output_size = NumDims == 0 ? 1 : params.output_dims.TotalSize();
668 for (Index num_output_coeffs = 0; num_output_coeffs < output_size;) {
669 ScalarNoConst* bcast_output = materialized_output + num_output_coeffs;
670 Index bcast_offset = desc.offset() + output_offset;
673 num_output_coeffs += BroadcastBlockAlongBcastDim(
674 params, bcast_offset, scratch, bcast_output, &materialized_input,
675 &materialized_input_size);
678 for (
int j = 0; j < idx; ++j) {
679 if (++it[j].count < it[j].size) {
680 output_offset += it[j].output_stride;
684 output_offset -= it[j].output_span;
688 return block_storage.AsTensorMaterializedBlock();
691 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return NULL; }
693 const TensorEvaluator<ArgType, Device>& impl()
const {
return m_impl; }
695 Broadcast functor()
const {
return m_broadcast; }
696 #ifdef EIGEN_USE_SYCL 698 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void bind(
699 cl::sycl::handler& cgh)
const {
704 static const bool IsColMajor =
705 static_cast<int>(Layout) == static_cast<int>(
ColMajor);
724 struct BlockBroadcastingParams {
725 Dimensions input_dims;
726 Dimensions output_dims;
727 Dimensions output_strides;
731 Index bcast_dim_size;
732 Index inner_dim_size;
736 Dimensions input_block_sizes;
737 Dimensions input_block_strides;
740 BroadcastDimensions bcast_block_sizes;
741 BroadcastDimensions bcast_block_strides;
742 BroadcastDimensions bcast_input_strides;
745 struct BlockBroadcastingIteratorState {
752 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlockBroadcastingParams
753 blockBroadcastingParams(TensorBlockDesc& desc)
const {
754 BlockBroadcastingParams params;
756 params.input_dims = Dimensions(m_impl.dimensions());
759 params.output_dims = desc.dimensions();
760 params.output_strides = internal::strides<Layout>(params.output_dims);
764 params.bcast_dim = 0;
765 params.bcast_dim_size = 1;
766 params.inner_dim_size = 1;
770 params.inner_dim_count = 0;
772 for (
int i = 0; i < NumDims; ++i) {
773 const int dim = IsColMajor ? i : NumDims - i - 1;
775 if (params.output_dims[dim] == m_dimensions[dim]) {
776 params.inner_dim_size *= params.output_dims[dim];
777 ++params.inner_dim_count;
782 eigen_assert(params.output_dims[dim] < m_dimensions[dim]);
783 params.bcast_dim = dim;
784 params.bcast_dim_size = params.output_dims[dim];
789 for (
int i = 0; i < params.inner_dim_count; ++i) {
790 const int dim = IsColMajor ? i : NumDims - i - 1;
791 params.input_block_sizes[dim] = params.input_dims[dim];
793 for (
int i = params.inner_dim_count; i < NumDims; ++i) {
794 const int dim = IsColMajor ? i : NumDims - i - 1;
795 params.input_block_sizes[dim] = 1;
797 params.input_block_strides =
798 internal::strides<Layout>(params.input_block_sizes);
818 for (
int i = 0; i < params.inner_dim_count; ++i) {
819 const int dim = IsColMajor ? i : NumDims - i - 1;
821 const int copy_dim = IsColMajor ? 2 * i : 2 * NumDims - 2 * i - 1;
822 const int broadcast_dim = IsColMajor ? copy_dim + 1 : copy_dim - 1;
824 params.bcast_block_sizes[copy_dim] = params.input_dims[dim];
825 params.bcast_block_sizes[broadcast_dim] = m_broadcast[dim];
826 params.bcast_block_strides[copy_dim] = params.output_strides[dim];
827 params.bcast_block_strides[broadcast_dim] =
828 params.output_strides[dim] * params.input_dims[dim];
829 params.bcast_input_strides[copy_dim] = params.input_block_strides[dim];
830 params.bcast_input_strides[broadcast_dim] = 0;
833 for (
int i = 2 * params.inner_dim_count; i < 2 * NumDims; ++i) {
834 const int dim = IsColMajor ? i : 2 * NumDims - i - 1;
835 params.bcast_block_sizes[dim] = 1;
836 params.bcast_block_strides[dim] = 0;
837 params.bcast_input_strides[dim] = 0;
843 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock emptyBlock()
const {
844 DSizes<Index, NumDims> dimensions;
845 for (
int i = 0; i < NumDims; ++i) dimensions[i] = 0;
846 return TensorBlock(internal::TensorBlockKind::kView, NULL, dimensions);
849 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlockAlongBcastDim(
850 BlockBroadcastingParams params, Index bcast_offset,
851 TensorBlockScratch& scratch, ScalarNoConst* materialized_output,
852 ScalarNoConst** materialized_input,
853 size_t* materialized_input_size)
const {
854 if (params.bcast_dim_size == 1) {
856 return BroadcastBlock(
857 params.input_block_sizes, params.input_block_strides,
858 params.bcast_block_sizes, params.bcast_block_strides,
859 params.bcast_input_strides, bcast_offset, 0, scratch,
860 materialized_output, materialized_input, materialized_input_size);
862 }
else if (params.input_dims[params.bcast_dim] == 1) {
864 const int broadcast_bcast_dim =
865 IsColMajor ? 2 * params.inner_dim_count + 1
866 : 2 * NumDims - 2 * params.inner_dim_count - 2;
868 params.bcast_block_sizes[broadcast_bcast_dim] = params.bcast_dim_size;
869 params.bcast_input_strides[broadcast_bcast_dim] = 0;
870 params.bcast_block_strides[broadcast_bcast_dim] =
871 params.output_strides[params.bcast_dim];
873 return BroadcastBlock(
874 params.input_block_sizes, params.input_block_strides,
875 params.bcast_block_sizes, params.bcast_block_strides,
876 params.bcast_input_strides, bcast_offset, 0, scratch,
877 materialized_output, materialized_input, materialized_input_size);
882 Index num_output_coeffs = 0;
904 const Index bcast_dim_left_index =
905 bcast_offset / m_outputStrides[params.bcast_dim];
908 const Index input_bcast_dim_size = params.input_dims[params.bcast_dim];
912 const Index first_multiple =
913 divup<Index>(bcast_dim_left_index, input_bcast_dim_size) *
914 input_bcast_dim_size;
916 if (first_multiple <= bcast_dim_left_index + params.bcast_dim_size) {
918 const Index last_multiple =
919 (bcast_dim_left_index + params.bcast_dim_size) /
920 input_bcast_dim_size * input_bcast_dim_size;
921 const int copy_bcast_dim =
922 IsColMajor ? 2 * params.inner_dim_count
923 : 2 * NumDims - 2 * params.inner_dim_count - 1;
924 const int broadcast_bcast_dim =
925 IsColMajor ? 2 * params.inner_dim_count + 1
926 : 2 * NumDims - 2 * params.inner_dim_count - 2;
928 if (first_multiple > bcast_dim_left_index) {
929 const Index head_size = first_multiple - bcast_dim_left_index;
930 params.input_block_sizes[params.bcast_dim] = head_size;
931 params.bcast_block_sizes[copy_bcast_dim] = head_size;
932 params.bcast_input_strides[copy_bcast_dim] =
933 params.input_block_strides[params.bcast_dim];
934 params.bcast_block_strides[copy_bcast_dim] =
935 params.output_strides[params.bcast_dim];
936 params.bcast_block_sizes[broadcast_bcast_dim] = 1;
937 params.bcast_input_strides[broadcast_bcast_dim] = 0;
938 params.bcast_block_strides[broadcast_bcast_dim] =
939 params.output_strides[params.bcast_dim] *
940 params.input_dims[params.bcast_dim];
942 num_output_coeffs += BroadcastBlock(
943 params.input_block_sizes, params.input_block_strides,
944 params.bcast_block_sizes, params.bcast_block_strides,
945 params.bcast_input_strides, bcast_offset, 0, scratch,
946 materialized_output, materialized_input, materialized_input_size);
948 if (first_multiple < last_multiple) {
949 params.input_block_sizes[params.bcast_dim] = input_bcast_dim_size;
950 params.bcast_block_sizes[copy_bcast_dim] = input_bcast_dim_size;
951 params.bcast_input_strides[copy_bcast_dim] =
952 params.input_block_strides[params.bcast_dim];
953 params.bcast_block_strides[copy_bcast_dim] =
954 params.output_strides[params.bcast_dim];
955 params.bcast_block_sizes[broadcast_bcast_dim] =
956 (last_multiple - first_multiple) / input_bcast_dim_size;
957 params.bcast_input_strides[broadcast_bcast_dim] = 0;
958 params.bcast_block_strides[broadcast_bcast_dim] =
959 params.output_strides[params.bcast_dim] *
960 params.input_dims[params.bcast_dim];
961 const Index offset = (first_multiple - bcast_dim_left_index) *
962 m_outputStrides[params.bcast_dim];
964 num_output_coeffs += BroadcastBlock(
965 params.input_block_sizes, params.input_block_strides,
966 params.bcast_block_sizes, params.bcast_block_strides,
967 params.bcast_input_strides, bcast_offset, offset, scratch,
968 materialized_output, materialized_input, materialized_input_size);
970 if (last_multiple < bcast_dim_left_index + params.bcast_dim_size) {
971 const Index tail_size =
972 bcast_dim_left_index + params.bcast_dim_size - last_multiple;
973 params.input_block_sizes[params.bcast_dim] = tail_size;
974 params.bcast_block_sizes[copy_bcast_dim] = tail_size;
975 params.bcast_input_strides[copy_bcast_dim] =
976 params.input_block_strides[params.bcast_dim];
977 params.bcast_block_strides[copy_bcast_dim] =
978 params.output_strides[params.bcast_dim];
979 params.bcast_block_sizes[broadcast_bcast_dim] = 1;
980 params.bcast_input_strides[broadcast_bcast_dim] = 0;
981 params.bcast_block_strides[broadcast_bcast_dim] =
982 params.output_strides[params.bcast_dim] *
983 params.input_dims[params.bcast_dim];
984 const Index offset = (last_multiple - bcast_dim_left_index) *
985 m_outputStrides[params.bcast_dim];
987 num_output_coeffs += BroadcastBlock(
988 params.input_block_sizes, params.input_block_strides,
989 params.bcast_block_sizes, params.bcast_block_strides,
990 params.bcast_input_strides, bcast_offset, offset, scratch,
991 materialized_output, materialized_input, materialized_input_size);
995 const int copy_bcast_dim =
996 IsColMajor ? 2 * params.inner_dim_count
997 : 2 * NumDims - 2 * params.inner_dim_count - 1;
998 params.input_block_sizes[params.bcast_dim] = params.bcast_dim_size;
999 params.bcast_block_sizes[copy_bcast_dim] = params.bcast_dim_size;
1000 params.bcast_input_strides[copy_bcast_dim] =
1001 params.input_block_strides[params.bcast_dim];
1002 params.bcast_block_strides[copy_bcast_dim] =
1003 params.output_strides[params.bcast_dim];
1005 num_output_coeffs += BroadcastBlock(
1006 params.input_block_sizes, params.input_block_strides,
1007 params.bcast_block_sizes, params.bcast_block_strides,
1008 params.bcast_input_strides, bcast_offset, 0, scratch,
1009 materialized_output, materialized_input, materialized_input_size);
1012 return num_output_coeffs;
1016 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlock(
1017 const Dimensions& input_block_sizes,
1018 const Dimensions& input_block_strides,
1019 const BroadcastDimensions& bcast_block_sizes,
1020 const BroadcastDimensions& bcast_block_strides,
1021 const BroadcastDimensions& bcast_input_strides, Index bcast_offset,
1022 Index offset, TensorBlockScratch& scratch,
1023 ScalarNoConst* materialized_output, ScalarNoConst** materialized_input,
1024 size_t* materialized_input_size)
const {
1027 const Index input_offset = bcast_offset + offset;
1028 TensorBlockDesc input_desc(
1029 IsColMajor ? indexColMajor(input_offset) : indexRowMajor(input_offset),
1032 ArgTensorBlock input_block = m_impl.block(input_desc, scratch);
1037 const ScalarNoConst* input_buffer = NULL;
1039 if (input_block.data() != NULL) {
1041 input_buffer = input_block.data();
1048 const size_t input_total_size = input_block_sizes.TotalSize();
1049 if (*materialized_input == NULL ||
1050 *materialized_input_size < input_total_size) {
1051 *materialized_input_size = input_total_size;
1052 void* mem = scratch.allocate(*materialized_input_size *
sizeof(Scalar));
1053 *materialized_input =
static_cast<ScalarNoConst*
>(mem);
1056 typedef internal::TensorBlockAssignment<
1057 ScalarNoConst, NumDims,
typename ArgTensorBlock::XprType, Index>
1058 TensorBlockAssignment;
1060 TensorBlockAssignment::Run(
1061 TensorBlockAssignment::target(input_block_sizes, input_block_strides,
1062 *materialized_input),
1063 input_block.expr());
1065 input_buffer = *materialized_input;
1071 typedef internal::TensorBlockIO<ScalarNoConst, Index, 2 * NumDims, Layout>
1074 typename TensorBlockIO::Src src(bcast_input_strides, input_buffer);
1075 typename TensorBlockIO::Dst dst(bcast_block_sizes, bcast_block_strides,
1076 materialized_output + offset);
1078 return TensorBlockIO::Copy(dst, src);
1082 const Device EIGEN_DEVICE_REF m_device;
1083 const typename internal::remove_reference<Broadcast>::type m_broadcast;
1084 Dimensions m_dimensions;
1085 array<Index, NumDims> m_outputStrides;
1086 array<Index, NumDims> m_inputStrides;
1087 TensorEvaluator<ArgType, Device> m_impl;
1093 #endif // EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index