Skip to content

Commit 81d800b

Browse files
authored
[AutoParallel]fix dist_loader when batch size is None (#60234)
1 parent 1815b99 commit 81d800b

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

python/paddle/distributed/auto_parallel/static/dist_loader.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -254,21 +254,26 @@ def __init__(
254254
self.dp_world_sizes = data_parallel_world_size
255255
self.dp_ranks = data_parallel_rank
256256
self.split_data = split_data
257-
# TODO: rank info
258-
self.batch_sampler = DistributedBatchSampler(
259-
dataset=self.dataset,
260-
batch_size=self.batch_size,
261-
num_replicas=self.dp_world_sizes[0],
262-
rank=self.dp_ranks[0],
263-
shuffle=self.shuffle,
264-
drop_last=self.drop_last,
265-
)
257+
258+
if self.batch_size is None:
259+
self.batch_sampler = None
260+
else:
261+
self.batch_sampler = DistributedBatchSampler(
262+
dataset=self.dataset,
263+
batch_size=self.batch_size,
264+
num_replicas=self.dp_world_sizes[0],
265+
rank=self.dp_ranks[0],
266+
shuffle=self.shuffle,
267+
drop_last=self.drop_last,
268+
)
269+
266270
self._dataloader = paddle.io.DataLoader(
267271
self.dataset,
268272
feed_list=self.feed_list,
269273
places=self.places,
270274
return_list=self.return_list,
271275
batch_sampler=self.batch_sampler,
276+
batch_size=1 if self.batch_sampler else self.batch_size,
272277
collate_fn=self.collate_fn,
273278
num_workers=self.num_workers,
274279
use_buffer_reader=self.use_buffer_reader,

0 commit comments

Comments
 (0)