-
Notifications
You must be signed in to change notification settings - Fork 239
compiler: Lazy IET visitors + Search #2621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
09de795
bb123c4
2886742
7314651
a6b6124
9f4b8c5
9bef19c
252eaa0
14ff62a
d6a722f
130efc1
a3b4048
8e1bc1b
d69462b
1cdf3e8
528d8f5
48373f5
f0bc509
d5c4aab
4f5a6ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,8 +5,9 @@ | |
""" | ||
|
||
from collections import OrderedDict | ||
from collections.abc import Iterable | ||
from collections.abc import Callable, Iterable, Iterator, Sequence | ||
from itertools import chain, groupby | ||
from typing import Any, Generic, TypeVar | ||
import ctypes | ||
|
||
import cgen as c | ||
|
@@ -58,6 +59,42 @@ def always_rebuild(self, o, *args, **kwargs): | |
return o._rebuild(*new_ops, **okwargs) | ||
|
||
|
||
YieldType = TypeVar('YieldType', covariant=True) | ||
ResultType = TypeVar('ResultType', covariant=True) | ||
|
||
|
||
class LazyVisitor(GenericVisitor, Generic[YieldType, ResultType]): | ||
|
||
""" | ||
A generic visitor that lazily yields results instead of flattening results | ||
from children at every step. | ||
|
||
Subclass-defined visit methods should be generators. | ||
""" | ||
|
||
def lookup_method(self, instance) -> Callable[..., Iterator[YieldType]]: | ||
return super().lookup_method(instance) | ||
|
||
def _visit(self, o, *args, **kwargs) -> Iterator[YieldType]: | ||
meth = self.lookup_method(o) | ||
yield from meth(o, *args, **kwargs) | ||
|
||
def _post_visit(self, ret: Iterator[YieldType]) -> ResultType: | ||
return list(ret) | ||
|
||
def visit_object(self, o: object, **kwargs) -> Iterator[YieldType]: | ||
yield from () | ||
|
||
def visit_Node(self, o: Node, **kwargs) -> Iterator[YieldType]: | ||
yield from self._visit(o.children, **kwargs) | ||
|
||
def visit_tuple(self, o: Sequence[Any], **kwargs) -> Iterator[YieldType]: | ||
for i in o: | ||
yield from self._visit(i, **kwargs) | ||
|
||
visit_list = visit_tuple | ||
|
||
|
||
class PrintAST(Visitor): | ||
|
||
_depth = 0 | ||
|
@@ -978,16 +1015,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False): | |
return ret | ||
|
||
|
||
class FindSymbols(Visitor): | ||
|
||
class Retval(list): | ||
def __init__(self, *retvals): | ||
elements = filter_ordered(flatten(retvals), key=id) | ||
super().__init__(elements) | ||
|
||
@classmethod | ||
def default_retval(cls): | ||
return cls.Retval() | ||
class FindSymbols(LazyVisitor[Any, list[Any]]): | ||
|
||
""" | ||
Find symbols in an Iteration/Expression tree. | ||
|
@@ -1006,32 +1034,32 @@ def default_retval(cls): | |
- `defines-aliases`: Collect all defined objects and their aliases | ||
""" | ||
|
||
@staticmethod | ||
def _defines_aliases(n): | ||
retval = [] | ||
for i in n.defines: | ||
f = i.function | ||
if f.is_ArrayBasic: | ||
retval.extend([f, f.indexed]) | ||
yield from (f, f.indexed) | ||
else: | ||
retval.append(i) | ||
return tuple(retval) | ||
yield i | ||
|
||
rules = { | ||
RulesDict = dict[str, Callable[[Node], Iterator[Any]]] | ||
rules: RulesDict = { | ||
'symbolics': lambda n: n.functions, | ||
'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)], | ||
'symbols': lambda n: [i for i in n.expr_symbols | ||
if isinstance(i, AbstractSymbol)], | ||
'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)], | ||
'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed], | ||
'indexedbases': lambda n: [i for i in n.expr_symbols | ||
if isinstance(i, IndexedBase)], | ||
'basics': lambda n: (i for i in n.expr_symbols if isinstance(i, Basic)), | ||
'symbols': lambda n: (i for i in n.expr_symbols | ||
if isinstance(i, AbstractSymbol)), | ||
'dimensions': lambda n: (i for i in n.expr_symbols if isinstance(i, Dimension)), | ||
'indexeds': lambda n: (i for i in n.expr_symbols if i.is_Indexed), | ||
'indexedbases': lambda n: (i for i in n.expr_symbols | ||
if isinstance(i, IndexedBase)), | ||
'writes': lambda n: as_tuple(n.writes), | ||
'defines': lambda n: as_tuple(n.defines), | ||
'globals': lambda n: [f.base for f in n.functions if f._mem_global], | ||
'globals': lambda n: (f.base for f in n.functions if f._mem_global), | ||
'defines-aliases': _defines_aliases | ||
} | ||
|
||
def __init__(self, mode='symbolics'): | ||
def __init__(self, mode: str = 'symbolics') -> None: | ||
super().__init__() | ||
|
||
modes = mode.split('|') | ||
|
@@ -1041,33 +1069,27 @@ def __init__(self, mode='symbolics'): | |
self.rule = lambda n: chain(*[self.rules[mode](n) for mode in modes]) | ||
|
||
def _post_visit(self, ret): | ||
return sorted(ret, key=lambda i: str(i)) | ||
return sorted(filter_ordered(ret, key=id), key=str) | ||
enwask marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def visit_tuple(self, o): | ||
return self.Retval(*[self._visit(i) for i in o]) | ||
def visit_Node(self, o: Node) -> Iterator[Any]: | ||
yield from self._visit(o.children) | ||
yield from self.rule(o) | ||
|
||
visit_list = visit_tuple | ||
|
||
def visit_Node(self, o): | ||
return self.Retval(self._visit(o.children), self.rule(o)) | ||
|
||
def visit_ThreadedProdder(self, o): | ||
def visit_ThreadedProdder(self, o) -> Iterator[Any]: | ||
# TODO: this handle required because ThreadedProdder suffers from the | ||
# long-standing issue affecting all Node subclasses which rely on | ||
# multiple inheritance | ||
return self.Retval(self._visit(o.then_body), self.rule(o)) | ||
|
||
def visit_Operator(self, o): | ||
ret = self._visit(o.body) | ||
ret.extend(flatten(self._visit(v) for v in o._func_table.values())) | ||
return self.Retval(ret, self.rule(o)) | ||
yield from self._visit(o.then_body) | ||
yield from self.rule(o) | ||
|
||
def visit_Operator(self, o) -> Iterator[Any]: | ||
yield from self._visit(o.body) | ||
yield from self.rule(o) | ||
for i in o._func_table.values(): | ||
yield from self._visit(i) | ||
|
||
class FindNodes(Visitor): | ||
|
||
@classmethod | ||
def default_retval(cls): | ||
return [] | ||
class FindNodes(LazyVisitor[Node, list[Node]]): | ||
|
||
""" | ||
Find all instances of given type. | ||
|
@@ -1083,126 +1105,123 @@ def default_retval(cls): | |
appears. | ||
""" | ||
|
||
rules = { | ||
RulesDict = dict[str, Callable[[type, Node], bool]] | ||
rules: RulesDict = { | ||
'type': lambda match, o: isinstance(o, match), | ||
'scope': lambda match, o: match in flatten(o.children) | ||
} | ||
|
||
def __init__(self, match, mode='type'): | ||
def __init__(self, match: type, mode: str = 'type') -> None: | ||
super().__init__() | ||
self.match = match | ||
self.rule = self.rules[mode] | ||
|
||
def visit_object(self, o, ret=None): | ||
return ret | ||
|
||
def visit_tuple(self, o, ret=None): | ||
for i in o: | ||
ret = self._visit(i, ret=ret) | ||
return ret | ||
|
||
visit_list = visit_tuple | ||
|
||
def visit_Node(self, o, ret=None): | ||
if ret is None: | ||
ret = self.default_retval() | ||
def visit_Node(self, o: Node, **kwargs) -> Iterator[Node]: | ||
if self.rule(self.match, o): | ||
ret.append(o) | ||
yield o | ||
for i in o.children: | ||
ret = self._visit(i, ret=ret) | ||
return ret | ||
yield from self._visit(i, **kwargs) | ||
|
||
|
||
class FindWithin(FindNodes): | ||
|
||
@classmethod | ||
def default_retval(cls): | ||
return [], False | ||
|
||
""" | ||
Like FindNodes, but given an additional parameter `within=(start, stop)`, | ||
it starts collecting matching nodes only after `start` is found, and stops | ||
collecting matching nodes after `stop` is found. | ||
""" | ||
|
||
def __init__(self, match, start, stop=None): | ||
# Sentinel values to signal the start/end of a matching window | ||
SET_FLAG = object() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have several concerns about the increased complexity in FindWithin |
||
UNSET_FLAG = object() | ||
|
||
def __init__(self, match: type, start: Node, stop: Node | None = None) -> None: | ||
super().__init__(match) | ||
self.start = start | ||
self.stop = stop | ||
|
||
def visit(self, o, ret=None): | ||
found, _ = self._visit(o, ret=ret) | ||
return found | ||
def _post_visit(self, ret: Iterator[Node | object]) -> list[Node]: | ||
return super()._post_visit(i for i in ret | ||
if i not in (self.SET_FLAG, self.UNSET_FLAG)) | ||
|
||
def visit_Node(self, o, ret=None): | ||
if ret is None: | ||
ret = self.default_retval() | ||
found, flag = ret | ||
def visit_object(self, o: object, flag: bool = False) -> Iterator[Node | object]: | ||
yield self.SET_FLAG if flag else self.UNSET_FLAG | ||
|
||
def visit_tuple(self, o: Sequence[Any], | ||
flag: bool = False) -> Iterator[Node | object]: | ||
for el in o: | ||
for i in self._visit(el, flag=flag): | ||
# New flag state is yielded at the end of child results | ||
if i is self.SET_FLAG: | ||
flag = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just |
||
continue | ||
if i is self.UNSET_FLAG: | ||
flag = False | ||
continue | ||
|
||
# Regular object | ||
yield i | ||
|
||
yield self.SET_FLAG if flag else self.UNSET_FLAG | ||
|
||
visit_list = visit_tuple | ||
|
||
if o is self.start: | ||
flag = True | ||
def visit_Node(self, o: Node, flag: bool = False) -> Iterator[Node | object]: | ||
EdCaunt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
flag = flag or (o is self.start) | ||
|
||
if flag and self.rule(self.match, o): | ||
found.append(o) | ||
for i in o.children: | ||
found, newflag = self._visit(i, ret=(found, flag)) | ||
if flag and not newflag: | ||
return found, newflag | ||
flag = newflag | ||
yield o | ||
|
||
for child in o.children: | ||
for i in self._visit(child, flag=flag): | ||
# New flag state is yielded at the end of child results | ||
if i is self.SET_FLAG: | ||
flag = True | ||
continue | ||
if i is self.UNSET_FLAG: | ||
if flag: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this different from what happens in visit_tuple? |
||
yield self.UNSET_FLAG | ||
return | ||
continue | ||
|
||
# Regular object | ||
yield i | ||
|
||
if o is self.stop: | ||
flag = False | ||
flag &= (o is not self.stop) | ||
yield self.SET_FLAG if flag else self.UNSET_FLAG | ||
|
||
return found, flag | ||
|
||
ApplicationType = TypeVar('ApplicationType') | ||
|
||
class FindApplications(Visitor): | ||
|
||
class FindApplications(LazyVisitor[ApplicationType, set[ApplicationType]]): | ||
|
||
""" | ||
Find all SymPy applied functions (aka, `Application`s). The user may refine | ||
the search by supplying a different target class. | ||
""" | ||
|
||
def __init__(self, cls=Application): | ||
def __init__(self, cls: type[ApplicationType] = Application): | ||
super().__init__() | ||
self.match = lambda i: isinstance(i, cls) and not isinstance(i, Basic) | ||
|
||
@classmethod | ||
def default_retval(cls): | ||
return set() | ||
|
||
def visit_object(self, o, **kwargs): | ||
return self.default_retval() | ||
|
||
def visit_tuple(self, o, ret=None): | ||
ret = ret or self.default_retval() | ||
for i in o: | ||
ret.update(self._visit(i, ret=ret)) | ||
return ret | ||
|
||
def visit_Node(self, o, ret=None): | ||
ret = ret or self.default_retval() | ||
for i in o.children: | ||
ret.update(self._visit(i, ret=ret)) | ||
return ret | ||
def _post_visit(self, ret): | ||
return set(ret) | ||
|
||
def visit_Expression(self, o, **kwargs): | ||
return o.expr.find(self.match) | ||
def visit_Expression(self, o: Expression, **kwargs) -> Iterator[ApplicationType]: | ||
yield from o.expr.find(self.match) | ||
|
||
def visit_Iteration(self, o, **kwargs): | ||
ret = self._visit(o.children) or self.default_retval() | ||
ret.update(o.symbolic_min.find(self.match)) | ||
ret.update(o.symbolic_max.find(self.match)) | ||
return ret | ||
def visit_Iteration(self, o: Iteration, **kwargs) -> Iterator[ApplicationType]: | ||
yield from self._visit(o.children) | ||
yield from o.symbolic_min.find(self.match) | ||
yield from o.symbolic_max.find(self.match) | ||
|
||
def visit_Call(self, o, **kwargs): | ||
ret = self.default_retval() | ||
def visit_Call(self, o: Call, **kwargs) -> Iterator[ApplicationType]: | ||
for i in o.arguments: | ||
try: | ||
ret.update(i.find(self.match)) | ||
yield from i.find(self.match) | ||
except (AttributeError, TypeError): | ||
ret.update(self._visit(i, ret=ret)) | ||
return ret | ||
yield from self._visit(i) | ||
|
||
|
||
class IsPerfectIteration(Visitor): | ||
|
Uh oh!
There was an error while loading. Please reload this page.