Skip to content

Commit 3e1d0ad

Browse files
Allow initialize dataloader without specifying 'sampler' (#809)
1 parent 14430f1 commit 3e1d0ad

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

ppsci/data/__init__.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,21 @@ def build_dataloader(_dataset, cfg):
8282
sampler_cfg["batch_size"] = cfg["batch_size"]
8383
batch_sampler = getattr(io, batch_sampler_cls)(_dataset, **sampler_cfg)
8484
else:
85-
if cfg["batch_size"] != 1:
86-
raise ValueError(
87-
f"`batch_size` should be 1 when sampler config is None, but got {cfg['batch_size']}."
85+
batch_sampler_cls = "BatchSampler"
86+
if world_size > 1:
87+
batch_sampler_cls = "DistributedBatchSampler"
88+
logger.warning(
89+
f"Automatically use 'DistributedBatchSampler' instead of "
90+
f"'BatchSampler' when world_size({world_size}) > 1."
8891
)
89-
logger.warning(
90-
"`batch_size` is set to 1 as neither sampler config nor batch_size is set."
91-
)
92-
batch_sampler = io.BatchSampler(
92+
batch_sampler = getattr(io, batch_sampler_cls)(
9393
_dataset,
9494
batch_size=cfg["batch_size"],
95+
shuffle=False,
96+
drop_last=False,
97+
)
98+
logger.message(
99+
"'shuffle' and 'drop_last' are both set to False in default as sampler config is not specified."
95100
)
96101

97102
# build collate_fn if specified

0 commit comments

Comments
 (0)