Skip to content

Commit a758919

Browse files
committed
Split generic tests into it's own file.
1 parent 28a47aa commit a758919

File tree

3 files changed

+101
-81
lines changed

3 files changed

+101
-81
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def dataclass(
100100
frozen: bool = False,
101101
base_schema: Optional[Type[marshmallow.Schema]] = None,
102102
cls_frame: Optional[types.FrameType] = None,
103-
) -> Type[_U]: ...
103+
) -> Type[_U]:
104+
...
104105

105106

106107
@overload
@@ -113,7 +114,8 @@ def dataclass(
113114
frozen: bool = False,
114115
base_schema: Optional[Type[marshmallow.Schema]] = None,
115116
cls_frame: Optional[types.FrameType] = None,
116-
) -> Callable[[Type[_U]], Type[_U]]: ...
117+
) -> Callable[[Type[_U]], Type[_U]]:
118+
...
117119

118120

119121
# _cls should never be specified by keyword, so start it with an
@@ -172,21 +174,24 @@ def dataclass(
172174

173175

174176
@overload
175-
def add_schema(_cls: Type[_U]) -> Type[_U]: ...
177+
def add_schema(_cls: Type[_U]) -> Type[_U]:
178+
...
176179

177180

178181
@overload
179182
def add_schema(
180183
base_schema: Optional[Type[marshmallow.Schema]] = None,
181-
) -> Callable[[Type[_U]], Type[_U]]: ...
184+
) -> Callable[[Type[_U]], Type[_U]]:
185+
...
182186

183187

184188
@overload
185189
def add_schema(
186190
_cls: Type[_U],
187191
base_schema: Optional[Type[marshmallow.Schema]] = None,
188192
cls_frame: Optional[types.FrameType] = None,
189-
) -> Type[_U]: ...
193+
) -> Type[_U]:
194+
...
190195

191196

192197
def add_schema(_cls=None, base_schema=None, cls_frame=None):
@@ -805,7 +810,9 @@ def field_for_schema(
805810
nested_schema
806811
or forward_reference
807812
or _RECURSION_GUARD.seen_classes.get(typ)
808-
or _internal_class_schema(typ, base_schema, typ_frame, generic_params_to_args) # type: ignore [arg-type]
813+
or _internal_class_schema(
814+
typ, base_schema, typ_frame, generic_params_to_args # type: ignore [arg-type]
815+
)
809816
)
810817

811818
return marshmallow.fields.Nested(nested, **metadata)
@@ -873,7 +880,7 @@ def _generic_params_to_args(clazz: type) -> Tuple[Tuple[type, type], ...]:
873880

874881
def _dataclass_type_hints(
875882
clazz: type,
876-
clazz_frame: types.FrameType = None,
883+
clazz_frame: Optional[types.FrameType] = None,
877884
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
878885
) -> Mapping[str, type]:
879886
localns = clazz_frame.f_locals if clazz_frame else None

tests/test_class_schema.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from marshmallow.fields import List as ListField
1818
from marshmallow.validate import Validator
1919

20-
from marshmallow_dataclass import NewType, _is_generic_alias_of_dataclass, class_schema
20+
from marshmallow_dataclass import NewType, class_schema
2121

2222

2323
class TestClassSchema(unittest.TestCase):
@@ -460,79 +460,6 @@ class Meta:
460460
self.assertNotIn("no_init", class_schema(NoInit)().fields)
461461
self.assertIn("no_init", class_schema(Init)().fields)
462462

