22
22
from .core .interpretation import Interpretation
23
23
from .core .normlime_base import precompute_normlime_weights
24
24
from .core ._session_preparation import gen_user_home
25
-
26
- def visualize (img_file ,
25
+
26
+ def lime (img_file ,
27
+ model ,
28
+ num_samples = 3000 ,
29
+ batch_size = 50 ,
30
+ save_dir = './' ):
31
+ """使用LIME算法将模型预测结果的可解释性可视化。
32
+
33
+ LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,
34
+ 在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入
35
+ 和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,
36
+ 得到每个输入维度的权重,以此来解释模型。
37
+
38
+ 注意:LIME可解释性结果可视化目前只支持分类模型。
39
+
40
+ Args:
41
+ img_file (str): 预测图像路径。
42
+ model (paddlex.cv.models): paddlex中的模型。
43
+ num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
44
+ batch_size (int): 预测数据batch大小,默认为50。
45
+ save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
46
+ """
47
+ assert model .model_type == 'classifier' , \
48
+ 'Now the interpretation visualize only be supported in classifier!'
49
+ if model .status != 'Normal' :
50
+ raise Exception ('The interpretation only can deal with the Normal model' )
51
+ if not osp .exists (save_dir ):
52
+ os .makedirs (save_dir )
53
+ model .arrange_transforms (
54
+ transforms = model .test_transforms , mode = 'test' )
55
+ tmp_transforms = copy .deepcopy (model .test_transforms )
56
+ tmp_transforms .transforms = tmp_transforms .transforms [:- 2 ]
57
+ img = tmp_transforms (img_file )[0 ]
58
+ img = np .around (img ).astype ('uint8' )
59
+ img = np .expand_dims (img , axis = 0 )
60
+ interpreter = None
61
+ interpreter = get_lime_interpreter (img , model , num_samples = num_samples , batch_size = batch_size )
62
+ img_name = osp .splitext (osp .split (img_file )[- 1 ])[0 ]
63
+ interpreter .interpret (img , save_dir = save_dir )
64
+
65
+
66
+ def normlime (img_file ,
27
67
model ,
28
68
dataset = None ,
29
- algo = 'lime' ,
30
69
num_samples = 3000 ,
31
70
batch_size = 50 ,
32
71
save_dir = './' ):
33
- """可解释性可视化。
72
+ """使用NormLIME算法将模型预测结果的可解释性可视化。
73
+
74
+ NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测
75
+ 试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。
76
+
77
+ 注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。
78
+ 注意2:NormLIME可解释性结果可视化目前只支持分类模型。
79
+
34
80
Args:
35
81
img_file (str): 预测图像路径。
36
82
model (paddlex.cv.models): paddlex中的模型。
37
83
dataset (paddlex.datasets): 数据集读取器,默认为None。
38
- algo (str): 可解释性方式,当前可选'lime'和'normlime'。
39
84
num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
40
85
batch_size (int): 预测数据batch大小,默认为50。
41
86
save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
@@ -54,21 +99,16 @@ def visualize(img_file,
54
99
img = np .around (img ).astype ('uint8' )
55
100
img = np .expand_dims (img , axis = 0 )
56
101
interpreter = None
57
- if algo == 'lime' :
58
- interpreter = get_lime_interpreter (img , model , dataset , num_samples = num_samples , batch_size = batch_size )
59
- elif algo == 'normlime' :
60
- if dataset is None :
61
- raise Exception ('The dataset is None. Cannot implement this kind of interpretation' )
62
- interpreter = get_normlime_interpreter (img , model , dataset ,
63
- num_samples = num_samples , batch_size = batch_size ,
102
+ if dataset is None :
103
+ raise Exception ('The dataset is None. Cannot implement this kind of interpretation' )
104
+ interpreter = get_normlime_interpreter (img , model , dataset ,
105
+ num_samples = num_samples , batch_size = batch_size ,
64
106
save_dir = save_dir )
65
- else :
66
- raise Exception ('The {} interpretation method is not supported yet!' .format (algo ))
67
107
img_name = osp .splitext (osp .split (img_file )[- 1 ])[0 ]
68
108
interpreter .interpret (img , save_dir = save_dir )
69
109
70
110
71
- def get_lime_interpreter (img , model , dataset , num_samples = 3000 , batch_size = 50 ):
111
+ def get_lime_interpreter (img , model , num_samples = 3000 , batch_size = 50 ):
72
112
def predict_func (image ):
73
113
image = image .astype ('float32' )
74
114
for i in range (image .shape [0 ]):
@@ -79,8 +119,8 @@ def predict_func(image):
79
119
model .test_transforms .transforms = tmp_transforms
80
120
return out [0 ]
81
121
labels_name = None
82
- if dataset is not None :
83
- labels_name = dataset .labels
122
+ if hasattr ( model , 'labels' ) :
123
+ labels_name = model .labels
84
124
interpreter = Interpretation ('lime' ,
85
125
predict_func ,
86
126
labels_name ,
0 commit comments