diff --git a/.github/workflows/build-and-deploy.yml b/.github/workflows/build-and-deploy.yml index 67de9fc66..60a27741c 100644 --- a/.github/workflows/build-and-deploy.yml +++ b/.github/workflows/build-and-deploy.yml @@ -12,13 +12,13 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: - python-version: "3.10" + python-version: "3.11" - name: install dependencies run: | python -m pip install --upgrade pip - pip install setuptools wheel twine + pip install --upgrade build setuptools wheel twine - name: build - run: python setup.py sdist + run: python -m build - name: deploy env: TWINE_USERNAME: __token__ diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8d7eae1db..cb4d1a453 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -64,4 +64,4 @@ jobs: TOXENV: ${{ matrix.toxenv }} run: tox - name: setup - run: python setup.py install + run: pip install . diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 000000000..fa031a0b1 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,282 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "comb_spec_searcher_rs" +version = "0.1.0" +dependencies = [ + "pyo3", +] + +[[package]] +name = "indoc" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" + +[[package]] +name = "libc" +version = "0.2.142" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a987beff54b60ffa6d51982e1aa1146bc42f19bd26be28b0586f252fccf5317" + +[[package]] +name = "lock_api" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "435011366fe56583b16cf956f9df0095b405b82d76425bc8981c0e22e60ec4df" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "memoffset" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" + +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-sys", +] + +[[package]] +name = "proc-macro2" +version = "1.0.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3b1ac5b3731ba34fdaa9785f8d74d17448cd18f30cf19e0c7e7b1fdb5272109" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cb946f5ac61bb61a5014924910d936ebd2b23b705f7a4a3c40b05c720b079a3" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd4d7c5337821916ea2a1d21d1092e8443cf34879e53a0ac653fbb98f44ff65c" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d39c55dab3fc5a4b25bbd1ac10a2da452c4aca13bb450f22818a002e29648d" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97daff08a4c48320587b5224cc98d609e3c27b6d437315bd40b605c98eeb5918" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags", +] + +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + +[[package]] +name = "smallvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd1ba337640d60c3e96bc6f0638a939b9c9a7f2c316a1598c279828b3d1dc8c5" + +[[package]] +name = "unicode-ident" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" + +[[package]] +name = "unindent" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" + +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 000000000..2dee29b60 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "comb_spec_searcher_rs" +version = "0.1.0" +edition = "2021" + +[lib] +# The name of the native library. This is the name which will be used in Python to import the +# library (i.e. `import string_sum`). If you change this, you must also change the name of the +# `#[pymodule]` in `src/lib.rs`. +name = "comb_spec_searcher_rs" + +# "cdylib" is necessary to produce a shared library for Python to import from. +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.18.3", features = ["extension-module"] } diff --git a/README.rst b/README.rst index e9e07ebee..034598027 100644 --- a/README.rst +++ b/README.rst @@ -44,7 +44,7 @@ cloning the repository: .. code:: bash - ./setup.py develop + pip install --editable . Combinatorial exploration ------------------------- diff --git a/comb_spec_searcher/rule_db/forest.py b/comb_spec_searcher/rule_db/forest.py index b4f8c19eb..ea4fb0b4f 100644 --- a/comb_spec_searcher/rule_db/forest.py +++ b/comb_spec_searcher/rule_db/forest.py @@ -1,24 +1,9 @@ -import gc import itertools -import platform import time from datetime import timedelta -from typing import ( - Callable, - Deque, - Dict, - Generic, - Iterable, - Iterator, - List, - Optional, - Set, - Tuple, - TypeVar, - Union, -) +from typing import Dict, Iterable, Iterator, List, Set, Tuple, Union -from logzero import logger +from comb_spec_searcher_rs import ForestRuleKey, RuleBucket, TableMethod from comb_spec_searcher.class_db import ClassDB from comb_spec_searcher.exception import StrategyDoesNotApply @@ -35,385 +20,12 @@ StrategyFactory, ) from comb_spec_searcher.strategies.strategy_pack import StrategyPack -from comb_spec_searcher.typing import CSSstrategy, ForestRuleKey, RuleBucket, RuleKey +from comb_spec_searcher.typing import CSSstrategy -T = TypeVar("T") -RuleWithShifts = Tuple[RuleKey, Tuple[int, ...]] SortedRWS = Dict[RuleBucket, List[ForestRuleKey]] empty_strategy: EmptyStrategy = EmptyStrategy() -class DefaultList(Generic[T]): - """ - A list data structure get automatically gets longer if an index not existing is - requested. - - When getting longer the list is filled by calling the provied `default_factory`. - This is similar to the `collections.defaultdict` data structure. - """ - - def __init__(self, default_factory: Callable[[], T]): - self._default_factory = default_factory - self._list: List[T] = [] - - def _increase_list_len(self, key: int) -> None: - """ - Increase the length of the list so that the given is valid. - """ - num_new_entry = key - len(self._list) + 1 - self._list.extend((self._default_factory() for _ in range(num_new_entry))) - - def __getitem__(self, key: int) -> T: - try: - return self._list[key] - except IndexError: - self._increase_list_len(key) - return self._list[key] - - def __setitem__(self, key: int, value: T) -> None: - self._list[key] = value - - def __iter__(self) -> Iterator[T]: - return iter(self._list) - - def __str__(self) -> str: - return str(self._list) - - -class Function: - """ - A python representation of a function. - - The function maps natural number to a natural number or infinity (represented - by the use of None) - - The default value of the function is 0. - """ - - def __init__(self) -> None: - self._value: List[Optional[int]] = [] - self._preimage_count: DefaultList[int] = DefaultList(int) - self._infinity_count: int = 0 - - @property - def preimage_count(self) -> List[int]: - """ - Return the number of classes for each value of the function. - """ - count = list(self._preimage_count) - while count and count[-1] == 0: - count.pop() - return count - - @property - def infinity_count(self) -> int: - return self._infinity_count - - def __getitem__(self, key: int) -> Optional[int]: - """ - Return the value of the function for the given key. - """ - try: - return self._value[key] - except IndexError: - self._increase_list_len(key) - return 0 - - def _increase_list_len(self, key: int) -> None: - num_new_entry = key - len(self._value) + 1 - self._value.extend((0 for _ in range(num_new_entry))) - self._preimage_count[0] += num_new_entry - - def increase_value(self, key: int) -> None: - """ - Increase by one the value of the function for the given key. - """ - try: - old_value = self._value[key] - if old_value is None: - raise ValueError(f"The function is already infinity for {key}") - except IndexError: - self._increase_list_len(key) - old_value = 0 - self._value[key] = old_value + 1 - self._preimage_count[old_value] -= 1 - self._preimage_count[old_value + 1] += 1 - - def set_infinite(self, key: int) -> None: - """ - Set the value of function for the given key to infinity. - """ - try: - old_value = self._value[key] - if old_value is None: - raise ValueError(f"The function is already infinity for {key}") - except IndexError: - self._increase_list_len(key) - old_value = 0 - self._value[key] = None - self._preimage_count[old_value] -= 1 - self._infinity_count += 1 - - def preimage_gap(self, length: int) -> int: - """ - Return the smallest k such that the preimage of the interval - [k, k+length-1] is empty. - """ - if length <= 0: - raise ValueError("length argument must be positive") - last_non_zero = -1 - for i, v in enumerate(self._preimage_count): - if v != 0: - last_non_zero = i - elif i - last_non_zero >= length: - return last_non_zero + 1 - return last_non_zero + 1 - - def preimage(self, value: Optional[int]) -> Iterator[int]: - """ - Return the preimage of the function for the given value. - """ - if value == 0: - raise ValueError("The preimage of 0 is infinite.") - return (k for k, v in enumerate(self._value) if v == value) - - def to_dict(self) -> Dict[int, Optional[int]]: - """ - Return a dictionary view of the function with only the non-zero value. - """ - return {i: v for i, v in enumerate(self._value) if v != 0} - - def __str__(self) -> str: - parts = ( - f"{i} -> {v if v is not None else '∞'}" for i, v in enumerate(self._value) - ) - return "\n".join(parts) - - -class TableMethod: - def __init__(self) -> None: - self._rules: List[ForestRuleKey] = [] - self._shifts: List[List[Optional[int]]] = [] - self._function: Function = Function() - self._gap_size: int = 1 - self._rules_using_class: DefaultList[List[Tuple[int, int]]] = DefaultList(list) - self._rules_pumping_class: DefaultList[List[int]] = DefaultList(list) - self._processing_queue: Deque[int] = Deque() - self._current_gap: Tuple[int, int] = (1, 1) - self._rule_holding_extra_terms: Set[int] = set() - - @property - def function(self) -> Dict[int, Optional[int]]: - """ - Return a dict representing the number of term that are computable for each - class. - Only the class where it can get at least a term are included. If a class is map - to None, then all the terms of the enumeration are computable. - """ - return self._function.to_dict() - - def add_rule_key( - self, - rule_key: ForestRuleKey, - ): - """ - Add the rule to the database. - - INPUTS: - - `rule_key` - - `shifts_for_zero`: The values of the shifts if no information was known - about any of the classes. - - `rule_bucket` the type of rule - """ - self._rules.append(rule_key) - self._shifts.append(self._compute_shift(rule_key.key, rule_key.shifts)) - max_gap = max((abs(s) for s in rule_key.shifts), default=0) - if max_gap > self._gap_size: - self._gap_size = max_gap - self._correct_gap() - if self._function[rule_key.parent] is not None: - rule_idx = len(self._rules) - 1 - self._rules_pumping_class[rule_key.parent].append(rule_idx) - for child_idx, child in enumerate(rule_key.children): - if self._function[child] is not None: - self._rules_using_class[child].append((rule_idx, child_idx)) - self._processing_queue.append(rule_idx) - self._process_queue() - - def is_pumping(self, label: int) -> bool: - """ - Determine if the comb_class is pumping in the current universe. - """ - return self._function[label] is None - - def status(self) -> str: - s = f"\tSize of the gap: {self._gap_size}\n" - s += f"\tSize of the stable subset: {self._function.infinity_count}\n" - s += f"\tSizes of the pre-images: {self._function.preimage_count}\n" - return s - - def stable_subset(self) -> Iterator[int]: - return self._function.preimage(None) - - def pumping_subuniverse( - self, - ) -> Iterator[ForestRuleKey]: - """ - Iterator over all the forest rule keys that contain only pumping - combinatorial classes. - """ - stable_subset = set(self.stable_subset()) - for forest_key in self._rules: - if forest_key.parent in stable_subset and stable_subset.issuperset( - forest_key.children - ): - yield forest_key - - def _compute_shift( - self, - rule_key: RuleKey, - shifts_for_zero: Tuple[int, ...], - ) -> List[Optional[int]]: - """ - Compute the initial value for the shifts a rule based on the current state of - the function. - """ - parent_current_value = self._function[rule_key[0]] - if parent_current_value is None: - return [None for _ in shifts_for_zero] - chidlren_function_value = map(self._function.__getitem__, rule_key[1]) - return [ - fvalue + sfz - parent_current_value if fvalue is not None else None - for fvalue, sfz in zip(chidlren_function_value, shifts_for_zero) - ] - - def _correct_gap(self) -> None: - """ - Correct the gap and if needed queue rules for the classes that were previously - on the right hand side of the gap. - - This should be toggled every time the gap changes whether the size changes or - the some value changes of the function caused the gap to change. - """ - k = self._function.preimage_gap(self._gap_size) - new_gap = (k, k + self._gap_size - 1) - if new_gap[1] > self._current_gap[1]: - self._processing_queue.extend(self._rule_holding_extra_terms) - self._rule_holding_extra_terms.clear() - self._current_gap = new_gap - - def _process_queue(self) -> None: - """ - Try to make improvement with all the class in the processing queue. - """ - while self._processing_queue or self._rule_holding_extra_terms: - while self._processing_queue: - rule_idx = self._processing_queue.popleft() - shifts = self._shifts[rule_idx] - if self._can_give_terms(shifts): - parent = self._rules[rule_idx].parent - self._increase_value(parent, rule_idx) - if self._rule_holding_extra_terms: - rule_idx = self._rule_holding_extra_terms.pop() - parent = self._rules[rule_idx].parent - self._set_infinite(parent) - - @staticmethod - def _can_give_terms(shifts: List[Optional[int]]) -> bool: - """ - Return True if the shifts indicate that a new terms can be computed. - """ - return all(s is None or s > 0 for s in shifts) - - def _increase_value(self, comb_class: int, rule_idx: int) -> None: - """ - Increase the value of the comb_class and put on the processing stack any rule - that can now give a new term. - - The rule_idx must indicate the rule used to justify the increase. - """ - current_value = self._function[comb_class] - if current_value is None: - return - if current_value > self._current_gap[1]: - self._rule_holding_extra_terms.add(rule_idx) - return - self._function.increase_value(comb_class) - # Correction of the gap - gap_start = self._function.preimage_gap(self._gap_size) - if self._current_gap[0] != gap_start: - self._correct_gap() - # Correction of the shifts for rule pumping comb_class - for r_idx in self._rules_pumping_class[comb_class]: - shifts = self._shifts[r_idx] - for i, v in enumerate(shifts): - shifts[i] = v - 1 if v is not None else None - if self._can_give_terms(shifts): - self._processing_queue.append(r_idx) - # Correction of the shifts for rules using comb_class to pump - for r_idx, class_idx in self._rules_using_class[comb_class]: - shifts = self._shifts[r_idx] - current_shift = shifts[class_idx] - assert current_shift is not None - shifts[class_idx] = current_shift + 1 - if self._can_give_terms(shifts): - self._processing_queue.append(r_idx) - - def _set_infinite(self, comb_class: int) -> None: - """ - Set the value if the class to infinity. - - This should happen when we know that we cannot pump anything on the left side - of the gap. - """ - current_value = self._function[comb_class] - if current_value is None: - return - assert current_value > self._current_gap[1] - assert not self._processing_queue - self._function.set_infinite(comb_class) - # This class will never be increased again so we remove any occurrence - # of the rule of any rule for that class from _rules_using_class and - # _rules_pumping_class - for rule_idx in self._rules_pumping_class[comb_class]: - for child in self._rules[rule_idx].children: - self._rules_using_class[child] = [ - (ri, ci) - for ri, ci in self._rules_using_class[child] - if ri != rule_idx - ] - self._rules_pumping_class[comb_class].clear() - # Correction of the shifts for rules using comb_class to pump - for rule_idx, class_idx in self._rules_using_class[comb_class]: - shifts = self._shifts[rule_idx] - shifts[class_idx] = None - if self._can_give_terms(shifts): - self._processing_queue.append(rule_idx) - self._rules_using_class[comb_class].clear() - - def rule_info(self, rule_idx: int) -> str: - """ - Return a string with information about a particular rule. - Mostly intended for debugging. - """ - - def v_to_str(v: Optional[int]) -> str: - """Return a string for the integer and infinity if None""" - if v is None: - return "∞" - return str(v) - - rule_key = self._rules[rule_idx] - current_value = f"{v_to_str(self._function[rule_key.parent])} -> " + ", ".join( - map(v_to_str, (self._function[c] for c in rule_key.children)) - ) - shifts = map(v_to_str, self._shifts[rule_idx]) - child_with_shift = ", ".join( - f"({c}, {s})" for c, s in zip(rule_key.children, shifts) - ) - return f"{rule_key.parent} -> {child_with_shift} || {current_value}" - - class ForestRuleExtractor: MINIMIZE_ORDER = ( RuleBucket.REVERSE, @@ -432,10 +44,7 @@ def __init__( self.pack = pack self.classdb = classdb self.root_label = root_label - self.rule_by_bucket = self._sorted_stable_rules(ruledb.table_method) - assert set(ForestRuleExtractor.MINIMIZE_ORDER) == set(self.rule_by_bucket) - self.needed_rules: List[ForestRuleKey] = [] - self._minimize() + self.needed_rules = ruledb.table_method.extract_specification(root_label) def check(self) -> None: """ @@ -475,64 +84,6 @@ def rules(self, cache: Tuple[AbstractRule, ...]) -> Iterator[AbstractRule]: else: yield rule - def _minimize(self): - """ - Perform the complete minimization of the forest. - """ - for key in ForestRuleExtractor.MINIMIZE_ORDER: - self._minimize_key(key) - - def _minimize_key(self, key: RuleBucket) -> None: - """ - Minimize the number of rules used for the type of rule given by key. - - The list of rule in `self.rule_by_bucket[key]` is cleared and a - minimal set from theses is added to `self.needed_rules`. - """ - logger.info("Minimizing %s", key.name) - maybe_useful: List[ForestRuleKey] = [] - not_minimizing: List[List[ForestRuleKey]] = [ - self.needed_rules, - maybe_useful, - ] - not_minimizing.extend( - rules for k, rules in self.rule_by_bucket.items() if k != key - ) - minimizing = self.rule_by_bucket[key] - while minimizing: - tb = TableMethod() - # Add the rule we are not trying to minimize - for rk in itertools.chain.from_iterable(not_minimizing): - tb.add_rule_key(rk) - if tb.is_pumping(self.root_label): - minimizing.clear() - break - # Add rule until it gets productive - for i, rk in enumerate(minimizing): - tb.add_rule_key(rk) - if tb.is_pumping(self.root_label): - break - else: - raise RuntimeError("Not pumping after adding all rules") - maybe_useful.append(rk) - assert minimizing, "variable i won't be set" - # pylint: disable=undefined-loop-variable - for _ in range(i, len(minimizing)): - minimizing.pop() - # added to avoid doubling in memory when minimizing with pypy - if platform.python_implementation() == "PyPy": - gc.collect_step() # type: ignore - counter = 0 - while maybe_useful: - rk = maybe_useful.pop() - if not self._is_productive(itertools.chain.from_iterable(not_minimizing)): - self.needed_rules.append(rk) - counter += 1 - # added to avoid doubling in memory when minimizing with pypy - if platform.python_implementation() == "PyPy": - gc.collect_step() # type: ignore - logger.info("Using %s rule for %s", counter, key.name) - def _is_productive(self, rule_keys: Iterable[ForestRuleKey]) -> bool: """ Check if the given set of rules is productive. @@ -542,23 +93,6 @@ def _is_productive(self, rule_keys: Iterable[ForestRuleKey]) -> bool: ruledb.add_rule_key(rk) return ruledb.is_pumping(self.root_label) - def _sorted_stable_rules(self, ruledb: TableMethod) -> SortedRWS: - """ - Extract all the rule from the stable subuniverse and return all of them in a - dict sorted by type. - """ - res: SortedRWS = {bucket: [] for bucket in self.MINIMIZE_ORDER} - for forest_key in ruledb.pumping_subuniverse(): - try: - res[forest_key.bucket].append(forest_key) - except KeyError as e: - msg = ( - f"{forest_key.bucket} type is not currently supported " - "by the extractor" - ) - raise RuntimeError(msg) from e - return res - def _find_rule(self, rule_key: ForestRuleKey) -> AbstractRule: """ Find a rule that have the given rule key. @@ -619,10 +153,7 @@ class RuleDBForest(RuleDBAbstract): """ def __init__( - self, - *, - reverse: bool = True, - rule_cache: Iterable[AbstractRule] = tuple(), + self, *, reverse: bool = True, rule_cache: Iterable[AbstractRule] = tuple() ) -> None: super().__init__() self.reverse = reverse diff --git a/comb_spec_searcher/strategies/rule.py b/comb_spec_searcher/strategies/rule.py index 7195b8194..336258d61 100644 --- a/comb_spec_searcher/strategies/rule.py +++ b/comb_spec_searcher/strategies/rule.py @@ -26,19 +26,12 @@ cast, ) +from comb_spec_searcher_rs import ForestRuleKey, RuleBucket from logzero import logger from sympy import Eq, Function, var from comb_spec_searcher.combinatorial_class import CombinatorialClass -from comb_spec_searcher.typing import ( - ForestRuleKey, - Objects, - ObjectsCache, - RuleBucket, - SubObjects, - SubTerms, - Terms, -) +from comb_spec_searcher.typing import Objects, ObjectsCache, SubObjects, SubTerms, Terms from ..combinatorial_class import CombinatorialClassType, CombinatorialObjectType from ..exception import SanityCheckFailure, SpecificationNotFound, StrategyDoesNotApply diff --git a/comb_spec_searcher/typing.py b/comb_spec_searcher/typing.py index 4bfc6c71c..ac232addf 100644 --- a/comb_spec_searcher/typing.py +++ b/comb_spec_searcher/typing.py @@ -1,4 +1,3 @@ -import enum from typing import ( TYPE_CHECKING, Callable, @@ -30,8 +29,6 @@ "Parameters", "ParametersMap", "RelianceProfile", - "RuleBucket", - "ForestRuleKey", "Objects", "ObjectsCache", "Terms", @@ -61,29 +58,9 @@ class WorkPacket(NamedTuple): inferral: bool -@enum.unique -class RuleBucket(enum.Enum): - UNDEFINED = enum.auto() - VERIFICATION = enum.auto() - EQUIV = enum.auto() - NORMAL = enum.auto() - REVERSE = enum.auto() - - RuleKey = Tuple[int, Tuple[int, ...]] -class ForestRuleKey(NamedTuple): - parent: int - children: Tuple[int, ...] - shifts: Tuple[int, ...] - bucket: RuleBucket - - @property - def key(self) -> RuleKey: - return (self.parent, self.children) - - # From constructor Parameters = Tuple[int, ...] ParametersMap = Callable[[Parameters], Parameters] diff --git a/comb_spec_searcher_rs.pyi b/comb_spec_searcher_rs.pyi new file mode 100644 index 000000000..daf340232 --- /dev/null +++ b/comb_spec_searcher_rs.pyi @@ -0,0 +1,31 @@ +import enum +from typing import List, Tuple + +class RuleBucket(enum.Enum): + UNDEFINED = enum.auto() + VERIFICATION = enum.auto() + EQUIV = enum.auto() + NORMAL = enum.auto() + REVERSE = enum.auto() + +class ForestRuleKey: + parent: int + bucket: RuleBucket + children: tuple[int] + key: tuple[int, tuple[int, ...]] + shifts: tuple[int] + + def __new__( + cls, + parent: int, + children: Tuple[int, ...], + shifts: Tuple[int, ...], + bucket: RuleBucket, + ): ... + +class TableMethod: + def add_rule_key(self, rule_key: ForestRuleKey) -> None: ... + def is_pumping(self, label: int) -> bool: ... + def pumping_subuniverse(self) -> List[ForestRuleKey]: ... + def status(self) -> str: ... + def extract_specification(self, root_class: int) -> List[ForestRuleKey]: ... diff --git a/pylintrc b/pylintrc index a9c893fc1..1db97f3ef 100644 --- a/pylintrc +++ b/pylintrc @@ -338,7 +338,7 @@ ignored-classes=optparse.Values,thread._local,_thread._local # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis). It # supports qualified module names, as well as Unix pattern matching. -ignored-modules= +ignored-modules=comb_spec_searcher_rs # Show a hint with possible names when a member name was not found. The aspect # of finding the hint is based on edit distance. diff --git a/pyproject.toml b/pyproject.toml index 22e31c67c..b047f5b0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,60 @@ +[project] +name = "comb_spec_searcher" +version = "4.2.0" +authors = [ + {name = "Permuta Triangle", email = "permutatriangle@gmail.com"}, +] +description = "A library for performing combinatorial exploration." +readme = "README.rst" +requires-python = ">=3.8" +keywords = [ + "enumerative", + "combinatorics", + "combinatorial", + "specification", + "counting", +] +license = {text = "BSD-3-Clause"} +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Programming Language :: Rust", + "Topic :: Education", + "Topic :: Scientific/Engineering :: Mathematics", +] +dependencies = [ + "logzero==1.7.0", + "sympy==1.10.1", + "psutil==5.9.4", + "pympler==1.0.1", + "requests==2.28.1", + "typing-extensions==4.4.0", + "tabulate==0.9.0", +] + +[project.urls] +homepage = "https://github.com/PermutaTriangle/comb_spec_searcher" +source = "https://github.com/PermutaTriangle/comb_spec_searcher" +tracker = "https://github.com/PermutaTriangle/comb_spec_searcher/issues" + +[build-system] +requires = ["maturin>=0.14,<0.15"] +build-backend = "maturin" + +[tool.maturin] +exclude = ["tests/**/*"] + [tool.black] -target-version = ['py37'] +target-version = ['py38'] include = '\.pyi?$' exclude = ''' ( diff --git a/setup.py b/setup.py deleted file mode 100755 index 670424e5a..000000000 --- a/setup.py +++ /dev/null @@ -1,61 +0,0 @@ -#!/usr/bin/env python -import os - -from setuptools import find_packages, setup - - -def read(fname): - return open(os.path.join(os.path.dirname(__file__), fname)).read() - - -def get_version(rel_path): - for line in read(rel_path).splitlines(): - if line.startswith("__version__"): - delim = '"' if '"' in line else "'" - return line.split(delim)[1] - raise RuntimeError("Unable to find version string.") - - -setup( - name="comb_spec_searcher", - version=get_version("comb_spec_searcher/__init__.py"), - author="Permuta Triangle", - author_email="permutatriangle@gmail.com", - description="A library for performing combinatorial exploration.", - license="BSD-3", - keywords="enumerative combinatorics combinatorial specification counting", - url="https://github.com/PermutaTriangle/comb_spec_searcher", - project_urls={ - "Source": "https://github.com/PermutaTriangle/comb_spec_searcher", - "Tracker": ("https://github.com/PermutaTriangle/comb_spec_searcher" "/issues"), - }, - packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), - package_data={"comb_spec_searcher": ["py.typed"]}, - long_description=read("README.rst"), - python_requires=">=3.8", - include_package_data=True, - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: BSD License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - "Topic :: Education", - "Topic :: Scientific/Engineering :: Mathematics", - ], - install_requires=[ - "logzero==1.7.0", - "sympy==1.10.1", - "psutil==5.9.4", - "pympler==1.0.1", - "requests==2.28.1", - "typing-extensions==4.4.0", - "tabulate==0.9.0", - ], -) diff --git a/src/forest.rs b/src/forest.rs new file mode 100644 index 000000000..fc495a444 --- /dev/null +++ b/src/forest.rs @@ -0,0 +1,5 @@ +mod extractor; +mod function; +mod table_method; + +pub use table_method::{ForestRuleKey, RuleBucket, TableMethodPyWrapper}; diff --git a/src/forest/extractor.rs b/src/forest/extractor.rs new file mode 100644 index 000000000..7c0206de2 --- /dev/null +++ b/src/forest/extractor.rs @@ -0,0 +1,108 @@ +use super::table_method::TableMethod; +use super::ForestRuleKey; +use super::RuleBucket; +use std::collections::HashSet; + +const MINIMIZE_ORDER: [RuleBucket; 4] = [ + RuleBucket::Reverse, + RuleBucket::Normal, + RuleBucket::Equiv, + RuleBucket::Verification, +]; + +enum MinimizationRoundResult { + Done(TableMethod), + NotDone(TableMethod), +} + +/// Perform one round of the minimization. +/// +/// Insert into the table method until it pumps for the root class add the last +/// rule to the maybe useful set. +fn minimzation_bucket_round( + tb: TableMethod, + bucket: &RuleBucket, + root_class: u32, + maybe_useful: &mut HashSet, +) -> MinimizationRoundResult { + let mut new_tb = TableMethod::new(); + let mut rules_in_bucket = vec![]; + for rk in tb.into_pumping_subuniverse() { + if rk.get_bucket() == bucket && !maybe_useful.contains(&rk) { + rules_in_bucket.push(rk); + } else { + new_tb.add_rule_key(rk); + } + } + if new_tb.is_pumping(root_class) { + return MinimizationRoundResult::Done(new_tb); + } + loop { + let rk = rules_in_bucket.pop().expect("Not pumping after adding all rules"); + new_tb.add_rule_key(rk); + if new_tb.is_pumping(root_class) { + maybe_useful.insert(new_tb.get_last_added_rule_key().unwrap().clone()); + break; + } + } + MinimizationRoundResult::NotDone(new_tb) +} + +/// Minimize the rules for a given bucket +fn minimize_bucket(mut tb: TableMethod, bucket: &RuleBucket, root_class: u32) -> TableMethod { + let mut done = false; + let mut maybe_useful = HashSet::new(); + while !done { + (tb, done) = match minimzation_bucket_round(tb, bucket, root_class, &mut maybe_useful) { + MinimizationRoundResult::Done(tb) => (tb, true), + MinimizationRoundResult::NotDone(tb) => (tb, false), + } + } + tb +} + +/// Perform the complete minimization of the forest +fn minimize(tb: TableMethod, root_class: u32) -> TableMethod { + let mut tb = tb; + for bucket in MINIMIZE_ORDER.iter() { + tb = minimize_bucket(tb, bucket, root_class); + } + tb +} + +pub fn extract_specification(root_class: u32, tb: TableMethod) -> Vec { + let minimized = minimize(tb, root_class); + let mut rules: Vec<_> = minimized.into_rules().collect(); + let parents: HashSet<_> = rules.iter().map(|rk| rk.get_parent()).collect(); + assert_eq!(parents.len(), rules.len()); + for rk in rules.iter() { + for c in rk.iter_children() { + assert!(parents.contains(c)); + } + } + rules +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn extract_132_test() { + let rules = vec![ + ForestRuleKey::new(0, vec![1, 2], vec![0, 0], RuleBucket::Normal), + ForestRuleKey::new(1, vec![], vec![], RuleBucket::Verification), + ForestRuleKey::new(2, vec![3], vec![0], RuleBucket::Equiv), + ForestRuleKey::new(3, vec![4], vec![0], RuleBucket::Equiv), + ForestRuleKey::new(4, vec![5, 0, 0], vec![0, 1, 1], RuleBucket::Normal), + ForestRuleKey::new(5, vec![], vec![], RuleBucket::Verification), + ForestRuleKey::new(2, vec![6], vec![2], RuleBucket::Undefined), + ]; + let mut tb = TableMethod::new(); + for rule in rules.into_iter() { + tb.add_rule_key(rule); + } + let spec = extract_specification(0, tb); + assert_eq!(spec.len(), 6); + } +} diff --git a/src/forest/function.rs b/src/forest/function.rs new file mode 100644 index 000000000..9c69c9a41 --- /dev/null +++ b/src/forest/function.rs @@ -0,0 +1,337 @@ +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum IntOrInf { + Int(u8), + Infinity, +} + +impl IntOrInf { + pub fn is_finite(&self) -> bool { + match self { + IntOrInf::Int(_) => true, + IntOrInf::Infinity => false, + } + } + + pub fn is_infinite(&self) -> bool { + !self.is_finite() + } +} + +/// A representation of a function from N to N U {∞} +/// +/// The default value of th function is 0. +pub struct Function { + values: Vec, + preimage_count: Vec, + infinity_count: u32, +} + +impl Function { + pub fn new() -> Function { + Function { + values: vec![], + preimage_count: vec![], + infinity_count: 0, + } + } + + /// Get the current function value for the given input + pub fn get_value(&self, input: u32) -> &IntOrInf { + self.values.get(input as usize).unwrap_or(&IntOrInf::Int(0)) + } + + /// Increase by one the value of the function for the given input. + pub fn increase_value(&mut self, input: u32) { + let old_value = self.values.get_mut(input as usize); + match old_value { + Some(IntOrInf::Int(value)) => { + self.preimage_count[*value as usize] -= 1; + *value += 1; + if self.preimage_count.len() <= *value as usize { + self.preimage_count.resize(*value as usize + 1, 0); + } + self.preimage_count[*value as usize] += 1; + } + Some(IntOrInf::Infinity) => (), + None => { + if self.preimage_count.len() < 2 { + self.preimage_count.resize(2, 0); + } + self.preimage_count[0] += input - self.values.len() as u32; + self.preimage_count[1] += 1; + self.values.resize(input as usize, IntOrInf::Int(0)); + self.values.push(IntOrInf::Int(1)); + } + } + } + + /// Set the value to infinity for the given input + pub fn set_infinite(&mut self, input: u32) { + let old_value = self.values.get_mut(input as usize); + match old_value { + Some(IntOrInf::Int(value)) => { + self.preimage_count[*value as usize] -= 1; + self.infinity_count += 1; + *old_value.unwrap() = IntOrInf::Infinity; + } + Some(IntOrInf::Infinity) => (), + None => { + self.preimage_count[0] += input - self.values.len() as u32; + self.preimage_count.resize(2, 0); + self.preimage_count[1] += 1; + self.values.resize(input as usize, IntOrInf::Int(0)); + self.values.push(IntOrInf::Int(1)); + } + } + } + + /// Number of value for which a value is registered + pub fn len(&self) -> u32 { + self.values.len() as u32 + } + + /// Return the preimage of the given input + /// + /// # Panic + /// + /// This function will panic on a value of 0 as the preimage is not well + /// defined. + pub fn preimage(&self, value: IntOrInf) -> FunctionPreImageIterator { + match value { + IntOrInf::Int(0) => panic!("The preimage of 0 is infinite"), + _ => FunctionPreImageIterator::new(self, value), + } + } + + /// Return the smallest k such that the preimage of the interval + /// [k, k+length-1] is empty. + /// + /// # Panic + /// + /// This function will panic on a gap size of 0 gap is not well defined. + pub fn preimage_gap(&self, gap_size: u32) -> u32 { + if gap_size == 0 { + panic!("Gap of size 0 is not well defined."); + } + let mut last_non_zero: u32 = 0; + for (i, v) in self.preimage_count.iter().enumerate() { + if *v != 0 { + last_non_zero = i as u32; + } else if i as u32 - last_non_zero >= gap_size { + return last_non_zero + 1; + } + } + last_non_zero + 1 + } + + pub fn get_infinity_count(&self) -> u32 { + self.infinity_count + } + + pub fn get_preimage_count(&self) -> &Vec { + &self.preimage_count + } +} + +pub struct FunctionPreImageIterator<'a> { + function: &'a Function, + value: IntOrInf, + pos: u32, +} + +impl<'a> FunctionPreImageIterator<'a> { + fn new(function: &Function, value: IntOrInf) -> FunctionPreImageIterator { + FunctionPreImageIterator { + function, + value, + pos: 0, + } + } +} + +impl<'a> Iterator for FunctionPreImageIterator<'a> { + type Item = u32; + + fn next(&mut self) -> Option { + while self.pos < self.function.len() { + let value_for_pos = self.function.get_value(self.pos); + self.pos += 1; + if *value_for_pos == self.value { + return Some(self.pos as u32 - 1); + } + } + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn create_function() { + let mut f = Function::new(); + } + + #[test] + fn add_value() { + let mut f = Function::new(); + assert_eq!(*f.get_value(0), IntOrInf::Int(0)); + assert_eq!(*f.get_value(0), IntOrInf::Int(0)); + f.increase_value(0); + assert_eq!(*f.get_value(0), IntOrInf::Int(1)); + f.increase_value(3); + assert_eq!(*f.get_value(4), IntOrInf::Int(0)); + f.increase_value(4); + assert_eq!(*f.get_value(0), IntOrInf::Int(1)); + assert_eq!(*f.get_value(1), IntOrInf::Int(0)); + assert_eq!(*f.get_value(2), IntOrInf::Int(0)); + assert_eq!(*f.get_value(3), IntOrInf::Int(1)); + assert_eq!(*f.get_value(4), IntOrInf::Int(1)); + assert_eq!(*f.get_value(5), IntOrInf::Int(0)); + assert_eq!(*f.get_value(6), IntOrInf::Int(0)); + } + + #[test] + fn preimage() { + let mut f = Function::new(); + f.increase_value(0); + f.increase_value(3); + f.increase_value(4); + let mut preimage1: Vec<_> = f.preimage(IntOrInf::Int(1)).collect(); + preimage1.sort(); + assert_eq!(preimage1, vec![0, 3, 4]); + assert_eq!(f.preimage(IntOrInf::Int(2)).count(), 0); + } + + #[test] + #[should_panic(expected = "The preimage of 0 is infinite")] + fn preimage_0() { + let mut f = Function::new(); + f.increase_value(0); + f.increase_value(3); + f.increase_value(4); + f.preimage(IntOrInf::Int(0)); + } + + #[test] + fn infinity() { + let mut f = Function::new(); + f.increase_value(0); + f.increase_value(3); + f.increase_value(4); + f.set_infinite(3); + assert_eq!(*f.get_value(0), IntOrInf::Int(1)); + assert_eq!(*f.get_value(1), IntOrInf::Int(0)); + assert_eq!(*f.get_value(2), IntOrInf::Int(0)); + assert_eq!(*f.get_value(3), IntOrInf::Infinity); + assert_eq!(*f.get_value(4), IntOrInf::Int(1)); + } + + #[test] + fn preimage_inf() { + let mut f = Function::new(); + f.increase_value(0); + f.increase_value(3); + f.increase_value(4); + f.set_infinite(3); + assert_eq!(f.preimage(IntOrInf::Infinity).collect::>(), vec![3]); + let mut preimage1: Vec<_> = f.preimage(IntOrInf::Int(1)).collect(); + preimage1.sort(); + assert_eq!(preimage1, vec![0, 4]); + } + + #[test] + fn bug() { + let mut f = Function::new(); + f.increase_value(1); + println!("{:?}", f.preimage_count); + f.increase_value(1); + println!("{:?}", f.preimage_count); + f.increase_value(1); + println!("{:?}", f.preimage_count); + f.increase_value(5); + println!("{:?}", f.preimage_count); + f.increase_value(5); + println!("{:?}", f.preimage_count); + f.increase_value(4); + println!("{:?}", f.preimage_count); + f.increase_value(5); + println!("{:?}", f.preimage_count); + } + + #[test] + fn preimage_count() { + let mut f = Function::new(); + f.increase_value(0); + assert_eq!(f.preimage_count, vec![0, 1]); + f.increase_value(0); + assert_eq!(f.preimage_count, vec![0, 0, 1]); + f.increase_value(0); + assert_eq!(f.preimage_count, vec![0, 0, 0, 1]); + f.increase_value(0); + assert_eq!(f.preimage_count, vec![0, 0, 0, 0, 1]); + f.increase_value(1); + assert_eq!(f.preimage_count, vec![0, 1, 0, 0, 1]); + f.increase_value(2); + assert_eq!(f.preimage_count, vec![0, 2, 0, 0, 1]); + } + + #[test] + fn preimage_gap_inf() { + let mut f = Function::new(); + f.increase_value(0); + f.increase_value(3); + f.increase_value(4); + f.set_infinite(3); + assert_eq!(f.preimage_gap(100), 2); + } + + #[test] + fn find_gap() { + let mut f = Function::new(); + f.increase_value(0); + f.increase_value(0); + f.increase_value(0); + f.increase_value(0); + f.increase_value(1); + f.increase_value(2); + assert_eq!(*f.get_value(0), IntOrInf::Int(4)); + assert_eq!(*f.get_value(1), IntOrInf::Int(1)); + assert_eq!(*f.get_value(2), IntOrInf::Int(1)); + assert_eq!(*f.get_value(3), IntOrInf::Int(0)); + assert_eq!(*f.get_value(4), IntOrInf::Int(0)); + assert_eq!(*f.get_value(5), IntOrInf::Int(0)); + assert_eq!(*f.get_value(6), IntOrInf::Int(0)); + assert_eq!(f.preimage_gap(1), 2); + assert_eq!(f.preimage_gap(2), 2); + assert_eq!(f.preimage_gap(3), 5); + } + + #[test] + #[should_panic(expected = "0")] + fn find_size_zero_gap() { + let f = Function::new(); + f.preimage_gap(0); + } + + #[test] + fn find_gap2() { + let mut f = Function::new(); + f.increase_value(2); + f.increase_value(3); + f.increase_value(4); + f.increase_value(5); + f.increase_value(5); + f.increase_value(5); + assert_eq!(*f.get_value(0), IntOrInf::Int(0)); + assert_eq!(*f.get_value(1), IntOrInf::Int(0)); + assert_eq!(*f.get_value(2), IntOrInf::Int(1)); + assert_eq!(*f.get_value(3), IntOrInf::Int(1)); + assert_eq!(*f.get_value(4), IntOrInf::Int(1)); + assert_eq!(*f.get_value(5), IntOrInf::Int(3)); + assert_eq!(f.preimage_gap(1), 2); + assert_eq!(f.preimage_gap(2), 4); + assert_eq!(f.preimage_gap(3), 4); + } +} diff --git a/src/forest/table_method.rs b/src/forest/table_method.rs new file mode 100644 index 000000000..f28ca58f4 --- /dev/null +++ b/src/forest/table_method.rs @@ -0,0 +1,855 @@ +use super::function::{Function, IntOrInf}; +use std::collections::hash_map::DefaultHasher; + +use core::slice::Iter; +use pyo3::class::basic::CompareOp; +use pyo3::prelude::*; +use pyo3::types::PyTuple; +use std; +use std::collections::{HashMap, HashSet, VecDeque}; +use std::hash::{Hash, Hasher}; +use std::mem; + +use super::extractor::extract_specification; + +#[pyclass] +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub enum RuleBucket { + #[pyo3(name = "UNDEFINED")] + Undefined, + #[pyo3(name = "VERIFICATION")] + Verification, + #[pyo3(name = "EQUIV")] + Equiv, + #[pyo3(name = "NORMAL")] + Normal, + #[pyo3(name = "REVERSE")] + Reverse, +} + +#[pymethods] +impl RuleBucket { + fn __hash__(&self) -> u64 { + match self { + RuleBucket::Undefined => 0, + RuleBucket::Verification => 1, + RuleBucket::Equiv => 2, + RuleBucket::Normal => 3, + RuleBucket::Reverse => 4, + } + } + + #[getter] + fn get_name(&self) -> &str { + match self { + RuleBucket::Undefined => "UNDEFINED", + RuleBucket::Verification => "VERIFICATION", + RuleBucket::Equiv => "EQUIV", + RuleBucket::Normal => "NORMAL", + RuleBucket::Reverse => "REVERSE", + } + } +} + +struct RuleClassConnector { + rule_using_class: HashMap>, + rule_pumping_class: HashMap>, +} + +impl RuleClassConnector { + pub fn new() -> RuleClassConnector { + RuleClassConnector { + rule_using_class: HashMap::new(), + rule_pumping_class: HashMap::new(), + } + } + + pub fn add_rule_pumping_class(&mut self, class: u32, rule_idx: usize) { + let entry = self.rule_pumping_class.entry(class).or_insert(vec![]); + entry.push(rule_idx); + } + + pub fn add_rule_using_class(&mut self, class: u32, rule_idx: usize, child_idx: usize) { + let entry = self.rule_using_class.entry(class).or_insert(vec![]); + entry.push((rule_idx, child_idx)); + } + + pub fn get_rules_pumping_class(&self, class: u32) -> impl Iterator { + self.rule_pumping_class + .get(&class) + .map(|v| v.iter()) + .unwrap_or([][..].iter()) + } + + pub fn get_rules_using_class(&self, class: u32) -> impl Iterator { + self.rule_using_class + .get(&class) + .map(|v| v.iter()) + .unwrap_or([][..].iter()) + } + + pub fn remove_class_information(&mut self, class: u32) { + todo!("This implementation is not correct"); + self.rule_using_class.remove(&class); + self.rule_pumping_class.remove(&class); + } +} + +#[pyclass] +#[derive(Debug, Hash, Clone, PartialEq, Eq)] +pub struct ForestRuleKey { + #[pyo3(get)] + parent: u32, + children: Vec, + shifts: Vec, + #[pyo3(get)] + bucket: RuleBucket, +} + +impl ForestRuleKey { + pub fn new( + parent: u32, + children: Vec, + shifts: Vec, + bucket: RuleBucket, + ) -> ForestRuleKey { + ForestRuleKey { + parent, + children, + shifts, + bucket, + } + } + + pub fn key(&self) -> (&u32, &Vec) { + (&self.parent, &self.children) + } + + pub fn get_bucket(&self) -> &RuleBucket { + &self.bucket + } + + pub fn get_parent(&self) -> &u32 { + &self.parent + } + + pub fn iter_children(&self) -> Iter { + self.children.iter() + } +} + +#[pymethods] +impl ForestRuleKey { + #[new] + fn py_new(parent: u32, children: Vec, shifts: Vec, bucket: RuleBucket) -> Self { + ForestRuleKey::new(parent, children, shifts, bucket) + } + + #[getter] + fn get_key(&self, py: Python<'_>) -> (u32, Py) { + (self.parent, self.get_children(py)) + } + + #[getter] + fn get_children(&self, py: Python<'_>) -> Py { + PyTuple::new(py, self.children.clone()).into() + } + + #[getter] + fn get_shifts(&self, py: Python<'_>) -> Py { + PyTuple::new(py, self.shifts.clone()).into() + } + + fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject { + match op { + CompareOp::Eq => (self == other).into_py(py), + CompareOp::Ne => (self != other).into_py(py), + _ => py.NotImplemented(), + } + } + + fn __hash__(&self) -> u64 { + let mut s = DefaultHasher::new(); + self.hash(&mut s); + s.finish() + } +} + +pub struct TableMethod { + rules: Vec, + shifts: Vec>>, + function: Function, + gap_size: u32, + // Both for rule using and rule pumping class + rule_class_connector: RuleClassConnector, + processing_queue: VecDeque, + current_gap: (u32, u32), + rule_holding_extra_terms: HashSet, +} + +impl TableMethod { + pub fn new() -> TableMethod { + TableMethod { + rules: vec![], + shifts: vec![], + function: Function::new(), + gap_size: 1, + rule_class_connector: RuleClassConnector::new(), + processing_queue: VecDeque::new(), + current_gap: (1, 1), + rule_holding_extra_terms: HashSet::new(), + } + } + + /// Add the rule to the database + pub fn add_rule_key(&mut self, rule_key: ForestRuleKey) -> &ForestRuleKey { + self.rules.push(rule_key); + let rule_key = self.rules.last().unwrap(); + self.shifts.push(self.compute_shift(&rule_key)); + let max_gap = rule_key.shifts.iter().map(|s| s.abs()).max().unwrap_or(0) as u32; + if max_gap > self.gap_size { + self.gap_size = max_gap; + self.correct_gap(); + } + // Because the correct gap need a mutable reference to self we need to + // invalidate the immutable reference hold by rule_key. + let rule_key = self.rules.last().unwrap(); + if self.function.get_value(rule_key.parent).is_finite() { + let rule_idx = self.rules.len() - 1; + self.rule_class_connector + .add_rule_pumping_class(rule_key.parent, rule_idx); + for (child_idx, child) in rule_key.children.iter().enumerate() { + if self.function.get_value(*child).is_finite() { + self.rule_class_connector + .add_rule_using_class(*child, rule_idx, child_idx); + } + } + self.processing_queue.push_back(rule_idx); + } + self.process_queue(); + let rule_key = self.rules.last().unwrap(); + rule_key + } + + /// Determine if the comb_class is pumping in the current universe. + pub fn is_pumping(&self, class: u32) -> bool { + self.function.get_value(class).is_infinite() + } + + pub fn stable_subset(&self) -> impl Iterator + '_ { + self.function.preimage(IntOrInf::Infinity) + } + + /// Iterator over all the forest rule keys that contain only pumping + /// combinatorial classes. + pub fn pumping_subuniverse(&self) -> impl Iterator { + self.rules.iter().filter(move |forest_key| { + self.is_pumping(forest_key.parent) + && forest_key.children.iter().all(|c| self.is_pumping(*c)) + }) + } + + /// Consumes the self and iterate all the forest rule keys that contain only pumping + /// combinatorial classes. + pub fn into_pumping_subuniverse(self) -> impl Iterator { + let stable_subset: HashSet<_> = self.stable_subset().collect(); + self.rules.into_iter().filter(move |forest_key| { + stable_subset.contains(&forest_key.parent) + && forest_key + .children + .iter() + .all(|c| stable_subset.contains(c)) + }) + } + + /// Consumes self and returns all the forest rule keys it contains. + pub fn into_rules(self) -> impl Iterator { + self.rules.into_iter() + } + + pub fn status(&self) -> String { + let mut s = String::new(); + s += &format!("\tSize of the gap: {}\n", self.gap_size); + s += &format!("\tSize of the stable subset: {}\n", self.function.get_infinity_count()); + s += &format!("\tSizes of the pre-images: {:?}\n", self.function.get_preimage_count()); + s + } + + pub fn get_last_added_rule_key(&self) -> Option<&ForestRuleKey> { + self.rules.last() + } + + /// Compute the initial value for the shifts a rule based on the current state of + /// the function. + fn compute_shift(&self, rule_key: &ForestRuleKey) -> Vec> { + let parent_curent_value = self.function.get_value(rule_key.parent); + match parent_curent_value { + IntOrInf::Infinity => vec![None; rule_key.children.len()], + IntOrInf::Int(parent_curent_value) => { + let children_function_values = rule_key + .children + .iter() + .map(|c| self.function.get_value(*c)); + children_function_values + .zip(&rule_key.shifts) + .map(|(fvalue, sfz)| match fvalue { + IntOrInf::Infinity => None, + IntOrInf::Int(fvalue) => { + Some(*fvalue as i8 + sfz - *parent_curent_value as i8) + } + }) + .collect() + } + } + } + + /// Correct the gap and if needed queue rules for the classes that were previously + /// on the right hand side of the gap. + /// + /// This should be toggled every time the gap changes whether the size changes or + /// some value changes of the function caused the gap to change. + fn correct_gap(&mut self) { + let k = self.function.preimage_gap(self.gap_size); + let new_gap = (k, k + self.gap_size); + if new_gap.1 > self.current_gap.1 { + self.processing_queue + .extend(self.rule_holding_extra_terms.iter()); + self.rule_holding_extra_terms.clear(); + } + self.current_gap = new_gap; + } + + /// Try to make improvement with all the class in the processing queue. + fn process_queue(&mut self) { + while !self.processing_queue.is_empty() || !self.rule_holding_extra_terms.is_empty() { + while let Some(rule_idx) = self.processing_queue.pop_front() { + let shifts = self.shifts.get(rule_idx).unwrap(); + if TableMethod::can_give_terms(shifts) { + let parent = self.rules[rule_idx].parent; + self.increase_value(parent, rule_idx); + } + } + if let Some(&rule_idx) = self.rule_holding_extra_terms.iter().next() { + self.rule_holding_extra_terms.remove(&rule_idx); + let parent = self.rules[rule_idx].parent; + self.set_infinite(parent); + } + } + } + + /// Return true if the shifts indicate that a new term can be computed + fn can_give_terms(shifts: &Vec>) -> bool { + shifts.iter().all(|&s| s.map_or(true, |x| x > 0)) + } + + /// Increase the value of the comb_class and put on the processing stack any rule + /// that can now give a new term. + /// + ///The rule_idx must indicate the rule used to justify the increase. + fn increase_value(&mut self, class: u32, rule_idx: usize) { + let current_value = self.function.get_value(class); + let current_value = match current_value { + IntOrInf::Infinity => return, + IntOrInf::Int(v) => *v, + }; + if current_value as u32 > self.current_gap.1 { + self.rule_holding_extra_terms.insert(rule_idx); + return; + } + self.function.increase_value(class); + // Correction of the gap + let gap_start = self.function.preimage_gap(self.gap_size); + if self.current_gap.0 != gap_start { + self.correct_gap() + } + // Correction of shifts for rule pumping class + for &r_idx in self.rule_class_connector.get_rules_pumping_class(class) { + let shifts = self.shifts.get_mut(r_idx).unwrap(); + for v in shifts.iter_mut() { + *v = v.map(|v| v - 1); + } + if TableMethod::can_give_terms(shifts) { + self.processing_queue.push_back(r_idx) + } + } + // Correction of the shifts for rule using the class + for &(r_idx, class_idx) in self.rule_class_connector.get_rules_using_class(class) { + let shifts = self.shifts.get_mut(r_idx).unwrap(); + let current_shift = shifts.get_mut(class_idx).unwrap(); + *current_shift = current_shift.map(|v| v + 1); + if TableMethod::can_give_terms(shifts) { + self.processing_queue.push_back(r_idx); + } + } + } + + /// Set the value if the class to infinity. + /// + /// This should happen when we know that we cannot pump anything on the left side + /// of the gap. + fn set_infinite(&mut self, class: u32) { + let current_value = self.function.get_value(class); + let current_value = match current_value { + IntOrInf::Infinity => return, + IntOrInf::Int(v) => v, + }; + assert!(*current_value as u32 > self.current_gap.1); + assert!(self.processing_queue.is_empty()); + self.function.set_infinite(class); + // This class will never be increased again so we remove any occurrence + // of the rule of any rule for that class from _rules_using_class and + //_rules_pumping_class + // TODO: implement that later since its only for saving memory + // for rule_idx in self._rules_pumping_class[comb_class]: + // for child in self._rules[rule_idx].children: + // self._rules_using_class[child] = [ + // (ri, ci) + // for ri, ci in self._rules_using_class[child] + // if ri != rule_idx + // ] + // self._rules_pumping_class[comb_class].clear() + // Correction of the shifts for rules using comb_class to pump + for &(rule_idx, class_idx) in self.rule_class_connector.get_rules_using_class(class) { + let shifts = self.shifts.get_mut(rule_idx).unwrap(); + shifts[class_idx] = None; + if TableMethod::can_give_terms(shifts) { + self.processing_queue.push_back(rule_idx) + } + } + // TODO: same as above + // self._rules_using_class[comb_class].clear() + } +} + +#[pyclass(name="TableMethod")] +pub struct TableMethodPyWrapper { + table_method: Option +} + +#[pymethods] +impl TableMethodPyWrapper { + #[new] + fn py_new() -> Self { + Self { + table_method: Some(TableMethod::new()) + } + } + + fn extract_specification(&mut self, root_class: u32) -> Vec { + let table_method = mem::replace(&mut self.table_method, None).unwrap(); + extract_specification(root_class, table_method) + } + + fn add_rule_key(&mut self, rule_key: ForestRuleKey) { + self.table_method.as_mut().unwrap().add_rule_key(rule_key); + } + + #[getter] + fn get_function(&self) -> HashMap> { + let mut map = HashMap::new(); + for pos in 0..self.table_method.as_ref().unwrap().function.len() { + let value = self.table_method.as_ref().unwrap().function.get_value(pos); + match value { + IntOrInf::Infinity => {map.insert(pos, None);}, + IntOrInf::Int(0) => (), + IntOrInf::Int(x) => {map.insert(pos, Some(*x));}, + } + } + map + } + + fn is_pumping(&self, label: u32) -> bool { + self.table_method.as_ref().unwrap().is_pumping(label) + } + + fn pumping_subuniverse(&self) -> Vec { + self.table_method.as_ref().unwrap().pumping_subuniverse().map(|rk| rk.clone()).collect() + } + + fn status(&self) -> String { + match self.table_method.as_ref() { + Some(tb) => tb.status(), + None => String::from("No status") + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// The universe consist of the rule of the usual 132 tree plus a dummy rule that is + /// useless. + #[test] + fn pumping_132_universe_test() { + let rules = vec![ + ForestRuleKey::new(0, vec![1, 2], vec![0, 0], RuleBucket::Normal), + ForestRuleKey::new(1, vec![], vec![], RuleBucket::Verification), + ForestRuleKey::new(2, vec![3], vec![0], RuleBucket::Equiv), + ForestRuleKey::new(3, vec![4], vec![0], RuleBucket::Equiv), + ForestRuleKey::new(4, vec![5, 0, 0], vec![0, 1, 1], RuleBucket::Normal), + ForestRuleKey::new(5, vec![], vec![], RuleBucket::Verification), + ForestRuleKey::new(2, vec![6], vec![2], RuleBucket::Undefined), + ]; + let mut tb = TableMethod::new(); + for rule in rules.into_iter() { + tb.add_rule_key(rule); + } + for i in 0..6 { + assert_eq!(*tb.function.get_value(i), IntOrInf::Infinity) + } + assert!((0..6).all(|c| tb.is_pumping(c))); + assert!(!tb.is_pumping(6)); + let mut pu: Vec<_> = tb + .pumping_subuniverse() + .map(|forest_key| forest_key.key()) + .collect(); + pu.sort(); + assert_eq!( + pu, + vec![ + (&0, &vec![1, 2]), + (&1, &vec![]), + (&2, &vec![3]), + (&3, &vec![4]), + (&4, &vec![5, 0, 0]), + (&5, &vec![]) + ] + ); + } + + /// The universe consist of the rule of the usual 132 tree plus a dummy rule that is + /// useless. + + /// We add rule progressively and make sure the function is always up to date. + #[test] + fn universe132_pumping_progressive_test() { + let mut tb = TableMethod::new(); + + // Point insertion + tb.add_rule_key(ForestRuleKey::new( + 0, + vec![1, 2], + vec![0, 0], + RuleBucket::Normal, + )); + assert_eq!(tb.function.get_value(0), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(1), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(2), &IntOrInf::Int(0)); + // Empty verif + tb.add_rule_key(ForestRuleKey::new( + 1, + vec![], + vec![], + RuleBucket::Verification, + )); + assert_eq!(tb.function.get_value(0), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(1), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(2), &IntOrInf::Int(0)); + // Point placement + tb.add_rule_key(ForestRuleKey::new(2, vec![3], vec![0], RuleBucket::Equiv)); + assert_eq!(tb.function.get_value(0), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(1), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(2), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(3), &IntOrInf::Int(0)); + // Row col sept + tb.add_rule_key(ForestRuleKey::new(3, vec![4], vec![0], RuleBucket::Equiv)); + assert_eq!(tb.function.get_value(0), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(1), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(2), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(3), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(4), &IntOrInf::Int(0)); + // Point verif + tb.add_rule_key(ForestRuleKey::new( + 5, + vec![], + vec![], + RuleBucket::Verification, + )); + assert_eq!(tb.function.get_value(0), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(1), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(2), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(3), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(4), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(5), &IntOrInf::Infinity); + // Dumb rule + tb.add_rule_key(ForestRuleKey::new( + 2, + vec![6], + vec![-2], + RuleBucket::Undefined, + )); + assert_eq!(tb.function.get_value(0), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(1), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(2), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(3), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(4), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(5), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(6), &IntOrInf::Int(0)); + // Dumb rule. This will pump 2 and 0 a little bit + tb.add_rule_key(ForestRuleKey::new( + 2, + vec![7], + vec![2], + RuleBucket::Undefined, + )); + assert_eq!(tb.function.get_value(0), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(1), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(2), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(3), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(4), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(5), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(6), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(7), &IntOrInf::Int(0)); + // Factor + tb.add_rule_key(ForestRuleKey::new( + 4, + vec![5, 0, 0], + vec![0, 1, 1], + RuleBucket::Normal, + )); + assert_eq!(tb.function.get_value(0), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(1), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(2), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(3), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(4), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(5), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(6), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(7), &IntOrInf::Int(0)); + } + + #[test] + fn universe_not_pumping_test() { + let rules = vec![ + ForestRuleKey::new(0, vec![1, 2], vec![0, 0], RuleBucket::Normal), + ForestRuleKey::new(5, vec![], vec![], RuleBucket::Verification), + ForestRuleKey::new(2, vec![3], vec![0], RuleBucket::Normal), + ForestRuleKey::new(3, vec![4], vec![0], RuleBucket::Normal), + ForestRuleKey::new(4, vec![5, 0, 0], vec![0, 1, 1], RuleBucket::Normal), + ]; + let mut tb = TableMethod::new(); + for rule in rules.into_iter() { + tb.add_rule_key(rule); + } + assert_eq!(tb.function.get_value(0), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(1), &IntOrInf::Int(0)); + assert_eq!(tb.function.get_value(2), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(3), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(4), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(5), &IntOrInf::Infinity); + } + + #[test] + fn segmented_test() { + let mut tb = TableMethod::new(); + + tb.add_rule_key(ForestRuleKey::new( + 0, + vec![1, 2], + vec![0, 0], + RuleBucket::Undefined, + )); + tb.add_rule_key(ForestRuleKey::new( + 1, + vec![4, 14], + vec![0, 0], + RuleBucket::Undefined, + )); + tb.add_rule_key(ForestRuleKey::new(2, vec![], vec![], RuleBucket::Undefined)); + assert_eq!(tb.function.get_value(2), &IntOrInf::Infinity); + + tb.add_rule_key(ForestRuleKey::new( + 3, + vec![16, 5], + vec![1, 0], + RuleBucket::Undefined, + )); + tb.add_rule_key(ForestRuleKey::new(4, vec![], vec![], RuleBucket::Undefined)); + tb.add_rule_key(ForestRuleKey::new(5, vec![], vec![], RuleBucket::Undefined)); + assert_eq!(tb.function.get_value(2), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(3), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(4), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(5), &IntOrInf::Infinity); + + // Induced a gap size change + tb.add_rule_key(ForestRuleKey::new( + 6, + vec![7, 5, 17], + vec![2, 1, 1], + RuleBucket::Undefined, + )); + assert_eq!(tb.function.get_value(2), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(3), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(4), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(5), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(6), &IntOrInf::Int(1)); + + tb.add_rule_key(ForestRuleKey::new( + 16, + vec![6], + vec![0], + RuleBucket::Undefined, + )); + assert_eq!(tb.function.get_value(2), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(3), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(4), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(5), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(6), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(16), &IntOrInf::Int(1)); + + tb.add_rule_key(ForestRuleKey::new(7, vec![], vec![], RuleBucket::Undefined)); + tb.add_rule_key(ForestRuleKey::new( + 8, + vec![9, 5], + vec![1, 0], + RuleBucket::Undefined, + )); + assert_eq!(tb.function.get_value(2), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(3), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(4), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(5), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(6), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(7), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(8), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(16), &IntOrInf::Int(1)); + + tb.add_rule_key(ForestRuleKey::new( + 12, + vec![20, 5], + vec![-1, 0], + RuleBucket::Undefined, + )); + tb.add_rule_key(ForestRuleKey::new( + 20, + vec![13], + vec![0], + RuleBucket::Undefined, + )); + tb.add_rule_key(ForestRuleKey::new( + 13, + vec![15, 2, 5], + vec![-1, 1, 0], + RuleBucket::Undefined, + )); + tb.add_rule_key(ForestRuleKey::new( + 15, + vec![1], + vec![0], + RuleBucket::Undefined, + )); + tb.add_rule_key(ForestRuleKey::new( + 14, + vec![3], + vec![0], + RuleBucket::Undefined, + )); + assert_eq!(tb.function.get_value(0), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(1), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(2), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(3), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(4), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(5), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(6), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(7), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(8), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(13), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(14), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(15), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(16), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(20), &IntOrInf::Int(1)); + + tb.add_rule_key(ForestRuleKey::new( + 18, + vec![8], + vec![0], + RuleBucket::Undefined, + )); + tb.add_rule_key(ForestRuleKey::new( + 11, + vec![12, 18], + vec![0, 0], + RuleBucket::Undefined, + )); + assert_eq!(tb.function.get_value(0), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(1), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(2), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(3), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(4), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(5), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(6), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(7), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(8), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(13), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(14), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(15), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(16), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(18), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(20), &IntOrInf::Int(1)); + + tb.add_rule_key(ForestRuleKey::new( + 17, + vec![8], + vec![0], + RuleBucket::Undefined, + )); + assert_eq!(tb.function.get_value(0), &IntOrInf::Int(3)); + assert_eq!(tb.function.get_value(1), &IntOrInf::Int(3)); + assert_eq!(tb.function.get_value(2), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(3), &IntOrInf::Int(3)); + assert_eq!(tb.function.get_value(4), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(5), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(6), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(7), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(8), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(11), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(12), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(13), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(14), &IntOrInf::Int(3)); + assert_eq!(tb.function.get_value(15), &IntOrInf::Int(3)); + assert_eq!(tb.function.get_value(16), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(17), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(18), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(20), &IntOrInf::Int(2)); + + tb.add_rule_key(ForestRuleKey::new( + 9, + vec![0, 19], + vec![0, 0], + RuleBucket::Undefined, + )); + tb.add_rule_key(ForestRuleKey::new( + 10, + vec![5, 11], + vec![0, 1], + RuleBucket::Undefined, + )); + assert_eq!(tb.function.get_value(0), &IntOrInf::Int(3)); + assert_eq!(tb.function.get_value(1), &IntOrInf::Int(3)); + assert_eq!(tb.function.get_value(2), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(3), &IntOrInf::Int(3)); + assert_eq!(tb.function.get_value(4), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(5), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(6), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(7), &IntOrInf::Infinity); + assert_eq!(tb.function.get_value(8), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(10), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(11), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(12), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(13), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(14), &IntOrInf::Int(3)); + assert_eq!(tb.function.get_value(15), &IntOrInf::Int(3)); + assert_eq!(tb.function.get_value(16), &IntOrInf::Int(2)); + assert_eq!(tb.function.get_value(17), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(18), &IntOrInf::Int(1)); + assert_eq!(tb.function.get_value(20), &IntOrInf::Int(2)); + + tb.add_rule_key(ForestRuleKey::new( + 19, + vec![10], + vec![0], + RuleBucket::Undefined, + )); + assert!((0..21).all(|c| tb.function.get_value(c) == &IntOrInf::Infinity)); + assert!((0..21).all(|c| tb.is_pumping(c))); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 000000000..203c0d21a --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,21 @@ +use pyo3::prelude::*; + +/// Formats the sum of two numbers as string. +#[pyfunction] +fn sum_as_string(a: usize, b: usize) -> PyResult { + Ok((a + b).to_string()) +} + +mod forest; + +/// A Python module implemented in Rust. The name of this function must match +/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to +/// import the module. +#[pymodule] +fn comb_spec_searcher_rs(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) +} diff --git a/tests/test_forest.py b/tests/test_forest.py index b4a3268d5..0bd537964 100644 --- a/tests/test_forest.py +++ b/tests/test_forest.py @@ -1,119 +1,19 @@ -import itertools -from typing import Dict, List, Union - -import pytest +from comb_spec_searcher_rs import ForestRuleKey, RuleBucket from comb_spec_searcher import CombinatorialSpecificationSearcher -from comb_spec_searcher.rule_db.forest import Function, RuleDBForest, TableMethod +from comb_spec_searcher.rule_db.forest import RuleDBForest, TableMethod from comb_spec_searcher.strategies.strategy import EmptyStrategy -from comb_spec_searcher.typing import ForestRuleKey, RuleBucket from example import AvoidingWithPrefix, pack -def assert_function_values( - function: Function, values: Union[Dict[int, int], List[int]] -): - """ - Check that the function values matches the given values. - """ - if isinstance(values, dict): - values = [values[i] if i in values else 0 for i in range(max(values) + 1)] - failing_index = [] - for n, (fv, gv) in enumerate( - itertools.zip_longest(function._value, values, fillvalue=0) - ): - if not (fv == gv): - failing_index.append(n) - - if failing_index: - m = f"Function differs on index {failing_index}\n" - m += f"The function values are {function._value}\n" - m += f"The expected values are {values}" - raise AssertionError(m) - - -class TestFunction: - def test_add_value(self): - f = Function() - assert f[0] == 0 - assert f[4] == 0 - f.increase_value(0) - assert f[0] == 1 - f.increase_value(3) - assert f[4] == 0 - f.increase_value(4) - assert f[0] == 1 - assert f[1] == 0 - assert f[2] == 0 - assert f[3] == 1 - assert f[4] == 1 - assert f[5] == 0 - assert f[6] == 0 - with pytest.raises(ValueError): - assert sorted(f.preimage(0)) == [] - assert sorted(f.preimage(1)) == [0, 3, 4] - assert sorted(f.preimage(2)) == [] - - def test_infinity(self): - f = Function() - f.increase_value(0) - f.increase_value(3) - f.increase_value(4) - f.set_infinite(3) - assert f[0] == 1 - assert f[1] == 0 - assert f[2] == 0 - assert f[3] is None - assert f[4] == 1 - assert f.preimage_gap(100) == 2 - assert sorted(f.preimage(None)) == [3] - assert sorted(f.preimage(1)) == [0, 4] - - def test_find_gap(self): - f = Function() - f.increase_value(0) - f.increase_value(0) - f.increase_value(0) - f.increase_value(0) - f.increase_value(1) - f.increase_value(2) - assert f[0] == 4 - assert f[1] == 1 - assert f[2] == 1 - assert f[3] == 0 - assert f[4] == 0 - assert f[5] == 0 - assert f[6] == 0 - assert f.preimage_gap(1) == 2 - assert f.preimage_gap(2) == 2 - assert f.preimage_gap(3) == 5 - with pytest.raises(ValueError): - f.preimage_gap(0) - with pytest.raises(ValueError): - f.preimage_gap(-1) - - def test_find_gap2(self): - f = Function() - f.increase_value(2) - f.increase_value(3) - f.increase_value(4) - f.increase_value(5) - f.increase_value(5) - f.increase_value(5) - print(f) - print(f._preimage_count) - assert f[0] == 0 - assert f[1] == 0 - assert f[2] == 1 - assert f[3] == 1 - assert f[4] == 1 - assert f[5] == 3 - assert f.preimage_gap(1) == 2 - assert f.preimage_gap(2) == 4 - assert f.preimage_gap(3) == 4 - - # Test of the table method +def test_rule_key_eq(): + key1 = ForestRuleKey(0, (1, 2), (0, 0), RuleBucket.NORMAL) + key2 = ForestRuleKey(0, (1, 2), (0, 0), RuleBucket.NORMAL) + key3 = ForestRuleKey(1, (), (), RuleBucket.VERIFICATION) + assert key1 == key2 + assert key1 != key3 + assert key2 != key3 def test_132_universe_pumping(): @@ -224,16 +124,7 @@ def test_segmented(): tb.add_rule_key(ForestRuleKey(7, tuple(), tuple(), RuleBucket.UNDEFINED)) tb.add_rule_key(ForestRuleKey(8, (9, 5), (1, 0), RuleBucket.UNDEFINED)) - assert tb.function == { - 2: None, - 3: 2, - 4: None, - 5: None, - 6: 1, - 7: None, - 8: 1, - 16: 1, - } + assert tb.function == {2: None, 3: 2, 4: None, 5: None, 6: 1, 7: None, 8: 1, 16: 1} tb.add_rule_key(ForestRuleKey(12, (20, 5), (-1, 0), RuleBucket.UNDEFINED)) tb.add_rule_key(ForestRuleKey(20, (13,), (0,), RuleBucket.UNDEFINED)) diff --git a/tests/test_rust.py b/tests/test_rust.py new file mode 100644 index 000000000..3680a7be6 --- /dev/null +++ b/tests/test_rust.py @@ -0,0 +1,5 @@ +from comb_spec_searcher_rs import sum_as_string + + +def test_rust_module(): + assert sum_as_string(2, 5) == "7" diff --git a/tox.ini b/tox.ini index df9d27805..86fe00e33 100644 --- a/tox.ini +++ b/tox.ini @@ -24,6 +24,7 @@ basepython = deps = pytest==7.2.0 pytest-repeat==0.9.1 + pytest-timeout==2.1.0 docutils==0.19 Pygments==2.13.0 commands = pytest @@ -40,7 +41,7 @@ allowlist_externals=sh commands_pre = sh setup_tilescope_test.sh commands = - pytest .tilings/tests + python -m pytest .tilings/tests [testenv:flake8] description = run flake8 (linter) @@ -50,7 +51,7 @@ deps = flake8==5.0.4 flake8-isort==5.0.0 commands = - flake8 --isort-show-traceback comb_spec_searcher tests setup.py example.py + flake8 --isort-show-traceback comb_spec_searcher tests example.py [testenv:pylint] description = run pylint (static code analysis)