Skip to content

Commit 1794927

Browse files
authored
[Tensor Operants & Prim-Relevant] Tensor supports logical operants (#50983)
* Add comments for #50886 * [Tensor Operants & Prim-Relevant] Tensor supports logical operants * add prim dynamic unit test * add prim static unit test
1 parent 296b3ff commit 1794927

File tree

8 files changed

+374
-1
lines changed

8 files changed

+374
-1
lines changed

paddle/fluid/prim/api/api.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
- subtract
33
- multiply
44
- divide
5+
- bitwise_and
6+
- bitwise_not
7+
- bitwise_or
8+
- bitwise_xor
59
- unsqueeze
610
- exp
711
- scale

paddle/fluid/prim/tests/test_eager_prim.cc

+48-1
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,22 @@ PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
3535
PD_DECLARE_KERNEL(pow, CPU, ALL_LAYOUT);
3636
PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
3737
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
38+
PD_DECLARE_KERNEL(bitwise_and, CPU, ALL_LAYOUT);
39+
PD_DECLARE_KERNEL(bitwise_or, CPU, ALL_LAYOUT);
40+
PD_DECLARE_KERNEL(bitwise_xor, CPU, ALL_LAYOUT);
41+
PD_DECLARE_KERNEL(bitwise_not, CPU, ALL_LAYOUT);
3842
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
3943
PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT);
4044
PD_DECLARE_KERNEL(tanh, GPU, ALL_LAYOUT);
4145
PD_DECLARE_KERNEL(tanh_grad, GPU, ALL_LAYOUT);
4246
PD_DECLARE_KERNEL(pow, GPU, ALL_LAYOUT);
4347
PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
4448
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
49+
PD_DECLARE_KERNEL(bitwise_and, KPS, ALL_LAYOUT);
50+
PD_DECLARE_KERNEL(bitwise_or, KPS, ALL_LAYOUT);
51+
PD_DECLARE_KERNEL(bitwise_xor, KPS, ALL_LAYOUT);
52+
PD_DECLARE_KERNEL(bitwise_not, KPS, ALL_LAYOUT);
53+
4554
#endif
4655

4756
namespace paddle {
@@ -81,7 +90,7 @@ TEST(EagerPrim, TanhBackwardTest) {
8190

8291
paddle::experimental::Tensor out1 = tanh_ad_func(tensor1);
8392
std::vector<paddle::experimental::Tensor> outs1 = {out1};
84-
// Disable prim
93+
// Enable prim
8594
PrimCommonUtils::SetBwdPrimEnabled(true);
8695
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
8796
// 4. Run Backward
@@ -104,6 +113,44 @@ TEST(EagerPrim, TanhBackwardTest) {
104113
->data<float>()[0]);
105114
}
106115

116+
TEST(EagerPrim, LogicalOperantsTest) {
117+
// 1. Initialized
118+
eager_test::InitEnv(paddle::platform::CPUPlace());
119+
FLAGS_tensor_operants_mode = "eager";
120+
paddle::prim::InitTensorOperants();
121+
// 2. pre
122+
paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32});
123+
paddle::experimental::Tensor tensor0 =
124+
::egr::egr_utils_api::CreateTensorWithValue(ddim,
125+
paddle::platform::CPUPlace(),
126+
phi::DataType::INT32,
127+
phi::DataLayout::NCHW,
128+
1 /*value*/,
129+
true /*is_leaf*/);
130+
::egr::egr_utils_api::RetainGradForTensor(tensor0);
131+
paddle::experimental::Tensor tensor1 =
132+
::egr::egr_utils_api::CreateTensorWithValue(ddim,
133+
paddle::platform::CPUPlace(),
134+
phi::DataType::INT32,
135+
phi::DataLayout::NCHW,
136+
0 /*value*/,
137+
true /*is_leaf*/);
138+
::egr::egr_utils_api::RetainGradForTensor(tensor1);
139+
// 3. Run Forward once
140+
paddle::experimental::Tensor out0 = tensor0 & tensor1;
141+
paddle::experimental::Tensor out1 = bitwise_and_ad_func(tensor0, tensor1);
142+
EXPECT_EQ(out0.data<int>()[0], out1.data<int>()[0]);
143+
out0 = tensor0 | tensor1;
144+
out1 = bitwise_or_ad_func(tensor0, tensor1);
145+
EXPECT_EQ(out0.data<int>()[0], out1.data<int>()[0]);
146+
out0 = tensor0 ^ tensor1;
147+
out1 = bitwise_xor_ad_func(tensor0, tensor1);
148+
EXPECT_EQ(out0.data<int>()[0], out1.data<int>()[0]);
149+
out0 = ~tensor0;
150+
out1 = bitwise_not_ad_func(tensor0);
151+
EXPECT_EQ(out0.data<int>()[0], out1.data<int>()[0]);
152+
}
153+
107154
TEST(EagerPrim, TestFlags) {
108155
PrimCommonUtils::SetBwdPrimEnabled(true);
109156
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());

