Binned Likelihood#
The binned likelihood quantifies the agreement between a model and data in terms of histograms. It is defined as follows:
where \(\lambda_i(\phi)\) is the model prediction for bin \(i\), \(d_i\) is the observed data in bin \(i\), and \(\pi_j\left(\phi_j\right)\) is the prior probability density function (BasePDF) for parameter \(j\). The first product is a Poisson per bin, and the second product is the constraint from each prior BasePDF.
Key to constructing this likelihood is the definition of the model \(\lambda(\phi)\) as a function of parameters \(\phi\). evermore provides building blocks to define these in a modular way.
These building blocks include:
evm.Parameter: A class that represents a parameter with a value, name, bounds, and prior BasePDF used as constraint.
evm.BaseEffect: Effects describe how data, e.g., histogram bins, may be varied.
evm.Modifier: Modifiers combine evm.Effects and evm.Parameters to modify data.
The negative log-likelihood (NLL) function of Eq.(1) can be implemented with evermore as follows (copy & paste the following snippet to start write a new statistical model):
from flax import nnx
import jax
import jax.numpy as jnp
from jaxtyping import PyTree, Array
import evermore as evm
# -- parameter definition --
# params: PyTree[evm.Parameter] = ...
# graphdef, dynamic_params, static_params = nnx.split(
# params, evm.filter.is_dynamic_parameter, ...
# )
# -- model definition --
# def model(params: PyTree[evm.Parameter], hists: PyTree[Array]) -> PyTree[Array]:
# ...
# -- NLL definition --
@nnx.jit
def nll(dynamic_params, args):
graphdef, static_params, hists, observation = args
params = nnx.merge(graphdef, dynamic_params, static_params)
expectations = model(params, hists)
# first product of Eq. 1 (Poisson term)
loss_val = evm.pdf.Poisson(lamb=evm.util.sum_over_leaves(expectations)).log_prob(
observation
).sum()
# second product of Eq. 1 (constraint)
constraints = evm.loss.get_log_probs(params)
# for parameters with `.get_value().size > 1` (jnp.sum the constraints)
constraints = jax.tree.map(jnp.sum, constraints)
loss_val += evm.util.sum_over_leaves(constraints)
return -jnp.sum(loss_val)
# args = (graphdef, static_params, hists, observation)
# loss_val = nll(dynamic_params, args)
Building the parameters and the model is key here. The relevant parts to build parameters and a model are described in Building Blocks.