Skip to content

Commit b41cfc7

Browse files
Numerically test rewrite
1 parent c10ac29 commit b41cfc7

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

tests/tensor/linalg/test_rewriting.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def test_decomposition_reused_preserves_check_finite(assume_a, counter):
258258
"lower_first", [True, False], ids=["lower_first", "upper_first"]
259259
)
260260
def test_cho_solve_handles_lower_flags(lower_first):
261+
rewrite_name = reuse_decomposition_multiple_solves.__name__
261262
A = tensor("A", shape=(2, None))
262263
b = tensor("b", shape=(2,))
263264

@@ -268,22 +269,35 @@ def test_cho_solve_handles_lower_flags(lower_first):
268269
dx2_dA = grad(x2.sum(), A)
269270

270271
fn = function([A, b], [x1, dx1_dA, x2, dx2_dA])
272+
fn_no_rewrite = function(
273+
[A, b],
274+
[x1, dx1_dA, x2, dx2_dA],
275+
mode=get_default_mode().excluding(rewrite_name),
276+
)
271277

272278
rng = np.random.default_rng()
273-
L_values = rng.normal(size=(2, 2))
279+
L_values = rng.normal(size=(2, 2)).astype(config.floatX)
274280
A_values = L_values @ L_values.T # Ensure A is positive definite
275281

276282
if lower_first:
277283
A_values[0, 1] = np.nan
278284
else:
279285
A_values[1, 0] = np.nan
280286

281-
b_values = rng.normal(size=(2,))
287+
b_values = rng.normal(size=(2,)).astype(config.floatX)
282288

283289
# This computation should not raise an error, and none of them should be NaN
284290
res = fn(A_values, b_values)
285-
for x in res:
291+
expected_res = fn_no_rewrite(A_values, b_values)
292+
293+
for x, expected_x in zip(res, expected_res):
286294
assert np.isfinite(x).all()
295+
np.testing.assert_allclose(
296+
x,
297+
expected_x,
298+
atol=1e-6 if config.floatX == "float64" else 1e-3,
299+
rtol=1e-6 if config.floatX == "float64" else 1e-3,
300+
)
287301

288302
# If we put the NaN in the wrong place, it should raise an error
289303
with pytest.raises(np.linalg.LinAlgError):

0 commit comments

Comments
 (0)