Skip to content

Commit c10ac29

Browse files
Handle lower flags more carefully
1 parent f72d7e5 commit c10ac29

File tree

3 files changed

+65
-15
lines changed

3 files changed

+65
-15
lines changed

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,40 +100,43 @@ def find_solve_clients(var, assume_a):
100100
elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims:
101101
# If it's a left expand_dims, recurse on the output
102102
clients.extend(find_solve_clients(cl.outputs[0], assume_a))
103+
103104
return clients
104105

105106
assume_a = node.op.core_op.assume_a
106107

107108
if assume_a not in allowed_assume_a:
108109
return None
109110

110-
A, _ = get_root_A(node.inputs[0])
111+
root_A, root_A_transposed = get_root_A(node.inputs[0])
111112

112113
# Find Solve using A (or left expand_dims of A)
113114
# TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate
114115
# that to the A_decomp outputs
115-
A_solve_clients_and_transpose = [
116-
(client, False) for client in find_solve_clients(A, assume_a)
116+
root_A_solve_clients_and_transpose = [
117+
(client, False) for client in find_solve_clients(root_A, assume_a)
117118
]
118119

119120
# Find Solves using A.T
120-
for cl, _ in fgraph.clients[A]:
121+
for cl, _ in fgraph.clients[root_A]:
121122
if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out):
122123
A_T = cl.out
123-
A_solve_clients_and_transpose.extend(
124+
root_A_solve_clients_and_transpose.extend(
124125
(client, True) for client in find_solve_clients(A_T, assume_a)
125126
)
126127

127-
if not eager and len(A_solve_clients_and_transpose) == 1:
128+
if not eager and len(root_A_solve_clients_and_transpose) == 1:
128129
# If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
129130
# That's a "reuse" inside the inner vectorized loop
130131
batch_ndim = node.op.batch_ndim(node)
131-
(client, _) = A_solve_clients_and_transpose[0]
132-
original_A, b = client.inputs
132+
(client, _) = root_A_solve_clients_and_transpose[0]
133+
134+
A, b = client.inputs
135+
133136
if not any(
134137
a_bcast and not b_bcast
135138
for a_bcast, b_bcast in zip(
136-
original_A.type.broadcastable[:batch_ndim],
139+
A.type.broadcastable[:batch_ndim],
137140
b.type.broadcastable[:batch_ndim],
138141
strict=True,
139142
)
@@ -142,19 +145,27 @@ def find_solve_clients(var, assume_a):
142145

143146
# If any Op had check_finite=True, we also do it for the LU decomposition
144147
check_finite_decomp = False
145-
for client, _ in A_solve_clients_and_transpose:
148+
for client, _ in root_A_solve_clients_and_transpose:
146149
if client.op.core_op.check_finite:
147150
check_finite_decomp = True
148151
break
149152

150-
lower = node.op.core_op.lower
153+
(first_solve, transposed) = root_A_solve_clients_and_transpose[0]
154+
lower = first_solve.op.core_op.lower
155+
if transposed:
156+
lower = not lower
157+
151158
A_decomp = decompose_A(
152-
A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
159+
root_A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
153160
)
154161

155162
replacements = {}
156-
for client, transposed in A_solve_clients_and_transpose:
163+
for client, transposed in root_A_solve_clients_and_transpose:
157164
_, b = client.inputs
165+
lower = client.op.core_op.lower
166+
if transposed:
167+
lower = not lower
168+
158169
new_x = solve_decomposed_system(
159170
A_decomp,
160171
b,

pytensor/tensor/slinalg.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def cholesky(
201201
202202
"""
203203

204-
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
204+
return Blockwise(
205+
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
206+
)(x)
205207

206208

207209
class SolveBase(Op):

tests/tensor/linalg/test_rewriting.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,15 @@ def test_lu_decomposition_reused_forward_and_gradient(assume_a, counter, transpo
112112
np.testing.assert_allclose(resg0, resg1, rtol=rtol)
113113

114114

115-
@pytest.mark.parametrize("transposed", (False, True))
115+
@pytest.mark.parametrize("transposed", (False, True), ids=["no_trans", "trans"])
116116
@pytest.mark.parametrize(
117117
"assume_a, counter",
118118
(
119119
("gen", LUOpCounter),
120120
("tridiagonal", TriDiagLUOpCounter),
121121
("pos", CholeskyOpCounter),
122122
),
123+
ids=["assume_gen", "assume_tridiagonal", "assume_pos"],
123124
)
124125
def test_lu_decomposition_reused_blockwise(assume_a, counter, transposed):
125126
rewrite_name = reuse_decomposition_multiple_solves.__name__
@@ -251,3 +252,39 @@ def test_decomposition_reused_preserves_check_finite(assume_a, counter):
251252
assert fn_opt(A_valid, b1_valid * np.nan, b2_valid)
252253
with pytest.raises(ValueError, match="array must not contain infs or NaNs"):
253254
assert fn_opt(A_valid * np.nan, b1_valid, b2_valid)
255+
256+
257+
@pytest.mark.parametrize(
258+
"lower_first", [True, False], ids=["lower_first", "upper_first"]
259+
)
260+
def test_cho_solve_handles_lower_flags(lower_first):
261+
A = tensor("A", shape=(2, None))
262+
b = tensor("b", shape=(2,))
263+
264+
x1 = solve(A, b, assume_a="pos", lower=lower_first, check_finite=False)
265+
x2 = solve(A.mT, b, assume_a="pos", lower=not lower_first, check_finite=False)
266+
267+
dx1_dA = grad(x1.sum(), A)
268+
dx2_dA = grad(x2.sum(), A)
269+
270+
fn = function([A, b], [x1, dx1_dA, x2, dx2_dA])
271+
272+
rng = np.random.default_rng()
273+
L_values = rng.normal(size=(2, 2))
274+
A_values = L_values @ L_values.T # Ensure A is positive definite
275+
276+
if lower_first:
277+
A_values[0, 1] = np.nan
278+
else:
279+
A_values[1, 0] = np.nan
280+
281+
b_values = rng.normal(size=(2,))
282+
283+
# This computation should not raise an error, and none of them should be NaN
284+
res = fn(A_values, b_values)
285+
for x in res:
286+
assert np.isfinite(x).all()
287+
288+
# If we put the NaN in the wrong place, it should raise an error
289+
with pytest.raises(np.linalg.LinAlgError):
290+
fn(A_values.T, b_values)

0 commit comments

Comments
 (0)