Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 51 additions & 42 deletions manim/utils/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import inspect
import json
import zlib
from collections.abc import Callable, Hashable, Iterable
from collections.abc import Callable, Hashable, Iterable, Sequence
from time import perf_counter
from types import FunctionType, MappingProxyType, MethodType, ModuleType
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, overload

import numpy as np

Expand Down Expand Up @@ -47,44 +47,46 @@
content-equality detection.
"""

_already_processed = set()
_already_processed: set[int] = set()

# Can be changed to whatever string to help debugging the JSon generation.
ALREADY_PROCESSED_PLACEHOLDER = "AP"
THRESHOLD_WARNING = 170_000

@classmethod
def reset_already_processed(cls):
def reset_already_processed(cls: type[_Memoizer]) -> None:
cls._already_processed.clear()

@classmethod
def check_already_processed_decorator(cls: _Memoizer, is_method: bool = False):
def check_already_processed_decorator(
cls: type[_Memoizer], is_method: bool = False
) -> Callable:
"""Decorator to handle the arguments that goes through the decorated function.
Returns _ALREADY_PROCESSED_PLACEHOLDER if the obj has been processed, or lets
the decorated function call go ahead.
Returns the value of ALREADY_PROCESSED_PLACEHOLDER if the obj has been processed,
or lets the decorated function call go ahead.

Parameters
----------
is_method
Whether the function passed is a method, by default False.
"""

def layer(func):
def layer(func: Callable[[Any], Any]) -> Callable:
# NOTE : There is probably a better way to separate both case when func is
# a method or a function.
if is_method:
return lambda self, obj: cls._handle_already_processed(
obj,
default_function=lambda obj: func(self, obj),
default_function=lambda obj: func(obj),
)
return lambda obj: cls._handle_already_processed(obj, default_function=func)

return layer

@classmethod
def check_already_processed(cls, obj: Any) -> Any:
def check_already_processed(cls: type[_Memoizer], obj: Any) -> Any:
"""Checks if obj has been already processed. Returns itself if it has not been,
or the value of _ALREADY_PROCESSED_PLACEHOLDER if it has.
or the value of ALREADY_PROCESSED_PLACEHOLDER if it has.
Marks the object as processed in the second case.

Parameters
Expand All @@ -101,7 +103,7 @@
return cls._handle_already_processed(obj, lambda x: x)

@classmethod
def mark_as_processed(cls, obj: Any) -> None:
def mark_as_processed(cls: type[_Memoizer], obj: Any) -> None:
"""Marks an object as processed.

Parameters
Expand All @@ -110,14 +112,14 @@
The object to mark as processed.
"""
cls._handle_already_processed(obj, lambda x: x)
return cls._return(obj, id, lambda x: x, memoizing=False)
cls._return(obj, id, lambda x: x, memoizing=False)

