Skip to content

Commit 14ce793

Browse files
committed
WorkQueue update
1 parent 4759bc8 commit 14ce793

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

paddle/fluid/framework/new_executor/nonblocking_threadpool.h

+10-4
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class ThreadPoolTempl {
7373
allow_spinning_(allow_spinning),
7474
global_steal_partition_(EncodePartition(0, num_threads_)),
7575
blocked_(0),
76+
num_tasks_(0),
7677
spinning_(0),
7778
done_(false),
7879
cancelled_(false),
@@ -143,6 +144,7 @@ class ThreadPoolTempl {
143144
void AddTaskWithHint(std::function<void()> fn, int start, int limit) {
144145
Task t = env_.CreateTask(std::move(fn));
145146
PerThread* pt = GetPerThread();
147+
uint64_t num_tasks = num_tasks_.fetch_add(1, std::memory_order_relaxed) + 1;
146148
if (pt->pool == this) {
147149
// Worker thread of this pool, push onto the thread's queue.
148150
Queue& q = thread_data_[pt->thread_id].queue;
@@ -166,8 +168,11 @@ class ThreadPoolTempl {
166168
// this. We expect that such scenario is prevented by program, that is,
167169
// this is kept alive while any threads can potentially be in Schedule.
168170
if (!t.f) {
169-
ec_.Notify(false);
171+
if (num_tasks > num_threads_ - blocked_.load(std::memory_order_relaxed)) {
172+
ec_.Notify(false);
173+
}
170174
} else {
175+
num_tasks_.fetch_sub(1, std::memory_order_relaxed);
171176
env_.ExecuteTask(t); // Push failed, execute directly.
172177
}
173178
}
@@ -263,6 +268,7 @@ class ThreadPoolTempl {
263268
std::vector<std::vector<unsigned>> all_coprimes_;
264269
unsigned global_steal_partition_;
265270
std::atomic<unsigned> blocked_;
271+
std::atomic<uint64_t> num_tasks_;
266272
std::atomic<bool> spinning_;
267273
std::atomic<bool> done_;
268274
std::atomic<bool> cancelled_;
@@ -305,6 +311,7 @@ class ThreadPoolTempl {
305311
}
306312
if (t.f) {
307313
env_.ExecuteTask(t);
314+
num_tasks_.fetch_sub(1, std::memory_order_relaxed);
308315
}
309316
}
310317
} else {
@@ -315,16 +322,14 @@ class ThreadPoolTempl {
315322
if (!t.f) {
316323
t = GlobalSteal();
317324
if (!t.f) {
318-
// Leave one thread spinning. This reduces latency.
319-
if (allow_spinning_ && !spinning_ && !spinning_.exchange(true)) {
325+
if (allow_spinning_) {
320326
for (int i = 0; i < spin_count && !t.f; i++) {
321327
if (!cancelled_.load(std::memory_order_relaxed)) {
322328
t = GlobalSteal();
323329
} else {
324330
return;
325331
}
326332
}
327-
spinning_ = false;
328333
}
329334
if (!t.f) {
330335
if (!WaitForWork(waiter, &t)) {
@@ -336,6 +341,7 @@ class ThreadPoolTempl {
336341
}
337342
if (t.f) {
338343
env_.ExecuteTask(t);
344+
num_tasks_.fetch_sub(1, std::memory_order_relaxed);
339345
}
340346
}
341347
}

0 commit comments

Comments
 (0)