Utilities#
- dump_hlo_graph(fun: Callable, *args: Any, **kwargs: Any) str[source]#
Returns the HLO
dotgraph offunevaluated at the inputs.- Parameters:
fun – Callable to trace.
*args – Positional arguments passed to
fun.**kwargs – Keyword arguments forwarded to
fun.
- Returns:
dotgraph describing the lowered HLO program.- Return type:
Examples
>>> import pathlib >>> import jax.numpy as jnp >>> def f(x): ... return x + 1.0 >>> graph = dump_hlo_graph(f, jnp.array([1.0, 2.0, 3.0])) >>> pathlib.Path("graph.gv").write_text(graph, encoding="ascii") 143
- dump_jaxpr(fun: Callable, *args: Any, **kwargs: Any) str[source]#
Pretty-prints the Jaxpr of
funevaluated at the given arguments.- Parameters:
fun – Callable to analyse.
*args – Positional arguments passed to
fun.**kwargs – Keyword arguments forwarded to
fun.
- Returns:
Human-readable representation of the traced Jaxpr.
- Return type:
Examples
>>> import jax >>> import jax.numpy as jnp >>> def f(x): ... return jnp.sin(x) ** 2 + jnp.cos(x) ** 2 >>> print(dump_jaxpr(f, jnp.array([1.0, 2.0, 3.0]))) { lambda ; a:f32[3]. let ... }
- tree_stack(trees: list[PyTree[jaxtyping.Shaped[Array, '...']]], *, broadcast_leaves: bool = False) PyTree[jaxtyping.Shaped[Array, 'batch_dim ...']][source]#
Stacks a list of PyTrees into a batched PyTree (AOS → SOA).
- Parameters:
trees – Sequence of PyTrees with identical static structure.
broadcast_leaves – Whether to broadcast each leaf to a common shape before stacking.
- Returns:
- PyTree whose leaves now carry an
additional leading batch dimension.
- Return type:
PyTree[Shaped[Array, “batch_dim …”]]
Examples
>>> import evermore as evm >>> import jax.numpy as jnp >>> modifiers = [ ... evm.NormalParameter().scale_log_asymmetric(up=jnp.array([1.1]), down=jnp.array([0.9])), ... evm.NormalParameter().scale_log_asymmetric(up=jnp.array([1.2]), down=jnp.array([0.8])), ... ] >>> stacked = evm.util.tree_stack(modifiers) >>> stacked.parameter.get_value().shape (2, 1)