@classmethod
def _handle_already_processed(
cls,
obj,
cls: type[_Memoizer],
obj: Any,
default_function: Callable[[Any], Any],
):
) -> str | Any:
if isinstance(
obj,
(
Expand All @@ -142,11 +144,11 @@

@classmethod
def _return(
cls,
cls: type[_Memoizer],
obj: Any,
obj_to_membership_sign: Callable[[Any], int],
default_func,
memoizing=True,
default_func: Callable[[Any], Any],
memoizing: bool = True,
) -> str | Any:
obj_membership_sign = obj_to_membership_sign(obj)
if obj_membership_sign in cls._already_processed:
Expand All @@ -172,9 +174,8 @@


class _CustomEncoder(json.JSONEncoder):
def default(self, obj: Any):
"""
This method is used to serialize objects to JSON format.
def default(self, obj: Any) -> Any:
"""This method is used to serialize objects to JSON format.

If obj is a function, then it will return a dict with two keys : 'code', for
the code source, and 'nonlocals' for all nonlocalsvalues. (including nonlocals
Expand Down Expand Up @@ -219,7 +220,7 @@
if obj.size > 1000:
obj = np.resize(obj, (100, 100))
return f"TRUNCATED ARRAY: {repr(obj)}"
# We return the repr and not a list to avoid the JsonEncoder to iterate over it.
# We return the repr and not a list to avoid the JSONEncoder to iterate over it.
return repr(obj)
elif hasattr(obj, "__dict__"):
temp = obj.__dict__
Expand All @@ -233,60 +234,68 @@
# Serialize it with only the type of the object. You can change this to whatever string when debugging the serialization process.
return str(type(obj))

def _cleaned_iterable(self, iterable: Iterable[Any]):
@overload
def _cleaned_iterable(self, iterable: dict[Any, Any]) -> dict[Any, Any]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

@overload
def _cleaned_iterable(self, iterable: Sequence) -> list[Any]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def _cleaned_iterable(self, iterable: Iterable) -> dict[Any, Any] | list[Any]:
"""Check for circular reference at each iterable that will go through the JSONEncoder, as well as key of the wrong format.

If a key with a bad format is found (i.e not a int, string, or float), it gets replaced byt its hash using the same process implemented here.
If a circular reference is found within the iterable, it will be replaced by the string "already processed".
If a key with a bad format is found (i.e not a int, string, or float), it gets replaced by its hash using the same process implemented here.
If a circular reference is found within the iterable, it will be replaced by the value of ALREADY_PROCESSED_PLACEHOLDER.

Parameters
----------
iterable
The iterable to check.
"""

def _key_to_hash(key):
def _key_to_hash(key: Any) -> int:
return zlib.crc32(json.dumps(key, cls=_CustomEncoder).encode())

def _iter_check_list(lst):
processed_list = [None] * len(lst)
for i, el in enumerate(lst):
def _iter_check_list(lst: Iterable[Any]) -> list[Any]:
processed_list = []
for el in lst:
el = _Memoizer.check_already_processed(el)
if isinstance(el, (list, tuple)):
new_value = _iter_check_list(el)
elif isinstance(el, dict):
if isinstance(el, dict):
new_value = _iter_check_dict(el)
elif isinstance(el, Iterable) and not isinstance(el, (str, bytes)):
new_value = _iter_check_list(el)
else:
new_value = el
processed_list[i] = new_value
processed_list.append(new_value)
return processed_list

def _iter_check_dict(dct):
def _iter_check_dict(dct: dict[Any, Any]) -> dict[Any, Any]:
processed_dict = {}
for k, v in dct.items():
v = _Memoizer.check_already_processed(v)
if k in KEYS_TO_FILTER_OUT:
continue
# We check if the k is of the right format (supporter by Json)
# We check if the k is of the right format (supported by JSON)
if not isinstance(k, (str, int, float, bool)) and k is not None:
k_new = _key_to_hash(k)
else:
k_new = k
if isinstance(v, dict):
new_value = _iter_check_dict(v)
elif isinstance(v, (list, tuple)):
elif isinstance(v, Iterable) and not isinstance(v, (str, bytes)):
new_value = _iter_check_list(v)
else:
new_value = v
processed_dict[k_new] = new_value
return processed_dict

if isinstance(iterable, (list, tuple)):
return _iter_check_list(iterable)
elif isinstance(iterable, dict):
if isinstance(iterable, dict):
return _iter_check_dict(iterable)
elif isinstance(iterable, Iterable):
return _iter_check_list(iterable)
else:
raise TypeError("'iterable' is neither an iterable nor a dictionary.")

def encode(self, obj: Any):
def encode(self, obj: Any) -> str:
"""Overriding of :meth:`JSONEncoder.encode`, to make our own process.

Parameters
Expand All @@ -305,7 +314,7 @@
return super().encode(obj)


def get_json(obj: dict):
def get_json(obj: Any) -> str:
"""Recursively serialize `object` to JSON using the :class:`CustomEncoder` class.

Parameters
Expand Down
Loading