1111import numpy as np
1212from collections import OrderedDict
1313from . import datasets
14+ from . import utils
1415import warnings
1516warnings .filterwarnings ("ignore" )
1617
@@ -191,6 +192,7 @@ class DenseNet(nn.Module):
191192 model = xrv.models.DenseNet(weights="densenet121-res224-mimic_ch") # MIMIC-CXR (MIT)
192193
193194 :param weights: Specify a weight name to load pre-trained weights
195+ :param cache_dir: Override where the weights will be stored (default is ~/.torchxrayvision/)
194196 :param op_threshs: Specify a weight name to load pre-trained weights
195197 :param apply_sigmoid: Apply a sigmoid
196198
@@ -227,6 +229,7 @@ def __init__(self,
227229 num_classes = len (datasets .default_pathologies ),
228230 in_channels = 1 ,
229231 weights = None ,
232+ cache_dir = None ,
230233 op_threshs = None ,
231234 apply_sigmoid = False
232235 ):
@@ -291,7 +294,7 @@ def __init__(self,
291294 self .register_buffer ('op_threshs' , op_threshs )
292295
293296 if self .weights != None :
294- self .weights_filename_local = get_weights (weights )
297+ self .weights_filename_local = get_weights (weights , cache_dir )
295298
296299 try :
297300 savedmodel = torch .load (self .weights_filename_local , map_location = 'cpu' )
@@ -355,6 +358,7 @@ class ResNet(nn.Module):
355358 model = xrv.models.ResNet(weights="resnet50-res512-all")
356359
357360 :param weights: Specify a weight name to load pre-trained weights
361+ :param cache_dir: Override where the weights will be stored (default is ~/.torchxrayvision/)
358362 :param op_threshs: Specify a weight name to load pre-trained weights
359363 :param apply_sigmoid: Apply a sigmoid
360364
@@ -382,7 +386,7 @@ class ResNet(nn.Module):
382386 ]
383387 """"""
384388
385- def __init__ (self , weights : str = None , apply_sigmoid : bool = False ):
389+ def __init__ (self , weights : str = None , apply_sigmoid : bool = False , cache_dir : str = None ):
386390 super (ResNet , self ).__init__ ()
387391
388392 self .weights = weights
@@ -392,7 +396,7 @@ def __init__(self, weights: str = None, apply_sigmoid: bool = False):
392396 possible_weights = [k for k in model_urls .keys () if k .startswith ("resnet" )]
393397 raise Exception ("Weights value must be in {}" .format (possible_weights ))
394398
395- self .weights_filename_local = get_weights (weights )
399+ self .weights_filename_local = get_weights (weights , cache_dir = cache_dir )
396400 self .weights_dict = model_urls [weights ]
397401 self .targets = model_urls [weights ]["labels" ]
398402 self .pathologies = self .targets # keep to be backward compatible
@@ -546,39 +550,22 @@ def get_model(weights: str, **kwargs):
546550 raise Exception ("Unknown model" )
547551
548552
549- def get_weights (weights : str ):
553+ def get_weights (weights : str , cache_dir : str = None ):
550554 if not weights in model_urls :
551555 raise Exception ("Weights not found. Valid options: {}" .format (list (model_urls .keys ())))
552556
553557 url = model_urls [weights ]["weights_url" ]
554558 weights_filename = os .path .basename (url )
555- weights_storage_folder = os .path .expanduser (os .path .join ("~" , ".torchxrayvision" , "models_data" ))
559+ if cache_dir is None :
560+ weights_storage_folder = utils .get_cache_dir ()
561+ else :
562+ weights_storage_folder = cache_dir
556563 weights_filename_local = os .path .expanduser (os .path .join (weights_storage_folder , weights_filename ))
557564
558565 if not os .path .isfile (weights_filename_local ):
559566 print ("Downloading weights..." )
560567 print ("If this fails you can run `wget {} -O {}`" .format (url , weights_filename_local ))
561568 pathlib .Path (weights_storage_folder ).mkdir (parents = True , exist_ok = True )
562- download (url , weights_filename_local )
569+ utils . download (url , weights_filename_local )
563570
564571 return weights_filename_local
565-
566-
567- # from here https://sumit-ghosh.com/articles/python-download-progress-bar/
568- def download (url : str , filename : str ):
569- with open (filename , 'wb' ) as f :
570- response = requests .get (url , stream = True )
571- total = response .headers .get ('content-length' )
572-
573- if total is None :
574- f .write (response .content )
575- else :
576- downloaded = 0
577- total = int (total )
578- for data in response .iter_content (chunk_size = max (int (total / 1000 ), 1024 * 1024 )):
579- downloaded += len (data )
580- f .write (data )
581- done = int (50 * downloaded / total )
582- sys .stdout .write ('\r [{}{}]' .format ('█' * done , '.' * (50 - done )))
583- sys .stdout .flush ()
584- sys .stdout .write ('\n ' )
0 commit comments