@@ -1194,30 +1194,12 @@ struct EmbeddingOpTranscriber : public OpTranscriber {
1194
1194
const auto & output_vars = op_desc.Output (" Out" );
1195
1195
const auto & output_name = output_vars[0 ];
1196
1196
1197
- const dialect::DenseTensorType& out_tensor_type = std::get<1 >(out_info);
1198
1197
pir::Value& out_value = std::get<2 >(out_info);
1199
-
1200
- ValueInfo ids_info = GetTensorInfoByVarName (
1201
- op_desc, op_desc.Input (" Ids" , true ), param_map, " Ids" );
1202
- const std::vector<int64_t >& ids_shape = std::get<0 >(ids_info);
1203
-
1204
- ValueInfo w_info = GetTensorInfoByVarName (
1205
- op_desc, op_desc.Input (" W" , true ), param_map, " W" );
1206
-
1207
- const std::vector<int64_t >& w_shape = std::get<0 >(w_info);
1208
-
1209
- std::vector<int64_t > out_new_shape (
1210
- ids_shape.begin (), ids_shape.begin () + ids_shape.size () - 1 );
1211
- out_new_shape.insert (out_new_shape.end (), w_shape[1 ]);
1212
-
1213
1198
pir::Builder builder (ctx, operation->GetParent ());
1214
- dialect::ReshapeOp reshape_op_out =
1215
- builder.Build <dialect::ReshapeOp>(out_value, out_new_shape);
1216
- pir::Value out_new = reshape_op_out.out ();
1217
- VLOG (6 ) << " [" << op_desc.Type () << " ] out_shape change from "
1218
- << out_tensor_type.dims () << " to "
1219
- << common::make_ddim (out_new_shape);
1220
-
1199
+ std::vector<int64_t > axis = {-2 };
1200
+ dialect::SqueezeOp squeeze_op_out =
1201
+ builder.Build <dialect::SqueezeOp>(out_value, axis);
1202
+ pir::Value out_new = squeeze_op_out.out ();
1221
1203
param_map->PushValue (output_name,
1222
1204
VariableDefiningInfo (out_new, false , -1 ));
1223
1205
}
0 commit comments