diff --git a/src/hssm/__init__.py b/src/hssm/__init__.py index 2f234d086..171da9906 100644 --- a/src/hssm/__init__.py +++ b/src/hssm/__init__.py @@ -19,7 +19,7 @@ from .param import UserParam as Param from .prior import Prior from .register import register_model -from .rl import RLSSM +from .rl import RLSSM, register_rlssm_model from .simulator import simulate_data from .utils import check_data_for_rl, set_floatX @@ -40,6 +40,7 @@ "Prior", "check_data_for_rl", "register_model", + "register_rlssm_model", "simulate_data", "set_floatX", "show_defaults", diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index a817401f9..0a7746ba2 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -51,34 +51,6 @@ } -class classproperty: - """A decorator that combines the behavior of @property and @classmethod. - - This decorator allows you to define a property that can be accessed on the class - itself, rather than on instances of the class. It is useful for defining class-level - properties that need to perform some computation or access class-level data. - - This implementation is provided for compatibility with Python versions 3.10 through - 3.12, as one cannot combine the @property and @classmethod decorators is across all - these versions. - - Example - ------- - class MyClass: - @classproperty - def my_class_property(cls): - return "This is a class property" - - print(MyClass.my_class_property) # Output: This is a class property - """ - - def __init__(self, fget): - self.fget = fget - - def __get__(self, instance, owner): # noqa: D105 - return self.fget(owner) - - class HSSM(HSSMBase): """The basic Hierarchical Sequential Sampling Model (HSSM) class. diff --git a/src/hssm/rl/__init__.py b/src/hssm/rl/__init__.py index 64e17bc41..bb6a4cb1a 100644 --- a/src/hssm/rl/__init__.py +++ b/src/hssm/rl/__init__.py @@ -5,9 +5,12 @@ Public API (import from ``hssm.rl``): -- ``RLSSM``: the RL + SSM model class implemented in :mod:`hssm.rl.rlssm`. -- ``RLSSMConfig``: the config class for RL + SSM models, implemented in - :mod:`hssm.rl.config`. +- ``RLSSM``: the public RL + SSM model class in :mod:`hssm.rl.rlssm`. +- ``_RLSSM``: the internal base class that requires a fully built config. +- ``RLSSMConfig``: the config class for RL + SSM models in :mod:`hssm.rl.config`. +- ``get_rlssm_model_config``: factory that builds a config from a named model. +- ``register_rlssm_model``: register a custom named RLSSM model. +- ``register_ssm``: register a custom SSM base logp function. - ``validate_balanced_panel``: panel-balance utility in :mod:`hssm.rl.utils`. RL likelihood builders live in :mod:`hssm.rl.likelihoods.builder` and include @@ -17,11 +20,20 @@ """ from .config import RLSSMConfig -from .rlssm import RLSSM +from .registry import ( + get_rlssm_model_config, + register_rlssm_model, + register_ssm, +) +from .rlssm import _RLSSM, RLSSM from .utils import validate_balanced_panel __all__ = [ "RLSSM", + "_RLSSM", "RLSSMConfig", + "get_rlssm_model_config", + "register_rlssm_model", + "register_ssm", "validate_balanced_panel", ] diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py new file mode 100644 index 000000000..df4c88970 --- /dev/null +++ b/src/hssm/rl/registry.py @@ -0,0 +1,598 @@ +"""Registry for named RLSSM models and SSM base log-likelihood functions. + +This module provides: + +- :data:`_SSM_REGISTRY` — holds *custom* SSM entries added via + :func:`register_ssm`. Built-in HSSM models (``"ddm"``, ``"angle"``, + ``"weibull"``, and any other model in :mod:`hssm.modelconfig` that exposes an + ``approx_differentiable`` likelihood) are resolved automatically from + :func:`hssm.modelconfig.get_default_model_config` and do **not** need to be + pre-registered here. +- :data:`_RLSSM_REGISTRY` — maps named RLSSM model strings (e.g. + ``"2AB_RescorlaWagner_ddm"``) to their default decision process, + learning process, and parameter info. +- :func:`get_rlssm_model_config` — builds a :class:`~hssm.rl.config.RLSSMConfig` + from a named model string with optional overrides. +- :func:`register_rlssm_model` — register a custom named RLSSM model. +- :func:`register_ssm` — register a custom SSM base logp function. +""" + +from __future__ import annotations + +import logging +from copy import deepcopy +from typing import Any + +from hssm.distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx +from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise +from hssm.utils import annotate_function + +from .config import RLSSMConfig + +_logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Default annotated Rescorla-Wagner learning function +# --------------------------------------------------------------------------- + +_compute_v_annotated = annotate_function( + inputs=["rl_alpha", "scaler", "response", "feedback"], + outputs=["v"], +)(compute_v_subject_wise) + +# --------------------------------------------------------------------------- +# SSM base log-likelihood registry +# --------------------------------------------------------------------------- +# This dict holds only *custom* SSM entries added at runtime via register_ssm(). +# Built-in HSSM models are resolved on demand from hssm.modelconfig — see +# _build_ssm_spec_from_modelconfig() and _get_decision_process_spec(). +# +# Each entry (custom or derived) provides: +# ssm_base_logp_func - annotated JAX function (inputs + outputs, no computed) +# list_params_ssm - ordered SSM parameter names (including computed ones) +# bounds_ssm - bounds for all SSM params +# params_default_ssm - default values aligned with list_params_ssm +# response - data column names + +_SSM_REGISTRY: dict[str, dict[str, Any]] = {} + +# Cache for resolved SSM base logp functions (populated on first use by +# _get_ssm_logp, or immediately by register_ssm when a pre-built func is +# supplied by the caller). +_SSM_LOGP_CACHE: dict[str, Any] = {} + + +def _make_ssm_base_logp_from_onnx( + onnx_file: str, + list_params_ssm: list[str], + response: list[str], +) -> Any: + """Build and annotate a JAX log-likelihood function from an ONNX model file.""" + _raw = make_jax_matrix_logp_funcs_from_onnx(model=onnx_file) + return annotate_function( + inputs=list_params_ssm + response, + outputs=["logp"], + )(_raw) + + +def _build_ssm_spec_from_modelconfig(name: str) -> dict[str, Any]: + """Build an SSM registry-compatible spec from HSSM's modelconfig system. + + This allows any built-in HSSM model with an ``approx_differentiable`` + likelihood to be used as an RLSSM decision process without re-registering + it in ``_SSM_REGISTRY``. Parameter defaults are computed as midpoints of + the model's parameter bounds. + + Raises + ------ + ValueError + If *name* is not a supported HSSM model or it has no + ``approx_differentiable`` likelihood. + """ + # Local import to avoid circular dependencies at module level. + from hssm.modelconfig import get_default_model_config # noqa: PLC0415 + + try: + mc = get_default_model_config(name) # type: ignore[arg-type] + except ValueError as exc: + raise ValueError( + f"Decision process '{name}' is not a registered custom SSM and is not a " + "supported HSSM model. " + f"Custom SSMs in registry: {list(_SSM_REGISTRY.keys())}. " + "Use register_ssm() to add a custom decision process." + ) from exc + + ad = mc["likelihoods"].get("approx_differentiable") + if ad is None: + raise ValueError( + f"Model '{name}' has no approx_differentiable likelihood and cannot be " + "used as an RLSSM decision process." + ) + + list_params_ssm: list[str] = list(mc["list_params"]) + bounds_ssm: dict[str, tuple[float, float]] = dict(ad["bounds"]) + response: list[str] = list(mc["response"]) + onnx_file: str = str(ad["loglik"]) + + # Derive parameter defaults as midpoints of their respective bounds. + params_default_ssm = [ + sum(bounds_ssm[p]) / 2 if p in bounds_ssm else 0.0 for p in list_params_ssm + ] + + # Capture loop variables explicitly to avoid closure-over-variable issues. + def _factory( + _onnx_file: str = onnx_file, + _params: list[str] = list_params_ssm, + _response: list[str] = response, + ) -> Any: + return _make_ssm_base_logp_from_onnx(_onnx_file, _params, _response) + + return { + "ssm_base_logp_func_factory": _factory, + "list_params_ssm": list_params_ssm, + "bounds_ssm": bounds_ssm, + "params_default_ssm": params_default_ssm, + "response": response, + "name": name, + } + + +def _get_decision_process_spec( + decision_process: str | dict[str, Any], +) -> dict[str, Any]: + """Return a defensive copy of a decision-process specification. + + Custom SSMs (registered via :func:`register_ssm`) take precedence. For + everything else the spec is derived on the fly from HSSM's modelconfig + system, meaning any built-in model with an ``approx_differentiable`` + likelihood (e.g. ``"ddm"``, ``"angle"``, ``"weibull"``) works out of the + box without explicit registration. + """ + if isinstance(decision_process, dict): + return deepcopy(decision_process) + + # Custom registry takes precedence over modelconfig. + if decision_process in _SSM_REGISTRY: + spec = deepcopy(_SSM_REGISTRY[decision_process]) + spec["name"] = decision_process + return spec + + # Fall back to HSSM's modelconfig for built-in SSMs. + return _build_ssm_spec_from_modelconfig(decision_process) + + +def _get_ssm_logp(name: str) -> Any: + """Return the annotated SSM base logp function, building it on first use. + + ONNX models are downloaded / loaded only when first called (lazy + initialisation). Subsequent calls return the cached object. + """ + if name in _SSM_LOGP_CACHE: + return _SSM_LOGP_CACHE[name] + + if name not in _SSM_REGISTRY: + # Build from HSSM's modelconfig for built-in SSMs. + spec = _build_ssm_spec_from_modelconfig(name) + _SSM_LOGP_CACHE[name] = spec["ssm_base_logp_func_factory"]() + return _SSM_LOGP_CACHE[name] + + entry = _SSM_REGISTRY[name] + if "ssm_base_logp_func_factory" in entry: + _SSM_LOGP_CACHE[name] = entry["ssm_base_logp_func_factory"]() + else: + # Pre-built function registered via register_ssm(). + _SSM_LOGP_CACHE[name] = entry["ssm_base_logp_func"] + return _SSM_LOGP_CACHE[name] + + +# --------------------------------------------------------------------------- +# RLSSM named model registry +# --------------------------------------------------------------------------- +# Each entry provides: +# decision_process - key into _SSM_REGISTRY +# learning_process - {param: annotated_func} +# learning_process_params - ordered list of sampled RL parameter names +# learning_process_bounds - {param: (lo, hi)} for RL params +# learning_process_params_default +# - default values aligned with learning_process_params +# extra_fields - extra data column names required by LP +# choices - response choice values +# description - human-readable description +# decision_process_loglik_kind +# learning_process_kind + +_RLSSM_REGISTRY: dict[str, dict[str, Any]] = { + "2AB_RescorlaWagner_DDM": { + "decision_process": _get_decision_process_spec("ddm"), + "learning_process": {"v": _compute_v_annotated}, + "learning_process_params": ["rl_alpha", "scaler"], + "learning_process_bounds": { + "rl_alpha": (0.0, 1.0), + "scaler": (0.0, 10.0), + }, + "learning_process_params_default": [0.1, 1.0], + "extra_fields": ["feedback"], + "choices": [0, 1], + "description": ( + "RLSSM model with Rescorla-Wagner Q-learning and the " + "standard DDM as decision process." + ), + "decision_process_loglik_kind": "approx_differentiable", + "learning_process_kind": "blackbox", + }, + "2AB_RescorlaWagner_Angle": { + "decision_process": _get_decision_process_spec("angle"), + "learning_process": {"v": _compute_v_annotated}, + "learning_process_params": ["rl_alpha", "scaler"], + "learning_process_bounds": { + "rl_alpha": (0.0, 1.0), + "scaler": (0.0, 10.0), + }, + "learning_process_params_default": [0.1, 1.0], + "extra_fields": ["feedback"], + "choices": [0, 1], + "description": ( + "RLSSM model with Rescorla-Wagner Q-learning and a " + "collapsing-bound DDM (angle model) as decision process." + ), + "decision_process_loglik_kind": "approx_differentiable", + "learning_process_kind": "blackbox", + }, + "2AB_RescorlaWagner_Weibull": { + "decision_process": _get_decision_process_spec("weibull"), + "learning_process": {"v": _compute_v_annotated}, + "learning_process_params": ["rl_alpha", "scaler"], + "learning_process_bounds": { + "rl_alpha": (0.0, 1.0), + "scaler": (0.0, 10.0), + }, + "learning_process_params_default": [0.1, 1.0], + "extra_fields": ["feedback"], + "choices": [0, 1], + "description": ( + "RLSSM model with Rescorla-Wagner Q-learning and a " + "Weibull-bound DDM as decision process." + ), + "decision_process_loglik_kind": "approx_differentiable", + "learning_process_kind": "blackbox", + }, +} + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _build_ssm_logp_func(ssm_base_logp_func: Any, learning_process: dict) -> Any: + """Re-annotate *ssm_base_logp_func* adding ``computed=learning_process``. + + Creates a new wrapper that carries the same ``.inputs`` and ``.outputs`` as + the base function but adds the ``computed`` dict so that + :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op` can resolve which + parameters are produced by the RL learning rule at runtime. + """ + existing = getattr(ssm_base_logp_func, "computed", None) + if existing: + raise ValueError( + "ssm_base_logp_func already has a non-empty .computed attribute. " + "Pass the raw (base) annotated function without .computed instead." + ) + return annotate_function( + inputs=ssm_base_logp_func.inputs, + outputs=ssm_base_logp_func.outputs, + computed=learning_process, + )(ssm_base_logp_func) + + +def _derive_lp_params( + learning_process: dict[str, Any], + response: list[str], + extra_fields: list[str], +) -> list[str]: + """Return sampled RL parameter names inferred from *learning_process*. + + Iterates over each LP function's ``.inputs`` and collects names that are + neither response columns nor extra fields. + """ + exclude = set(response) | set(extra_fields) + learning_process_params: list[str] = [] + seen: set[str] = set() + for param_name, lp_func in learning_process.items(): + if not hasattr(lp_func, "inputs"): + _logger.warning( + "Learning process function for '%s' has no .inputs attribute; " + "cannot derive RL parameters from it. " + "Ensure it is decorated with @annotate_function.", + param_name, + ) + continue + for inp in lp_func.inputs: + if inp not in exclude and inp not in seen: + learning_process_params.append(inp) + seen.add(inp) + return learning_process_params + + +# --------------------------------------------------------------------------- +# Public factory +# --------------------------------------------------------------------------- + + +def get_rlssm_model_config( + model: str = "2AB_RescorlaWagner_DDM", + choices: list[int] | None = None, + learning_process: dict[str, Any] | None = None, + decision_process: str | None = None, +) -> RLSSMConfig: + """Build an :class:`~hssm.rl.config.RLSSMConfig` from a named model. + + Parameters + ---------- + model: + Name of a registered RLSSM model (e.g. ``"2AB_RescorlaWagner_DDM"``). + choices: + Override the response choice values stored in the registry. + learning_process: + Override the learning process dict stored in the registry. + decision_process: + Override the SSM name stored in the registry. + + Returns + ------- + RLSSMConfig + Fully populated configuration ready to be passed to :class:`_RLSSM`. + + Raises + ------ + ValueError + If *model* or the resolved *decision_process* is not registered. + """ + if model not in _RLSSM_REGISTRY: + available = list(_RLSSM_REGISTRY.keys()) + raise ValueError( + f"Model '{model}' not found in the RLSSM registry. " + f"Available models: {available}. " + "To add a custom model, use register_rlssm_model() " + "(and register_ssm() for custom decision processes), " + "or pass 'model_config=' directly to RLSSM()." + ) + + # Shallow-copy so overrides don't mutate the registry entry. + entry = dict(_RLSSM_REGISTRY[model]) + + if learning_process is not None: + entry["learning_process"] = learning_process + if decision_process is not None: + entry["decision_process"] = _get_decision_process_spec(decision_process) + if choices is not None: + entry["choices"] = choices + + ssm_entry = _get_decision_process_spec(entry["decision_process"]) + dp: str = ssm_entry["name"] + ssm_base = _get_ssm_logp(dp) + # Defensive copy so callers mutating config.learning_process don't corrupt + # the registry entry (entry is only a shallow copy of _RLSSM_REGISTRY[model]). + lp: dict[str, Any] = dict(entry["learning_process"]) + + # Compose the full ssm_logp_func with .computed = learning_process. + ssm_logp_func = _build_ssm_logp_func(ssm_base, lp) + + # list_params = [sampled RL params] + [sampled SSM params (non-computed)] + ssm_sampled = [p for p in ssm_entry["list_params_ssm"] if p not in lp] + + # Defensive copy of response to prevent downstream mutation of registry. + response = list(ssm_entry["response"]) + + # Use `is None` checks so that explicitly empty containers ([], {}) are + # respected as valid "no RL params" configuration and not overridden by + # the fallback derivation logic. + _rl_params = entry.get("learning_process_params") + learning_process_params: list[str] = ( + _derive_lp_params(lp, response, entry.get("extra_fields") or []) + if _rl_params is None + else _rl_params + ) + list_params = learning_process_params + ssm_sampled + + # bounds: RL bounds ∪ SSM sampled bounds + missing_bounds = [p for p in ssm_sampled if p not in ssm_entry["bounds_ssm"]] + if missing_bounds: + raise ValueError( + f"SSM parameter(s) {missing_bounds} are included in list_params but have " + f"no entry in bounds_ssm for decision process '{dp}'. " + "Provide bounds for all sampled parameters via register_ssm() or ensure " + "the built-in modelconfig includes them." + ) + _rl_bounds = entry.get("learning_process_bounds") + bounds: dict[str, tuple[float, float]] = dict( + _rl_bounds if _rl_bounds is not None else {} + ) + for p in ssm_sampled: + bounds[p] = ssm_entry["bounds_ssm"][p] + + # params_default aligned with list_params + _rl_defaults = entry.get("learning_process_params_default") + rl_defaults: list[float] = list(_rl_defaults if _rl_defaults is not None else []) + ssm_all_defaults: list[float] = list(ssm_entry["params_default_ssm"]) + ssm_sampled_defaults = [ + ssm_all_defaults[i] + for i, p in enumerate(ssm_entry["list_params_ssm"]) + if p not in lp + ] + params_default = rl_defaults + ssm_sampled_defaults + + return RLSSMConfig( + model_name=entry.get("model_name", model), + description=entry.get("description"), + decision_process=dp, + decision_process_loglik_kind=entry["decision_process_loglik_kind"], + learning_process_kind=entry["learning_process_kind"], + learning_process=lp, + ssm_logp_func=ssm_logp_func, + list_params=list_params, + bounds=bounds, + params_default=params_default, + response=response, + choices=tuple(entry["choices"]), + extra_fields=list(entry["extra_fields"]) + if entry.get("extra_fields") is not None + else None, + ) + + +# --------------------------------------------------------------------------- +# Public query helpers +# --------------------------------------------------------------------------- + + +def list_models() -> dict[str, str | None]: + """Return the names and descriptions of all registered RLSSM models. + + This is the recommended starting point for new users who want to discover + which models are available out of the box. + + Returns + ------- + dict[str, str | None] + Mapping of model name → description string (or ``None`` if no + description was provided at registration time). + + Examples + -------- + >>> import hssm + >>> hssm.rl.list_models() + {'2AB_RescorlaWagner_DDM': 'RLSSM model with Rescorla-Wagner ...', ...} + """ + return {name: entry.get("description") for name, entry in _RLSSM_REGISTRY.items()} + + +# --------------------------------------------------------------------------- +# Public registration helpers +# --------------------------------------------------------------------------- + + +def register_rlssm_model( + name: str, + decision_process: str, + learning_process: dict[str, Any], + learning_process_params: list[str], + learning_process_bounds: dict[str, tuple[float, float]], + learning_process_params_default: list[float], + extra_fields: list[str] | None = None, + choices: list[int] | None = None, + description: str | None = None, + decision_process_loglik_kind: str = "approx_differentiable", + learning_process_kind: str = "blackbox", +) -> None: + """Register a named RLSSM model in the global registry. + + Parameters + ---------- + name: + Registry key (e.g. ``"my_rldm"``). + decision_process: + Name of the SSM to use. This may be either a custom SSM already registered + in the SSM registry via :func:`register_ssm`, or a built-in HSSM modelconfig SSM + name such as ``"ddm"``, ``"angle"``, or ``"weibull"``. + learning_process: + Dict mapping computed parameter name → annotated learning function. + learning_process_params: + Ordered list of sampled RL parameter names. + learning_process_bounds: + Parameter bounds for the RL parameters. + learning_process_params_default: + Default values aligned with *learning_process_params*. + extra_fields: + Data column names required by the learning process (e.g. ``["feedback"]``). + choices: + Response choice values. Defaults to ``[0, 1]``. + description: + Optional human-readable description. + decision_process_loglik_kind: + Loglik kind tag. Defaults to ``"approx_differentiable"``. + learning_process_kind: + Learning process kind tag. Defaults to ``"blackbox"``. + """ + if name in _RLSSM_REGISTRY: + _logger.warning( + "Model '%s' is already in the RLSSM registry and will be overwritten.", + name, + ) + _RLSSM_REGISTRY[name] = { + "decision_process": _get_decision_process_spec(decision_process), + # Shallow-copy all mutable caller-supplied collections so that later + # mutations of the originals do not silently corrupt the registry entry. + "learning_process": dict(learning_process), + "learning_process_params": list(learning_process_params), + "learning_process_bounds": dict(learning_process_bounds), + "learning_process_params_default": list(learning_process_params_default), + "extra_fields": list(extra_fields) if extra_fields is not None else [], + "choices": list(choices) if choices is not None else [0, 1], + "description": description, + "decision_process_loglik_kind": decision_process_loglik_kind, + "learning_process_kind": learning_process_kind, + } + + +def register_ssm( + name: str, + ssm_base_logp_func: Any, + list_params_ssm: list[str], + bounds_ssm: dict[str, tuple[float, float]], + params_default_ssm: list[float], + response: list[str] | None = None, +) -> None: + """Register an SSM base log-likelihood function in the SSM registry. + + Parameters + ---------- + name: + Registry key (e.g. ``"ddm"``). + ssm_base_logp_func: + An annotated JAX function (created with ``@annotate_function``) that + computes the SSM log-likelihood from a parameter matrix. Must carry + ``.inputs`` and ``.outputs`` attributes but should **not** have a + ``.computed`` key — that is injected by the factory at config-build time. + list_params_ssm: + Ordered list of all SSM parameter names (including any that will be + computed by the learning process). + bounds_ssm: + Bounds for the non-computed SSM parameters. + params_default_ssm: + Default values aligned with *list_params_ssm*. + response: + Data column names. Defaults to ``["rt", "response"]``. + """ + if not callable(ssm_base_logp_func): + raise ValueError( + f"ssm_base_logp_func must be callable, got {type(ssm_base_logp_func)!r}." + ) + if not hasattr(ssm_base_logp_func, "inputs") or not hasattr( + ssm_base_logp_func, "outputs" + ): + raise ValueError( + "ssm_base_logp_func must be decorated with @annotate_function " + "(missing .inputs or .outputs attribute)." + ) + existing_computed = getattr(ssm_base_logp_func, "computed", None) + if existing_computed: + raise ValueError( + "ssm_base_logp_func should not have a non-empty .computed attribute " + "at registration time. The .computed dict is injected later by " + "get_rlssm_model_config() when composing the learning process. " + "Pass the raw base function instead." + ) + if name in _SSM_REGISTRY: + _logger.warning( + "SSM '%s' is already in the SSM registry and will be overwritten.", name + ) + _SSM_REGISTRY[name] = { + "ssm_base_logp_func": ssm_base_logp_func, + "list_params_ssm": list(list_params_ssm), + "bounds_ssm": dict(bounds_ssm), + "params_default_ssm": list(params_default_ssm), + "response": list(response) if response is not None else ["rt", "response"], + } + # Pre-built: cache immediately so _get_ssm_logp never calls a factory. + _SSM_LOGP_CACHE[name] = ssm_base_logp_func diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 180cd12c8..bdb12413a 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -1,14 +1,19 @@ """RLSSM: Reinforcement Learning Sequential Sampling Model. -This module defines the :class:`RLSSM` class, a subclass of :class:`HSSMBase` -for models that couple a reinforcement learning (RL) learning process with a -sequential sampling decision model (SSM). +This module defines: -The key difference from :class:`HSSM` is the likelihood: +- :class:`_RLSSM` — the internal base class (previously ``RLSSM``) that requires + a fully populated :class:`~hssm.rl.config.RLSSMConfig` to be passed directly. +- :class:`RLSSM` — the public-facing subclass with a simplified constructor that + accepts a *model* name string, optional *learning_process* / *decision_process* + overrides, and an optional *model_config* override. Config construction is + delegated to :func:`~hssm.rl.registry.get_rlssm_model_config`. + +The key difference from :class:`~hssm.hssm.HSSM` is the likelihood: - ``HSSM`` wraps an analytical / ONNX / blackbox callable via :func:`~hssm.distribution_utils.make_likelihood_callable`. - - ``RLSSM`` builds a differentiable pytensor ``Op`` directly from an - :class:`~hssm.rl.likelihoods.builder.AnnotatedFunction` via + - ``_RLSSM`` / ``RLSSM`` build a differentiable pytensor ``Op`` directly from + an :class:`~hssm.rl.likelihoods.builder.AnnotatedFunction` via :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`, which internally handles the RL learning rule and per-participant trial structure. This Op is then passed straight to @@ -16,6 +21,7 @@ standard ``loglik`` / ``loglik_kind`` wrapping pipeline. """ +import logging from dataclasses import replace from typing import TYPE_CHECKING, Any, Callable, Literal, cast @@ -34,12 +40,19 @@ from hssm.rl.likelihoods.builder import make_rl_logp_op from hssm.rl.utils import validate_balanced_panel -from ..base import HSSMBase +from ..base import HSSMBase, classproperty from .config import RLSSMConfig +from .registry import get_rlssm_model_config + +_logger = logging.getLogger("hssm") + +class _RLSSM(HSSMBase): + """Internal Reinforcement Learning Sequential Sampling Model. -class RLSSM(HSSMBase): - """Reinforcement Learning Sequential Sampling Model. + Requires a fully populated :class:`RLSSMConfig` (with ``ssm_logp_func`` set) + to be passed directly. End users should use :class:`RLSSM` instead, which + provides a simplified interface backed by the named-model registry. Combines a reinforcement learning (RL) process with a sequential sampling model (SSM) inside a single differentiable likelihood. The RL component @@ -158,7 +171,6 @@ def __init__( # Store RL-specific state on self BEFORE super().__init__() so that # _make_model_distribution() (called from super) can access them. - self.config = model_config self.n_participants = n_participants self.n_trials = n_trials @@ -189,6 +201,8 @@ def __init__( # _make_model_distribution for details. model_config = replace(model_config, loglik=loglik_op, backend="jax") + # missing_data and deadline are guaranteed False at this point (guards + # above reject any other value). Pass them explicitly for clarity. super().__init__( data=data, model_config=model_config, @@ -198,8 +212,8 @@ def __init__( link_settings=link_settings, prior_settings=prior_settings, extra_namespace=extra_namespace, - missing_data=missing_data, - deadline=deadline, + missing_data=False, + deadline=False, loglik_missing_data=loglik_missing_data, process_initvals=process_initvals, initval_jitter=initval_jitter, @@ -228,10 +242,12 @@ def _make_model_distribution(self) -> type[pm.Distribution]: # has_lapse=True) rather than self.model_config.list_params (the original # config list, never mutated by HSSMBase). list_params = self.list_params - assert list_params is not None, "list_params must be set" - assert isinstance(list_params, list), ( - "list_params must be a list" - ) # for type checker + if list_params is None: + raise RuntimeError( + "list_params must be set before _make_model_distribution is called." + ) + if not isinstance(list_params, list): + raise TypeError(f"list_params must be a list, got {type(list_params)!r}.") # Every RLSSM distribution parameter is trialwise (the Op receives one # value per trial). p_outlier is excluded to match the contract of @@ -247,7 +263,11 @@ def _make_model_distribution(self) -> type[pm.Distribution]: # The differentiable pytensor Op was stored on model_config.loglik during # __init__; ensure it's present and cast for typing. - assert self.model_config.loglik is not None, "model_config.loglik must be set" + if self.model_config.loglik is None: + raise RuntimeError( + "model_config.loglik must be set before _make_model_distribution " + "is called. This indicates the Op was not built in __init__." + ) loglik_op = cast("Callable[..., Any] | Op", self.model_config.loglik) # RLSSMConfig carries no `rv` field; use model_name as the rv identifier. @@ -262,3 +282,211 @@ def _make_model_distribution(self) -> type[pm.Distribution]: extra_fields=extra_fields_data, params_is_trialwise=params_is_trialwise, ) + + +class _BlockedAttribute: + """Data descriptor that blocks read access with NotImplementedError. + + During initialisation, writes are stored and reads return the stored + value so that :meth:`MissingDataMixin._process_missing_data_and_deadline` + (which both writes and reads ``self.missing_data`` / ``self.deadline`` / + ``self.loglik_missing_data``) can complete without error. + + Once ``instance.__dict__['_rlssm_fully_initialized']`` is set to ``True`` + at the end of :meth:`RLSSM.__init__`, any read raises + :exc:`NotImplementedError`. + + Using a data descriptor (one with both ``__get__`` and ``__set__``) is + necessary because data descriptors take priority over instance ``__dict__`` + entries, so the descriptor's ``__get__`` fires even after a write. + """ + + def __init__(self, name: str, message: str) -> None: + self._name = name + self._message = message + self._storage_key = f"_ba_{name}" + + def __set_name__(self, owner: type, name: str) -> None: # noqa: D105 + self._name = name + self._storage_key = f"_ba_{name}" + + def __get__(self, obj: Any, objtype: type | None = None) -> Any: # noqa: D105 + if obj is None: + # Class-level access — return the descriptor itself. + return self + if not obj.__dict__.get("_rlssm_fully_initialized", False): + # During __init__: return the stored value so internal code works. + return obj.__dict__.get(self._storage_key, False) + raise NotImplementedError(self._message) + + def __set__(self, obj: Any, value: Any) -> None: # noqa: D105 + # Store the value so internal reads during __init__ work correctly. + obj.__dict__[self._storage_key] = value + + +class RLSSM(_RLSSM): + """Reinforcement Learning Sequential Sampling Model — simplified public API. + + This class wraps :class:`_RLSSM` with a user-friendly constructor that + accepts a *model* name string (looked up in the named-model registry) and + optional overrides for *learning_process*, *decision_process*, and + *choices*. Advanced users can bypass the registry entirely by supplying a + pre-built *model_config*. + + ``missing_data``, ``deadline``, and ``loglik_missing_data`` are not + supported for RLSSM models and raise :exc:`NotImplementedError` if accessed. + + Parameters + ---------- + data : pd.DataFrame + Trial-level data (balanced panel required). + model : str | None, optional + Name of a registered RLSSM model. Defaults to ``"2AB_RescorlaWagner_DDM"``. + choices : list[int] | None, optional + Override the choice values in the registry. ``None`` uses the registry + default. + include : list | None, optional + Parameter specifications forwarded to :class:`~hssm.base.HSSMBase`. + model_config : RLSSMConfig | None, optional + Fully built config. When provided, *model*, + *learning_process*, *decision_process*, and *choices* are ignored + (a warning is emitted if they are non-default). + learning_process : dict | None, optional + Override the learning-process dict in the registry. ``None`` uses the + registry default. + decision_process : str | None, optional + Override the SSM name in the registry. ``None`` uses the registry + default. + participant_col : str, optional + Column identifying participants. Defaults to ``"participant_id"``. + p_outlier : float | dict | bmb.Prior | None, optional + Lapse probability. Defaults to ``0.05``. + lapse : dict | bmb.Prior | None, optional + Lapse distribution. Defaults to ``Uniform(0, 20)``. + link_settings : Literal["log_logit"] | None, optional + Link-function preset. Defaults to ``None``. + prior_settings : Literal["safe"] | None, optional + Prior preset. Defaults to ``"safe"``. + extra_namespace : dict | None, optional + Extra variables for formula evaluation. Defaults to ``None``. + process_initvals : bool, optional + Whether to post-process initial values. Defaults to ``True``. + initval_jitter : float, optional + Jitter magnitude for initial values. + **kwargs + Additional keyword arguments forwarded to :class:`bmb.Model`. + """ + + # Block read access to the three missing-data attributes while silently + # accepting any writes made by the base-class initialisation path. + missing_data = _BlockedAttribute( # type: ignore[assignment] + "missing_data", + "RLSSM does not support 'missing_data'. " + "The RL log-likelihood Op relies on strict row order; rearranging rows " + "for missing RT values would corrupt the RL learning dynamics. " + "Please remove missing trials from the data before passing it to RLSSM.", + ) + deadline = _BlockedAttribute( # type: ignore[assignment] + "deadline", + "RLSSM does not support 'deadline'. " + "The RL log-likelihood Op relies on strict row order; rearranging rows " + "for deadline trials would corrupt the RL learning dynamics. " + "Please remove deadline trials from the data before passing it to RLSSM.", + ) + loglik_missing_data = _BlockedAttribute( # type: ignore[assignment] + "loglik_missing_data", + "RLSSM does not support 'loglik_missing_data'. " + "Missing-data network assembly (OPN / CPN) is not implemented for RLSSM.", + ) + + def __init__( + self, + data: pd.DataFrame, + model: str | None = None, + choices: list[int] | None = None, + include: list[dict[str, Any] | Any] | None = None, + model_config: RLSSMConfig | None = None, + learning_process: dict[str, Any] | None = None, + decision_process: str | None = None, + participant_col: str = "participant_id", + p_outlier: float | dict | bmb.Prior | None = 0.05, + lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0), + link_settings: Literal["log_logit"] | None = None, + prior_settings: Literal["safe"] | None = "safe", + extra_namespace: dict[str, Any] | None = None, + process_initvals: bool = True, + initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"], + **kwargs: Any, + ) -> None: + # Capture simplified args BEFORE calling super so they can be + # restored afterwards for save/load serialisation. + # NOTE: _store_init_args only operates on its arguments, not on self. + _my_init_args = self._store_init_args(locals(), kwargs) + + if model_config is not None and any( + x is not None for x in [model, learning_process, decision_process, choices] + ): + _logger.warning( + "model_config was provided; ignoring model, learning_process, " + "decision_process, and choices arguments." + ) + + model_config = model_config or get_rlssm_model_config( + model=model or "2AB_RescorlaWagner_DDM", + choices=choices, + learning_process=learning_process, + decision_process=decision_process, + ) + + # missing_data / deadline are intentionally omitted — _RLSSM defaults + # them to False. The _BlockedAttribute descriptors on this class + # silently accept writes from the base-class init path and block reads + # after _rlssm_fully_initialized is set below. + super().__init__( + data=data, + model_config=model_config, + participant_col=participant_col, + include=include, + p_outlier=p_outlier, + lapse=lapse, + link_settings=link_settings, + prior_settings=prior_settings, + extra_namespace=extra_namespace, + process_initvals=process_initvals, + initval_jitter=initval_jitter, + **kwargs, + ) + + # Restore the simplified constructor args so that save/load round-trips + # reconstruct the model via RLSSM(model=...) rather than + # _RLSSM(model_config=...). + self._init_args = _my_init_args + + # Mark initialisation complete — after this point _BlockedAttribute + # raises NotImplementedError on any read of missing_data / deadline / + # loglik_missing_data. + # IMPORTANT: This MUST be the last statement in __init__. Any code + # added after this line cannot read self.missing_data / self.deadline. + self.__dict__["_rlssm_fully_initialized"] = True + + @classproperty + def list_models(cls) -> dict[str, str | None]: + """All registered RLSSM models and their descriptions. + + This is the recommended entry point for newcomers to discover which + models are available without constructing a full model instance. + + Returns + ------- + dict[str, str | None] + Mapping of model name → description (``None`` if not provided). + + Examples + -------- + >>> from hssm.rl import RLSSM + >>> RLSSM.list_models + {'2AB_RescorlaWagner_DDM': 'RLSSM model with ...', ...} + """ + from .registry import list_models # noqa: PLC0415 + + return list_models() diff --git a/tests/rl/test_registry.py b/tests/rl/test_registry.py new file mode 100644 index 000000000..f8bbcf293 --- /dev/null +++ b/tests/rl/test_registry.py @@ -0,0 +1,594 @@ +"""Unit tests for the RL registry helpers. + +These tests exercise the registry module directly without constructing full +RLSSM model instances, so regressions in lazy SSM resolution, config +composition, and registration validation are caught at the module boundary. +""" + +from __future__ import annotations + +import logging +from copy import deepcopy +from typing import Any + +import pytest + +from hssm.rl import registry +from hssm.rl.config import RLSSMConfig +from hssm.utils import annotate_function + + +@pytest.fixture(autouse=True) +def isolated_registries(monkeypatch: pytest.MonkeyPatch) -> None: + """Isolate global registries so tests do not leak state.""" + monkeypatch.setattr(registry, "_SSM_REGISTRY", deepcopy(registry._SSM_REGISTRY)) + monkeypatch.setattr(registry, "_RLSSM_REGISTRY", deepcopy(registry._RLSSM_REGISTRY)) + monkeypatch.setattr(registry, "_SSM_LOGP_CACHE", dict(registry._SSM_LOGP_CACHE)) + + +@pytest.fixture +def learning_process() -> dict[str, Any]: + """Return an annotated RL learning rule for test models.""" + + @annotate_function( + inputs=["rl_alpha", "response", "feedback"], + outputs=["v"], + ) + def compute_v(rl_alpha, response, feedback): + return rl_alpha + + return {"v": compute_v} + + +@pytest.fixture +def annotated_ssm_base_logp() -> Any: + """Return an annotated SSM base log-likelihood function.""" + + @annotate_function( + inputs=["v", "a", "rt", "response"], + outputs=["logp"], + ) + def base_logp(v, a, rt, response): + return a + + return base_logp + + +class TestBuildSsmSpecFromModelconfig: + def test_unknown_model_raises(self) -> None: + """Completely unknown names should raise ValueError from the modelconfig bridge.""" + with pytest.raises(ValueError, match="not a registered custom SSM"): + registry._build_ssm_spec_from_modelconfig("totally_unknown_model") + + def test_no_approx_differentiable_raises( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Models without an approx_differentiable likelihood must be rejected.""" + import hssm.modelconfig as mc_module + + def _fake_get(name): # type: ignore[no-untyped-def] + return { + "list_params": ["v", "a"], + "response": ["rt", "response"], + "likelihoods": { + "analytical": { + "loglik": lambda: None, + "bounds": {}, + "backend": None, + }, + }, + } + + monkeypatch.setattr(mc_module, "get_default_model_config", _fake_get) + with pytest.raises(ValueError, match="no approx_differentiable likelihood"): + registry._build_ssm_spec_from_modelconfig("ddm") + + def test_factory_calls_onnx_loader( + self, + monkeypatch: pytest.MonkeyPatch, + annotated_ssm_base_logp: Any, + ) -> None: + """The lazy factory should call make_jax_matrix_logp_funcs_from_onnx with the + correct filename.""" + import hssm.distribution_utils.onnx as onnx_module + + called_with: list[str] = [] + + def _fake_onnx(model: str) -> Any: + called_with.append(model) + return annotated_ssm_base_logp + + monkeypatch.setattr( + onnx_module, "make_jax_matrix_logp_funcs_from_onnx", _fake_onnx + ) + monkeypatch.setattr( + registry, "make_jax_matrix_logp_funcs_from_onnx", _fake_onnx + ) + + spec = registry._build_ssm_spec_from_modelconfig("angle") + result = spec["ssm_base_logp_func_factory"]() + + assert called_with == ["angle.onnx"] + assert callable(result) + assert result.inputs == ["v", "a", "z", "t", "theta", "rt", "response"] + assert result.outputs == ["logp"] + + +class TestGetSsmLogp: + def test_builds_lazy_factory_once(self, annotated_ssm_base_logp: Any) -> None: + """Lazy SSM factories should only build and cache one function instance.""" + call_count = 0 + + def factory() -> Any: + nonlocal call_count + call_count += 1 + return annotated_ssm_base_logp + + registry._SSM_REGISTRY["lazy_unit_test_ssm"] = { + "ssm_base_logp_func_factory": factory, + "list_params_ssm": ["v", "a"], + "bounds_ssm": {"a": (0.3, 3.0)}, + "params_default_ssm": [0.0, 1.5], + "response": ["rt", "response"], + } + + first = registry._get_ssm_logp("lazy_unit_test_ssm") + second = registry._get_ssm_logp("lazy_unit_test_ssm") + + assert first is annotated_ssm_base_logp + assert second is first + assert call_count == 1 + + def test_resolves_builtin_via_modelconfig( + self, + monkeypatch: pytest.MonkeyPatch, + annotated_ssm_base_logp: Any, + ) -> None: + """_get_ssm_logp should build and cache the logp for a built-in SSM.""" + import hssm.distribution_utils.onnx as onnx_module + + monkeypatch.setattr( + onnx_module, + "make_jax_matrix_logp_funcs_from_onnx", + lambda **_: annotated_ssm_base_logp, + ) + monkeypatch.setattr( + registry, + "make_jax_matrix_logp_funcs_from_onnx", + lambda **_: annotated_ssm_base_logp, + ) + + result = registry._get_ssm_logp("angle") + + assert callable(result) + assert "angle" in registry._SSM_LOGP_CACHE + # Second call returns cached value without re-building. + assert registry._get_ssm_logp("angle") is result + + +class TestBuildSsmLogpFunc: + def test_raises_if_already_computed( + self, + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], + ) -> None: + """_build_ssm_logp_func must refuse functions that already have .computed set.""" + precomputed = annotate_function( + inputs=annotated_ssm_base_logp.inputs, + outputs=annotated_ssm_base_logp.outputs, + computed=learning_process, + )(annotated_ssm_base_logp) + + with pytest.raises(ValueError, match="already has a non-empty .computed"): + registry._build_ssm_logp_func(precomputed, learning_process) + + +class TestDeriveRlParams: + def test_excludes_response_and_extra_fields( + self, + learning_process: dict[str, Any], + ) -> None: + """Derived RL params should ignore response columns and extra fields.""" + derived = registry._derive_lp_params( + learning_process=learning_process, + response=["rt", "response"], + extra_fields=["feedback"], + ) + + assert derived == ["rl_alpha"] + + def test_warns_for_unannotated_lp_func(self) -> None: + """A learning-process function without .inputs should log a warning and be skipped.""" + + def unannotated_func(x): # type: ignore[no-untyped-def] + return x + + lp = {"v": unannotated_func} + result = registry._derive_lp_params( + learning_process=lp, + response=["rt", "response"], + extra_fields=["feedback"], + ) + # The unannotated function contributes no params. + assert result == [] + + +class TestRegisterSsm: + def test_caches_prebuilt_function(self, annotated_ssm_base_logp: Any) -> None: + """Registering a pre-built SSM should populate the cache immediately.""" + registry.register_ssm( + name="cached_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + ) + + assert registry._SSM_LOGP_CACHE["cached_ssm"] is annotated_ssm_base_logp + assert registry._SSM_REGISTRY["cached_ssm"]["response"] == ["rt", "response"] + + def test_rejects_precomputed_function( + self, + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], + ) -> None: + """SSM registration should reject functions that already carry computed params.""" + precomputed_logp = annotate_function( + inputs=annotated_ssm_base_logp.inputs, + outputs=annotated_ssm_base_logp.outputs, + computed=learning_process, + )(annotated_ssm_base_logp) + + with pytest.raises(ValueError, match="should not have a non-empty .computed"): + registry.register_ssm( + name="invalid_ssm", + ssm_base_logp_func=precomputed_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + ) + + def test_rejects_non_callable(self) -> None: + """register_ssm must raise when ssm_base_logp_func is not callable.""" + with pytest.raises(ValueError, match="must be callable"): + registry.register_ssm( + name="bad_ssm", + ssm_base_logp_func="not_a_function", # type: ignore[arg-type] + list_params_ssm=["v"], + bounds_ssm={}, + params_default_ssm=[0.0], + ) + + def test_rejects_unannotated_callable(self) -> None: + """register_ssm must raise when the callable lacks .inputs or .outputs.""" + + def plain(x): # type: ignore[no-untyped-def] + return x + + with pytest.raises( + ValueError, match="must be decorated with @annotate_function" + ): + registry.register_ssm( + name="unannotated_ssm", + ssm_base_logp_func=plain, + list_params_ssm=["v"], + bounds_ssm={}, + params_default_ssm=[0.0], + ) + + def test_warns_on_overwrite( + self, + annotated_ssm_base_logp: Any, + caplog: pytest.LogCaptureFixture, + ) -> None: + """Re-registering an existing SSM name should emit a warning.""" + registry.register_ssm( + name="dup_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + ) + + with caplog.at_level(logging.WARNING, logger="hssm"): + registry.register_ssm( + name="dup_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + ) + + assert any("dup_ssm" in r.message for r in caplog.records) + + +class TestRegisterRlssmModel: + def test_copies_mutable_inputs( + self, + learning_process: dict[str, Any], + ) -> None: + """Caller mutations after registration must not alter the stored model.""" + learning_process_params = ["rl_alpha"] + learning_process_bounds = {"rl_alpha": (0.0, 1.0)} + rl_defaults = [0.2] + extra_fields = ["feedback"] + choices = [0, 1] + + registry.register_rlssm_model( + name="copy_test_model", + decision_process="angle", + learning_process=learning_process, + learning_process_params=learning_process_params, + learning_process_bounds=learning_process_bounds, + learning_process_params_default=rl_defaults, + extra_fields=extra_fields, + choices=choices, + ) + + learning_process_params.append("scaler") + learning_process_bounds["scaler"] = (0.0, 10.0) + rl_defaults.append(1.0) + extra_fields.append("trial") + choices.append(2) + # Built-in SSMs (e.g. "angle") are derived from modelconfig on each call + # and are never stored as shared mutable state in _SSM_REGISTRY, so there + # is no registry entry to corrupt here. + learning_process["other"] = next(iter(learning_process.values())) + + stored = registry._RLSSM_REGISTRY["copy_test_model"] + + assert stored["decision_process"]["name"] == "angle" + assert stored["decision_process"]["response"] == ["rt", "response"] + assert stored["learning_process_params"] == ["rl_alpha"] + assert stored["learning_process_bounds"] == {"rl_alpha": (0.0, 1.0)} + assert stored["learning_process_params_default"] == [0.2] + assert stored["extra_fields"] == ["feedback"] + assert stored["choices"] == [0, 1] + assert list(stored["learning_process"]) == ["v"] + + def test_warns_on_overwrite( + self, + learning_process: dict[str, Any], + annotated_ssm_base_logp: Any, + caplog: pytest.LogCaptureFixture, + ) -> None: + """Re-registering an existing RLSSM model name should emit a warning.""" + registry.register_ssm( + name="overwrite_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + ) + registry.register_rlssm_model( + name="overwrite_rlssm", + decision_process="overwrite_ssm", + learning_process=learning_process, + learning_process_params=["rl_alpha"], + learning_process_bounds={"rl_alpha": (0.0, 1.0)}, + learning_process_params_default=[0.2], + ) + + with caplog.at_level(logging.WARNING, logger="hssm"): + registry.register_rlssm_model( + name="overwrite_rlssm", + decision_process="overwrite_ssm", + learning_process=learning_process, + learning_process_params=["rl_alpha"], + learning_process_bounds={"rl_alpha": (0.0, 1.0)}, + learning_process_params_default=[0.2], + ) + + assert any("overwrite_rlssm" in r.message for r in caplog.records) + + +class TestGetRlssmModelConfig: + def test_builds_expected_config( + self, + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], + ) -> None: + """Registry config composition should exclude computed SSM params.""" + registry.register_ssm( + name="unit_test_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + response=["rt", "response"], + ) + registry.register_rlssm_model( + name="unit_test_model", + decision_process="unit_test_ssm", + learning_process=learning_process, + learning_process_params=["rl_alpha"], + learning_process_bounds={"rl_alpha": (0.0, 1.0)}, + learning_process_params_default=[0.2], + extra_fields=["feedback"], + choices=[0, 1], + ) + + config = registry.get_rlssm_model_config("unit_test_model") + + assert isinstance(config, RLSSMConfig) + assert config.list_params == ["rl_alpha", "a"] + assert config.bounds == {"rl_alpha": (0.0, 1.0), "a": (0.3, 3.0)} + assert config.params_default == [0.2, 1.5] + assert config.response == ["rt", "response"] + assert config.ssm_logp_func.computed == learning_process + + # Mutating the returned config must not corrupt the registry's stored list. + config.response.append("mutated") + assert registry._SSM_REGISTRY["unit_test_ssm"]["response"] == ["rt", "response"] + + def test_respects_explicit_empty_rl_fields( + self, + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], + ) -> None: + """Explicit empty RL collections must not trigger fallback derivation.""" + registry.register_ssm( + name="empty_rl_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + response=["rt", "response"], + ) + registry._RLSSM_REGISTRY["empty_rl_model"] = { + "decision_process": "empty_rl_ssm", + "learning_process": learning_process, + "learning_process_params": [], + "learning_process_bounds": {}, + "learning_process_params_default": [], + "extra_fields": ["feedback"], + "choices": [0, 1], + "description": "test model", + "decision_process_loglik_kind": "approx_differentiable", + "learning_process_kind": "blackbox", + } + + config = registry.get_rlssm_model_config("empty_rl_model") + + assert config.list_params == ["a"] + assert config.bounds == {"a": (0.3, 3.0)} + assert config.params_default == [1.5] + + def test_unknown_model_raises(self) -> None: + """Unknown RLSSM model names should fail with a clear error.""" + with pytest.raises(ValueError, match="not found in the RLSSM registry"): + registry.get_rlssm_model_config("does_not_exist") + + def test_derives_rl_params_when_absent( + self, + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], + ) -> None: + """When learning_process_params is absent from the registry entry, params are derived + from the learning_process .inputs.""" + registry.register_ssm( + name="derive_params_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + response=["rt", "response"], + ) + # Inject an entry without learning_process_params so the fallback derivation runs. + registry._RLSSM_REGISTRY["derive_params_model"] = { + "decision_process": "derive_params_ssm", + "learning_process": learning_process, + # learning_process_params deliberately absent + "learning_process_bounds": {}, + "learning_process_params_default": [], + "extra_fields": ["feedback"], + "choices": [0, 1], + "description": None, + "decision_process_loglik_kind": "approx_differentiable", + "learning_process_kind": "blackbox", + } + + config = registry.get_rlssm_model_config("derive_params_model") + + # "rl_alpha" is the only input to learning_process that isn't response/extra. + assert "rl_alpha" in config.list_params + + def test_raises_for_missing_bounds( + self, + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], + ) -> None: + """SSM params missing from bounds_ssm must raise immediately with a clear error. + + Previously the factory silently skipped such params, producing a broken + RLSSMConfig that only failed later inside _RLSSM.__init__. Now it must + raise at the factory boundary so the error message points at the root cause. + """ + registry.register_ssm( + name="no_bounds_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + # "a" intentionally has no bounds entry + bounds_ssm={}, + params_default_ssm=[0.0, 1.5], + response=["rt", "response"], + ) + registry.register_rlssm_model( + name="no_bounds_model", + decision_process="no_bounds_ssm", + learning_process=learning_process, + learning_process_params=["rl_alpha"], + learning_process_bounds={"rl_alpha": (0.0, 1.0)}, + learning_process_params_default=[0.2], + extra_fields=["feedback"], + choices=[0, 1], + ) + + with pytest.raises(ValueError, match="no entry in bounds_ssm"): + registry.get_rlssm_model_config("no_bounds_model") + + +class TestBuiltinModels: + @pytest.mark.parametrize( + "model_name, expected_dp", + [ + ("2AB_RescorlaWagner_DDM", "ddm"), + ("2AB_RescorlaWagner_Weibull", "weibull"), + ], + ) + def test_are_registered(self, model_name: str, expected_dp: str) -> None: + """2AB_RescorlaWagner_DDM and _Weibull must be present in the RLSSM registry.""" + assert model_name in registry._RLSSM_REGISTRY + entry = registry._RLSSM_REGISTRY[model_name] + assert entry["decision_process"]["name"] == expected_dp + assert entry["learning_process_params"] == ["rl_alpha", "scaler"] + assert entry["extra_fields"] == ["feedback"] + assert entry["choices"] == [0, 1] + assert entry["decision_process_loglik_kind"] == "approx_differentiable" + assert entry["learning_process_kind"] == "blackbox" + + @pytest.mark.parametrize( + "model_name", + ["2AB_RescorlaWagner_DDM", "2AB_RescorlaWagner_Weibull"], + ) + def test_config_structure( + self, + monkeypatch: pytest.MonkeyPatch, + annotated_ssm_base_logp: Any, + model_name: str, + ) -> None: + """get_rlssm_model_config should produce a well-formed RLSSMConfig for + both the DDM and Weibull starter-pack models.""" + import hssm.distribution_utils.onnx as onnx_module + + monkeypatch.setattr( + onnx_module, + "make_jax_matrix_logp_funcs_from_onnx", + lambda model: annotated_ssm_base_logp, + ) + monkeypatch.setattr( + registry, + "make_jax_matrix_logp_funcs_from_onnx", + lambda model: annotated_ssm_base_logp, + ) + + config = registry.get_rlssm_model_config(model_name) + + assert isinstance(config, RLSSMConfig) + # RL params come first + assert config.list_params[:2] == ["rl_alpha", "scaler"] + assert "rl_alpha" in config.bounds + assert "scaler" in config.bounds + assert config.choices == (0, 1) + assert config.extra_fields == ["feedback"] + assert config.ssm_logp_func.computed == {"v": registry._compute_v_annotated} + + +class TestListModels: + def test_returns_all_names(self) -> None: + """list_models should return every key in _RLSSM_REGISTRY with its description.""" + result = registry.list_models() + + assert set(result.keys()) == set(registry._RLSSM_REGISTRY.keys()) + for name, desc in result.items(): + assert desc == registry._RLSSM_REGISTRY[name].get("description") diff --git a/tests/test_rl_builder_output_shape.py b/tests/rl/test_rl_builder_output_shape.py similarity index 100% rename from tests/test_rl_builder_output_shape.py rename to tests/rl/test_rl_builder_output_shape.py diff --git a/tests/test_rl_utils.py b/tests/rl/test_rl_utils.py similarity index 100% rename from tests/test_rl_utils.py rename to tests/rl/test_rl_utils.py diff --git a/tests/test_rldm_likelihood.py b/tests/rl/test_rldm_likelihood.py similarity index 97% rename from tests/test_rldm_likelihood.py rename to tests/rl/test_rldm_likelihood.py index e679f22c7..54fd544cc 100644 --- a/tests/test_rldm_likelihood.py +++ b/tests/rl/test_rldm_likelihood.py @@ -14,7 +14,7 @@ @pytest.fixture def fixture_path(): - return Path(__file__).parent / "fixtures" + return Path(__file__).parent.parent / "fixtures" def test_make_rldm_logp_func(fixture_path): diff --git a/tests/rl/test_rlssm.py b/tests/rl/test_rlssm.py new file mode 100644 index 000000000..e5fe341c0 --- /dev/null +++ b/tests/rl/test_rlssm.py @@ -0,0 +1,436 @@ +"""Tests for the RLSSM class. + +Mirrors the structure of tests/test_hssm.py, covering initialisation, +config validation, param keys, balanced-panel enforcement, the no-lapse +variant, bambi / PyMC model construction, and a sampling smoke test. +""" + +import logging +from collections.abc import Generator +from pathlib import Path +from unittest.mock import patch + +import cloudpickle +import jax.numpy as jnp +import numpy as np +import pandas as pd +import pytensor +import pytest + +import hssm +from hssm.distribution_utils import make_distribution as real_make_distribution +from hssm.rl import _RLSSM, RLSSM, RLSSMConfig, register_rlssm_model +from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise +from hssm.utils import annotate_function + +# Annotate the RL learning function: maps +# (rl_alpha, scaler, response, feedback) -> v +_compute_v_annotated = annotate_function( + inputs=["rl_alpha", "scaler", "response", "feedback"], + outputs=["v"], +)(compute_v_subject_wise) + + +@annotate_function( + inputs=["v", "a", "z", "t", "theta", "rt", "response"], + outputs=["logp"], + computed={"v": _compute_v_annotated}, +) +def _dummy_ssm_logp(lan_matrix: jnp.ndarray) -> jnp.ndarray: + """Return per-trial log-probabilities (column-sum); structural tests only.""" + # Return 1D (N,) — PyTensor declares the Op output as pt.vector(), so + # gradients arrive as (N,). A (N,1) return causes a VJP shape mismatch. + return jnp.sum(lan_matrix, axis=1) + + +@pytest.fixture(scope="module", autouse=True) +def _set_floatx_float32() -> Generator[None, None, None]: + """Ensure float32 is used for this module's tests, then restore previous setting.""" + prev_floatx = pytensor.config.floatX + hssm.set_floatX("float32", update_jax=True) + try: + yield + finally: + hssm.set_floatX(prev_floatx, update_jax=True) + + +@pytest.fixture(scope="module") +def rldm_data() -> pd.DataFrame: + """Load the RLDM fixture dataset (balanced panel).""" + raw = np.load( + Path(__file__).parent.parent / "fixtures" / "rldm_data.npy", allow_pickle=True + ).item() + return pd.DataFrame(raw["data"]) + + +@pytest.fixture(scope="module") +def rlssm_config() -> RLSSMConfig: + """Minimal but valid RLSSMConfig for the RLDM fixture dataset.""" + return RLSSMConfig( + model_name="rldm_test", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_kind="blackbox", + list_params=["rl_alpha", "scaler", "a", "theta", "t", "z"], + params_default=[0.1, 1.0, 1.0, 0.0, 0.3, 0.5], + bounds={ + "rl_alpha": (0.0, 1.0), + "scaler": (0.0, 10.0), + "a": (0.1, 3.0), + "theta": (-0.1, 0.1), + "t": (0.001, 1.0), + "z": (0.1, 0.9), + }, + learning_process={"v": _compute_v_annotated}, + response=["rt", "response"], + choices=[0, 1], + extra_fields=["feedback"], + ssm_logp_func=_dummy_ssm_logp, + ) + + +class TestRLSSMInit: + """Basic construction, attribute checks, and invalid-input guards at construction time.""" + + def test_rlssm_init(self, rldm_data, rlssm_config) -> None: + """Basic RLSSM initialisation should succeed and return an RLSSM instance.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert isinstance(model, RLSSM) + assert model.model_config.model_name == "rldm_test" + + def test_rlssm_panel_attrs(self, rldm_data, rlssm_config) -> None: + """n_participants and n_trials should match the fixture data structure.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + + n_participants = rldm_data["participant_id"].nunique() + n_trials = len(rldm_data) // n_participants + + assert model.n_participants == n_participants + assert model.n_trials == n_trials + + def test_rlssm_params_keys(self, rldm_data, rlssm_config) -> None: + """model.params should contain exactly list_params + p_outlier.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + expected = set(rlssm_config.list_params) | {"p_outlier"} + assert set(model.params.keys()) == expected + + def test_rlssm_unbalanced_raises(self, rldm_data, rlssm_config) -> None: + """Dropping one row should make the panel unbalanced → ValueError.""" + unbalanced = rldm_data.iloc[:-1].copy() + with pytest.raises(ValueError, match="balanced panels"): + RLSSM(data=unbalanced, model_config=rlssm_config) + + def test_rlssm_nan_participant_id_raises(self, rldm_data, rlssm_config) -> None: + """NaN in participant_id column should raise ValueError before groupby silently drops rows.""" + nan_data = rldm_data.copy() + nan_data.loc[nan_data.index[0], "participant_id"] = float("nan") + with pytest.raises(ValueError, match="NaN"): + RLSSM(data=nan_data, model_config=rlssm_config) + + def test_rlssm_missing_ssm_logp_func_raises(self, rldm_data, rlssm_config) -> None: + """RLSSMConfig without ssm_logp_func should raise ValueError on init.""" + bad_config = RLSSMConfig( + model_name="rldm_bad", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_kind="blackbox", + list_params=rlssm_config.list_params, + params_default=rlssm_config.params_default, + bounds=rlssm_config.bounds, + learning_process=rlssm_config.learning_process, + response=list(rlssm_config.response), + choices=list(rlssm_config.choices), + extra_fields=list(rlssm_config.extra_fields), + # ssm_logp_func intentionally omitted → defaults to None + ) + with pytest.raises(ValueError, match="ssm_logp_func"): + RLSSM(data=rldm_data, model_config=bad_config) + + def test_rlssm_unannotated_ssm_logp_func_raises( + self, rldm_data, rlssm_config + ) -> None: + """A plain callable without @annotate_function attrs should raise ValueError.""" + bad_config = RLSSMConfig( + model_name="rldm_bad", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_kind="blackbox", + list_params=rlssm_config.list_params, + params_default=rlssm_config.params_default, + bounds=rlssm_config.bounds, + learning_process=rlssm_config.learning_process, + response=list(rlssm_config.response), + choices=list(rlssm_config.choices), + extra_fields=list(rlssm_config.extra_fields), + ssm_logp_func=lambda x: x, # callable but no .inputs/.outputs/.computed + ) + with pytest.raises(ValueError, match="annotate_function"): + RLSSM(data=rldm_data, model_config=bad_config) + + def test_rlssm_missing_data_raises(self, rldm_data, rlssm_config) -> None: + """Passing missing_data!=False should raise NotImplementedError with 'missing_data' in msg.""" + with pytest.raises(NotImplementedError, match="missing_data"): + RLSSM(data=rldm_data, model_config=rlssm_config, missing_data=True) + + def test_rlssm_deadline_raises(self, rldm_data, rlssm_config) -> None: + """Passing deadline!=False should raise NotImplementedError with 'deadline' in msg.""" + with pytest.raises(NotImplementedError, match="deadline"): + RLSSM(data=rldm_data, model_config=rlssm_config, deadline=True) + + +class TestRLSSMModelStructure: + """Internal model anatomy after construction: params, prefix, lapse, bambi/pymc, extra_fields.""" + + def test_rlssm_params_is_trialwise_aligned(self, rldm_data, rlssm_config) -> None: + """params_is_trialwise must align with list_params (same length, p_outlier=False).""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model.model_config.list_params is not None + params_is_trialwise = [ + name != "p_outlier" for name in model.model_config.list_params + ] + assert len(params_is_trialwise) == len(model.model_config.list_params) + for name, is_tw in zip(model.model_config.list_params, params_is_trialwise): + if name == "p_outlier": + assert not is_tw, "p_outlier must be non-trialwise" + else: + assert is_tw, f"{name} must be trialwise" + + def test_rlssm_get_prefix(self, rldm_data, rlssm_config) -> None: + """_get_prefix must use token-based matching, not substring search. + + - 'rl_alpha_Intercept' → 'rl_alpha' (underscore-containing RL param) + - 'p_outlier_log__' → 'p_outlier' (lapse param via token loop, not substring) + - 'a_Intercept' → 'a' (single-token standard param) + """ + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model._get_prefix("rl_alpha_Intercept") == "rl_alpha" + assert model._get_prefix("p_outlier_log__") == "p_outlier" + assert model._get_prefix("p_outlier") == "p_outlier" + assert model._get_prefix("a_Intercept") == "a" + # Fallback: not in params + assert model._get_prefix("unknown_param") == "unknown_param" + + def test_rlssm_no_lapse(self, rldm_data, rlssm_config) -> None: + """Setting p_outlier=None should remove p_outlier from params.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config, p_outlier=None) + assert "p_outlier" not in model.params + + def test_rlssm_model_built(self, rldm_data, rlssm_config) -> None: + """The bambi model should be built and the computed param 'v' absent from params.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model.model is not None + # rl_alpha is a free (sampled) parameter + assert "rl_alpha" in model.params + # v is computed inside the Op; it must NOT appear as a free parameter + assert "v" not in model.params + + def test_rlssm_extra_fields_are_copies(self, rldm_data, rlssm_config) -> None: + """extra_fields passed to make_distribution must be independent numpy copies. + + to_numpy(copy=True) should return a new buffer; if it returned a view, + in-place mutations of the DataFrame would silently corrupt the distribution. + """ + model = RLSSM(data=rldm_data, model_config=rlssm_config) + captured: dict = {} + + def capturing_make_distribution(*args, **kwargs): + captured["extra_fields"] = kwargs.get("extra_fields") + return real_make_distribution(*args, **kwargs) + + with patch( + "hssm.rl.rlssm.make_distribution", side_effect=capturing_make_distribution + ): + model._make_model_distribution() + + assert captured.get("extra_fields") is not None + for field_name, arr in zip(rlssm_config.extra_fields, captured["extra_fields"]): + original = model.data[field_name].to_numpy() + assert not np.shares_memory(arr, original), ( + f"extra_fields['{field_name}'] shares memory with the DataFrame — " + "it is a view, not a copy" + ) + + def test_rlssm_pymc_model(self, rldm_data, rlssm_config) -> None: + """pymc_model should be accessible after model construction.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model.pymc_model is not None + + def test_rlssm_no_extra_fields_none_passed_to_make_distribution( + self, rldm_data, rlssm_config + ) -> None: + """When extra_fields is empty, make_distribution receives extra_fields=None.""" + config_no_extra = RLSSMConfig( + model_name="rldm_no_extra", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_kind="blackbox", + list_params=rlssm_config.list_params, + params_default=rlssm_config.params_default, + bounds=rlssm_config.bounds, + learning_process=rlssm_config.learning_process, + response=list(rlssm_config.response), + choices=list(rlssm_config.choices), + extra_fields=[], # empty → should pass None to make_distribution + ssm_logp_func=_dummy_ssm_logp, + ) + model = RLSSM(data=rldm_data, model_config=config_no_extra) + captured: dict = {} + + def capturing_make_distribution(*args, **kwargs): + captured["extra_fields"] = kwargs.get("extra_fields") + return real_make_distribution(*args, **kwargs) + + with patch( + "hssm.rl.rlssm.make_distribution", side_effect=capturing_make_distribution + ): + model._make_model_distribution() + + assert captured.get("extra_fields") is None + + +class TestRLSSMSerialization: + """Cloudpickle serialisation and deserialisation.""" + + def test_rlssm_pickle_round_trip( + self, rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig + ) -> None: + """Cloudpickle round-trip must reconstruct an equivalent RLSSM. + + Verifies that __getstate__ / __setstate__ survive serialisation: + - The reconstructed object is a fresh RLSSM (not the same instance). + - n_participants and n_trials are preserved. + - list_params (including p_outlier) are preserved. + - model_config.model_name is preserved. + - model.model (bambi model) is rebuilt, confirming full re-initialisation. + """ + model = RLSSM(data=rldm_data, model_config=rlssm_config) + blob = cloudpickle.dumps(model) + restored = cloudpickle.loads(blob) + + assert restored is not model + assert isinstance(restored, RLSSM) + assert restored.n_participants == model.n_participants + assert restored.n_trials == model.n_trials + assert restored.list_params == model.list_params + assert restored.model_config.model_name == model.model_config.model_name + assert restored.model is not None + + +class TestRLSSMSampling: + """Slow sampling smoke tests.""" + + @pytest.mark.slow + def test_rlssm_sample_smoke(self, rldm_data, rlssm_config) -> None: + """Minimal sampling run should return an InferenceData object.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + trace = model.sample( + draws=4, tune=50, chains=1, cores=1, sampler="numpyro", target_accept=0.9 + ) + assert trace is not None + + +class TestRLSSMSimplifiedInterface: + """Public model= kwarg API: registry lookups, register_rlssm_model, unsupported-feature properties.""" + + def test_rlssm_is_subclass_of_internal(self) -> None: + """RLSSM must be a subclass of _RLSSM.""" + assert issubclass(RLSSM, _RLSSM) + + @pytest.mark.parametrize( + "model_name, expected_dp", + [ + ("2AB_RescorlaWagner_DDM", "ddm"), + ("2AB_RescorlaWagner_Angle", "angle"), + ("2AB_RescorlaWagner_Weibull", "weibull"), + ], + ) + def test_rlssm_builtin_models_instantiate( + self, rldm_data, model_name: str, expected_dp: str + ) -> None: + """All built-in 2AB_RescorlaWagner_* models should instantiate correctly.""" + model = RLSSM(data=rldm_data, model=model_name) + assert isinstance(model, RLSSM) + assert model.model_config.decision_process == expected_dp + assert "rl_alpha" in model.params + assert "scaler" in model.params + assert "a" in model.params + assert "v" not in model.params + + def test_rlssm_default_model_is_ddm(self, rldm_data) -> None: + """Omitting model should default to '2AB_RescorlaWagner_DDM'.""" + model = RLSSM(data=rldm_data) + assert isinstance(model, RLSSM) + assert model.model_config.decision_process == "ddm" + + def test_rlssm_model_config_provided(self, rldm_data, rlssm_config) -> None: + """Passing model_config= directly should bypass the registry.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model.model_config.model_name == rlssm_config.model_name + + def test_rlssm_unregistered_model_raises(self, rldm_data) -> None: + """Using an unknown model name should raise ValueError.""" + with pytest.raises(ValueError, match="not found in the RLSSM registry"): + RLSSM(data=rldm_data, model="model_not_in_registry") + + def test_rlssm_missing_data_property_raises(self, rldm_data) -> None: + """Accessing .missing_data on a built RLSSM instance must raise NotImplementedError.""" + model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") + with pytest.raises(NotImplementedError, match="missing_data"): + _ = model.missing_data + + def test_rlssm_deadline_property_raises(self, rldm_data) -> None: + """Accessing .deadline on a built RLSSM instance must raise NotImplementedError.""" + model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") + with pytest.raises(NotImplementedError, match="deadline"): + _ = model.deadline + + def test_rlssm_loglik_missing_data_property_raises(self, rldm_data) -> None: + """Accessing .loglik_missing_data on a built RLSSM instance must raise NotImplementedError.""" + model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") + with pytest.raises(NotImplementedError, match="loglik_missing_data"): + _ = model.loglik_missing_data + + def test_register_rlssm_model(self, rldm_data) -> None: + """A user-registered model should be instantiable via the simplified interface.""" + # Re-use the existing annotated learning function and ssm logp from the + # module-level helpers defined at the top of this test file. + register_rlssm_model( + name="rldm_custom_test", + decision_process="angle", + learning_process={"v": _compute_v_annotated}, + learning_process_params=["rl_alpha", "scaler"], + learning_process_bounds={"rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0)}, + learning_process_params_default=[0.1, 1.0], + extra_fields=["feedback"], + choices=[0, 1], + ) + model = RLSSM(data=rldm_data, model="rldm_custom_test") + assert isinstance(model, RLSSM) + assert "rl_alpha" in model.params + assert "v" not in model.params + + def test_rlssm_init_args_uses_simplified_interface(self, rldm_data) -> None: + """_init_args should reflect the simplified constructor, not model_config.""" + model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") + assert "model" in model._init_args + assert model._init_args["model"] == "2AB_RescorlaWagner_DDM" + # model_config should not be baked in as a hard reference + assert "model_config" in model._init_args + assert model._init_args["model_config"] is None + + def test_rlssm_model_config_with_overrides_warns( + self, rldm_data, rlssm_config, caplog + ) -> None: + """Warn when model_config is given alongside model/overrides. + + The extra arguments (model, learning_process, decision_process, choices) + should be ignored and a warning emitted. + """ + with caplog.at_level(logging.WARNING, logger="hssm"): + RLSSM(data=rldm_data, model_config=rlssm_config, model="some_other_model") + + assert any("ignoring" in r.message for r in caplog.records) diff --git a/tests/test_rlssm_config.py b/tests/rl/test_rlssm_config.py similarity index 100% rename from tests/test_rlssm_config.py rename to tests/rl/test_rlssm_config.py diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py deleted file mode 100644 index 973061c73..000000000 --- a/tests/test_rlssm.py +++ /dev/null @@ -1,329 +0,0 @@ -"""Tests for the RLSSM class. - -Mirrors the structure of tests/test_hssm.py, covering initialisation, -config validation, param keys, balanced-panel enforcement, the no-lapse -variant, bambi / PyMC model construction, and a sampling smoke test. -""" - -from collections.abc import Generator -from pathlib import Path - -import jax.numpy as jnp -import numpy as np -import pandas as pd -import pytensor -import pytest - -import hssm -from hssm.rl import RLSSM, RLSSMConfig -from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise -from hssm.utils import annotate_function - -# --------------------------------------------------------------------------- -# Module-level annotated helpers (shared by all tests) -# --------------------------------------------------------------------------- - -# Annotate the RL learning function: maps -# (rl_alpha, scaler, response, feedback) -> v -_compute_v_annotated = annotate_function( - inputs=["rl_alpha", "scaler", "response", "feedback"], - outputs=["v"], -)(compute_v_subject_wise) - - -# Annotated SSM log-likelihood function (simplified for testing). -# It receives a 2-D lan_matrix whose columns correspond to -# [v, a, z, t, theta, rt, response] -# and returns per-trial log-probabilities of shape (n_total_trials,). -@annotate_function( - inputs=["v", "a", "z", "t", "theta", "rt", "response"], - outputs=["logp"], - computed={"v": _compute_v_annotated}, -) -def _dummy_ssm_logp(lan_matrix: jnp.ndarray) -> jnp.ndarray: - """Return per-trial log-probabilities (column-sum); structural tests only.""" - # Return 1D (N,) — PyTensor declares the Op output as pt.vector(), so - # gradients arrive as (N,). A (N,1) return causes a VJP shape mismatch. - return jnp.sum(lan_matrix, axis=1) - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture(scope="module", autouse=True) -def _set_floatx_float32() -> Generator[None, None, None]: - """Ensure float32 is used for this module's tests, then restore previous setting.""" - prev_floatx = pytensor.config.floatX - hssm.set_floatX("float32", update_jax=True) - try: - yield - finally: - hssm.set_floatX(prev_floatx, update_jax=True) - - -@pytest.fixture(scope="module") -def rldm_data() -> pd.DataFrame: - """Load the RLDM fixture dataset (balanced panel).""" - raw = np.load( - Path(__file__).parent / "fixtures" / "rldm_data.npy", allow_pickle=True - ).item() - return pd.DataFrame(raw["data"]) - - -@pytest.fixture(scope="module") -def rlssm_config() -> RLSSMConfig: - """Minimal but valid RLSSMConfig for the RLDM fixture dataset.""" - return RLSSMConfig( - model_name="rldm_test", - loglik_kind="approx_differentiable", - decision_process="angle", - decision_process_loglik_kind="approx_differentiable", - learning_process_kind="blackbox", - list_params=["rl_alpha", "scaler", "a", "theta", "t", "z"], - params_default=[0.1, 1.0, 1.0, 0.0, 0.3, 0.5], - bounds={ - "rl_alpha": (0.0, 1.0), - "scaler": (0.0, 10.0), - "a": (0.1, 3.0), - "theta": (-0.1, 0.1), - "t": (0.001, 1.0), - "z": (0.1, 0.9), - }, - learning_process={"v": _compute_v_annotated}, - response=["rt", "response"], - choices=[0, 1], - extra_fields=["feedback"], - ssm_logp_func=_dummy_ssm_logp, - ) - - -# --------------------------------------------------------------------------- -# Initialisation & config-validation tests -# --------------------------------------------------------------------------- - - -def test_rlssm_init(rldm_data, rlssm_config) -> None: - """Basic RLSSM initialisation should succeed and return an RLSSM instance.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - assert isinstance(model, RLSSM) - assert model.model_config.model_name == "rldm_test" - - -def test_rlssm_panel_attrs(rldm_data, rlssm_config) -> None: - """n_participants and n_trials should match the fixture data structure.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - - n_participants = rldm_data["participant_id"].nunique() - n_trials = len(rldm_data) // n_participants - - assert model.n_participants == n_participants - assert model.n_trials == n_trials - - -def test_rlssm_params_keys(rldm_data, rlssm_config) -> None: - """model.params should contain exactly list_params + p_outlier.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - expected = set(rlssm_config.list_params) | {"p_outlier"} - assert set(model.params.keys()) == expected - - -def test_rlssm_unbalanced_raises(rldm_data, rlssm_config) -> None: - """Dropping one row should make the panel unbalanced → ValueError.""" - unbalanced = rldm_data.iloc[:-1].copy() - with pytest.raises(ValueError, match="balanced panels"): - RLSSM(data=unbalanced, model_config=rlssm_config) - - -def test_rlssm_nan_participant_id_raises(rldm_data, rlssm_config) -> None: - """NaN in participant_id column should raise ValueError before groupby silently drops rows.""" - nan_data = rldm_data.copy() - nan_data.loc[nan_data.index[0], "participant_id"] = float("nan") - with pytest.raises(ValueError, match="NaN"): - RLSSM(data=nan_data, model_config=rlssm_config) - - -def test_rlssm_missing_ssm_logp_func_raises(rldm_data, rlssm_config) -> None: - """RLSSMConfig without ssm_logp_func should raise ValueError on init.""" - bad_config = RLSSMConfig( - model_name="rldm_bad", - loglik_kind="approx_differentiable", - decision_process="angle", - decision_process_loglik_kind="approx_differentiable", - learning_process_kind="blackbox", - list_params=rlssm_config.list_params, - params_default=rlssm_config.params_default, - bounds=rlssm_config.bounds, - learning_process=rlssm_config.learning_process, - response=list(rlssm_config.response), - choices=list(rlssm_config.choices), - extra_fields=list(rlssm_config.extra_fields), - # ssm_logp_func intentionally omitted → defaults to None - ) - with pytest.raises(ValueError, match="ssm_logp_func"): - RLSSM(data=rldm_data, model_config=bad_config) - - -def test_rlssm_unannotated_ssm_logp_func_raises(rldm_data, rlssm_config) -> None: - """A plain callable without @annotate_function attrs should raise ValueError.""" - bad_config = RLSSMConfig( - model_name="rldm_bad", - loglik_kind="approx_differentiable", - decision_process="angle", - decision_process_loglik_kind="approx_differentiable", - learning_process_kind="blackbox", - list_params=rlssm_config.list_params, - params_default=rlssm_config.params_default, - bounds=rlssm_config.bounds, - learning_process=rlssm_config.learning_process, - response=list(rlssm_config.response), - choices=list(rlssm_config.choices), - extra_fields=list(rlssm_config.extra_fields), - ssm_logp_func=lambda x: x, # callable but no .inputs/.outputs/.computed - ) - with pytest.raises(ValueError, match="annotate_function"): - RLSSM(data=rldm_data, model_config=bad_config) - - -def test_rlssm_missing_data_raises(rldm_data, rlssm_config) -> None: - """Passing missing_data!=False should raise NotImplementedError with 'missing_data' in msg.""" - with pytest.raises(NotImplementedError, match="missing_data"): - RLSSM(data=rldm_data, model_config=rlssm_config, missing_data=True) - - -def test_rlssm_deadline_raises(rldm_data, rlssm_config) -> None: - """Passing deadline!=False should raise NotImplementedError with 'deadline' in msg.""" - with pytest.raises(NotImplementedError, match="deadline"): - RLSSM(data=rldm_data, model_config=rlssm_config, deadline=True) - - -# --------------------------------------------------------------------------- -# Model-structure tests -# --------------------------------------------------------------------------- - - -def test_rlssm_params_is_trialwise_aligned(rldm_data, rlssm_config) -> None: - """params_is_trialwise must align with list_params (same length, p_outlier=False).""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - assert model.model_config.list_params is not None - params_is_trialwise = [ - name != "p_outlier" for name in model.model_config.list_params - ] - assert len(params_is_trialwise) == len(model.model_config.list_params) - for name, is_tw in zip(model.model_config.list_params, params_is_trialwise): - if name == "p_outlier": - assert not is_tw, "p_outlier must be non-trialwise" - else: - assert is_tw, f"{name} must be trialwise" - - -def test_rlssm_get_prefix(rldm_data, rlssm_config) -> None: - """_get_prefix must use token-based matching, not substring search. - - - 'rl_alpha_Intercept' → 'rl_alpha' (underscore-containing RL param) - - 'p_outlier_log__' → 'p_outlier' (lapse param via token loop, not substring) - - 'a_Intercept' → 'a' (single-token standard param) - """ - model = RLSSM(data=rldm_data, model_config=rlssm_config) - assert model._get_prefix("rl_alpha_Intercept") == "rl_alpha" - assert model._get_prefix("p_outlier_log__") == "p_outlier" - assert model._get_prefix("p_outlier") == "p_outlier" - assert model._get_prefix("a_Intercept") == "a" - # Fallback: not in params - assert model._get_prefix("unknown_param") == "unknown_param" - - -def test_rlssm_no_lapse(rldm_data, rlssm_config) -> None: - """Setting p_outlier=None should remove p_outlier from params.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config, p_outlier=None) - assert "p_outlier" not in model.params - - -def test_rlssm_model_built(rldm_data, rlssm_config) -> None: - """The bambi model should be built and the computed param 'v' absent from params.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - assert model.model is not None - # rl_alpha is a free (sampled) parameter - assert "rl_alpha" in model.params - # v is computed inside the Op; it must NOT appear as a free parameter - assert "v" not in model.params - - -def test_rlssm_extra_fields_are_copies(rldm_data, rlssm_config) -> None: - """extra_fields passed to make_distribution must be independent numpy copies. - - to_numpy(copy=True) should return a new buffer; if it returned a view, - in-place mutations of the DataFrame would silently corrupt the distribution. - """ - from unittest.mock import patch - - from hssm.distribution_utils import make_distribution as real_make_distribution - - model = RLSSM(data=rldm_data, model_config=rlssm_config) - captured: dict = {} - - def capturing_make_distribution(*args, **kwargs): - captured["extra_fields"] = kwargs.get("extra_fields") - return real_make_distribution(*args, **kwargs) - - with patch( - "hssm.rl.rlssm.make_distribution", side_effect=capturing_make_distribution - ): - model._make_model_distribution() - - assert captured.get("extra_fields") is not None - for field_name, arr in zip(rlssm_config.extra_fields, captured["extra_fields"]): - original = model.data[field_name].to_numpy() - assert not np.shares_memory(arr, original), ( - f"extra_fields['{field_name}'] shares memory with the DataFrame — " - "it is a view, not a copy" - ) - - -def test_rlssm_pymc_model(rldm_data, rlssm_config) -> None: - """pymc_model should be accessible after model construction.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - assert model.pymc_model is not None - - -# --------------------------------------------------------------------------- -# Slow sampling smoke test -# --------------------------------------------------------------------------- - - -@pytest.mark.slow -def test_rlssm_sample_smoke(rldm_data, rlssm_config) -> None: - """Minimal sampling run should return an InferenceData object.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - trace = model.sample( - draws=4, tune=50, chains=1, cores=1, sampler="numpyro", target_accept=0.9 - ) - assert trace is not None - - -def test_rlssm_pickle_round_trip( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: - """cloudpickle round-trip must reconstruct an equivalent RLSSM. - - Verifies that __getstate__ / __setstate__ survive serialisation: - - The reconstructed object is a fresh RLSSM (not the same instance). - - n_participants and n_trials are preserved. - - list_params (including p_outlier) are preserved. - - model_config.model_name is preserved. - - model.model (bambi model) is rebuilt, confirming full re-initialisation. - """ - import cloudpickle - - model = RLSSM(data=rldm_data, model_config=rlssm_config) - blob = cloudpickle.dumps(model) - restored = cloudpickle.loads(blob) - - assert restored is not model - assert isinstance(restored, RLSSM) - assert restored.n_participants == model.n_participants - assert restored.n_trials == model.n_trials - assert restored.list_params == model.list_params - assert restored.model_config.model_name == model.model_config.model_name - assert restored.model is not None