from __future__ import annotations
import typing as tp
from collections.abc import Hashable
from flax import nnx
from jaxtyping import Array, ArrayLike, Float, PyTree
from evermore.util import float_array
if tp.TYPE_CHECKING:
from evermore.binned.modifier import H, Modifier
from evermore.parameters.transform import BaseParameterTransformation
from evermore.pdf import BasePDF
__all__ = [
"PT",
"BaseParameter",
"NormalParameter",
"Parameter",
"V",
]
def __dir__():
return __all__
V = tp.TypeVar("V", bound=Float[Array, "..."])
[docs]
class BaseParameter(nnx.Variable[V]):
def __init__(
self,
value: V | ArrayLike = 0.0,
name: str | None = None,
lower: V | ArrayLike | None = None,
upper: V | ArrayLike | None = None,
frozen: bool = False,
transform: BaseParameterTransformation | None = None,
tags: frozenset[Hashable] = frozenset(),
**kwargs: tp.Any,
) -> None:
super().__init__(value=float_array(value), **kwargs)
# store other metadata
self.set_metadata(name=name)
# boundaries
if lower is not None:
lower = float_array(lower)
if upper is not None:
upper = float_array(upper)
self.set_metadata(lower=lower, upper=upper)
# frozen: if True, the parameter is not updated during optimization
self.set_metadata(frozen=frozen)
# transform
self.set_metadata(transform=transform)
# tags
self.set_metadata(tags=tags)
@property
def prior(self) -> BasePDF | None:
"""Returns the prior distribution associated with this parameter.
Returns:
BasePDF | None: Prior distribution, or ``None`` if no prior is set.
"""
return None
# modifier shorthands
[docs]
def scale(self, slope: ArrayLike = 1.0, offset: ArrayLike = 0.0) -> Modifier:
"""Creates a linear modifier driven by this parameter.
Args:
slope: Multiplicative factor applied to the histogram.
offset: Additive shift applied to the histogram.
Returns:
Modifier: Modifier representing the linear effect.
"""
from evermore.binned.effect import Linear
from evermore.binned.modifier import Modifier
return Modifier(
value=self.get_value(),
effect=Linear(slope=slope, offset=offset),
)
[docs]
class Parameter(BaseParameter[V]):
"""Generic parameter with optional bounds, priors, and metadata.
Attributes:
value: Current parameter value (mutable via ``.get_value()``).
name: Optional human-readable identifier.
lower: Optional lower bound enforced via transformations.
upper: Optional upper bound enforced via transformations.
prior: Optional prior distribution.
frozen: Whether the parameter participates in optimisation.
transform: Optional transformation applied during ``unwrap``/``wrap``.
tags: Additional metadata tags.
Examples:
>>> import evermore as evm
>>> theta = evm.Parameter(value=1.0, lower=0.0, upper=2.0)
>>> theta.get_value()
Array(1., dtype=float32)
"""
[docs]
class NormalParameter(Parameter[V]):
"""Parameter whose default prior is the standard normal distribution.
Provides convenience methods for log-normal scaling and template morphing.
"""
def __init__(
self,
value: V | ArrayLike = 0.0,
name: str | None = None,
lower: V | ArrayLike | None = None,
upper: V | ArrayLike | None = None,
frozen: bool = False,
transform: BaseParameterTransformation | None = None,
tags: frozenset[Hashable] = frozenset(),
**kwargs: tp.Any,
) -> None:
super().__init__(
value=value,
name=name,
lower=lower,
upper=upper,
frozen=frozen,
transform=transform,
tags=tags,
**kwargs,
)
@property
def prior(self) -> BasePDF:
"""Returns the standard normal prior distribution for this parameter.
Returns:
BasePDF: Standard normal distribution.
"""
from evermore.pdf import Normal
return Normal(mean=float_array(0.0), width=float_array(1.0))
[docs]
def scale_log_asymmetric(self, up: ArrayLike, down: ArrayLike) -> Modifier:
"""Creates an asymmetric log-normal modifier for this parameter.
Args:
up: Scaling factor applied to upward deviations.
down: Scaling factor applied to downward deviations.
Returns:
Modifier: Modifier representing the asymmetric exponential effect.
"""
from evermore.binned.effect import AsymmetricExponential
from evermore.binned.modifier import Modifier
return Modifier(
value=self.get_value(),
effect=AsymmetricExponential(up=float_array(up), down=float_array(down)),
)
[docs]
def scale_log_symmetric(self, kappa: ArrayLike) -> Modifier:
"""Creates a symmetric log-normal modifier for this parameter.
Args:
kappa: scaling factor
Returns:
Modifier: Modifier representing the symmetric exponential effect.
"""
from evermore.binned.effect import SymmetricExponential
from evermore.binned.modifier import Modifier
return Modifier(
value=self.get_value(),
effect=SymmetricExponential(kappa=float_array(kappa)),
)
[docs]
def morphing(
self,
up_template: H,
down_template: H,
) -> Modifier:
"""Creates a vertical template morphing modifier for this parameter.
Args:
up_template: Template used for upward variations.
down_template: Template used for downward variations.
Returns:
Modifier: Modifier modelling the morphing effect.
"""
from evermore.binned.effect import VerticalTemplateMorphing
from evermore.binned.modifier import Modifier
return Modifier(
value=self.get_value(),
effect=VerticalTemplateMorphing(
up_template=up_template, down_template=down_template
),
)
PT = tp.TypeVar("PT", bound=PyTree[BaseParameter])