From d27f90a40c72840e94524a86e9ed233f890df15a Mon Sep 17 00:00:00 2001 From: feifei-111 <2364819892@qq.com> Date: Fri, 24 Mar 2023 08:50:44 +0000 Subject: [PATCH 1/5] fix dy2s grad name parse --- .../dygraph_to_static/test_gradname_parse.py | 63 +++++++++++++++++++ python/paddle/jit/dy2static/utils.py | 2 +- 2 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py new file mode 100644 index 00000000000000..f08cfcacf74c20 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import collections +import unittest + +import paddle +from paddle import ParamAttr +from paddle.nn import BatchNorm, Linear + + +class SimpleNet(paddle.nn.Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.linear0 = Linear(100,50) + self.linear1 = Linear(50,10) + + param_attr0 = ParamAttr(name="aaaprefix_bn_scale") + bias_attr0 = ParamAttr(name="aaaprefix_bn_offset") + self.bn0 = BatchNorm(50, param_attr=param_attr0, bias_attr=bias_attr0) + + param_attr1 = ParamAttr(name="bn_scale") + bias_attr1 = ParamAttr(name="bn_offset") + self.bn1 = BatchNorm(10, param_attr=param_attr1, bias_attr=bias_attr1) + + def forward(self, x): + x = self.linear0(x) + x = self.bn0(x) + x = self.linear1(x) + x = self.bn1(x) + return x + +class TestGradNameParse(unittest.TestCase): + def test_grad_name_parse(self): + net = SimpleNet() + opt = paddle.optimizer.Adam(learning_rate=0.1, parameters=net.parameters(), weight_decay=paddle.regularizer.L1Decay(0.01)) + net = paddle.jit.to_static(net) + inp = paddle.rand([100,100], dtype="float32") + out = net(inp) + loss = out.mean() + loss.backward() + + for name, param in net.bn1.named_parameters(): + if name in ["bn_scale", "bn_offset"]: + assert param.shape == param.grad.shape + + opt.minimize(loss) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index ac96163d704614..05c5cd1eeec5e1 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -1460,7 +1460,7 @@ def _param_grad_names(program_desc, params): for var in program_desc.block(0).all_vars() if var.name().endswith(param.name + '@GRAD') ] - if candidate: + if candidate and 'grad/' in param.name: names.append(max(candidate, key=lambda name: name.count('grad/'))) else: names.append(param.name + '@GRAD') From 5147e66d7afc02f67bdbd941b98669d8eadda519 Mon Sep 17 00:00:00 2001 From: feifei-111 <2364819892@qq.com> Date: Fri, 24 Mar 2023 08:53:35 +0000 Subject: [PATCH 2/5] pre-commit --- .../dygraph_to_static/test_gradname_parse.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py index f08cfcacf74c20..970e22113c1335 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function -import collections import unittest import paddle @@ -24,9 +22,9 @@ class SimpleNet(paddle.nn.Layer): def __init__(self): - super(SimpleNet, self).__init__() - self.linear0 = Linear(100,50) - self.linear1 = Linear(50,10) + super().__init__() + self.linear0 = Linear(100, 50) + self.linear1 = Linear(50, 10) param_attr0 = ParamAttr(name="aaaprefix_bn_scale") bias_attr0 = ParamAttr(name="aaaprefix_bn_offset") @@ -43,12 +41,17 @@ def forward(self, x): x = self.bn1(x) return x + class TestGradNameParse(unittest.TestCase): def test_grad_name_parse(self): net = SimpleNet() - opt = paddle.optimizer.Adam(learning_rate=0.1, parameters=net.parameters(), weight_decay=paddle.regularizer.L1Decay(0.01)) + opt = paddle.optimizer.Adam( + learning_rate=0.1, + parameters=net.parameters(), + weight_decay=paddle.regularizer.L1Decay(0.01), + ) net = paddle.jit.to_static(net) - inp = paddle.rand([100,100], dtype="float32") + inp = paddle.rand([100, 100], dtype="float32") out = net(inp) loss = out.mean() loss.backward() @@ -59,5 +62,6 @@ def test_grad_name_parse(self): opt.minimize(loss) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From eb07a4c41410eefab49add1213ba903b15154a5e Mon Sep 17 00:00:00 2001 From: feifei-111 <2364819892@qq.com> Date: Mon, 27 Mar 2023 02:52:14 +0000 Subject: [PATCH 3/5] bug fix --- python/paddle/jit/dy2static/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 05c5cd1eeec5e1..056dacb8a94743 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -1460,8 +1460,12 @@ def _param_grad_names(program_desc, params): for var in program_desc.block(0).all_vars() if var.name().endswith(param.name + '@GRAD') ] - if candidate and 'grad/' in param.name: - names.append(max(candidate, key=lambda name: name.count('grad/'))) + if candidate: + new_name = max(candidate, key=lambda name: name.count('grad/')) + if 'grad/' in new_name: + names.append(new_name) + else: + names.append(param.name + '@GRAD') else: names.append(param.name + '@GRAD') From 45f6feec53d17fd323094dca6a8ae01fd6bd89c2 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Thu, 6 Apr 2023 07:28:58 +0000 Subject: [PATCH 4/5] Fix grad/ error --- .../dygraph_to_static/test_gradname_parse.py | 12 ++++++---- python/paddle/jit/dy2static/utils.py | 24 +++++++++---------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py index 970e22113c1335..d51bcbf0684239 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py @@ -35,11 +35,12 @@ def __init__(self): self.bn1 = BatchNorm(10, param_attr=param_attr1, bias_attr=bias_attr1) def forward(self, x): - x = self.linear0(x) - x = self.bn0(x) - x = self.linear1(x) - x = self.bn1(x) - return x + x1 = self.linear0(x) + x2 = self.bn0(x1) + x3 = self.linear1(x2) + x4 = self.bn1(x3) + dx = paddle.grad(x4, x) + return dx[0] class TestGradNameParse(unittest.TestCase): @@ -52,6 +53,7 @@ def test_grad_name_parse(self): ) net = paddle.jit.to_static(net) inp = paddle.rand([100, 100], dtype="float32") + inp.stop_gradient = False out = net(inp) loss = out.mean() loss.backward() diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 056dacb8a94743..c5b99c59ec0646 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -1454,21 +1454,21 @@ def _param_grad_names(program_desc, params): names = [] # NOTE: `names` and `params` must be in the same order so that # the param grad name can be set correctly in the run_program. + for param in params: - candidate = [ - var.name() - for var in program_desc.block(0).all_vars() - if var.name().endswith(param.name + '@GRAD') - ] + candidate = [] + suffix = param.name + '@GRAD' + for var in program_desc.block(0).all_vars(): + var_name = var.name() + if var_name.endswith(suffix): + prefix_count = var_name.count('grad/') + if 'grad/' * prefix_count + suffix == var_name: + candidate.append(var_name) + if candidate: - new_name = max(candidate, key=lambda name: name.count('grad/')) - if 'grad/' in new_name: - names.append(new_name) - else: - names.append(param.name + '@GRAD') + names.append(max(candidate, key=lambda name: name.count('grad/'))) else: - names.append(param.name + '@GRAD') - + names.append(suffix) return names From 69e7f1a54c0611ccd3debd2c2db6c2e689472878 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Thu, 6 Apr 2023 07:31:28 +0000 Subject: [PATCH 5/5] Format code --- python/paddle/jit/dy2static/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index c5b99c59ec0646..6696b6c39038bb 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -1454,7 +1454,6 @@ def _param_grad_names(program_desc, params): names = [] # NOTE: `names` and `params` must be in the same order so that # the param grad name can be set correctly in the run_program. - for param in params: candidate = [] suffix = param.name + '@GRAD'