Skip to content

Commit 5a8984c

Browse files
authored
add cache_dir arg to allow files to be stored anywhere (#159)
* add cache_dir to a few models
1 parent 543946b commit 5a8984c

File tree

5 files changed

+41
-53
lines changed

5 files changed

+41
-53
lines changed

.github/workflows/tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ jobs:
3131
strategy:
3232
max-parallel: 2
3333
matrix:
34-
python-version: ['3.9']
35-
torch-version: [2.1.1]
34+
python-version: ['3.11']
35+
torch-version: [2.4.1]
3636
os: [ubuntu-latest, macos-latest, windows-latest] # only run ubuntu for now because the other ones fail for no reason, macos-latest, windows-latest]
3737

3838
# Steps represent a sequence of tasks that will be executed as part of the job

torchxrayvision/autoencoders.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import sys
66
import requests
7+
from . import utils
78

89

910
model_urls = {}
@@ -218,7 +219,7 @@ def ResNetAE101(**kwargs):
218219
return _ResNetAE(Bottleneck, DeconvBottleneck, [3, 4, 23, 2], 1, **kwargs)
219220

220221

221-
def ResNetAE(weights=None):
222+
def ResNetAE(weights=None, cache_dir=None):
222223
"""A ResNet based autoencoder.
223224
224225
Possible weights for this class include:
@@ -231,6 +232,11 @@ def ResNetAE(weights=None):
231232
z = ae.encode(image)
232233
image2 = ae.decode(z)
233234
235+
236+
params:
237+
weights (str): Weights to use. See above for options.
238+
cache_dir (str): Override directory used to store cached weights (default: ~/.torchxrayvision/)
239+
234240
"""
235241

236242
if weights == None:
@@ -245,14 +251,17 @@ def ResNetAE(weights=None):
245251
# load pretrained models
246252
url = model_urls[weights]["weights_url"]
247253
weights_filename = os.path.basename(url)
248-
weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data"))
254+
if cache_dir is None:
255+
weights_storage_folder = utils.get_cache_dir()
256+
else:
257+
weights_storage_folder = cache_dir
249258
weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename))
250259

251260
if not os.path.isfile(weights_filename_local):
252261
print("Downloading weights...")
253262
print("If this fails you can run `wget {} -O {}`".format(url, weights_filename_local))
254263
pathlib.Path(weights_storage_folder).mkdir(parents=True, exist_ok=True)
255-
download(url, weights_filename_local)
264+
utils.download(url, weights_filename_local)
256265

257266
try:
258267
state_dict = torch.load(weights_filename_local, map_location='cpu')
@@ -268,23 +277,3 @@ def ResNetAE(weights=None):
268277
ae.description = model_urls[weights]["description"]
269278

270279
return ae
271-
272-
273-
# from here https://sumit-ghosh.com/articles/python-download-progress-bar/
274-
def download(url, filename):
275-
with open(filename, 'wb') as f:
276-
response = requests.get(url, stream=True)
277-
total = response.headers.get('content-length')
278-
279-
if total is None:
280-
f.write(response.content)
281-
else:
282-
downloaded = 0
283-
total = int(total)
284-
for data in response.iter_content(chunk_size=max(int(total / 1000), 1024 * 1024)):
285-
downloaded += len(data)
286-
f.write(data)
287-
done = int(50 * downloaded / total)
288-
sys.stdout.write('\r[{}{}]'.format('█' * done, '.' * (50 - done)))
289-
sys.stdout.flush()
290-
sys.stdout.write('\n')

torchxrayvision/baseline_models/chestx_det/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torchvision
1010

1111
from .ptsemseg.pspnet import pspnet
12+
from ... import utils
1213

1314

