Sampling Parameters#

sample_from_covariance_matrix(rngs: Rngs, params: PT, *, covariance_matrix: Float[Array, 'nparams nparams'], mask: PyTree[bool] | None = None, n_samples: int = 1) PT[source]#

Samples new parameter configurations from a multivariate normal.

Parameters:
  • rngsnnx.Rngs container used to draw randomness.

  • params – PyTree of parameters providing the mean values.

  • covariance_matrix – Covariance matrix defining the multivariate normal.

  • mask – Optional PyTree indicating which parameters should be resampled.

  • n_samples – Number of samples to draw; adds a leading batch dimension when > 1.

Returns:

PyTree with sampled parameter values replacing the originals.

Examples

>>> import evermore as evm
>>> import jax.numpy as jnp
>>> from flax import nnx
>>> cov = jnp.array([[1.0, 0.5], [0.5, 2.0]])
>>> params = {
...     "a": evm.Parameter(value=jnp.array([1.0])),
...     "b": evm.Parameter(value=jnp.array([2.0])),
... }
>>> rngs = nnx.Rngs(0)
>>> samples = evm.sample.sample_from_covariance_matrix(
...     rngs, params, covariance_matrix=cov, n_samples=3
... )
>>> samples["a"].get_value().shape
(3, 1)
sample_from_priors(rngs: Rngs, params: PT) PT[source]#

Samples independent values from each parameter’s prior distribution.

Parameters:
  • rngsnnx.Rngs container used to draw randomness.

  • params – PyTree containing the parameters to sample.

Returns:

PyTree mirroring params with sampled values substituted in place of .get_value().

Examples

>>> import evermore as evm
>>> import jax
>>> from flax import nnx
>>> params = {
...     "a": evm.Parameter(value=0.0),
...     "b": evm.NormalParameter(value=0.0),
... }
>>> samples = evm.sample.sample_from_priors(nnx.Rngs(0), params)
>>> isinstance(samples["b"].get_value(), jax.Array)
True