From b14baead540e40d2094950fbac2f1aa7001bfdb8 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Mon, 15 Jul 2024 16:44:42 +0800 Subject: [PATCH 1/3] Minor edits to contributing documentation --- doc/dev_start_guide.rst | 5 +-- doc/extending/creating_a_numba_jax_op.rst | 38 ++++++----------------- doc/internal/metadocumentation.rst | 28 +---------------- 3 files changed, 13 insertions(+), 58 deletions(-) diff --git a/doc/dev_start_guide.rst b/doc/dev_start_guide.rst index 010d0ffb75..6c5f5f581b 100644 --- a/doc/dev_start_guide.rst +++ b/doc/dev_start_guide.rst @@ -209,7 +209,8 @@ You can now build the documentation from the root of the project with: .. code-block:: bash - python -m sphinx -b html ./doc ./html + # -j for parallel and faster doc build + sphinx-build -b html ./doc ./html -j auto Afterward, you can go to `html/index.html` and navigate the changes in a browser. One way to do this is to go to the `html` directory and run: @@ -219,7 +220,7 @@ Afterward, you can go to `html/index.html` and navigate the changes in a browser python -m http.server -**Do not commit the `html` directory. The documentation is built automatically.** +**Do not commit the `html` directory.** For more documentation customizations such as different formats e.g., PDF, refer to the `Sphinx documentation `_. Other tools that might help diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index 42c7304b5c..b84a2001f7 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -1,5 +1,5 @@ Adding JAX, Numba and Pytorch support for `Op`\s -======================================= +================================================ PyTensor is able to convert its graphs into JAX, Numba and Pytorch compiled functions. In order to do this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba/Pytorch implementation function. @@ -7,7 +7,7 @@ this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba/Py This tutorial will explain how JAX, Numba and Pytorch implementations are created for an :class:`Op`. Step 1: Identify the PyTensor :class:`Op` you'd like to implement ------------------------------------------------------------------------- +----------------------------------------------------------------- Find the source for the PyTensor :class:`Op` you'd like to be supported and identify the function signature and return values. These can be determined by @@ -97,8 +97,8 @@ Next, we look at the :meth:`Op.perform` implementation to see exactly how the inputs and outputs are used to compute the outputs for an :class:`Op` in Python. This method is effectively what needs to be implemented. -Step 2: Find the relevant method in JAX/Numba/Pytorch (or something close) ---------------------------------------------------------- +Step 2: Find the relevant or close method in JAX/Numba/Pytorch +-------------------------------------------------------------- With a precise idea of what the PyTensor :class:`Op` does we need to figure out how to implement it in JAX, Numba or Pytorch. In the best case scenario, there is a similarly named @@ -269,7 +269,7 @@ and :func:`torch.cumprod` z[0] = np.cumprod(x, axis=self.axis) Step 3: Register the function with the respective dispatcher ---------------------------------------------------------------- +------------------------------------------------------------ With the PyTensor `Op` replicated, we'll need to register the function with the backends `Linker`. This is done through the use of @@ -626,28 +626,8 @@ Step 4: Write tests Note ---- -In out previous example of extending JAX, :class:`Eye`\ :class:`Op` was used with the test function as follows: - -.. code:: python - - def test_jax_Eye(): - """Test JAX conversion of the `Eye` `Op`.""" - - # Create a symbolic input for `Eye` - x_at = pt.scalar() - - # Create a variable that is the output of an `Eye` `Op` - eye_var = pt.eye(x_at) - - # Create an PyTensor `FunctionGraph` - out_fg = FunctionGraph(outputs=[eye_var]) - - # Pass the graph and any inputs to the testing function - compare_jax_and_py(out_fg, [3]) - -This one nowadays leads to a test failure due to new restrictions in JAX + JIT, -as reported in issue `#654 `_. -All jitted functions now must have constant shape, which means a graph like the +Due to new restrictions in JAX JIT as reported in issue `#654 `_, +all jitted functions now must have constant shape. In other words, only PyTensor graphs with static shapes +can be translated to JAX at the moment. It means a graph like the one of :class:`Eye` can never be translated to JAX, since it's fundamentally a -function with dynamic shapes. In other words, only PyTensor graphs with static shapes -can be translated to JAX at the moment. \ No newline at end of file +function with dynamic shapes. \ No newline at end of file diff --git a/doc/internal/metadocumentation.rst b/doc/internal/metadocumentation.rst index e5c2be28de..2dc7ad5c65 100644 --- a/doc/internal/metadocumentation.rst +++ b/doc/internal/metadocumentation.rst @@ -8,33 +8,7 @@ Documentation Documentation AKA Meta-Documentation How to build documentation -------------------------- -Let's say you are writing documentation, and want to see the `sphinx -`__ output before you push it. -The documentation will be generated in the ``html`` directory. - -.. code-block:: bash - - cd PyTensor/ - python ./doc/scripts/docgen.py - -If you don't want to generate the pdf, do the following: - -.. code-block:: bash - - cd PyTensor/ - python ./doc/scripts/docgen.py --nopdf - - -For more details: - -.. code-block:: bash - - $ python doc/scripts/docgen.py --help - Usage: doc/scripts/docgen.py [OPTIONS] - -o : output the html files in the specified dir - --rst: only compile the doc (requires sphinx) - --nopdf: do not produce a PDF file from the doc, only HTML - --help: this help +Refer to `relevant section of Developer Start Guide `_. Use ReST for documentation -------------------------- From b85397cfe9c0515c5d57c1947227567d2f00b4b9 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Wed, 17 Jul 2024 09:47:07 +0800 Subject: [PATCH 2/3] Edited issues pointed by Ricardo and Thomas --- doc/extending/creating_a_numba_jax_op.rst | 36 +++++++++++++++++++---- doc/internal/metadocumentation.rst | 2 +- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index b84a2001f7..16a6d8830c 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -626,8 +626,34 @@ Step 4: Write tests Note ---- -Due to new restrictions in JAX JIT as reported in issue `#654 `_, -all jitted functions now must have constant shape. In other words, only PyTensor graphs with static shapes -can be translated to JAX at the moment. It means a graph like the -one of :class:`Eye` can never be translated to JAX, since it's fundamentally a -function with dynamic shapes. \ No newline at end of file +Due to restrictions in JAX JIT as reported in issue `#654 `_, +all jitted functions must have constant shape. In other words, only PyTensor graphs with static shapes +can be translated to JAX at the moment. It means a graph like the old test function for :class:`Eye` `Op` + +.. code:: python + + def test_jax_eye(): + # Create a symbolic input for `Eye` + x_at = pt.scalar(dtype=np.int64) + + # Create a variable that is the output of an `Eye` `Op` + eye_var = pt.eye(x_at) + + # Create an PyTensor `FunctionGraph` + out_fg = FunctionGraph(outputs=[eye_var]) + + # Pass the graph and any inputs to the testing function + compare_jax_and_py(out_fg, [3]) + +cannot be translated to JAX, since it involved a function with dynamic shapes. +That's why the JAX link test for :class:`Eye` `Op` is now +.. code:: python + + def test_jax_eye(): + """Tests jaxification of the Eye operator""" + out = ptb.eye(3) + out_fg = FunctionGraph([], [out]) + + compare_jax_and_py(out_fg, []) + +with the shape specified explicitly instead of via a variable. \ No newline at end of file diff --git a/doc/internal/metadocumentation.rst b/doc/internal/metadocumentation.rst index 2dc7ad5c65..a618c1e4ed 100644 --- a/doc/internal/metadocumentation.rst +++ b/doc/internal/metadocumentation.rst @@ -8,7 +8,7 @@ Documentation Documentation AKA Meta-Documentation How to build documentation -------------------------- -Refer to `relevant section of Developer Start Guide `_. +Refer to relevant section of :doc:`../dev_start_guide`. Use ReST for documentation -------------------------- From e087ac049094e91aa4dcdf75f8b099aedf68f88a Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Wed, 17 Jul 2024 16:15:11 +0800 Subject: [PATCH 3/3] Edited the Note section --- doc/extending/creating_a_numba_jax_op.rst | 38 +++++++++-------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index 16a6d8830c..75063551b7 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -626,34 +626,26 @@ Step 4: Write tests Note ---- -Due to restrictions in JAX JIT as reported in issue `#654 `_, -all jitted functions must have constant shape. In other words, only PyTensor graphs with static shapes -can be translated to JAX at the moment. It means a graph like the old test function for :class:`Eye` `Op` +Due to restrictions with JAX JIT compiler as reported in issue `#654 `_, +PyTensor graphs with dynamic shapes may be untranslatable to JAX. For example, this code snipper for :class:`Eye` `Op` .. code:: python - def test_jax_eye(): - # Create a symbolic input for `Eye` - x_at = pt.scalar(dtype=np.int64) + x_at = pt.scalar(dtype=np.int64) + eye_var = pt.eye(x_at) + f = pytensor.function([x_at], eye_var, mode="JAX") + f(3) - # Create a variable that is the output of an `Eye` `Op` - eye_var = pt.eye(x_at) +cannot be translated to JAX, since it involved a dynamic shape. This is one issue that may pop up during +linking an `Op` to JAX. - # Create an PyTensor `FunctionGraph` - out_fg = FunctionGraph(outputs=[eye_var]) +Note that not that all dynamic shapes are disallowed. +For example, if the function depends on input shapes, it still works. +This code snippet gives the answer that is expected in the example above. - # Pass the graph and any inputs to the testing function - compare_jax_and_py(out_fg, [3]) - -cannot be translated to JAX, since it involved a function with dynamic shapes. -That's why the JAX link test for :class:`Eye` `Op` is now .. code:: python - def test_jax_eye(): - """Tests jaxification of the Eye operator""" - out = ptb.eye(3) - out_fg = FunctionGraph([], [out]) - - compare_jax_and_py(out_fg, []) - -with the shape specified explicitly instead of via a variable. \ No newline at end of file + x_at = pt.vector(dtype=np.int64) + eye_var = pt.eye(x_at.shape[0]) + f = pytensor.function([x_at], eye_var, mode="JAX") + f([3, 3, 3]) \ No newline at end of file