@@ -188,6 +188,8 @@ void InterpreterCore::Convert() {
188
188
BuildAndCacheInstructionCtx (&vec_instruction_[i], *global_scope_, place_);
189
189
}
190
190
191
+ BuildSkipShareLoDInfo ();
192
+
191
193
for (size_t i = 0 ; i < vec_instruction_.size (); ++i) {
192
194
gc_event_.emplace_back (vec_instruction_[i].execution_ctx_ .get ()->GetPlace (),
193
195
platform::GenerateDeviceEventFlag ());
@@ -225,8 +227,8 @@ void InterpreterCore::BuildAndCacheInstructionCtx(
225
227
instr_node->runtime_ctx_ ->inputs .swap (ins_map);
226
228
instr_node->runtime_ctx_ ->outputs .swap (outs_map);
227
229
228
- instr_node->infershape_ctx_ .reset (
229
- new RuntimeInferShapeContext ( *op_base, *instr_node->runtime_ctx_ .get ()));
230
+ instr_node->infershape_ctx_ .reset (new InterpretercoreInferShapeContext (
231
+ *op_base, *instr_node->runtime_ctx_ .get ()));
230
232
231
233
auto * dev_ctx = instr_node->dev_ctx_ ;
232
234
Scope scope;
@@ -235,6 +237,26 @@ void InterpreterCore::BuildAndCacheInstructionCtx(
235
237
*op_base, scope, *dev_ctx, *instr_node->runtime_ctx_ .get ()));
236
238
}
237
239
240
+ void InterpreterCore::BuildSkipShareLoDInfo () {
241
+ for (size_t i = 0 ; i < vec_instruction_.size (); ++i) {
242
+ bool can_skip_lod = true ;
243
+ for (auto & input : vec_instruction_[i].runtime_ctx_ .get ()->inputs ) {
244
+ for (auto & var : input.second ) {
245
+ if (var->IsType <LoDTensor>()) {
246
+ if (var->Get <LoDTensor>().lod ().size () != 0 ) {
247
+ can_skip_lod = false ;
248
+ break ;
249
+ }
250
+ } else {
251
+ can_skip_lod = false ;
252
+ break ;
253
+ }
254
+ }
255
+ }
256
+ vec_instruction_[i].infershape_ctx_ .get ()->SetSkipLoD (can_skip_lod);
257
+ }
258
+ }
259
+
238
260
void InterpreterCore::RunInstruction (const Instruction& instr_node) {
239
261
VLOG (3 ) << " RunInstruction: "
240
262
<< instr_node.kernel_func_ .operator_base_ ->Type ();
0 commit comments