@@ -24,21 +24,21 @@ class MypyTestItem:
24
24
lineno : int
25
25
end_lineno : int
26
26
expected_messages : List [Message ]
27
- func_node : Optional [ast .FunctionDef ] = None
27
+ func_node : Optional [Union [ ast .FunctionDef , ast . AsyncFunctionDef ] ] = None
28
28
marks : Set [str ] = dataclasses .field (default_factory = lambda : set ())
29
29
actual_messages : List [Message ] = dataclasses .field (default_factory = lambda : [])
30
30
31
31
@classmethod
32
32
def from_ast_node (
33
33
cls ,
34
- func_node : ast .FunctionDef ,
34
+ func_node : Union [ ast .FunctionDef , ast . AsyncFunctionDef ] ,
35
35
marks : Optional [Set [str ]] = None ,
36
36
unfiltered_messages : Optional [Iterable [Message ]] = None ,
37
37
) -> "MypyTestItem" :
38
- if not isinstance (func_node , ast .FunctionDef ):
38
+ if not isinstance (func_node , ( ast .FunctionDef , ast . AsyncFunctionDef ) ):
39
39
raise ValueError (
40
40
f"Invalid func_node type: Got { type (func_node )} , "
41
- f"expected { ast .FunctionDef } "
41
+ f"expected { ast .FunctionDef } or { ast . AsyncFunctionDef } "
42
42
)
43
43
lineno = func_node .lineno
44
44
end_lineno = getattr (func_node , "end_lineno" , 0 )
@@ -121,7 +121,7 @@ def parse_file(filename: Union[os.PathLike, str, pathlib.Path], config) -> MypyT
121
121
items : List [MypyTestItem ] = []
122
122
123
123
for node in ast .iter_child_nodes (tree ):
124
- if not isinstance (node , ast .FunctionDef ):
124
+ if not isinstance (node , ( ast .FunctionDef , ast . AsyncFunctionDef ) ):
125
125
continue
126
126
marks = _find_marks (node )
127
127
if "mypy_testing" in marks :
0 commit comments