@@ -202,24 +202,26 @@ def loop_check(callback, item):
202
202
for each in item :
203
203
callback (each )
204
204
205
+
205
206
class CheckInputTypeWrapper (object ):
206
207
def __init__ (self , generator , input_types , logger ):
207
208
self .generator = generator
208
209
self .input_types = input_types
209
210
self .logger = logger
210
211
211
- def __call__ (self , obj , filename ):
212
- for items in self .generator (obj , filename ):
213
- try :
214
- # dict type is required for input_types when item is dict type
215
- assert (isinstance (items , dict ) and \
216
- not isinstance (self .input_types , dict ))== False
217
- yield items
218
- except AssertionError as e :
219
- self .logger .error (
212
+ def __call__ (self , obj , filename ):
213
+ for items in self .generator (obj , filename ):
214
+ try :
215
+ # dict type is required for input_types when item is dict type
216
+ assert (isinstance (items , dict ) and \
217
+ not isinstance (self .input_types , dict ))== False
218
+ yield items
219
+ except AssertionError as e :
220
+ self .logger .error (
220
221
"%s type is required for input type but got %s" %
221
222
(repr (type (items )), repr (type (self .input_types ))))
222
- raise
223
+ raise
224
+
223
225
224
226
def provider (input_types = None ,
225
227
should_shuffle = None ,
@@ -374,8 +376,8 @@ def __init__(self, file_list, **kwargs):
374
376
self .generator = InputOrderWrapper (self .generator ,
375
377
self .input_order )
376
378
else :
377
- self .generator = CheckInputTypeWrapper (self . generator , self . slots ,
378
- self .logger )
379
+ self .generator = CheckInputTypeWrapper (
380
+ self . generator , self . slots , self .logger )
379
381
if self .check :
380
382
self .generator = CheckWrapper (self .generator , self .slots ,
381
383
check_fail_continue ,
0 commit comments