Skip to content

Commit 806919d

Browse files
authored
Add View classifier (#143)
* initial commit of view classifier * initial commit of view classifier * add model to docs * add test
1 parent 42e8624 commit 806919d

File tree

6 files changed

+572
-1
lines changed

6 files changed

+572
-1
lines changed

docs/source/models.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,9 @@ Riken Age Model
5555

5656
.. automodule:: torchxrayvision.baseline_models.riken
5757
:members:
58+
59+
Xinario View Model
60+
+++++++++++++++
61+
62+
.. automodule:: torchxrayvision.baseline_models.xinario
63+
:members:

scripts/view_classifier.ipynb

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "b28dd454-0a3a-4e0f-8e74-6c4ad712b783",
7+
"metadata": {
8+
"tags": []
9+
},
10+
"outputs": [],
11+
"source": [
12+
"%load_ext autoreload\n",
13+
"%autoreload 2"
14+
]
15+
},
16+
{
17+
"cell_type": "code",
18+
"execution_count": 2,
19+
"id": "94e0fa4d-efad-41e1-9314-09a2ef12a438",
20+
"metadata": {
21+
"tags": []
22+
},
23+
"outputs": [],
24+
"source": [
25+
"import torchxrayvision as xrv\n",
26+
"import sys\n",
27+
"import numpy as np\n",
28+
"import torch\n",
29+
"import torchvision\n",
30+
"import matplotlib.pyplot as plt\n",
31+
"import dataset_utils"
32+
]
33+
},
34+
{
35+
"cell_type": "code",
36+
"execution_count": null,
37+
"id": "98ed4012-c6bb-4988-802f-0c11f2cd0057",
38+
"metadata": {},
39+
"outputs": [],
40+
"source": []
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": 3,
45+
"id": "794438a4-856b-4eaa-b46e-c1c654a242e1",
46+
"metadata": {
47+
"tags": []
48+
},
49+
"outputs": [],
50+
"source": [
51+
"model = xrv.baseline_models.xinario.ViewModel()"
52+
]
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": null,
57+
"id": "4d608d78-ad75-466f-84f4-f9d5a4237e37",
58+
"metadata": {
59+
"tags": []
60+
},
61+
"outputs": [],
62+
"source": []
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": 4,
67+
"id": "7abcf913-ee6b-4d6e-b8fd-d4dcd9cd058f",
68+
"metadata": {
69+
"tags": []
70+
},
71+
"outputs": [
72+
{
73+
"name": "stdout",
74+
"output_type": "stream",
75+
"text": [
76+
"['Granuloma', 'Hemidiaphragm Elevation', 'Pleural_Thickening', 'Nodule', 'Mass', 'Cardiomegaly', 'Consolidation', 'Fibrosis', 'Scoliosis', 'Fracture', 'Atelectasis', 'Emphysema', 'Effusion', 'Air Trapping', 'Aortic Atheromatosis', 'Support Devices', 'Tuberculosis', 'Pneumothorax', 'Costophrenic Angle Blunting', 'Hilar Enlargement', 'Flattened Diaphragm', 'Edema', 'Bronchiectasis', 'Infiltration', 'Tube', 'Aortic Elongation', 'Pneumonia', 'Hernia']\n"
77+
]
78+
}
79+
],
80+
"source": [
81+
"transform = torchvision.transforms.Compose([\n",
82+
" xrv.datasets.XRayCenterCrop(),\n",
83+
" xrv.datasets.XRayResizer(224)\n",
84+
"])\n",
85+
"d = dataset_utils.get_data('pc', views='*', transform=transform)"
86+
]
87+
},
88+
{
89+
"cell_type": "code",
90+
"execution_count": 5,
91+
"id": "7164bd40-b0c5-4dd6-b21a-50939ee5236e",
92+
"metadata": {
93+
"tags": []
94+
},
95+
"outputs": [
96+
{
97+
"data": {
98+
"text/plain": [
99+
"array(['PA', 'L', 'AP', 'AP Supine', 'COSTAL', 'UNK', 'EXCLUDE'],\n",
100+
" dtype=object)"
101+
]
102+
},
103+
"execution_count": 5,
104+
"metadata": {},
105+
"output_type": "execute_result"
106+
}
107+
],
108+
"source": [
109+
"d.csv.view.unique()"
110+
]
111+
},
112+
{
113+
"cell_type": "code",
114+
"execution_count": 6,
115+
"id": "2984fdc2-3ab8-4edf-9cb2-7aca381ffe2f",
116+
"metadata": {
117+
"tags": []
118+
},
119+
"outputs": [],
120+
"source": [
121+
"frontal = np.where(d.csv.view == 'PA')[0]"
122+
]
123+
},
124+
{
125+
"cell_type": "code",
126+
"execution_count": 7,
127+
"id": "47717a84-2120-459e-ad08-df074bdadce3",
128+
"metadata": {
129+
"tags": []
130+
},
131+
"outputs": [],
132+
"source": [
133+
"lateral = np.where(d.csv.view == 'L')[0]"
134+
]
135+
},
136+
{
137+
"cell_type": "code",
138+
"execution_count": 8,
139+
"id": "342ba82d-5504-4e65-80cd-40e467bcf3a6",
140+
"metadata": {
141+
"tags": []
142+
},
143+
"outputs": [
144+
{
145+
"name": "stdout",
146+
"output_type": "stream",
147+
"text": [
148+
"tensor([[23.1546, 16.9751]]) Frontal\n",
149+
"tensor([[23.6190, 15.1804]]) Frontal\n",
150+
"tensor([[23.9368, 15.9114]]) Frontal\n",
151+
"tensor([[20.4266, 14.5170]]) Frontal\n",
152+
"tensor([[25.9273, 14.4245]]) Frontal\n",
153+
"tensor([[24.4080, 13.7654]]) Frontal\n",
154+
"tensor([[25.0222, 15.7349]]) Frontal\n",
155+
"tensor([[23.8637, 16.7607]]) Frontal\n",
156+
"tensor([[22.3303, 13.5714]]) Frontal\n",
157+
"tensor([[21.2553, 14.0465]]) Frontal\n"
158+
]
159+
}
160+
],
161+
"source": [
162+
"for i in frontal[:10]:\n",
163+
" img = d[i]['img']\n",
164+
" with torch.no_grad():\n",
165+
" output = model(torch.from_numpy(img))\n",
166+
" print(output, model.targets[output.argmax()])"
167+
]
168+
},
169+
{
170+
"cell_type": "code",
171+
"execution_count": 9,
172+
"id": "f9161e17-f505-4cf1-87ad-40a73e79ae21",
173+
"metadata": {
174+
"tags": []
175+
},
176+
"outputs": [
177+
{
178+
"name": "stdout",
179+
"output_type": "stream",
180+
"text": [
181+
"tensor([[17.3186, 26.7156]]) Lateral\n",
182+
"tensor([[15.9319, 24.5127]]) Lateral\n",
183+
"tensor([[20.1788, 34.1056]]) Lateral\n",
184+
"tensor([[20.5084, 35.7469]]) Lateral\n",
185+
"tensor([[20.0122, 36.1225]]) Lateral\n",
186+
"tensor([[20.1512, 29.6003]]) Lateral\n",
187+
"tensor([[21.8098, 32.7101]]) Lateral\n",
188+
"tensor([[18.7384, 35.3062]]) Lateral\n",
189+
"tensor([[19.8528, 28.8093]]) Lateral\n",
190+
"tensor([[20.8488, 33.3455]]) Lateral\n"
191+
]
192+
}
193+
],
194+
"source": [
195+
"for i in lateral[:10]:\n",
196+
" img = d[i]['img']\n",
197+
" with torch.no_grad():\n",
198+
" output = model(torch.from_numpy(img))\n",
199+
" print(output, model.targets[output.argmax()])"
200+
]
201+
},
202+
{
203+
"cell_type": "code",
204+
"execution_count": null,
205+
"id": "69b9abb6-bf8c-451b-8eea-237da57dc6dc",
206+
"metadata": {},
207+
"outputs": [],
208+
"source": []
209+
},
210+
{
211+
"cell_type": "code",
212+
"execution_count": null,
213+
"id": "fdfb6d6b-3093-468e-b32a-cc3bb857994f",
214+
"metadata": {},
215+
"outputs": [],
216+
"source": []
217+
},
218+
{
219+
"cell_type": "code",
220+
"execution_count": null,
221+
"id": "b35d5965-10d4-4d1e-bfca-4df7afbd5f7d",
222+
"metadata": {
223+
"tags": []
224+
},
225+
"outputs": [],
226+
"source": []
227+
},
228+
{
229+
"cell_type": "code",
230+
"execution_count": null,
231+
"id": "5b25a4c6-e41f-4f0b-8103-672902ab2d28",
232+
"metadata": {},
233+
"outputs": [],
234+
"source": []
235+
}
236+
],
237+
"metadata": {
238+
"kernelspec": {
239+
"display_name": "Python 3 (ipykernel)",
240+
"language": "python",
241+
"name": "python3"
242+
},
243+
"language_info": {
244+
"codemirror_mode": {
245+
"name": "ipython",
246+
"version": 3
247+
},
248+
"file_extension": ".py",
249+
"mimetype": "text/x-python",
250+
"name": "python",
251+
"nbconvert_exporter": "python",
252+
"pygments_lexer": "ipython3",
253+
"version": "3.9.0"
254+
}
255+
},
256+
"nbformat": 4,
257+
"nbformat_minor": 5
258+
}

