Skip to content

Commit d49d38a

Browse files
committed
add vjp test for new ir
1 parent 0738201 commit d49d38a

File tree

3 files changed

+72
-1
lines changed

3 files changed

+72
-1
lines changed

paddle/fluid/ir/dialect/ir_api.cc

-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14-
#pragma once
1514

1615
#include "paddle/fluid/ir/dialect/ir_api.h"
1716
#include "paddle/fluid/ir/dialect/pd_dialect.h"

test/cpp/prim/CMakeLists.txt

+9
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,12 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER))
6161
init_env_utils
6262
python)
6363
endif()
64+
65+
# skip win32 since wget is not installed by default on windows machine.
66+
67+
if(NOT WIN32)
68+
cc_test(
69+
test_vjp_new_ir
70+
SRCS test_vjp.cc
71+
DEPS phi_kernel_adaptor pd_dialect ir)
72+
endif()

test/cpp/prim/test_vjp.cc

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <gtest/gtest.h>
16+
17+
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
18+
#include "paddle/fluid/ir/dialect/pd_dialect.h"
19+
#include "paddle/fluid/ir/dialect/pd_op.h"
20+
#include "paddle/fluid/ir/dialect/pd_type.h"
21+
#include "paddle/fluid/ir/dialect/utils.h"
22+
#include "paddle/fluid/ir/interface/op_yaml_info.h"
23+
#include "paddle/fluid/platform/init_phi.h"
24+
#include "paddle/ir/core/block.h"
25+
#include "paddle/ir/core/builtin_attribute.h"
26+
#include "paddle/ir/core/builtin_dialect.h"
27+
#include "paddle/ir/core/builtin_op.h"
28+
#include "paddle/ir/core/ir_context.h"
29+
#include "paddle/ir/core/program.h"
30+
#include "paddle/ir/core/utils.h"
31+
#include "paddle/phi/core/meta_tensor.h"
32+
#include "paddle/phi/infermeta/binary.h"
33+
34+
DECLARE_FILE_SYMBOLS(kernel_dialect);
35+
36+
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
37+
PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT);
38+
PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
39+
40+
TEST(VJP, TanhBackwardTest) {
41+
ir::IrContext* ctx = ir::IrContext::Instance();
42+
ir::Program program((ctx));
43+
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
44+
45+
ir::Builder builder = ir::Builder(ctx, program.block());
46+
47+
paddle::dialect::FullOp op1 = builder.Build<paddle::dialect::FullOp>(
48+
std::vector<int64_t>{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());
49+
50+
paddle::dialect::TanhOp op2 =
51+
builder.Build<paddle::dialect::TanhOp>(op1.out());
52+
53+
paddle::dialect::FullOp op3 = builder.Build<paddle::dialect::FullOp>(
54+
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
55+
56+
paddle::dialect::VjpInterface tanh_vjp_interface =
57+
op2->dyn_cast<paddle::dialect::VjpInterface>();
58+
59+
std::vector<int> stop_gradients{0};
60+
std::vector<ir::OpResult> out_grads{op3.out()};
61+
std::vector<ir::OpResult> grad_res =
62+
tanh_vjp_interface.Vjp(op2.operation(), out_grads, stop_gradients);
63+
}

0 commit comments

Comments
 (0)