Skip to content

Commit 3d9fe74

Browse files
committed
FindSymbols tweaks
1 parent e56e768 commit 3d9fe74

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

devito/ir/iet/visitors.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class LazyVisitor(GenericVisitor):
6363

6464
"""
6565
A generic visitor that lazily yields results instead of flattening results
66-
from children at every step (unless required for rebuilding).
66+
from children at every step.
6767
6868
Subclass-defined visit methods (and default_retval) should be generators.
6969
"""
@@ -1014,12 +1014,15 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
10141014
class FindSymbols(LazyVisitor):
10151015

10161016
@classmethod
1017-
def retval(cls, *its: Iterable[Any]) -> Iterator[Any]:
1017+
def join(cls, *its: Iterable[Any], key: Callable[[Any], Any] = str) -> Iterator[Any]:
10181018
"""
1019-
Filterrs and flattens nested iterables while preserving relative order.
1019+
Flattens nested iterables and filters for uniqueness, ordering results
1020+
lexicographically so we don't need to sort the final result.
10201021
"""
1021-
ret = chain(flatten_iter(it) for it in its)
1022-
return filter_ordered_iter(flatten_iter(ret), key=id)
1022+
1023+
_it = (flatten_iter(it) for it in its)
1024+
yield from filter_ordered_iter(flatten_iter(_it), key=id)
1025+
10231026

10241027
"""
10251028
Find symbols in an Iteration/Expression tree.
@@ -1070,27 +1073,27 @@ def __init__(self, mode: str = 'symbolics') -> None:
10701073
else:
10711074
self.rule = lambda n: chain(*[self.rules[mode](n) for mode in modes])
10721075

1073-
def _post_visit(self, ret: Iterable[Any]) -> list[Any]:
1074-
return sorted(ret, key=lambda i: str(i))
1076+
def _post_visit(self, ret: Iterable[Any]) -> Iterable[Any]:
1077+
return sorted(ret, key=str)
10751078

10761079
def visit_tuple(self, o: Sequence[Any]) -> Iterator[Any]:
1077-
yield from self.retval(self._visit(i) for i in o)
1080+
yield from self.join(self._visit(i) for i in o)
10781081

10791082
visit_list = visit_tuple
10801083

10811084
def visit_Node(self, o: Node) -> Iterator[Any]:
1082-
yield from self.retval(self._visit(o.children), self.rule(o))
1085+
yield from self.join(self._visit(o.children), self.rule(o))
10831086

10841087
def visit_ThreadedProdder(self, o) -> Iterator[Any]:
10851088
# TODO: this handle required because ThreadedProdder suffers from the
10861089
# long-standing issue affecting all Node subclasses which rely on
10871090
# multiple inheritance
1088-
yield from self.retval(self._visit(o.then_body), self.rule(o))
1091+
yield from self.join(self._visit(o.then_body), self.rule(o))
10891092

10901093
def visit_Operator(self, o) -> Iterator[Any]:
1091-
yield from self.retval(self._visit(o.body),
1092-
flatten_iter(self._visit(i) for i in o._dspace.parts),
1093-
self.rule(o))
1094+
yield from self.join(self._visit(o.body),
1095+
self.rule(o),
1096+
*(self._visit(i) for i in o._dspace.parts))
10941097

10951098

10961099
class FindNodes(LazyVisitor):

0 commit comments

Comments
 (0)