Skip to content

Commit 6c51e49

Browse files
authored
[AutoParallel]fp16 pass support assign op (#47649)
* fp16 pass support assign op * choose assign op exec mode * add unittest * add cmakelist
1 parent c65f056 commit 6c51e49

File tree

3 files changed

+167
-0
lines changed

3 files changed

+167
-0
lines changed

python/paddle/distributed/passes/auto_parallel_fp16.py

+21
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def __init__(
156156
list
157157
) # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
158158
self.is_train = False
159+
self.out_var_op_deps = {}
159160

160161
def _is_fp16_op(self, op_id):
161162
return self._op_fp16_dict.get(op_id, None)
@@ -169,6 +170,14 @@ def _build_state(self):
169170
# assume all backward block are behind forward blocks
170171
for block in self.program.blocks:
171172
for op in block.ops:
173+
for name in op.output_arg_names:
174+
if name not in self.out_var_op_deps:
175+
self.out_var_op_deps[name] = [op.desc.original_id()]
176+
else:
177+
self.out_var_op_deps[name].extend(
178+
[op.desc.original_id()]
179+
)
180+
172181
self._mark_op(op)
173182

174183
# set forward tensor dtype
@@ -192,6 +201,18 @@ def _mark_op(self, op):
192201
if op.type == "assign" and "array_" in op.input_arg_names[0]:
193202
self._op_fp16_dict[op.desc.original_id()] = False
194203
return
204+
# If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
205+
if op.type == "assign":
206+
out_name = op.output_arg_names[0]
207+
if len(self.out_var_op_deps[out_name]) > 1:
208+
if not self._op_fp16_dict[
209+
self.out_var_op_deps[out_name][0]
210+
]:
211+
self._op_fp16_dict[op.desc.original_id()] = False
212+
else:
213+
self._op_fp16_dict[op.desc.original_id()] = True
214+
return
215+
195216
if _need_keep_fp32(
196217
op, self.amp_list.unsupported_list, self.use_fp16_guard
197218
):

python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -115,5 +115,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
115115
py_test_modules(test_conditional_block_reshard MODULES
116116
test_conditional_block_reshard)
117117
py_test_modules(test_engine_api_error MODULES test_engine_api_error)
118+
py_test_modules(test_fp16_assign MODULES test_fp16_assign)
118119

119120
endif()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import copy
17+
18+
import paddle
19+
from paddle.distributed.fleet import auto
20+
from paddle.distributed.passes import new_pass
21+
22+
paddle.enable_static()
23+
24+
25+
def make_program():
26+
main_program = paddle.fluid.Program()
27+
start_program = paddle.fluid.Program()
28+
with paddle.static.program_guard(main_program, start_program):
29+
x = paddle.static.data(name='x', shape=[4, 6, 8], dtype='float32')
30+
y = paddle.static.data(name='y', shape=[4, 6, 6], dtype='float32')
31+
z = paddle.static.data(name='y', shape=[4, 6, 6], dtype='float32')
32+
33+
auto.shard_tensor(x, auto.ProcessMesh([0], ['d0']), [None, None, None])
34+
35+
out0 = paddle.static.nn.fc(
36+
x,
37+
size=6,
38+
num_flatten_dims=2,
39+
weight_attr=paddle.ParamAttr(
40+
initializer=paddle.nn.initializer.Constant(value=0.5)
41+
),
42+
bias_attr=paddle.ParamAttr(
43+
initializer=paddle.nn.initializer.Constant(value=1.0)
44+
),
45+
)
46+
where_0 = paddle.where(y > 1, y, out0)
47+
48+
out1 = paddle.static.nn.fc(
49+
out0,
50+
size=6,
51+
num_flatten_dims=2,
52+
weight_attr=paddle.ParamAttr(
53+
initializer=paddle.nn.initializer.Constant(value=0.5)
54+
),
55+
bias_attr=paddle.ParamAttr(
56+
initializer=paddle.nn.initializer.Constant(value=1.0)
57+
),
58+
)
59+
where_1 = paddle.where(y > 1, y, out1)
60+
61+
paddle.fluid.layers.assign(where_1, where_0)
62+
63+
return main_program, start_program
64+
65+
66+
def parallelizer(program_func, rank):
67+
from paddle.distributed.auto_parallel.completion import Completer
68+
from paddle.distributed.auto_parallel.partitioner import Partitioner
69+
from paddle.distributed.auto_parallel.dist_context import DistributedContext
70+
71+
main_program, start_program = program_func()
72+
73+
dist_context = DistributedContext()
74+
completer = Completer(dist_context)
75+
completer.complete_forward_annotation(main_program)
76+
dist_context.block_state.parse_forward_blocks(main_program)
77+
78+
strategy = auto.Strategy()
79+
amp = strategy.amp
80+
amp.enable = True
81+
amp.use_pure_fp16 = True
82+
amp.init_loss_scaling = 32768
83+
amp.use_fp16_guard = False
84+
amp.custom_black_list = ['where']
85+
86+
config = copy.deepcopy(strategy.amp.to_dict())
87+
config["dist_context"] = dist_context
88+
config["params_grads"] = []
89+
config["loss"] = None
90+
config["base_opt"] = None
91+
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
92+
auto_parallel_fp16_pass.apply([main_program], [start_program], None)
93+
94+
partitioner = Partitioner(dist_context, rank)
95+
dist_main_prog, _, _ = partitioner.partition(
96+
main_program, start_program, []
97+
)
98+
99+
return dist_main_prog, dist_context
100+
101+
102+
class TestFp16Assign(unittest.TestCase):
103+
def assert_fp32_dtype(self, block, op):
104+
for slot in op.input_names:
105+
for name in op.input(slot):
106+
if block.vars[name].dtype == paddle.bool:
107+
continue
108+
assert block.vars[name].dtype == paddle.float32
109+
for slot in op.output_names:
110+
for name in op.output(slot):
111+
if block.vars[name].dtype == paddle.bool:
112+
continue
113+
assert block.vars[name].dtype == paddle.float32
114+
115+
def assert_fp16_dtype(self, block, op):
116+
for slot in op.input_names:
117+
if slot == "Condition":
118+
continue
119+
for name in op.input(slot):
120+
if block.vars[name].dtype == paddle.bool:
121+
continue
122+
assert block.vars[name].dtype == paddle.float16
123+
for slot in op.output_names:
124+
for name in op.output(slot):
125+
if block.vars[name].dtype == paddle.bool:
126+
continue
127+
assert block.vars[name].dtype == paddle.float16
128+
129+
def test_fp16_assign(self):
130+
131+
dist_main_prog, dist_context = parallelizer(make_program, 0)
132+
block = dist_main_prog.global_block()
133+
for op in block.ops:
134+
if op.type == "cast":
135+
continue
136+
if op.type == "where":
137+
self.assert_fp32_dtype(block, op)
138+
elif op.type == "assign":
139+
self.assert_fp32_dtype(block, op)
140+
else:
141+
self.assert_fp16_dtype(block, op)
142+
143+
144+
if __name__ == "__main__":
145+
unittest.main()

0 commit comments

Comments
 (0)