@@ -38,6 +38,10 @@ PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
38
38
PD_DECLARE_KERNEL (subtract, CPU, ALL_LAYOUT);
39
39
PD_DECLARE_KERNEL (multiply, CPU, ALL_LAYOUT);
40
40
PD_DECLARE_KERNEL (concat, CPU, ALL_LAYOUT);
41
+ PD_DECLARE_KERNEL (bitwise_and, CPU, ALL_LAYOUT);
42
+ PD_DECLARE_KERNEL (bitwise_or, CPU, ALL_LAYOUT);
43
+ PD_DECLARE_KERNEL (bitwise_xor, CPU, ALL_LAYOUT);
44
+ PD_DECLARE_KERNEL (bitwise_not, CPU, ALL_LAYOUT);
41
45
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
42
46
PD_DECLARE_KERNEL (full, GPU, ALL_LAYOUT);
43
47
PD_DECLARE_KERNEL (tanh, GPU, ALL_LAYOUT);
@@ -47,6 +51,10 @@ PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
47
51
PD_DECLARE_KERNEL (subtract, KPS, ALL_LAYOUT);
48
52
PD_DECLARE_KERNEL (multiply, KPS, ALL_LAYOUT);
49
53
PD_DECLARE_KERNEL (concat, GPU, ALL_LAYOUT);
54
+ PD_DECLARE_KERNEL (bitwise_and, KPS, ALL_LAYOUT);
55
+ PD_DECLARE_KERNEL (bitwise_or, KPS, ALL_LAYOUT);
56
+ PD_DECLARE_KERNEL (bitwise_xor, KPS, ALL_LAYOUT);
57
+ PD_DECLARE_KERNEL (bitwise_not, KPS, ALL_LAYOUT);
50
58
#endif
51
59
namespace paddle {
52
60
namespace prim {
@@ -362,6 +370,68 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
362
370
ASSERT_EQ (fw_out_name[1 ], " out2" );
363
371
}
364
372
373
+ TEST (StaticCompositeGradMaker, LogicalOperantsTest) {
374
+ // Initialized environment
375
+ FLAGS_tensor_operants_mode = " static" ;
376
+ paddle::OperantsManager::Instance ().static_operants .reset (
377
+ new paddle::prim::StaticTensorOperants ());
378
+
379
+ TestBaseProgram base_program = TestBaseProgram ();
380
+ auto * target_block = base_program.GetBlock (0 );
381
+ std::vector<int64_t > shape = {2 , 2 };
382
+ StaticCompositeContext::Instance ().SetBlock (target_block);
383
+ Tensor x0 = prim::empty<prim::DescTensor>(
384
+ shape, phi::DataType::INT32, phi::CPUPlace ());
385
+ std::string x0_name =
386
+ std::static_pointer_cast<prim::DescTensor>(x0.impl ())->Name ();
387
+ Tensor x1 = prim::empty<prim::DescTensor>(
388
+ shape, phi::DataType::INT32, phi::CPUPlace ());
389
+ std::string x1_name =
390
+ std::static_pointer_cast<prim::DescTensor>(x1.impl ())->Name ();
391
+ Tensor x2 = prim::empty<prim::DescTensor>(
392
+ shape, phi::DataType::INT32, phi::CPUPlace ());
393
+ std::string x2_name =
394
+ std::static_pointer_cast<prim::DescTensor>(x2.impl ())->Name ();
395
+ Tensor x3 = prim::empty<prim::DescTensor>(
396
+ shape, phi::DataType::INT32, phi::CPUPlace ());
397
+ std::string x3_name =
398
+ std::static_pointer_cast<prim::DescTensor>(x3.impl ())->Name ();
399
+
400
+ Tensor out_not = ~x0;
401
+ Tensor out_and = out_not & x1;
402
+ Tensor out_or = out_and | x2;
403
+ Tensor out_xor = out_or ^ x3;
404
+
405
+ ASSERT_EQ (target_block->AllOps ().size (), static_cast <std::size_t >(4 ));
406
+ ASSERT_EQ (target_block->AllOps ()[0 ]->Type (), " bitwise_not" );
407
+ ASSERT_EQ (target_block->AllOps ()[0 ]->Inputs ().at (" X" ).size (),
408
+ static_cast <std::size_t >(1 ));
409
+ ASSERT_EQ (target_block->AllOps ()[0 ]->Inputs ().at (" X" )[0 ], x0_name);
410
+ ASSERT_EQ (target_block->AllOps ()[0 ]->Outputs ().at (" Out" ).size (),
411
+ std::size_t (1 ));
412
+
413
+ ASSERT_EQ (target_block->AllOps ()[1 ]->Type (), " bitwise_and" );
414
+ ASSERT_EQ (target_block->AllOps ()[1 ]->Inputs ().at (" Y" ).size (),
415
+ static_cast <std::size_t >(1 ));
416
+ ASSERT_EQ (target_block->AllOps ()[1 ]->Inputs ().at (" Y" )[0 ], x1_name);
417
+ ASSERT_EQ (target_block->AllOps ()[1 ]->Outputs ().at (" Out" ).size (),
418
+ std::size_t (1 ));
419
+
420
+ ASSERT_EQ (target_block->AllOps ()[2 ]->Type (), " bitwise_or" );
421
+ ASSERT_EQ (target_block->AllOps ()[2 ]->Inputs ().at (" Y" ).size (),
422
+ static_cast <std::size_t >(1 ));
423
+ ASSERT_EQ (target_block->AllOps ()[2 ]->Inputs ().at (" Y" )[0 ], x2_name);
424
+ ASSERT_EQ (target_block->AllOps ()[2 ]->Outputs ().at (" Out" ).size (),
425
+ std::size_t (1 ));
426
+
427
+ ASSERT_EQ (target_block->AllOps ()[3 ]->Type (), " bitwise_xor" );
428
+ ASSERT_EQ (target_block->AllOps ()[3 ]->Inputs ().at (" Y" ).size (),
429
+ static_cast <std::size_t >(1 ));
430
+ ASSERT_EQ (target_block->AllOps ()[3 ]->Inputs ().at (" Y" )[0 ], x3_name);
431
+ ASSERT_EQ (target_block->AllOps ()[3 ]->Outputs ().at (" Out" ).size (),
432
+ std::size_t (1 ));
433
+ }
434
+
365
435
TEST (StaticPrim, TestFlags) {
366
436
PrimCommonUtils::SetBwdPrimEnabled (true );
367
437
ASSERT_TRUE (PrimCommonUtils::IsBwdPrimEnabled ());
@@ -378,3 +448,7 @@ USE_OP_ITSELF(elementwise_mul);
378
448
USE_OP_ITSELF (elementwise_sub);
379
449
USE_OP_ITSELF (elementwise_pow);
380
450
USE_OP_ITSELF (scale);
451
+ USE_OP_ITSELF (bitwise_xor);
452
+ USE_OP_ITSELF (bitwise_and);
453
+ USE_OP_ITSELF (bitwise_not);
454
+ USE_OP_ITSELF (bitwise_or);
0 commit comments