diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 79ca5beec1..c22945d914 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -3,6 +3,7 @@ import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.scalar import ScalarLoop from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -11,6 +12,7 @@ @pytorch_funcify.register(Elemwise) def pytorch_funcify_Elemwise(op, node, **kwargs): scalar_op = op.scalar_op + base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) def check_special_scipy(func_name): @@ -33,6 +35,9 @@ def elemwise_fn(*inputs): Elemwise._check_runtime_broadcast(node, inputs) return base_fn(*inputs) + elif isinstance(scalar_op, ScalarLoop): + return elemwise_ravel_fn(base_fn, op, node, **kwargs) + else: def elemwise_fn(*inputs): @@ -176,3 +181,37 @@ def softmax_grad(dy, sm): return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm return softmax_grad + + +def elemwise_ravel_fn(base_fn, op, node, **kwargs): + """ + Dispatch methods using `.item()` (ScalarLoop + Elemwise) is common, but vmap + in torch has a limitation: https://github.com/pymc-devs/pytensor/issues/1031, + Instead, we can ravel all the inputs, broadcasted according to torch + """ + + n_outputs = len(node.outputs) + + def elemwise_fn(*inputs): + bcasted_inputs = torch.broadcast_tensors(*inputs) + raveled_inputs = [inp.ravel() for inp in bcasted_inputs] + + out_shape = bcasted_inputs[0].size() + out_size = out_shape.numel() + raveled_outputs = [torch.empty(out_size) for out in node.outputs] + + for i in range(out_size): + core_outs = base_fn(*(inp[i] for inp in raveled_inputs)) + if n_outputs == 1: + raveled_outputs[0][i] = core_outs + else: + for o in range(n_outputs): + raveled_outputs[o][i] = core_outs[o] + + outputs = tuple(out.view(out_shape) for out in raveled_outputs) + if n_outputs == 1: + return outputs[0] + else: + return outputs + + return elemwise_fn diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 1416e58f55..65170b1f53 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -7,6 +7,7 @@ Cast, ScalarOp, ) +from pytensor.scalar.loop import ScalarLoop from pytensor.scalar.math import Softplus @@ -62,3 +63,37 @@ def cast(x): @pytorch_funcify.register(Softplus) def pytorch_funcify_Softplus(op, node, **kwargs): return torch.nn.Softplus() + + +@pytorch_funcify.register(ScalarLoop) +def pytorch_funicify_ScalarLoop(op, node, **kwargs): + update = pytorch_funcify(op.fgraph, **kwargs) + state_length = op.nout + if op.is_while: + + def scalar_loop(steps, *start_and_constants): + carry, constants = ( + start_and_constants[:state_length], + start_and_constants[state_length:], + ) + done = True + for _ in range(steps): + *carry, done = update(*carry, *constants) + if torch.any(done): + break + return *carry, done + else: + + def scalar_loop(steps, *start_and_constants): + carry, constants = ( + start_and_constants[:state_length], + start_and_constants[state_length:], + ) + for _ in range(steps): + carry = update(*carry, *constants) + if len(node.outputs) == 1: + return carry[0] + else: + return carry + + return scalar_loop diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index ac0b0c8c02..d47aa43dda 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -54,21 +54,22 @@ def __init__(self, fn, gen_functors): self.fn = torch.compile(fn) self.gen_functors = gen_functors.copy() - def __call__(self, *args, **kwargs): + def __call__(self, *inputs, **kwargs): import pytensor.link.utils # set attrs for n, fn in self.gen_functors: setattr(pytensor.link.utils, n[1:], fn) - res = self.fn(*args, **kwargs) + # Torch does not accept numpy inputs and may return GPU objects + outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs) # unset attrs for n, _ in self.gen_functors: if getattr(pytensor.link.utils, n[1:], False): delattr(pytensor.link.utils, n[1:]) - return res + return tuple(out.cpu().numpy() for out in outs) def __del__(self): del self.gen_functors @@ -76,12 +77,7 @@ def __del__(self): inner_fn = wrapper(fn, self.gen_functors) self.gen_functors = [] - # Torch does not accept numpy inputs and may return GPU objects - def fn(*inputs, inner_fn=inner_fn): - outs = inner_fn(*(pytorch_typify(inp) for inp in inputs)) - return tuple(out.cpu().numpy() for out in outs) - - return fn + return inner_fn def create_thunk_inputs(self, storage_map): thunk_inputs = [] diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index d7e2aef47b..2ac8ee7c3b 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -4,6 +4,7 @@ import numpy as np import pytest +import pytensor.tensor as pt import pytensor.tensor.basic as ptb from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function @@ -17,7 +18,10 @@ from pytensor.ifelse import ifelse from pytensor.link.pytorch.linker import PytorchLinker from pytensor.raise_op import CheckAndRaise +from pytensor.scalar import float64, int64 +from pytensor.scalar.loop import ScalarLoop from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus +from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.type import matrices, matrix, scalar, vector @@ -385,3 +389,85 @@ def test_pytorch_softplus(): out = softplus(x) f = FunctionGraph([x], [out]) compare_pytorch_and_py(f, [np.random.rand(3)]) + + +def test_ScalarLoop(): + n_steps = int64("n_steps") + x0 = float64("x0") + const = float64("const") + x = x0 + const + + op = ScalarLoop(init=[x0], constant=[const], update=[x]) + x = op(n_steps, x0, const) + + fn = function([n_steps, x0, const], x, mode=pytorch_mode) + np.testing.assert_allclose(fn(5, 0, 1), 5) + np.testing.assert_allclose(fn(5, 0, 2), 10) + np.testing.assert_allclose(fn(4, 3, -1), -1) + + +def test_ScalarLoop_while(): + n_steps = int64("n_steps") + x0 = float64("x0") + x = x0 + 1 + until = x >= 10 + + op = ScalarLoop(init=[x0], update=[x], until=until) + fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode) + for res, expected in zip( + [fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)], + [[10, True], [10, True], [6, False]], + strict=True, + ): + np.testing.assert_allclose(res[0], np.array(expected[0])) + np.testing.assert_allclose(res[1], np.array(expected[1])) + + +def test_ScalarLoop_Elemwise_single_carries(): + n_steps = int64("n_steps") + x0 = float64("x0") + x = x0 * 2 + until = x >= 10 + + scalarop = ScalarLoop(init=[x0], update=[x], until=until) + op = Elemwise(scalarop) + + n_steps = pt.scalar("n_steps", dtype="int32") + x0 = pt.vector("x0", dtype="float32") + state, done = op(n_steps, x0) + + f = FunctionGraph([n_steps, x0], [state, done]) + args = [ + np.array(10).astype("int32"), + np.arange(0, 5).astype("float32"), + ] + compare_pytorch_and_py( + f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) + ) + + +def test_ScalarLoop_Elemwise_multi_carries(): + n_steps = int64("n_steps") + x0 = float64("x0") + x1 = float64("x1") + x = x0 * 2 + x1_n = x1 * 3 + until = x >= 10 + + scalarop = ScalarLoop(init=[x0, x1], update=[x, x1_n], until=until) + op = Elemwise(scalarop) + + n_steps = pt.scalar("n_steps", dtype="int32") + x0 = pt.vector("x0", dtype="float32") + x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1)) + *states, done = op(n_steps, x0, x1) + + f = FunctionGraph([n_steps, x0, x1], [*states, done]) + args = [ + np.array(10).astype("int32"), + np.arange(0, 5).astype("float32"), + np.random.rand(7, 3, 1).astype("float32"), + ] + compare_pytorch_and_py( + f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) + )