@@ -258,6 +258,7 @@ def test_decomposition_reused_preserves_check_finite(assume_a, counter):
258
258
"lower_first" , [True , False ], ids = ["lower_first" , "upper_first" ]
259
259
)
260
260
def test_cho_solve_handles_lower_flags (lower_first ):
261
+ rewrite_name = reuse_decomposition_multiple_solves .__name__
261
262
A = tensor ("A" , shape = (2 , None ))
262
263
b = tensor ("b" , shape = (2 ,))
263
264
@@ -268,22 +269,35 @@ def test_cho_solve_handles_lower_flags(lower_first):
268
269
dx2_dA = grad (x2 .sum (), A )
269
270
270
271
fn = function ([A , b ], [x1 , dx1_dA , x2 , dx2_dA ])
272
+ fn_no_rewrite = function (
273
+ [A , b ],
274
+ [x1 , dx1_dA , x2 , dx2_dA ],
275
+ mode = get_default_mode ().excluding (rewrite_name ),
276
+ )
271
277
272
278
rng = np .random .default_rng ()
273
- L_values = rng .normal (size = (2 , 2 ))
279
+ L_values = rng .normal (size = (2 , 2 )). astype ( config . floatX )
274
280
A_values = L_values @ L_values .T # Ensure A is positive definite
275
281
276
282
if lower_first :
277
283
A_values [0 , 1 ] = np .nan
278
284
else :
279
285
A_values [1 , 0 ] = np .nan
280
286
281
- b_values = rng .normal (size = (2 ,))
287
+ b_values = rng .normal (size = (2 ,)). astype ( config . floatX )
282
288
283
289
# This computation should not raise an error, and none of them should be NaN
284
290
res = fn (A_values , b_values )
285
- for x in res :
291
+ expected_res = fn_no_rewrite (A_values , b_values )
292
+
293
+ for x , expected_x in zip (res , expected_res ):
286
294
assert np .isfinite (x ).all ()
295
+ np .testing .assert_allclose (
296
+ x ,
297
+ expected_x ,
298
+ atol = 1e-6 if config .floatX == "float64" else 1e-3 ,
299
+ rtol = 1e-6 if config .floatX == "float64" else 1e-3 ,
300
+ )
287
301
288
302
# If we put the NaN in the wrong place, it should raise an error
289
303
with pytest .raises (np .linalg .LinAlgError ):
0 commit comments