|
| 1 | + |
| 2 | +import numpy as np |
| 3 | + |
| 4 | +from aix360.algorithms.lwbe import LocalWBExplainer |
| 5 | + |
| 6 | +from typing import Tuple, Optional, List |
| 7 | +from typing import NamedTuple |
| 8 | + |
| 9 | +Index = Tuple[int,int] |
| 10 | + |
| 11 | + |
| 12 | +class AlternateMatching(NamedTuple): |
| 13 | + |
| 14 | + """ |
| 15 | + OTMatchingExplainer returns an ordered list of objects, |
| 16 | + each repreenting an explaination. |
| 17 | +
|
| 18 | + Attributes: |
| 19 | +
|
| 20 | + matching (numpy 2d array): alternative matching |
| 21 | + salient (list of int tuples): salient matches (i,j) that constrast with the explained matching |
| 22 | + """ |
| 23 | + |
| 24 | + matching: np.ndarray |
| 25 | + salient: List[Tuple] |
| 26 | + |
| 27 | +class OTMatchingExplainer(LocalWBExplainer): |
| 28 | + |
| 29 | + """ |
| 30 | + OTMatchingExplainer provides explainations for a matching |
| 31 | + that satisfies the transport polytope constraints. |
| 32 | + Given a matching, it produces a set of alternative matchings, |
| 33 | + where each alternate contrasts with the provided instance |
| 34 | + by a sparse set of salient matchings. [#]_. |
| 35 | +
|
| 36 | + This is akin to a search engine providing alternative suggestions |
| 37 | + relevant to a search string. OTMatchingExplainer aims to provide |
| 38 | + the same for matchings. |
| 39 | +
|
| 40 | + References: |
| 41 | + .. [#] `Fabian Lim, Laura Wynter, Shiau Hong Lim, |
| 42 | + "Order Constraints in Optimal Transport", |
| 43 | + 2022 |
| 44 | + <https://arxiv.org/abs/2110.07275>`_ |
| 45 | + """ |
| 46 | + |
| 47 | + def __init__( |
| 48 | + self, |
| 49 | + deactivate_bounds: bool = False, |
| 50 | + error_limit : float = 1e-3, |
| 51 | + ): |
| 52 | + """ |
| 53 | + Initialize the OTMatchingExplainer |
| 54 | + """ |
| 55 | + |
| 56 | + import sys |
| 57 | + if sys.version_info.major == 2: |
| 58 | + super(OTMatchingExplainer, self).__init__() |
| 59 | + else: |
| 60 | + super().__init__() |
| 61 | + |
| 62 | + self._deactivate_bounds = deactivate_bounds |
| 63 | + self._error_limit = error_limit |
| 64 | + |
| 65 | + def set_params(self, *args, **kwargs): |
| 66 | + """ |
| 67 | + Set parameters for the explainer. |
| 68 | + """ |
| 69 | + pass |
| 70 | + |
| 71 | + def explain_instance( |
| 72 | + self, |
| 73 | + matching: np.ndarray, |
| 74 | + costs: np.ndarray, |
| 75 | + constraints: Tuple[ |
| 76 | + np.ndarray, |
| 77 | + np.ndarray, |
| 78 | + ], |
| 79 | + num_alternate_matchings: int = 1, |
| 80 | + search_thresholds: Tuple[float, float] = (.5, .5), |
| 81 | + search_node_limit: int = 20, |
| 82 | + search_depth_limit: int = 1, |
| 83 | + search_match_pos_filter: Optional[List[Index]]=None, |
| 84 | + ): |
| 85 | + """ |
| 86 | + Explain the matching |
| 87 | +
|
| 88 | + Args: |
| 89 | + matching (numpy 2d array): the matching to be explained. |
| 90 | + costs: (numpy 2d array): the (non-negative) matching costs used to obtain above matching. |
| 91 | + constraints (numpy array, numpy array): the transport polytope row/column constraints. |
| 92 | + num_alternate_matchings (int): the number of alternate matchings to return back. |
| 93 | + search_node_limit (int): stop the search when this many nodes have been encountered. |
| 94 | + search_depth_limit (int): do not progress beyond this tree level in the search |
| 95 | + search_match_pos_filter ((int,int) array or None): if specified, this is a whitelist of positions (i,j) of candidate match positions |
| 96 | + search_thresholds (float, float): thresholds used to pick the candidate match positions to search over. |
| 97 | +
|
| 98 | + Returns: |
| 99 | + list of AlternateMatching explanations. |
| 100 | + """ |
| 101 | + |
| 102 | + # the row and column constraints |
| 103 | + a, b = constraints |
| 104 | + |
| 105 | + # check the filter |
| 106 | + if search_match_pos_filter is not None: |
| 107 | + for x in search_match_pos_filter: |
| 108 | + if ( |
| 109 | + (len(x) != 2) |
| 110 | + or |
| 111 | + (type(x) != tuple) |
| 112 | + ): |
| 113 | + raise ValueError(f"search_match_pos_filter must only contain 2-tuples") |
| 114 | + |
| 115 | + # TODO: remove the warnings here when the "import numpy.matlib" |
| 116 | + # issue has been resoluved. |
| 117 | + import warnings |
| 118 | + with warnings.catch_warnings(record=True): |
| 119 | + from otoc import search_otoc_candidates2 |
| 120 | + |
| 121 | + self._model = search_otoc_candidates2( |
| 122 | + a, b, costs, |
| 123 | + strategy = ( |
| 124 | + 'least-saturated-coef', |
| 125 | + { |
| 126 | + 'base_solution': matching, |
| 127 | + 'saturationThreshold': search_thresholds, |
| 128 | + 'a': a, |
| 129 | + 'b': b, |
| 130 | + 'index_filter': search_match_pos_filter, |
| 131 | + } |
| 132 | + ), |
| 133 | + numCandidates=num_alternate_matchings, |
| 134 | + limitCandidates=search_node_limit, |
| 135 | + limitCandatesMode='candidates-obtained', |
| 136 | + limitDepth=search_depth_limit, |
| 137 | + deactivate_bounds=self._deactivate_bounds, |
| 138 | + acceptableError=self._error_limit, |
| 139 | + ) |
| 140 | + |
| 141 | + # perform the search to get various match candidates |
| 142 | + for algo in self._model: |
| 143 | + for _ in algo: |
| 144 | + pass |
| 145 | + |
| 146 | + # search history |
| 147 | + history = self._model._history |
| 148 | + |
| 149 | + # return the top candidate matches |
| 150 | + results = [] |
| 151 | + for i in range(1, num_alternate_matchings+1): |
| 152 | + x = self._model.best_solution(n=i) |
| 153 | + |
| 154 | + # if None is returned, then the search has |
| 155 | + # terminated early and there are a deficit of |
| 156 | + # candidates. |
| 157 | + # So then just terminate here |
| 158 | + |
| 159 | + if x is None: |
| 160 | + break |
| 161 | + |
| 162 | + results.append( |
| 163 | + AlternateMatching( |
| 164 | + matching=x, |
| 165 | + salient=history[ |
| 166 | + self._model.best_history_index(n=i) |
| 167 | + ].index, # type: ignore |
| 168 | + ) |
| 169 | + ) |
| 170 | + return results |
0 commit comments