diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 9694a022e3..c796c155f5 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -64,6 +64,7 @@ log, log1mexp, log1p, + log1pexp, makeKeepDims, maximum, mul, @@ -400,7 +401,7 @@ def local_exp_log(fgraph, node): @register_specialize -@node_rewriter([exp, expm1]) +@node_rewriter([exp, expm1, log1pexp, log1mexp]) def local_exp_log_nan_switch(fgraph, node): # Rewrites of the kind exp(log...(x)) that require a `nan` switch x = node.inputs[0] @@ -453,6 +454,20 @@ def local_exp_log_nan_switch(fgraph, node): new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype)) return [new_out] + # Case for log1pexp(log(x)) -> log1p(x) (log1pexp aka softplus) + if isinstance(prev_op, ps.Log) and isinstance(node_op, ps_math.Softplus): + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = switch(ge(x, 0), log1p(x), np.asarray(np.nan, old_out.dtype)) + return [new_out] + + # Case for log1mexp(log(x)) -> log1p(-x) + if isinstance(prev_op, ps.Log) and isinstance(node_op, ps_math.Log1mexp): + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = switch(ge(x, 0), log1p(-x), np.asarray(np.nan, old_out.dtype)) + return [new_out] + @register_canonicalize @register_specialize diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 9a092663a9..4080b979c9 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -67,6 +67,7 @@ log, log1mexp, log1p, + log1pexp, lt, maximum, minimum, @@ -2010,6 +2011,53 @@ def test_exp_softplus(self, exp_op): decimal=6, ) + def test_log1pexp_log(self): + # log1pexp(log(x)) -> log1p(x) + data_valid = np.random.random((4, 3)).astype("float32") * 2 + data_valid[0, 0] = 0 # edge case + data_invalid = data_valid - 2 + + x = fmatrix() + f = function([x], log1pexp(log(x)), mode=self.mode.excluding("inplace")) + assert equal_computations( + f.maker.fgraph.outputs, + [ + pt.switch( + x >= np.array([[0]], dtype=np.int8), + pt.log1p(x), + np.array([[np.nan]], dtype=np.float32), + ) + ], + ) + + expected = np.log1p(data_valid) + np.testing.assert_almost_equal(f(data_valid), expected) + assert np.all(np.isnan(f(data_invalid))) + + def test_log1mexp_log(self): + # log1mexp(log(x)) -> log1p(-x) + data_valid = np.random.random((4, 3)).astype("float32") + data_valid[0, 0] = 0 # edge case + data_valid[0, 1] = 1 # another edge case + data_invalid = np.concatenate([data_valid + 1.1, data_valid - 1.1]) + + x = fmatrix() + f = function([x], log1mexp(log(x)), mode=self.mode.excluding("inplace")) + assert equal_computations( + f.maker.fgraph.outputs, + [ + pt.switch( + x >= np.array([[0]], dtype=np.int8), + pt.log1p(-x), + np.array([[np.nan]], dtype=np.float32), + ) + ], + ) + + expected = np.log1p(-data_valid) + np.testing.assert_almost_equal(f(data_valid), expected) + assert np.all(np.isnan(f(data_invalid))) + @pytest.mark.parametrize( ["nested_expression", "expected_switches"], [