@@ -711,25 +711,36 @@ def test_static_graph():
711
711
with paddle .static .program_guard (
712
712
paddle .static .Program (), paddle .static .Program ()
713
713
):
714
- x_t = paddle .static .data (name = "x" , dtype = x .dtype , shape = x .shape )
715
- index_t = paddle .static .data (
716
- name = "index" , dtype = index .dtype , shape = index .shape
717
- )
718
- updates_t = paddle .static .data (
719
- name = "updates" , dtype = updates .dtype , shape = updates .shape
720
- )
721
- out_t = paddle .scatter (x_t , index_t , updates_t )
722
- feed = {
723
- x_t .name : x ,
724
- index_t .name : index ,
725
- updates_t .name : updates ,
726
- }
727
- fetch = [out_t ]
728
- gpu_exe = paddle .static .Executor (paddle .CUDAPlace (0 ))
729
- gpu_value = gpu_exe .run (feed = feed , fetch_list = fetch )[0 ]
730
- return gpu_value
731
-
732
- np .testing .assert_array_equal (test_dygraph (), test_static_graph ())
714
+ scope = paddle .static .Scope ()
715
+ with paddle .static .scope_guard (scope ):
716
+ x_t = paddle .static .data (
717
+ name = "x" , dtype = x .dtype , shape = x .shape
718
+ )
719
+ index_t = paddle .static .data (
720
+ name = "index" , dtype = index .dtype , shape = index .shape
721
+ )
722
+ updates_t = paddle .static .data (
723
+ name = "updates" , dtype = updates .dtype , shape = updates .shape
724
+ )
725
+ out_t = paddle .scatter (x_t , index_t , updates_t )
726
+ feed = {
727
+ x_t .name : x ,
728
+ index_t .name : index ,
729
+ updates_t .name : updates ,
730
+ }
731
+ fetch = [out_t ]
732
+ gpu_exe = paddle .static .Executor (paddle .CUDAPlace (0 ))
733
+ gpu_value = gpu_exe .run (feed = feed , fetch_list = fetch )[0 ]
734
+ scope ._remove_from_pool ()
735
+ return gpu_value
736
+
737
+ def test_pir_static_graph ():
738
+ with paddle .pir_utils .IrGuard ():
739
+ return test_static_graph ()
740
+
741
+ dy_out = test_dygraph ()
742
+ np .testing .assert_array_equal (dy_out , test_static_graph ())
743
+ np .testing .assert_array_equal (dy_out , test_pir_static_graph ())
733
744
734
745
735
746
@unittest .skipIf (
0 commit comments