@@ -70,7 +70,6 @@ bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& batchTask) {
70
70
// 每个lod型的fetchvar拷贝到对应的临时空间中
71
71
// 最后再计算临时空间的总量,合并fetchvar和lod
72
72
fetchvar_batch = 0 ;
73
-
74
73
} else {
75
74
// 普通fetchvar情况,此时该Task总的fetchvar_batch =
76
75
// 输入的总的batch_size()
@@ -86,14 +85,15 @@ bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& batchTask) {
86
85
// 此时 lod 为空。
87
86
tensor_out.lod = batchTask._batch_out [fetchvar_index].lod ;
88
87
// resize all batch memory at one time
89
-
88
+
90
89
size_t databuf_size = fetchvar_batch * fetchvar_bytesize_index;
91
-
92
- void * databuf_data = MempoolWrapper::instance ().malloc (databuf_size,memoryPtr);
90
+
91
+ void * databuf_data =
92
+ MempoolWrapper::instance ().malloc (databuf_size, memoryPtr);
93
93
paddle::PaddleBuf paddleBuf (databuf_data, databuf_size);
94
94
tensor_out.data = paddleBuf;
95
-
96
- // tensor_out.data.Resize(databuf_size);
95
+
96
+ // tensor_out.data.Resize(databuf_size);
97
97
} else {
98
98
// 当taskmeta_num = 1时,由于同时只有一个taskMeta操作task
99
99
// 不涉及线程安全问题,所以此时可以直接由taskMeta->task->resize->copy
@@ -213,7 +213,8 @@ void TaskExecutor<TaskT>::stop() {
213
213
template <typename TaskT>
214
214
TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(
215
215
const void * inVectorT_ptr,
216
- void * outVectorT_ptr, MempoolRegion* memoryPtr) { // NOLINT
216
+ void * outVectorT_ptr,
217
+ MempoolRegion* memoryPtr) { // NOLINT
217
218
TaskT* task = butil::get_object<TaskT>();
218
219
if (!task) {
219
220
LOG (ERROR) << " Failed get TaskT from object pool" ;
@@ -240,7 +241,7 @@ TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(
240
241
task->write_fd = fds[1 ];
241
242
task->owner_tid = ::syscall (SYS_gettid);
242
243
task->memoryPtr = memoryPtr;
243
- // task->_bspec_key = _bspec_key;
244
+ // task->_bspec_key = _bspec_key;
244
245
task->inVectorT_ptr = (const InVectorT*)inVectorT_ptr;
245
246
task->outVectorT_ptr = (OutVectorT*)outVectorT_ptr;
246
247
if (!task->task_init ()) {
@@ -309,20 +310,23 @@ bool TaskExecutor<TaskT>::move_task_to_batch(
309
310
}
310
311
311
312
// combine_task_valid负责判断是否能够合并
312
- // 除最外层的shape外,内层shape应一致才能合并 。
313
+ // 除最外层的shape外,内层shape应一致或者允许Padding才能合并 。
313
314
// 否则跳出循环,放入下一个batchTask中。
314
315
// 以此保证batch.append_task(task)中的task的内层shape相同。
315
316
316
317
// 对于Shape[0] = 1 而!=batch的情况,因为合并时,取其中一个的值
317
318
// 所以要求该feedvar必须相等,才能合并。
318
319
// 否则跳出循环,放入下一个batchTask中。
319
320
// 目前没有PaddleTensor和PaddleBuff没有重载==,所以只能比较内存.
320
- // TODO(HexToString): 可以考虑后期支持AutoPadding.
321
321
if (previous_task != nullptr ) {
322
- if (! task->combine_task_valid (previous_task)) {
322
+ if (task->combine_task_valid (previous_task) == 0 ) {
323
323
break ;
324
324
}
325
325
}
326
+
327
+ if (batchTask.padding (task) != 2 ) {
328
+ break ;
329
+ }
326
330
size_t rem = batchTask.append_task (task);
327
331
previous_task = task;
328
332
if (task->rem <= 0 ) {
@@ -407,10 +411,11 @@ int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) {
407
411
}
408
412
409
413
template <typename InItemT, typename OutItemT>
410
- bool TaskManager<InItemT, OutItemT>::schedule(const void * in,
411
- void * out, MempoolRegion* memoryPtr) { // NOLINT
414
+ bool TaskManager<InItemT, OutItemT>::schedule(
415
+ const void * in, void * out, MempoolRegion* memoryPtr) { // NOLINT
412
416
TaskHandler<TaskT> handler =
413
- TaskExecutorVector<TaskT>::instance ()[_model_index].schedule (in, out, memoryPtr);
417
+ TaskExecutorVector<TaskT>::instance ()[_model_index].schedule (
418
+ in, out, memoryPtr);
414
419
415
420
if (handler.valid ()) {
416
421
_task_owned = handler;
0 commit comments