Skip to content

Commit 147d767

Browse files
authored
Fix (#64092)
1 parent 774f3e3 commit 147d767

File tree

4 files changed

+28
-24
lines changed

4 files changed

+28
-24
lines changed

paddle/fluid/operators/math/beam_search.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,15 @@ class BeamSearchFunctor<phi::CPUContext, T> {
6969
// the output tensor shape should be [num_instances, 1]
7070
auto dims = common::make_ddim(
7171
std::vector<int64_t>({static_cast<int>(num_instances), 1}));
72-
auto *selected_ids_data =
73-
selected_ids->mutable_data<int64_t>(dims, platform::CPUPlace());
74-
auto *selected_scores_data =
75-
selected_scores->mutable_data<float>(dims, platform::CPUPlace());
72+
selected_ids->Resize(dims);
73+
auto *selected_ids_data = context.template Alloc<int64_t>(selected_ids);
74+
selected_scores->Resize(dims);
75+
auto *selected_scores_data = context.template Alloc<float>(selected_scores);
76+
if (parent_idx != nullptr) {
77+
parent_idx->Resize({static_cast<int64_t>(num_instances)});
78+
}
7679
auto *parent_idx_data =
77-
parent_idx
78-
? parent_idx->mutable_data<int>(
79-
{static_cast<int64_t>(num_instances)}, platform::CPUPlace())
80-
: nullptr;
80+
parent_idx ? context.template Alloc<int>(parent_idx) : nullptr;
8181

8282
// fill in data
8383
std::vector<size_t> low_level;

paddle/fluid/operators/math/beam_search.cu

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -433,17 +433,18 @@ class BeamSearchFunctor<phi::GPUContext, T> {
433433
// Reserve a big enough memory.
434434
auto selected_dims =
435435
common::make_ddim({static_cast<int64_t>(num_seqs * beam_size), 1});
436-
int64_t* selected_ids_data =
437-
selected_ids->mutable_data<int64_t>(selected_dims, context.GetPlace());
436+
selected_ids->Resize(selected_dims);
437+
int64_t* selected_ids_data = context.template Alloc<int64_t>(selected_ids);
438+
selected_scores->Resize(selected_dims);
438439
float* selected_scores_data =
439-
selected_scores->mutable_data<float>(selected_dims, context.GetPlace());
440+
context.template Alloc<float>(selected_scores);
441+
if (parent_idx != nullptr) {
442+
parent_idx->Resize({static_cast<int64_t>(num_seqs * beam_size)});
443+
}
440444
int* parent_idx_data =
441-
parent_idx ? parent_idx->mutable_data<int>(
442-
{static_cast<int64_t>(num_seqs * beam_size)},
443-
context.GetPlace())
444-
: nullptr;
445+
parent_idx ? context.template Alloc<int>(parent_idx) : nullptr;
445446

446-
framework::LoD selected_lod(2);
447+
phi::LoD selected_lod(2);
447448
selected_lod[0].assign(abs_lod[level].begin(), abs_lod[level].end());
448449
selected_lod[1].resize(scores->dims()[0] + 1);
449450
phi::MixVector<size_t> mix_vector(&selected_lod[1]);

paddle/fluid/operators/math/beam_search_xpu.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,15 @@ class BeamSearchFunctor<platform::XPUDeviceContext, T> {
9494
// the output tensor shape should be [num_instances, 1]
9595
auto dims = common::make_ddim(
9696
std::vector<int64_t>({static_cast<int>(num_instances), 1}));
97-
auto *selected_ids_data =
98-
selected_ids->mutable_data<int64_t>(dims, platform::CPUPlace());
99-
auto *selected_scores_data =
100-
selected_scores->mutable_data<float>(dims, platform::CPUPlace());
97+
selected_ids->Resize(dims);
98+
auto *selected_ids_data = context.template Alloc<int64_t>(selected_ids);
99+
selected_scores->Resize(dims);
100+
auto *selected_scores_data = context.template Alloc<float>(selected_scores);
101+
if (parent_idx != nullptr) {
102+
parent_idx->Resize({static_cast<int64_t>(num_instances)});
103+
}
101104
auto *parent_idx_data =
102-
parent_idx
103-
? parent_idx->mutable_data<int>(
104-
{static_cast<int64_t>(num_instances)}, platform::CPUPlace())
105-
: nullptr;
105+
parent_idx ? context.template Alloc<int>(parent_idx) : nullptr;
106106

107107
// fill in data
108108
std::vector<size_t> low_level;

test/cpp/fluid/math/beam_search_test.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ void TestBeamSearch() {
7373

7474
auto* place = new Place();
7575
DeviceContext* context = new DeviceContext(*place);
76+
context->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
77+
.GetAllocator(phi::CPUPlace())
78+
.get());
7679
if (paddle::platform::is_cpu_place(*place)) {
7780
PrepareCPUTensors(&ids, &scores, &pre_ids, &pre_scores);
7881
} else {

0 commit comments

Comments
 (0)