Skip to content

Commit a9e9d01

Browse files
committed
add test for tanh vjp
1 parent d49d38a commit a9e9d01

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

paddle/fluid/ir/dialect/vjp_interface.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ std::vector<ir::OpResult> TanhOp::Vjp(ir::Operation* op,
3232
std::vector<ir::OpResult> res;
3333
res.reserve(tensor_res.size());
3434
// TODO(wanghao107): maybe combile here
35-
for (int i = 0; i < tensor_res.size(); ++i) {
35+
for (size_t i = 0; i < tensor_res.size(); ++i) {
3636
res.emplace_back(
3737
std::static_pointer_cast<primitive::experimental::DescTensor>(
3838
tensor_res[i][0].impl())

test/cpp/prim/test_vjp.cc

+39-2
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414

1515
#include <gtest/gtest.h>
1616

17+
#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h"
1718
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
1819
#include "paddle/fluid/ir/dialect/pd_dialect.h"
1920
#include "paddle/fluid/ir/dialect/pd_op.h"
2021
#include "paddle/fluid/ir/dialect/pd_type.h"
2122
#include "paddle/fluid/ir/dialect/utils.h"
2223
#include "paddle/fluid/ir/interface/op_yaml_info.h"
24+
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
2325
#include "paddle/fluid/platform/init_phi.h"
2426
#include "paddle/ir/core/block.h"
2527
#include "paddle/ir/core/builtin_attribute.h"
@@ -28,14 +30,14 @@
2830
#include "paddle/ir/core/ir_context.h"
2931
#include "paddle/ir/core/program.h"
3032
#include "paddle/ir/core/utils.h"
31-
#include "paddle/phi/core/meta_tensor.h"
32-
#include "paddle/phi/infermeta/binary.h"
3333

3434
DECLARE_FILE_SYMBOLS(kernel_dialect);
3535

3636
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
3737
PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT);
3838
PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
39+
namespace paddle {
40+
namespace framework {
3941

4042
TEST(VJP, TanhBackwardTest) {
4143
ir::IrContext* ctx = ir::IrContext::Instance();
@@ -60,4 +62,39 @@ TEST(VJP, TanhBackwardTest) {
6062
std::vector<ir::OpResult> out_grads{op3.out()};
6163
std::vector<ir::OpResult> grad_res =
6264
tanh_vjp_interface.Vjp(op2.operation(), out_grads, stop_gradients);
65+
66+
std::ostringstream print_stream;
67+
program.Print(print_stream);
68+
std::cout << print_stream.str();
69+
70+
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
71+
72+
auto place = platform::CPUPlace();
73+
Scope scope;
74+
75+
ProgramDesc prog_desc;
76+
InterpreterCore test_core(place, std::move(kernel_program), &scope);
77+
test_core.BetaRun({});
78+
std::stringstream os;
79+
os << reinterpret_cast<NewIRInterpreter*>(
80+
const_cast<InterpreterBaseImpl*>(test_core.Impl()));
81+
std::string prefix_str = os.str();
82+
auto out_tensor =
83+
test_core.local_scope() == nullptr
84+
? scope.FindVar(prefix_str + "_inner_var_1")->Get<phi::DenseTensor>()
85+
: test_core.local_scope()
86+
->FindVar(prefix_str + "_inner_var_1")
87+
->Get<phi::DenseTensor>();
88+
auto grad_out_tensor =
89+
test_core.local_scope() == nullptr
90+
? scope.FindVar(prefix_str + "_inner_var_3")->Get<phi::DenseTensor>()
91+
: test_core.local_scope()
92+
->FindVar(prefix_str + "_inner_var_3")
93+
->Get<phi::DenseTensor>();
94+
95+
ASSERT_NEAR(out_tensor.data<float>()[0], 0.76159, 1e-5);
96+
ASSERT_NEAR(grad_out_tensor.data<float>()[0], 0.83995, 1e-5);
6397
}
98+
99+
} // namespace framework
100+
} // namespace paddle

0 commit comments

Comments
 (0)