Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
TensorContractionSycl.h
1 // This file is part of Eigen, a lightweight C++ template library for linear algebra.
2 //
3 // Mehdi Goli Codeplay Software Ltd.
4 // Ralph Potter Codeplay Software Ltd.
5 // Luke Iwanski Codeplay Software Ltd.
6 // Contact: <eigen@codeplay.com>
7 //
8 // This Source Code Form is subject to the terms of the Mozilla Public License v. 2.0. If a copy of the MPL was not
9 // distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 /*****************************************************************
12  * TensorContractionSycl.h
13  *
14  * \brief:
15  * TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
16  *
17  *****************************************************************/
18 
19 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
20 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
21 
22 namespace Eigen {
23 
24 namespace TensorSycl {
25 namespace internal {
26 
27 #ifndef EIGEN_SYCL_DISABLE_GEMV
28 
42 template <typename Scalar, typename StorageIndex, StorageIndex NCWindow, StorageIndex CFactor, StorageIndex NCFactor>
43 struct TVPanelSize {
44  // LocalThreadSizeC: determines total number of thread per workgroup for the contracting dimension
45  static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeC = EIGEN_SYCL_LOCAL_THREAD_DIM0;
46  // LocalThreadSizeNC: determines total number of thread per workgroup for the non-contracting dimension
47  static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeNC = EIGEN_SYCL_LOCAL_THREAD_DIM1;
48  // TileSizeDimNC: determines the tile size for the non-contracting dimension
49  static EIGEN_CONSTEXPR StorageIndex TileSizeDimNC = NCWindow / NCFactor;
50  // TileSizeDimC: determines the tile size for the contracting dimension
51  static EIGEN_CONSTEXPR StorageIndex TileSizeDimC = CFactor * LocalThreadSizeNC * LocalThreadSizeC;
52  // WorkLoadPerThreadNC : determines workload per thread for loading the non-contracting dimension
53  static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadNC = TileSizeDimNC / LocalThreadSizeNC;
54  // WorkLoadPerThreadC: determines workload per thread for loading the non-contracting dimension
55  static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadC = TileSizeDimC / LocalThreadSizeC;
56  // BC : determines if supporting bank conflict is required
57  static EIGEN_CONSTEXPR bool BC = false;
58 };
59 #endif
60 
78 template <typename Scalar, typename StorageIndex, StorageIndex REG_SIZE_M, StorageIndex REG_SIZE_N, StorageIndex TSDK>
79 struct TTPanelSize {
80  // TileSizeDimK: determines Tile size for dimension K. The packet size is assumed to be considered
81  static EIGEN_CONSTEXPR StorageIndex TileSizeDimK = TSDK;
82  // WorkLoadPerThreadM : determines workload per thread for loading the M dimension This can be varied based on the
83  // available register on a chosen device(can be controlled by EIGEN_SYCL_REG_M macro//
84 #ifndef EIGEN_SYCL_REG_M
85  static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadM = REG_SIZE_M;
86 #else
87  static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadM = EIGEN_SYCL_REG_M;
88 #endif
89 // WorkLoadPerThreadN : determines workload per thread for loading the N dimension This can be varied based on the
90 // available register on a chosen device(can be controlled by EIGEN_SYCL_REG_N macro
91 #ifndef EIGEN_SYCL_REG_N
92  static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadN = REG_SIZE_N;
93 #else
94  static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadN = EIGEN_SYCL_REG_N;
95 #endif
96  // LocalThreadSizeM: determines total number of thread per workgroup for the m dimension
97  static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeM = EIGEN_SYCL_LOCAL_THREAD_DIM0;
98  // LocalThreadSizeN: determines total number of thread per workgroup for the n dimension
99  static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeN = EIGEN_SYCL_LOCAL_THREAD_DIM1;
100  // TileSizeDimM: determines the tile size for the m dimension
101  static EIGEN_CONSTEXPR StorageIndex TileSizeDimM = LocalThreadSizeM * WorkLoadPerThreadM;
102  // TileSizeDimN: determines the tile size for the n dimension
103  static EIGEN_CONSTEXPR StorageIndex TileSizeDimN = LocalThreadSizeN * WorkLoadPerThreadN;
104  // LoadPerThreadLhs: determines workload per thread for loading Lhs Tensor. This must be divisable by packetsize
105  static EIGEN_CONSTEXPR StorageIndex LoadPerThreadLhs =
106  ((TileSizeDimK * WorkLoadPerThreadM * WorkLoadPerThreadN) / (TileSizeDimN));
107  // LoadPerThreadRhs: determines workload per thread for loading Rhs Tensor. This must be divisable by packetsize
108  static EIGEN_CONSTEXPR StorageIndex LoadPerThreadRhs =
109  ((TileSizeDimK * WorkLoadPerThreadM * WorkLoadPerThreadN) / (TileSizeDimM));
110  // BC : determines if supporting bank conflict is required
111  static EIGEN_CONSTEXPR bool BC = true;
112  // DoubleBuffer: determines if double buffering technique should be used (This can be disabled by
113  // EIGEN_SYCL_DISABLE_DOUBLE_BUFFER macro when the device doesnot have sufficient local memory)
114  static EIGEN_CONSTEXPR bool DoubleBuffer =
115 #ifdef EIGEN_SYCL_DISABLE_DOUBLE_BUFFER
116  false;
117 #else
118  true;
119 #endif
120 };
121 
122 /* !
123  * \brief contraction_type: an enum class representing the Tensor Contraction implementation algorithm. This is used to
124  * specialize the contraction algorithm based on device support for dedicated local memory.
125  */
126 enum class contraction_type { local, no_local };
127 /* !
128  * \brief data_source an enum class determining the location of the data in a memory hierarchy (global, local, private).
129  */
130 enum class data_source { global_mem, local_mem, private_mem };
131 
157 template <bool PacketLoad, bool is_coalesced_layout, bool, typename PacketType, typename TensorMapper,
158  typename StorageIndex>
159 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<PacketLoad, PacketType>::type read(
160  const TensorMapper &tensorMapper, const StorageIndex &NCIndex, const StorageIndex &CIndex, const StorageIndex &ld) {
161  const StorageIndex row = (is_coalesced_layout) ? NCIndex : CIndex;
162  const StorageIndex col = (is_coalesced_layout) ? CIndex : NCIndex;
163  return tensorMapper.get_tensor().template packet<Unaligned>(row + (col * ld));
164 }
165 
188 template <bool PacketLoad, bool, bool IsRhs, typename PacketType, typename TensorMapper, typename StorageIndex>
189 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<!PacketLoad, PacketType>::type read(
190  const TensorMapper &tensorMapper, const StorageIndex &NCIndex, const StorageIndex &CIndex, const StorageIndex &) {
191  const StorageIndex row = (IsRhs) ? CIndex : NCIndex;
192  const StorageIndex col = (IsRhs) ? NCIndex : CIndex;
193  return tensorMapper(row, col);
194 }
195 
217 template <typename StorageIndex, StorageIndex ld, data_source dt, typename PacketType, typename DataScalar>
218 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
219  typename ::Eigen::internal::enable_if<dt != data_source::global_mem, void>::type
220  write(PacketType &packet_data, DataScalar ptr) {
221  EIGEN_CONSTEXPR int PacketSize = Eigen::internal::unpacket_traits<PacketType>::size;
222  EIGEN_UNROLL_LOOP
223  for (int i = 0; i < PacketSize; i++) {
224  *ptr = PacketWrapper<PacketType, PacketSize>::scalarize(i, packet_data);
225  ptr += ld;
226  }
227 }
228 
244 template <data_source dt, typename PacketType, typename DataScalar>
245 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<
246  Eigen::internal::unpacket_traits<PacketType>::size != 1 && dt == data_source::global_mem, void>::type
247 write(PacketType &packet_data, DataScalar *ptr) {
248  ::Eigen::internal::pstoreu<DataScalar, PacketType>(ptr, packet_data);
249 }
250 
264 template <data_source dt, typename PacketType, typename DataScalar>
265 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<
266  Eigen::internal::unpacket_traits<PacketType>::size == 1 && dt == data_source::global_mem, void>::type
267 write(PacketType &packet_data, DataScalar *ptr) {
268  *ptr = packet_data;
269 }
270 
276 template <bool is_internal>
277 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_boundary(bool) {
278  return true;
279 }
280 
286 template <>
287 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_boundary<false>(bool cond) {
288  return cond;
289 }
290 
317 template <bool is_transposed, bool is_rhs_, bool packet_load_, typename PacketType>
318 struct BlockProperties {
319  static EIGEN_CONSTEXPR bool packet_load = packet_load_;
320  typedef typename Eigen::internal::unpacket_traits<PacketType>::type OutScalar;
321  static EIGEN_CONSTEXPR bool is_rhs = is_rhs_;
322  typedef typename Eigen::internal::conditional<packet_load, PacketType, OutScalar>::type OutType;
323  static EIGEN_CONSTEXPR int elements_per_access = Eigen::internal::unpacket_traits<OutType>::size;
324  static EIGEN_CONSTEXPR bool is_coalesced_layout = !(is_transposed ^ is_rhs);
325  static EIGEN_CONSTEXPR int nc_stride = (is_coalesced_layout ? elements_per_access : 1);
326  static EIGEN_CONSTEXPR int c_stride = (is_coalesced_layout ? 1 : elements_per_access);
327 };
328 
368 template <typename StorageIndex>
369 struct ThreadProperties {
370  const StorageIndex linearLocalThreadId;
371  const StorageIndex kGroupId;
372  const StorageIndex mGroupOffset;
373  const StorageIndex nGroupOffset;
374  const StorageIndex kGroupOffset;
375  const StorageIndex mLocalOffset;
376  const StorageIndex nLocalOffset;
377  const StorageIndex mGlobalOffset;
378  const StorageIndex nGlobalOffset;
379  StorageIndex kSize;
380  const bool is_internal;
381  // this is used to adjust the last block
382  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ThreadProperties(
383  const StorageIndex linearLocalThreadId_, const StorageIndex kGroupId_, const StorageIndex mGroupOffset_,
384  const StorageIndex nGroupOffset_, const StorageIndex kGroupOffset_, const StorageIndex mLocalOffset_,
385  const StorageIndex nLocalOffset_, const StorageIndex mGlobalOffset_, const StorageIndex nGlobalOffset_,
386  StorageIndex kSize_, const bool is_internal_)
387  : linearLocalThreadId(linearLocalThreadId_),
388  kGroupId(kGroupId_),
389  mGroupOffset(mGroupOffset_),
390  nGroupOffset(nGroupOffset_),
391  kGroupOffset(kGroupOffset_),
392  mLocalOffset(mLocalOffset_),
393  nLocalOffset(nLocalOffset_),
394  mGlobalOffset(mGlobalOffset_),
395  nGlobalOffset(nGlobalOffset_),
396  kSize(kSize_),
397  is_internal(is_internal_) {}
398 };
399 
450 template <typename OutScalar, typename LhsScalar, typename RhsScalar, typename OutAccessor, typename LhsMapper,
451  typename RhsMapper, typename StorageIndex, typename Properties, typename TripleDim, bool Vectorizable,
452  typename input_mapper_properties, bool IsFinal, contraction_type contraction_tp>
453 class TensorContractionKernel {
454  public:
455  typedef typename Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketReturnType
456  PacketReturnType;
457  static EIGEN_CONSTEXPR int PacketSize =
458  Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketSize;
459  static EIGEN_CONSTEXPR bool is_lhs_transposed =
460  !::Eigen::internal::TensorContractionInputMapperTrait<LhsMapper>::inner_dim_contiguous;
461  static EIGEN_CONSTEXPR bool is_rhs_transposed =
462  !::Eigen::internal::TensorContractionInputMapperTrait<RhsMapper>::inner_dim_contiguous;
463 
464  typedef BlockProperties<is_lhs_transposed, false, input_mapper_properties::is_lhs_matrix && Vectorizable,
465  PacketReturnType>
466  LHSBlockProperties;
467 
468  typedef BlockProperties<is_rhs_transposed, true, input_mapper_properties::is_rhs_matrix && Vectorizable,
469  PacketReturnType>
470  RHSBlockProperties;
471 
472  static EIGEN_CONSTEXPR StorageIndex NStride =
473  contraction_tp == contraction_type::local ? Properties::WorkLoadPerThreadN : RHSBlockProperties::nc_stride;
474 
475  typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
476  typedef cl::sycl::multi_ptr<OutScalar, cl::sycl::access::address_space::local_space> local_ptr;
477  typedef OutScalar * /*cl::sycl::multi_ptr<OutScalar, cl::sycl::access::address_space::private_space>*/ private_ptr;
478  typedef
479  typename ::Eigen::internal::conditional<contraction_tp == contraction_type::local, local_ptr, private_ptr>::type
480  tile_ptr;
481  static EIGEN_CONSTEXPR StorageIndex LSDL = contraction_tp == contraction_type::local
482  ? Properties::TileSizeDimM + Properties::BC
483  : Properties::WorkLoadPerThreadM;
484  static EIGEN_CONSTEXPR StorageIndex LSDR = contraction_tp == contraction_type::local
485  ? Properties::TileSizeDimN + Properties::BC
486  : Properties::WorkLoadPerThreadN;
487  static EIGEN_CONSTEXPR StorageIndex LocalOffset = Properties::LocalThreadSizeM * Properties::LocalThreadSizeN;
488 
501  template <contraction_type, StorageIndex>
502  struct MemHolder {
503  tile_ptr ptr;
504  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MemHolder(local_ptr block_start_ptr) : ptr(block_start_ptr) {}
505  };
509  template <StorageIndex MemSize>
510  struct MemHolder<contraction_type::no_local, MemSize> {
511  OutScalar ptr[MemSize] = {OutScalar{0}};
512  };
535  struct TiledMemory {
536  MemHolder<contraction_tp, Properties::WorkLoadPerThreadM * Properties::TileSizeDimK> lhs_scratch_extract;
537  MemHolder<contraction_tp, Properties::WorkLoadPerThreadN * Properties::TileSizeDimK> rhs_scratch_extract;
538  tile_ptr lhs_scratch_ptr_compute;
539  tile_ptr rhs_scratch_ptr_compute;
540  const std::pair<StorageIndex, StorageIndex> lhs_extract_index;
541  const std::pair<StorageIndex, StorageIndex> rhs_extract_index;
542  template <contraction_type tp = contraction_tp>
543  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
544  TiledMemory(const ThreadProperties<StorageIndex> &, local_ptr,
545  typename ::Eigen::internal::enable_if<tp == contraction_type::no_local>::type * = 0)
546  : lhs_scratch_extract{},
547  rhs_scratch_extract{},
548  lhs_scratch_ptr_compute(lhs_scratch_extract.ptr),
549  rhs_scratch_ptr_compute(rhs_scratch_extract.ptr),
550  lhs_extract_index(std::pair<StorageIndex, StorageIndex>(StorageIndex{0}, StorageIndex{0})),
551  rhs_extract_index(std::pair<StorageIndex, StorageIndex>(StorageIndex{0}, StorageIndex{0})) {}
552 
553  template <contraction_type tp = contraction_tp>
554  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
555  TiledMemory(const ThreadProperties<StorageIndex> &thread_properties, local_ptr block_start_ptr,
556  typename ::Eigen::internal::enable_if<tp == contraction_type::local>::type * = 0)
557  : lhs_scratch_extract{block_start_ptr},
558  rhs_scratch_extract{lhs_scratch_extract.ptr +
559  ((Properties::DoubleBuffer + 1) * LSDL * Properties::TileSizeDimK)},
560  lhs_scratch_ptr_compute(lhs_scratch_extract.ptr + thread_properties.mLocalOffset),
561  rhs_scratch_ptr_compute(rhs_scratch_extract.ptr + thread_properties.nLocalOffset),
562  lhs_extract_index(
563  local_id_extract<LHSBlockProperties, Properties::TileSizeDimM>(thread_properties.linearLocalThreadId)),
564  rhs_extract_index(
565  local_id_extract<RHSBlockProperties, Properties::TileSizeDimN>(thread_properties.linearLocalThreadId)) {}
566  };
567 
568  Scratch scratch;
569  const LhsMapper lhs;
570  const RhsMapper rhs;
571  OutAccessor out_res;
572  const StorageIndex groupSizeM;
573  const StorageIndex groupSizeN;
574  const StorageIndex numTiles;
575  const TripleDim triple_dim;
576 
577  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionKernel(Scratch scratch_, const LhsMapper lhs_,
578  const RhsMapper rhs_, OutAccessor out_res_,
579  const StorageIndex groupSizeM_,
580  const StorageIndex groupSizeN_,
581  const StorageIndex numTiles_,
582  const TripleDim triple_dim_)
583  : scratch(scratch_),
584  lhs(lhs_),
585  rhs(rhs_),
586  out_res(out_res_),
587  groupSizeM(groupSizeM_),
588  groupSizeN(groupSizeN_),
589  numTiles(numTiles_),
590  triple_dim(triple_dim_) {}
591 
592  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionKernel(Scratch scratch_, const LhsMapper lhs_,
593  const RhsMapper rhs_, OutAccessor out_res_,
594  const StorageIndex groupSizeM_,
595  const StorageIndex numTiles_,
596  const TripleDim triple_dim_)
597  : TensorContractionKernel(scratch_, lhs_, rhs_, out_res_, groupSizeM_, 1, numTiles_, triple_dim_) {}
598 
599  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID) {
600  const StorageIndex linearLocalThreadId = itemID.get_local_id(0);
601  const StorageIndex nLocalThreadId = linearLocalThreadId / Properties::LocalThreadSizeM;
602  const StorageIndex mLocalThreadId = linearLocalThreadId % Properties::LocalThreadSizeM;
603  const StorageIndex mGroupId = itemID.get_group(0) % groupSizeM;
604  const StorageIndex tmp = itemID.get_group(0) / groupSizeM;
605  const StorageIndex nGroupId = IsFinal ? tmp : tmp % groupSizeN;
606  const StorageIndex kGroupId = IsFinal ? 0 : tmp / groupSizeN;
607  const StorageIndex mGroupOffset = mGroupId * Properties::TileSizeDimM;
608  const StorageIndex nGroupOffset = nGroupId * Properties::TileSizeDimN;
609  const StorageIndex mLocalOffset = PacketSize * mLocalThreadId;
610  const StorageIndex nLocalOffset = NStride * nLocalThreadId;
611  const StorageIndex mGlobalOffset = mGroupOffset + mLocalOffset;
612  const StorageIndex nGlobalOffset = nGroupOffset + nLocalOffset;
613 
614  const StorageIndex kSizePerWG = IsFinal ? triple_dim.K : numTiles * Properties::TileSizeDimK;
615  StorageIndex kGroupOffset = kGroupId * kSizePerWG;
616  const bool is_internal = triple_dim.M - mGroupOffset >= Properties::TileSizeDimM &&
617  triple_dim.N - nGroupOffset >= Properties::TileSizeDimN &&
618  triple_dim.K - kGroupOffset >= kSizePerWG;
619  // this is used to adjust the last block
620  StorageIndex kSize = IsFinal ? triple_dim.K : std::min(kSizePerWG, triple_dim.K - kGroupOffset);
621  // This is used to find out the lats K offset so that kGroupOffset -kSize can compute the coffset for loading to
622  // tile
623  kGroupOffset += kSize;
624 
625  auto thread_properties =
626  ThreadProperties<StorageIndex>(linearLocalThreadId, kGroupId, mGroupOffset, nGroupOffset, kGroupOffset,
627  mLocalOffset, nLocalOffset, mGlobalOffset, nGlobalOffset, kSize, is_internal);
628 
629  auto out_ptr = out_res.get_pointer() + (IsFinal ? 0 : thread_properties.kGroupId * triple_dim.M * triple_dim.N);
630 
631  (thread_properties.is_internal) ? compute_panel<true>(itemID, thread_properties, out_ptr)
632  : compute_panel<false>(itemID, thread_properties, out_ptr);
633  }
634  // The compute block computes the contraction operation private block for each thread and store the resutl in the
635  // privateRes memory of Each computation the compute block function is independent of local and no local concepts as
636  // it only compute the block on each thread's private memory space
637  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_block_per_tile(OutScalar *lhs_block_ptr, OutScalar *rhs_block_ptr,
638  PacketReturnType *privateRes) {
639  StorageIndex idx = 0;
640  EIGEN_CONSTEXPR StorageIndex lhs_stride =
641  contraction_tp == contraction_type::local ? (PacketSize * Properties::LocalThreadSizeM) : 1;
642  EIGEN_UNROLL_LOOP
643  for (StorageIndex wLPTN = 0; wLPTN < Properties::WorkLoadPerThreadN; wLPTN++) {
644  auto rhsPacket = PacketReturnType{*(rhs_block_ptr + wLPTN)};
645  StorageIndex lhs_index = 0;
646  EIGEN_UNROLL_LOOP
647  for (StorageIndex wLPTM = 0; wLPTM < Properties::WorkLoadPerThreadM / PacketSize; wLPTM++) {
648  PacketReturnType lhsPack{};
649  Eigen::TensorSycl::internal::PacketWrapper<PacketReturnType, PacketSize>::set_packet(lhsPack,
650  lhs_block_ptr + lhs_index);
651  privateRes[idx] = ::Eigen::internal::pmadd(lhsPack, rhsPacket, privateRes[idx]);
652 
653  lhs_index += lhs_stride;
654  idx++;
655  }
656  }
657  }
658  // The store function write the computed contraction operation in the private memory of each thread to the global
659  // memory. The store function is independent of local and no local concepts s that it can be abstract out in the base
660  // class.
661  template <bool is_internal_block, StorageIndex PrivateNStride, typename OutPtr>
662  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void store(OutPtr *out_ptr, PacketReturnType *privateRes,
663  StorageIndex mGlobalOffset, StorageIndex nGlobalOffset) {
664  auto chk_bound = [&](const StorageIndex &mIndex, const StorageIndex &nIndex) EIGEN_DEVICE_FUNC {
665  return (mIndex + PacketSize - 1 < triple_dim.M && nGlobalOffset + nIndex < triple_dim.N);
666  };
667  // when local memory is not used M and N are both accessed in a coalesced way. However, when local memory is
668  // available the k*N is transposed in the local to N*K therefore, each blocks operates on blockId*
669  // WorkLoadPerThreadN slice of N
670  EIGEN_CONSTEXPR StorageIndex GlobalNStride =
671  contraction_tp == contraction_type::local ? 1 : Properties::LocalThreadSizeN;
672  EIGEN_UNROLL_LOOP
673  for (StorageIndex wLPTN = 0; wLPTN < Properties::WorkLoadPerThreadN / PrivateNStride; wLPTN++) {
674  // output leading dimension
675  StorageIndex outputLD = 0;
676  // When local memory is used the PrivateNstride is always 1 because the coalesed access on N is loaded into Local
677  // memory and extracting from local to global is the same as no transposed version. However, when local memory is
678  // not used and RHS is transposed we packetize the load for RHS.
679  EIGEN_UNROLL_LOOP
680  for (StorageIndex nId = 0; nId < PrivateNStride; nId++) {
681  StorageIndex globalRow = mGlobalOffset;
682  EIGEN_UNROLL_LOOP
683  for (StorageIndex wLPTM = 0; wLPTM < Properties::WorkLoadPerThreadM / PacketSize; wLPTM++) {
684  PacketReturnType privetOut = privateRes[wLPTM];
685  if (check_boundary<is_internal_block>(chk_bound(globalRow, nId))) {
686  // Store the final results in C. The C matrix has always M as a first StorageIndex and N as a second
687  // StorageIndex Therefore it is always coalesced layout
688  write<data_source::global_mem>(privetOut, out_ptr + outputLD + globalRow);
689  } else {
690  EIGEN_UNROLL_LOOP
691  for (StorageIndex mId = 0; mId < PacketSize; mId++) {
692  StorageIndex mOffset = globalRow + mId;
693  if (mOffset < triple_dim.M && (nGlobalOffset + nId < triple_dim.N)) {
694  out_ptr[mOffset + outputLD] =
695  Eigen::TensorSycl::internal::PacketWrapper<PacketReturnType, PacketSize>::scalarize(mId, privetOut);
696  }
697  }
698  }
699  globalRow += (PacketSize * Properties::LocalThreadSizeM);
700  }
701  outputLD += triple_dim.M;
702  privateRes += Properties::WorkLoadPerThreadM / PacketSize;
703  }
704  out_ptr += (GlobalNStride * outputLD);
705 
706  nGlobalOffset += (PrivateNStride * GlobalNStride);
707  }
708  }
709  // when no local memory is used the following extract_block will be enabled
710  template <typename InputBlockProperties, bool is_internal_block, typename Input, typename PrivateReg,
711  contraction_type contract_tp = contraction_tp>
712  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
713  typename ::Eigen::internal::enable_if<contract_tp == contraction_type::no_local>::type
714  extract_block(const Input &inpt, PrivateReg private_ptr, const std::pair<StorageIndex, StorageIndex> &,
715  const StorageIndex &ncOffset, const StorageIndex cOffset) {
716  EIGEN_CONSTEXPR StorageIndex LocalThreadSizeNC =
717  InputBlockProperties::is_rhs ? Properties::LocalThreadSizeN : Properties::LocalThreadSizeM;
718  EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadNC =
719  InputBlockProperties::is_rhs ? Properties::WorkLoadPerThreadN : Properties::WorkLoadPerThreadM;
720  const StorageIndex &NC = InputBlockProperties::is_rhs ? triple_dim.N : triple_dim.M;
721 
722  auto chk_bound = [&](const StorageIndex &CIndex, const StorageIndex &NCIndex) EIGEN_DEVICE_FUNC {
723  return ((CIndex + InputBlockProperties::c_stride - 1 < triple_dim.K) &&
724  (NCIndex + InputBlockProperties::nc_stride - 1 < NC));
725  };
726  const StorageIndex ld = InputBlockProperties::is_coalesced_layout ? NC : triple_dim.K;
727  StorageIndex cIndex = cOffset;
728 
729  EIGEN_UNROLL_LOOP
730  for (StorageIndex cId = 0; cId < Properties::TileSizeDimK / InputBlockProperties::c_stride; cId++) {
731  StorageIndex ncIndex = ncOffset;
732  EIGEN_UNROLL_LOOP
733  for (StorageIndex ncId = 0; ncId < WorkLoadPerThreadNC / InputBlockProperties::nc_stride; ncId++) {
734  if (check_boundary<is_internal_block>(chk_bound(cIndex, ncIndex))) {
735  auto val =
736  read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
737  InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, ncIndex, cIndex, ld);
738 
739  write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : WorkLoadPerThreadNC),
740  data_source::private_mem>(val, private_ptr);
741  } else {
742  EIGEN_UNROLL_LOOP
743  for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) {
744  const StorageIndex ncInd = ncIndex + (InputBlockProperties::is_coalesced_layout ? i : 0);
745  const StorageIndex cInd = cIndex + (InputBlockProperties::is_coalesced_layout ? 0 : i);
746  OutScalar val =
747  (ncInd < NC && cInd < triple_dim.K)
748  ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
749  inpt, ncInd, cInd, ld)
750  : OutScalar(0);
751  write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : WorkLoadPerThreadNC),
752  data_source::private_mem>(
753  val, private_ptr + (InputBlockProperties::is_coalesced_layout ? i : 0) +
754  ((InputBlockProperties::is_coalesced_layout ? 0 : i) * WorkLoadPerThreadNC));
755  }
756  }
757 
758  // if it is lhs we have to load it packetised when the packet size is > 1, because the output is coalesced. So
759  // even if M is not accessed in a coalesced mode, we have to load packet_size number of m per thread.
760  ncIndex = (!InputBlockProperties::is_rhs && InputBlockProperties::nc_stride == 1 && PacketSize != 1)
761  ? ncOffset + (ncId + 1) % PacketSize + ((ncId + 1) / PacketSize) * LocalThreadSizeNC
762  : (ncIndex + InputBlockProperties::nc_stride * LocalThreadSizeNC);
763  private_ptr += InputBlockProperties::nc_stride;
764  }
765  // the previous for loop ( private_ptr += (ncId * nc_stride)) has already moved ptr with one WorkLoadPerThreadNC
766  private_ptr += (InputBlockProperties::c_stride - 1) * WorkLoadPerThreadNC;
767  cIndex += InputBlockProperties::c_stride;
768  }
769  }
770  template <typename InputBlockProperties, StorageIndex TileSizeDimNC>
771  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::pair<StorageIndex, StorageIndex> local_id_extract(
772  const StorageIndex &linearLocalThreadId) {
773  const StorageIndex localThreadNC =
774  (InputBlockProperties::is_coalesced_layout)
775  ? linearLocalThreadId % (TileSizeDimNC / InputBlockProperties::nc_stride)
776  : linearLocalThreadId / (Properties::TileSizeDimK / InputBlockProperties::c_stride);
777  const StorageIndex localThreadC =
778  (InputBlockProperties::is_coalesced_layout)
779  ? linearLocalThreadId / (TileSizeDimNC / InputBlockProperties::nc_stride)
780  : linearLocalThreadId % (Properties::TileSizeDimK / InputBlockProperties::c_stride);
781  return std::pair<StorageIndex, StorageIndex>(localThreadNC, localThreadC);
782  }
783 
784  template <bool db = Properties::DoubleBuffer, contraction_type ctp = contraction_tp>
785  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
786  typename ::Eigen::internal::enable_if<db && ctp == contraction_type::local>::type
787  sync_mem(const cl::sycl::nd_item<1> &, bool &db_offset) noexcept {
788  db_offset = !db_offset;
789  }
790 
791  template <bool db = Properties::DoubleBuffer, contraction_type ctp = contraction_tp>
792  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
793  typename ::Eigen::internal::enable_if<!db && ctp == contraction_type::local>::type
794  sync_mem(const cl::sycl::nd_item<1> &itemID, bool &) noexcept {
795  itemID.barrier(cl::sycl::access::fence_space::local_space);
796  }
797 
798  template <contraction_type ctp = contraction_tp>
799  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
800  typename ::Eigen::internal::enable_if<ctp == contraction_type::no_local>::type
801  sync_mem(const cl::sycl::nd_item<1> &, bool &) noexcept {
802  return;
803  }
804 
805  template <bool need_sync, contraction_type ctp = contraction_tp>
806  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
807  typename ::Eigen::internal::enable_if<need_sync && ctp == contraction_type::no_local>::type
808  sync_thread(const cl::sycl::nd_item<1> &
809 #ifdef EIGEN_SYCL_ARM_GPU_CACHE_OPTIMISATION
810  itemID
811 #endif
812  ) noexcept {
813 #ifdef EIGEN_SYCL_ARM_GPU_CACHE_OPTIMISATION
814  itemID.barrier(cl::sycl::access::fence_spacce::local_space);
815 #else
816  return;
817 #endif
818  }
819  template <bool need_sync, contraction_type ctp = contraction_tp>
820  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
821  typename ::Eigen::internal::enable_if<need_sync && ctp == contraction_type::local>::type
822  sync_thread(const cl::sycl::nd_item<1> &itemID) {
823  itemID.barrier(cl::sycl::access::fence_space::local_space);
824  }
825  template <bool need_sync>
826  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<!need_sync>::type sync_thread(
827  const cl::sycl::nd_item<1> &) {
828  return;
829  }
830 
831  template <bool is_internal_block>
832  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_tile_per_panel(const cl::sycl::nd_item<1> &itemID,
833  ThreadProperties<StorageIndex> &thread_properties,
834  TiledMemory &tiled_input_block,
835  PacketReturnType *privateRes, bool &db_offset) {
836  // Tiling the Rhs block from global to local memory
837  extract_block<RHSBlockProperties, is_internal_block>(
838  rhs, tiled_input_block.rhs_scratch_extract.ptr + (db_offset * Properties::TileSizeDimK * LSDR),
839  tiled_input_block.rhs_extract_index,
840  contraction_tp == contraction_type::local ? thread_properties.nGroupOffset : thread_properties.nGlobalOffset,
841  thread_properties.kGroupOffset - thread_properties.kSize);
842 
843  sync_thread<contraction_tp == contraction_type::no_local>(itemID);
844 
845  // Tiling the Lhs block from global to local memory
846  extract_block<LHSBlockProperties, is_internal_block>(
847  lhs, tiled_input_block.lhs_scratch_extract.ptr + (db_offset * LSDL * Properties::TileSizeDimK),
848  tiled_input_block.lhs_extract_index,
849  contraction_tp == contraction_type::local ? thread_properties.mGroupOffset : thread_properties.mGlobalOffset,
850  thread_properties.kGroupOffset - thread_properties.kSize);
851 
852  // itemID.barrier(cl::sycl::access::fence_space::local_space);
853  sync_thread<contraction_tp == contraction_type::local>(itemID);
854  // switch to compute mede
855  StorageIndex lhs_offset = (db_offset * LSDL * Properties::TileSizeDimK);
856  StorageIndex rhs_offset = (db_offset * Properties::TileSizeDimK * LSDR);
857  // Loop over the values of a single tile
858  for (StorageIndex k = 0; k < Properties::TileSizeDimK; k++) {
859  compute_block_per_tile(tiled_input_block.lhs_scratch_ptr_compute + lhs_offset,
860  tiled_input_block.rhs_scratch_ptr_compute + rhs_offset, privateRes);
861  lhs_offset += LSDL;
862  rhs_offset += LSDR;
863  }
864  // computing the K index for the next tile
865  thread_properties.kSize -= Properties::TileSizeDimK;
866  sync_mem(itemID, db_offset);
867  }
868 
869  // when local memory is available the following compute_panel will be enabled
870  template <bool is_internal_block, typename OutPtr>
871  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_panel(const cl::sycl::nd_item<1> &itemID,
872  ThreadProperties<StorageIndex> &thread_properties,
873  OutPtr out_ptr) {
874  auto tiled_input_block = TiledMemory{thread_properties, scratch.get_pointer()};
875  // Allocate register space
876  PacketReturnType privateRes[Properties::WorkLoadPerThreadM * Properties::WorkLoadPerThreadN / PacketSize] = {
877  PacketReturnType{0}};
878  bool db_offset = 0;
879 
880  while (thread_properties.kSize >= Properties::TileSizeDimK) {
881  compute_tile_per_panel<is_internal_block>(itemID, thread_properties, tiled_input_block, privateRes, db_offset);
882  }
883  if (thread_properties.kSize > 0) {
884  compute_tile_per_panel<false>(itemID, thread_properties, tiled_input_block, privateRes, db_offset);
885  }
886 
887  // Storing the final results in the output
888  store<is_internal_block,
889  contraction_tp == contraction_type::local ? static_cast<StorageIndex>(1) : RHSBlockProperties::nc_stride>(
890  out_ptr + thread_properties.nGlobalOffset * triple_dim.M, privateRes, thread_properties.mGlobalOffset,
891  thread_properties.nGlobalOffset);
892  }
893  // When local memory is available the following extract_block will be enabled
894  template <typename InputBlockProperties, bool is_internal_block, typename Input, typename Local,
895  contraction_type contract_tp = contraction_tp>
896  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
897  typename ::Eigen::internal::enable_if<contract_tp == contraction_type::local>::type
898  extract_block(const Input &inpt, Local local_ptr, const std::pair<StorageIndex, StorageIndex>& local_index,
899  const StorageIndex &ncOffset, const StorageIndex cOffset) {
900  EIGEN_CONSTEXPR StorageIndex TileSizeDimNC =
901  InputBlockProperties::is_rhs ? Properties::TileSizeDimN : Properties::TileSizeDimM;
902  EIGEN_CONSTEXPR StorageIndex LoadPerThread =
903  InputBlockProperties::is_rhs ? Properties::LoadPerThreadRhs : Properties::LoadPerThreadLhs;
904  EIGEN_CONSTEXPR StorageIndex LSD = InputBlockProperties::is_rhs ? LSDR : LSDL;
905  static_assert(((LocalOffset % (TileSizeDimNC / InputBlockProperties::nc_stride) == 0) &&
906  (LocalOffset % (Properties::TileSizeDimK / InputBlockProperties::c_stride) == 0)),
907  " LocalOffset must be divisable by stride");
908  const StorageIndex &NC = InputBlockProperties::is_rhs ? triple_dim.N : triple_dim.M;
909  StorageIndex localThreadNC = local_index.first;
910  StorageIndex localThreadC = local_index.second;
911  auto chk_bound = [&](const StorageIndex &CIndex, const StorageIndex &NCIndex) EIGEN_DEVICE_FUNC {
912  return ((CIndex + InputBlockProperties::c_stride - 1 < triple_dim.K) &&
913  (NCIndex + InputBlockProperties::nc_stride - 1 < NC));
914  };
915  EIGEN_UNROLL_LOOP
916  for (StorageIndex lPT = 0; lPT < LoadPerThread / InputBlockProperties::elements_per_access; lPT++) {
917  const StorageIndex CIndex = cOffset + (InputBlockProperties::c_stride * localThreadC);
918  const StorageIndex NCIndex = ncOffset + (InputBlockProperties::nc_stride * localThreadNC);
919  const StorageIndex ld = InputBlockProperties::is_coalesced_layout ? NC : triple_dim.K;
920  if (check_boundary<is_internal_block>(chk_bound(CIndex, NCIndex))) {
921  auto val =
922  read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
923  InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, NCIndex, CIndex, ld);
924  write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : LSD), data_source::local_mem>(
925  val, local_ptr + (InputBlockProperties::nc_stride * localThreadNC) +
926  (InputBlockProperties::c_stride * localThreadC * LSD));
927  } else {
928  EIGEN_UNROLL_LOOP
929  for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) {
930  const StorageIndex nCInd = NCIndex + (InputBlockProperties::is_coalesced_layout ? i : 0);
931  const StorageIndex cInd = CIndex + (InputBlockProperties::is_coalesced_layout ? 0 : i);
932  OutScalar val =
933  (nCInd < NC && cInd < triple_dim.K)
934  ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
935  inpt, nCInd, cInd, ld)
936  : OutScalar(0);
937 
938  write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : LSD), data_source::local_mem>(
939  val, local_ptr + (InputBlockProperties::nc_stride * localThreadNC) +
940  (InputBlockProperties::is_coalesced_layout ? i : 0) +
941  ((InputBlockProperties::c_stride * localThreadC +
942  (InputBlockProperties::is_coalesced_layout ? 0 : i)) *
943  LSD));
944  }
945  }
946  localThreadNC += (InputBlockProperties::is_coalesced_layout)
947  ? LocalOffset % (TileSizeDimNC / InputBlockProperties::nc_stride)
948  : LocalOffset / (Properties::TileSizeDimK / InputBlockProperties::c_stride);
949  localThreadC += (InputBlockProperties::is_coalesced_layout)
950  ? LocalOffset / (TileSizeDimNC / InputBlockProperties::nc_stride)
951  : LocalOffset % (Properties::TileSizeDimK / InputBlockProperties::c_stride);
952  }
953  }
954 };
955 
956 #ifndef EIGEN_SYCL_DISABLE_GEMV
957 
999 template <typename OutScalar, typename OutAccessor, typename VectorMapper, typename TensorMapper, typename StorageIndex,
1000  typename Properties, StorageIndex KFactor, bool Vectorizable, bool is_lhs_vec, bool IsFinal>
1001 struct GeneralVectorTensor {
1002  typedef typename Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketReturnType
1003  PacketReturnType;
1004  static EIGEN_CONSTEXPR int PacketSize =
1005  Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketSize;
1006  typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
1007 
1008  static EIGEN_CONSTEXPR StorageIndex OutScratchOffset =
1009  KFactor * Properties::LocalThreadSizeC * Properties::LocalThreadSizeNC;
1010 
1011  // Since the access layout for a vector can always be coalesced, when LHS is a vector, we pass false and false to make
1012  // sure that the !^ is true When RHS is a vector, we pass true and true to make sure that the !^ is true.
1013  typedef BlockProperties<is_lhs_vec ? false : true, is_lhs_vec ? false : true, Vectorizable, PacketReturnType>
1014  VecBlockProperties;
1015 
1016  Scratch scratch;
1017  const VectorMapper vec;
1018  const TensorMapper mat;
1019  OutAccessor out_res;
1020  const StorageIndex nonContractGroupSize;
1021  const StorageIndex nonContractDim;
1022  const StorageIndex contractDim;
1023 
1024  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE GeneralVectorTensor(Scratch scratch_, const VectorMapper vec_,
1025  const TensorMapper mat_, OutAccessor out_res_,
1026  const StorageIndex nonContractGroupSize_,
1027  const StorageIndex nonContractDim_,
1028  const StorageIndex contractDim_)
1029  : scratch(scratch_),
1030  vec(vec_),
1031  mat(mat_),
1032  out_res(out_res_),
1033  nonContractGroupSize(nonContractGroupSize_),
1034  nonContractDim(nonContractDim_),
1035  contractDim(contractDim_) {}
1036 
1037  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID) {
1038  auto scratch_ptr = scratch.get_pointer();
1039  const StorageIndex linearLocalThreadId = itemID.get_local_id(0);
1040  StorageIndex nonContractId = is_lhs_vec ? linearLocalThreadId / Properties::LocalThreadSizeC
1041  : linearLocalThreadId % Properties::LocalThreadSizeNC;
1042  StorageIndex contractId = is_lhs_vec ? linearLocalThreadId % Properties::LocalThreadSizeC
1043  : linearLocalThreadId / Properties::LocalThreadSizeNC;
1044  const StorageIndex cGroupSize = itemID.get_group_range(0) / nonContractGroupSize;
1045  const StorageIndex nonContractGroupId =
1046  is_lhs_vec ? itemID.get_group(0) / cGroupSize : itemID.get_group(0) % nonContractGroupSize;
1047  const StorageIndex contractGroupId =
1048  is_lhs_vec ? itemID.get_group(0) % cGroupSize : itemID.get_group(0) / nonContractGroupSize;
1049  auto out_ptr = out_res.get_pointer() + (IsFinal ? 0 : contractGroupId * nonContractDim);
1050 
1051  const StorageIndex nonContractGroupOffset = nonContractGroupId * Properties::TileSizeDimNC;
1052  const StorageIndex contractGroupOffset = contractGroupId * Properties::TileSizeDimC;
1053  auto outScratchIndex = nonContractId + contractId * Properties::LocalThreadSizeNC;
1054  const StorageIndex globalNonContractDimOffset = nonContractGroupOffset + nonContractId;
1055  const StorageIndex globalContractDimOffset = contractGroupOffset + contractId;
1056  auto local_output = scratch_ptr + OutScratchOffset;
1057  const bool is_internal = nonContractDim - nonContractGroupOffset >= Properties::TileSizeDimNC &&
1058  contractDim - contractGroupOffset >= Properties::TileSizeDimC;
1059  is_internal
1060  ? compute_panel<true>(itemID, vec, mat, local_output, out_ptr,
1061 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1062  scratch_ptr, contractGroupOffset,
1063 #endif
1064  nonContractGroupOffset, linearLocalThreadId, contractDim, nonContractDim, contractId,
1065  nonContractId, globalContractDimOffset, globalNonContractDimOffset, outScratchIndex)
1066  : compute_panel<false>(itemID, vec, mat, local_output, out_ptr,
1067 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1068  scratch_ptr, contractGroupOffset,
1069 #endif
1070  nonContractGroupOffset, linearLocalThreadId, contractDim, nonContractDim, contractId,
1071  nonContractId, globalContractDimOffset, globalNonContractDimOffset, outScratchIndex);
1072  }
1073  template <bool is_internal_block, typename OutPtr>
1074  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_panel(
1075  const cl::sycl::nd_item<1> &itemID, const VectorMapper &vec, const TensorMapper &mat, OutScalar *local_output,
1076  OutPtr out_ptr,
1077 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1078  OutScalar *scratch_ptr, const StorageIndex contractGroupOffset,
1079 #endif
1080  const StorageIndex nonContractGroupOffset, const StorageIndex linearLocalThreadId, StorageIndex contractDim,
1081  StorageIndex nonContractDim, StorageIndex contractId, StorageIndex nonContractId,
1082  StorageIndex globalContractDimOffset, StorageIndex globalNonContractDimOffset, StorageIndex outScratchIndex) {
1083  OutScalar outScalar[Properties::WorkLoadPerThreadNC] = {OutScalar(0)};
1084  // Reading the vector
1085 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1086  const StorageIndex vectorOffset = contractGroupOffset + linearLocalThreadId;
1087  extract_block<VecBlockProperties, is_internal_block, KFactor,
1088  Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC>(vec, scratch_ptr, linearLocalThreadId,
1089  vectorOffset, contractDim);
1090 
1091  itemID.barrier(cl::sycl::access::fence_space::local_space);
1092  auto in_scratch_ptr = scratch_ptr + contractId;
1093 #endif
1094 
1095  StorageIndex privateOffsetC = 0;
1096  EIGEN_UNROLL_LOOP
1097  for (StorageIndex i = 0; i < Properties::WorkLoadPerThreadC; i++) {
1098  StorageIndex privateOffsetNC = 0;
1099  bool contract_conds = ((globalContractDimOffset + privateOffsetC) < contractDim);
1100 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1101  auto vecScalar = *in_scratch_ptr;
1102 #else
1103  auto vecScalar = (check_boundary<is_internal_block>(contract_conds))
1104  ? vec(is_lhs_vec ? StorageIndex(0) : globalContractDimOffset + privateOffsetC,
1105  is_lhs_vec ? globalContractDimOffset + privateOffsetC : StorageIndex(0))
1106  : OutScalar(0);
1107 #endif
1108  EIGEN_UNROLL_LOOP
1109  for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1110  auto matScalar = (check_boundary<is_internal_block>(
1111  contract_conds && ((globalNonContractDimOffset + privateOffsetNC) < nonContractDim)))
1112  ? mat(is_lhs_vec ? globalContractDimOffset + privateOffsetC
1113  : globalNonContractDimOffset + privateOffsetNC,
1114  is_lhs_vec ? globalNonContractDimOffset + privateOffsetNC
1115  : globalContractDimOffset + privateOffsetC)
1116  : OutScalar(0);
1117 
1118  outScalar[j] = cl::sycl::mad(matScalar, vecScalar, outScalar[j]);
1119  privateOffsetNC += Properties::LocalThreadSizeNC;
1120  }
1121  privateOffsetC += Properties::LocalThreadSizeC;
1122 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1123  in_scratch_ptr += Properties::LocalThreadSizeC;
1124 #endif
1125  }
1126 
1127  auto out_scratch_ptr = local_output + outScratchIndex;
1128  // Each block of 16*16 element in shared memory should reduce to 16*1
1129  EIGEN_UNROLL_LOOP
1130  for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1131  *out_scratch_ptr = outScalar[j];
1132 
1133  out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1134  }
1135  if (is_lhs_vec) {
1136  nonContractId = linearLocalThreadId % Properties::LocalThreadSizeNC;
1137  contractId = linearLocalThreadId / Properties::LocalThreadSizeNC;
1138  outScratchIndex = nonContractId + contractId * Properties::LocalThreadSizeNC;
1139  }
1140 
1141  out_scratch_ptr = local_output + outScratchIndex;
1142  EIGEN_UNROLL_LOOP
1143  for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1144  EIGEN_UNROLL_LOOP
1145  for (StorageIndex offset = Properties::LocalThreadSizeC >> 1; offset > 0; offset >>= 1) {
1146  itemID.barrier(cl::sycl::access::fence_space::local_space);
1147  if (contractId < offset) {
1148  StorageIndex myNeigbourId = (Properties::LocalThreadSizeNC * offset);
1149  *out_scratch_ptr += out_scratch_ptr[myNeigbourId];
1150  }
1151  }
1152  // moving to next 16 by 16 block
1153  out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1154  }
1155 
1156  if (contractId == 0) {
1157  out_scratch_ptr = local_output + nonContractId;
1158  StorageIndex global_final_offset = nonContractGroupOffset + nonContractId;
1159  out_ptr += global_final_offset;
1160  EIGEN_UNROLL_LOOP
1161  for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1162  if (check_boundary<is_internal_block>(global_final_offset < nonContractDim)) {
1163  auto res = *out_scratch_ptr;
1164 
1165  *out_ptr = res;
1166  out_ptr += Properties::LocalThreadSizeNC;
1167  }
1168  // moving to next 16 by 16 block to ge the next 16 reduced elements
1169  out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1170  if (!(is_internal_block)) global_final_offset += Properties::LocalThreadSizeNC;
1171  }
1172  }
1173  }
1174 
1175  template <typename InputBlockProperties, bool is_internal_block, int CFactor, int GroupSize, typename Input,
1176  typename Local>
1177  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void extract_block(const Input &inpt, Local *local_ptr,
1178  const StorageIndex &linearLocalThreadId,
1179  const StorageIndex &cOffset, const StorageIndex &C) {
1180  local_ptr += InputBlockProperties::c_stride * linearLocalThreadId;
1181  StorageIndex cIndex = cOffset;
1182  for (StorageIndex cId = 0; cId < CFactor / InputBlockProperties::c_stride; cId++) {
1183  if (check_boundary<is_internal_block>(cIndex + InputBlockProperties::c_stride - 1 < C)) {
1184  auto val = read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
1185  InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, StorageIndex(0),
1186  cIndex, StorageIndex(1));
1187  write<StorageIndex, 1, data_source::local_mem>(val, local_ptr);
1188  } else {
1189  EIGEN_UNROLL_LOOP
1190  for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) {
1191  OutScalar val =
1192  (cIndex + i < C)
1193  ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
1194  inpt, StorageIndex(0), cIndex + i, StorageIndex(1))
1195  : OutScalar(0);
1196  write<StorageIndex, 1, data_source::local_mem>(val, local_ptr + i);
1197  }
1198  }
1199  local_ptr += InputBlockProperties::c_stride * GroupSize;
1200  cIndex += InputBlockProperties::c_stride * GroupSize;
1201  }
1202  }
1203 };
1204 #endif
1205 
1206 #ifndef EIGEN_SYCL_DISABLE_SCALAR
1207 
1239 template <typename OutScalar, typename LhsScalar, typename RhsScalar, typename OutAccessor, typename LhsMapper,
1240  typename RhsMapper, typename StorageIndex, bool Vectorizable>
1241 struct GeneralScalarContraction {
1242  typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
1243  Scratch scratch;
1244  const LhsMapper lhs;
1245  const RhsMapper rhs;
1246  OutAccessor out_res;
1247  const StorageIndex rng;
1248 
1249  EIGEN_DEVICE_FUNC
1250  GeneralScalarContraction(Scratch scratch_, const LhsMapper lhs_, const RhsMapper rhs_, OutAccessor out_res_,
1251  const StorageIndex rng_)
1252  : scratch(scratch_), lhs(lhs_), rhs(rhs_), out_res(out_res_), rng(rng_) {}
1253 
1254  EIGEN_DEVICE_FUNC void operator()(cl::sycl::nd_item<1> itemID) {
1255  auto out_ptr = out_res.get_pointer();
1256  auto scratch_ptr = scratch.get_pointer().get();
1257 
1258  StorageIndex globalid = itemID.get_global_id(0);
1259  StorageIndex localid = itemID.get_local_id(0);
1260  OutScalar accumulator = OutScalar(0);
1261  for (StorageIndex i = globalid; i < rng; i += itemID.get_global_range(0)) {
1262  accumulator = cl::sycl::mad(lhs(0, i), rhs(i, 0), accumulator);
1263  }
1264  auto out_scratch_ptr = scratch_ptr + localid;
1265  *out_scratch_ptr = accumulator;
1266  for (StorageIndex offset = itemID.get_local_range(0) >> 1; offset > 0; offset >>= 1) {
1267  itemID.barrier(cl::sycl::access::fence_space::local_space);
1268  if (localid < offset) {
1269  *out_scratch_ptr = (accumulator += out_scratch_ptr[offset]);
1270  }
1271  }
1272  if (localid == 0) {
1273  out_ptr[itemID.get_group(0)] = accumulator;
1274  }
1275  }
1276 };
1277 #endif
1278 
1279 } // namespace internal
1280 } // namespace TensorSycl
1281 
1282 template <typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
1283 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>,
1284  Eigen::SyclDevice>
1285  : public TensorContractionEvaluatorBase<TensorEvaluator<
1286  const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Eigen::SyclDevice>> {
1287  static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
1288  "SYCL tensor contraction does not support output kernels.");
1289 
1290  typedef Eigen::SyclDevice Device;
1291 
1292  typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
1293  typedef TensorContractionEvaluatorBase<Self> Base;
1294  typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
1295  typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
1296  typedef typename XprType::Index StorageIndex;
1297  typedef typename XprType::CoeffReturnType CoeffReturnType;
1298  typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
1299  typedef typename Base::Storage Storage;
1300  typedef typename Base::EvaluatorPointerType EvaluatorPointerType;
1301  struct TripleDim {
1302  const StorageIndex M;
1303  const StorageIndex N;
1304  const StorageIndex K;
1305  TripleDim(const StorageIndex M_, const StorageIndex N_, const StorageIndex K_) : M(M_), N(N_), K(K_) {}
1306  };
1307  enum {
1308  Layout = TensorEvaluator<LeftArgType, Device>::Layout,
1309  PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
1310  BlockAccess = false,
1311  };
1312 
1313  static EIGEN_CONSTEXPR int LDims = Base::LDims;
1314  static EIGEN_CONSTEXPR int RDims = Base::RDims;
1315  static EIGEN_CONSTEXPR int ContractDims = Base::ContractDims;
1316 
1317  typedef array<StorageIndex, LDims> left_dim_mapper_t;
1318  typedef array<StorageIndex, RDims> right_dim_mapper_t;
1319 
1320  typedef array<StorageIndex, ContractDims> contract_t;
1321  typedef array<StorageIndex, LDims - ContractDims> left_nocontract_t;
1322  typedef array<StorageIndex, RDims - ContractDims> right_nocontract_t;
1323 
1324  static const int NumDims = LDims + RDims - 2 * ContractDims;
1325 
1326  typedef DSizes<StorageIndex, NumDims> Dimensions;
1327 
1328  typedef TensorEvaluator<typename Base::EvalLeftArgType, Device> LeftEvaluator;
1329  typedef TensorEvaluator<typename Base::EvalRightArgType, Device> RightEvaluator;
1330  typedef typename Eigen::internal::remove_const<typename LeftEvaluator::CoeffReturnType>::type LhsScalar;
1331  typedef typename Eigen::internal::remove_const<typename RightEvaluator::CoeffReturnType>::type RhsScalar;
1332 
1333  typedef typename LeftEvaluator::Dimensions LeftDimensions;
1334  typedef typename RightEvaluator::Dimensions RightDimensions;
1335 
1336  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered>
1337  struct input_mapper_propertis {
1338  static EIGEN_CONSTEXPR bool is_lhs_matrix = (LDims == 2 && ContractDims == 1) || lhs_inner_dim_contiguous;
1339  static EIGEN_CONSTEXPR bool is_rhs_matrix =
1340  (RDims == 2 && ContractDims == 1) || (rhs_inner_dim_contiguous && !rhs_inner_dim_reordered);
1341  };
1342 
1343  TensorEvaluator(const XprType &op, const Device &device) : Base(op, device) {}
1344 
1345  // We need to redefine this method to make nvcc happy
1346  EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(typename Base::EvaluatorPointerType data) {
1347  this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1348  this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1349  if (!data) {
1350  this->m_result = this->m_device.get(
1351  static_cast<Scalar *>(this->m_device.allocate_temp(this->dimensions().TotalSize() * sizeof(Scalar))));
1352  data = this->m_result;
1353  }
1354  evalToSycl(data);
1355  return (this->m_result != NULL);
1356  }
1357  const Eigen::SyclDevice &device() const { return this->m_device; }
1358  void evalToSycl(typename Base::EvaluatorPointerType buffer) const {
1359  if (this->m_lhs_inner_dim_contiguous) {
1360  if (this->m_rhs_inner_dim_contiguous) {
1361  if (this->m_rhs_inner_dim_reordered) {
1362  evalTyped<true, true, true, Unaligned>(buffer);
1363  } else {
1364  evalTyped<true, true, false, Unaligned>(buffer);
1365  }
1366  } else {
1367  if (this->m_rhs_inner_dim_reordered) {
1368  evalTyped<true, false, true, Unaligned>(buffer);
1369  } else {
1370  evalTyped<true, false, false, Unaligned>(buffer);
1371  }
1372  }
1373  } else {
1374  if (this->m_rhs_inner_dim_contiguous) {
1375  if (this->m_rhs_inner_dim_reordered) {
1376  evalTyped<false, true, true, Unaligned>(buffer);
1377  } else {
1378  evalTyped<false, true, false, Unaligned>(buffer);
1379  }
1380  } else {
1381  if (this->m_rhs_inner_dim_reordered) {
1382  evalTyped<false, false, true, Unaligned>(buffer);
1383  } else {
1384  evalTyped<false, false, false, Unaligned>(buffer);
1385  }
1386  }
1387  }
1388  }
1389 
1390  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
1391  void evalTyped(typename Base::EvaluatorPointerType buffer) const {
1392  const auto triple_dim = TripleDim{this->m_i_size, this->m_j_size, this->m_k_size};
1393  typedef internal::TensorContractionInputMapper<
1394  LhsScalar, StorageIndex, internal::Lhs, LeftEvaluator, left_nocontract_t, contract_t,
1395  PacketType<CoeffReturnType, Device>::size, lhs_inner_dim_contiguous, false, Unaligned, MakeSYCLPointer>
1396  LhsMapper;
1397 
1398  typedef internal::TensorContractionInputMapper<RhsScalar, StorageIndex, internal::Rhs, RightEvaluator,
1399  right_nocontract_t, contract_t,
1400  PacketType<CoeffReturnType, Device>::size, rhs_inner_dim_contiguous,
1401  rhs_inner_dim_reordered, Unaligned, MakeSYCLPointer>
1402  RhsMapper;
1403 
1404  // initialize data mappers
1405  LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1406  this->m_left_contracting_strides, this->m_k_strides);
1407 
1408  RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1409  this->m_right_contracting_strides, this->m_k_strides);
1410 
1411 #ifndef EIGEN_SYCL_DISABLE_SCALAR
1412  if (triple_dim.M == 1 && triple_dim.N == 1) {
1413  launchSC(buffer, lhs, rhs, triple_dim.K);
1414  } else
1415 #endif
1416 #ifndef EIGEN_SYCL_DISABLE_GEMV
1417  if (triple_dim.M != 1 && triple_dim.N == 1) {
1418  LaunchVT<false>(buffer, rhs, lhs, triple_dim.M, triple_dim.K);
1419  } else if (triple_dim.M == 1 && triple_dim.N != 1) {
1420  LaunchVT<true>(buffer, lhs, rhs, triple_dim.N, triple_dim.K);
1421  } else // This is equivalent of if (m!=1 && n!=1)
1422 #endif
1423  {
1424  typedef input_mapper_propertis<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>
1425  inpt_mapper_properties;
1426 #ifndef EIGEN_SYCL_DISABLE_SKINNY
1427  bool skinny = false;
1428  auto platform_name = this->device().getPlatformName();
1429  // This is based on empirical calculation for AMD r9-nano and Fiji
1430  if (platform_name.find("AMD") == 0) {
1431  skinny = (triple_dim.M < triple_dim.K || triple_dim.N < triple_dim.K) &&
1432  ((triple_dim.M < 1024 && triple_dim.N < 1024) ||
1433  (uint64_t(triple_dim.M * triple_dim.N) < uint64_t(triple_dim.K)));
1434  } else {
1435  skinny = (((std::max(triple_dim.K, triple_dim.N) / std::min(triple_dim.K, triple_dim.N)) > 100) ||
1436  ((std::max(triple_dim.K, triple_dim.M) / std::min(triple_dim.K, triple_dim.M)) > 100) ||
1437  ((std::max(triple_dim.N, triple_dim.M) / std::min(triple_dim.N, triple_dim.M)) > 100));
1438  }
1439  if (skinny)
1440  adjustTT<true, inpt_mapper_properties>(buffer, lhs, rhs, triple_dim);
1441  else
1442 #endif // EIGEN_SYCL_DISABLE_SKINNY
1443  adjustTT<false, inpt_mapper_properties>(buffer, lhs, rhs, triple_dim);
1444  }
1445  }
1446 
1447  template <bool skinny, typename input_mapper_properties, typename LhsMapper, typename RhsMapper>
1448  void EIGEN_ALWAYS_INLINE adjustTT(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs,
1449  const TripleDim &triple_dim) const {
1450 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1451  if (device().has_local_memory()) {
1452  typedef TensorSycl::internal::TTPanelSize<CoeffReturnType, StorageIndex, 4, 4, 16> PanelParameters;
1453  launchTT<TensorSycl::internal::contraction_type::local, skinny, input_mapper_properties, PanelParameters>(
1454  buffer, lhs, rhs, triple_dim);
1455  }
1456 #endif
1457 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_OFF
1458  if (!(device().has_local_memory())) {
1459  typedef TensorSycl::internal::TTPanelSize<CoeffReturnType, StorageIndex, 4, 4, 4> PanelParameters;
1460  launchTT<TensorSycl::internal::contraction_type::no_local, skinny, input_mapper_properties, PanelParameters>(
1461  buffer, lhs, rhs, triple_dim);
1462  }
1463 #endif
1464  }
1465 
1466  template <TensorSycl::internal::contraction_type ct, bool skinny, typename input_mapper_properties,
1467  typename Properties, typename LhsMapper, typename RhsMapper>
1468  void launchTT(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs,
1469  const TripleDim &triple_dim) const {
1470  const StorageIndex roundUpM = Eigen::TensorSycl::internal::roundUp(triple_dim.M, Properties::TileSizeDimM);
1471  const StorageIndex roundUpN = Eigen::TensorSycl::internal::roundUp(triple_dim.N, Properties::TileSizeDimN);
1472  const StorageIndex groupSizeM = roundUpM / Properties::TileSizeDimM;
1473  const StorageIndex groupSizeN = roundUpN / Properties::TileSizeDimN;
1474 
1475  const StorageIndex roundUpK = Eigen::TensorSycl::internal::roundUp(triple_dim.K, Properties::TileSizeDimK);
1476  StorageIndex totalTilesK = roundUpK / Properties::TileSizeDimK;
1477  StorageIndex groupSizeK =
1478  skinny
1479  ? std::max(std::min(totalTilesK,
1480  (StorageIndex)(device().getPowerOfTwo(device().getNumSyclMultiProcessors(), true) * 4) /
1481  (groupSizeM * groupSizeN)),
1482  StorageIndex(1))
1483  : StorageIndex(1);
1484 
1485  const StorageIndex numTilesPerGroup = Eigen::TensorSycl::internal::roundUp(totalTilesK, groupSizeK) / groupSizeK;
1486 
1487  const StorageIndex totalGroupSize = groupSizeM * groupSizeN * groupSizeK;
1488 
1489  const StorageIndex localRange = Properties::LocalThreadSizeM * Properties::LocalThreadSizeN;
1490  const StorageIndex globalRange = totalGroupSize * localRange;
1491 
1492  const StorageIndex scratchSize = (ct == TensorSycl::internal::contraction_type::local)
1493  ? ((Properties::DoubleBuffer + 1) *
1494  (Properties::TileSizeDimM + Properties::BC) * (Properties::TileSizeDimK)) +
1495  ((Properties::DoubleBuffer + 1) * (Properties::TileSizeDimK) *
1496  (Properties::TileSizeDimN + Properties::BC))
1497  : StorageIndex(1);
1498 
1499  auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange));
1500  if (groupSizeK == 1) {
1501  typedef TensorSycl::internal::TensorContractionKernel<CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType,
1502  LhsMapper, RhsMapper, StorageIndex, Properties, TripleDim,
1503  PacketAccess, input_mapper_properties, true, ct>
1504  ContractKernelName;
1505  device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1506  lhs, rhs, buffer, thread_range, scratchSize, groupSizeM, groupSizeN, numTilesPerGroup, triple_dim);
1507  } else {
1508  typedef TensorSycl::internal::TensorContractionKernel<CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType,
1509  LhsMapper, RhsMapper, StorageIndex, Properties, TripleDim,
1510  PacketAccess, input_mapper_properties, false, ct>
1511  ContractKernelName;
1512  CoeffReturnType *temp_pointer = static_cast<CoeffReturnType *>(
1513  device().allocate_temp(triple_dim.M * triple_dim.N * groupSizeK * sizeof(CoeffReturnType)));
1514  EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer);
1515 
1516  device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1517  lhs, rhs, tmp_global_accessor, thread_range, scratchSize, groupSizeM, groupSizeN, numTilesPerGroup,
1518  triple_dim);
1519 
1520  typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1521  auto op = Op();
1522  typedef TensorSycl::internal::SecondStepPartialReduction<CoeffReturnType, StorageIndex, EvaluatorPointerType,
1523  EvaluatorPointerType, Op>
1524  ReductionKernel;
1525 
1526  device().template unary_kernel_launcher<CoeffReturnType, ReductionKernel>(
1527  tmp_global_accessor, buffer,
1528  cl::sycl::nd_range<1>(cl::sycl::range<1>(StorageIndex(
1529  Eigen::TensorSycl::internal::roundUp(triple_dim.M * triple_dim.N, localRange))),
1530  cl::sycl::range<1>(localRange)),
1531  StorageIndex(1), op, StorageIndex(triple_dim.M * triple_dim.N), groupSizeK);
1532 
1533  device().deallocate_temp(temp_pointer);
1534  }
1535  }
1536 
1537 #ifndef EIGEN_SYCL_DISABLE_GEMV
1538  template <bool is_lhs_vec, typename VectorMapper, typename TensorMapper, typename StorageIndex>
1539  void EIGEN_ALWAYS_INLINE LaunchVT(EvaluatorPointerType buffer, const VectorMapper &vec, const TensorMapper &mat,
1540  StorageIndex NC, StorageIndex C) const {
1541  const StorageIndex nonContractDim = NC;
1542  EIGEN_CONSTEXPR StorageIndex NCFactor = 1;
1543  EIGEN_CONSTEXPR StorageIndex CFactor = 1;
1544  EIGEN_CONSTEXPR StorageIndex NCWindow = 16;
1545  typedef Eigen::TensorSycl::internal::TVPanelSize<CoeffReturnType, StorageIndex, NCWindow, CFactor, NCFactor>
1546  Properties;
1547  const StorageIndex roundUpC = Eigen::TensorSycl::internal::roundUp(C, Properties::TileSizeDimC);
1548  const StorageIndex cNumGroups = roundUpC / (Properties::LocalThreadSizeC * Properties::WorkLoadPerThreadC);
1549  const StorageIndex roundUpNC = Eigen::TensorSycl::internal::roundUp(nonContractDim, Properties::TileSizeDimNC);
1550  const StorageIndex nCNumGroups = roundUpNC / (Properties::LocalThreadSizeNC * Properties::WorkLoadPerThreadNC);
1551  const StorageIndex globalRange =
1552  (roundUpNC / (Properties::WorkLoadPerThreadNC)) * (roundUpC / (Properties::WorkLoadPerThreadC));
1553  const StorageIndex localRange = Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC;
1554  const StorageIndex scratchSize =
1555  (Properties::WorkLoadPerThreadNC + CFactor) * Properties::LocalThreadSizeC * Properties::LocalThreadSizeNC;
1556  auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange));
1557  if (cNumGroups > 1) {
1558  typedef Eigen::TensorSycl::internal::GeneralVectorTensor<CoeffReturnType, EvaluatorPointerType, VectorMapper,
1559  TensorMapper, StorageIndex, Properties, CFactor, false,
1560  is_lhs_vec, false>
1561  ContractKernelName;
1562  CoeffReturnType *temp_pointer =
1563  static_cast<CoeffReturnType *>(device().allocate_temp(nonContractDim * cNumGroups * sizeof(CoeffReturnType)));
1564  EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer);
1565 
1566  device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1567  vec, mat, tmp_global_accessor, thread_range, scratchSize, nCNumGroups, nonContractDim, C);
1568 
1569  typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1570  typedef TensorSycl::internal::SecondStepPartialReduction<CoeffReturnType, StorageIndex, EvaluatorPointerType,
1571  EvaluatorPointerType, Op>
1572  ReductionKernel;
1573 
1574  device().template unary_kernel_launcher<CoeffReturnType, ReductionKernel>(
1575  tmp_global_accessor, buffer,
1576  cl::sycl::nd_range<1>(cl::sycl::range<1>(Eigen::TensorSycl::internal::roundUp(nonContractDim, localRange)),
1577  cl::sycl::range<1>(localRange)),
1578  StorageIndex(1), Op(), nonContractDim, cNumGroups);
1579 
1580  device().deallocate_temp(temp_pointer);
1581  } else {
1582  typedef Eigen::TensorSycl::internal::GeneralVectorTensor<CoeffReturnType, EvaluatorPointerType, VectorMapper,
1583  TensorMapper, StorageIndex, Properties, CFactor, false,
1584  is_lhs_vec, true>
1585  ContractKernelName;
1586  device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1587  vec, mat, buffer, thread_range, scratchSize, nCNumGroups, nonContractDim, C);
1588  }
1589  }
1590 #endif
1591 
1592 #ifndef EIGEN_SYCL_DISABLE_SCALAR
1593  template <typename LhsMapper, typename RhsMapper>
1594  EIGEN_ALWAYS_INLINE void launchSC(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs,
1595  StorageIndex K) const {
1596  EIGEN_STATIC_ASSERT(!((EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1) &
1597  (EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1 - 1)),
1598  "The Local thread size must be a power of 2 for the reduction "
1599  "operation");
1600  EIGEN_CONSTEXPR StorageIndex local_range = EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1;
1601 
1602  // Here we force the code not to be more than 2-step reduction: Our empirical research shows that if each thread
1603  // reduces at least 512 elementss individually, we get better performance.
1604  const StorageIndex num_work_group = ((K + (512 * local_range - 1)) / (512 * local_range) > 1 ? local_range : 1);
1605  const StorageIndex global_range = num_work_group * local_range;
1606 
1607  typedef Eigen::TensorSycl::internal::GeneralScalarContraction<
1608  CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType, LhsMapper, RhsMapper, StorageIndex, false>
1609  ContractKernelName;
1610  auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(global_range), cl::sycl::range<1>(local_range));
1611  if (num_work_group > 1) {
1612  CoeffReturnType *temp_pointer =
1613  static_cast<CoeffReturnType *>(device().allocate_temp(num_work_group * sizeof(CoeffReturnType)));
1614  EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer);
1615  device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(lhs, rhs, tmp_global_accessor,
1616  thread_range, local_range, K);
1617  typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1618  typedef TensorSycl::internal::SecondStepFullReducer<CoeffReturnType, Op, EvaluatorPointerType,
1619  EvaluatorPointerType, StorageIndex, local_range>
1620  GenericRKernel;
1621  device().template unary_kernel_launcher<CoeffReturnType, GenericRKernel>(
1622  tmp_global_accessor, buffer,
1623  cl::sycl::nd_range<1>(cl::sycl::range<1>(local_range), cl::sycl::range<1>(local_range)), local_range, Op());
1624 
1625  device().deallocate_temp(temp_pointer);
1626  } else {
1627  device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(lhs, rhs, buffer, thread_range,
1628  local_range, K);
1629  }
1630  }
1631 #endif
1632 
1633  EIGEN_STRONG_INLINE void cleanup() {
1634  this->m_leftImpl.cleanup();
1635  this->m_rightImpl.cleanup();
1636 
1637  if (this->m_result) {
1638  this->m_device.deallocate_temp(this->m_result);
1639  this->m_result = NULL;
1640  }
1641  }
1642  // The placeholder accessors must bound to a command group handler for SYCL
1643  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
1644  this->m_leftImpl.bind(cgh);
1645  this->m_rightImpl.bind(cgh);
1646  this->m_result.bind(cgh);
1647  }
1648 };
1649 } // namespace Eigen
1650 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
Namespace containing all symbols from the Eigen library.