Skip to content

[AutoParallel]fp16 pass support assign op #47649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions python/paddle/distributed/passes/auto_parallel_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(
list
) # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
self.is_train = False
self.out_var_op_deps = {}

def _is_fp16_op(self, op_id):
return self._op_fp16_dict.get(op_id, None)
Expand All @@ -169,6 +170,14 @@ def _build_state(self):
# assume all backward block are behind forward blocks
for block in self.program.blocks:
for op in block.ops:
for name in op.output_arg_names:
if name not in self.out_var_op_deps:
self.out_var_op_deps[name] = [op.desc.original_id()]
else:
self.out_var_op_deps[name].extend(
[op.desc.original_id()]
)

self._mark_op(op)

# set forward tensor dtype
Expand All @@ -192,6 +201,18 @@ def _mark_op(self, op):
if op.type == "assign" and "array_" in op.input_arg_names[0]:
self._op_fp16_dict[op.desc.original_id()] = False
return
# If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
if op.type == "assign":
out_name = op.output_arg_names[0]
if len(self.out_var_op_deps[out_name]) > 1:
if not self._op_fp16_dict[
self.out_var_op_deps[out_name][0]
]:
self._op_fp16_dict[op.desc.original_id()] = False
else:
self._op_fp16_dict[op.desc.original_id()] = True
return

if _need_keep_fp32(
op, self.amp_list.unsupported_list, self.use_fp16_guard
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_conditional_block_reshard MODULES
test_conditional_block_reshard)
py_test_modules(test_engine_api_error MODULES test_engine_api_error)
py_test_modules(test_fp16_assign MODULES test_fp16_assign)

endif()
145 changes: 145 additions & 0 deletions python/paddle/fluid/tests/unittests/auto_parallel/test_fp16_assign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import copy

import paddle
from paddle.distributed.fleet import auto
from paddle.distributed.passes import new_pass

paddle.enable_static()


def make_program():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 6, 8], dtype='float32')
y = paddle.static.data(name='y', shape=[4, 6, 6], dtype='float32')
z = paddle.static.data(name='y', shape=[4, 6, 6], dtype='float32')

auto.shard_tensor(x, auto.ProcessMesh([0], ['d0']), [None, None, None])

out0 = paddle.static.nn.fc(
x,
size=6,
num_flatten_dims=2,
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.5)
),
bias_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=1.0)
),
)
where_0 = paddle.where(y > 1, y, out0)

out1 = paddle.static.nn.fc(
out0,
size=6,
num_flatten_dims=2,
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.5)
),
bias_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=1.0)
),
)
where_1 = paddle.where(y > 1, y, out1)

paddle.fluid.layers.assign(where_1, where_0)

return main_program, start_program


def parallelizer(program_func, rank):
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.dist_context import DistributedContext

main_program, start_program = program_func()

dist_context = DistributedContext()
completer = Completer(dist_context)
completer.complete_forward_annotation(main_program)
dist_context.block_state.parse_forward_blocks(main_program)

strategy = auto.Strategy()
amp = strategy.amp
amp.enable = True
amp.use_pure_fp16 = True
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.custom_black_list = ['where']

config = copy.deepcopy(strategy.amp.to_dict())
config["dist_context"] = dist_context
config["params_grads"] = []
config["loss"] = None
config["base_opt"] = None
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply([main_program], [start_program], None)

partitioner = Partitioner(dist_context, rank)
dist_main_prog, _, _ = partitioner.partition(
main_program, start_program, []
)

return dist_main_prog, dist_context


class TestFp16Assign(unittest.TestCase):
def assert_fp32_dtype(self, block, op):
for slot in op.input_names:
for name in op.input(slot):
if block.vars[name].dtype == paddle.bool:
continue
assert block.vars[name].dtype == paddle.float32
for slot in op.output_names:
for name in op.output(slot):
if block.vars[name].dtype == paddle.bool:
continue
assert block.vars[name].dtype == paddle.float32

def assert_fp16_dtype(self, block, op):
for slot in op.input_names:
if slot == "Condition":
continue
for name in op.input(slot):
if block.vars[name].dtype == paddle.bool:
continue
assert block.vars[name].dtype == paddle.float16
for slot in op.output_names:
for name in op.output(slot):
if block.vars[name].dtype == paddle.bool:
continue
assert block.vars[name].dtype == paddle.float16

def test_fp16_assign(self):

dist_main_prog, dist_context = parallelizer(make_program, 0)
block = dist_main_prog.global_block()
for op in block.ops:
if op.type == "cast":
continue
if op.type == "where":
self.assert_fp32_dtype(block, op)
elif op.type == "assign":
self.assert_fp32_dtype(block, op)
else:
self.assert_fp16_dtype(block, op)


if __name__ == "__main__":
unittest.main()