Skip to content

Commit 528d8f5

Browse files
committed
misc: LazyVisitor tweaks + lazy FindWithin
1 parent 1cdf3e8 commit 528d8f5

File tree

1 file changed

+33
-36
lines changed

1 file changed

+33
-36
lines changed

devito/ir/iet/visitors.py

Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,11 @@ def always_rebuild(self, o, *args, **kwargs):
5959
return o._rebuild(*new_ops, **okwargs)
6060

6161

62-
ResultType = TypeVar('ResultType')
62+
YieldType = TypeVar('YieldType', covariant=True)
63+
ResultType = TypeVar('ResultType', covariant=True)
6364

6465

65-
class LazyVisitor(GenericVisitor, Generic[ResultType]):
66+
class LazyVisitor(GenericVisitor, Generic[YieldType, ResultType]):
6667

6768
"""
6869
A generic visitor that lazily yields results instead of flattening results
@@ -71,25 +72,25 @@ class LazyVisitor(GenericVisitor, Generic[ResultType]):
7172
Subclass-defined visit methods should be generators.
7273
"""
7374

74-
def lookup_method(self, instance) -> Callable[..., Iterator[Any]]:
75+
def lookup_method(self, instance) -> Callable[..., Iterator[YieldType]]:
7576
return super().lookup_method(instance)
7677

77-
def _visit(self, o, *args, **kwargs) -> Iterator[Any]:
78+
def _visit(self, o, *args, **kwargs) -> Iterator[YieldType]:
7879
meth = self.lookup_method(o)
7980
yield from meth(o, *args, **kwargs)
8081

81-
def _post_visit(self, ret: Iterator[Any]) -> ResultType:
82+
def _post_visit(self, ret: Iterator[YieldType]) -> ResultType:
8283
return list(ret)
8384

84-
def visit_object(self, o: object, **kwargs) -> Iterator[Any]:
85+
def visit_object(self, o: object, **kwargs) -> Iterator[YieldType]:
8586
yield from ()
8687

87-
def visit_Node(self, o: Node, **kwargs) -> Iterator[Any]:
88+
def visit_Node(self, o: Node, **kwargs) -> Iterator[YieldType]:
8889
yield from self._visit(o.children, **kwargs)
8990

90-
def visit_tuple(self, o: Sequence[Any]) -> Iterator[Any]:
91+
def visit_tuple(self, o: Sequence[Any], **kwargs) -> Iterator[YieldType]:
9192
for i in o:
92-
yield from self._visit(i)
93+
yield from self._visit(i, **kwargs)
9394

9495
visit_list = visit_tuple
9596

@@ -1014,7 +1015,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
10141015
return ret
10151016

10161017

1017-
class FindSymbols(LazyVisitor[list[Any]]):
1018+
class FindSymbols(LazyVisitor[Any, list[Any]]):
10181019

10191020
"""
10201021
Find symbols in an Iteration/Expression tree.
@@ -1088,7 +1089,7 @@ def visit_Operator(self, o) -> Iterator[Any]:
10881089
yield from self._visit(i)
10891090

10901091

1091-
class FindNodes(LazyVisitor[list[Node]]):
1092+
class FindNodes(LazyVisitor[Node, list[Node]]):
10921093

10931094
"""
10941095
Find all instances of given type.
@@ -1110,65 +1111,61 @@ class FindNodes(LazyVisitor[list[Node]]):
11101111
'scope': lambda match, o: match in flatten(o.children)
11111112
}
11121113

1113-
def __init__(self, match: type, mode: str = 'type'):
1114+
def __init__(self, match: type, mode: str = 'type') -> None:
11141115
super().__init__()
11151116
self.match = match
11161117
self.rule = self.rules[mode]
11171118

1118-
def visit_Node(self, o: Node) -> Iterator[Any]:
1119+
def visit_Node(self, o: Node, **kwargs) -> Iterator[Node]:
11191120
if self.rule(self.match, o):
11201121
yield o
11211122
for i in o.children:
1122-
yield from self._visit(i)
1123+
yield from self._visit(i, **kwargs)
11231124

11241125

11251126
class FindWithin(FindNodes):
11261127

1127-
@classmethod
1128-
def default_retval(cls):
1129-
return [], False
1130-
11311128
"""
11321129
Like FindNodes, but given an additional parameter `within=(start, stop)`,
11331130
it starts collecting matching nodes only after `start` is found, and stops
11341131
collecting matching nodes after `stop` is found.
11351132
"""
11361133

1137-
def __init__(self, match, start, stop=None):
1134+
# Dummy object to signal the end of the search
1135+
STOP = object()
1136+
1137+
def __init__(self, match: type, start: Node, stop: Node | None = None) -> None:
11381138
super().__init__(match)
11391139
self.start = start
11401140
self.stop = stop
11411141

1142-
def visit(self, o, ret=None):
1143-
found, _ = self._visit(o, ret=ret)
1144-
return found
1145-
1146-
def visit_Node(self, o, ret=None):
1147-
if ret is None:
1148-
ret = self.default_retval()
1149-
found, flag = ret
1142+
def _post_visit(self, ret: Iterator[Node]) -> list[Node]:
1143+
ret = super()._post_visit(ret)
1144+
if ret[-1] is self.STOP:
1145+
ret.pop()
1146+
return ret
11501147

1148+
def visit_Node(self, o: Node, flag: bool = False) -> Iterator[Node]:
11511149
if o is self.start:
11521150
flag = True
11531151

11541152
if flag and self.rule(self.match, o):
1155-
found.append(o)
1153+
yield o
11561154
for i in o.children:
1157-
found, newflag = self._visit(i, ret=(found, flag))
1158-
if flag and not newflag:
1159-
return found, newflag
1160-
flag = newflag
1155+
for r in self._visit(i, flag=flag):
1156+
yield r
1157+
if r is self.STOP:
1158+
return
11611159

11621160
if o is self.stop:
1163-
flag = False
1164-
1165-
return found, flag
1161+
yield self.STOP
11661162

11671163

11681164
ApplicationType = TypeVar('ApplicationType')
11691165

11701166

1171-
class FindApplications(LazyVisitor[set[ApplicationType]]):
1167+
class FindApplications(LazyVisitor[ApplicationType, set[ApplicationType]]):
1168+
11721169
"""
11731170
Find all SymPy applied functions (aka, `Application`s). The user may refine
11741171
the search by supplying a different target class.

0 commit comments

Comments
 (0)