1415
def _convert_state_dict(state_dict):
@@ -51,6 +52,10 @@ class PSPNet(nn.Module):
5152
url = {https://arxiv.org/abs/2104.10326},
5253
year = {2021}
5354
}
55+
56+
params:
57+
cache_dir (str): Override directory used to store cached weights (default: ~/.torchxrayvision/)
58+
5459
"""
5560

5661
targets: List[str] = [
@@ -62,7 +67,7 @@ class PSPNet(nn.Module):
6267
]
6368
""""""
6469

65-
def __init__(self):
70+
def __init__(self, cache_dir:str = None):
6671

6772
super(PSPNet, self).__init__()
6873

@@ -78,7 +83,10 @@ def __init__(self):
7883
url = "https://github.com/mlmed/torchxrayvision/releases/download/v1/pspnet_chestxray_best_model_4.pth"
7984

8085
weights_filename = os.path.basename(url)
81-
weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data"))
86+
if cache_dir is None:
87+
weights_storage_folder = utils.get_cache_dir()
88+
else:
89+
weights_storage_folder = cache_dir
8290
self.weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename))
8391

8492
if not os.path.isfile(self.weights_filename_local):

torchxrayvision/models.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
from collections import OrderedDict
1313
from . import datasets
14+
from . import utils
1415
import warnings
1516
warnings.filterwarnings("ignore")
1617

@@ -191,6 +192,7 @@ class DenseNet(nn.Module):
191192
model = xrv.models.DenseNet(weights="densenet121-res224-mimic_ch") # MIMIC-CXR (MIT)
192193
193194
:param weights: Specify a weight name to load pre-trained weights
195+
:param cache_dir: Override where the weights will be stored (default is ~/.torchxrayvision/)
194196
:param op_threshs: Specify a weight name to load pre-trained weights
195197
:param apply_sigmoid: Apply a sigmoid
196198
@@ -227,6 +229,7 @@ def __init__(self,
227229
num_classes=len(datasets.default_pathologies),
228230
in_channels=1,
229231
weights=None,
232+
cache_dir=None,
230233
op_threshs=None,
231234
apply_sigmoid=False
232235
):
@@ -291,7 +294,7 @@ def __init__(self,
291294
self.register_buffer('op_threshs', op_threshs)
292295

293296
if self.weights != None:
294-
self.weights_filename_local = get_weights(weights)
297+
self.weights_filename_local = get_weights(weights, cache_dir)
295298

296299
try:
297300
savedmodel = torch.load(self.weights_filename_local, map_location='cpu')
@@ -355,6 +358,7 @@ class ResNet(nn.Module):
355358
model = xrv.models.ResNet(weights="resnet50-res512-all")
356359
357360
:param weights: Specify a weight name to load pre-trained weights
361+
:param cache_dir: Override where the weights will be stored (default is ~/.torchxrayvision/)
358362
:param op_threshs: Specify a weight name to load pre-trained weights
359363
:param apply_sigmoid: Apply a sigmoid
360364
@@ -382,7 +386,7 @@ class ResNet(nn.Module):
382386
]
383387
""""""
384388

385-
def __init__(self, weights: str = None, apply_sigmoid: bool = False):
389+
def __init__(self, weights: str = None, apply_sigmoid: bool = False, cache_dir: str = None):
386390
super(ResNet, self).__init__()
387391

388392
self.weights = weights
@@ -392,7 +396,7 @@ def __init__(self, weights: str = None, apply_sigmoid: bool = False):
392396
possible_weights = [k for k in model_urls.keys() if k.startswith("resnet")]
393397
raise Exception("Weights value must be in {}".format(possible_weights))
394398

395-
self.weights_filename_local = get_weights(weights)
399+
self.weights_filename_local = get_weights(weights, cache_dir=cache_dir)
396400
self.weights_dict = model_urls[weights]
397401
self.targets = model_urls[weights]["labels"]
398402
self.pathologies = self.targets # keep to be backward compatible
@@ -546,39 +550,22 @@ def get_model(weights: str, **kwargs):
546550
raise Exception("Unknown model")
547551

548552

549-
def get_weights(weights: str):
553+
def get_weights(weights: str, cache_dir:str = None):
550554
if not weights in model_urls:
551555
raise Exception("Weights not found. Valid options: {}".format(list(model_urls.keys())))
552556

553557
url = model_urls[weights]["weights_url"]
554558
weights_filename = os.path.basename(url)
555-
weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data"))
559+
if cache_dir is None:
560+
weights_storage_folder = utils.get_cache_dir()
561+
else:
562+
weights_storage_folder = cache_dir
556563
weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename))
557564

558565
if not os.path.isfile(weights_filename_local):
559566
print("Downloading weights...")
560567
print("If this fails you can run `wget {} -O {}`".format(url, weights_filename_local))
561568
pathlib.Path(weights_storage_folder).mkdir(parents=True, exist_ok=True)
562-
download(url, weights_filename_local)
569+
utils.download(url, weights_filename_local)
563570

564571
return weights_filename_local
565-
566-
567-
# from here https://sumit-ghosh.com/articles/python-download-progress-bar/
568-
def download(url: str, filename: str):
569-
with open(filename, 'wb') as f:
570-
response = requests.get(url, stream=True)
571-
total = response.headers.get('content-length')
572-
573-
if total is None:
574-
f.write(response.content)
575-
else:
576-
downloaded = 0
577-
total = int(total)
578-
for data in response.iter_content(chunk_size=max(int(total / 1000), 1024 * 1024)):
579-
downloaded += len(data)
580-
f.write(data)
581-
done = int(50 * downloaded / total)
582-
sys.stdout.write('\r[{}{}]'.format('█' * done, '.' * (50 - done)))
583-
sys.stdout.flush()
584-
sys.stdout.write('\n')

torchxrayvision/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
import numpy as np
44
import skimage
55
import torch
6+
import os
67

78
from os import PathLike
89
from numpy import ndarray
910
import warnings
1011
from tqdm.autonotebook import tqdm
1112

1213

14+
def get_cache_dir():
15+
return os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data/"))
16+
1317
def in_notebook():
1418
try:
1519
from IPython import get_ipython

0 commit comments

Comments
 (0)