Skip to content

Commit 158ee2a

Browse files
Add xtensor docs
Co-authored-by: Oriol Abril-Pla <oriol.abril.pla@gmail.com>
1 parent a43f6b8 commit 158ee2a

File tree

14 files changed

+897
-220
lines changed

14 files changed

+897
-220
lines changed

doc/library/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Modules
2525
sparse/index
2626
tensor/index
2727
typed_list
28+
xtensor/index
2829

2930
.. module:: pytensor
3031
:platform: Unix, Windows

doc/library/xtensor/index.md

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
(libdoc_xtensor)=
2+
# `xtensor` -- XTensor operations
3+
4+
This module implements as abstraction layer on regular tensor operations, that behaves like Xarray.
5+
6+
A new type {class}`pytensor.xtensor.type.XTensorType`, generalizes the {class}`pytensor.tensor.TensorType`
7+
with the addition of a `dims` attribute, that labels the dimensions of the tensor.
8+
9+
Variables of XTensorType (i.e., {class}`pytensor.xtensor.type.XTensorVariable`s) are the symbolic counterpart
10+
to xarray DataArray objects.
11+
12+
The module implements several PyTensor operations {class}`pytensor.xtensor.basic.XOp`s, whose signature mimics that of
13+
xarray (and xarray_einstats) DataArray operations. These operations, unlike most regular PyTensor operations, cannot
14+
be directly evaluated, but require a rewrite (lowering) into a regular tensor graph that can itself be evaluated as usual.
15+
16+
Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray.
17+
If the existing XOps can be composed to produce the desired result, then we can use them directly.
18+
19+
## Coordinates
20+
For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`.
21+
The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor.
22+
Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor.
23+
24+
## Example
25+
26+
27+
```{testcode}
28+
29+
import pytensor.tensor as pt
30+
import pytensor.xtensor as ptx
31+
32+
a = pt.tensor("a", shape=(3,))
33+
b = pt.tensor("b", shape=(4,))
34+
35+
ax = ptx.as_xtensor(a, dims=["x"])
36+
bx = ptx.as_xtensor(b, dims=["y"])
37+
38+
zx = ax + bx
39+
assert zx.type == ptx.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4))
40+
41+
z = zx.values
42+
z.dprint()
43+
```
44+
45+
46+
```{testoutput}
47+
48+
TensorFromXTensor [id A]
49+
└─ XElemwise{scalar_op=Add()} [id B]
50+
├─ XTensorFromTensor{dims=('x',)} [id C]
51+
│ └─ a [id D]
52+
└─ XTensorFromTensor{dims=('y',)} [id E]
53+
└─ b [id F]
54+
```
55+
56+
Once we compile the graph, no XOps are left.
57+
58+
```{testcode}
59+
60+
import pytensor
61+
62+
with pytensor.config.change_flags(optimizer_verbose=True):
63+
fn = pytensor.function([a, b], z)
64+
65+
```
66+
67+
```{testoutput}
68+
69+
rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0)
70+
rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None
71+
rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None
72+
rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0)
73+
74+
```
75+
76+
```{testcode}
77+
78+
fn.dprint()
79+
```
80+
81+
```{testoutput}
82+
83+
Add [id A] 2
84+
├─ ExpandDims{axis=1} [id B] 1
85+
│ └─ a [id C]
86+
└─ ExpandDims{axis=0} [id D] 0
87+
└─ b [id E]
88+
```
89+
90+
91+
## Index
92+
93+
:::{toctree}
94+
:maxdepth: 1
95+
96+
module_functions
97+
math
98+
linalg
99+
random
100+
type
101+
:::

doc/library/xtensor/linalg.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
(libdoc_xtensor_linalg)=
2+
# `xtensor.linalg` -- Linear algebra operations
3+
4+
```{eval-rst}
5+
.. automodule:: pytensor.xtensor.linalg
6+
:members:
7+
```

doc/library/xtensor/math.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
(libdoc_xtensor_math)=
2+
# `xtensor.math` Mathematical operations
3+
4+
```{eval-rst}
5+
.. automodule:: pytensor.xtensor.math
6+
:members:
7+
:exclude-members: XDot, dot
8+
```
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
(libdoc_xtensor_module_function)=
2+
# `xtensor` -- Module level operations
3+
4+
```{eval-rst}
5+
.. automodule:: pytensor.xtensor
6+
:members: concat, dot
7+
```

doc/library/xtensor/random.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
(libdoc_xtensor_random)=
2+
# `xtensor.random` Random number generator operations
3+
4+
```{eval-rst}
5+
.. automodule:: pytensor.xtensor.random
6+
:members:
7+
```

doc/library/xtensor/type.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
(libdoc_xtenor_type)=
2+
3+
# `xtensor.type` -- Types and Variables
4+
5+
## XTensorVariable creation functions
6+
7+
```{eval-rst}
8+
.. currentmodule:: pytensor.xtensor.type
9+
.. autosummary::
10+
:members: xtensor, xtensor_constant, as_xtensor
11+
12+
```
13+
14+
## XTensor Type and Variable classes
15+
16+
```{eval-rst}
17+
.. currentmodule:: pytensor.xtensor.type
18+
.. autosummary::
19+
:members: XTensorType, XTensorVariable, XTensorConstant
20+
21+
```
22+
23+

pytensor/xtensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
22

33
import pytensor.xtensor.rewriting
4-
from pytensor.xtensor import linalg, random
4+
from pytensor.xtensor import linalg, math, random
55
from pytensor.xtensor.math import dot
66
from pytensor.xtensor.shape import concat
77
from pytensor.xtensor.type import (

pytensor/xtensor/linalg.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,31 @@ def cholesky(
1111
lower: bool = True,
1212
*,
1313
check_finite: bool = False,
14-
overwrite_a: bool = False,
1514
on_error: Literal["raise", "nan"] = "raise",
1615
dims: Sequence[str],
1716
):
17+
"""Compute the Cholesky decomposition of an XTensorVariable.
18+
19+
Parameters
20+
----------
21+
x : XTensorVariable
22+
The input variable to decompose.
23+
lower : bool, optional
24+
Whether to return the lower triangular matrix. Default is True.
25+
check_finite : bool, optional
26+
Whether to check that the input is finite. Default is False.
27+
on_error : {'raise', 'nan'}, optional
28+
What to do if the input is not positive definite. If 'raise', an error is raised.
29+
If 'nan', the output will contain NaNs. Default is 'raise'.
30+
dims : Sequence[str]
31+
The two core dimensions of the input variable, over which the Cholesky decomposition is computed.
32+
"""
1833
if len(dims) != 2:
1934
raise ValueError(f"Cholesky needs two dims, got {len(dims)}")
2035

2136
core_op = Cholesky(
2237
lower=lower,
2338
check_finite=check_finite,
24-
overwrite_a=overwrite_a,
2539
on_error=on_error,
2640
)
2741
core_dims = (
@@ -40,6 +54,30 @@ def solve(
4054
lower: bool = False,
4155
check_finite: bool = False,
4256
):
57+
"""Solve a system of linear equations using XTensorVariables.
58+
59+
Parameters
60+
----------
61+
a : XTensorVariable
62+
The left hand-side xtensor.
63+
b : XTensorVariable
64+
The right-hand side xtensor.
65+
dims : Sequence[str]
66+
The core dimensions over which to solve the linear equations.
67+
If length is 2, we are solving a matrix-vector equation,
68+
and the two dimensions should be present in `a`, but only one in `b`.
69+
If length is 3, we are solving a matrix-matrix equation,
70+
and two dimensions should be present in `a`, two in `b`, and only one should be shared.
71+
In both cases the shared dimension will not appear in the output.
72+
assume_a : str, optional
73+
The type of matrix `a` is assumed to be. Default is 'gen' (general).
74+
Options are ["gen", "sym", "her", "pos", "tridiagonal", "banded"].
75+
Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"].
76+
lower : bool, optional
77+
Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos".
78+
check_finite : bool, optional
79+
Whether to check that the input is finite. Default is False.
80+
"""
4381
a, b = as_xtensor(a), as_xtensor(b)
4482
input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]]
4583
output_core_dims: tuple[tuple[str] | tuple[str, str]]

0 commit comments

Comments
 (0)