Skip to content

Commit 607db71

Browse files
committed
[Dy2St] Run PT in SOT mode only (#59658)
1 parent c294b45 commit 607db71

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

python/paddle/jit/dy2static/partial_program.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@ def __call__(self, inputs):
238238
in_vars, in_var_names = self._prepare_inputs(inputs)
239239
out_vars = self._prepare_outputs()
240240
self._cast_fp16_if_pure_fp16(in_vars)
241-
attrs = self._prepare_attributes()
241+
# TODO(dev): Currently AST + PT has some issues in control flow, so we only
242+
# enable SOT + PT in 2.6, we will fix it later.
243+
attrs = self._prepare_attributes(force_not_use_pt=True)
242244
attrs.extend(["x_names", in_var_names])
243245

244246
self._sync_lr_value_with_scheduler()
@@ -777,7 +779,7 @@ def _cast_fp16_if_pure_fp16(self, in_vars):
777779
in_vars[i] = var.astype('float16')
778780
in_vars[i].name = name
779781

780-
def _prepare_attributes(self):
782+
def _prepare_attributes(self, force_not_use_pt=False):
781783
attrs = [
782784
'forward_global_block',
783785
self.forward_program.desc.block(0),
@@ -822,6 +824,8 @@ def _prepare_attributes(self):
822824
is_cinn_enabled = self._build_strategy.build_cinn_pass
823825
if is_prim_enabled or in_cinn_backend or is_cinn_enabled:
824826
in_pir_pt_mode = False
827+
if force_not_use_pt:
828+
in_pir_pt_mode = False
825829
attrs.extend(['in_pir_pt_mode', in_pir_pt_mode])
826830

827831
return attrs

python/paddle/jit/sot/symbolic/compile_cache.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,13 @@ def __call__(self, *args, **kwargs):
104104
),
105105
)
106106
if self.partial_program is None:
107-
with EventGuard("FallbackWrapper: call compiled_fn"):
108-
outputs = self.compiled_fn(*args, **kwargs)
107+
with EventGuard("FallbackWrapper: get_concrete_program"):
109108
(
110109
self.concrete_program,
111110
self.partial_program,
112111
) = self.compiled_fn.get_concrete_program(*args, **kwargs)
113-
else:
114-
# Speed up Resnet from 0.0068 --> 0.0057
115-
with EventGuard("FallbackWrapper: call partial_program"):
116-
outputs = self.partial_program.sot_call(*args, **kwargs)
112+
with EventGuard("FallbackWrapper: sot call partial_program"):
113+
outputs = self.partial_program.sot_call(*args, **kwargs)
117114

118115
clear_eager_tensor_name(outputs)
119116
log_do(

0 commit comments

Comments
 (0)