@@ -59,10 +59,11 @@ def always_rebuild(self, o, *args, **kwargs):
59
59
return o ._rebuild (* new_ops , ** okwargs )
60
60
61
61
62
- ResultType = TypeVar ('ResultType' )
62
+ YieldType = TypeVar ('YieldType' , covariant = True )
63
+ ResultType = TypeVar ('ResultType' , covariant = True )
63
64
64
65
65
- class LazyVisitor (GenericVisitor , Generic [ResultType ]):
66
+ class LazyVisitor (GenericVisitor , Generic [YieldType , ResultType ]):
66
67
67
68
"""
68
69
A generic visitor that lazily yields results instead of flattening results
@@ -71,25 +72,25 @@ class LazyVisitor(GenericVisitor, Generic[ResultType]):
71
72
Subclass-defined visit methods should be generators.
72
73
"""
73
74
74
- def lookup_method (self , instance ) -> Callable [..., Iterator [Any ]]:
75
+ def lookup_method (self , instance ) -> Callable [..., Iterator [YieldType ]]:
75
76
return super ().lookup_method (instance )
76
77
77
- def _visit (self , o , * args , ** kwargs ) -> Iterator [Any ]:
78
+ def _visit (self , o , * args , ** kwargs ) -> Iterator [YieldType ]:
78
79
meth = self .lookup_method (o )
79
80
yield from meth (o , * args , ** kwargs )
80
81
81
- def _post_visit (self , ret : Iterator [Any ]) -> ResultType :
82
+ def _post_visit (self , ret : Iterator [YieldType ]) -> ResultType :
82
83
return list (ret )
83
84
84
- def visit_object (self , o : object , ** kwargs ) -> Iterator [Any ]:
85
+ def visit_object (self , o : object , ** kwargs ) -> Iterator [YieldType ]:
85
86
yield from ()
86
87
87
- def visit_Node (self , o : Node , ** kwargs ) -> Iterator [Any ]:
88
+ def visit_Node (self , o : Node , ** kwargs ) -> Iterator [YieldType ]:
88
89
yield from self ._visit (o .children , ** kwargs )
89
90
90
- def visit_tuple (self , o : Sequence [Any ]) -> Iterator [Any ]:
91
+ def visit_tuple (self , o : Sequence [Any ], ** kwargs ) -> Iterator [YieldType ]:
91
92
for i in o :
92
- yield from self ._visit (i )
93
+ yield from self ._visit (i , ** kwargs )
93
94
94
95
visit_list = visit_tuple
95
96
@@ -1014,7 +1015,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
1014
1015
return ret
1015
1016
1016
1017
1017
- class FindSymbols (LazyVisitor [list [Any ]]):
1018
+ class FindSymbols (LazyVisitor [Any , list [Any ]]):
1018
1019
1019
1020
"""
1020
1021
Find symbols in an Iteration/Expression tree.
@@ -1088,7 +1089,7 @@ def visit_Operator(self, o) -> Iterator[Any]:
1088
1089
yield from self ._visit (i )
1089
1090
1090
1091
1091
- class FindNodes (LazyVisitor [list [Node ]]):
1092
+ class FindNodes (LazyVisitor [Node , list [Node ]]):
1092
1093
1093
1094
"""
1094
1095
Find all instances of given type.
@@ -1110,65 +1111,61 @@ class FindNodes(LazyVisitor[list[Node]]):
1110
1111
'scope' : lambda match , o : match in flatten (o .children )
1111
1112
}
1112
1113
1113
- def __init__ (self , match : type , mode : str = 'type' ):
1114
+ def __init__ (self , match : type , mode : str = 'type' ) -> None :
1114
1115
super ().__init__ ()
1115
1116
self .match = match
1116
1117
self .rule = self .rules [mode ]
1117
1118
1118
- def visit_Node (self , o : Node ) -> Iterator [Any ]:
1119
+ def visit_Node (self , o : Node , ** kwargs ) -> Iterator [Node ]:
1119
1120
if self .rule (self .match , o ):
1120
1121
yield o
1121
1122
for i in o .children :
1122
- yield from self ._visit (i )
1123
+ yield from self ._visit (i , ** kwargs )
1123
1124
1124
1125
1125
1126
class FindWithin (FindNodes ):
1126
1127
1127
- @classmethod
1128
- def default_retval (cls ):
1129
- return [], False
1130
-
1131
1128
"""
1132
1129
Like FindNodes, but given an additional parameter `within=(start, stop)`,
1133
1130
it starts collecting matching nodes only after `start` is found, and stops
1134
1131
collecting matching nodes after `stop` is found.
1135
1132
"""
1136
1133
1137
- def __init__ (self , match , start , stop = None ):
1134
+ # Dummy object to signal the end of the search
1135
+ STOP = object ()
1136
+
1137
+ def __init__ (self , match : type , start : Node , stop : Node | None = None ) -> None :
1138
1138
super ().__init__ (match )
1139
1139
self .start = start
1140
1140
self .stop = stop
1141
1141
1142
- def visit (self , o , ret = None ):
1143
- found , _ = self ._visit (o , ret = ret )
1144
- return found
1145
-
1146
- def visit_Node (self , o , ret = None ):
1147
- if ret is None :
1148
- ret = self .default_retval ()
1149
- found , flag = ret
1142
+ def _post_visit (self , ret : Iterator [Node ]) -> list [Node ]:
1143
+ ret = super ()._post_visit (ret )
1144
+ if ret [- 1 ] is self .STOP :
1145
+ ret .pop ()
1146
+ return ret
1150
1147
1148
+ def visit_Node (self , o : Node , flag : bool = False ) -> Iterator [Node ]:
1151
1149
if o is self .start :
1152
1150
flag = True
1153
1151
1154
1152
if flag and self .rule (self .match , o ):
1155
- found . append ( o )
1153
+ yield o
1156
1154
for i in o .children :
1157
- found , newflag = self ._visit (i , ret = ( found , flag ))
1158
- if flag and not newflag :
1159
- return found , newflag
1160
- flag = newflag
1155
+ for r in self ._visit (i , flag = flag ):
1156
+ yield r
1157
+ if r is self . STOP :
1158
+ return
1161
1159
1162
1160
if o is self .stop :
1163
- flag = False
1164
-
1165
- return found , flag
1161
+ yield self .STOP
1166
1162
1167
1163
1168
1164
ApplicationType = TypeVar ('ApplicationType' )
1169
1165
1170
1166
1171
- class FindApplications (LazyVisitor [set [ApplicationType ]]):
1167
+ class FindApplications (LazyVisitor [ApplicationType , set [ApplicationType ]]):
1168
+
1172
1169
"""
1173
1170
Find all SymPy applied functions (aka, `Application`s). The user may refine
1174
1171
the search by supplying a different target class.
0 commit comments