10 #ifndef EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H 11 #define EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H 15 template <
typename Environment>
16 class ThreadPoolTempl :
public Eigen::ThreadPoolInterface {
18 typedef typename Environment::Task Task;
19 typedef RunQueue<Task, 1024> Queue;
21 ThreadPoolTempl(
int num_threads, Environment env = Environment())
22 : ThreadPoolTempl(num_threads, true, env) {}
24 ThreadPoolTempl(
int num_threads,
bool allow_spinning,
25 Environment env = Environment())
27 num_threads_(num_threads),
28 allow_spinning_(allow_spinning),
29 thread_data_(num_threads),
30 all_coprimes_(num_threads),
31 waiters_(num_threads),
32 global_steal_partition_(EncodePartition(0, num_threads_)),
38 waiters_.resize(num_threads_);
46 eigen_plain_assert(num_threads_ < kMaxThreads);
47 for (
int i = 1; i <= num_threads_; ++i) {
48 all_coprimes_.emplace_back(i);
49 ComputeCoprimes(i, &all_coprimes_.back());
51 #ifndef EIGEN_THREAD_LOCAL 52 init_barrier_.reset(
new Barrier(num_threads_));
54 thread_data_.resize(num_threads_);
55 for (
int i = 0; i < num_threads_; i++) {
56 SetStealPartition(i, EncodePartition(0, num_threads_));
57 thread_data_[i].thread.reset(
58 env_.CreateThread([
this, i]() { WorkerLoop(i); }));
60 #ifndef EIGEN_THREAD_LOCAL 63 init_barrier_->Wait();
78 for (
size_t i = 0; i < thread_data_.size(); i++) {
79 thread_data_[i].queue.Flush();
84 for (
size_t i = 0; i < thread_data_.size(); ++i)
85 thread_data_[i].thread.reset();
88 void SetStealPartitions(
const std::vector<std::pair<unsigned, unsigned>>& partitions) {
89 eigen_plain_assert(partitions.size() ==
static_cast<std::size_t
>(num_threads_));
92 for (
int i = 0; i < num_threads_; i++) {
93 const auto& pair = partitions[i];
94 unsigned start = pair.first, end = pair.second;
95 AssertBounds(start, end);
96 unsigned val = EncodePartition(start, end);
97 SetStealPartition(i, val);
101 void Schedule(std::function<
void()> fn) EIGEN_OVERRIDE {
102 ScheduleWithHint(std::move(fn), 0, num_threads_);
105 void ScheduleWithHint(std::function<
void()> fn,
int start,
106 int limit)
override {
107 Task t = env_.CreateTask(std::move(fn));
108 PerThread* pt = GetPerThread();
109 if (pt->pool ==
this) {
111 Queue& q = thread_data_[pt->thread_id].queue;
112 t = q.PushFront(std::move(t));
116 eigen_plain_assert(start < limit);
117 eigen_plain_assert(limit <= num_threads_);
118 int num_queues = limit - start;
119 int rnd = Rand(&pt->rand) % num_queues;
120 eigen_plain_assert(start + rnd < limit);
121 Queue& q = thread_data_[start + rnd].queue;
122 t = q.PushBack(std::move(t));
138 void Cancel() EIGEN_OVERRIDE {
143 #ifdef EIGEN_THREAD_ENV_SUPPORTS_CANCELLATION 144 for (
size_t i = 0; i < thread_data_.size(); i++) {
145 thread_data_[i].thread->OnCancel();
153 int NumThreads() const EIGEN_FINAL {
return num_threads_; }
155 int CurrentThreadId() const EIGEN_FINAL {
156 const PerThread* pt =
const_cast<ThreadPoolTempl*
>(
this)->GetPerThread();
157 if (pt->pool ==
this) {
158 return pt->thread_id;
172 static const int kMaxPartitionBits = 16;
173 static const int kMaxThreads = 1 << kMaxPartitionBits;
175 inline unsigned EncodePartition(
unsigned start,
unsigned limit) {
176 return (start << kMaxPartitionBits) | limit;
179 inline void DecodePartition(
unsigned val,
unsigned* start,
unsigned* limit) {
180 *limit = val & (kMaxThreads - 1);
181 val >>= kMaxPartitionBits;
185 void AssertBounds(
int start,
int end) {
186 eigen_plain_assert(start >= 0);
187 eigen_plain_assert(start < end);
188 eigen_plain_assert(end <= num_threads_);
191 inline void SetStealPartition(
size_t i,
unsigned val) {
192 thread_data_[i].steal_partition.store(val, std::memory_order_relaxed);
195 inline unsigned GetStealPartition(
int i) {
196 return thread_data_[i].steal_partition.load(std::memory_order_relaxed);
199 void ComputeCoprimes(
int N, MaxSizeVector<unsigned>* coprimes) {
200 for (
int i = 1; i <= N; i++) {
210 coprimes->push_back(i);
215 typedef typename Environment::EnvThread Thread;
218 constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) {}
219 ThreadPoolTempl* pool;
222 #ifndef EIGEN_THREAD_LOCAL 229 constexpr ThreadData() : thread(), steal_partition(0), queue() {}
230 std::unique_ptr<Thread> thread;
231 std::atomic<unsigned> steal_partition;
236 const int num_threads_;
237 const bool allow_spinning_;
238 MaxSizeVector<ThreadData> thread_data_;
239 MaxSizeVector<MaxSizeVector<unsigned>> all_coprimes_;
240 MaxSizeVector<EventCount::Waiter> waiters_;
241 unsigned global_steal_partition_;
242 std::atomic<unsigned> blocked_;
243 std::atomic<bool> spinning_;
244 std::atomic<bool> done_;
245 std::atomic<bool> cancelled_;
247 #ifndef EIGEN_THREAD_LOCAL 248 std::unique_ptr<Barrier> init_barrier_;
249 std::mutex per_thread_map_mutex_;
250 std::unordered_map<uint64_t, std::unique_ptr<PerThread>> per_thread_map_;
254 void WorkerLoop(
int thread_id) {
255 #ifndef EIGEN_THREAD_LOCAL 256 std::unique_ptr<PerThread> new_pt(
new PerThread());
257 per_thread_map_mutex_.lock();
258 bool insertOK = per_thread_map_.emplace(GlobalThreadIdHash(), std::move(new_pt)).second;
259 eigen_plain_assert(insertOK);
260 EIGEN_UNUSED_VARIABLE(insertOK);
261 per_thread_map_mutex_.unlock();
262 init_barrier_->Notify();
263 init_barrier_->Wait();
265 PerThread* pt = GetPerThread();
267 pt->rand = GlobalThreadIdHash();
268 pt->thread_id = thread_id;
269 Queue& q = thread_data_[thread_id].queue;
270 EventCount::Waiter* waiter = &waiters_[thread_id];
275 const int spin_count =
276 allow_spinning_ && num_threads_ > 0 ? 5000 / num_threads_ : 0;
277 if (num_threads_ == 1) {
284 while (!cancelled_) {
285 Task t = q.PopFront();
286 for (
int i = 0; i < spin_count && !t.f; i++) {
287 if (!cancelled_.load(std::memory_order_relaxed)) {
292 if (!WaitForWork(waiter, &t)) {
301 while (!cancelled_) {
302 Task t = q.PopFront();
309 if (allow_spinning_ && !spinning_ && !spinning_.exchange(
true)) {
310 for (
int i = 0; i < spin_count && !t.f; i++) {
311 if (!cancelled_.load(std::memory_order_relaxed)) {
320 if (!WaitForWork(waiter, &t)) {
336 Task Steal(
unsigned start,
unsigned limit) {
337 PerThread* pt = GetPerThread();
338 const size_t size = limit - start;
339 unsigned r = Rand(&pt->rand);
342 eigen_plain_assert(all_coprimes_[size - 1].size() < (1<<30));
343 unsigned victim = ((uint64_t)r * (uint64_t)size) >> 32;
344 unsigned index = ((uint64_t) all_coprimes_[size - 1].size() * (uint64_t)r) >> 32;
345 unsigned inc = all_coprimes_[size - 1][index];
347 for (
unsigned i = 0; i < size; i++) {
348 eigen_plain_assert(start + victim < limit);
349 Task t = thread_data_[start + victim].queue.PopBack();
354 if (victim >= size) {
363 PerThread* pt = GetPerThread();
364 unsigned partition = GetStealPartition(pt->thread_id);
367 if (global_steal_partition_ == partition)
return Task();
368 unsigned start, limit;
369 DecodePartition(partition, &start, &limit);
370 AssertBounds(start, limit);
372 return Steal(start, limit);
377 return Steal(0, num_threads_);
384 bool WaitForWork(EventCount::Waiter* waiter, Task* t) {
385 eigen_plain_assert(!t->f);
390 int victim = NonEmptyQueueIndex();
396 *t = thread_data_[victim].queue.PopBack();
405 if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) {
412 if (NonEmptyQueueIndex() != -1) {
426 ec_.CommitWait(waiter);
431 int NonEmptyQueueIndex() {
432 PerThread* pt = GetPerThread();
436 const size_t size = thread_data_.size();
437 unsigned r = Rand(&pt->rand);
438 unsigned inc = all_coprimes_[size - 1][r % all_coprimes_[size - 1].size()];
439 unsigned victim = r % size;
440 for (
unsigned i = 0; i < size; i++) {
441 if (!thread_data_[victim].queue.Empty()) {
445 if (victim >= size) {
452 static EIGEN_STRONG_INLINE uint64_t GlobalThreadIdHash() {
453 return std::hash<std::thread::id>()(std::this_thread::get_id());
456 EIGEN_STRONG_INLINE PerThread* GetPerThread() {
457 #ifndef EIGEN_THREAD_LOCAL 458 static PerThread dummy;
459 auto it = per_thread_map_.find(GlobalThreadIdHash());
460 if (it == per_thread_map_.end()) {
463 return it->second.get();
466 EIGEN_THREAD_LOCAL PerThread per_thread_;
467 PerThread* pt = &per_thread_;
472 static EIGEN_STRONG_INLINE
unsigned Rand(uint64_t* state) {
473 uint64_t current = *state;
475 *state = current * 6364136223846793005ULL + 0xda3e39cb94b95bdbULL;
477 return static_cast<unsigned>((current ^ (current >> 22)) >>
478 (22 + (current >> 61)));
482 typedef ThreadPoolTempl<StlThreadEnvironment> ThreadPool;
486 #endif // EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H Namespace containing all symbols from the Eigen library.