463-
def test_generic_dataclass(self):
464-
T = typing.TypeVar("T")
465-
466-
@dataclasses.dataclass
467-
class SimpleGeneric(typing.Generic[T]):
468-
data: T
469-
470-
@dataclasses.dataclass
471-
class NestedFixed:
472-
data: SimpleGeneric[int]
473-
474-
@dataclasses.dataclass
475-
class NestedGeneric(typing.Generic[T]):
476-
data: SimpleGeneric[T]
477-
478-
self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int]))
479-
self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric))
480-
481-
schema_s = class_schema(SimpleGeneric[str])()
482-
self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"}))
483-
self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"})
484-
with self.assertRaises(ValidationError):
485-
schema_s.load({"data": 2})
486-
487-
schema_nested = class_schema(NestedFixed)()
488-
self.assertEqual(
489-
NestedFixed(data=SimpleGeneric(1)),
490-
schema_nested.load({"data": {"data": 1}}),
491-
)
492-
self.assertEqual(
493-
schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))),
494-
{"data": {"data": 1}},
495-
)
496-
with self.assertRaises(ValidationError):
497-
schema_nested.load({"data": {"data": "str"}})
498-
499-
schema_nested_generic = class_schema(NestedGeneric[int])()
500-
self.assertEqual(
501-
NestedGeneric(data=SimpleGeneric(1)),
502-
schema_nested_generic.load({"data": {"data": 1}}),
503-
)
504-
self.assertEqual(
505-
schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))),
506-
{"data": {"data": 1}},
507-
)
508-
with self.assertRaises(ValidationError):
509-
schema_nested_generic.load({"data": {"data": "str"}})
510-
511-
def test_generic_dataclass_repeated_fields(self):
512-
T = typing.TypeVar("T")
513-
514-
@dataclasses.dataclass
515-
class AA:
516-
a: int
517-
518-
@dataclasses.dataclass
519-
class BB(typing.Generic[T]):
520-
b: T
521-
522-
@dataclasses.dataclass
523-
class Nested:
524-
x: BB[float]
525-
z: BB[float]
526-
# if y is the first field in this class, deserialisation will fail.
527-
# see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027
528-
y: BB[AA]
529-
530-
schema_nested = class_schema(Nested)()
531-
self.assertEqual(
532-
Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))),
533-
schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}),
534-
)
535-
536463

537464
if __name__ == "__main__":
538465
unittest.main()

tests/test_generics.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import dataclasses
2+
import typing
3+
import unittest
4+
5+
from marshmallow import ValidationError
6+
7+
from marshmallow_dataclass import _is_generic_alias_of_dataclass, class_schema
8+
9+
10+
class TestGenerics(unittest.TestCase):
11+
def test_generic_dataclass(self):
12+
T = typing.TypeVar("T")
13+
14+
@dataclasses.dataclass
15+
class SimpleGeneric(typing.Generic[T]):
16+
data: T
17+
18+
@dataclasses.dataclass
19+
class NestedFixed:
20+
data: SimpleGeneric[int]
21+
22+
@dataclasses.dataclass
23+
class NestedGeneric(typing.Generic[T]):
24+
data: SimpleGeneric[T]
25+
26+
self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int]))
27+
self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric))
28+
29+
schema_s = class_schema(SimpleGeneric[str])()
30+
self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"}))
31+
self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"})
32+
with self.assertRaises(ValidationError):
33+
schema_s.load({"data": 2})
34+
35+
schema_nested = class_schema(NestedFixed)()
36+
self.assertEqual(
37+
NestedFixed(data=SimpleGeneric(1)),
38+
schema_nested.load({"data": {"data": 1}}),
39+
)
40+
self.assertEqual(
41+
schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))),
42+
{"data": {"data": 1}},
43+
)
44+
with self.assertRaises(ValidationError):
45+
schema_nested.load({"data": {"data": "str"}})
46+
47+
schema_nested_generic = class_schema(NestedGeneric[int])()
48+
self.assertEqual(
49+
NestedGeneric(data=SimpleGeneric(1)),
50+
schema_nested_generic.load({"data": {"data": 1}}),
51+
)
52+
self.assertEqual(
53+
schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))),
54+
{"data": {"data": 1}},
55+
)
56+
with self.assertRaises(ValidationError):
57+
schema_nested_generic.load({"data": {"data": "str"}})
58+
59+
def test_generic_dataclass_repeated_fields(self):
60+
T = typing.TypeVar("T")
61+
62+
@dataclasses.dataclass
63+
class AA:
64+
a: int
65+
66+
@dataclasses.dataclass
67+
class BB(typing.Generic[T]):
68+
b: T
69+
70+
@dataclasses.dataclass
71+
class Nested:
72+
x: BB[float]
73+
z: BB[float]
74+
# if y is the first field in this class, deserialisation will fail.
75+
# see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027
76+
y: BB[AA]
77+
78+
schema_nested = class_schema(Nested)()
79+
self.assertEqual(
80+
Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))),
81+
schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}),
82+
)
83+
84+
85+
if __name__ == "__main__":
86+
unittest.main()

0 commit comments

Comments
 (0)