Sampling Parameters#
- sample_from_covariance_matrix(rngs: Rngs, params: PT, *, covariance_matrix: Float[Array, 'nparams nparams'], mask: PyTree[bool] | None = None, n_samples: int = 1) PT[source]#
Samples new parameter configurations from a multivariate normal.
- Parameters:
rngs –
nnx.Rngscontainer used to draw randomness.params – PyTree of parameters providing the mean values.
covariance_matrix – Covariance matrix defining the multivariate normal.
mask – Optional PyTree indicating which parameters should be resampled.
n_samples – Number of samples to draw; adds a leading batch dimension when
> 1.
- Returns:
PyTree with sampled parameter values replacing the originals.
Examples
>>> import evermore as evm >>> import jax.numpy as jnp >>> from flax import nnx >>> cov = jnp.array([[1.0, 0.5], [0.5, 2.0]]) >>> params = { ... "a": evm.Parameter(value=jnp.array([1.0])), ... "b": evm.Parameter(value=jnp.array([2.0])), ... } >>> rngs = nnx.Rngs(0) >>> samples = evm.sample.sample_from_covariance_matrix( ... rngs, params, covariance_matrix=cov, n_samples=3 ... ) >>> samples["a"].get_value().shape (3, 1)
- sample_from_priors(rngs: Rngs, params: PT) PT[source]#
Samples independent values from each parameter’s prior distribution.
- Parameters:
rngs –
nnx.Rngscontainer used to draw randomness.params – PyTree containing the parameters to sample.
- Returns:
PyTree mirroring
paramswith sampled values substituted in place of.get_value().
Examples
>>> import evermore as evm >>> import jax >>> from flax import nnx >>> params = { ... "a": evm.Parameter(value=0.0), ... "b": evm.NormalParameter(value=0.0), ... } >>> samples = evm.sample.sample_from_priors(nnx.Rngs(0), params) >>> isinstance(samples["b"].get_value(), jax.Array) True