Skip to content

Commit c5b2ad6

Browse files
authored
[PIR]Open uts for scatter (#60694)
1 parent db4ba94 commit c5b2ad6

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

paddle/phi/kernels/impl/data_impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ template <typename T, typename Context>
2727
void ShadowFeedKernel(const Context& ctx,
2828
const DenseTensor& x,
2929
DenseTensor* out) {
30-
ctx.template Alloc<T>(out);
3130
if (!x.initialized()) {
31+
ctx.template Alloc<T>(out);
3232
return;
3333
}
34-
if (x.place() == out->place()) {
34+
if (x.place() == ctx.GetPlace()) {
3535
out->ShareDataWith(x);
3636
out->set_lod(x.lod());
3737
} else {

test/legacy_test/test_scatter_op.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -711,25 +711,36 @@ def test_static_graph():
711711
with paddle.static.program_guard(
712712
paddle.static.Program(), paddle.static.Program()
713713
):
714-
x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape)
715-
index_t = paddle.static.data(
716-
name="index", dtype=index.dtype, shape=index.shape
717-
)
718-
updates_t = paddle.static.data(
719-
name="updates", dtype=updates.dtype, shape=updates.shape
720-
)
721-
out_t = paddle.scatter(x_t, index_t, updates_t)
722-
feed = {
723-
x_t.name: x,
724-
index_t.name: index,
725-
updates_t.name: updates,
726-
}
727-
fetch = [out_t]
728-
gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0))
729-
gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0]
730-
return gpu_value
731-
732-
np.testing.assert_array_equal(test_dygraph(), test_static_graph())
714+
scope = paddle.static.Scope()
715+
with paddle.static.scope_guard(scope):
716+
x_t = paddle.static.data(
717+
name="x", dtype=x.dtype, shape=x.shape
718+
)
719+
index_t = paddle.static.data(
720+
name="index", dtype=index.dtype, shape=index.shape
721+
)
722+
updates_t = paddle.static.data(
723+
name="updates", dtype=updates.dtype, shape=updates.shape
724+
)
725+
out_t = paddle.scatter(x_t, index_t, updates_t)
726+
feed = {
727+
x_t.name: x,
728+
index_t.name: index,
729+
updates_t.name: updates,
730+
}
731+
fetch = [out_t]
732+
gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0))
733+
gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0]
734+
scope._remove_from_pool()
735+
return gpu_value
736+
737+
def test_pir_static_graph():
738+
with paddle.pir_utils.IrGuard():
739+
return test_static_graph()
740+
741+
dy_out = test_dygraph()
742+
np.testing.assert_array_equal(dy_out, test_static_graph())
743+
np.testing.assert_array_equal(dy_out, test_pir_static_graph())
733744

734745

735746
@unittest.skipIf(

0 commit comments

Comments
 (0)