Source code for evermore.parameters.transform

from __future__ import annotations

import abc
import typing as tp

import jax
import jax.numpy as jnp
from flax import nnx
from jax.experimental import checkify

from evermore.parameters.filter import is_parameter
from evermore.parameters.parameter import PT, BaseParameter, V
from evermore.util import float_array

__all__ = [
    "BaseParameterTransformation",
    "MinuitTransform",
    "SoftPlusTransform",
    "unwrap",
    "wrap",
]


def __dir__():
    return __all__


# error handling taken from https://github.com/JaxGaussianProcesses/GPJax/blob/main/gpjax/parameters.py
def _safe_assert(fn: tp.Callable, value: V, **kwargs) -> None:
    error, _ = fn(value, **kwargs)
    checkify.check_error(error)
    return


@checkify.checkify
def _check_in_bounds(value: V, lower: V, upper: V) -> None:
    """Check if a value is bounded between lower and upper.

    Args:
        value: The value to check.
        lower: The lower bound.
        upper: The upper bound.

    Raises:
        ValueError: If any element of value is outside the bounds.
    """
    checkify.check(
        jnp.all((value > lower) & (value < upper)),
        "value needs to be bounded between {lower} and {upper}, got {value}",
        value=value,
        lower=lower,
        upper=upper,
    )


@checkify.checkify
def _check_is_finite(value: V) -> None:
    """Check if a value is finite.

    Args:
        value: The value to check.

    Raises:
        ValueError: If any element of value is not finite.
    """
    checkify.check(
        jnp.all(jnp.isfinite(value)),
        "value needs to be finite",
    )


@checkify.checkify
def _check_is_non_negative(value: V) -> None:
    """Check if a value is element-wise non-negative.

    Args:
        value: Values to validate.

    Raises:
        ValueError: If any element is negative.
    """
    checkify.check(
        jnp.all(value >= 0), "value needs to be non-negative, got {value}", value=value
    )


[docs] def unwrap(params: PT) -> PT: """Applies registered transformations to move parameters into unconstrained space. Args: params: PyTree that may contain parameters with attached transformations. Returns: PyTree where each parameter has been transformed via ``unwrap``. """ def _unwrap(path, param: BaseParameter[V]) -> BaseParameter[V]: del path # unused if param.transform is None: return param return param.transform.unwrap(param) graphdef, params_state, rest = nnx.split(params, is_parameter, ...) params_state_t = nnx.map_state(_unwrap, params_state) return nnx.merge(graphdef, params_state_t, rest)
[docs] def wrap(params: PT) -> PT: """Applies registered transformations to move parameters back to constrained space. Args: params: PyTree that may contain parameters with attached transformations. Returns: PyTree where each parameter has been transformed via ``wrap``. """ def _wrap(path, param: BaseParameter[V]) -> BaseParameter[V]: del path # unused if param.transform is None: return param return param.transform.wrap(param) graphdef, params_state, rest = nnx.split(params, is_parameter, ...) params_state_t = nnx.map_state(_wrap, params_state) return nnx.merge(graphdef, params_state_t, rest)
[docs] class BaseParameterTransformation(nnx.Module): """Abstract interface for parameter transformations. Subclasses provide ``unwrap``/``wrap`` implementations that translate between constrained and unconstrained representations of parameters. """
[docs] @abc.abstractmethod def unwrap(self, parameter: BaseParameter[V]) -> BaseParameter[V]: """Transforms a parameter from constrained to unconstrained space. Args: parameter: Parameter to transform. Returns: BaseParameter: Transformed parameter instance. """
[docs] @abc.abstractmethod def wrap(self, parameter: BaseParameter[V]) -> BaseParameter[V]: """Transforms a parameter from unconstrained space back to its original domain. Args: parameter: Parameter to transform. Returns: BaseParameter: Parameter in its original space. """
[docs] class MinuitTransform(BaseParameterTransformation): """Implements MINUIT-style transformations for bounded parameters. Both lower and upper bounds must be finite for the transformation to be well-defined. References: MINUIT User's Guide, Section 1.2.1 ``The transformation for parameters with limits``. Examples: >>> import evermore as evm >>> from evermore.parameters import transform as tr >>> minuit = tr.MinuitTransform() >>> params = { ... "a": evm.Parameter(2.0, lower=-0.1, upper=2.2, transform=minuit), ... "b": evm.Parameter(0.1, lower=0.0, upper=1.1, transform=minuit), ... } >>> unconstrained = tr.unwrap(params) >>> restored = tr.wrap(unconstrained) >>> restored["a"].get_value() == params["a"].get_value() Array(True, dtype=bool) """ def _check_and_regularize(self, parameter: BaseParameter[V]) -> tuple[V, V, V]: # this is not allowed here if (parameter.lower is None and parameter.upper is not None) or ( parameter.lower is not None and parameter.upper is None ): msg = f"{parameter} must have both lower and upper boundaries set, or none of them." raise ValueError(msg) value = float_array(parameter.get_value()) lower = float_array(parameter.lower) upper = float_array(parameter.upper) # check for finite boundaries _safe_assert(_check_is_finite, lower) _safe_assert(_check_is_finite, upper) return value, lower, upper # ty:ignore[invalid-return-type]
[docs] def unwrap(self, parameter: BaseParameter[V]) -> BaseParameter[V]: # short-cut if parameter.lower is None and parameter.upper is None: return parameter value, lower, upper = self._check_and_regularize(parameter) # for unwrapping, we need to make sure the value is within the boundaries initially _safe_assert( _check_in_bounds, value, lower=lower, upper=upper, ) # this formula turns user-provided "external" parameter values into "internal" values new_value = jnp.arcsin(2.0 * (value - lower) / (upper - lower) - 1.0) return parameter.replace(value=new_value) # ty:ignore[invalid-return-type]
[docs] def wrap(self, parameter: BaseParameter[V]) -> BaseParameter[V]: # short-cut if parameter.lower is None and parameter.upper is None: return parameter value, lower, upper = self._check_and_regularize(parameter) # this formula turns "internal" parameter values into "external" values new_value = lower + (upper - lower) / 2.0 * (jnp.sin(value) + 1.0) return parameter.replace(value=new_value) # ty:ignore[invalid-return-type]
[docs] class SoftPlusTransform(BaseParameterTransformation): """Ensures parameters remain positive by using the softplus bijection. This transformation does not require explicit bounds; ``unwrap`` maps to the unconstrained real line and ``wrap`` maps back to the positive reals. Examples: >>> import evermore as evm >>> from evermore.parameters import transform as tr >>> positive = tr.SoftPlusTransform() >>> params = { ... "a": evm.Parameter(2.0, transform=positive), ... "b": evm.Parameter(0.1, transform=positive), ... } >>> unconstrained = tr.unwrap(params) >>> restored = tr.wrap(unconstrained) >>> restored["b"].get_value() Array(0.1, dtype=float32) """
[docs] def unwrap(self, parameter: BaseParameter[V]) -> BaseParameter[V]: # from: https://github.com/danielward27/paramax/blob/main/paramax/utils.py """Applies the inverse softplus transformation after validating the value.""" value = float_array(parameter.get_value()) _safe_assert(_check_is_non_negative, value) new_value = jnp.log(-jnp.expm1(-value)) + value return parameter.replace(value=new_value) # ty:ignore[invalid-return-type]
[docs] def wrap(self, parameter: BaseParameter[V]) -> BaseParameter[V]: new_value = jax.nn.softplus(float_array(parameter.get_value())) return parameter.replace(value=new_value) # ty:ignore[invalid-return-type]