Skip to content

Commit f62047d

Browse files
conflict
2 parents 9f3fc89 + 604b7a5 commit f62047d

File tree

18 files changed

+500
-17
lines changed

18 files changed

+500
-17
lines changed

paddle/fluid/eager/auto_code_generator/generate_file_structures.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ def GenerateFileStructureForIntermediateDygraph(eager_dir, split_count):
143143
for i in range(split_count):
144144
f.write("nodes" + str(i + 1) + ".cc ")
145145
f.write("${fluid_manual_nodes} DEPS ${eager_deps} ${fluid_deps})\n")
146-
f.write("add_dependencies(dygraph_node copy_dygraph_node)\n")
146+
f.write(
147+
"add_dependencies(dygraph_node copy_dygraph_node copy_dygraph_forward_functions)\n"
148+
)
147149

148150
with open(forwards_level_cmakelist_path, "w") as f:
149151
f.write("add_custom_target(\n")
@@ -181,7 +183,7 @@ def GenerateFileStructureForIntermediateDygraph(eager_dir, split_count):
181183
"${fluid_manual_functions} DEPS ${eager_deps} ${fluid_deps} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})\n"
182184
)
183185
f.write(
184-
"add_dependencies(dygraph_function copy_dygraph_forward_functions)\n"
186+
"add_dependencies(dygraph_function copy_dygraph_forward_functions copy_dygraph_node)\n"
185187
)
186188

187189
with open(generated_level_cmakelist_path, "w") as f:

paddle/fluid/prim/api/api.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
- subtract
33
- multiply
44
- divide
5+
- less_equal
6+
- less_than
7+
- equal
8+
- not_equal
9+
- greater_equal
10+
- greater_than
511
- bitwise_and
612
- bitwise_not
713
- bitwise_or
@@ -35,3 +41,4 @@
3541
- less_equal
3642
- sin
3743
- cos
44+
- where

paddle/fluid/prim/api/composite_backward/composite_backward_api.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,18 @@ using Tensor = paddle::Tensor;
3030
using IntArray = paddle::experimental::IntArrayBase<paddle::Tensor>;
3131
// This function should have as same signature as phi, which defined in
3232
// paddle/phi/api/backward/backward_api.h
33+
template <typename T>
34+
void relu_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
35+
if (x_grad) {
36+
auto condition = greater_than<T>(
37+
out, full<T>(phi::vectorize(out.dims()), 0.0, out.dtype()));
38+
auto res = where<T>(condition,
39+
out_grad,
40+
full<T>(phi::vectorize(out.dims()), 0.0, out.dtype()));
41+
set_output<T>(res, x_grad);
42+
}
43+
}
44+
3345
template <typename T>
3446
void softmax_grad(const Tensor& out,
3547
const Tensor& out_grad,

paddle/fluid/prim/tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ cc_test_old(
3838
static_global_utils
3939
static_tensor_operants
4040
tensor_api
41-
operants_manager)
41+
operants_manager
42+
generated_static_op)
4243