paddle/fluid/prim/tests/test_static_prim.cc

+74
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
3838
PD_DECLARE_KERNEL(subtract, CPU, ALL_LAYOUT);
3939
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
4040
PD_DECLARE_KERNEL(concat, CPU, ALL_LAYOUT);
41+
PD_DECLARE_KERNEL(bitwise_and, CPU, ALL_LAYOUT);
42+
PD_DECLARE_KERNEL(bitwise_or, CPU, ALL_LAYOUT);
43+
PD_DECLARE_KERNEL(bitwise_xor, CPU, ALL_LAYOUT);
44+
PD_DECLARE_KERNEL(bitwise_not, CPU, ALL_LAYOUT);
4145
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
4246
PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT);
4347
PD_DECLARE_KERNEL(tanh, GPU, ALL_LAYOUT);
@@ -47,6 +51,10 @@ PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
4751
PD_DECLARE_KERNEL(subtract, KPS, ALL_LAYOUT);
4852
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
4953
PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT);
54+
PD_DECLARE_KERNEL(bitwise_and, KPS, ALL_LAYOUT);
55+
PD_DECLARE_KERNEL(bitwise_or, KPS, ALL_LAYOUT);
56+
PD_DECLARE_KERNEL(bitwise_xor, KPS, ALL_LAYOUT);
57+
PD_DECLARE_KERNEL(bitwise_not, KPS, ALL_LAYOUT);
5058
#endif
5159
namespace paddle {
5260
namespace prim {
@@ -362,6 +370,68 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
362370
ASSERT_EQ(fw_out_name[1], "out2");
363371
}
364372

373+
TEST(StaticCompositeGradMaker, LogicalOperantsTest) {
374+
// Initialized environment
375+
FLAGS_tensor_operants_mode = "static";
376+
paddle::OperantsManager::Instance().static_operants.reset(
377+
new paddle::prim::StaticTensorOperants());
378+
379+
TestBaseProgram base_program = TestBaseProgram();
380+
auto* target_block = base_program.GetBlock(0);
381+
std::vector<int64_t> shape = {2, 2};
382+
StaticCompositeContext::Instance().SetBlock(target_block);
383+
Tensor x0 = prim::empty<prim::DescTensor>(
384+
shape, phi::DataType::INT32, phi::CPUPlace());
385+
std::string x0_name =
386+
std::static_pointer_cast<prim::DescTensor>(x0.impl())->Name();
387+
Tensor x1 = prim::empty<prim::DescTensor>(
388+
shape, phi::DataType::INT32, phi::CPUPlace());
389+
std::string x1_name =
390+
std::static_pointer_cast<prim::DescTensor>(x1.impl())->Name();
391+
Tensor x2 = prim::empty<prim::DescTensor>(
392+
shape, phi::DataType::INT32, phi::CPUPlace());
393+
std::string x2_name =
394+
std::static_pointer_cast<prim::DescTensor>(x2.impl())->Name();
395+
Tensor x3 = prim::empty<prim::DescTensor>(
396+
shape, phi::DataType::INT32, phi::CPUPlace());
397+
std::string x3_name =
398+
std::static_pointer_cast<prim::DescTensor>(x3.impl())->Name();
399+
400+
Tensor out_not = ~x0;
401+
Tensor out_and = out_not & x1;
402+
Tensor out_or = out_and | x2;
403+
Tensor out_xor = out_or ^ x3;
404+
405+
ASSERT_EQ(target_block->AllOps().size(), static_cast<std::size_t>(4));
406+
ASSERT_EQ(target_block->AllOps()[0]->Type(), "bitwise_not");
407+
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X").size(),
408+
static_cast<std::size_t>(1));
409+
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X")[0], x0_name);
410+
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out").size(),
411+
std::size_t(1));
412+
413+
ASSERT_EQ(target_block->AllOps()[1]->Type(), "bitwise_and");
414+
ASSERT_EQ(target_block->AllOps()[1]->Inputs().at("Y").size(),
415+
static_cast<std::size_t>(1));
416+
ASSERT_EQ(target_block->AllOps()[1]->Inputs().at("Y")[0], x1_name);
417+
ASSERT_EQ(target_block->AllOps()[1]->Outputs().at("Out").size(),
418+
std::size_t(1));
419+
420+
ASSERT_EQ(target_block->AllOps()[2]->Type(), "bitwise_or");
421+
ASSERT_EQ(target_block->AllOps()[2]->Inputs().at("Y").size(),
422+
static_cast<std::size_t>(1));
423+
ASSERT_EQ(target_block->AllOps()[2]->Inputs().at("Y")[0], x2_name);
424+
ASSERT_EQ(target_block->AllOps()[2]->Outputs().at("Out").size(),
425+
std::size_t(1));
426+
427+
ASSERT_EQ(target_block->AllOps()[3]->Type(), "bitwise_xor");
428+
ASSERT_EQ(target_block->AllOps()[3]->Inputs().at("Y").size(),
429+
static_cast<std::size_t>(1));
430+
ASSERT_EQ(target_block->AllOps()[3]->Inputs().at("Y")[0], x3_name);
431+
ASSERT_EQ(target_block->AllOps()[3]->Outputs().at("Out").size(),
432+
std::size_t(1));
433+
}
434+
365435
TEST(StaticPrim, TestFlags) {
366436
PrimCommonUtils::SetBwdPrimEnabled(true);
367437
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
@@ -378,3 +448,7 @@ USE_OP_ITSELF(elementwise_mul);
378448
USE_OP_ITSELF(elementwise_sub);
379449
USE_OP_ITSELF(elementwise_pow);
380450
USE_OP_ITSELF(scale);
451+
USE_OP_ITSELF(bitwise_xor);
452+
USE_OP_ITSELF(bitwise_and);
453+
USE_OP_ITSELF(bitwise_not);
454+
USE_OP_ITSELF(bitwise_or);

paddle/phi/api/include/tensor.h

+12
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,14 @@ class PADDLE_API Tensor final {
550550

551551
Tensor operator-() const;
552552

553+
Tensor operator~() const;
554+
555+
Tensor operator&(const Tensor& other) const;
556+
557+
Tensor operator|(const Tensor& other) const;
558+
559+
Tensor operator^(const Tensor& other) const;
560+
553561
/* Part 8: Autograd methods */
554562

555563
/**
@@ -669,6 +677,10 @@ class PADDLE_API Tensor final {
669677
Tensor divide(const Scalar& y) const;
670678
Tensor multiply(const Scalar& y) const;
671679
Tensor subtract(const Scalar& y) const;
680+
Tensor bitwise_and(const Tensor& y) const;
681+
Tensor bitwise_or(const Tensor& y) const;
682+
Tensor bitwise_xor(const Tensor& y) const;
683+
Tensor bitwise_not() const;
672684
Tensor pow(const Tensor& y) const;
673685
Tensor pow(const Scalar& y) const;
674686

paddle/phi/api/yaml/generator/tensor_operants_gen.py

+17
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
indent = " "
3131

32+
# E.g.: Prim uses `elementwise_pow + fill_constant` to replace `pow`, so that we use this map to generate the `pow` signature when iterating over `elementwise_pow` API.
3233
specific_ops_map = {"elementwise_pow": "pow"}
3334

3435

@@ -149,6 +150,22 @@ class TensorOperantsBase {
149150
return scale(-1.0, 0.0, true);
150151
}
151152
153+
Tensor Tensor::operator~() const {
154+
return bitwise_not();
155+
}
156+
157+
Tensor Tensor::operator&(const Tensor &other) const {
158+
return bitwise_and(other);
159+
}
160+
161+
Tensor Tensor::operator|(const Tensor &other) const {
162+
return bitwise_or(other);
163+
}
164+
165+
Tensor Tensor::operator^(const Tensor &other) const {
166+
return bitwise_xor(other);
167+
}
168+
152169
Tensor Tensor::pow(const Tensor& y) const {
153170
return paddle::OperantsManager::Instance().pow(static_cast<const Tensor &>(*this), y);
154171
}

paddle/phi/api/yaml/tensor_operants.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
# Attach operants to Tensor, this file should be consistent with the declaration in `tensor.h`
2+
# Assure this file is the subset of `paddle/fluid/prim/api/api.yaml`
23
- add
34
- subtract
45
- multiply
56
- divide
7+
- bitwise_and
8+
- bitwise_not
9+
- bitwise_or
10+
- bitwise_xor
611
- unsqueeze
712
- exp
813
- scale

0 commit comments

Comments
 (0)