Skip to content

Commit 67a800b

Browse files
Add SerializedRelatedField for write-by-slug, read-as-nested representation
- Introduce SerializedRelatedField in rest_framework/relations.py - Supports writing by slug/PK and reading as nested serializer - Add corresponding tests in tests/test_relations.py
1 parent 2001878 commit 67a800b

File tree

3 files changed

+142
-1
lines changed

3 files changed

+142
-1
lines changed

rest_framework/relations.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from urllib import parse
55

66
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
7+
from django.db import models
78
from django.db.models import Manager
89
from django.db.models.query import QuerySet
910
from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve
@@ -583,3 +584,40 @@ def iter_options(self):
583584
cutoff=self.html_cutoff,
584585
cutoff_text=self.html_cutoff_text
585586
)
587+
588+
589+
class SerializedRelatedField(SlugRelatedField):
590+
"""
591+
A relational field that accepts a simple slug for writes
592+
(like SlugRelatedField), but expands to a nested serializer
593+
for reads if `serializer_class` is provided.
594+
595+
Example:
596+
class OrderSerializer(serializers.ModelSerializer):
597+
address = SerializedRelatedField(
598+
serializer_class=AddressSerializer,
599+
queryset=Address.objects.all(),
600+
lookup_field="pk",
601+
)
602+
"""
603+
604+
def __init__(self, serializer_class=None, lookup_field="pk", **kwargs):
605+
self.serializer_class = serializer_class
606+
kwargs["slug_field"] = lookup_field
607+
super().__init__(**kwargs)
608+
609+
if self.serializer_class is not None and self.queryset is None:
610+
raise AssertionError(
611+
"SerializedRelatedField with serializer_class requires a queryset"
612+
)
613+
614+
def to_representation(self, value):
615+
# Ensure PKOnlyObject (used in select_related/prefetch) is resolved
616+
if hasattr(value, "pk") and not isinstance(value, models.Model):
617+
value = self.get_queryset().get(pk=value.pk)
618+
619+
if self.serializer_class is not None:
620+
serializer = self.serializer_class(value, context=self.context)
621+
return serializer.data
622+
623+
return super().to_representation(value)

rest_framework/serializers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from rest_framework.relations import ( # NOQA # isort:skip
6363
HyperlinkedIdentityField, HyperlinkedRelatedField, ManyRelatedField,
6464
PrimaryKeyRelatedField, RelatedField, SlugRelatedField, StringRelatedField,
65+
SerializedRelatedField
6566
)
6667

6768
# Non-field imports, but public API

tests/test_relations.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
import pytest
44
from _pytest.monkeypatch import MonkeyPatch
55
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
6-
from django.test import override_settings
6+
from django.db import models
7+
from django.test import TestCase, override_settings
78
from django.urls import re_path
89
from django.utils.datastructures import MultiValueDict
10+
from django.utils.translation import gettext_lazy as _
911

1012
from rest_framework import relations, serializers
1113
from rest_framework.fields import empty
1214
from rest_framework.test import APISimpleTestCase
15+
from tests.models import RESTFrameworkModel
1316

1417
from .utils import (
1518
BadType, MockObject, MockQueryset, fail_reverse, mock_reverse
@@ -518,3 +521,102 @@ def test_can_be_pickled(self):
518521
upkled = pickle.loads(pickle.dumps(self.default_hyperlink))
519522
assert upkled == self.default_hyperlink
520523
assert upkled.name == self.default_hyperlink.name
524+
525+
526+
class Address(RESTFrameworkModel):
527+
postal_code = models.CharField(
528+
max_length=20, unique=True, verbose_name=_("Postal Code")
529+
)
530+
province = models.CharField(max_length=100, verbose_name=_("Province"))
531+
city = models.CharField(max_length=100, verbose_name=_("City"))
532+
street = models.CharField(
533+
max_length=255, blank=True, null=True, verbose_name=_("Street")
534+
)
535+
additional_info = models.TextField(
536+
verbose_name=_("Additional Info"), blank=True, null=True
537+
)
538+
539+
540+
class AddressSerializer(serializers.ModelSerializer):
541+
class Meta:
542+
model = Address
543+
fields = "__all__"
544+
545+
546+
class SerializedRelatedFieldTests(TestCase):
547+
548+
class OrderSerializerPostalCode(serializers.Serializer):
549+
address = relations.SerializedRelatedField(
550+
serializer_class=AddressSerializer,
551+
queryset=Address.objects.all(),
552+
lookup_field='postal_code',
553+
)
554+
555+
class OrderSerializerCity(serializers.Serializer):
556+
address = relations.SerializedRelatedField(
557+
serializer_class=AddressSerializer,
558+
queryset=Address.objects.all(),
559+
lookup_field='city',
560+
)
561+
562+
class OrderSerializerPK(serializers.Serializer):
563+
address = relations.SerializedRelatedField(
564+
serializer_class=AddressSerializer,
565+
queryset=Address.objects.all(),
566+
)
567+
568+
def setUp(self):
569+
self.address = Address.objects.create(
570+
postal_code="12345",
571+
province="Tehran",
572+
city="Tehran",
573+
street="Valiasr",
574+
additional_info="Test info"
575+
)
576+
Address.objects.create(
577+
postal_code="123456",
578+
province="Tehran",
579+
city="Tehran",
580+
street="Valiasr",
581+
additional_info="Test info"
582+
)
583+
584+
def test_write_slug(self):
585+
data = {"address": self.address.postal_code}
586+
serializer = self.OrderSerializerPostalCode(data=data)
587+
assert serializer.is_valid(), serializer.errors
588+
assert serializer.validated_data["address"] == self.address
589+
590+
def test_read_nested(self):
591+
data = {"address": self.address.postal_code}
592+
serializer = self.OrderSerializerPostalCode(data=data)
593+
assert serializer.is_valid(), serializer.errors
594+
expected = AddressSerializer(self.address).data
595+
assert serializer.data["address"] == expected
596+
597+
def test_write_default(self):
598+
data = {"address": self.address.pk}
599+
serializer = self.OrderSerializerPK(data=data)
600+
assert serializer.is_valid(), serializer.errors
601+
expected = AddressSerializer(self.address).data
602+
assert serializer.data["address"] == expected
603+
604+
def test_read_default(self):
605+
data = {"address": self.address.pk}
606+
serializer = self.OrderSerializerPK(data=data)
607+
assert serializer.is_valid(), serializer.errors
608+
expected = AddressSerializer(self.address).data
609+
assert serializer.data["address"] == expected
610+
611+
def test_duplicated(self):
612+
data = {"address": "Tehran"}
613+
serializer = self.OrderSerializerCity(data=data)
614+
with pytest.raises(Address.MultipleObjectsReturned) as exc_info:
615+
serializer.is_valid(raise_exception=True)
616+
assert "returned more than one Address -- it returned 2!" in str(exc_info.value)
617+
618+
def test_not_fount(self):
619+
data = {"address": "Isfahan"}
620+
serializer = self.OrderSerializerCity(data=data)
621+
serializer.is_valid()
622+
assert "Object with city=Isfahan does not exist." in serializer.errors["address"]

0 commit comments

Comments
 (0)