Source code for evermore.util
from __future__ import annotations
import operator
from collections.abc import Callable
from functools import partial
from typing import Any
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PyTree, Shaped
__all__ = [
"dump_hlo_graph",
"dump_jaxpr",
"float_array",
"sum_over_leaves",
"tree_stack",
]
def __dir__():
return __all__
def float_array(x: Any) -> Float[Array, "..."]: # noqa: UP037
return jnp.asarray(x, dtype=jnp.result_type(jnp.float_))
def sum_over_leaves(tree: PyTree) -> Array:
return jax.tree.reduce_associative(operator.add, tree)
[docs]
def tree_stack(
trees: list[PyTree[Shaped[Array, "..."]]], # noqa: UP037
*,
broadcast_leaves: bool = False,
) -> PyTree[Shaped[Array, "batch_dim ..."]]:
"""Stacks a list of PyTrees into a batched PyTree (AOS → SOA).
Args:
trees: Sequence of PyTrees with identical static structure.
broadcast_leaves: Whether to broadcast each leaf to a common shape before stacking.
Returns:
PyTree[Shaped[Array, "batch_dim ..."]]: PyTree whose leaves now carry an
additional leading batch dimension.
Examples:
>>> import evermore as evm
>>> import jax.numpy as jnp
>>> modifiers = [
... evm.NormalParameter().scale_log_asymmetric(up=jnp.array([1.1]), down=jnp.array([0.9])),
... evm.NormalParameter().scale_log_asymmetric(up=jnp.array([1.2]), down=jnp.array([0.8])),
... ]
>>> stacked = evm.util.tree_stack(modifiers)
>>> stacked.parameter.get_value().shape
(2, 1)
"""
# check that all trees have the same structure
first_treedef = jax.tree.structure(trees[0])
for tree in trees[1:]:
other_treedef = jax.tree.structure(tree)
if other_treedef != first_treedef:
msg = (
"All static trees must have the same structure. "
f"Got {other_treedef} and {first_treedef}"
)
raise ValueError(msg)
# actual stacking function for leaves
def batch_axis_stack(*leaves: Array) -> Array:
leaves = jax.tree.map(jnp.atleast_1d, leaves) # ensure at least 1D
if broadcast_leaves:
shape = jnp.broadcast_shapes(*(leaf.shape for leaf in leaves))
return jnp.stack(
jax.tree.map(partial(jnp.broadcast_to, shape=shape), leaves)
)
return jnp.stack(leaves, axis=0)
return jax.tree.map(batch_axis_stack, *trees)
[docs]
def dump_jaxpr(fun: Callable, *args: Any, **kwargs: Any) -> str:
"""Pretty-prints the Jaxpr of ``fun`` evaluated at the given arguments.
Args:
fun: Callable to analyse.
*args: Positional arguments passed to ``fun``.
**kwargs: Keyword arguments forwarded to ``fun``.
Returns:
str: Human-readable representation of the traced Jaxpr.
Examples:
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x):
... return jnp.sin(x) ** 2 + jnp.cos(x) ** 2
>>> print(dump_jaxpr(f, jnp.array([1.0, 2.0, 3.0])))
{ lambda ; a:f32[3]. let ... }
"""
jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
return jaxpr.pretty_print(name_stack=True)
[docs]
def dump_hlo_graph(fun: Callable, *args: Any, **kwargs: Any) -> str:
"""Returns the HLO ``dot`` graph of ``fun`` evaluated at the inputs.
Args:
fun: Callable to trace.
*args: Positional arguments passed to ``fun``.
**kwargs: Keyword arguments forwarded to ``fun``.
Returns:
str: ``dot`` graph describing the lowered HLO program.
Examples:
>>> import pathlib
>>> import jax.numpy as jnp
>>> def f(x):
... return x + 1.0
>>> graph = dump_hlo_graph(f, jnp.array([1.0, 2.0, 3.0]))
>>> pathlib.Path("graph.gv").write_text(graph, encoding="ascii")
143
"""
return jax.jit(fun).lower(*args, **kwargs).compiler_ir("hlo").as_hlo_dot_graph() # ty:ignore[possibly-missing-attribute]