4344
if(NOT (NOT WITH_PYTHON AND ON_INFER))
4445
cc_library(

paddle/fluid/prim/tests/test_eager_prim.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
3535
PD_DECLARE_KERNEL(pow, CPU, ALL_LAYOUT);
3636
PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
3737
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
38+
PD_DECLARE_KERNEL(less_equal, CPU, ALL_LAYOUT);
39+
PD_DECLARE_KERNEL(less_than, CPU, ALL_LAYOUT);
40+
PD_DECLARE_KERNEL(equal, CPU, ALL_LAYOUT);
41+
PD_DECLARE_KERNEL(not_equal, CPU, ALL_LAYOUT);
42+
PD_DECLARE_KERNEL(greater_equal, CPU, ALL_LAYOUT);
43+
PD_DECLARE_KERNEL(greater_than, CPU, ALL_LAYOUT);
3844
PD_DECLARE_KERNEL(bitwise_and, CPU, ALL_LAYOUT);
3945
PD_DECLARE_KERNEL(bitwise_or, CPU, ALL_LAYOUT);
4046
PD_DECLARE_KERNEL(bitwise_xor, CPU, ALL_LAYOUT);
@@ -46,6 +52,12 @@ PD_DECLARE_KERNEL(tanh_grad, GPU, ALL_LAYOUT);
4652
PD_DECLARE_KERNEL(pow, GPU, ALL_LAYOUT);
4753
PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
4854
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
55+
PD_DECLARE_KERNEL(less_equal, KPS, ALL_LAYOUT);
56+
PD_DECLARE_KERNEL(less_than, KPS, ALL_LAYOUT);
57+
PD_DECLARE_KERNEL(equal, KPS, ALL_LAYOUT);
58+
PD_DECLARE_KERNEL(not_equal, KPS, ALL_LAYOUT);
59+
PD_DECLARE_KERNEL(greater_equal, KPS, ALL_LAYOUT);
60+
PD_DECLARE_KERNEL(greater_than, KPS, ALL_LAYOUT);
4961
PD_DECLARE_KERNEL(bitwise_and, KPS, ALL_LAYOUT);
5062
PD_DECLARE_KERNEL(bitwise_or, KPS, ALL_LAYOUT);
5163
PD_DECLARE_KERNEL(bitwise_xor, KPS, ALL_LAYOUT);
@@ -151,6 +163,50 @@ TEST(EagerPrim, LogicalOperantsTest) {
151163
EXPECT_EQ(out0.data<int>()[0], out1.data<int>()[0]);
152164
}
153165

166+
TEST(EagerPrim, CompareOperantsTest) {
167+
// 1. Initialized
168+
eager_test::InitEnv(paddle::platform::CPUPlace());
169+
FLAGS_tensor_operants_mode = "eager";
170+
paddle::prim::InitTensorOperants();
171+
// 2. pre
172+
paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32});
173+
paddle::Tensor tensor0 =
174+
::egr::egr_utils_api::CreateTensorWithValue(ddim,
175+
paddle::platform::CPUPlace(),
176+
phi::DataType::INT32,
177+
phi::DataLayout::NCHW,
178+
1 /*value*/,
179+
true /*is_leaf*/);
180+
::egr::egr_utils_api::RetainGradForTensor(tensor0);
181+
paddle::Tensor tensor1 =
182+
::egr::egr_utils_api::CreateTensorWithValue(ddim,
183+
paddle::platform::CPUPlace(),
184+
phi::DataType::INT32,
185+
phi::DataLayout::NCHW,
186+
0 /*value*/,
187+
true /*is_leaf*/);
188+
::egr::egr_utils_api::RetainGradForTensor(tensor1);
189+
// 3. Run Forward once
190+
paddle::Tensor out0 = (tensor0 < tensor1);
191+
paddle::Tensor out1 = less_than_ad_func(tensor0, tensor1);
192+
EXPECT_EQ(out0.data<bool>()[0], out1.data<bool>()[0]);
193+
out0 = (tensor0 <= tensor1);
194+
out1 = less_equal_ad_func(tensor0, tensor1);
195+
EXPECT_EQ(out0.data<bool>()[0], out1.data<bool>()[0]);
196+
out0 = (tensor0 == tensor1);
197+
out1 = equal_ad_func(tensor0, tensor1);
198+
EXPECT_EQ(out0.data<bool>()[0], out1.data<bool>()[0]);
199+
out0 = (tensor0 != tensor1);
200+
out1 = not_equal_ad_func(tensor0, tensor1);
201+
EXPECT_EQ(out0.data<bool>()[0], out1.data<bool>()[0]);
202+
out0 = (tensor0 > tensor1);
203+
out1 = greater_than_ad_func(tensor0, tensor1);
204+
EXPECT_EQ(out0.data<bool>()[0], out1.data<bool>()[0]);
205+
out0 = (tensor0 >= tensor1);
206+
out1 = greater_equal_ad_func(tensor0, tensor1);
207+
EXPECT_EQ(out0.data<bool>()[0], out1.data<bool>()[0]);
208+
}
209+
154210
TEST(EagerPrim, TestFlags) {
155211
PrimCommonUtils::SetBwdPrimEnabled(true);
156212
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());

