1
1
from collections .abc import Callable , Iterable , Iterator
2
- from typing import Literal
2
+ from itertools import chain
3
+ from typing import Any , Literal
3
4
4
5
import sympy
5
6
11
12
'retrieve_terminals' , 'retrieve_symbols' , 'retrieve_dimensions' ,
12
13
'retrieve_derivatives' , 'search' ]
13
14
14
- class Set (set [sympy .Basic ]):
15
+
16
+ class Set (set ):
15
17
16
18
@staticmethod
17
- def wrap (obj : sympy . Basic ) -> set [ sympy . Basic ] :
19
+ def wrap (obj ) -> set :
18
20
return {obj }
19
21
20
22
21
- class List (list [ sympy . Basic ] ):
23
+ class List (list ):
22
24
23
25
@staticmethod
24
- def wrap (obj : sympy . Basic ) -> list [ sympy . Basic ] :
26
+ def wrap (obj ) -> list :
25
27
return [obj ]
26
28
27
- def update (self , obj : sympy . Basic ) -> None :
29
+ def update (self , obj : Iterable [ Any ] ) -> None :
28
30
self .extend (obj )
29
31
30
32
@@ -35,48 +37,42 @@ def update(self, obj: sympy.Basic) -> None:
35
37
36
38
37
39
class Search :
38
-
39
- def __init__ (self , query : Callable [[sympy .Basic ], bool ],
40
- order : Literal ['postorder' , 'preorder' ], deep : bool = False ) -> None :
40
+ def __init__ (self , query : Callable [[Any ], bool ], deep : bool = False ) -> None :
41
41
"""
42
- Search objects in an expression. This is much quicker than the more
43
- general SymPy's find.
42
+ Search objects in an expression. This is much quicker than the more general
43
+ SymPy's find.
44
44
45
45
Parameters
46
46
----------
47
47
query
48
48
Any query from :mod:`queries`.
49
- order : str
50
- Either `preorder` or `postorder`, for the search order.
51
49
deep : bool, optional
52
50
If True, propagate the search within an Indexed's indices. Defaults to False.
53
51
"""
54
52
self .query = query
55
- self .order = order
56
53
self .deep = deep
57
54
58
- def _next (self , expr ) -> Iterator [sympy . Basic ]:
55
+ def _next (self , expr ) -> Iterator [Any ]:
59
56
if self .deep and expr .is_Indexed :
60
57
yield from expr .indices
61
58
elif not q_leaf (expr ):
62
59
yield from expr .args
63
60
64
- def visit (self , expr : sympy .Basic ) -> Iterator [sympy .Basic ]:
65
- """Visit the expression in the specified order."""
66
- if self .order == 'preorder' :
67
- if self .query (expr ):
68
- yield expr
69
- for child in self ._next (expr ):
70
- yield from self .visit (child )
71
- else :
72
- for child in self ._next (expr ):
73
- yield from self .visit (child )
74
- if self .query (expr ):
75
- yield expr
61
+ def visit_preorder (self , expr ) -> Iterator [Any ]:
62
+ if self .query (expr ):
63
+ yield expr
64
+ for i in self ._next (expr ):
65
+ yield from self .visit_preorder (i )
76
66
67
+ def visit_postorder (self , expr ) -> Iterator [Any ]:
68
+ for i in self ._next (expr ):
69
+ yield from self .visit_postorder (i )
70
+ if self .query (expr ):
71
+ yield expr
77
72
78
- def search (exprs : sympy .Basic | Iterable [sympy .Basic ],
79
- query : type | Callable [[sympy .Basic ], bool ],
73
+
74
+ def search (exprs ,
75
+ query : type | Callable [[Any ], bool ],
80
76
mode : Literal ['all' , 'unique' ] = 'unique' ,
81
77
visit : Literal ['dfs' , 'bfs' , 'bfs_first_hit' ] = 'dfs' ,
82
78
deep : bool = False ) -> List | Set :
@@ -92,21 +88,21 @@ def search(exprs: sympy.Basic | Iterable[sympy.Basic],
92
88
93
89
# Search doesn't actually use a BFS (rather, a preorder DFS), but the terminology
94
90
# is retained in this function's parameters for backwards compatibility
95
- order = 'postorder' if visit == 'dfs' else 'preorder'
96
- searcher = Search ( Q , order , deep )
91
+ searcher = Search ( Q , deep )
92
+ _visit = searcher . visit_postorder if visit == 'dfs' else searcher . visit_preorder
97
93
98
94
Collection = modes [mode ]
99
95
found = Collection ()
100
96
for e in as_tuple (exprs ):
101
97
if not isinstance (e , sympy .Basic ):
102
98
continue
103
99
104
- for i in searcher .visit (e ):
105
- found .update (Collection .wrap (i ))
106
-
107
- if visit == 'bfs_first_hit' :
108
- # Stop at the first hit for this outer expression
100
+ if visit == 'bfs_first_hit' :
101
+ for i in _visit (e ):
102
+ found .update (Collection .wrap (i ))
109
103
break
104
+ else :
105
+ found .update (_visit (e ))
110
106
111
107
return found
112
108
0 commit comments