File tree 3 files changed +8
-9
lines changed
cinn/hlir/dialect/operator/transforms
phi/ops/yaml/inconsistent
3 files changed +8
-9
lines changed Original file line number Diff line number Diff line change @@ -268,7 +268,7 @@ class ReshapeOpPattern
268
268
out_shape_attr[i].dyn_cast <::pir::Int64Attribute>().data ());
269
269
}
270
270
}
271
- ReplaceWithCinnReshapeOp (op, rewriter, vec_out_shape );
271
+ rewriter. ReplaceAllUsesWith (op. result ( 0 ), cinn_reshape. result ( 0 ) );
272
272
rewriter.EraseOp (op);
273
273
}
274
274
};
Original file line number Diff line number Diff line change 364
364
inplace : (out_grad_in -> out_grad_out)
365
365
366
366
- backward_op : reshape_double_grad
367
- forward : reshape_grad (Tensor xshape , Tensor grad_out) -> Tensor(grad_x)
367
+ forward : reshape_grad (Tensor x , Tensor grad_out) -> Tensor(grad_x)
368
368
args : (Tensor grad_out, Tensor grad_x_grad)
369
369
output : Tensor(grad_out_grad)
370
370
infer_meta :
376
376
inplace : (grad_x_grad -> grad_out_grad)
377
377
378
378
- backward_op : reshape_grad
379
- forward : reshape (Tensor x, IntArray shape) -> Tensor(out), Tensor(xshape)
380
- args : (Tensor xshape , Tensor out_grad)
379
+ forward : reshape (Tensor x, IntArray shape) -> Tensor(out)
380
+ args : (Tensor x , Tensor out_grad)
381
381
output : Tensor(x_grad)
382
382
infer_meta :
383
- func : KernelWithXShapeInferMeta
384
- param : [xshape , out_grad]
383
+ func : UnchangedInferMeta
384
+ param : [x , out_grad]
385
385
spmd_rule : StaticReshapeGradInferSpmd
386
386
kernel :
387
387
func : reshape_grad
Original file line number Diff line number Diff line change 721
721
722
722
- op : reshape
723
723
args : (Tensor x, IntArray shape)
724
- output : Tensor(out), Tensor(xshape)
724
+ output : Tensor(out)
725
725
infer_meta :
726
- func : ReshapeWithXShapeInferMeta
726
+ func : ReshapeInferMeta
727
727
spmd_rule : ReshapeInferSpmdDynamic
728
728
kernel :
729
729
func : reshape
730
730
inplace : (x -> out)
731
731
view : (x -> out)
732
- intermediate : xshape
733
732
backward : reshape_grad
734
733
interfaces : paddle::dialect::InferSymbolicShapeInterface
735
734
You can’t perform that action at this time.
0 commit comments