From bd79ccb4e7d982f11b1c81ca99179b0452b92fbf Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 16 Apr 2025 15:12:21 +0000 Subject: [PATCH 1/2] refine infer_symbol_shape for pad3d and set_value op --- .../infer_symbolic_shape/unary_infer_sym.cc | 57 ++++++++++++++++++- test/dygraph_to_static/test_jit_setitem.py | 3 - 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 564fb564c13820..6dfdc0cda67912 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -2782,7 +2782,8 @@ bool Pad3dOpInferSymbolicShape(pir::Operation *op, const std::string &data_format = op->attribute("data_format").AsString(); const std::vector &paddings = - paddle::dialect::details::GetDataFromTensorOrTensorList(paddings_shape); + paddle::dialect::details::GetOrCreateExprVecFromData(paddings_shape, + infer_context); const std::vector &out_dims = [&] { std::vector out_dims = x_shape; PADDLE_ENFORCE_EQ(paddings.size(), @@ -3638,12 +3639,62 @@ bool SetValue_OpInferSymbolicShape( bool SetValueWithTensorOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - return SetValueOpInferSymbolicShape(op, infer_context); + const auto &input_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &input_shape = input_shape_or_data.shape(); + PADDLE_ENFORCE_LT( + input_shape.size(), + 7, + common::errors::InvalidArgument("The SetValueOp's rank of input should " + "be less than 7, but received %d.", + input_shape.size())); + + if (input_shape_or_data.isa() && + input_shape_or_data.data()) { + const auto &value = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + const auto &start = + infer_context->GetShapeOrDataForValue(op->operand_source(2)); + const auto &end = + infer_context->GetShapeOrDataForValue(op->operand_source(3)); + + const bool need_set_data = [&] { + if (!value.data().has_value() || value.data()->size() != 1) return false; + if (!start.data().has_value() || start.data()->size() != 1 || + !start.data()->at(0).isa()) + return false; + if (!end.data().has_value() || end.data()->size() != 1 || + !end.data()->at(0).isa()) + return false; + + int64_t start_val = start.data()->at(0).dyn_cast(); + int64_t end_val = end.data()->at(0).dyn_cast(); + if (end_val - start_val != 1ll || start_val < 0ll || + start_val >= static_cast(input_shape_or_data.data()->size())) + return false; + + return true; + }(); + + if (need_set_data) { + auto out_data = input_shape_or_data.data().value(); + out_data.at(start.data()->at(0).dyn_cast()) = + value.data()->at(0); + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::TensorShapeOrDataDimExprs(input_shape, out_data)); + return true; + } + } + + infer_context->SetShapeOrDataForValue( + op->result(0), symbol::TensorShapeOrDataDimExprs(input_shape)); + return true; } bool SetValueWithTensor_OpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - return SetValueOpInferSymbolicShape(op, infer_context); + return SetValueWithTensorOpInferSymbolicShape(op, infer_context); } // bool TensorUnfoldOpInferSymbolicShape( diff --git a/test/dygraph_to_static/test_jit_setitem.py b/test/dygraph_to_static/test_jit_setitem.py index 15fd8f973848ef..96cd220a50f0ea 100644 --- a/test/dygraph_to_static/test_jit_setitem.py +++ b/test/dygraph_to_static/test_jit_setitem.py @@ -18,7 +18,6 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, - test_phi_only, ) import paddle @@ -265,8 +264,6 @@ def run_dygraph(self, func): y = func(x, H, W) return (y,) - # NOTE(SigureMo): Please remove this function after the CINN case fixed - @test_phi_only def test_case(self): func = self.init_func() dy_res = self.run_dygraph(func) From 67d4ddfe196586782dde50f2769971b80f23da17 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 16 Apr 2025 16:09:46 +0000 Subject: [PATCH 2/2] polish code --- test/dygraph_to_static/test_jit_setitem.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test/dygraph_to_static/test_jit_setitem.py b/test/dygraph_to_static/test_jit_setitem.py index 96cd220a50f0ea..cde8ee51f87b35 100644 --- a/test/dygraph_to_static/test_jit_setitem.py +++ b/test/dygraph_to_static/test_jit_setitem.py @@ -264,14 +264,6 @@ def run_dygraph(self, func): y = func(x, H, W) return (y,) - def test_case(self): - func = self.init_func() - dy_res = self.run_dygraph(func) - st_res = self.run_to_static(func) - - for dy_out, st_out in zip(dy_res, st_res): - np.testing.assert_allclose(dy_out.numpy(), st_out.numpy()) - if __name__ == '__main__': unittest.main()