paddle/fluid/prim/tests/test_static_prim.cc

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
3838
PD_DECLARE_KERNEL(subtract, CPU, ALL_LAYOUT);
3939
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
4040
PD_DECLARE_KERNEL(concat, CPU, ALL_LAYOUT);
41+
PD_DECLARE_KERNEL(less_equal, CPU, ALL_LAYOUT);
42+
PD_DECLARE_KERNEL(less_than, CPU, ALL_LAYOUT);
43+
PD_DECLARE_KERNEL(equal, CPU, ALL_LAYOUT);
44+
PD_DECLARE_KERNEL(not_equal, CPU, ALL_LAYOUT);
45+
PD_DECLARE_KERNEL(greater_equal, CPU, ALL_LAYOUT);
46+
PD_DECLARE_KERNEL(greater_than, CPU, ALL_LAYOUT);
4147
PD_DECLARE_KERNEL(bitwise_and, CPU, ALL_LAYOUT);
4248
PD_DECLARE_KERNEL(bitwise_or, CPU, ALL_LAYOUT);
4349
PD_DECLARE_KERNEL(bitwise_xor, CPU, ALL_LAYOUT);
@@ -51,6 +57,12 @@ PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
5157
PD_DECLARE_KERNEL(subtract, KPS, ALL_LAYOUT);
5258
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
5359
PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT);
60+
PD_DECLARE_KERNEL(less_equal, KPS, ALL_LAYOUT);
61+
PD_DECLARE_KERNEL(less_than, KPS, ALL_LAYOUT);
62+
PD_DECLARE_KERNEL(equal, KPS, ALL_LAYOUT);
63+
PD_DECLARE_KERNEL(not_equal, KPS, ALL_LAYOUT);
64+
PD_DECLARE_KERNEL(greater_equal, KPS, ALL_LAYOUT);
65+
PD_DECLARE_KERNEL(greater_than, KPS, ALL_LAYOUT);
5466
PD_DECLARE_KERNEL(bitwise_and, KPS, ALL_LAYOUT);
5567
PD_DECLARE_KERNEL(bitwise_or, KPS, ALL_LAYOUT);
5668
PD_DECLARE_KERNEL(bitwise_xor, KPS, ALL_LAYOUT);
@@ -429,6 +441,99 @@ TEST(StaticCompositeGradMaker, LogicalOperantsTest) {
429441
std::size_t(1));
430442
}
431443

