Skip to content

Commit f406545

Browse files
[AutoParallel] Fix dataloader in to_static mode. (#64334)
* [AutoParallel] Fix dataloader in to_static mode. * Polish code according to review comment. * Polish code. * Fix some problems. * Polish code. * Polish code. * Add test case. * Polish code. * Polish code. * Polish code.
1 parent be632d8 commit f406545

File tree

3 files changed

+292
-0
lines changed

3 files changed

+292
-0
lines changed

python/paddle/distributed/auto_parallel/static/engine.py

+9
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,15 @@ def _prepare_data_spec_from_dataloader(self, dataloader):
283283
inputs_spec = []
284284
labels_spec = []
285285
data = next(iter(dataloader))
286+
if hasattr(dataloader, "batch_sampler"):
287+
batch_sampler = dataloader.batch_sampler
288+
else:
289+
batch_sampler = dataloader._dataloader.batch_sampler
290+
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
291+
# Get data from DataLoader iterator directly may affect data generation randomness
292+
# of BatchSampler when `Shuffle=True`. It may cause difference of data feeding
293+
# between dynamic and to_static mode.
294+
batch_sampler.epoch -= 1
286295
if isinstance(data, dict):
287296
data = tuple(data.values())
288297
if len(data) != 2:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from functools import reduce
17+
18+
import numpy as np
19+
from semi_auto_parallel_llama_model import (
20+
LlamaForCausalLMAuto,
21+
LlamaPretrainingCriterionAuto,
22+
get_mesh,
23+
)
24+
25+
import paddle
26+
import paddle.distributed as dist
27+
from paddle import LazyGuard
28+
from paddle.io import BatchSampler, DataLoader, Dataset
29+
30+
31+
class Config:
32+
vocab_size = 32000
33+
hidden_size = 4096
34+
intermediate_size = 11008
35+
max_position_embeddings = 2048
36+
seq_length = 2048
37+
num_hidden_layers = 2
38+
num_attention_heads = 32
39+
num_key_value_heads = 32
40+
initializer_range = 0.02
41+
rms_norm_eps = 1e-6
42+
use_cache = True
43+
use_flash_attention = False
44+
sequence_parallel = False
45+
rope = True
46+
recompute = False
47+
recompute_granularity = None
48+
use_lazy_init = False
49+
50+
51+
inputs = []
52+
labels = []
53+
54+
for i in range(100):
55+
inputs.append(
56+
np.random.uniform(low=0, high=32000, size=[Config().seq_length]).astype(
57+
"int64"
58+
)
59+
)
60+
labels.append(
61+
(np.random.uniform(size=[Config().seq_length]) * 10).astype("int64")
62+
)
63+
64+
65+
class RandomDataset(Dataset):
66+
def __init__(self, seq_len, num_samples=100):
67+
super().__init__()
68+
self.seq_len = seq_len
69+
self.num_samples = num_samples
70+
71+
def __getitem__(self, index):
72+
global inputs, labels
73+
return inputs[index], labels[index]
74+
75+
def __len__(self):
76+
return self.num_samples
77+
78+
79+
def create_optimizer(model, lr_scheduler):
80+
decay_parameters = [
81+
p.name
82+
for n, p in model.named_parameters()
83+
if not any(nd in n for nd in ["bias", "norm"])
84+
]
85+
86+
def apply_decay_param_fun(x):
87+
return x in decay_parameters
88+
89+
# test global_clip in auto_parallel
90+
if os.getenv("use_param_group") == "true":
91+
param_group = {}
92+
param_group["params"] = list(model.parameters())
93+
param_group["weight_decay"] = 0.01
94+
param_group["grad_clip"] = paddle.nn.ClipGradByGlobalNorm(1.0)
95+
optimizer = paddle.optimizer.adamw.AdamW(
96+
learning_rate=lr_scheduler,
97+
apply_decay_param_fun=apply_decay_param_fun,
98+
parameters=[param_group],
99+
)
100+
else:
101+
optimizer = paddle.optimizer.adamw.AdamW(
102+
learning_rate=lr_scheduler,
103+
apply_decay_param_fun=apply_decay_param_fun,
104+
parameters=model.parameters(),
105+
weight_decay=0.01,
106+
grad_clip=paddle.nn.ClipGradByGlobalNorm(1.0),
107+
)
108+
return optimizer
109+
110+
111+
class TestLlamaAuto:
112+
def __init__(self):
113+
self.config = Config()
114+
self.dp = int(os.getenv("dp"))
115+
self.mp = int(os.getenv("mp"))
116+
self.pp = int(os.getenv("pp"))
117+
if os.getenv("use_sp") == "true":
118+
self.config.sequence_parallel = True
119+
if os.getenv("recompute") == "true":
120+
self.config.recompute = True
121+
self.config.recompute_granularity = os.getenv("recompute_granularity")
122+
if os.getenv("use_lazy_init") == "true":
123+
self.config.use_lazy_init = True
124+
self.gradient_accumulation_steps = int(os.getenv("acc_step"))
125+
self.amp = False
126+
self.amp_dtype = "float16"
127+
self.amp_level = "O1"
128+
self.amp_master_grad = False
129+
if os.getenv("amp") == "true":
130+
self.amp = True
131+
if os.getenv("amp_dtype") in ["float16", "bfloat16"]:
132+
self.amp_dtype = os.getenv("amp_dtype")
133+
if os.getenv("amp_level") in ["O0", "O1", "O2"]:
134+
self.amp_level = os.getenv("amp_level")
135+
if os.getenv("amp_master_grad") == "true":
136+
self.amp_master_grad = True
137+
138+
self.init_dist_env()
139+
140+
def init_dist_env(self):
141+
order = ["dp", "pp", "mp"]
142+
dp_degree = self.dp
143+
mp_degree = self.mp
144+
pp_degree = self.pp
145+
degree = [dp_degree, pp_degree, mp_degree]
146+
mesh_dims = list(filter(lambda x: x[1] > 1, list(zip(order, degree))))
147+
if not mesh_dims:
148+
mesh_dims = [("dp", 1)]
149+
dim_names = [mesh_dim[0] for mesh_dim in mesh_dims]
150+
mesh_shape = [mesh_dim[1] for mesh_dim in mesh_dims]
151+
mesh_arr = np.arange(
152+
0, reduce(lambda x, y: x * y, mesh_shape, 1)
153+
).reshape(mesh_shape)
154+
global_mesh = dist.ProcessMesh(mesh_arr, dim_names)
155+
dist.auto_parallel.set_mesh(global_mesh)
156+
157+
def run_llama(self, to_static=0):
158+
if self.config.use_lazy_init:
159+
with LazyGuard():
160+
model = LlamaForCausalLMAuto(self.config)
161+
for param in model.parameters():
162+
assert not param._is_initialized()
163+
param.initialize()
164+
else:
165+
model = LlamaForCausalLMAuto(self.config)
166+
criterion = LlamaPretrainingCriterionAuto(self.config)
167+
168+
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
169+
learning_rate=0.0001, warmup_steps=2, start_lr=0, end_lr=0.0001
170+
)
171+
optimizer = create_optimizer(model, lr_scheduler)
172+
if self.amp and not to_static:
173+
model, optimizer = paddle.amp.decorate(
174+
models=model,
175+
optimizers=optimizer,
176+
level=self.amp_level,
177+
dtype=self.amp_dtype,
178+
master_grad=self.amp_master_grad,
179+
)
180+
optimizer = dist.shard_optimizer(optimizer)
181+
182+
train_dataset = RandomDataset(self.config.seq_length)
183+
train_sampler = BatchSampler(
184+
train_dataset,
185+
batch_size=2,
186+
shuffle=True,
187+
drop_last=False,
188+
)
189+
train_dataloader = DataLoader(
190+
train_dataset,
191+
batch_sampler=train_sampler,
192+
num_workers=0,
193+
)
194+
195+
if self.pp == 1:
196+
meshes = [get_mesh(0)]
197+
elif self.pp > 1:
198+
meshes = [get_mesh(0), get_mesh(-1)]
199+
else:
200+
raise ValueError("pp should be greater or equal to 1")
201+
202+
dist_loader = dist.shard_dataloader(
203+
dataloader=train_dataloader,
204+
meshes=meshes,
205+
shard_dims="dp",
206+
)
207+
208+
global_step = 1
209+
tr_loss = float(0)
210+
211+
if not to_static:
212+
model.train()
213+
scaler = None
214+
if self.amp and self.amp_dtype == "float16":
215+
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
216+
scaler = dist.shard_scaler(scaler)
217+
218+
for epoch_idx in range(1):
219+
for step, inputs in enumerate(dist_loader()):
220+
input_ids, labels = inputs
221+
return input_ids._local_value()._md5sum()
222+
break
223+
else:
224+
strategy = dist.Strategy()
225+
if self.gradient_accumulation_steps > 1:
226+
strategy.pipeline.accumulate_steps = (
227+
self.gradient_accumulation_steps
228+
)
229+
230+
if self.amp:
231+
amp = strategy.amp
232+
amp.enable = self.amp
233+
amp.dtype = self.amp_dtype
234+
amp.level = self.amp_level.lower()
235+
if self.amp_master_grad:
236+
amp.use_master_grad = True
237+
238+
dist_model = dist.to_static(
239+
model,
240+
dist_loader,
241+
criterion,
242+
optimizer,
243+
strategy=strategy,
244+
)
245+
246+
dist_model.train()
247+
for step, inputs in enumerate(dist_loader()):
248+
input_ids, labels = inputs
249+
return input_ids._local_value()._md5sum()
250+
break
251+
252+
def run_test_cases(self):
253+
dynamic_input_mdsum = self.run_llama(to_static=0)
254+
static_input_md5sum = self.run_llama(to_static=1)
255+
if dist.get_rank() == 0:
256+
assert dynamic_input_mdsum == static_input_md5sum
257+
258+
259+
if __name__ == '__main__':
260+
TestLlamaAuto().run_test_cases()

test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model.py

+23
Original file line numberDiff line numberDiff line change
@@ -204,5 +204,28 @@ def test_simple_net_hybrid_strategy(self):
204204
)
205205

206206

207+
class TestSemiAutoParallelLlamaDataLoader(test_base.CommunicationTestDistBase):
208+
def setUp(self):
209+
super().setUp(num_of_devices=8, timeout=200, nnode=1)
210+
self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "1"}
211+
self._changeable_envs = {
212+
"backend": ["gpu"],
213+
"use_sp": ["false"],
214+
"use_param_group": ["false"],
215+
"recompute": ["true"],
216+
"recompute_granularity": ["full"],
217+
}
218+
219+
def test_simple_net_hybrid_strategy(self):
220+
envs_list = test_base.gen_product_envs_list(
221+
self._default_envs, self._changeable_envs
222+
)
223+
for envs in envs_list:
224+
self.run_test_case(
225+
"semi_auto_llama_dataloader.py",
226+
user_defined_envs=envs,
227+
)
228+
229+
207230
if __name__ == "__main__":
208231
unittest.main()

0 commit comments

Comments
 (0)