Skip to content

Commit 388e93e

Browse files
More carefully handle lower flag in Solve
1 parent 35444df commit 388e93e

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,40 +100,43 @@ def find_solve_clients(var, assume_a):
100100
elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims:
101101
# If it's a left expand_dims, recurse on the output
102102
clients.extend(find_solve_clients(cl.outputs[0], assume_a))
103+
103104
return clients
104105

105106
assume_a = node.op.core_op.assume_a
106107

107108
if assume_a not in allowed_assume_a:
108109
return None
109110

110-
A, _ = get_root_A(node.inputs[0])
111+
root_A, root_A_transposed = get_root_A(node.inputs[0])
111112

112113
# Find Solve using A (or left expand_dims of A)
113114
# TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate
114115
# that to the A_decomp outputs
115-
A_solve_clients_and_transpose = [
116-
(client, False) for client in find_solve_clients(A, assume_a)
116+
root_A_solve_clients_and_transpose = [
117+
(client, False) for client in find_solve_clients(root_A, assume_a)
117118
]
118119

119120
# Find Solves using A.T
120-
for cl, _ in fgraph.clients[A]:
121+
for cl, _ in fgraph.clients[root_A]:
121122
if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out):
122123
A_T = cl.out
123-
A_solve_clients_and_transpose.extend(
124+
root_A_solve_clients_and_transpose.extend(
124125
(client, True) for client in find_solve_clients(A_T, assume_a)
125126
)
126127

127-
if not eager and len(A_solve_clients_and_transpose) == 1:
128+
if not eager and len(root_A_solve_clients_and_transpose) == 1:
128129
# If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
129130
# That's a "reuse" inside the inner vectorized loop
130131
batch_ndim = node.op.batch_ndim(node)
131-
(client, _) = A_solve_clients_and_transpose[0]
132-
original_A, b = client.inputs
132+
(client, _) = root_A_solve_clients_and_transpose[0]
133+
134+
A, b = client.inputs
135+
133136
if not any(
134137
a_bcast and not b_bcast
135138
for a_bcast, b_bcast in zip(
136-
original_A.type.broadcastable[:batch_ndim],
139+
A.type.broadcastable[:batch_ndim],
137140
b.type.broadcastable[:batch_ndim],
138141
strict=True,
139142
)
@@ -142,19 +145,27 @@ def find_solve_clients(var, assume_a):
142145

143146
# If any Op had check_finite=True, we also do it for the LU decomposition
144147
check_finite_decomp = False
145-
for client, _ in A_solve_clients_and_transpose:
148+
for client, _ in root_A_solve_clients_and_transpose:
146149
if client.op.core_op.check_finite:
147150
check_finite_decomp = True
148151
break
149152

150-
lower = node.op.core_op.lower
153+
(first_solve, transposed) = root_A_solve_clients_and_transpose[0]
154+
lower = first_solve.op.core_op.lower
155+
if transposed:
156+
lower = not lower
157+
151158
A_decomp = decompose_A(
152-
A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
159+
root_A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
153160
)
154161

155162
replacements = {}
156-
for client, transposed in A_solve_clients_and_transpose:
163+
for client, transposed in root_A_solve_clients_and_transpose:
157164
_, b = client.inputs
165+
lower = client.op.core_op.lower
166+
if transposed:
167+
lower = not lower
168+
158169
new_x = solve_decomposed_system(
159170
A_decomp,
160171
b,

0 commit comments

Comments
 (0)