Skip to content

Commit af80a84

Browse files
refine lookup_table translator (#69399)
1 parent aa323f6 commit af80a84

File tree

1 file changed

+4
-22
lines changed

1 file changed

+4
-22
lines changed

paddle/fluid/ir_adaptor/translator/op_translator.cc

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,30 +1194,12 @@ struct EmbeddingOpTranscriber : public OpTranscriber {
11941194
const auto& output_vars = op_desc.Output("Out");
11951195
const auto& output_name = output_vars[0];
11961196

1197-
const dialect::DenseTensorType& out_tensor_type = std::get<1>(out_info);
11981197
pir::Value& out_value = std::get<2>(out_info);
1199-
1200-
ValueInfo ids_info = GetTensorInfoByVarName(
1201-
op_desc, op_desc.Input("Ids", true), param_map, "Ids");
1202-
const std::vector<int64_t>& ids_shape = std::get<0>(ids_info);
1203-
1204-
ValueInfo w_info = GetTensorInfoByVarName(
1205-
op_desc, op_desc.Input("W", true), param_map, "W");
1206-
1207-
const std::vector<int64_t>& w_shape = std::get<0>(w_info);
1208-
1209-
std::vector<int64_t> out_new_shape(
1210-
ids_shape.begin(), ids_shape.begin() + ids_shape.size() - 1);
1211-
out_new_shape.insert(out_new_shape.end(), w_shape[1]);
1212-
12131198
pir::Builder builder(ctx, operation->GetParent());
1214-
dialect::ReshapeOp reshape_op_out =
1215-
builder.Build<dialect::ReshapeOp>(out_value, out_new_shape);
1216-
pir::Value out_new = reshape_op_out.out();
1217-
VLOG(6) << "[" << op_desc.Type() << "] out_shape change from "
1218-
<< out_tensor_type.dims() << " to "
1219-
<< common::make_ddim(out_new_shape);
1220-
1199+
std::vector<int64_t> axis = {-2};
1200+
dialect::SqueezeOp squeeze_op_out =
1201+
builder.Build<dialect::SqueezeOp>(out_value, axis);
1202+
pir::Value out_new = squeeze_op_out.out();
12211203
param_map->PushValue(output_name,
12221204
VariableDefiningInfo(out_new, false, -1));
12231205
}

0 commit comments

Comments
 (0)