Utilities#

dump_hlo_graph(fun: Callable, *args: Any, **kwargs: Any) str[source]#

Returns the HLO dot graph of fun evaluated at the inputs.

Parameters:
  • fun – Callable to trace.

  • *args – Positional arguments passed to fun.

  • **kwargs – Keyword arguments forwarded to fun.

Returns:

dot graph describing the lowered HLO program.

Return type:

str

Examples

>>> import pathlib
>>> import jax.numpy as jnp
>>> def f(x):
...     return x + 1.0
>>> graph = dump_hlo_graph(f, jnp.array([1.0, 2.0, 3.0]))
>>> pathlib.Path("graph.gv").write_text(graph, encoding="ascii")
143
dump_jaxpr(fun: Callable, *args: Any, **kwargs: Any) str[source]#

Pretty-prints the Jaxpr of fun evaluated at the given arguments.

Parameters:
  • fun – Callable to analyse.

  • *args – Positional arguments passed to fun.

  • **kwargs – Keyword arguments forwarded to fun.

Returns:

Human-readable representation of the traced Jaxpr.

Return type:

str

Examples

>>> import jax
>>> import jax.numpy as jnp
>>> def f(x):
...     return jnp.sin(x) ** 2 + jnp.cos(x) ** 2
>>> print(dump_jaxpr(f, jnp.array([1.0, 2.0, 3.0])))
{ lambda ; a:f32[3]. let ... }
tree_stack(trees: list[PyTree[jaxtyping.Shaped[Array, '...']]], *, broadcast_leaves: bool = False) PyTree[jaxtyping.Shaped[Array, 'batch_dim ...']][source]#

Stacks a list of PyTrees into a batched PyTree (AOS → SOA).

Parameters:
  • trees – Sequence of PyTrees with identical static structure.

  • broadcast_leaves – Whether to broadcast each leaf to a common shape before stacking.

Returns:

PyTree whose leaves now carry an

additional leading batch dimension.

Return type:

PyTree[Shaped[Array, “batch_dim …”]]

Examples

>>> import evermore as evm
>>> import jax.numpy as jnp
>>> modifiers = [
...     evm.NormalParameter().scale_log_asymmetric(up=jnp.array([1.1]), down=jnp.array([0.9])),
...     evm.NormalParameter().scale_log_asymmetric(up=jnp.array([1.2]), down=jnp.array([0.8])),
... ]
>>> stacked = evm.util.tree_stack(modifiers)
>>> stacked.parameter.get_value().shape
(2, 1)