Skip to content

Commit b94ea21

Browse files
fix gpu to gpu copy bug (#72792)
1 parent 4a03621 commit b94ea21

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

paddle/phi/core/tensor_utils.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ void Copy(const Context& dev_ctx,
169169
auto src_gpu_place = src_place;
170170
auto dst_gpu_place = dst_place;
171171
auto ctx_place = dev_ctx.GetPlace();
172+
172173
PADDLE_ENFORCE_EQ(
173174
ctx_place.GetType() == AllocationType::GPU,
174175
true,
@@ -178,18 +179,19 @@ void Copy(const Context& dev_ctx,
178179
auto stream =
179180
blocking ? nullptr
180181
: reinterpret_cast<const phi::GPUContext&>(dev_ctx).stream();
181-
if (src_place.GetType() == dst_place.GetType()) {
182+
if (src_place.GetDeviceId() == dst_place.GetDeviceId()) {
182183
memory_utils::Copy(
183184
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
184185
} else {
185-
if (ctx_place.GetType() == src_place.GetType()) {
186+
if (ctx_place.GetDeviceId() == src_place.GetDeviceId()) {
186187
memory_utils::Copy(
187188
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
188189
phi::DeviceContextPool::Instance().Get(src.place())->Wait();
189-
} else if (ctx_place.GetType() == dst_place.GetType()) {
190+
} else if (ctx_place.GetDeviceId() == dst_place.GetDeviceId()) {
190191
phi::DeviceContextPool::Instance().Get(src.place())->Wait();
191192
memory_utils::Copy(
192193
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
194+
phi::DeviceContextPool::Instance().Get(dst_place)->Wait();
193195
} else {
194196
PADDLE_THROW(errors::Unavailable(
195197
"Context place dose not match the source and destination place."));

0 commit comments

Comments
 (0)