Source code for evermore.parameters.filter
from __future__ import annotations
import dataclasses
from collections.abc import Hashable
from flax import nnx
from evermore.parameters.parameter import BaseParameter
__all__ = [
"HasName",
"HasTags",
"IsFrozen",
"is_dynamic_parameter",
"is_frozen",
"is_not_frozen",
"is_parameter",
]
[docs]
@dataclasses.dataclass(frozen=True)
class IsFrozen:
"""Filter that selects parameters marked as frozen."""
def __call__(self, path, x):
del path # unused
return hasattr(x, "frozen") and x.frozen
[docs]
@dataclasses.dataclass(frozen=True)
class HasName:
"""Filter that matches parameters by their ``name`` attribute.
Attributes:
name: Required name.
"""
name: str
def __call__(self, path, x: BaseParameter):
del path # unused
return hasattr(x, "name") and x.name == self.name
is_parameter = nnx.OfType(BaseParameter)
"""Filter that keeps only instances of ``BaseParameter``."""
is_frozen = IsFrozen()
"""Filter that keeps parameters with ``frozen=True``."""
is_not_frozen = nnx.Not(is_frozen)
"""Filter that excludes frozen parameters."""
is_dynamic_parameter = nnx.All(is_parameter, is_not_frozen)