Source code for evermore.parameters.transform
from __future__ import annotations
import abc
import typing as tp
import jax
import jax.numpy as jnp
from flax import nnx
from jax.experimental import checkify
from jaxtyping import PyTree
from evermore.parameters.filter import is_parameter
from evermore.parameters.parameter import BaseParameter, V
from evermore.util import float_array
__all__ = [
"BaseParameterTransformation",
"MinuitTransform",
"SoftPlusTransform",
"unwrap",
"wrap",
]
def __dir__():
return __all__
# error handling taken from https://github.com/JaxGaussianProcesses/GPJax/blob/main/gpjax/parameters.py
def _safe_assert(fn: tp.Callable, value: V, **kwargs) -> None:
error, _ = fn(value, **kwargs)
checkify.check_error(error)
return
@checkify.checkify
def _check_in_bounds(value: V, lower: V, upper: V) -> None:
"""Check if a value is bounded between lower and upper.
Args:
value: The value to check.
lower: The lower bound.
upper: The upper bound.
Raises:
ValueError: If any element of value is outside the bounds.
"""
checkify.check(
jnp.all((value > lower) & (value < upper)),
"value needs to be bounded between {lower} and {upper}, got {value}",
value=value,
lower=lower,
upper=upper,
)
@checkify.checkify
def _check_is_finite(value: V) -> None:
"""Check if a value is finite.
Args:
value: The value to check.
Raises:
ValueError: If any element of value is not finite.
"""
checkify.check(
jnp.all(jnp.isfinite(value)),
"value needs to be finite",
)
@checkify.checkify
def _check_is_non_negative(value: V) -> None:
"""Check if a value is element-wise non-negative.
Args:
value: Values to validate.
Raises:
ValueError: If any element is negative.
"""
checkify.check(
jnp.all(value >= 0), "value needs to be non-negative, got {value}", value=value
)
[docs]
def unwrap(params: PyTree[BaseParameter]) -> PyTree[BaseParameter]:
"""Applies registered transformations to move parameters into unconstrained space.
Args:
params: PyTree that may contain parameters with attached transformations.
Returns:
PyTree where each parameter has been transformed via ``unwrap``.
"""
def _unwrap(path, param: BaseParameter[V]) -> BaseParameter[V]:
del path # unused
if param.transform is None:
return param
return param.transform.unwrap(param)
graphdef, params_state, rest = nnx.split(params, is_parameter, ...)
params_state_t = nnx.map_state(_unwrap, params_state)
return nnx.merge(graphdef, params_state_t, rest)
[docs]
def wrap(params: PyTree[BaseParameter]) -> PyTree[BaseParameter]:
"""Applies registered transformations to move parameters back to constrained space.
Args:
params: PyTree that may contain parameters with attached transformations.
Returns:
PyTree where each parameter has been transformed via ``wrap``.
"""
def _wrap(path, param: BaseParameter[V]) -> BaseParameter[V]:
del path # unused
if param.transform is None:
return param
return param.transform.wrap(param)
graphdef, params_state, rest = nnx.split(params, is_parameter, ...)
params_state_t = nnx.map_state(_wrap, params_state)
return nnx.merge(graphdef, params_state_t, rest)
[docs]
class BaseParameterTransformation(nnx.Module):
"""Abstract interface for parameter transformations.
Subclasses provide ``unwrap``/``wrap`` implementations that translate between
constrained and unconstrained representations of parameters.
"""
[docs]
@abc.abstractmethod
def unwrap(self, parameter: BaseParameter[V]) -> BaseParameter[V]:
"""Transforms a parameter from constrained to unconstrained space.
Args:
parameter: Parameter to transform.
Returns:
BaseParameter: Transformed parameter instance.
"""
[docs]
@abc.abstractmethod
def wrap(self, parameter: BaseParameter[V]) -> BaseParameter[V]:
"""Transforms a parameter from unconstrained space back to its original domain.
Args:
parameter: Parameter to transform.
Returns:
BaseParameter: Parameter in its original space.
"""
[docs]
class MinuitTransform(BaseParameterTransformation):
"""Implements MINUIT-style transformations for bounded parameters.
Both lower and upper bounds must be finite for the transformation to be well-defined.
References:
MINUIT User's Guide, Section 1.2.1 ``The transformation for parameters with limits``.
Examples:
>>> import evermore as evm
>>> from evermore.parameters import transform as tr
>>> minuit = tr.MinuitTransform()
>>> params = {
... "a": evm.Parameter(2.0, lower=-0.1, upper=2.2, transform=minuit),
... "b": evm.Parameter(0.1, lower=0.0, upper=1.1, transform=minuit),
... }
>>> unconstrained = tr.unwrap(params)
>>> restored = tr.wrap(unconstrained)
>>> restored["a"].value == params["a"].value
Array(True, dtype=bool)
"""
def _check_and_regularize(self, parameter: BaseParameter[V]) -> tuple[V, V, V]:
# this is not allowed here
if (parameter.lower is None and parameter.upper is not None) or (
parameter.lower is not None and parameter.upper is None
):
msg = f"{parameter} must have both lower and upper boundaries set, or none of them."
raise ValueError(msg)
value = float_array(parameter.value)
lower = float_array(parameter.lower)
upper = float_array(parameter.upper)
# check for finite boundaries
_safe_assert(_check_is_finite, lower)
_safe_assert(_check_is_finite, upper)
return value, lower, upper # ty:ignore[invalid-return-type]
[docs]
def unwrap(self, parameter: BaseParameter[V]) -> BaseParameter[V]:
# short-cut
if parameter.lower is None and parameter.upper is None:
return parameter
value, lower, upper = self._check_and_regularize(parameter)
# for unwrapping, we need to make sure the value is within the boundaries initially
_safe_assert(
_check_in_bounds,
value,
lower=lower,
upper=upper,
)
# this formula turns user-provided "external" parameter values into "internal" values
new_value = jnp.arcsin(
2.0 * (value - lower) / (upper - lower) # type: ignore[operator]
- 1.0
)
return parameter.replace(value=new_value) # ty:ignore[invalid-return-type]
[docs]
def wrap(self, parameter: BaseParameter[V]) -> BaseParameter[V]:
# short-cut
if parameter.lower is None and parameter.upper is None:
return parameter
value, lower, upper = self._check_and_regularize(parameter)
# this formula turns "internal" parameter values into "external" values
new_value = lower + (upper - lower) / 2.0 * (jnp.sin(value) + 1.0)
return parameter.replace(value=new_value) # ty:ignore[invalid-return-type]
[docs]
class SoftPlusTransform(BaseParameterTransformation):
"""Ensures parameters remain positive by using the softplus bijection.
This transformation does not require explicit bounds; ``unwrap`` maps to the
unconstrained real line and ``wrap`` maps back to the positive reals.
Examples:
>>> import evermore as evm
>>> from evermore.parameters import transform as tr
>>> positive = tr.SoftPlusTransform()
>>> params = {
... "a": evm.Parameter(2.0, transform=positive),
... "b": evm.Parameter(0.1, transform=positive),
... }
>>> unconstrained = tr.unwrap(params)
>>> restored = tr.wrap(unconstrained)
>>> restored["b"].value
Array(0.1, dtype=float32)
"""
[docs]
def unwrap(self, parameter: BaseParameter[V]) -> BaseParameter[V]:
# from: https://github.com/danielward27/paramax/blob/main/paramax/utils.py
"""Applies the inverse softplus transformation after validating the value."""
value = float_array(parameter.value)
_safe_assert(_check_is_non_negative, value)
new_value = jnp.log(-jnp.expm1(-value)) + value
return parameter.replace(value=new_value) # ty:ignore[invalid-return-type]
[docs]
def wrap(self, parameter: BaseParameter[V]) -> BaseParameter[V]:
new_value = jax.nn.softplus(float_array(parameter.value))
return parameter.replace(value=new_value) # ty:ignore[invalid-return-type]