Skip to content

Commit 0a9cdd4

Browse files
authored
fix expand_as (#66311)
1 parent c7a4bfa commit 0a9cdd4

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ bool ExpandAsOpInferSymbolicShape(
215215
std::vector<int> target_shape =
216216
paddle::dialect::details::GetVectorAttr<int>(op, "target_shape");
217217
const std::vector<symbol::DimExpr> &output_dims = [&] {
218+
if (op->operand_source(0)) {
219+
return infer_context->GetShapeOrDataForValue(op->operand_source(1))
220+
.shape();
221+
}
218222
std::vector<symbol::DimExpr> output_dims;
219223
output_dims.reserve(target_shape.size());
220224
for (int shape : target_shape) {

python/paddle/tensor/manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4447,7 +4447,7 @@ def expand_as(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
44474447
"some_var.stop_gradient = True, supporting "
44484448
"some_var as the input 'x'."
44494449
)
4450-
return _C_ops.expand_as(x, None, y.shape)
4450+
return _C_ops.expand_as(x, y, y.shape)
44514451
else:
44524452
check_variable_and_dtype(
44534453
x,

0 commit comments

Comments
 (0)