Loss#

covariance_matrix(loss_fn: Callable, tree: PyTree[evermore.parameters.parameter.BaseParameter]) Float[Array, 'nparams nparams'][source]#

Derives a correlation matrix under the Laplace approximation.

The Fisher information matrix is inverted and re-scaled so that the resulting matrix has unit diagonal entries. This corresponds to the correlation matrix implied by the Laplace approximation.

Parameters:
  • loss_fn – Callable that accepts a PyTree of parameters and returns a scalar loss.

  • tree – PyTree containing the parameters of interest.

Returns:

Correlation matrix associated with the

supplied parameter PyTree (diagonal entries are 1).

Return type:

Float[Array, “nparams nparams”]

Examples

>>> import evermore as evm
>>> import jax.numpy as jnp
>>> params = {
...     "a": evm.Parameter(value=jnp.array([1.0])),
...     "b": evm.Parameter(value=jnp.array([2.0])),
... }
>>> def loss_fn(pytree):
...     return jnp.sum(
...         (pytree["a"].get_value() - 1.0) ** 2 + (pytree["b"].get_value() - 2.0) ** 2
...     )
>>> evm.loss.covariance_matrix(loss_fn, params).shape
(2, 2)
cramer_rao_uncertainty(loss_fn: Callable, tree: PT) PT[source]#

Estimates Cramér-Rao uncertainties under the Laplace approximation.

The uncertainties are the square roots of the diagonal of the Fisher information matrix for the provided parameter PyTree.

Parameters:
  • loss_fn – Callable that accepts a PyTree of parameters and returns a scalar loss.

  • tree – PyTree containing the parameters of interest.

Returns:

PyTree matching tree with each parameter replaced by its estimated

standard deviation.

Examples

>>> import evermore as evm
>>> import jax.numpy as jnp
>>> params = {
...     "a": evm.Parameter(value=jnp.array([1.0])),
...     "b": evm.Parameter(value=jnp.array([2.0])),
... }
>>> def loss_fn(pytree):
...     return jnp.sum(
...         (pytree["a"].get_value() - 1.0) ** 2 + (pytree["b"].get_value() - 2.0) ** 2
...     )
>>> uncertainties = evm.loss.cramer_rao_uncertainty(loss_fn, params)
>>> {name: value.shape for name, value in uncertainties.items()}
{'a': (1,), 'b': (1,)}
fisher_information_matrix(loss_fn: Callable, tree: PyTree[evermore.parameters.parameter.BaseParameter]) Float[Array, 'nparams nparams'][source]#

Builds the Fisher information matrix under the Laplace approximation.

The Fisher matrix is obtained by evaluating the Hessian of loss_fn and inverting it. Only differentiable parameters, as determined by flax.nnx.split and evermore.filter.is_dynamic_parameter, contribute.

Parameters:
  • loss_fn – Callable that accepts a PyTree of parameters and returns a scalar loss.

  • tree – PyTree containing the parameters of interest.

Returns:

Fisher information matrix evaluated at

the provided parameter values.

Return type:

Float[Array, “nparams nparams”]

Examples

>>> import evermore as evm
>>> import jax.numpy as jnp
>>> params = {
...     "a": evm.Parameter(value=jnp.array([1.0])),
...     "b": evm.Parameter(value=jnp.array([2.0])),
... }
>>> def loss_fn(pytree):
...     return jnp.sum(
...         (pytree["a"].get_value() - 1.0) ** 2 + (pytree["b"].get_value() - 2.0) ** 2
...     )
>>> evm.loss.fisher_information_matrix(loss_fn, params).shape
(2, 2)
get_log_probs(tree: PyTree[evermore.parameters.parameter.BaseParameter]) State[source]#

Computes log probabilities for every parameter in a PyTree.

The function iterates over each parameter, evaluates its prior distribution (if present), and returns an nnx.State whose leaves store the corresponding log probabilities.

Parameters:

tree – PyTree that may contain parameters and auxiliary nodes.

Returns:

State matching the input structure with log probabilities in

place of the original parameters.

Return type:

nnx.State

hessian_matrix(loss_fn: Callable, tree: PyTree[evermore.parameters.parameter.BaseParameter]) Float[Array, 'nparams nparams'][source]#

Computes the Hessian of a scalar loss with respect to dynamic parameters.

The function leverages flax.nnx.split to separate differentiable and static state and evaluates the Hessian of loss_fn at the current parameter values.

Parameters:
  • loss_fn – Callable that accepts a PyTree of parameters and returns a scalar loss.

  • tree – PyTree containing parameters and auxiliary nodes.

Returns:

Hessian of loss_fn with respect to the

dynamic parameter values.

Return type:

Float[Array, “nparams nparams”]

Examples

>>> import evermore as evm
>>> import jax.numpy as jnp
>>> params = {
...     "a": evm.Parameter(value=jnp.array([1.0])),
...     "b": evm.Parameter(value=jnp.array([2.0])),
... }
>>> def loss_fn(pytree):
...     return jnp.sum(
...         (pytree["a"].get_value() - 1.0) ** 2 + (pytree["b"].get_value() - 2.0) ** 2
...     )
>>> evm.loss.hessian_matrix(loss_fn, params).shape
(2, 2)