Skip to content

Commit 75f59f4

Browse files
authored
Merge pull request #1713 from HexToString/add_2_lod_and_padding_new
Add 2 lod and padding new
2 parents 67b17ec + 3d9b462 commit 75f59f4

File tree

2 files changed

+555
-81
lines changed

2 files changed

+555
-81
lines changed

core/predictor/framework/bsf-inl.h

+19-14
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& batchTask) {
7070
// 每个lod型的fetchvar拷贝到对应的临时空间中
7171
// 最后再计算临时空间的总量,合并fetchvar和lod
7272
fetchvar_batch = 0;
73-
7473
} else {
7574
// 普通fetchvar情况,此时该Task总的fetchvar_batch =
7675
// 输入的总的batch_size()
@@ -86,14 +85,15 @@ bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& batchTask) {
8685
// 此时 lod 为空。
8786
tensor_out.lod = batchTask._batch_out[fetchvar_index].lod;
8887
// resize all batch memory at one time
89-
88+
9089
size_t databuf_size = fetchvar_batch * fetchvar_bytesize_index;
91-
92-
void* databuf_data = MempoolWrapper::instance().malloc(databuf_size,memoryPtr);
90+
91+
void* databuf_data =
92+
MempoolWrapper::instance().malloc(databuf_size, memoryPtr);
9393
paddle::PaddleBuf paddleBuf(databuf_data, databuf_size);
9494
tensor_out.data = paddleBuf;
95-
96-
//tensor_out.data.Resize(databuf_size);
95+
96+
// tensor_out.data.Resize(databuf_size);
9797
} else {
9898
// 当taskmeta_num = 1时,由于同时只有一个taskMeta操作task
9999
// 不涉及线程安全问题,所以此时可以直接由taskMeta->task->resize->copy
@@ -213,7 +213,8 @@ void TaskExecutor<TaskT>::stop() {
213213
template <typename TaskT>
214214
TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(
215215
const void* inVectorT_ptr,
216-
void* outVectorT_ptr, MempoolRegion* memoryPtr) { // NOLINT
216+
void* outVectorT_ptr,
217+
MempoolRegion* memoryPtr) { // NOLINT
217218
TaskT* task = butil::get_object<TaskT>();
218219
if (!task) {
219220
LOG(ERROR) << "Failed get TaskT from object pool";
@@ -240,7 +241,7 @@ TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(
240241
task->write_fd = fds[1];
241242
task->owner_tid = ::syscall(SYS_gettid);
242243
task->memoryPtr = memoryPtr;
243-
//task->_bspec_key = _bspec_key;
244+
// task->_bspec_key = _bspec_key;
244245
task->inVectorT_ptr = (const InVectorT*)inVectorT_ptr;
245246
task->outVectorT_ptr = (OutVectorT*)outVectorT_ptr;
246247
if (!task->task_init()) {
@@ -309,20 +310,23 @@ bool TaskExecutor<TaskT>::move_task_to_batch(
309310
}
310311

311312
// combine_task_valid负责判断是否能够合并
312-
// 除最外层的shape外,内层shape应一致才能合并
313+
// 除最外层的shape外,内层shape应一致或者允许Padding才能合并
313314
// 否则跳出循环,放入下一个batchTask中。
314315
// 以此保证batch.append_task(task)中的task的内层shape相同。
315316

316317
// 对于Shape[0] = 1 而!=batch的情况,因为合并时,取其中一个的值
317318
// 所以要求该feedvar必须相等,才能合并。
318319
// 否则跳出循环,放入下一个batchTask中。
319320
// 目前没有PaddleTensor和PaddleBuff没有重载==,所以只能比较内存.
320-
// TODO(HexToString): 可以考虑后期支持AutoPadding.
321321
if (previous_task != nullptr) {
322-
if (!task->combine_task_valid(previous_task)) {
322+
if (task->combine_task_valid(previous_task) == 0) {
323323
break;
324324
}
325325
}
326+
327+
if (batchTask.padding(task) != 2) {
328+
break;
329+
}
326330
size_t rem = batchTask.append_task(task);
327331
previous_task = task;
328332
if (task->rem <= 0) {
@@ -407,10 +411,11 @@ int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) {
407411
}
408412

409413
template <typename InItemT, typename OutItemT>
410-
bool TaskManager<InItemT, OutItemT>::schedule(const void* in,
411-
void* out, MempoolRegion* memoryPtr) { // NOLINT
414+
bool TaskManager<InItemT, OutItemT>::schedule(
415+
const void* in, void* out, MempoolRegion* memoryPtr) { // NOLINT
412416
TaskHandler<TaskT> handler =
413-
TaskExecutorVector<TaskT>::instance()[_model_index].schedule(in, out, memoryPtr);
417+
TaskExecutorVector<TaskT>::instance()[_model_index].schedule(
418+
in, out, memoryPtr);
414419

415420
if (handler.valid()) {
416421
_task_owned = handler;

0 commit comments

Comments
 (0)