From 8e32b2b49f6681e6324bb451c2d03e742d018c0d Mon Sep 17 00:00:00 2001 From: Ankur Dedania Date: Wed, 18 Jan 2023 18:39:52 -0600 Subject: [PATCH 1/2] add support for union with metadata --- marshmallow_dataclass/__init__.py | 7 +++- tests/test_field_for_schema.py | 53 +++++++++++++++++++++++++++++++ tests/test_union.py | 21 ++++++++++++ 3 files changed, 80 insertions(+), 1 deletion(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 106cde2..1e01f6d 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -578,13 +578,18 @@ def _field_for_generic_type( ) from . import union_field + union_metadata = { + k: v + for k, v in metadata.items() + if k not in ("allow_none", "dump_default", "load_default", "required") + } return union_field.Union( [ ( subtyp, field_for_schema( subtyp, - metadata={"required": True}, + metadata=union_metadata, base_schema=base_schema, typ_frame=typ_frame, ), diff --git a/tests/test_field_for_schema.py b/tests/test_field_for_schema.py index 0e60f0b..0e6673c 100644 --- a/tests/test_field_for_schema.py +++ b/tests/test_field_for_schema.py @@ -11,6 +11,7 @@ from typing_extensions import Final, Literal # type: ignore[assignment] from marshmallow import fields, Schema, validate +from marshmallow.warnings import RemovedInMarshmallow4Warning from marshmallow_dataclass import ( field_for_schema, @@ -132,6 +133,58 @@ class Color(Enum): marshmallow_enum.EnumField(enum=Color, required=True), ) + def test_union_enum(self): + self.maxDiff = None + + class Fruit(Enum): + apple = "Apple" + banana = "Banana" + tomato = "Tomato" + + with self.assertWarns(RemovedInMarshmallow4Warning): + if hasattr(fields, "Enum"): + self.assertFieldsEqual( + field_for_schema(Union[Fruit, str], metadata={"by_value": True}), + union_field.Union( + [ + ( + Fruit, + fields.Enum(enum=Fruit, required=True, by_value=True), + ), + ( + str, + fields.String( + required=True, metadata={"by_value": True} + ), + ), + ], + required=True, + metadata={"by_value": True}, + ), + ) + else: + import marshmallow_enum + + self.assertFieldsEqual( + field_for_schema(Union[Fruit, str], metadata={"by_value": True}), + marshmallow_enum.EnumField( + [ + ( + Fruit, + fields.Enum(enum=Fruit, required=True, by_value=True), + ), + ( + str, + fields.String( + required=True, metadata={"by_value": True} + ), + ), + ], + required=True, + metadata={"by_value": True}, + ), + ) + def test_literal(self): self.assertFieldsEqual( field_for_schema(Literal["a"]), diff --git a/tests/test_union.py b/tests/test_union.py index 5f0eb58..aca4db7 100644 --- a/tests/test_union.py +++ b/tests/test_union.py @@ -1,8 +1,10 @@ from dataclasses import field +from enum import Enum import sys import unittest from typing import List, Optional, Union, Dict +from marshmallow.warnings import RemovedInMarshmallow4Warning import marshmallow from marshmallow_dataclass import dataclass @@ -196,3 +198,22 @@ class PEP604IntOrStr: data_in = {"value": 42} self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + def test_union_enum(self): + class Fruit(Enum): + apple = "Apple" + banana = "Banana" + tomato = "Tomato" + + @dataclass + class Dclass: + value: Union[Fruit, dict] = field(metadata={"by_value": True}) + + with self.assertWarns(RemovedInMarshmallow4Warning): + schema = Dclass.Schema() + + data_in = {"value": "Apple"} + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + data_in = {"value": {"fruit": "Orange"}} + self.assertEqual(schema.dump(schema.load(data_in)), data_in) From 22427e1a80334916898f64b2d0b69ef2d9c34a45 Mon Sep 17 00:00:00 2001 From: Ankur Dedania Date: Wed, 18 Jan 2023 19:29:20 -0600 Subject: [PATCH 2/2] post test cleanup --- tests/test_field_for_schema.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_field_for_schema.py b/tests/test_field_for_schema.py index 0e6673c..2d6d7af 100644 --- a/tests/test_field_for_schema.py +++ b/tests/test_field_for_schema.py @@ -134,8 +134,6 @@ class Color(Enum): ) def test_union_enum(self): - self.maxDiff = None - class Fruit(Enum): apple = "Apple" banana = "Banana"