Skip to content

Commit b874978

Browse files
authored
Merge pull request #92 from Ericgig/v0.1.1
V0.1.1
2 parents cf06d52 + f10d927 commit b874978

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.1.0
1+
0.1.1

src/qutip_jax/qobjevo.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@
1414
from functools import partial
1515

1616

17+
try:
18+
# Pre jac 0.6.1
19+
PjitFunction = jaxlib.xla_extension.PjitFunction
20+
except AttributeError:
21+
# Post jac 0.6.1
22+
PjitFunction = jaxlib._jax.PjitFunction
23+
24+
1725
__all__ = []
1826

1927

@@ -146,7 +154,7 @@ def unflatten(cls, aux_data, children):
146154
)
147155

148156

149-
coefficient_builders[jaxlib.xla_extension.PjitFunction] = JaxJitCoeff
157+
coefficient_builders[PjitFunction] = JaxJitCoeff
150158
jax.tree_util.register_pytree_node(
151159
JaxJitCoeff, JaxJitCoeff.flatten, JaxJitCoeff.unflatten
152160
)

tests/test_unary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _inv_jax(matrix):
6060
data.add(
6161
matrix,
6262
data.diag(
63-
[1.1] * matrix.shape[0], shape=matrix.shape, dtype="JaxArray"
63+
[2.0] * matrix.shape[0], shape=matrix.shape, dtype="JaxArray"
6464
),
6565
)
6666
)

0 commit comments

Comments
 (0)