@@ -254,21 +254,26 @@ def __init__(
254
254
self .dp_world_sizes = data_parallel_world_size
255
255
self .dp_ranks = data_parallel_rank
256
256
self .split_data = split_data
257
- # TODO: rank info
258
- self .batch_sampler = DistributedBatchSampler (
259
- dataset = self .dataset ,
260
- batch_size = self .batch_size ,
261
- num_replicas = self .dp_world_sizes [0 ],
262
- rank = self .dp_ranks [0 ],
263
- shuffle = self .shuffle ,
264
- drop_last = self .drop_last ,
265
- )
257
+
258
+ if self .batch_size is None :
259
+ self .batch_sampler = None
260
+ else :
261
+ self .batch_sampler = DistributedBatchSampler (
262
+ dataset = self .dataset ,
263
+ batch_size = self .batch_size ,
264
+ num_replicas = self .dp_world_sizes [0 ],
265
+ rank = self .dp_ranks [0 ],
266
+ shuffle = self .shuffle ,
267
+ drop_last = self .drop_last ,
268
+ )
269
+
266
270
self ._dataloader = paddle .io .DataLoader (
267
271
self .dataset ,
268
272
feed_list = self .feed_list ,
269
273
places = self .places ,
270
274
return_list = self .return_list ,
271
275
batch_sampler = self .batch_sampler ,
276
+ batch_size = 1 if self .batch_sampler else self .batch_size ,
272
277
collate_fn = self .collate_fn ,
273
278
num_workers = self .num_workers ,
274
279
use_buffer_reader = self .use_buffer_reader ,
0 commit comments