diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index cfb54de138..916f8bec44 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -121,6 +121,10 @@ def __new__(cls, *args, **kwargs): return cls.many_init(*args, **kwargs) return super().__new__(cls, *args, **kwargs) + # Allow type checkers to make serializers generic. + def __class_getitem__(cls, *args, **kwargs): + return cls + @classmethod def many_init(cls, *args, **kwargs): """ diff --git a/tests/test_serializer.py b/tests/test_serializer.py index a58c46b2d9..afefd70e1c 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -1,6 +1,7 @@ import inspect import pickle import re +import sys from collections import ChainMap from collections.abc import Mapping @@ -204,6 +205,13 @@ class ExampleSerializer(serializers.Serializer): exceptions.ErrorDetail(string='Raised error', code='invalid') ]} + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_serializer_is_subscriptable(self): + assert serializers.Serializer is serializers.Serializer["foo"] + class TestValidateMethod: def test_non_field_error_validate_method(self): diff --git a/tests/test_serializer_lists.py b/tests/test_serializer_lists.py index 98e72385a2..f35c4fcc9e 100644 --- a/tests/test_serializer_lists.py +++ b/tests/test_serializer_lists.py @@ -1,3 +1,5 @@ +import sys + import pytest from django.http import QueryDict from django.utils.datastructures import MultiValueDict @@ -55,6 +57,13 @@ def test_validate_html_input(self): assert serializer.is_valid() assert serializer.validated_data == expected_output + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_list_serializer_is_subscriptable(self): + assert serializers.ListSerializer is serializers.ListSerializer["foo"] + class TestListSerializerContainingNestedSerializer: """