Skip to content

Re-prefetch related objects after updating #8043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions rest_framework/generics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Generic views that provide commonly needed behaviour.
"""
from typing import Iterable

from django.core.exceptions import ValidationError
from django.db.models.query import QuerySet
from django.http import Http404
Expand Down Expand Up @@ -45,6 +47,8 @@ class GenericAPIView(views.APIView):
# The style to use for queryset pagination.
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS

prefetch_related = []

def get_queryset(self):
"""
Get the list of items for this view.
Expand All @@ -68,10 +72,31 @@ def get_queryset(self):

queryset = self.queryset
if isinstance(queryset, QuerySet):
# Prefetch related objects
if self.get_prefetch_related():
queryset = queryset.prefetch_related(*self.get_prefetch_related())
# Ensure queryset is re-evaluated on each request.
queryset = queryset.all()
return queryset

def get_prefetch_related(self):
"""
Get the list of prefetch related objects for self.queryset or instance.
This must be an iterable.
Defaults to using `self.prefetch_related`.

You may want to override this if you need to provide prefetched objects
depending on the incoming request.

(Eg. `['toppings', Prefetch('restaurants', queryset=Restaurant.objects.select_related('best_pizza'))]`)
"""
assert isinstance(self.prefetch_related, Iterable), (
"'%s' should either include an iterable `prefetch_related` attribute, "
"or override the `get_prefetch_related()` method."
% self.__class__.__name__
)
return self.prefetch_related

def get_object(self):
"""
Returns the object the view is displaying.
Expand Down
6 changes: 5 additions & 1 deletion rest_framework/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways.
"""
from django.db.models.query import prefetch_related_objects

from rest_framework import status
from rest_framework.response import Response
from rest_framework.settings import api_settings
Expand Down Expand Up @@ -69,8 +71,10 @@ def update(self, request, *args, **kwargs):

if getattr(instance, '_prefetched_objects_cache', None):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
# forcibly invalidate the prefetch cache on the instance,
# and then re-prefetch related objects
instance._prefetched_objects_cache = {}
prefetch_related_objects([instance], *self.get_prefetch_related())

return Response(serializer.data)

Expand Down
47 changes: 21 additions & 26 deletions tests/test_prefetch_related.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,40 @@ class Meta:


class UserUpdate(generics.UpdateAPIView):
queryset = User.objects.exclude(username='exclude').prefetch_related('groups')
queryset = User.objects.exclude(username='exclude')
serializer_class = UserSerializer
prefetch_related = ['groups']


class TestPrefetchRelatedUpdates(TestCase):
def setUp(self):
self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
self.user.groups.set(self.groups)

def test_prefetch_related_updates(self):
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = {
'id': pk,
self.expected = {
'id': self.user.pk,
'username': 'new',
'groups': [1],
'email': 'tom@example.com'
'email': 'tom@example.com',
}
assert response.data == expected
self.view = UserUpdate.as_view()

def test_prefetch_related_updates(self):
request = factory.put(
'/', {'username': 'new', 'groups': [self.groups[0].pk]}, format='json'
)
response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 1
assert response.data == self.expected

def test_prefetch_related_excluding_instance_from_original_queryset(self):
"""
Regression test for https://github.com/encode/django-rest-framework/issues/4661
"""
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = {
'id': pk,
'username': 'exclude',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected
request = factory.put(
'/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
)
response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 1
self.expected['username'] = 'exclude'
assert response.data == self.expected