14
14
15
15
#include < gtest/gtest.h>
16
16
17
+ #include " paddle/fluid/framework/new_executor/new_ir_interpreter.h"
17
18
#include " paddle/fluid/framework/new_executor/standalone_executor.h"
18
19
#include " paddle/fluid/ir/dialect/pd_dialect.h"
19
20
#include " paddle/fluid/ir/dialect/pd_op.h"
20
21
#include " paddle/fluid/ir/dialect/pd_type.h"
21
22
#include " paddle/fluid/ir/dialect/utils.h"
22
23
#include " paddle/fluid/ir/interface/op_yaml_info.h"
24
+ #include " paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
23
25
#include " paddle/fluid/platform/init_phi.h"
24
26
#include " paddle/ir/core/block.h"
25
27
#include " paddle/ir/core/builtin_attribute.h"
28
30
#include " paddle/ir/core/ir_context.h"
29
31
#include " paddle/ir/core/program.h"
30
32
#include " paddle/ir/core/utils.h"
31
- #include " paddle/phi/core/meta_tensor.h"
32
- #include " paddle/phi/infermeta/binary.h"
33
33
34
34
DECLARE_FILE_SYMBOLS (kernel_dialect);
35
35
36
36
PD_DECLARE_KERNEL (full, CPU, ALL_LAYOUT);
37
37
PD_DECLARE_KERNEL (tanh, CPU, ALL_LAYOUT);
38
38
PD_DECLARE_KERNEL (tanh_grad, CPU, ALL_LAYOUT);
39
+ namespace paddle {
40
+ namespace framework {
39
41
40
42
TEST (VJP, TanhBackwardTest) {
41
43
ir::IrContext* ctx = ir::IrContext::Instance ();
@@ -60,4 +62,39 @@ TEST(VJP, TanhBackwardTest) {
60
62
std::vector<ir::OpResult> out_grads{op3.out ()};
61
63
std::vector<ir::OpResult> grad_res =
62
64
tanh_vjp_interface.Vjp (op2.operation (), out_grads, stop_gradients);
65
+
66
+ std::ostringstream print_stream;
67
+ program.Print (print_stream);
68
+ std::cout << print_stream.str ();
69
+
70
+ auto kernel_program = paddle::dialect::PdOpLowerToKernelPass (&program);
71
+
72
+ auto place = platform::CPUPlace ();
73
+ Scope scope;
74
+
75
+ ProgramDesc prog_desc;
76
+ InterpreterCore test_core (place, std::move (kernel_program), &scope);
77
+ test_core.BetaRun ({});
78
+ std::stringstream os;
79
+ os << reinterpret_cast <NewIRInterpreter*>(
80
+ const_cast <InterpreterBaseImpl*>(test_core.Impl ()));
81
+ std::string prefix_str = os.str ();
82
+ auto out_tensor =
83
+ test_core.local_scope () == nullptr
84
+ ? scope.FindVar (prefix_str + " _inner_var_1" )->Get <phi::DenseTensor>()
85
+ : test_core.local_scope ()
86
+ ->FindVar (prefix_str + " _inner_var_1" )
87
+ ->Get <phi::DenseTensor>();
88
+ auto grad_out_tensor =
89
+ test_core.local_scope () == nullptr
90
+ ? scope.FindVar (prefix_str + " _inner_var_3" )->Get <phi::DenseTensor>()
91
+ : test_core.local_scope ()
92
+ ->FindVar (prefix_str + " _inner_var_3" )
93
+ ->Get <phi::DenseTensor>();
94
+
95
+ ASSERT_NEAR (out_tensor.data <float >()[0 ], 0.76159 , 1e-5 );
96
+ ASSERT_NEAR (grad_out_tensor.data <float >()[0 ], 0.83995 , 1e-5 );
63
97
}
98
+
99
+ } // namespace framework
100
+ } // namespace paddle
0 commit comments