Skip to content

Commit fbf0d32

Browse files
committed
Better lazy search
1 parent 0450647 commit fbf0d32

File tree

1 file changed

+32
-36
lines changed

1 file changed

+32
-36
lines changed

devito/symbolics/search.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Callable, Iterable, Iterator
2-
from typing import Literal
2+
from itertools import chain
3+
from typing import Any, Literal
34

45
import sympy
56

@@ -11,20 +12,21 @@
1112
'retrieve_terminals', 'retrieve_symbols', 'retrieve_dimensions',
1213
'retrieve_derivatives', 'search']
1314

14-
class Set(set[sympy.Basic]):
15+
16+
class Set(set):
1517

1618
@staticmethod
17-
def wrap(obj: sympy.Basic) -> set[sympy.Basic]:
19+
def wrap(obj) -> set:
1820
return {obj}
1921

2022

21-
class List(list[sympy.Basic]):
23+
class List(list):
2224

2325
@staticmethod
24-
def wrap(obj: sympy.Basic) -> list[sympy.Basic]:
26+
def wrap(obj) -> list:
2527
return [obj]
2628

27-
def update(self, obj: sympy.Basic) -> None:
29+
def update(self, obj: Iterable[Any]) -> None:
2830
self.extend(obj)
2931

3032

@@ -35,48 +37,42 @@ def update(self, obj: sympy.Basic) -> None:
3537

3638

3739
class Search:
38-
39-
def __init__(self, query: Callable[[sympy.Basic], bool],
40-
order: Literal['postorder', 'preorder'], deep: bool = False) -> None:
40+
def __init__(self, query: Callable[[Any], bool], deep: bool = False) -> None:
4141
"""
42-
Search objects in an expression. This is much quicker than the more
43-
general SymPy's find.
42+
Search objects in an expression. This is much quicker than the more general
43+
SymPy's find.
4444
4545
Parameters
4646
----------
4747
query
4848
Any query from :mod:`queries`.
49-
order : str
50-
Either `preorder` or `postorder`, for the search order.
5149
deep : bool, optional
5250
If True, propagate the search within an Indexed's indices. Defaults to False.
5351
"""
5452
self.query = query
55-
self.order = order
5653
self.deep = deep
5754

58-
def _next(self, expr) -> Iterator[sympy.Basic]:
55+
def _next(self, expr) -> Iterator[Any]:
5956
if self.deep and expr.is_Indexed:
6057
yield from expr.indices
6158
elif not q_leaf(expr):
6259
yield from expr.args
6360

64-
def visit(self, expr: sympy.Basic) -> Iterator[sympy.Basic]:
65-
"""Visit the expression in the specified order."""
66-
if self.order == 'preorder':
67-
if self.query(expr):
68-
yield expr
69-
for child in self._next(expr):
70-
yield from self.visit(child)
71-
else:
72-
for child in self._next(expr):
73-
yield from self.visit(child)
74-
if self.query(expr):
75-
yield expr
61+
def visit_preorder(self, expr) -> Iterator[Any]:
62+
if self.query(expr):
63+
yield expr
64+
for i in self._next(expr):
65+
yield from self.visit_preorder(i)
7666

67+
def visit_postorder(self, expr) -> Iterator[Any]:
68+
for i in self._next(expr):
69+
yield from self.visit_postorder(i)
70+
if self.query(expr):
71+
yield expr
7772

78-
def search(exprs: sympy.Basic | Iterable[sympy.Basic],
79-
query: type | Callable[[sympy.Basic], bool],
73+
74+
def search(exprs,
75+
query: type | Callable[[Any], bool],
8076
mode: Literal['all', 'unique'] = 'unique',
8177
visit: Literal['dfs', 'bfs', 'bfs_first_hit'] = 'dfs',
8278
deep: bool = False) -> List | Set:
@@ -92,21 +88,21 @@ def search(exprs: sympy.Basic | Iterable[sympy.Basic],
9288

9389
# Search doesn't actually use a BFS (rather, a preorder DFS), but the terminology
9490
# is retained in this function's parameters for backwards compatibility
95-
order = 'postorder' if visit == 'dfs' else 'preorder'
96-
searcher = Search(Q, order, deep)
91+
searcher = Search(Q, deep)
92+
_visit = searcher.visit_postorder if visit == 'dfs' else searcher.visit_preorder
9793

9894
Collection = modes[mode]
9995
found = Collection()
10096
for e in as_tuple(exprs):
10197
if not isinstance(e, sympy.Basic):
10298
continue
10399

104-
for i in searcher.visit(e):
105-
found.update(Collection.wrap(i))
106-
107-
if visit == 'bfs_first_hit':
108-
# Stop at the first hit for this outer expression
100+
if visit == 'bfs_first_hit':
101+
for i in _visit(e):
102+
found.update(Collection.wrap(i))
109103
break
104+
else:
105+
found.update(_visit(e))
110106

111107
return found
112108

0 commit comments

Comments
 (0)