Skip to content

Commit 39fda14

Browse files
authored
Polish bfloat16 main_grad unittest for data parallel and sharding stage1. (#58842)
* Polish bfloat16 main_grad unittest for data parallel. * Optimize unittest of sharding stage1. * Polish codes and add check of weights. * Polish unittest for sharding stage1. * Revert some minor changes. * Polish the compare of parameters. * Compute loss in float32.
1 parent aea6907 commit 39fda14

File tree

3 files changed

+350
-253
lines changed

3 files changed

+350
-253
lines changed
+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
import re
17+
from collections import OrderedDict
18+
19+
import numpy as np
20+
21+
import paddle
22+
from paddle.distributed.fleet.utils import mix_precision_utils
23+
from paddle.nn import Linear, ReLU
24+
25+
logging.basicConfig(level="INFO", format="%(message)s")
26+
27+
28+
class MLP(paddle.nn.Layer):
29+
def __init__(self, linear_size=1000):
30+
super().__init__()
31+
32+
self._linear1 = Linear(linear_size, linear_size)
33+
self._linear2 = Linear(linear_size, linear_size)
34+
self._linear3 = Linear(linear_size, 10)
35+
self._relu = ReLU()
36+
37+
def forward(self, inputs):
38+
y = self._linear1(inputs)
39+
y = self._linear2(y)
40+
y = self._linear3(y)
41+
y = self._relu(y)
42+
return y
43+
44+
45+
class RandomDataset(paddle.io.Dataset):
46+
def __init__(self, num_samples=200, linear_size=1000):
47+
self.num_samples = num_samples
48+
self.linear_size = linear_size
49+
self.samples = []
50+
for i in range(num_samples):
51+
img = np.random.rand(self.linear_size).astype('float32')
52+
self.samples.append(img)
53+
54+
def __getitem__(self, idx):
55+
return self.samples[idx]
56+
57+
def __len__(self):
58+
return self.num_samples
59+
60+
61+
def create_optimizer(model, use_pure_bf16, use_main_grad):
62+
if use_main_grad:
63+
assert use_pure_bf16
64+
model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16")
65+
optimizer = paddle.optimizer.AdamW(
66+
parameters=model.parameters(),
67+
learning_rate=0.00001,
68+
weight_decay=0.00001,
69+
grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0),
70+
multi_precision=use_pure_bf16,
71+
)
72+
if use_main_grad:
73+
optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer)
74+
75+
return optimizer
76+
77+
78+
def save_model_parameters(model):
79+
param_dict = OrderedDict()
80+
for param in model.parameters():
81+
param_dict[param.name] = param
82+
return param_dict
83+
84+
85+
def _extract_linear_order(param_names):
86+
# for param_names from model.state_dict, they are as like: ["_linear1.weight", "_linear1.bias"]
87+
# for master weight names from optimizer.state_dict, they are as like: ["linear_6.w_0", "linear_6.b_0"]
88+
param_order = []
89+
for name in param_names:
90+
param_id = re.findall(r"\d+", name)
91+
assert len(param_id) >= 1
92+
param_order.append(int(param_id[0]))
93+
return list(set(param_order))
94+
95+
96+
def _extract_param_order_dict(model_param_dict_o1, model_param_dict_o2):
97+
param_names_o1 = list(model_param_dict_o1.keys())
98+
param_order_o1 = _extract_linear_order(param_names_o1)
99+
param_order_o1.sort()
100+
101+
param_names_o2 = list(model_param_dict_o2.keys())
102+
param_order_o2 = _extract_linear_order(param_names_o2)
103+
param_order_o2.sort()
104+
105+
assert len(param_order_o1) == len(param_order_o2)
106+
107+
param_order_dict = {}
108+
for i in range(len(param_order_o1)):
109+
param_order_dict[param_order_o2[i]] = param_order_o1[i]
110+
111+
logging.info(f"-- param_names_o1: {param_names_o1}")
112+
logging.info(f"-- param_names_o2: {param_names_o2}")
113+
logging.info(f"param_order_dict: {param_order_dict}")
114+
return param_order_dict
115+
116+
117+
def compare_state_dict(
118+
model_param_dict_o1, model_param_dict_o2, optimizer_state_dict_o2
119+
):
120+
master_weights = None
121+
if optimizer_state_dict_o2.get("master_weights", None) is not None:
122+
master_weights = optimizer_state_dict_o2["master_weights"]
123+
assert master_weights is not None
124+
master_weights_names = list(master_weights.keys())
125+
126+
param_names = list(model_param_dict_o1.keys())
127+
param_order_dict = _extract_param_order_dict(
128+
model_param_dict_o1, model_param_dict_o2
129+
)
130+
param_master_pair = []
131+
132+
# We assume the order of params in param_names and master_weights_names is the same.
133+
param_id = 0
134+
for master_weight_name in master_weights_names:
135+
master_weight_id = re.findall(r"\d+", master_weight_name)[0]
136+
param_id = param_order_dict[int(master_weight_id)]
137+
for param_name in param_names:
138+
if (
139+
master_weight_name.endswith("w_0")
140+
and param_name.endswith("weight")
141+
) or (
142+
master_weight_name.endswith("b_0")
143+
and param_name.endswith("bias")
144+
):
145+
name_prefix = "linear" + param_id
146+
if name_prefix in param_name:
147+
param_master_pair.append([param_name, master_weight_name])
148+
149+
logging.info(f"-- master_weights_names: {master_weights_names}")
150+
for pair in param_master_pair:
151+
param_name = pair[0]
152+
master_weight_name = pair[1]
153+
logging.info(f"-- compare {param_name} with {master_weight_name}")
154+
param_o1 = model_param_dict_o1[param_name]
155+
master_param_o2 = master_weights[master_weight_name]
156+
np.testing.assert_array_equal(param_o1.numpy(), master_param_o2.numpy())

0 commit comments

Comments
 (0)