@@ -169,6 +169,7 @@ void Copy(const Context& dev_ctx,
169
169
auto src_gpu_place = src_place;
170
170
auto dst_gpu_place = dst_place;
171
171
auto ctx_place = dev_ctx.GetPlace ();
172
+
172
173
PADDLE_ENFORCE_EQ (
173
174
ctx_place.GetType () == AllocationType::GPU,
174
175
true ,
@@ -178,18 +179,19 @@ void Copy(const Context& dev_ctx,
178
179
auto stream =
179
180
blocking ? nullptr
180
181
: reinterpret_cast <const phi::GPUContext&>(dev_ctx).stream ();
181
- if (src_place.GetType () == dst_place.GetType ()) {
182
+ if (src_place.GetDeviceId () == dst_place.GetDeviceId ()) {
182
183
memory_utils::Copy (
183
184
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
184
185
} else {
185
- if (ctx_place.GetType () == src_place.GetType ()) {
186
+ if (ctx_place.GetDeviceId () == src_place.GetDeviceId ()) {
186
187
memory_utils::Copy (
187
188
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
188
189
phi::DeviceContextPool::Instance ().Get (src.place ())->Wait ();
189
- } else if (ctx_place.GetType () == dst_place.GetType ()) {
190
+ } else if (ctx_place.GetDeviceId () == dst_place.GetDeviceId ()) {
190
191
phi::DeviceContextPool::Instance ().Get (src.place ())->Wait ();
191
192
memory_utils::Copy (
192
193
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
194
+ phi::DeviceContextPool::Instance ().Get (dst_place)->Wait ();
193
195
} else {
194
196
PADDLE_THROW (errors::Unavailable (
195
197
" Context place dose not match the source and destination place." ));
0 commit comments