Skip to content

Commit ac7c7f8

Browse files
authored
Merge pull request #147 from fabianlim/master
Add Matching Explainer Algorithm
2 parents 0a6ab2c + ce28d05 commit ac7c7f8

21 files changed

+993
-6
lines changed

.github/workflows/Build.yml

+4
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,7 @@ jobs:
6060

6161
- name: Step 12 - Test Logistic Rule Regression
6262
run: python ./tests/rbm/test_Logistic_Rule_Regression.py
63+
64+
- name: Step 13 - Test Matching Explainer
65+
run: python ./tests/matching/test_order_constraints.py
66+

CONTRIBUTING.md

+1
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@ via code, tests, or documentation:
4242
* Kush Varshney
4343
* Dennis Wei
4444
* Yunfeng Zhang
45+
* Fabian Lim

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ We have developed the package with extensibility in mind. This library is still
3030
### Local direct explanation
3131

3232
- Teaching AI to Explain its Decisions ([Hind et al., 2019](https://doi.org/10.1145/3306618.3314273))
33-
33+
- Order Constraints in Optimal Transport ([Lim et al.,2022](https://arxiv.org/abs/2110.07275), [Github](https://github.com/IBM/otoc)
34+
3435
### Global direct explanation
3536

3637
- Boolean Decision Rules via Column Generation (Light Edition) ([Dash et al., 2018](https://papers.nips.cc/paper/7716-boolean-decision-rules-via-column-generation))
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .order_constraints import OTMatchingExplainer
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
2+
import numpy as np
3+
4+
from aix360.algorithms.lwbe import LocalWBExplainer
5+
6+
from typing import Tuple, Optional, List
7+
from typing import NamedTuple
8+
9+
Index = Tuple[int,int]
10+
11+
12+
class AlternateMatching(NamedTuple):
13+
14+
"""
15+
OTMatchingExplainer returns an ordered list of objects,
16+
each repreenting an explaination.
17+
18+
Attributes:
19+
20+
matching (numpy 2d array): alternative matching
21+
salient (list of int tuples): salient matches (i,j) that constrast with the explained matching
22+
"""
23+
24+
matching: np.ndarray
25+
salient: List[Tuple]
26+
27+
class OTMatchingExplainer(LocalWBExplainer):
28+
29+
"""
30+
OTMatchingExplainer provides explainations for a matching
31+
that satisfies the transport polytope constraints.
32+
Given a matching, it produces a set of alternative matchings,
33+
where each alternate contrasts with the provided instance
34+
by a sparse set of salient matchings. [#]_.
35+
36+
This is akin to a search engine providing alternative suggestions
37+
relevant to a search string. OTMatchingExplainer aims to provide
38+
the same for matchings.
39+
40+
References:
41+
.. [#] `Fabian Lim, Laura Wynter, Shiau Hong Lim,
42+
"Order Constraints in Optimal Transport",
43+
2022
44+
<https://arxiv.org/abs/2110.07275>`_
45+
"""
46+
47+
def __init__(
48+
self,
49+
deactivate_bounds: bool = False,
50+
error_limit : float = 1e-3,
51+
):
52+
"""
53+
Initialize the OTMatchingExplainer
54+
"""
55+
56+
import sys
57+
if sys.version_info.major == 2:
58+
super(OTMatchingExplainer, self).__init__()
59+
else:
60+
super().__init__()
61+
62+
self._deactivate_bounds = deactivate_bounds
63+
self._error_limit = error_limit
64+
65+
def set_params(self, *args, **kwargs):
66+
"""
67+
Set parameters for the explainer.
68+
"""
69+
pass
70+
71+
def explain_instance(
72+
self,
73+
matching: np.ndarray,
74+
costs: np.ndarray,
75+
constraints: Tuple[
76+
np.ndarray,
77+
np.ndarray,
78+
],
79+
num_alternate_matchings: int = 1,
80+
search_thresholds: Tuple[float, float] = (.5, .5),
81+
search_node_limit: int = 20,
82+
search_depth_limit: int = 1,
83+
search_match_pos_filter: Optional[List[Index]]=None,
84+
):
85+
"""
86+
Explain the matching
87+
88+
Args:
89+
matching (numpy 2d array): the matching to be explained.
90+
costs: (numpy 2d array): the (non-negative) matching costs used to obtain above matching.
91+
constraints (numpy array, numpy array): the transport polytope row/column constraints.
92+
num_alternate_matchings (int): the number of alternate matchings to return back.
93+
search_node_limit (int): stop the search when this many nodes have been encountered.
94+
search_depth_limit (int): do not progress beyond this tree level in the search
95+
search_match_pos_filter ((int,int) array or None): if specified, this is a whitelist of positions (i,j) of candidate match positions
96+
search_thresholds (float, float): thresholds used to pick the candidate match positions to search over.
97+
98+
Returns:
99+
list of AlternateMatching explanations.
100+
"""
101+
102+
# the row and column constraints
103+
a, b = constraints
104+
105+
# check the filter
106+
if search_match_pos_filter is not None:
107+
for x in search_match_pos_filter:
108+
if (
109+
(len(x) != 2)
110+
or
111+
(type(x) != tuple)
112+
):
113+
raise ValueError(f"search_match_pos_filter must only contain 2-tuples")
114+
115+
# TODO: remove the warnings here when the "import numpy.matlib"
116+
# issue has been resoluved.
117+
import warnings
118+
with warnings.catch_warnings(record=True):
119+
from otoc import search_otoc_candidates2
120+
121+
self._model = search_otoc_candidates2(
122+
a, b, costs,
123+
strategy = (
124+
'least-saturated-coef',
125+
{
126+
'base_solution': matching,
127+
'saturationThreshold': search_thresholds,
128+
'a': a,
129+
'b': b,
130+
'index_filter': search_match_pos_filter,
131+
}
132+
),
133+
numCandidates=num_alternate_matchings,
134+
limitCandidates=search_node_limit,
135+
limitCandatesMode='candidates-obtained',
136+
limitDepth=search_depth_limit,
137+
deactivate_bounds=self._deactivate_bounds,
138+
acceptableError=self._error_limit,
139+
)
140+
141+
# perform the search to get various match candidates
142+
for algo in self._model:
143+
for _ in algo:
144+
pass
145+
146+
# search history
147+
history = self._model._history
148+
149+
# return the top candidate matches
150+
results = []
151+
for i in range(1, num_alternate_matchings+1):
152+
x = self._model.best_solution(n=i)
153+
154+
# if None is returned, then the search has
155+
# terminated early and there are a deficit of
156+
# candidates.
157+
# So then just terminate here
158+
159+
if x is None:
160+
break
161+
162+
results.append(
163+
AlternateMatching(
164+
matching=x,
165+
salient=history[
166+
self._model.best_history_index(n=i)
167+
].index, # type: ignore
168+
)
169+
)
170+
return results

aix360/data/README.md

+9
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,14 @@ The datasets supported by the AI Explainability 360 toolkit are listed below. Pl
5151

5252
No download required, appropriate python code generates the data.
5353

54+
7. e-SNLI dataset
55+
56+
Source: https://www.eraserbenchmark.com/zipped/esnli.tar.gz
57+
58+
Follow download instructions in the [[nbviewer](https://nbviewer.jupyter.org/github/IBM/AIX360/tree/master/examples/matching/)]
59+
60+
61+
62+
5463

5564

aix360/data/esnli_data/__init__.py

Whitespace-only changes.

aix360/datasets/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .cifar_dataset import CIFARDataset
77
from .ted_dataset import TEDDataset
88
from .fashion_mnist_dataset import FMnistDataset
9+
from .esnli_dataset import eSNLIDataset

aix360/datasets/esnli_dataset.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import os
2+
3+
from functools import lru_cache
4+
from typing import Dict
5+
6+
# helper function that returns a specific
7+
# sentence pair example from the e-SNLI dataset
8+
@lru_cache(maxsize=120)
9+
def _example(file: str, id: str) -> Dict:
10+
import json
11+
with open(file, 'r', encoding='utf-8') as f:
12+
while True:
13+
try:
14+
line = f.readline().strip()
15+
16+
if line == '':
17+
raise EOFError
18+
19+
d = json.loads(line)
20+
if d['docid'] == id:
21+
return d
22+
except EOFError:
23+
raise RuntimeError(f"example {id} not found")
24+
25+
class eSNLIDataset:
26+
"""
27+
The e-SNLI dataset [#]_ contains pairs of sentences
28+
each accompanied by human-rationale annotations
29+
as to which words are in each pairs are most
30+
important for matching.
31+
32+
The sentence pairs are from the Stanford Natural
33+
Language Inference dataset with labels that indicate
34+
if the sentence pair is a logical entailment,
35+
contradiction or neutral.
36+
37+
References:
38+
.. [#] `Camburu, Oana-Maria, Tim Rocktäschel, Thomas Lukasiewicz, and Phil Blunsom,
39+
“E-SNLI: Natural Language Inference with Natural Language Explanations.”,
40+
2018
41+
<https://arxiv.org/abs/1812.01193>`_
42+
"""
43+
44+
def __init__(self):
45+
self._dirpath = os.path.join(
46+
os.path.dirname(os.path.abspath(__file__)),
47+
'..', 'data','esnli_data'
48+
)
49+
50+
self._cache_doc = {}
51+
52+
def get_example(self, example_id: str) -> Dict:
53+
"""
54+
Return an e-SNLI example.
55+
56+
The example_id indexes the "docs.jsonl" file of the downloaded dataset.
57+
58+
Args:
59+
example_id (str): the example index.
60+
61+
Returns:
62+
e-SNLI example in dictionary form.
63+
"""
64+
return _example(
65+
os.path.join(
66+
self._dirpath,
67+
'docs.jsonl'
68+
),
69+
example_id,
70+
)
71+

docs/datasets.rst

+7
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,10 @@ TED Dataset
4242

4343
.. autoclass:: aix360.datasets.TEDDataset
4444
:members:
45+
46+
eSNLI Dataset
47+
-------------
48+
49+
.. autoclass:: aix360.datasets.eSNLIDataset
50+
:members:
51+

docs/lwbe.rst

+7
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,10 @@ SHAP Explainers
2525

2626
.. autoclass:: aix360.algorithms.shap.shap_wrapper.LinearExplainer
2727
:members:
28+
29+
30+
Matching Explainers
31+
---------------------------
32+
33+
.. autoclass:: aix360.algorithms.matching.order_constraints.OTMatchingExplainer
34+
:members:

examples/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ the user through the various steps of the notebook.
3030

3131
- [GLRMExplainer and BRCGExplainer using Boston, Breast-cancer datasets](./rbm) [[on nbviewer](https://nbviewer.jupyter.org/github/IBM/AIX360/tree/master/examples/rbm/)]
3232

33+
- [OTMatchingExplainer on Sentence Matching in NLP](./matching) [[on nbviewer](https://nbviewer.jupyter.org/github/IBM/AIX360/tree/master/examples/matching/matching-pairs-of-sentences.ipynb)]

examples/matching/data/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)