Skip to content

[Tensor Operants & Prim-Relevant] Tensor supports logical operants #50983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
- subtract
- multiply
- divide
- bitwise_and
- bitwise_not
- bitwise_or
- bitwise_xor
- unsqueeze
- exp
- scale
Expand Down
49 changes: 48 additions & 1 deletion paddle/fluid/prim/tests/test_eager_prim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,22 @@ PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pow, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_and, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_or, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_xor, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_not, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pow, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_and, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_or, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_xor, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_not, KPS, ALL_LAYOUT);

#endif

namespace paddle {
Expand Down Expand Up @@ -81,7 +90,7 @@ TEST(EagerPrim, TanhBackwardTest) {

paddle::experimental::Tensor out1 = tanh_ad_func(tensor1);
std::vector<paddle::experimental::Tensor> outs1 = {out1};
// Disable prim
// Enable prim
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
// 4. Run Backward
Expand All @@ -104,6 +113,44 @@ TEST(EagerPrim, TanhBackwardTest) {
->data<float>()[0]);
}

TEST(EagerPrim, LogicalOperantsTest) {
// 1. Initialized
eager_test::InitEnv(paddle::platform::CPUPlace());
FLAGS_tensor_operants_mode = "eager";
paddle::prim::InitTensorOperants();
// 2. pre
paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32});
paddle::experimental::Tensor tensor0 =
::egr::egr_utils_api::CreateTensorWithValue(ddim,
paddle::platform::CPUPlace(),
phi::DataType::INT32,
phi::DataLayout::NCHW,
1 /*value*/,
true /*is_leaf*/);
::egr::egr_utils_api::RetainGradForTensor(tensor0);
paddle::experimental::Tensor tensor1 =
::egr::egr_utils_api::CreateTensorWithValue(ddim,
paddle::platform::CPUPlace(),
phi::DataType::INT32,
phi::DataLayout::NCHW,
0 /*value*/,
true /*is_leaf*/);
::egr::egr_utils_api::RetainGradForTensor(tensor1);
// 3. Run Forward once
paddle::experimental::Tensor out0 = tensor0 & tensor1;
paddle::experimental::Tensor out1 = bitwise_and_ad_func(tensor0, tensor1);
EXPECT_EQ(out0.data<int>()[0], out1.data<int>()[0]);
out0 = tensor0 | tensor1;
out1 = bitwise_or_ad_func(tensor0, tensor1);
EXPECT_EQ(out0.data<int>()[0], out1.data<int>()[0]);
out0 = tensor0 ^ tensor1;
out1 = bitwise_xor_ad_func(tensor0, tensor1);
EXPECT_EQ(out0.data<int>()[0], out1.data<int>()[0]);
out0 = ~tensor0;
out1 = bitwise_not_ad_func(tensor0);
EXPECT_EQ(out0.data<int>()[0], out1.data<int>()[0]);
}

