Loss#
- covariance_matrix(loss_fn: Callable, tree: PyTree[evermore.parameters.parameter.BaseParameter]) Float[Array, 'nparams nparams'][source]#
Derives a correlation matrix under the Laplace approximation.
The Fisher information matrix is inverted and re-scaled so that the resulting matrix has unit diagonal entries. This corresponds to the correlation matrix implied by the Laplace approximation.
- Parameters:
loss_fn – Callable that accepts a PyTree of parameters and returns a scalar loss.
tree – PyTree containing the parameters of interest.
- Returns:
- Correlation matrix associated with the
supplied parameter PyTree (diagonal entries are 1).
- Return type:
Float[Array, “nparams nparams”]
Examples
>>> import evermore as evm >>> import jax.numpy as jnp >>> params = { ... "a": evm.Parameter(value=jnp.array([1.0])), ... "b": evm.Parameter(value=jnp.array([2.0])), ... } >>> def loss_fn(pytree): ... return jnp.sum( ... (pytree["a"].value - 1.0) ** 2 + (pytree["b"].value - 2.0) ** 2 ... ) >>> evm.loss.covariance_matrix(loss_fn, params).shape (2, 2)
- cramer_rao_uncertainty(loss_fn: Callable, tree: PyTree[evermore.parameters.parameter.BaseParameter]) PyTree[evermore.parameters.parameter.BaseParameter][source]#
Estimates Cramér-Rao uncertainties under the Laplace approximation.
The uncertainties are the square roots of the diagonal of the Fisher information matrix for the provided parameter PyTree.
- Parameters:
loss_fn – Callable that accepts a PyTree of parameters and returns a scalar loss.
tree – PyTree containing the parameters of interest.
- Returns:
- PyTree matching
treewith each parameter replaced by its estimated standard deviation.
- PyTree matching
Examples
>>> import evermore as evm >>> import jax.numpy as jnp >>> params = { ... "a": evm.Parameter(value=jnp.array([1.0])), ... "b": evm.Parameter(value=jnp.array([2.0])), ... } >>> def loss_fn(pytree): ... return jnp.sum( ... (pytree["a"].value - 1.0) ** 2 + (pytree["b"].value - 2.0) ** 2 ... ) >>> uncertainties = evm.loss.cramer_rao_uncertainty(loss_fn, params) >>> {name: value.shape for name, value in uncertainties.items()} {'a': (1,), 'b': (1,)}
- fisher_information_matrix(loss_fn: Callable, tree: PyTree[evermore.parameters.parameter.BaseParameter]) Float[Array, 'nparams nparams'][source]#
Builds the Fisher information matrix under the Laplace approximation.
The Fisher matrix is obtained by evaluating the Hessian of
loss_fnand inverting it. Only differentiable parameters, as determined byflax.nnx.splitandevermore.filter.is_dynamic_parameter, contribute.- Parameters:
loss_fn – Callable that accepts a PyTree of parameters and returns a scalar loss.
tree – PyTree containing the parameters of interest.
- Returns:
- Fisher information matrix evaluated at
the provided parameter values.
- Return type:
Float[Array, “nparams nparams”]
Examples
>>> import evermore as evm >>> import jax.numpy as jnp >>> params = { ... "a": evm.Parameter(value=jnp.array([1.0])), ... "b": evm.Parameter(value=jnp.array([2.0])), ... } >>> def loss_fn(pytree): ... return jnp.sum( ... (pytree["a"].value - 1.0) ** 2 + (pytree["b"].value - 2.0) ** 2 ... ) >>> evm.loss.fisher_information_matrix(loss_fn, params).shape (2, 2)
- get_log_probs(tree: PyTree[evermore.parameters.parameter.BaseParameter]) State[source]#
Computes log probabilities for every parameter in a PyTree.
The function iterates over each parameter, evaluates its prior distribution (if present), and returns an
nnx.Statewhose leaves store the corresponding log probabilities.- Parameters:
tree – PyTree that may contain parameters and auxiliary nodes.
- Returns:
- State matching the input structure with log probabilities in
place of the original parameters.
- Return type:
nnx.State
- hessian_matrix(loss_fn: Callable, tree: PyTree[evermore.parameters.parameter.BaseParameter]) Float[Array, 'nparams nparams'][source]#
Computes the Hessian of a scalar loss with respect to dynamic parameters.
The function leverages
flax.nnx.splitto separate differentiable and static state and evaluates the Hessian ofloss_fnat the current parameter values.- Parameters:
loss_fn – Callable that accepts a PyTree of parameters and returns a scalar loss.
tree – PyTree containing parameters and auxiliary nodes.
- Returns:
- Hessian of
loss_fnwith respect to the dynamic parameter values.
- Hessian of
- Return type:
Float[Array, “nparams nparams”]
Examples
>>> import evermore as evm >>> import jax.numpy as jnp >>> params = { ... "a": evm.Parameter(value=jnp.array([1.0])), ... "b": evm.Parameter(value=jnp.array([2.0])), ... } >>> def loss_fn(pytree): ... return jnp.sum( ... (pytree["a"].value - 1.0) ** 2 + (pytree["b"].value - 2.0) ** 2 ... ) >>> evm.loss.hessian_matrix(loss_fn, params).shape (2, 2)