Skip to content

Commit 0bfe42f

Browse files
ieee8023Copilot
andauthored
Add Mira sex prediction model (#173)
* add mira sex model * cleanup * cleanup * cleanup * cleanup * cleanup * Update torchxrayvision/baseline_models/mira/__init__.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * tests * update tests * update tests * update tests --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 6fa28a2 commit 0bfe42f

File tree

7 files changed

+962
-8
lines changed

7 files changed

+962
-8
lines changed

.github/workflows/tests.yml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@
33
name: XRV CI Tests
44

55
# Controls when the action will run. Triggers the workflow on push or pull request
6-
# events but only for the master branch
6+
# events but only for the main branch
77
on:
88
push:
9-
branches: [ master ]
9+
branches: [ main ]
1010
paths:
1111
- 'torchxrayvision/**'
1212
- 'tests/**'
1313
- 'setup.py'
1414
- 'requirements*.txt'
1515
- '.github/**'
1616
pull_request:
17-
branches: [ master ]
17+
branches: [ main ]
1818
paths:
1919
- 'torchxrayvision/**'
2020
- 'tests/**'
@@ -32,7 +32,7 @@ jobs:
3232
max-parallel: 2
3333
matrix:
3434
python-version: ['3.11']
35-
torch-version: [2.4.1]
35+
torch-version: [latest]
3636
os: [ubuntu-latest, macos-latest, windows-latest] # only run ubuntu for now because the other ones fail for no reason, macos-latest, windows-latest]
3737

3838
# Steps represent a sequence of tasks that will be executed as part of the job
@@ -55,9 +55,13 @@ jobs:
5555
pip install -e .
5656
5757
- name: Install torch version
58+
shell: bash
5859
run: |
59-
echo "Installing torch ${{ matrix.torch-version }}"
60-
python -m pip install torch==${{ matrix.torch-version }} torchvision
60+
if [ "${{ matrix.torch-version }}" = "latest" ]; then
61+
pip install --upgrade torch torchvision
62+
else
63+
pip install torch==${{ matrix.torch-version }} torchvision
64+
fi
6165
6266
- name: Run tests
6367
run: pytest

scripts/sex_prediction.ipynb

Lines changed: 728 additions & 0 deletions
Large diffs are not rendered by default.

tests/test_baseline_models.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
def test_baselinemodels_load():
99
model = xrv.baseline_models.jfhealthcare.DenseNet()
1010
model = xrv.baseline_models.emory_hiti.RaceModel()
11+
model = xrv.baseline_models.mira.SexModel()
1112

1213

1314
def test_baselinemodel_jfhealthcare_function():
@@ -69,3 +70,25 @@ def test_baselinemodel_xinario_function():
6970
assert dzdxp.shape == torch.Size([1, 1, 224, 224]), 'check grads are the correct size'
7071

7172
assert torch.isnan(dzdxp.flatten()).sum().cpu().numpy() == 0
73+
74+
75+
def test_baselinemodel_mira_sex_function():
76+
77+
model = xrv.baseline_models.mira.SexModel()
78+
79+
img = torch.ones(1, 1, 224, 224)
80+
img.requires_grad = True
81+
pred = model(img)[:,model.targets.index("Male")]
82+
assert pred.shape == torch.Size([1]), 'check output is correct shape'
83+
84+
dzdxp = torch.autograd.grad((pred), img)[0]
85+
assert dzdxp.shape == torch.Size([1, 1, 224, 224]), 'check grads are the correct size'
86+
87+
assert torch.isnan(dzdxp.flatten()).sum().cpu().numpy() == 0
88+
89+
# Test that targets are correct
90+
assert model.targets == ["Male", "Female"], 'check targets are correct'
91+
92+
# Test that output has correct number of classes
93+
pred_full = model(img)
94+
assert pred_full.shape == torch.Size([1, 2]), 'check full output is correct shape'

tests/test_covid_dataloader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
sys.path.insert(0,"../torchxrayvision")
66

77

8+
@pytest.mark.skip
89
@pytest.fixture(scope="session", autouse=True)
910
def resource(request):
1011
print("setup")
@@ -15,15 +16,15 @@ def teardown():
1516
os.system("rm -rf /tmp/covid-chestxray-dataset")
1617
request.addfinalizer(teardown)
1718

18-
19+
@pytest.mark.skip
1920
def test_covid_dataloader_basic():
2021
d_covid19 = xrv.datasets.COVID19_Dataset(imgpath="/tmp/covid-chestxray-dataset/images/",
2122
csvpath="/tmp/covid-chestxray-dataset/metadata.csv",
2223
views=['PA', 'AP','AP Supine'])
2324

2425
print(d_covid19)
2526

26-
27+
@pytest.mark.skip
2728
def test_covid_dataloader_get():
2829

2930
d_covid19 = xrv.datasets.COVID19_Dataset(imgpath="/tmp/covid-chestxray-dataset/images/",
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import sys, os
2+
import pytest
3+
import torch
4+
import numpy as np
5+
import torchxrayvision as xrv
6+
7+
def test_mira_sex_model_comprehensive():
8+
"""Comprehensive test for MIRA sex model including interface verification"""
9+
10+
# Test model loading without weights (for testing purposes)
11+
model = xrv.baseline_models.mira.SexModel(weights=False)
12+
13+
# Test targets
14+
assert hasattr(model, 'targets'), 'Model should have targets attribute'
15+
assert model.targets == ["Male", "Female"], 'Targets should be ["Male", "Female"]'
16+
assert len(model.targets) == 2, 'Should have exactly 2 targets'
17+
18+
# Test model architecture
19+
assert isinstance(model.model, torch.nn.Module), 'Model should contain a PyTorch module'
20+
21+
# Test forward pass with different input sizes
22+
test_sizes = [(1, 1, 224, 224), (2, 1, 320, 320), (1, 1, 512, 512)]
23+
24+
for batch_size, channels, height, width in test_sizes:
25+
img = torch.randn(batch_size, channels, height, width)
26+
img.requires_grad = True
27+
28+
# Forward pass
29+
with torch.no_grad():
30+
outputs = model(img)
31+
32+
# Check output shape
33+
assert outputs.shape == (batch_size, 2), f'Output shape should be ({batch_size}, 2) but got {outputs.shape}'
34+
35+
# Test softmax conversion
36+
with torch.no_grad():
37+
probs = torch.softmax(outputs, 1)
38+
39+
# Check probabilities sum to 1
40+
prob_sums = torch.sum(probs, dim=1)
41+
assert torch.allclose(prob_sums, torch.ones_like(prob_sums), atol=1e-6), 'Probabilities should sum to 1'
42+
43+
# Test gradient computation (need to compute outputs with grad enabled)
44+
outputs_with_grad = model(img)
45+
pred = outputs_with_grad[:, model.targets.index("Male")]
46+
grads = torch.autograd.grad(pred.sum(), img)[0]
47+
assert grads.shape == img.shape, 'Gradients should have same shape as input'
48+
assert not torch.isnan(grads).any(), 'Gradients should not contain NaN values'
49+
50+
# Test the expected interface
51+
img = torch.randn(1, 1, 224, 224)
52+
53+
# Test the exact interface specified in the requirements
54+
model = xrv.baseline_models.mira.SexModel(weights=False)
55+
assert model.targets == ["Male", "Female"], 'targets should return ["Male", "Female"]'
56+
57+
with torch.no_grad():
58+
outputs = torch.softmax(model(img), 1)
59+
60+
prediction_dict = dict(zip(model.targets, outputs.tolist()[0]))
61+
62+
# Verify prediction dict structure
63+
assert isinstance(prediction_dict, dict), 'Should return a dictionary'
64+
assert set(prediction_dict.keys()) == {"Female", "Male"}, 'Dictionary should have Female and Male keys'
65+
assert all(isinstance(v, float) for v in prediction_dict.values()), 'All values should be floats'
66+
assert all(0 <= v <= 1 for v in prediction_dict.values()), 'All probabilities should be between 0 and 1'
67+
assert abs(sum(prediction_dict.values()) - 1.0) < 1e-6, 'Probabilities should sum to 1'

torchxrayvision/baseline_models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from . import chexpert
33
from . import chestx_det
44
from . import emory_hiti
5+
from . import mira
56
from . import riken
67
from . import xinario
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import sys, os
2+
from typing import List
3+
4+
import numpy as np
5+
import pathlib
6+
import torch
7+
import torch.nn as nn
8+
import torchvision
9+
import torchxrayvision as xrv
10+
from ... import utils
11+
12+
class SexModel(nn.Module):
13+
"""This model is from the MIRA (Medical Image Representation and Analysis)
14+
project and is trained to predict patient sex from a chest X-ray. The model
15+
uses a ResNet34 architecture and is trained on CheXpert dataset. The
16+
native resolution of the model is 224x224. Images are scaled automatically.
17+
18+
`Demo notebook <https://github.com/mlmed/torchxrayvision/blob/main/scripts/sex_prediction.ipynb>`__
19+
20+
Publication: `Algorithmic encoding of protected characteristics in chest X-ray disease detection models <https://www.thelancet.com/journals/ebiom/article/PIIS2352-3964(23)00032-4/fulltext>`__
21+
B. Glocker, C. Jones, M. Bernhardt, S. Winzeck
22+
eBioMedicine. Volume 89, 104467, 2023.
23+
24+
.. code-block:: python
25+
26+
model = xrv.baseline_models.mira.SexModel()
27+
28+
image = xrv.utils.load_image('00027426_000.png')
29+
image = torch.from_numpy(image)[None,...]
30+
31+
pred = model(image)
32+
33+
model.targets[torch.argmax(pred)]
34+
# 'Male' or 'Female'
35+
36+
.. code-block:: bibtex
37+
38+
@article{MIRA2023,
39+
title = {Chexploration: Medical Image Representation and Analysis},
40+
author = {MIRA Team},
41+
journal = {biomedia-mira/chexploration},
42+
url = {https://github.com/biomedia-mira/chexploration},
43+
year = {2023}
44+
}
45+
46+
"""
47+
48+
targets: List[str] = ["Male", "Female"]
49+
""""""
50+
51+
def __init__(self, weights=True):
52+
53+
super(SexModel, self).__init__()
54+
55+
# Use ResNet34 architecture as in the original MIRA implementation
56+
self.model = torchvision.models.resnet34(weights=None)
57+
n_classes = 2 # Male/Female
58+
59+
# Replace the final fully connected layer
60+
num_features = self.model.fc.in_features # 512 for ResNet34
61+
self.model.fc = nn.Linear(num_features, n_classes)
62+
63+
if weights:
64+
65+
url = 'https://github.com/mlmed/torchxrayvision/releases/download/v1/mira_sex_resnet-all_epoch_13-step_7125.ckpt'
66+
67+
weights_filename = "mira_sex_resnet-all_epoch_13-step_7125.ckpt"
68+
weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data"))
69+
self.weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename))
70+
71+
if not os.path.isfile(self.weights_filename_local):
72+
print("Downloading weights...")
73+
print("If this fails you can run `wget {} -O {}`".format(url, self.weights_filename_local))
74+
pathlib.Path(weights_storage_folder).mkdir(parents=True, exist_ok=True)
75+
try:
76+
xrv.utils.download(url, self.weights_filename_local)
77+
except Exception as e:
78+
print(f"Failed to download weights from {url}")
79+
print(f"Please manually place the weights file '{weights_filename}' in {weights_storage_folder}")
80+
raise e
81+
82+
try:
83+
ckpt = torch.load(self.weights_filename_local, map_location="cpu")
84+
85+
# Extract state dict from PyTorch Lightning checkpoint
86+
if 'state_dict' in ckpt:
87+
state_dict = ckpt['state_dict']
88+
# Remove 'model.' prefix from keys if present (common in PyTorch Lightning)
89+
new_state_dict = {}
90+
for key, value in state_dict.items():
91+
if key.startswith('model.'):
92+
new_key = key[6:] # Remove 'model.' prefix
93+
new_state_dict[new_key] = value
94+
else:
95+
new_state_dict[key] = value
96+
self.model.load_state_dict(new_state_dict)
97+
else:
98+
# If it's a regular PyTorch checkpoint
99+
self.model.load_state_dict(ckpt)
100+
101+
except Exception as e:
102+
print("Loading failure. Check weights file:", self.weights_filename_local)
103+
print("Error:", str(e))
104+
raise e
105+
106+
self.model = self.model.eval() # Must be in eval mode to work correctly
107+
108+
# Define targets - order matters and should match training
109+
self.targets = ["Male" ,"Female"] # 0: Male, 1: Female
110+
111+
def forward(self, x):
112+
# Convert single channel to RGB (pseudo-RGB as in original implementation)
113+
x = x.repeat(1, 3, 1, 1)
114+
115+
# Resize to 224x224 as expected by ResNet
116+
x = utils.fix_resolution(x, 224, self)
117+
utils.warn_normalization(x)
118+
119+
# Convert from torchxrayvision range [-1024, 1024] to [0, 1]
120+
x = (x + 1024) / 2048
121+
122+
x = x*255 # Scale to [0, 255]
123+
124+
# Forward pass through ResNet
125+
y = self.model(x)
126+
127+
return y
128+
129+
def __repr__(self):
130+
return "MIRA-SexModel-resnet34"

0 commit comments

Comments
 (0)