TEST(EagerPrim, TestFlags) {
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
Expand Down
74 changes: 74 additions & 0 deletions paddle/fluid/prim/tests/test_static_prim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(subtract, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_and, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_or, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_xor, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_not, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh, GPU, ALL_LAYOUT);
Expand All @@ -47,6 +51,10 @@ PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(subtract, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_and, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_or, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_xor, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_not, KPS, ALL_LAYOUT);
#endif
namespace paddle {
namespace prim {
Expand Down Expand Up @@ -362,6 +370,68 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
ASSERT_EQ(fw_out_name[1], "out2");
}

TEST(StaticCompositeGradMaker, LogicalOperantsTest) {
// Initialized environment
FLAGS_tensor_operants_mode = "static";
paddle::OperantsManager::Instance().static_operants.reset(
new paddle::prim::StaticTensorOperants());

TestBaseProgram base_program = TestBaseProgram();
auto* target_block = base_program.GetBlock(0);
std::vector<int64_t> shape = {2, 2};
StaticCompositeContext::Instance().SetBlock(target_block);
Tensor x0 = prim::empty<prim::DescTensor>(
shape, phi::DataType::INT32, phi::CPUPlace());
std::string x0_name =
std::static_pointer_cast<prim::DescTensor>(x0.impl())->Name();
Tensor x1 = prim::empty<prim::DescTensor>(
shape, phi::DataType::INT32, phi::CPUPlace());
std::string x1_name =
std::static_pointer_cast<prim::DescTensor>(x1.impl())->Name();
Tensor x2 = prim::empty<prim::DescTensor>(
shape, phi::DataType::INT32, phi::CPUPlace());
std::string x2_name =
std::static_pointer_cast<prim::DescTensor>(x2.impl())->Name();
Tensor x3 = prim::empty<prim::DescTensor>(
shape, phi::DataType::INT32, phi::CPUPlace());
std::string x3_name =
std::static_pointer_cast<prim::DescTensor>(x3.impl())->Name();

Tensor out_not = ~x0;
Tensor out_and = out_not & x1;
Tensor out_or = out_and | x2;
Tensor out_xor = out_or ^ x3;

ASSERT_EQ(target_block->AllOps().size(), static_cast<std::size_t>(4));
ASSERT_EQ(target_block->AllOps()[0]->Type(), "bitwise_not");
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X")[0], x0_name);
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out").size(),
std::size_t(1));

ASSERT_EQ(target_block->AllOps()[1]->Type(), "bitwise_and");
ASSERT_EQ(target_block->AllOps()[1]->Inputs().at("Y").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[1]->Inputs().at("Y")[0], x1_name);
ASSERT_EQ(target_block->AllOps()[1]->Outputs().at("Out").size(),
std::size_t(1));

ASSERT_EQ(target_block->AllOps()[2]->Type(), "bitwise_or");
ASSERT_EQ(target_block->AllOps()[2]->Inputs().at("Y").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[2]->Inputs().at("Y")[0], x2_name);
ASSERT_EQ(target_block->AllOps()[2]->Outputs().at("Out").size(),
std::size_t(1));

ASSERT_EQ(target_block->AllOps()[3]->Type(), "bitwise_xor");
ASSERT_EQ(target_block->AllOps()[3]->Inputs().at("Y").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[3]->Inputs().at("Y")[0], x3_name);
ASSERT_EQ(target_block->AllOps()[3]->Outputs().at("Out").size(),
std::size_t(1));
}

TEST(StaticPrim, TestFlags) {
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
Expand All @@ -378,3 +448,7 @@ USE_OP_ITSELF(elementwise_mul);
USE_OP_ITSELF(elementwise_sub);
USE_OP_ITSELF(elementwise_pow);
USE_OP_ITSELF(scale);
USE_OP_ITSELF(bitwise_xor);
USE_OP_ITSELF(bitwise_and);
USE_OP_ITSELF(bitwise_not);
USE_OP_ITSELF(bitwise_or);
12 changes: 12 additions & 0 deletions paddle/phi/api/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,14 @@ class PADDLE_API Tensor final {

Tensor operator-() const;

Tensor operator~() const;

Tensor operator&(const Tensor& other) const;

Tensor operator|(const Tensor& other) const;

Tensor operator^(const Tensor& other) const;

/* Part 8: Autograd methods */

/**
Expand Down Expand Up @@ -669,6 +677,10 @@ class PADDLE_API Tensor final {
Tensor divide(const Scalar& y) const;
Tensor multiply(const Scalar& y) const;
Tensor subtract(const Scalar& y) const;
Tensor bitwise_and(const Tensor& y) const;
Tensor bitwise_or(const Tensor& y) const;
Tensor bitwise_xor(const Tensor& y) const;
Tensor bitwise_not() const;
Tensor pow(const Tensor& y) const;
Tensor pow(const Scalar& y) const;

Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/api/yaml/generator/tensor_operants_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

indent = " "

# E.g.: Prim uses `elementwise_pow + fill_constant` to replace `pow`, so that we use this map to generate the `pow` signature when iterating over `elementwise_pow` API.
specific_ops_map = {"elementwise_pow": "pow"}


Expand Down Expand Up @@ -149,6 +150,22 @@ class TensorOperantsBase {
return scale(-1.0, 0.0, true);
}

Tensor Tensor::operator~() const {
return bitwise_not();
}

Tensor Tensor::operator&(const Tensor &other) const {
return bitwise_and(other);
}

Tensor Tensor::operator|(const Tensor &other) const {
return bitwise_or(other);
}

Tensor Tensor::operator^(const Tensor &other) const {
return bitwise_xor(other);
}

Tensor Tensor::pow(const Tensor& y) const {
return paddle::OperantsManager::Instance().pow(static_cast<const Tensor &>(*this), y);
}
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/api/yaml/tensor_operants.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# Attach operants to Tensor, this file should be consistent with the declaration in `tensor.h`
# Assure this file is the subset of `paddle/fluid/prim/api/api.yaml`
- add
- subtract
- multiply
- divide
- bitwise_and
- bitwise_not
- bitwise_or
- bitwise_xor
- unsqueeze
- exp
- scale
Expand Down
Loading