from __future__ import annotations
import abc
import typing as tp
from collections.abc import Callable
import jax.numpy as jnp
from flax import nnx
from jaxtyping import Array, ArrayLike, Float
from evermore.parameters.parameter import V
from evermore.util import float_array
__all__ = [
"AsymmetricExponential",
"BaseEffect",
"Identity",
"Lambda",
"Linear",
"OffsetAndScale",
"SymmetricExponential",
"VerticalTemplateMorphing",
]
def __dir__():
return __all__
H = tp.TypeVar("H", bound=Float[Array, "..."])
[docs]
class OffsetAndScale(nnx.Pytree):
def __init__(self, offset: ArrayLike = 0.0, scale: ArrayLike = 1.0):
self.offset = float_array(offset)
self.scale = float_array(scale)
def broadcast(self) -> OffsetAndScale:
shape = jnp.broadcast_shapes(self.offset.shape, self.scale.shape)
return type(self)(
offset=jnp.broadcast_to(self.offset, shape),
scale=jnp.broadcast_to(self.scale, shape),
)
[docs]
class BaseEffect(nnx.Module):
@abc.abstractmethod
def __call__(self, value: V, hist: H) -> OffsetAndScale: ...
[docs]
class Identity(BaseEffect):
def __call__(self, value: V, hist: H) -> OffsetAndScale:
del value # unused
return OffsetAndScale(
offset=jnp.zeros_like(hist), scale=jnp.ones_like(hist)
).broadcast()
[docs]
class Lambda(BaseEffect):
def __init__(
self,
fun: Callable[[V, H], OffsetAndScale | H],
normalize_by: tp.Literal["offset", "scale"] | None = None,
):
self.fun = fun
self.normalize_by = normalize_by
def __call__(self, value: V, hist: H) -> OffsetAndScale:
res = self.fun(value, hist) # ty:ignore[invalid-argument-type]
if isinstance(res, OffsetAndScale):
if self.normalize_by is not None:
msg = f"Normalization not supported for OffsetAndScale return value from {self.fun}"
raise ValueError(msg)
return res
if self.normalize_by == "offset":
return OffsetAndScale(
offset=(res - hist), scale=jnp.ones_like(hist)
).broadcast()
if self.normalize_by == "scale":
return OffsetAndScale(
offset=jnp.zeros_like(hist), scale=(res / hist)
).broadcast()
msg = f"Unknown normalization type '{self.normalize_by}' for '{res}'"
raise ValueError(msg)
[docs]
class Linear(BaseEffect):
def __init__(self, offset: ArrayLike, slope: ArrayLike):
self.offset = float_array(offset)
self.slope = float_array(slope)
def __call__(self, value: V, hist: H) -> OffsetAndScale:
sf = value * self.slope + self.offset
return OffsetAndScale(offset=jnp.zeros_like(hist), scale=sf).broadcast()
[docs]
class VerticalTemplateMorphing(BaseEffect):
def __init__(self, up_template: H, down_template: H):
# + 1 sigma
self.up_template: H = float_array(up_template)
# - 1 sigma
self.down_template: H = float_array(down_template)
def vshift(self, value: V, hist: H) -> H:
dx_sum = self.up_template + self.down_template - 2 * hist
dx_diff = self.up_template - self.down_template
# taken from https://github.com/nsmith-/jaxfit/blob/8479cd73e733ba35462287753fab44c0c560037b/src/jaxfit/roofit/combine.py#L173C6-L192
_asym_poly = jnp.array([3.0, -10.0, 15.0, 0.0]) / 8.0
abs_value = jnp.abs(value)
return jnp.array(0.5) * (
dx_diff * value
+ dx_sum
* jnp.where(
abs_value > 1.0,
abs_value,
jnp.polyval(_asym_poly, value * value),
)
) # ty:ignore[invalid-return-type]
def __call__(self, value: V, hist: H) -> OffsetAndScale:
offset = self.vshift(value, hist=hist)
return OffsetAndScale(offset=offset, scale=jnp.ones_like(hist)).broadcast()
[docs]
class AsymmetricExponential(BaseEffect):
def __init__(self, up: H, down: H):
self.up: H = float_array(up)
self.down: H = float_array(down)
def interpolate(self, value: V) -> V:
# https://github.com/cms-analysis/HiggsAnalysis-CombinedLimit/blob/be488af288361ef101859a398ae618131373cad7/src/ProcessNormalization.cc#L112-L129
lo, hi = self.down, self.up
hi = jnp.log(hi)
lo = jnp.log(lo)
lo = -lo
avg = 0.5 * (hi + lo)
halfdiff = 0.5 * (hi - lo)
twox = value + value
twox2 = twox * twox
alpha = 0.125 * twox * (twox2 * (3 * twox2 - 10.0) + 15.0)
return jnp.where(
jnp.abs(value) >= 0.5, jnp.where(value >= 0, hi, lo), avg + alpha * halfdiff
) # ty:ignore[invalid-return-type]
def __call__(self, value: V, hist: H) -> OffsetAndScale:
interp = self.interpolate(value)
return OffsetAndScale(
offset=jnp.zeros_like(hist), scale=jnp.exp(value * interp)
).broadcast()
[docs]
class SymmetricExponential(BaseEffect):
def __init__(self, kappa: H):
self.kappa = float_array(kappa)
def __call__(self, value: V, hist: H) -> OffsetAndScale:
# https://github.com/cms-analysis/HiggsAnalysis-CombinedLimit/blob/be488af288361ef101859a398ae618131373cad7/src/ProcessNormalization.cc#L90
# scale with exp(log(kappa) * theta) = kappa^theta
return OffsetAndScale(
offset=jnp.zeros_like(hist),
scale=self.kappa**value,
).broadcast()