diff --git a/README.md b/README.md index 2079155..4482db3 100644 --- a/README.md +++ b/README.md @@ -304,6 +304,36 @@ class Point: ordered = True ``` +### Partial Loading + +The `.load` method of a `marshmallow_dataclass` schema can not be used to load partial data. +This is because `load` always wraps the deserialized data in the target dataclass type, +and it is not possible to partially construct a dataclass. + +However `marshmallow_dataclass` schema now provide a `load_to_dict` method that can be +used to deserialize data with `partial=True`. `Load_to_dict` works just like +the plain `marshmallow.Schema.load` method. It returns the deserialized data as a `dict` +(or a list of `dict`s if `many=True`) — no construction of dataclasses is done. + +```pycon +from marshmallow_dataclass import dataclass + + +@dataclass +class Person: + first_name: str + last_name: str + + +>>> Person.Schema().load_to_dict({"first_name": "Joe"}, partial=True) +# => {"first_name": "Joe"} + +>>> Person.Schema().load({"first_name": "Joe"}, partial=True) +# => Traceback (most recent call last): +# ... +# TypeError: Person.__init__() missing 1 required positional argument: 'last_name' +``` + ## Documentation The project documentation is hosted on GitHub Pages: https://lovasoa.github.io/marshmallow_dataclass/ diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index c7b1e0a..b7c57ae 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -40,6 +40,7 @@ class User: import threading import types import warnings +from contextlib import contextmanager from enum import EnumMeta from functools import lru_cache, partial from typing import ( @@ -733,6 +734,28 @@ def field_for_schema( return marshmallow.fields.Nested(nested, **metadata) +class _ThreadLocalBool(threading.local): + """A thread-local boolean flag.""" + + def __init__(self, value: bool = False): + self.value = value + + def __bool__(self): + return self.value + + @contextmanager + def temporarily(self, value: bool): + orig = self.value + self.value = value + try: + yield self + finally: + self.value = orig + + +_disable_magic = _ThreadLocalBool(False) + + def _base_schema( clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None ) -> Type[marshmallow.Schema]: @@ -746,12 +769,20 @@ def _base_schema( class BaseSchema(base_schema or marshmallow.Schema): # type: ignore def load(self, data: Mapping, *, many: bool = None, **kwargs): all_loaded = super().load(data, many=many, **kwargs) + if _disable_magic: + return all_loaded many = self.many if many is None else bool(many) if many: return [clazz(**loaded) for loaded in all_loaded] else: return clazz(**all_loaded) + def load_to_dict( + self, data: Mapping, **kwargs + ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + with _disable_magic.temporarily(True): + return self.load(data, **kwargs) + return BaseSchema diff --git a/tests/test_load_to_dict.py b/tests/test_load_to_dict.py new file mode 100644 index 0000000..817ffd4 --- /dev/null +++ b/tests/test_load_to_dict.py @@ -0,0 +1,47 @@ +import dataclasses +import unittest + +from marshmallow_dataclass import class_schema + + +class Test_Schema_load_to_dict(unittest.TestCase): + def test_simple(self): + @dataclasses.dataclass + class Simple: + one: int = dataclasses.field() + two: str = dataclasses.field() + + simple_schema = class_schema(Simple)() + assert simple_schema.load_to_dict({"one": "1", "two": "b"}) == { + "one": 1, + "two": "b", + } + + def test_partial(self): + @dataclasses.dataclass + class Simple: + one: int = dataclasses.field() + two: str = dataclasses.field() + + simple_schema = class_schema(Simple)() + assert simple_schema.load_to_dict({"one": "1"}, partial=True) == {"one": 1} + + def test_nested(self): + @dataclasses.dataclass + class Simple: + one: int = dataclasses.field() + two: str = dataclasses.field() + + @dataclasses.dataclass + class Nested: + x: str = dataclasses.field() + child: Simple = dataclasses.field() + + nested_schema = class_schema(Nested)() + assert nested_schema.load_to_dict({"child": {"one": "1"}}, partial=True) == { + "child": {"one": 1}, + } + + +if __name__ == "__main__": + unittest.main()