Skip to content

compiler: Lazy IET visitors + Search #2621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 136 additions & 117 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
"""

from collections import OrderedDict
from collections.abc import Iterable
from collections.abc import Callable, Iterable, Iterator, Sequence
from itertools import chain, groupby
from typing import Any, Generic, TypeVar
import ctypes

import cgen as c
Expand Down Expand Up @@ -58,6 +59,42 @@ def always_rebuild(self, o, *args, **kwargs):
return o._rebuild(*new_ops, **okwargs)


YieldType = TypeVar('YieldType', covariant=True)
ResultType = TypeVar('ResultType', covariant=True)


class LazyVisitor(GenericVisitor, Generic[YieldType, ResultType]):

"""
A generic visitor that lazily yields results instead of flattening results
from children at every step.

Subclass-defined visit methods should be generators.
"""

def lookup_method(self, instance) -> Callable[..., Iterator[YieldType]]:
return super().lookup_method(instance)

def _visit(self, o, *args, **kwargs) -> Iterator[YieldType]:
meth = self.lookup_method(o)
yield from meth(o, *args, **kwargs)

def _post_visit(self, ret: Iterator[YieldType]) -> ResultType:
return list(ret)

def visit_object(self, o: object, **kwargs) -> Iterator[YieldType]:
yield from ()

def visit_Node(self, o: Node, **kwargs) -> Iterator[YieldType]:
yield from self._visit(o.children, **kwargs)

def visit_tuple(self, o: Sequence[Any], **kwargs) -> Iterator[YieldType]:
for i in o:
yield from self._visit(i, **kwargs)

visit_list = visit_tuple


class PrintAST(Visitor):

_depth = 0
Expand Down Expand Up @@ -978,16 +1015,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
return ret


class FindSymbols(Visitor):

class Retval(list):
def __init__(self, *retvals):
elements = filter_ordered(flatten(retvals), key=id)
super().__init__(elements)

@classmethod
def default_retval(cls):
return cls.Retval()
class FindSymbols(LazyVisitor[Any, list[Any]]):

"""
Find symbols in an Iteration/Expression tree.
Expand All @@ -1006,32 +1034,32 @@ def default_retval(cls):
- `defines-aliases`: Collect all defined objects and their aliases
"""

@staticmethod
def _defines_aliases(n):
retval = []
for i in n.defines:
f = i.function
if f.is_ArrayBasic:
retval.extend([f, f.indexed])
yield from (f, f.indexed)
else:
retval.append(i)
return tuple(retval)
yield i

rules = {
RulesDict = dict[str, Callable[[Node], Iterator[Any]]]
rules: RulesDict = {
'symbolics': lambda n: n.functions,
'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)],
'symbols': lambda n: [i for i in n.expr_symbols
if isinstance(i, AbstractSymbol)],
'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)],
'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed],
'indexedbases': lambda n: [i for i in n.expr_symbols
if isinstance(i, IndexedBase)],
'basics': lambda n: (i for i in n.expr_symbols if isinstance(i, Basic)),
'symbols': lambda n: (i for i in n.expr_symbols
if isinstance(i, AbstractSymbol)),
'dimensions': lambda n: (i for i in n.expr_symbols if isinstance(i, Dimension)),
'indexeds': lambda n: (i for i in n.expr_symbols if i.is_Indexed),
'indexedbases': lambda n: (i for i in n.expr_symbols
if isinstance(i, IndexedBase)),
'writes': lambda n: as_tuple(n.writes),
'defines': lambda n: as_tuple(n.defines),
'globals': lambda n: [f.base for f in n.functions if f._mem_global],
'globals': lambda n: (f.base for f in n.functions if f._mem_global),
'defines-aliases': _defines_aliases
}

def __init__(self, mode='symbolics'):
def __init__(self, mode: str = 'symbolics') -> None:
super().__init__()

modes = mode.split('|')
Expand All @@ -1041,33 +1069,27 @@ def __init__(self, mode='symbolics'):
self.rule = lambda n: chain(*[self.rules[mode](n) for mode in modes])

def _post_visit(self, ret):
return sorted(ret, key=lambda i: str(i))
return sorted(filter_ordered(ret, key=id), key=str)

def visit_tuple(self, o):
return self.Retval(*[self._visit(i) for i in o])
def visit_Node(self, o: Node) -> Iterator[Any]:
yield from self._visit(o.children)
yield from self.rule(o)

visit_list = visit_tuple

def visit_Node(self, o):
return self.Retval(self._visit(o.children), self.rule(o))

def visit_ThreadedProdder(self, o):
def visit_ThreadedProdder(self, o) -> Iterator[Any]:
# TODO: this handle required because ThreadedProdder suffers from the
# long-standing issue affecting all Node subclasses which rely on
# multiple inheritance
return self.Retval(self._visit(o.then_body), self.rule(o))

def visit_Operator(self, o):
ret = self._visit(o.body)
ret.extend(flatten(self._visit(v) for v in o._func_table.values()))
return self.Retval(ret, self.rule(o))
yield from self._visit(o.then_body)
yield from self.rule(o)

def visit_Operator(self, o) -> Iterator[Any]:
yield from self._visit(o.body)
yield from self.rule(o)
for i in o._func_table.values():
yield from self._visit(i)

class FindNodes(Visitor):

@classmethod
def default_retval(cls):
return []
class FindNodes(LazyVisitor[Node, list[Node]]):

"""
Find all instances of given type.
Expand All @@ -1083,126 +1105,123 @@ def default_retval(cls):
appears.
"""

rules = {
RulesDict = dict[str, Callable[[type, Node], bool]]
rules: RulesDict = {
'type': lambda match, o: isinstance(o, match),
'scope': lambda match, o: match in flatten(o.children)
}

def __init__(self, match, mode='type'):
def __init__(self, match: type, mode: str = 'type') -> None:
super().__init__()
self.match = match
self.rule = self.rules[mode]

def visit_object(self, o, ret=None):
return ret

def visit_tuple(self, o, ret=None):
for i in o:
ret = self._visit(i, ret=ret)
return ret

visit_list = visit_tuple

def visit_Node(self, o, ret=None):
if ret is None:
ret = self.default_retval()
def visit_Node(self, o: Node, **kwargs) -> Iterator[Node]:
if self.rule(self.match, o):
ret.append(o)
yield o
for i in o.children:
ret = self._visit(i, ret=ret)
return ret
yield from self._visit(i, **kwargs)


class FindWithin(FindNodes):

@classmethod
def default_retval(cls):
return [], False

"""
Like FindNodes, but given an additional parameter `within=(start, stop)`,
it starts collecting matching nodes only after `start` is found, and stops
collecting matching nodes after `stop` is found.
"""

def __init__(self, match, start, stop=None):
# Sentinel values to signal the start/end of a matching window
SET_FLAG = object()
Copy link
Contributor

@FabioLuporini FabioLuporini Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have several concerns about the increased complexity in FindWithin

UNSET_FLAG = object()

def __init__(self, match: type, start: Node, stop: Node | None = None) -> None:
super().__init__(match)
self.start = start
self.stop = stop

def visit(self, o, ret=None):
found, _ = self._visit(o, ret=ret)
return found
def _post_visit(self, ret: Iterator[Node | object]) -> list[Node]:
return super()._post_visit(i for i in ret
if i not in (self.SET_FLAG, self.UNSET_FLAG))

def visit_Node(self, o, ret=None):
if ret is None:
ret = self.default_retval()
found, flag = ret
def visit_object(self, o: object, flag: bool = False) -> Iterator[Node | object]:
yield self.SET_FLAG if flag else self.UNSET_FLAG

def visit_tuple(self, o: Sequence[Any],
flag: bool = False) -> Iterator[Node | object]:
for el in o:
for i in self._visit(el, flag=flag):
# New flag state is yielded at the end of child results
if i is self.SET_FLAG:
flag = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just flag = self.SET_FLAG, and same below, so that you can eventually just return flag

continue
if i is self.UNSET_FLAG:
flag = False
continue

# Regular object
yield i

yield self.SET_FLAG if flag else self.UNSET_FLAG

visit_list = visit_tuple

if o is self.start:
flag = True
def visit_Node(self, o: Node, flag: bool = False) -> Iterator[Node | object]:
flag = flag or (o is self.start)

if flag and self.rule(self.match, o):
found.append(o)
for i in o.children:
found, newflag = self._visit(i, ret=(found, flag))
if flag and not newflag:
return found, newflag
flag = newflag
yield o

for child in o.children:
for i in self._visit(child, flag=flag):
# New flag state is yielded at the end of child results
if i is self.SET_FLAG:
flag = True
continue
if i is self.UNSET_FLAG:
if flag:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this different from what happens in visit_tuple?

yield self.UNSET_FLAG
return
continue

# Regular object
yield i

if o is self.stop:
flag = False
flag &= (o is not self.stop)
yield self.SET_FLAG if flag else self.UNSET_FLAG

return found, flag

ApplicationType = TypeVar('ApplicationType')

class FindApplications(Visitor):

class FindApplications(LazyVisitor[ApplicationType, set[ApplicationType]]):

"""
Find all SymPy applied functions (aka, `Application`s). The user may refine
the search by supplying a different target class.
"""

def __init__(self, cls=Application):
def __init__(self, cls: type[ApplicationType] = Application):
super().__init__()
self.match = lambda i: isinstance(i, cls) and not isinstance(i, Basic)

@classmethod
def default_retval(cls):
return set()

def visit_object(self, o, **kwargs):
return self.default_retval()

def visit_tuple(self, o, ret=None):
ret = ret or self.default_retval()
for i in o:
ret.update(self._visit(i, ret=ret))
return ret

def visit_Node(self, o, ret=None):
ret = ret or self.default_retval()
for i in o.children:
ret.update(self._visit(i, ret=ret))
return ret
def _post_visit(self, ret):
return set(ret)

def visit_Expression(self, o, **kwargs):
return o.expr.find(self.match)
def visit_Expression(self, o: Expression, **kwargs) -> Iterator[ApplicationType]:
yield from o.expr.find(self.match)

def visit_Iteration(self, o, **kwargs):
ret = self._visit(o.children) or self.default_retval()
ret.update(o.symbolic_min.find(self.match))
ret.update(o.symbolic_max.find(self.match))
return ret
def visit_Iteration(self, o: Iteration, **kwargs) -> Iterator[ApplicationType]:
yield from self._visit(o.children)
yield from o.symbolic_min.find(self.match)
yield from o.symbolic_max.find(self.match)

def visit_Call(self, o, **kwargs):
ret = self.default_retval()
def visit_Call(self, o: Call, **kwargs) -> Iterator[ApplicationType]:
for i in o.arguments:
try:
ret.update(i.find(self.match))
yield from i.find(self.match)
except (AttributeError, TypeError):
ret.update(self._visit(i, ret=ret))
return ret
yield from self._visit(i)


class IsPerfectIteration(Visitor):
Expand Down
Loading
Loading