Skip to content

Commit c7ae49f

Browse files
authored
Merge pull request #90 from SunAhong1993/syf0520
add interpret vis docs
2 parents 19e7e12 + 35fb9cb commit c7ae49f

File tree

6 files changed

+124
-41
lines changed

6 files changed

+124
-41
lines changed

docs/apis/visualize.md

+39-12
Original file line numberDiff line numberDiff line change
@@ -114,27 +114,54 @@ pdx.slim.visualize(model, 'mobilenetv2.sensitivities', save_dir='./')
114114
# 可视化结果保存在./sensitivities.png
115115
```
116116

117-
## 可解释性结果可视化
117+
## LIME可解释性结果可视化
118118
```
119-
paddlex.interpret.visualize(img_file,
120-
model,
121-
dataset=None,
122-
algo='lime',
123-
num_samples=3000,
124-
batch_size=50,
125-
save_dir='./')
119+
paddlex.interpret.lime(img_file,
120+
model,
121+
num_samples=3000,
122+
batch_size=50,
123+
save_dir='./')
126124
```
127-
将模型预测结果的可解释性可视化,目前只支持分类模型。
125+
使用LIME算法将模型预测结果的可解释性可视化。
126+
LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,得到每个输入维度的权重,以此来解释模型。
127+
128+
**注意:** 可解释性结果可视化目前只支持分类模型。
128129

129130
### 参数
130131
>* **img_file** (str): 预测图像路径。
131132
>* **model** (paddlex.cv.models): paddlex中的模型。
132-
>* **dataset** (paddlex.datasets): 数据集读取器,默认为None。
133-
>* **algo** (str): 可解释性方式,当前可选'lime'和'normlime'。
134133
>* **num_samples** (int): LIME用于学习线性模型的采样数,默认为3000。
135134
>* **batch_size** (int): 预测数据batch大小,默认为50。
136135
>* **save_dir** (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
137136
138137

139138
### 使用示例
140-
> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/interpret.py)
139+
> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/lime.py)
140+
141+
142+
## NormLIME可解释性结果可视化
143+
```
144+
paddlex.interpret.normlime(img_file,
145+
model,
146+
dataset=None,
147+
num_samples=3000,
148+
batch_size=50,
149+
save_dir='./')
150+
```
151+
使用NormLIME算法将模型预测结果的可解释性可视化。
152+
NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。
153+
154+
**注意:** 可解释性结果可视化目前只支持分类模型。
155+
156+
### 参数
157+
>* **img_file** (str): 预测图像路径。
158+
>* **model** (paddlex.cv.models): paddlex中的模型。
159+
>* **dataset** (paddlex.datasets): 数据集读取器,默认为None。
160+
>* **num_samples** (int): LIME用于学习线性模型的采样数,默认为3000。
161+
>* **batch_size** (int): 预测数据batch大小,默认为50。
162+
>* **save_dir** (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
163+
164+
**注意:** dataset`读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。
165+
### 使用示例
166+
> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/normlime.py)
167+

paddlex/interpret/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
from __future__ import absolute_import
1616
from . import visualize
1717

18-
visualize = visualize.visualize
18+
lime = visualize.lime
19+
normlime = visualize.normlime

paddlex/interpret/core/normlime_base.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,8 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
116116
if os.path.exists(save_path):
117117
logging.info(save_path + ' exists, not computing this one.', use_color=True)
118118
continue
119-
120-
logging.info('processing'+each_data_ if isinstance(each_data_, str) else data_index + \
121-
f'+{data_index}/{len(list_data_)}', use_color=True)
119+
img_file_name = each_data_ if isinstance(each_data_, str) else data_index
120+
logging.info('processing '+ img_file_name + ' [{}/{}]'.format(data_index, len(list_data_)), use_color=True)
122121

123122
image_show = read_image(each_data_)
124123
result = predict_fn(image_show)

paddlex/interpret/visualize.py

