from __future__ import annotations
import jax
import jax.flatten_util
import jax.numpy as jnp
from flax import nnx
from jaxtyping import Array, Float, PyTree
from evermore.parameters.filter import is_dynamic_parameter, is_parameter
from evermore.parameters.parameter import PT, BaseParameter, V
from evermore.pdf import BasePDF, PoissonBase
__all__ = [
"sample_from_covariance_matrix",
"sample_from_priors",
]
def __dir__():
return __all__
[docs]
def sample_from_covariance_matrix(
rngs: nnx.Rngs,
params: PT,
*,
covariance_matrix: Float[Array, "nparams nparams"],
mask: PyTree[bool] | None = None,
n_samples: int = 1,
) -> PT:
"""Samples new parameter configurations from a multivariate normal.
Args:
rngs: ``nnx.Rngs`` container used to draw randomness.
params: PyTree of parameters providing the mean values.
covariance_matrix: Covariance matrix defining the multivariate normal.
mask: Optional PyTree indicating which parameters should be resampled.
n_samples: Number of samples to draw; adds a leading batch dimension when ``> 1``.
Returns:
PyTree with sampled parameter values replacing the originals.
Examples:
>>> import evermore as evm
>>> import jax.numpy as jnp
>>> from flax import nnx
>>> cov = jnp.array([[1.0, 0.5], [0.5, 2.0]])
>>> params = {
... "a": evm.Parameter(value=jnp.array([1.0])),
... "b": evm.Parameter(value=jnp.array([2.0])),
... }
>>> rngs = nnx.Rngs(0)
>>> samples = evm.sample.sample_from_covariance_matrix(
... rngs, params, covariance_matrix=cov, n_samples=3
... )
>>> samples["a"].get_value().shape
(3, 1)
"""
# get the value & make sure it has at least 1d so we insert a batch dim later
graphdef, params_state, rest = nnx.split(params, is_dynamic_parameter, ...)
values = jax.tree.map(jnp.atleast_1d, nnx.pure(params_state))
flat_values, unravel_fn = jax.flatten_util.ravel_pytree(values)
# sample parameter sets from the correlation matrix (centered around `flat_values`)
flat_sampled_values = rngs.multivariate_normal(
mean=flat_values,
cov=covariance_matrix,
shape=(n_samples,),
)
# insert batch dim
sampled_param_values = jax.vmap(unravel_fn)(flat_sampled_values)
def _update(path, variable, value):
del path # unused
return variable.replace(value=value)
# using jax.tree.map here to not do inplace updates
sampled_params_state = jax.tree.map_with_path(
_update,
params_state,
sampled_param_values,
is_leaf=is_parameter,
is_leaf_takes_path=True,
)
return nnx.merge(graphdef, sampled_params_state, rest)
[docs]
def sample_from_priors(rngs: nnx.Rngs, params: PT) -> PT:
"""Samples independent values from each parameter's prior distribution.
Args:
rngs: ``nnx.Rngs`` container used to draw randomness.
params: PyTree containing the parameters to sample.
Returns:
PyTree mirroring ``params`` with sampled values substituted in place of ``.get_value()``.
Examples:
>>> import evermore as evm
>>> import jax
>>> from flax import nnx
>>> params = {
... "a": evm.Parameter(value=0.0),
... "b": evm.NormalParameter(value=0.0),
... }
>>> samples = evm.sample.sample_from_priors(nnx.Rngs(0), params)
>>> isinstance(samples["b"].get_value(), jax.Array)
True
"""
graphdef, params_state, rest = nnx.split(params, is_parameter, ...)
def _sample_from_prior(path, param: BaseParameter[V]) -> BaseParameter[V]:
del path # unused
if isinstance(param.prior, BasePDF):
pdf = param.prior
# Sample new value from the prior pdf
sampled_value = pdf.sample(rngs(), shape=param.get_value().shape)
# TODO: this is not correct I assume
if isinstance(pdf, PoissonBase):
sampled_value = (sampled_value / pdf.lamb) - 1
return param.replace(value=sampled_value) # ty:ignore[invalid-return-type]
# can't sample if there's:
# - no pdf to sample from
return param
# Sample for each parameter
sampled_params_state = nnx.map_state(_sample_from_prior, params_state)
return nnx.merge(graphdef, sampled_params_state, rest)