Skip to content

Commit 5712fe3

Browse files
committed
remove xshape for reshape op
1 parent 827eb1e commit 5712fe3

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ class ReshapeOpPattern
268268
out_shape_attr[i].dyn_cast<::pir::Int64Attribute>().data());
269269
}
270270
}
271-
ReplaceWithCinnReshapeOp(op, rewriter, vec_out_shape);
271+
rewriter.ReplaceAllUsesWith(op.result(0), cinn_reshape.result(0));
272272
rewriter.EraseOp(op);
273273
}
274274
};

paddle/phi/ops/yaml/inconsistent/static_backward.yaml

+5-5
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@
364364
inplace : (out_grad_in -> out_grad_out)
365365

366366
- backward_op : reshape_double_grad
367-
forward : reshape_grad (Tensor xshape, Tensor grad_out) -> Tensor(grad_x)
367+
forward : reshape_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x)
368368
args : (Tensor grad_out, Tensor grad_x_grad)
369369
output : Tensor(grad_out_grad)
370370
infer_meta :
@@ -376,12 +376,12 @@
376376
inplace : (grad_x_grad -> grad_out_grad)
377377

378378
- backward_op : reshape_grad
379-
forward : reshape (Tensor x, IntArray shape) -> Tensor(out), Tensor(xshape)
380-
args : (Tensor xshape, Tensor out_grad)
379+
forward : reshape (Tensor x, IntArray shape) -> Tensor(out)
380+
args : (Tensor x, Tensor out_grad)
381381
output : Tensor(x_grad)
382382
infer_meta :
383-
func : KernelWithXShapeInferMeta
384-
param : [xshape, out_grad]
383+
func : UnchangedInferMeta
384+
param : [x, out_grad]
385385
spmd_rule: StaticReshapeGradInferSpmd
386386
kernel :
387387
func : reshape_grad

paddle/phi/ops/yaml/inconsistent/static_ops.yaml

+2-3
Original file line numberDiff line numberDiff line change
@@ -721,15 +721,14 @@
721721

722722
- op : reshape
723723
args : (Tensor x, IntArray shape)
724-
output : Tensor(out), Tensor(xshape)
724+
output : Tensor(out)
725725
infer_meta :
726-
func : ReshapeWithXShapeInferMeta
726+
func : ReshapeInferMeta
727727
spmd_rule : ReshapeInferSpmdDynamic
728728
kernel :
729729
func : reshape
730730
inplace : (x -> out)
731731
view: (x -> out)
732-
intermediate : xshape
733732
backward: reshape_grad
734733
interfaces : paddle::dialect::InferSymbolicShapeInterface
735734

0 commit comments

Comments
 (0)