444+
TEST(StaticCompositeGradMaker, CompareOperantsTest) {
445+
// Initialized environment
446+
FLAGS_tensor_operants_mode = "static";
447+
paddle::OperantsManager::Instance().static_operants.reset(
448+
new paddle::prim::StaticTensorOperants());
449+
450+
TestBaseProgram base_program = TestBaseProgram();
451+
auto* target_block = base_program.GetBlock(0);
452+
std::vector<int64_t> shape = {2, 2};
453+
StaticCompositeContext::Instance().SetBlock(target_block);
454+
Tensor x0 = prim::empty<prim::DescTensor>(
455+
shape, phi::DataType::INT32, phi::CPUPlace());
456+
std::string x0_name =
457+
std::static_pointer_cast<prim::DescTensor>(x0.impl())->Name();
458+
Tensor x1 = prim::empty<prim::DescTensor>(
459+
shape, phi::DataType::INT32, phi::CPUPlace());
460+
std::string x1_name =
461+
std::static_pointer_cast<prim::DescTensor>(x1.impl())->Name();
462+
Tensor x2 = prim::empty<prim::DescTensor>(
463+
shape, phi::DataType::INT32, phi::CPUPlace());
464+
std::string x2_name =
465+
std::static_pointer_cast<prim::DescTensor>(x2.impl())->Name();
466+
Tensor x3 = prim::empty<prim::DescTensor>(
467+
shape, phi::DataType::INT32, phi::CPUPlace());
468+
std::string x3_name =
469+
std::static_pointer_cast<prim::DescTensor>(x3.impl())->Name();
470+
Tensor x4 = prim::empty<prim::DescTensor>(
471+
shape, phi::DataType::INT32, phi::CPUPlace());
472+
std::string x4_name =
473+
std::static_pointer_cast<prim::DescTensor>(x4.impl())->Name();
474+
Tensor x5 = prim::empty<prim::DescTensor>(
475+
shape, phi::DataType::INT32, phi::CPUPlace());
476+
std::string x5_name =
477+
std::static_pointer_cast<prim::DescTensor>(x5.impl())->Name();
478+
Tensor x6 = prim::empty<prim::DescTensor>(
479+
shape, phi::DataType::INT32, phi::CPUPlace());
480+
std::string x6_name =
481+
std::static_pointer_cast<prim::DescTensor>(x6.impl())->Name();
482+
483+
Tensor out_less = (x0 < x1);
484+
Tensor out_less_equal = (out_less <= x2);
485+
Tensor out_equal = (out_less_equal == x3);
486+
Tensor out_not_equal = (out_equal != x4);
487+
Tensor out_greater = (out_not_equal > x5);
488+
Tensor out_greater_equal = (out_greater >= x6);
489+
490+
ASSERT_EQ(target_block->AllOps().size(), static_cast<std::size_t>(6));
491+
ASSERT_EQ(target_block->AllOps()[0]->Type(), "less_than");
492+
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X").size(),
493+
static_cast<std::size_t>(1));
494+
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X")[0], x0_name);
495+
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("Y").size(),
496+
static_cast<std::size_t>(1));
497+
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("Y")[0], x1_name);
498+
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out").size(),
499+
std::size_t(1));
500+
501+
ASSERT_EQ(target_block->AllOps()[1]->Type(), "less_equal");
502+
ASSERT_EQ(target_block->AllOps()[1]->Inputs().at("Y").size(),
503+
static_cast<std::size_t>(1));
504+
ASSERT_EQ(target_block->AllOps()[1]->Inputs().at("Y")[0], x2_name);
505+
ASSERT_EQ(target_block->AllOps()[1]->Outputs().at("Out").size(),
506+
std::size_t(1));
507+
508+
ASSERT_EQ(target_block->AllOps()[2]->Type(), "equal");
509+
ASSERT_EQ(target_block->AllOps()[2]->Inputs().at("Y").size(),
510+
static_cast<std::size_t>(1));
511+
ASSERT_EQ(target_block->AllOps()[2]->Inputs().at("Y")[0], x3_name);
512+
ASSERT_EQ(target_block->AllOps()[2]->Outputs().at("Out").size(),
513+
std::size_t(1));
514+
515+
ASSERT_EQ(target_block->AllOps()[3]->Type(), "not_equal");
516+
ASSERT_EQ(target_block->AllOps()[3]->Inputs().at("Y").size(),
517+
static_cast<std::size_t>(1));
518+
ASSERT_EQ(target_block->AllOps()[3]->Inputs().at("Y")[0], x4_name);
519+
ASSERT_EQ(target_block->AllOps()[3]->Outputs().at("Out").size(),
520+
std::size_t(1));
521+
522+
ASSERT_EQ(target_block->AllOps()[4]->Type(), "greater_than");
523+
ASSERT_EQ(target_block->AllOps()[4]->Inputs().at("Y").size(),
524+
static_cast<std::size_t>(1));
525+
ASSERT_EQ(target_block->AllOps()[4]->Inputs().at("Y")[0], x5_name);
526+
ASSERT_EQ(target_block->AllOps()[4]->Outputs().at("Out").size(),
527+
std::size_t(1));
528+
529+
ASSERT_EQ(target_block->AllOps()[5]->Type(), "greater_equal");
530+
ASSERT_EQ(target_block->AllOps()[5]->Inputs().at("Y").size(),
531+
static_cast<std::size_t>(1));
532+
ASSERT_EQ(target_block->AllOps()[5]->Inputs().at("Y")[0], x6_name);
533+
ASSERT_EQ(target_block->AllOps()[5]->Outputs().at("Out").size(),
534+
std::size_t(1));
535+
}
536+
432537
TEST(StaticPrim, TestFlags) {
433538
PrimCommonUtils::SetBwdPrimEnabled(true);
434539
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
@@ -445,6 +550,12 @@ USE_OP_ITSELF(elementwise_mul);
445550
USE_OP_ITSELF(elementwise_sub);
446551
USE_OP_ITSELF(elementwise_pow);
447552
USE_OP_ITSELF(scale);
553+
USE_OP_ITSELF(less_equal);
554+
USE_OP_ITSELF(less_than);
555+
USE_OP_ITSELF(equal);
556+
USE_OP_ITSELF(not_equal);
557+
USE_OP_ITSELF(greater_equal);
558+
USE_OP_ITSELF(greater_than);
448559
USE_OP_ITSELF(bitwise_xor);
449560
USE_OP_ITSELF(bitwise_and);
450561
USE_OP_ITSELF(bitwise_not);

paddle/phi/api/include/tensor.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -534,29 +534,23 @@ class PADDLE_API Tensor final {
534534
* @return Tensor
535535
*/
536536
Tensor operator+(const Tensor& other) const;
537-
538537
Tensor operator-(const Tensor& other) const;
539-
540538
Tensor operator*(const Tensor& other) const;
541-
542539
Tensor operator/(const Tensor& other) const;
543-
544540
Tensor operator+(const Scalar& other) const;
545-
546541
Tensor operator-(const Scalar& other) const;
547-
548542
Tensor operator*(const Scalar& other) const;
549-
550543
Tensor operator/(const Scalar& other) const;
551-
544+
Tensor operator<(const Tensor& other) const;
545+
Tensor operator<=(const Tensor& other) const;
546+
Tensor operator==(const Tensor& other) const;
547+
Tensor operator!=(const Tensor& other) const;
548+
Tensor operator>(const Tensor& other) const;
549+
Tensor operator>=(const Tensor& other) const;
552550
Tensor operator-() const;
553-
554551
Tensor operator~() const;
555-
556552
Tensor operator&(const Tensor& other) const;
557-
558553
Tensor operator|(const Tensor& other) const;
559-
560554
Tensor operator^(const Tensor& other) const;
561555

562556
/* Part 8: Autograd methods */
@@ -678,6 +672,12 @@ class PADDLE_API Tensor final {
678672
Tensor divide(const Scalar& y) const;
679673
Tensor multiply(const Scalar& y) const;
680674
Tensor subtract(const Scalar& y) const;
675+
Tensor less_equal(const Tensor& y) const;
676+
Tensor less_than(const Tensor& y) const;
677+
Tensor equal(const Tensor& y) const;
678+
Tensor not_equal(const Tensor& y) const;
679+
Tensor greater_equal(const Tensor& y) const;
680+
Tensor greater_than(const Tensor& y) const;
681681
Tensor bitwise_and(const Tensor& y) const;
682682
Tensor bitwise_or(const Tensor& y) const;
683683
Tensor bitwise_xor(const Tensor& y) const;

paddle/phi/api/yaml/backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,7 @@
11431143
kernel :
11441144
func : relu_grad
11451145
backward: relu_double_grad
1146+
composite: relu_grad(out, out_grad, x_grad)
11461147
inplace : (out_grad -> x_grad)
11471148

11481149
- backward_op : renorm_grad

paddle/phi/api/yaml/generator/tensor_operants_gen.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,30 @@ class TensorOperantsBase {
144144
return paddle::OperantsManager::Instance().subtract(static_cast<const Tensor &>(*this), y);
145145
}
146146
147+
Tensor Tensor::operator<(const Tensor &other) const {
148+
return less_than(other);
149+
}
150+
151+
Tensor Tensor::operator<=(const Tensor &other) const {
152+
return less_equal(other);
153+
}
154+
155+
Tensor Tensor::operator==(const Tensor &other) const {
156+
return equal(other);
157+
}
158+
159+
Tensor Tensor::operator!=(const Tensor &other) const {
160+
return not_equal(other);
161+
}
162+
163+
Tensor Tensor::operator>(const Tensor &other) const {
164+
return greater_than(other);
165+
}
166+
167+
Tensor Tensor::operator>=(const Tensor &other) const {
168+
return greater_equal(other);
169+
}
170+
147171
Tensor Tensor::operator-() const {
148172
return scale(-1.0, 0.0, true);
149173
}

paddle/phi/api/yaml/tensor_operants.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
- subtract
55
- multiply
66
- divide
7+
- less_equal
8+
- less_than
9+
- equal
10+
- not_equal
11+
- greater_equal
12+
- greater_than
713
- bitwise_and
814
- bitwise_not
915
- bitwise_or

0 commit comments

Comments
 (0)