File tree 1 file changed +12
-7
lines changed
1 file changed +12
-7
lines changed Original file line number Diff line number Diff line change @@ -82,16 +82,21 @@ def build_dataloader(_dataset, cfg):
82
82
sampler_cfg ["batch_size" ] = cfg ["batch_size" ]
83
83
batch_sampler = getattr (io , batch_sampler_cls )(_dataset , ** sampler_cfg )
84
84
else :
85
- if cfg ["batch_size" ] != 1 :
86
- raise ValueError (
87
- f"`batch_size` should be 1 when sampler config is None, but got { cfg ['batch_size' ]} ."
85
+ batch_sampler_cls = "BatchSampler"
86
+ if world_size > 1 :
87
+ batch_sampler_cls = "DistributedBatchSampler"
88
+ logger .warning (
89
+ f"Automatically use 'DistributedBatchSampler' instead of "
90
+ f"'BatchSampler' when world_size({ world_size } ) > 1."
88
91
)
89
- logger .warning (
90
- "`batch_size` is set to 1 as neither sampler config nor batch_size is set."
91
- )
92
- batch_sampler = io .BatchSampler (
92
+ batch_sampler = getattr (io , batch_sampler_cls )(
93
93
_dataset ,
94
94
batch_size = cfg ["batch_size" ],
95
+ shuffle = False ,
96
+ drop_last = False ,
97
+ )
98
+ logger .message (
99
+ "'shuffle' and 'drop_last' are both set to False in default as sampler config is not specified."
95
100
)
96
101
97
102
# build collate_fn if specified
You can’t perform that action at this time.
0 commit comments