from __future__ import annotations
import abc
from collections.abc import Callable
from typing import Literal, Protocol, runtime_checkable
import jax
import jax.numpy as jnp
from flax import nnx
from jax._src.random import Shape
from jax.scipy.special import digamma, gammaln, xlogy
from jaxtyping import Array, Float, PRNGKeyArray
from evermore.parameters.parameter import V
from evermore.util import float_array
__all__ = [
"BasePDF",
"Normal",
"PoissonBase",
"PoissonContinuous",
"PoissonDiscrete",
]
def __dir__():
return __all__
@runtime_checkable
class ImplementsFromUnitNormalConversion(Protocol):
def __evermore_from_unit_normal__(self, x: V) -> V: ...
[docs]
class BasePDF(nnx.Pytree):
@abc.abstractmethod
def log_prob(self, x: V) -> V: ...
@abc.abstractmethod
def cdf(self, x: V) -> V: ...
@abc.abstractmethod
def inv_cdf(self, x: V) -> V: ...
@abc.abstractmethod
def sample(self, key: PRNGKeyArray, shape: Shape) -> Float[Array, ...]: ...
def prob(self, x: V, **kwargs) -> V:
return jnp.exp(self.log_prob(x, **kwargs)) # ty:ignore[invalid-return-type]
[docs]
class Normal(BasePDF):
def __init__(self, mean: V, width: V):
self.mean = float_array(mean)
self.width = float_array(width)
def log_prob(self, x: V) -> V:
logpdf_max = jax.scipy.stats.norm.logpdf(
self.mean, loc=self.mean, scale=self.width
)
unnormalized = jax.scipy.stats.norm.logpdf(x, loc=self.mean, scale=self.width)
return unnormalized - logpdf_max # ty:ignore[invalid-return-type]
def cdf(self, x: V) -> V:
return jax.scipy.stats.norm.cdf(x, loc=self.mean, scale=self.width) # ty:ignore[invalid-return-type]
def inv_cdf(self, x: V) -> V:
return jax.scipy.stats.norm.ppf(x, loc=self.mean, scale=self.width) # ty:ignore[invalid-return-type]
def __evermore_from_unit_normal__(self, x: V) -> V:
return self.mean + self.width * x # ty:ignore[invalid-return-type]
def sample(self, key: PRNGKeyArray, shape: Shape) -> Float[Array, ...]:
# sample parameter from pdf
return self.__evermore_from_unit_normal__(jax.random.normal(key, shape=shape))
[docs]
class PoissonBase(BasePDF):
def __init__(self, lamb: V):
self.lamb = float_array(lamb)
[docs]
class PoissonDiscrete(PoissonBase):
"""Poisson distribution with discrete support.
Float inputs are floored to the nearest integer, matching the behaviour of
libraries such as SciPy or RooFit.
"""
def log_prob(
self,
x: V,
normalize: bool = True,
) -> V:
k = jnp.floor(x)
# plain evaluation of the pmf
unnormalized = jax.scipy.stats.poisson.logpmf(k, self.lamb)
if not normalize:
return unnormalized # ty:ignore[invalid-return-type]
# when normalizing, divide (subtract in log space) by maximum over k range
logpdf_max = jax.scipy.stats.poisson.logpmf(k, k)
return unnormalized - logpdf_max # ty:ignore[invalid-return-type]
def cdf(self, x: V) -> V:
# no need to round x to k, already done by cdf library function
return jax.scipy.stats.poisson.cdf(x, self.lamb) # ty:ignore[invalid-return-type]
def inv_cdf(self, x: V, rounding: DiscreteRounding = "floor") -> V:
# define starting point for search from normal approximation
def start_fn(x: V) -> V:
return jnp.floor(
self.lamb + jax.scipy.stats.norm.ppf(x) * jnp.sqrt(self.lamb)
) # ty:ignore[invalid-return-type]
# define the cdf function
def cdf_fn(k: V) -> V:
return jax.scipy.stats.poisson.cdf(k, self.lamb) # ty:ignore[invalid-return-type]
return discrete_inv_cdf_search(
x,
cdf_fn=cdf_fn,
start_fn=start_fn,
rounding=rounding,
)
def sample(self, key: PRNGKeyArray, shape: Shape) -> Float[Array, ...]:
return jax.random.poisson(key, self.lamb, shape=shape)
[docs]
class PoissonContinuous(PoissonBase):
def log_prob(
self,
x: V,
normalize: bool = True,
shift_mode: bool = False,
) -> V:
# optionally adjust lambda to a higher value such that the new mode is the current lambda
lamb = jnp.exp(digamma(self.lamb + 1)) if shift_mode else self.lamb
def _log_prob(x, lamb):
x = jnp.array(x, jnp.result_type(float))
return xlogy(x, lamb) - lamb - gammaln(x + 1)
# plain evaluation of the pdf
unnormalized = _log_prob(x, lamb)
if not normalize:
return unnormalized
# when normalizing, divide (subtract in log space) by maximum over a range
# that depends on whether the mode is shifted
args = (self.lamb, lamb) if shift_mode else (x, x)
logpdf_max = _log_prob(*args)
return unnormalized - logpdf_max
def cdf(self, x: V) -> V:
err = f"{self.__class__.__name__} does not support cdf"
raise Exception(err)
def inv_cdf(self, x: V) -> V:
err = f"{self.__class__.__name__} does not support inv_cdf"
raise Exception(err)
def sample(
self, key: PRNGKeyArray, shape: Shape | None = None
) -> Float[Array, ...]:
msg = f"{self.__class__.__name__} does not support sampling, use PoissonDiscrete instead"
raise Exception(msg)
# alias for rounding literals
DiscreteRounding = Literal["floor", "ceil", "closest"]
known_roundings = frozenset(DiscreteRounding.__args__)
def discrete_inv_cdf_search(
x: V,
cdf_fn: Callable[[V], V],
start_fn: Callable[[V], V],
rounding: DiscreteRounding,
) -> V:
"""Computes an inverse CDF for discrete distributions via iterative search.
Args:
x: Values between 0 and 1 for which the inverse CDF should be evaluated.
cdf_fn: Callable returning the cumulative distribution evaluated at a given integer.
start_fn: Callable providing an initial guess for the search (usually via an approximation).
rounding: Strategy used when the target value lies between two integers.
Returns:
V: Integral values with the same shape as ``x`` that correspond to the requested quantiles.
Examples:
>>> import jax.numpy as jnp
>>> import jax.scipy.stats
>>> lamb = 5.0
>>> def start_fn(q):
... return jnp.floor(lamb + jax.scipy.stats.norm.ppf(q) * jnp.sqrt(lamb))
>>> def cdf_fn(k):
... return jax.scipy.stats.poisson.cdf(k, lamb)
>>> discrete_inv_cdf_search(jnp.array(0.9), cdf_fn, start_fn, \"floor\")
Array(7., dtype=float32)
"""
# store masks for injecting exact values for known edge cases later on
# inject 0 for x == 0
zero_mask = x == 0.0
# inject inf for x == 1
inf_mask = x == 1.0
# inject nan for ~(0 < x < 1) or non-finite values
nan_mask = (x < 0.0) | (x > 1.0) | ~jnp.isfinite(x)
# setup stopping condition and iteration body for the iterative search
# note: functions are defined for scalar values and then vmap'd, with results being reshaped
def cond_fn(val):
*_, stop = val
return ~jnp.any(stop)
def body_fn(val):
k, target_itg, prev_itg, stop = val
# compute the current integral
itg = cdf_fn(k)
# special case: itg is the exact solution
stop |= itg == target_itg
# if no previous integral is available or if we have not yet "cornered" the target value
# with the current and previous integrals, make a step in the right direction
make_step = (
(prev_itg < 0)
| ((prev_itg < itg) & (itg < target_itg))
| ((target_itg < itg) & (itg < prev_itg))
)
step = jnp.where(~stop & make_step, jnp.sign(target_itg - itg), 0)
k += step
# if target_itg is between the computed integrals we can now find the correct k
# note: k might be subject to a shift by +1 or -1, depending on the stride and rounding
k_found = ~stop & ~make_step
# we're using python >=3.11 :)
match rounding:
case "floor":
k_shift = jnp.where(k_found & (itg > target_itg), -1, 0)
case "ceil":
k_shift = jnp.where(k_found & (prev_itg > target_itg), 1, 0)
case "closest":
k_shift = jnp.where(
k_found & (abs(itg - target_itg) > abs(prev_itg - target_itg)),
jnp.sign(prev_itg - itg),
0,
)
case _:
msg = f"unknown rounding '{rounding}' mode, expected one of {', '.join(known_roundings)}"
raise ValueError(msg)
k += k_shift
# update the stop flag and end
stop |= k_found
return (k, target_itg, itg, stop)
def search(start_k, target_itg, stop):
prev_itg = -jnp.ones_like(target_itg)
val = (start_k, target_itg, prev_itg, stop)
return jax.lax.while_loop(cond_fn, body_fn, val)[0]
# jnp.vectorize is auto-vmapping over all axes of its arguments,
# see: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.vectorize.html#jax.numpy.vectorize
vsearch = jnp.vectorize(search)
# define starting point and stop flag (eagerly skipping edge cases), then search
start_k = start_fn(x)
stop = zero_mask | inf_mask | nan_mask
k = vsearch(start_k, x, stop)
# inject known values for edge cases
k = jnp.where(zero_mask, 0.0, k)
k = jnp.where(inf_mask, jnp.inf, k)
k = jnp.where(nan_mask, jnp.nan, k)
return k # noqa: RET504