From bb762c057a15dd770b748d80e513ab961642b8f6 Mon Sep 17 00:00:00 2001 From: xu98bin Date: Wed, 18 Jan 2023 11:57:28 +0800 Subject: [PATCH] solve auto_aprallel pp2 with fp16 question --- python/paddle/distributed/auto_parallel/completion.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 8960c47c1f5bf2..8979239df5f11c 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -1850,11 +1850,11 @@ def complete_update_annotation(self, serial_main_program): op_dist_attr.set_output_dims_mapping( input_var.name, ref_dims_mapping ) - - input_var_attr.process_mesh = ref_process_mesh - self._dist_context.set_tensor_dist_attr_for_program( - input_var, input_var_attr - ) + if "SkipUpdate" not in input_name: + input_var_attr.process_mesh = ref_process_mesh + self._dist_context.set_tensor_dist_attr_for_program( + input_var, input_var_attr + ) self._dist_context.set_op_dist_attr_for_program( op, op_dist_attr