Skip to content

Commit af2c4a6

Browse files
committed
Re-prefetch related objects after updating
1 parent 24a938a commit af2c4a6

File tree

3 files changed

+50
-27
lines changed

3 files changed

+50
-27
lines changed

rest_framework/generics.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Generic views that provide commonly needed behaviour.
33
"""
4+
from typing import Iterable
45
from django.core.exceptions import ValidationError
56
from django.db.models.query import QuerySet
67
from django.http import Http404
@@ -45,6 +46,8 @@ class GenericAPIView(views.APIView):
4546
# The style to use for queryset pagination.
4647
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
4748

49+
prefetch_related = []
50+
4851
def get_queryset(self):
4952
"""
5053
Get the list of items for this view.
@@ -68,9 +71,30 @@ def get_queryset(self):
6871

6972
queryset = self.queryset
7073
if isinstance(queryset, QuerySet):
74+
# Prefetch related objects
75+
if self.get_prefetch_related():
76+
queryset = queryset.prefetch_related(*self.get_prefetch_related())
7177
# Ensure queryset is re-evaluated on each request.
7278
queryset = queryset.all()
7379
return queryset
80+
81+
def get_prefetch_related(self):
82+
"""
83+
Get the list of prefetch related objects for self.queryset or instance.
84+
This must be an iterable.
85+
Defaults to using `self.prefetch_related`.
86+
87+
You may want to override this if you need to provide prefetched objects
88+
depending on the incoming request.
89+
90+
(Eg. `['toppings', Prefetch('restaurants', queryset=Restaurant.objects.select_related('best_pizza'))]`)
91+
"""
92+
assert isinstance(self.prefetch_related, Iterable), (
93+
"'%s' should either include an iterable `prefetch_related` attribute, "
94+
"or override the `get_prefetch_related()` method."
95+
% self.__class__.__name__
96+
)
97+
return self.prefetch_related
7498

7599
def get_object(self):
76100
"""

rest_framework/mixins.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
We don't bind behaviour to http method handlers yet,
55
which allows mixin classes to be composed in interesting ways.
66
"""
7+
from django.db.models.query import prefetch_related_objects
8+
79
from rest_framework import status
810
from rest_framework.response import Response
911
from rest_framework.settings import api_settings
@@ -69,8 +71,10 @@ def update(self, request, *args, **kwargs):
6971

7072
if getattr(instance, '_prefetched_objects_cache', None):
7173
# If 'prefetch_related' has been applied to a queryset, we need to
72-
# forcibly invalidate the prefetch cache on the instance.
74+
# forcibly invalidate the prefetch cache on the instance,
75+
# and then re-prefetch related objects
7376
instance._prefetched_objects_cache = {}
77+
prefetch_related_objects([instance], *self.get_prefetch_related())
7478

7579
return Response(serializer.data)
7680

tests/test_prefetch_related.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,45 +14,40 @@ class Meta:
1414

1515

1616
class UserUpdate(generics.UpdateAPIView):
17-
queryset = User.objects.exclude(username='exclude').prefetch_related('groups')
17+
queryset = User.objects.exclude(username='exclude')
1818
serializer_class = UserSerializer
19+
prefetch_related = ['groups']
1920

2021

2122
class TestPrefetchRelatedUpdates(TestCase):
2223
def setUp(self):
2324
self.user = User.objects.create(username='tom', email='tom@example.com')
2425
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
2526
self.user.groups.set(self.groups)
26-
27-
def test_prefetch_related_updates(self):
28-
view = UserUpdate.as_view()
29-
pk = self.user.pk
30-
groups_pk = self.groups[0].pk
31-
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
32-
response = view(request, pk=pk)
33-
assert User.objects.get(pk=pk).groups.count() == 1
34-
expected = {
35-
'id': pk,
27+
self.expected = {
28+
'id': self.user.pk,
3629
'username': 'new',
3730
'groups': [1],
38-
'email': 'tom@example.com'
31+
'email': 'tom@example.com',
3932
}
40-
assert response.data == expected
33+
self.view = UserUpdate.as_view()
34+
35+
def test_prefetch_related_updates(self):
36+
request = factory.put(
37+
'/', {'username': 'new', 'groups': [self.groups[0].pk]}, format='json'
38+
)
39+
response = self.view(request, pk=self.user.pk)
40+
assert User.objects.get(pk=self.user.pk).groups.count() == 1
41+
assert response.data == self.expected
4142

4243
def test_prefetch_related_excluding_instance_from_original_queryset(self):
4344
"""
4445
Regression test for https://github.com/encode/django-rest-framework/issues/4661
4546
"""
46-
view = UserUpdate.as_view()
47-
pk = self.user.pk
48-
groups_pk = self.groups[0].pk
49-
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
50-
response = view(request, pk=pk)
51-
assert User.objects.get(pk=pk).groups.count() == 1
52-
expected = {
53-
'id': pk,
54-
'username': 'exclude',
55-
'groups': [1],
56-
'email': 'tom@example.com'
57-
}
58-
assert response.data == expected
47+
request = factory.put(
48+
'/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
49+
)
50+
response = self.view(request, pk=self.user.pk)
51+
assert User.objects.get(pk=self.user.pk).groups.count() == 1
52+
self.expected['username'] = 'exclude'
53+
assert response.data == self.expected

0 commit comments

Comments
 (0)