tests/test_baseline_models.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,19 @@ def test_baselinemodel_riken_age_function():
5353

5454
assert dzdxp.shape == torch.Size([2, 1, 224, 224]), 'check grads are the correct size'
5555

56-
assert torch.isnan(dzdxp.flatten()).sum().cpu().numpy() == 0 , 'check no grads are nans'
56+
assert torch.isnan(dzdxp.flatten()).sum().cpu().numpy() == 0 , 'check no grads are nans'
57+
58+
59+
def test_baselinemodel_xinario_function():
60+
61+
model = xrv.baseline_models.xinario.ViewModel()
62+
63+
img = torch.ones(1, 1, 224, 224)
64+
img.requires_grad = True
65+
pred = model(img)[:,model.targets.index("Lateral")]
66+
assert pred.shape == torch.Size([1]), 'check output is correct shape'
67+
68+
dzdxp = torch.autograd.grad((pred), img)[0]
69+
assert dzdxp.shape == torch.Size([1, 1, 224, 224]), 'check grads are the correct size'
70+
71+
assert torch.isnan(dzdxp.flatten()).sum().cpu().numpy() == 0

torchxrayvision/baseline_models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from . import chestx_det
44
from . import emory_hiti
55
from . import riken
6+
from . import xinario

0 commit comments

Comments
 (0)