+57-17
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,65 @@
2222
from .core.interpretation import Interpretation
2323
from .core.normlime_base import precompute_normlime_weights
2424
from .core._session_preparation import gen_user_home
25-
26-
def visualize(img_file,
25+
26+
def lime(img_file,
27+
model,
28+
num_samples=3000,
29+
batch_size=50,
30+
save_dir='./'):
31+
"""使用LIME算法将模型预测结果的可解释性可视化。
32+
33+
LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,
34+
在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入
35+
和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,
36+
得到每个输入维度的权重,以此来解释模型。
37+
38+
注意:LIME可解释性结果可视化目前只支持分类模型。
39+
40+
Args:
41+
img_file (str): 预测图像路径。
42+
model (paddlex.cv.models): paddlex中的模型。
43+
num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
44+
batch_size (int): 预测数据batch大小,默认为50。
45+
save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
46+
"""
47+
assert model.model_type == 'classifier', \
48+
'Now the interpretation visualize only be supported in classifier!'
49+
if model.status != 'Normal':
50+
raise Exception('The interpretation only can deal with the Normal model')
51+
if not osp.exists(save_dir):
52+
os.makedirs(save_dir)
53+
model.arrange_transforms(
54+
transforms=model.test_transforms, mode='test')
55+
tmp_transforms = copy.deepcopy(model.test_transforms)
56+
tmp_transforms.transforms = tmp_transforms.transforms[:-2]
57+
img = tmp_transforms(img_file)[0]
58+
img = np.around(img).astype('uint8')
59+
img = np.expand_dims(img, axis=0)
60+
interpreter = None
61+
interpreter = get_lime_interpreter(img, model, num_samples=num_samples, batch_size=batch_size)
62+
img_name = osp.splitext(osp.split(img_file)[-1])[0]
63+
interpreter.interpret(img, save_dir=save_dir)
64+
65+
66+
def normlime(img_file,
2767
model,
2868
dataset=None,
29-
algo='lime',
3069
num_samples=3000,
3170
batch_size=50,
3271
save_dir='./'):
33-
"""可解释性可视化。
72+
"""使用NormLIME算法将模型预测结果的可解释性可视化。
73+
74+
NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测
75+
试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。
76+
77+
注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。
78+
注意2:NormLIME可解释性结果可视化目前只支持分类模型。
79+
3480
Args:
3581
img_file (str): 预测图像路径。
3682
model (paddlex.cv.models): paddlex中的模型。
3783
dataset (paddlex.datasets): 数据集读取器,默认为None。
38-
algo (str): 可解释性方式,当前可选'lime'和'normlime'。
3984
num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
4085
batch_size (int): 预测数据batch大小,默认为50。
4186
save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
@@ -54,21 +99,16 @@ def visualize(img_file,
5499
img = np.around(img).astype('uint8')
55100
img = np.expand_dims(img, axis=0)
56101
interpreter = None
57-
if algo == 'lime':
58-
interpreter = get_lime_interpreter(img, model, dataset, num_samples=num_samples, batch_size=batch_size)
59-
elif algo == 'normlime':
60-
if dataset is None:
61-
raise Exception('The dataset is None. Cannot implement this kind of interpretation')
62-
interpreter = get_normlime_interpreter(img, model, dataset,
63-
num_samples=num_samples, batch_size=batch_size,
102+
if dataset is None:
103+
raise Exception('The dataset is None. Cannot implement this kind of interpretation')
104+
interpreter = get_normlime_interpreter(img, model, dataset,
105+
num_samples=num_samples, batch_size=batch_size,
64106
save_dir=save_dir)
65-
else:
66-
raise Exception('The {} interpretation method is not supported yet!'.format(algo))
67107
img_name = osp.splitext(osp.split(img_file)[-1])[0]
68108
interpreter.interpret(img, save_dir=save_dir)
69109

70110

71-
def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50):
111+
def get_lime_interpreter(img, model, num_samples=3000, batch_size=50):
72112
def predict_func(image):
73113
image = image.astype('float32')
74114
for i in range(image.shape[0]):
@@ -79,8 +119,8 @@ def predict_func(image):
79119
model.test_transforms.transforms = tmp_transforms
80120
return out[0]
81121
labels_name = None
82-
if dataset is not None:
83-
labels_name = dataset.labels
122+
if hasattr(model, 'labels'):
123+
labels_name = model.labels
84124
interpreter = Interpretation('lime',
85125
predict_func,
86126
labels_name,

tutorials/interpret/lime.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
# 选择使用0号卡
3+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
4+
5+
import os.path as osp
6+
import paddlex as pdx
7+
8+
# 下载和解压Imagenet果蔬分类数据集
9+
veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
10+
pdx.utils.download_and_decompress(veg_dataset, path='./')
11+
12+
# 下载和解压已训练好的MobileNetV2模型
13+
model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz'
14+
pdx.utils.download_and_decompress(model_file, path='./')
15+
16+
# 加载模型
17+
model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
18+
19+
# 可解释性可视化
20+
pdx.interpret.lime(
21+
'mini_imagenet_veg/mushroom/n07734744_1106.JPEG',
22+
model,
23+
save_dir='./')

tutorials/interpret/interpret.py renamed to tutorials/interpret/normlime.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,8 @@
2424
transforms=model.test_transforms)
2525

2626
# 可解释性可视化
27-
pdx.interpret.visualize(
28-
'mini_imagenet_veg/mushroom/n07734744_1106.JPEG',
29-
model,
30-
test_dataset,
31-
algo='lime',
32-
save_dir='./')
33-
pdx.interpret.visualize(
27+
pdx.interpret.normlime(
3428
'mini_imagenet_veg/mushroom/n07734744_1106.JPEG',
3529
model,
3630
test_dataset,
37-
algo='normlime',
3831
save_dir='./')

0 commit comments

Comments
 (0)