From b152169612b90e323727d6d18011ecb1f2a0d695 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 6 May 2026 14:50:42 -0400 Subject: [PATCH 01/41] Update init files --- src/hssm/__init__.py | 3 ++- src/hssm/rl/__init__.py | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/hssm/__init__.py b/src/hssm/__init__.py index 2f234d08..171da990 100644 --- a/src/hssm/__init__.py +++ b/src/hssm/__init__.py @@ -19,7 +19,7 @@ from .param import UserParam as Param from .prior import Prior from .register import register_model -from .rl import RLSSM +from .rl import RLSSM, register_rlssm_model from .simulator import simulate_data from .utils import check_data_for_rl, set_floatX @@ -40,6 +40,7 @@ "Prior", "check_data_for_rl", "register_model", + "register_rlssm_model", "simulate_data", "set_floatX", "show_defaults", diff --git a/src/hssm/rl/__init__.py b/src/hssm/rl/__init__.py index 64e17bc4..78579888 100644 --- a/src/hssm/rl/__init__.py +++ b/src/hssm/rl/__init__.py @@ -5,9 +5,12 @@ Public API (import from ``hssm.rl``): -- ``RLSSM``: the RL + SSM model class implemented in :mod:`hssm.rl.rlssm`. -- ``RLSSMConfig``: the config class for RL + SSM models, implemented in - :mod:`hssm.rl.config`. +- ``RLSSM``: the public RL + SSM model class in :mod:`hssm.rl.rlssm`. +- ``_RLSSM``: the internal base class that requires a fully built config. +- ``RLSSMConfig``: the config class for RL + SSM models in :mod:`hssm.rl.config`. +- ``get_rlssm_model_config``: factory that builds a config from a named model. +- ``register_rlssm_model``: register a custom named RLSSM model. +- ``register_ssm``: register a custom SSM base logp function. - ``validate_balanced_panel``: panel-balance utility in :mod:`hssm.rl.utils`. RL likelihood builders live in :mod:`hssm.rl.likelihoods.builder` and include @@ -17,11 +20,16 @@ """ from .config import RLSSMConfig -from .rlssm import RLSSM +from .registry import get_rlssm_model_config, register_rlssm_model, register_ssm +from .rlssm import _RLSSM, RLSSM from .utils import validate_balanced_panel __all__ = [ "RLSSM", + "_RLSSM", "RLSSMConfig", + "get_rlssm_model_config", + "register_rlssm_model", + "register_ssm", "validate_balanced_panel", ] From c3a09bef93e67b5ee0f99a69a92bebceb1158eb8 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 6 May 2026 14:50:57 -0400 Subject: [PATCH 02/41] Add registry --- src/hssm/rl/registry.py | 367 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 367 insertions(+) create mode 100644 src/hssm/rl/registry.py diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py new file mode 100644 index 00000000..9dd46834 --- /dev/null +++ b/src/hssm/rl/registry.py @@ -0,0 +1,367 @@ +"""Registry for named RLSSM models and SSM base log-likelihood functions. + +This module provides: + +- :data:`_SSM_REGISTRY` — maps SSM names (e.g. ``"angle"``) to their base + annotated JAX log-likelihood functions and parameter metadata. +- :data:`_RLSSM_REGISTRY` — maps named RLSSM model strings (e.g. ``"rldm"``) + to their default decision process, learning process, and parameter info. +- :func:`get_rlssm_model_config` — builds a :class:`~hssm.rl.config.RLSSMConfig` + from a named model string with optional overrides. +- :func:`register_rlssm_model` — register a custom named RLSSM model. +- :func:`register_ssm` — register a custom SSM base logp function. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from hssm.distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx +from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise +from hssm.utils import annotate_function + +from .config import RLSSMConfig + +_logger = logging.getLogger("hssm") + +# --------------------------------------------------------------------------- +# Default annotated Rescorla-Wagner learning function +# --------------------------------------------------------------------------- + +_compute_v_annotated = annotate_function( + inputs=["rl_alpha", "scaler", "response", "feedback"], + outputs=["v"], +)(compute_v_subject_wise) + +# --------------------------------------------------------------------------- +# SSM base log-likelihood registry +# --------------------------------------------------------------------------- +# Each entry provides: +# ssm_base_logp_func - annotated JAX function (inputs + outputs, no computed) +# list_params_ssm - ordered SSM parameter names (including computed ones) +# bounds_ssm - bounds for non-computed SSM params +# params_default_ssm - default values aligned with list_params_ssm +# response - data column names + + +def _make_angle_base_logp() -> Any: + """Build the annotated angle SSM base logp function from its ONNX file.""" + _raw = make_jax_matrix_logp_funcs_from_onnx(model="angle.onnx") + return annotate_function( + inputs=["v", "a", "z", "t", "theta", "rt", "response"], + outputs=["logp"], + )(_raw) + + +_SSM_REGISTRY: dict[str, dict[str, Any]] = { + "angle": { + "ssm_base_logp_func": _make_angle_base_logp(), + # All SSM params in the order the SSM expects them (includes computed). + "list_params_ssm": ["v", "a", "z", "t", "theta"], + # Bounds only for params that will be *sampled* (not RL-computed). + "bounds_ssm": { + "a": (0.3, 3.0), + "z": (0.1, 0.9), + "t": (0.001, 2.0), + "theta": (-0.1, 1.3), + }, + # Defaults aligned with list_params_ssm: v, a, z, t, theta + "params_default_ssm": [0.0, 1.5, 0.5, 0.5, 0.0], + "response": ["rt", "response"], + }, +} + +# --------------------------------------------------------------------------- +# RLSSM named model registry +# --------------------------------------------------------------------------- +# Each entry provides: +# decision_process - key into _SSM_REGISTRY +# learning_process - {param: annotated_func} +# rl_params - ordered list of sampled RL parameter names +# rl_bounds - {param: (lo, hi)} for RL params +# rl_params_default - default values aligned with rl_params +# extra_fields - extra data column names required by LP +# choices - response choice values +# description - human-readable description +# decision_process_loglik_kind +# learning_process_kind + +_RLSSM_REGISTRY: dict[str, dict[str, Any]] = { + "rldm": { + "decision_process": "angle", + "learning_process": {"v": _compute_v_annotated}, + "rl_params": ["rl_alpha", "scaler"], + "rl_bounds": { + "rl_alpha": (0.0, 1.0), + "scaler": (0.0, 10.0), + }, + "rl_params_default": [0.1, 1.0], + "extra_fields": ["feedback"], + "choices": [0, 1], + "description": ( + "RLSSM model with Rescorla-Wagner Q-learning and a " + "collapsing-bound DDM (angle model) as decision process." + ), + "decision_process_loglik_kind": "approx_differentiable", + "learning_process_kind": "blackbox", + }, +} + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _build_ssm_logp_func(ssm_base_logp_func: Any, learning_process: dict) -> Any: + """Re-annotate *ssm_base_logp_func* adding ``computed=learning_process``. + + Creates a new wrapper that carries the same ``.inputs`` and ``.outputs`` as + the base function but adds the ``computed`` dict so that + :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op` can resolve which + parameters are produced by the RL learning rule at runtime. + """ + return annotate_function( + inputs=ssm_base_logp_func.inputs, + outputs=ssm_base_logp_func.outputs, + computed=learning_process, + )(ssm_base_logp_func) + + +def _derive_rl_params( + learning_process: dict[str, Any], + response: list[str], + extra_fields: list[str], +) -> list[str]: + """Return sampled RL parameter names inferred from *learning_process*. + + Iterates over each LP function's ``.inputs`` and collects names that are + neither response columns nor extra fields. + """ + exclude = set(response) | set(extra_fields) + rl_params: list[str] = [] + seen: set[str] = set() + for lp_func in learning_process.values(): + if not hasattr(lp_func, "inputs"): + continue + for inp in lp_func.inputs: + if inp not in exclude and inp not in seen: + rl_params.append(inp) + seen.add(inp) + return rl_params + + +# --------------------------------------------------------------------------- +# Public factory +# --------------------------------------------------------------------------- + + +def get_rlssm_model_config( + model: str = "rldm", + choices: list[int] | None = None, + learning_process: dict[str, Any] | None = None, + decision_process: str | None = None, +) -> RLSSMConfig: + """Build an :class:`~hssm.rl.config.RLSSMConfig` from a named model. + + Parameters + ---------- + model: + Name of a registered RLSSM model (e.g. ``"rldm"``). + choices: + Override the response choice values stored in the registry. + learning_process: + Override the learning process dict stored in the registry. + decision_process: + Override the SSM name stored in the registry. + + Returns + ------- + RLSSMConfig + Fully populated configuration ready to be passed to :class:`_RLSSM`. + + Raises + ------ + ValueError + If *model* or the resolved *decision_process* is not registered. + """ + if model not in _RLSSM_REGISTRY: + available = list(_RLSSM_REGISTRY.keys()) + raise ValueError( + f"Model '{model}' not found in the RLSSM registry. " + f"Available models: {available}. " + "To use a custom model, pass 'model_config=' directly." + ) + + # Shallow-copy so overrides don't mutate the registry entry. + entry = dict(_RLSSM_REGISTRY[model]) + + if learning_process is not None: + entry["learning_process"] = learning_process + if decision_process is not None: + entry["decision_process"] = decision_process + if choices is not None: + entry["choices"] = choices + + dp: str = entry["decision_process"] + if dp not in _SSM_REGISTRY: + available_ssms = list(_SSM_REGISTRY.keys()) + raise ValueError( + f"Decision process '{dp}' not found in the SSM registry. " + f"Available: {available_ssms}. Use register_ssm() to add it." + ) + + ssm_entry = _SSM_REGISTRY[dp] + ssm_base = ssm_entry["ssm_base_logp_func"] + lp: dict[str, Any] = entry["learning_process"] + + # Compose the full ssm_logp_func with .computed = learning_process. + ssm_logp_func = _build_ssm_logp_func(ssm_base, lp) + + # list_params = [sampled RL params] + [sampled SSM params (non-computed)] + computed_set = set(lp.keys()) + ssm_sampled = [p for p in ssm_entry["list_params_ssm"] if p not in computed_set] + + rl_params: list[str] = entry.get("rl_params") or _derive_rl_params( + lp, ssm_entry["response"], entry.get("extra_fields") or [] + ) + list_params = rl_params + ssm_sampled + + # bounds: RL bounds ∪ SSM sampled bounds + bounds: dict[str, tuple[float, float]] = dict(entry.get("rl_bounds") or {}) + for p in ssm_sampled: + if p in ssm_entry["bounds_ssm"]: + bounds[p] = ssm_entry["bounds_ssm"][p] + + # params_default aligned with list_params + rl_defaults: list[float] = list(entry.get("rl_params_default") or []) + ssm_all_defaults: list[float] = list(ssm_entry["params_default_ssm"]) + ssm_sampled_defaults = [ + ssm_all_defaults[i] + for i, p in enumerate(ssm_entry["list_params_ssm"]) + if p not in computed_set + ] + params_default = rl_defaults + ssm_sampled_defaults + + return RLSSMConfig( + model_name=entry.get("model_name", model), + description=entry.get("description"), + decision_process=dp, + decision_process_loglik_kind=entry["decision_process_loglik_kind"], + learning_process_kind=entry["learning_process_kind"], + learning_process=lp, + ssm_logp_func=ssm_logp_func, + list_params=list_params, + bounds=bounds, + params_default=params_default, + response=ssm_entry["response"], + choices=entry["choices"], + extra_fields=entry.get("extra_fields"), + ) + + +# --------------------------------------------------------------------------- +# Public registration helpers +# --------------------------------------------------------------------------- + + +def register_rlssm_model( + name: str, + decision_process: str, + learning_process: dict[str, Any], + rl_params: list[str], + rl_bounds: dict[str, tuple[float, float]], + rl_params_default: list[float], + extra_fields: list[str] | None = None, + choices: list[int] | None = None, + description: str | None = None, + decision_process_loglik_kind: str = "approx_differentiable", + learning_process_kind: str = "blackbox", +) -> None: + """Register a named RLSSM model in the global registry. + + Parameters + ---------- + name: + Registry key (e.g. ``"my_rldm"``). + decision_process: + Name of the SSM to use (must already be in the SSM registry). + learning_process: + Dict mapping computed parameter name → annotated learning function. + rl_params: + Ordered list of sampled RL parameter names. + rl_bounds: + Parameter bounds for the RL parameters. + rl_params_default: + Default values aligned with *rl_params*. + extra_fields: + Data column names required by the learning process (e.g. ``["feedback"]``). + choices: + Response choice values. Defaults to ``[0, 1]``. + description: + Optional human-readable description. + decision_process_loglik_kind: + Loglik kind tag. Defaults to ``"approx_differentiable"``. + learning_process_kind: + Learning process kind tag. Defaults to ``"blackbox"``. + """ + if name in _RLSSM_REGISTRY: + _logger.warning( + "Model '%s' is already in the RLSSM registry and will be overwritten.", + name, + ) + _RLSSM_REGISTRY[name] = { + "decision_process": decision_process, + "learning_process": learning_process, + "rl_params": rl_params, + "rl_bounds": rl_bounds, + "rl_params_default": rl_params_default, + "extra_fields": extra_fields or [], + "choices": choices if choices is not None else [0, 1], + "description": description, + "decision_process_loglik_kind": decision_process_loglik_kind, + "learning_process_kind": learning_process_kind, + } + + +def register_ssm( + name: str, + ssm_base_logp_func: Any, + list_params_ssm: list[str], + bounds_ssm: dict[str, tuple[float, float]], + params_default_ssm: list[float], + response: list[str] | None = None, +) -> None: + """Register an SSM base log-likelihood function in the SSM registry. + + Parameters + ---------- + name: + Registry key (e.g. ``"ddm"``). + ssm_base_logp_func: + An annotated JAX function (created with ``@annotate_function``) that + computes the SSM log-likelihood from a parameter matrix. Must carry + ``.inputs`` and ``.outputs`` attributes but should **not** have a + ``.computed`` key — that is injected by the factory at config-build time. + list_params_ssm: + Ordered list of all SSM parameter names (including any that will be + computed by the learning process). + bounds_ssm: + Bounds for the non-computed SSM parameters. + params_default_ssm: + Default values aligned with *list_params_ssm*. + response: + Data column names. Defaults to ``["rt", "response"]``. + """ + if name in _SSM_REGISTRY: + _logger.warning( + "SSM '%s' is already in the SSM registry and will be overwritten.", name + ) + _SSM_REGISTRY[name] = { + "ssm_base_logp_func": ssm_base_logp_func, + "list_params_ssm": list_params_ssm, + "bounds_ssm": bounds_ssm, + "params_default_ssm": params_default_ssm, + "response": response or ["rt", "response"], + } From e4fb4acae2fe26fe01d6486bc66a7f5ab9aff77c Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 6 May 2026 14:51:17 -0400 Subject: [PATCH 03/41] Update rlssm module --- src/hssm/rl/rlssm.py | 220 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 212 insertions(+), 8 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 180cd12c..8a252deb 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -1,14 +1,19 @@ """RLSSM: Reinforcement Learning Sequential Sampling Model. -This module defines the :class:`RLSSM` class, a subclass of :class:`HSSMBase` -for models that couple a reinforcement learning (RL) learning process with a -sequential sampling decision model (SSM). +This module defines: -The key difference from :class:`HSSM` is the likelihood: +- :class:`_RLSSM` — the internal base class (previously ``RLSSM``) that requires + a fully populated :class:`~hssm.rl.config.RLSSMConfig` to be passed directly. +- :class:`RLSSM` — the public-facing subclass with a simplified constructor that + accepts a *model* name string, optional *learning_process* / *decision_process* + overrides, and an optional *model_config* escape hatch. Config construction is + delegated to :func:`~hssm.rl.registry.get_rlssm_model_config`. + +The key difference from :class:`~hssm.hssm.HSSM` is the likelihood: - ``HSSM`` wraps an analytical / ONNX / blackbox callable via :func:`~hssm.distribution_utils.make_likelihood_callable`. - - ``RLSSM`` builds a differentiable pytensor ``Op`` directly from an - :class:`~hssm.rl.likelihoods.builder.AnnotatedFunction` via + - ``_RLSSM`` / ``RLSSM`` build a differentiable pytensor ``Op`` directly from + an :class:`~hssm.rl.likelihoods.builder.AnnotatedFunction` via :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`, which internally handles the RL learning rule and per-participant trial structure. This Op is then passed straight to @@ -16,6 +21,7 @@ standard ``loglik`` / ``loglik_kind`` wrapping pipeline. """ +import logging from dataclasses import replace from typing import TYPE_CHECKING, Any, Callable, Literal, cast @@ -36,10 +42,17 @@ from ..base import HSSMBase from .config import RLSSMConfig +from .registry import get_rlssm_model_config + +_logger = logging.getLogger("hssm") + +class _RLSSM(HSSMBase): + """Internal Reinforcement Learning Sequential Sampling Model. -class RLSSM(HSSMBase): - """Reinforcement Learning Sequential Sampling Model. + Requires a fully populated :class:`RLSSMConfig` (with ``ssm_logp_func`` set) + to be passed directly. End users should use :class:`RLSSM` instead, which + provides a simplified interface backed by the named-model registry. Combines a reinforcement learning (RL) process with a sequential sampling model (SSM) inside a single differentiable likelihood. The RL component @@ -262,3 +275,194 @@ def _make_model_distribution(self) -> type[pm.Distribution]: extra_fields=extra_fields_data, params_is_trialwise=params_is_trialwise, ) + + +# --------------------------------------------------------------------------- +# Blocked-attribute descriptor +# --------------------------------------------------------------------------- + + +class _BlockedAttribute: + """Data descriptor that blocks read access with NotImplementedError. + + During initialisation, writes are stored and reads return the stored + value so that :meth:`MissingDataMixin._process_missing_data_and_deadline` + (which both writes and reads ``self.missing_data`` / ``self.deadline`` / + ``self.loglik_missing_data``) can complete without error. + + Once ``instance.__dict__['_rlssm_fully_initialized']`` is set to ``True`` + at the end of :meth:`RLSSM.__init__`, any read raises + :exc:`NotImplementedError`. + + Using a data descriptor (one with both ``__get__`` and ``__set__``) is + necessary because data descriptors take priority over instance ``__dict__`` + entries, so the descriptor's ``__get__`` fires even after a write. + """ + + def __init__(self, name: str, message: str) -> None: + self._name = name + self._message = message + self._storage_key = f"_ba_{name}" + + def __set_name__(self, owner: type, name: str) -> None: # noqa: D105 + self._name = name + self._storage_key = f"_ba_{name}" + + def __get__(self, obj: Any, objtype: type | None = None) -> Any: # noqa: D105 + if obj is None: + # Class-level access — return the descriptor itself. + return self + if not obj.__dict__.get("_rlssm_fully_initialized", False): + # During __init__: return the stored value so internal code works. + return obj.__dict__.get(self._storage_key, False) + raise NotImplementedError(self._message) + + def __set__(self, obj: Any, value: Any) -> None: # noqa: D105 + # Store the value so internal reads during __init__ work correctly. + obj.__dict__[self._storage_key] = value + + +# --------------------------------------------------------------------------- +# Public wrapper +# --------------------------------------------------------------------------- + + +class RLSSM(_RLSSM): + """Reinforcement Learning Sequential Sampling Model — simplified public API. + + This class wraps :class:`_RLSSM` with a user-friendly constructor that + accepts a *model* name string (looked up in the named-model registry) and + optional overrides for *learning_process*, *decision_process*, and + *choices*. Advanced users can bypass the registry entirely by supplying a + pre-built *model_config*. + + ``missing_data``, ``deadline``, and ``loglik_missing_data`` are not + supported for RLSSM models and raise :exc:`NotImplementedError` if accessed. + + Parameters + ---------- + data : pd.DataFrame + Trial-level data (balanced panel required). + model : str, optional + Name of a registered RLSSM model. Defaults to ``"rldm"``. + choices : list[int] | None, optional + Override the choice values in the registry. ``None`` uses the registry + default. + include : list | None, optional + Parameter specifications forwarded to :class:`~hssm.base.HSSMBase`. + model_config : RLSSMConfig | None, optional + Fully built config (escape hatch). When provided, *model*, + *learning_process*, *decision_process*, and *choices* are ignored + (a warning is emitted if they are non-default). + learning_process : dict | None, optional + Override the learning-process dict in the registry. ``None`` uses the + registry default. + decision_process : str | None, optional + Override the SSM name in the registry. ``None`` uses the registry + default. + participant_col : str, optional + Column identifying participants. Defaults to ``"participant_id"``. + p_outlier : float | dict | bmb.Prior | None, optional + Lapse probability. Defaults to ``0.05``. + lapse : dict | bmb.Prior | None, optional + Lapse distribution. Defaults to ``Uniform(0, 20)``. + link_settings : Literal["log_logit"] | None, optional + Link-function preset. Defaults to ``None``. + prior_settings : Literal["safe"] | None, optional + Prior preset. Defaults to ``"safe"``. + extra_namespace : dict | None, optional + Extra variables for formula evaluation. Defaults to ``None``. + process_initvals : bool, optional + Whether to post-process initial values. Defaults to ``True``. + initval_jitter : float, optional + Jitter magnitude for initial values. + **kwargs + Additional keyword arguments forwarded to :class:`bmb.Model`. + """ + + # Block read access to the three missing-data attributes while silently + # accepting any writes made by the base-class initialisation path. + missing_data = _BlockedAttribute( # type: ignore[assignment] + "missing_data", + "RLSSM does not support 'missing_data'. " + "The RL log-likelihood Op relies on strict row order; rearranging rows " + "for missing RT values would corrupt the RL learning dynamics. " + "Please remove missing trials from the data before passing it to RLSSM.", + ) + deadline = _BlockedAttribute( # type: ignore[assignment] + "deadline", + "RLSSM does not support 'deadline'. " + "The RL log-likelihood Op relies on strict row order; rearranging rows " + "for deadline trials would corrupt the RL learning dynamics. " + "Please remove deadline trials from the data before passing it to RLSSM.", + ) + loglik_missing_data = _BlockedAttribute( # type: ignore[assignment] + "loglik_missing_data", + "RLSSM does not support 'loglik_missing_data'. " + "Missing-data network assembly (OPN / CPN) is not implemented for RLSSM.", + ) + + def __init__( + self, + data: pd.DataFrame, + model: str = "rldm", + choices: list[int] | None = None, + include: list[dict[str, Any] | Any] | None = None, + model_config: RLSSMConfig | None = None, + learning_process: dict[str, Any] | None = None, + decision_process: str | None = None, + participant_col: str = "participant_id", + p_outlier: float | dict | bmb.Prior | None = 0.05, + lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0), + link_settings: Literal["log_logit"] | None = None, + prior_settings: Literal["safe"] | None = "safe", + extra_namespace: dict[str, Any] | None = None, + process_initvals: bool = True, + initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"], + **kwargs: Any, + ) -> None: + # Capture simplified args BEFORE calling super so they can be + # restored afterwards for save/load serialisation. + _my_init_args = self._store_init_args(locals(), kwargs) + + if model_config is not None: + # Escape-hatch path: caller supplied a fully built config. + if any( + x is not None for x in [learning_process, decision_process, choices] + ): + _logger.warning( + "model_config was provided; ignoring model, learning_process, " + "decision_process, and choices arguments." + ) + else: + model_config = get_rlssm_model_config( + model=model, + choices=choices, + learning_process=learning_process, + decision_process=decision_process, + ) + + super().__init__( + data=data, + model_config=model_config, + participant_col=participant_col, + include=include, + p_outlier=p_outlier, + lapse=lapse, + link_settings=link_settings, + prior_settings=prior_settings, + extra_namespace=extra_namespace, + process_initvals=process_initvals, + initval_jitter=initval_jitter, + **kwargs, + ) + + # Restore the simplified constructor args so that save/load round-trips + # reconstruct the model via RLSSM(model=...) rather than + # _RLSSM(model_config=...). + self._init_args = _my_init_args + + # Mark initialisation complete — after this point _BlockedAttribute + # raises NotImplementedError on any read of missing_data / deadline / + # loglik_missing_data. + self.__dict__["_rlssm_fully_initialized"] = True From 9dc18ac802d0c6c1f102f90890ed6be019dd8867 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 6 May 2026 14:51:24 -0400 Subject: [PATCH 04/41] Add tests for simplified RLSSM interface and model registration --- tests/test_rlssm.py | 94 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 1 deletion(-) diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index 973061c7..1c8f3b72 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -15,7 +15,7 @@ import pytest import hssm -from hssm.rl import RLSSM, RLSSMConfig +from hssm.rl import RLSSM, RLSSMConfig, _RLSSM, register_rlssm_model from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise from hssm.utils import annotate_function @@ -327,3 +327,95 @@ def test_rlssm_pickle_round_trip( assert restored.list_params == model.list_params assert restored.model_config.model_name == model.model_config.model_name assert restored.model is not None + + +# --------------------------------------------------------------------------- +# Simplified public interface tests (new RLSSM wrapper) +# --------------------------------------------------------------------------- + + +def test_rlssm_is_subclass_of_internal() -> None: + """RLSSM must be a subclass of _RLSSM.""" + assert issubclass(RLSSM, _RLSSM) + + +def test_rlssm_simplified_init(rldm_data) -> None: + """RLSSM(data, model='rldm') should build without supplying model_config.""" + model = RLSSM(data=rldm_data, model="rldm") + assert isinstance(model, RLSSM) + # Registry-derived params: rl_alpha, scaler (RL) + a, z, t, theta (SSM) + assert "rl_alpha" in model.params + assert "scaler" in model.params + assert "a" in model.params + # v is computed inside the Op; must NOT appear as a free parameter + assert "v" not in model.params + + +def test_rlssm_default_model_is_rldm(rldm_data) -> None: + """Omitting model should default to 'rldm'.""" + model = RLSSM(data=rldm_data) + assert isinstance(model, RLSSM) + assert model.model_config.decision_process == "angle" + + +def test_rlssm_model_config_escape_hatch(rldm_data, rlssm_config) -> None: + """Passing model_config= directly should bypass the registry.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model.model_config.model_name == rlssm_config.model_name + + +def test_rlssm_unregistered_model_raises(rldm_data) -> None: + """Using an unknown model name should raise ValueError.""" + with pytest.raises(ValueError, match="not found in the RLSSM registry"): + RLSSM(data=rldm_data, model="model_not_in_registry") + + +def test_rlssm_missing_data_property_raises(rldm_data) -> None: + """Accessing .missing_data on a built RLSSM instance must raise NotImplementedError.""" + model = RLSSM(data=rldm_data, model="rldm") + with pytest.raises(NotImplementedError, match="missing_data"): + _ = model.missing_data + + +def test_rlssm_deadline_property_raises(rldm_data) -> None: + """Accessing .deadline on a built RLSSM instance must raise NotImplementedError.""" + model = RLSSM(data=rldm_data, model="rldm") + with pytest.raises(NotImplementedError, match="deadline"): + _ = model.deadline + + +def test_rlssm_loglik_missing_data_property_raises(rldm_data) -> None: + """Accessing .loglik_missing_data on a built RLSSM instance must raise NotImplementedError.""" + model = RLSSM(data=rldm_data, model="rldm") + with pytest.raises(NotImplementedError, match="loglik_missing_data"): + _ = model.loglik_missing_data + + +def test_register_rlssm_model(rldm_data) -> None: + """A user-registered model should be instantiable via the simplified interface.""" + # Re-use the existing annotated learning function and ssm logp from the + # module-level helpers defined at the top of this test file. + register_rlssm_model( + name="rldm_custom_test", + decision_process="angle", + learning_process={"v": _compute_v_annotated}, + rl_params=["rl_alpha", "scaler"], + rl_bounds={"rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0)}, + rl_params_default=[0.1, 1.0], + extra_fields=["feedback"], + choices=[0, 1], + ) + model = RLSSM(data=rldm_data, model="rldm_custom_test") + assert isinstance(model, RLSSM) + assert "rl_alpha" in model.params + assert "v" not in model.params + + +def test_rlssm_init_args_uses_simplified_interface(rldm_data) -> None: + """_init_args should reflect the simplified constructor, not model_config.""" + model = RLSSM(data=rldm_data, model="rldm") + assert "model" in model._init_args + assert model._init_args["model"] == "rldm" + # model_config should not be baked in as a hard reference + assert "model_config" in model._init_args + assert model._init_args["model_config"] is None From 413fe367d25006941864119a000b959b40a393d4 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 6 May 2026 15:01:52 -0400 Subject: [PATCH 05/41] Fix condition in RLSSM class to handle model configuration correctly --- src/hssm/rl/rlssm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 8a252deb..6aa0a531 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -427,7 +427,7 @@ def __init__( if model_config is not None: # Escape-hatch path: caller supplied a fully built config. - if any( + if model != "rldm" or any( x is not None for x in [learning_process, decision_process, choices] ): _logger.warning( From 19a39aa42fffe900c6b4c784ffb33753bde6a979 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 6 May 2026 15:04:23 -0400 Subject: [PATCH 06/41] Implement lazy loading for SSM base log-likelihood functions and add caching mechanism --- src/hssm/rl/registry.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index 9dd46834..8d7f486f 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -56,7 +56,10 @@ def _make_angle_base_logp() -> Any: _SSM_REGISTRY: dict[str, dict[str, Any]] = { "angle": { - "ssm_base_logp_func": _make_angle_base_logp(), + # Factory callable — invoked lazily on first use via _get_ssm_logp(). + # Storing a callable (not the result) avoids loading angle.onnx at + # import time (which would trigger hf_hub_download in offline envs). + "ssm_base_logp_func_factory": _make_angle_base_logp, # All SSM params in the order the SSM expects them (includes computed). "list_params_ssm": ["v", "a", "z", "t", "theta"], # Bounds only for params that will be *sampled* (not RL-computed). @@ -72,6 +75,29 @@ def _make_angle_base_logp() -> Any: }, } +# Cache for resolved SSM base logp functions (populated on first use by +# _get_ssm_logp, or immediately by register_ssm when a pre-built func is +# supplied by the caller). +_SSM_LOGP_CACHE: dict[str, Any] = {} + + +def _get_ssm_logp(name: str) -> Any: + """Return the annotated SSM base logp function, building it on first use. + + For built-in SSMs the ONNX model is downloaded / loaded only when this + function is first called (lazy initialisation). Subsequent calls return + the cached object without any I/O. + """ + if name not in _SSM_LOGP_CACHE: + entry = _SSM_REGISTRY[name] + if "ssm_base_logp_func_factory" in entry: + _SSM_LOGP_CACHE[name] = entry["ssm_base_logp_func_factory"]() + else: + # Pre-built function registered via register_ssm(). + _SSM_LOGP_CACHE[name] = entry["ssm_base_logp_func"] + return _SSM_LOGP_CACHE[name] + + # --------------------------------------------------------------------------- # RLSSM named model registry # --------------------------------------------------------------------------- @@ -213,7 +239,7 @@ def get_rlssm_model_config( ) ssm_entry = _SSM_REGISTRY[dp] - ssm_base = ssm_entry["ssm_base_logp_func"] + ssm_base = _get_ssm_logp(dp) lp: dict[str, Any] = entry["learning_process"] # Compose the full ssm_logp_func with .computed = learning_process. @@ -365,3 +391,5 @@ def register_ssm( "params_default_ssm": params_default_ssm, "response": response or ["rt", "response"], } + # Pre-built: cache immediately so _get_ssm_logp never calls a factory. + _SSM_LOGP_CACHE[name] = ssm_base_logp_func From 301cd135622e98bae113e6091306518665b5b169 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 6 May 2026 15:54:47 -0400 Subject: [PATCH 07/41] Enhance error handling and validation in RLSSM and SSM registry - Added KeyError for missing SSM in the registry. - Improved ValueError messages for non-callable logp functions. - Implemented runtime checks for list_params and loglik in RLSSM. - Ensured defensive copying of mutable parameters in registry functions. --- src/hssm/rl/registry.py | 52 ++++++++++++++++++++++++++++++++++------- src/hssm/rl/rlssm.py | 30 +++++++++++++++++------- 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index 8d7f486f..5c437dff 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -89,6 +89,12 @@ def _get_ssm_logp(name: str) -> Any: the cached object without any I/O. """ if name not in _SSM_LOGP_CACHE: + if name not in _SSM_REGISTRY: + raise KeyError( + f"SSM '{name}' not found in the SSM registry. " + f"Available: {list(_SSM_REGISTRY.keys())}. " + "Use register_ssm() to add it." + ) entry = _SSM_REGISTRY[name] if "ssm_base_logp_func_factory" in entry: _SSM_LOGP_CACHE[name] = entry["ssm_base_logp_func_factory"]() @@ -148,6 +154,12 @@ def _build_ssm_logp_func(ssm_base_logp_func: Any, learning_process: dict) -> Any :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op` can resolve which parameters are produced by the RL learning rule at runtime. """ + existing = getattr(ssm_base_logp_func, "computed", None) + if existing: + raise ValueError( + "ssm_base_logp_func already has a non-empty .computed attribute. " + "Pass the raw (base) annotated function without .computed instead." + ) return annotate_function( inputs=ssm_base_logp_func.inputs, outputs=ssm_base_logp_func.outputs, @@ -168,8 +180,14 @@ def _derive_rl_params( exclude = set(response) | set(extra_fields) rl_params: list[str] = [] seen: set[str] = set() - for lp_func in learning_process.values(): + for param_name, lp_func in learning_process.items(): if not hasattr(lp_func, "inputs"): + _logger.warning( + "Learning process function for '%s' has no .inputs attribute; " + "cannot derive RL parameters from it. " + "Ensure it is decorated with @annotate_function.", + param_name, + ) continue for inp in lp_func.inputs: if inp not in exclude and inp not in seen: @@ -249,8 +267,11 @@ def get_rlssm_model_config( computed_set = set(lp.keys()) ssm_sampled = [p for p in ssm_entry["list_params_ssm"] if p not in computed_set] + # Defensive copy of response to prevent downstream mutation of registry. + response = list(ssm_entry["response"]) + rl_params: list[str] = entry.get("rl_params") or _derive_rl_params( - lp, ssm_entry["response"], entry.get("extra_fields") or [] + lp, response, entry.get("extra_fields") or [] ) list_params = rl_params + ssm_sampled @@ -281,7 +302,7 @@ def get_rlssm_model_config( list_params=list_params, bounds=bounds, params_default=params_default, - response=ssm_entry["response"], + response=response, choices=entry["choices"], extra_fields=entry.get("extra_fields"), ) @@ -339,12 +360,14 @@ def register_rlssm_model( ) _RLSSM_REGISTRY[name] = { "decision_process": decision_process, - "learning_process": learning_process, - "rl_params": rl_params, - "rl_bounds": rl_bounds, - "rl_params_default": rl_params_default, - "extra_fields": extra_fields or [], - "choices": choices if choices is not None else [0, 1], + # Shallow-copy all mutable caller-supplied collections so that later + # mutations of the originals do not silently corrupt the registry entry. + "learning_process": dict(learning_process), + "rl_params": list(rl_params), + "rl_bounds": dict(rl_bounds), + "rl_params_default": list(rl_params_default), + "extra_fields": list(extra_fields) if extra_fields is not None else [], + "choices": list(choices) if choices is not None else [0, 1], "description": description, "decision_process_loglik_kind": decision_process_loglik_kind, "learning_process_kind": learning_process_kind, @@ -380,6 +403,17 @@ def register_ssm( response: Data column names. Defaults to ``["rt", "response"]``. """ + if not callable(ssm_base_logp_func): + raise ValueError( + f"ssm_base_logp_func must be callable, got {type(ssm_base_logp_func)!r}." + ) + if not hasattr(ssm_base_logp_func, "inputs") or not hasattr( + ssm_base_logp_func, "outputs" + ): + raise ValueError( + "ssm_base_logp_func must be decorated with @annotate_function " + "(missing .inputs or .outputs attribute)." + ) if name in _SSM_REGISTRY: _logger.warning( "SSM '%s' is already in the SSM registry and will be overwritten.", name diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 6aa0a531..1f318943 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -171,7 +171,6 @@ def __init__( # Store RL-specific state on self BEFORE super().__init__() so that # _make_model_distribution() (called from super) can access them. - self.config = model_config self.n_participants = n_participants self.n_trials = n_trials @@ -202,6 +201,8 @@ def __init__( # _make_model_distribution for details. model_config = replace(model_config, loglik=loglik_op, backend="jax") + # missing_data and deadline are guaranteed False at this point (guards + # above reject any other value). Pass them explicitly for clarity. super().__init__( data=data, model_config=model_config, @@ -211,8 +212,8 @@ def __init__( link_settings=link_settings, prior_settings=prior_settings, extra_namespace=extra_namespace, - missing_data=missing_data, - deadline=deadline, + missing_data=False, + deadline=False, loglik_missing_data=loglik_missing_data, process_initvals=process_initvals, initval_jitter=initval_jitter, @@ -241,10 +242,12 @@ def _make_model_distribution(self) -> type[pm.Distribution]: # has_lapse=True) rather than self.model_config.list_params (the original # config list, never mutated by HSSMBase). list_params = self.list_params - assert list_params is not None, "list_params must be set" - assert isinstance(list_params, list), ( - "list_params must be a list" - ) # for type checker + if list_params is None: + raise RuntimeError( + "list_params must be set before _make_model_distribution is called." + ) + if not isinstance(list_params, list): + raise TypeError(f"list_params must be a list, got {type(list_params)!r}.") # Every RLSSM distribution parameter is trialwise (the Op receives one # value per trial). p_outlier is excluded to match the contract of @@ -260,7 +263,11 @@ def _make_model_distribution(self) -> type[pm.Distribution]: # The differentiable pytensor Op was stored on model_config.loglik during # __init__; ensure it's present and cast for typing. - assert self.model_config.loglik is not None, "model_config.loglik must be set" + if self.model_config.loglik is None: + raise RuntimeError( + "model_config.loglik must be set before _make_model_distribution " + "is called. This indicates the Op was not built in __init__." + ) loglik_op = cast("Callable[..., Any] | Op", self.model_config.loglik) # RLSSMConfig carries no `rv` field; use model_name as the rv identifier. @@ -423,6 +430,7 @@ def __init__( ) -> None: # Capture simplified args BEFORE calling super so they can be # restored afterwards for save/load serialisation. + # NOTE: _store_init_args only operates on its arguments, not on self. _my_init_args = self._store_init_args(locals(), kwargs) if model_config is not None: @@ -442,6 +450,10 @@ def __init__( decision_process=decision_process, ) + # missing_data / deadline are intentionally omitted — _RLSSM defaults + # them to False. The _BlockedAttribute descriptors on this class + # silently accept writes from the base-class init path and block reads + # after _rlssm_fully_initialized is set below. super().__init__( data=data, model_config=model_config, @@ -465,4 +477,6 @@ def __init__( # Mark initialisation complete — after this point _BlockedAttribute # raises NotImplementedError on any read of missing_data / deadline / # loglik_missing_data. + # IMPORTANT: This MUST be the last statement in __init__. Any code + # added after this line cannot read self.missing_data / self.deadline. self.__dict__["_rlssm_fully_initialized"] = True From b23c8c26cc120b4a202553684206ab86ca329d31 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 6 May 2026 15:56:02 -0400 Subject: [PATCH 08/41] Refactor RLSSM model configuration to respect explicit empty containers for RL parameters, bounds, and defaults --- src/hssm/rl/registry.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index 5c437dff..143ca7f1 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -270,19 +270,29 @@ def get_rlssm_model_config( # Defensive copy of response to prevent downstream mutation of registry. response = list(ssm_entry["response"]) - rl_params: list[str] = entry.get("rl_params") or _derive_rl_params( - lp, response, entry.get("extra_fields") or [] + # Use `is None` checks so that explicitly empty containers ([], {}) are + # respected as valid "no RL params" configuration and not overridden by + # the fallback derivation logic. + _rl_params = entry.get("rl_params") + rl_params: list[str] = ( + _derive_rl_params(lp, response, entry.get("extra_fields") or []) + if _rl_params is None + else _rl_params ) list_params = rl_params + ssm_sampled # bounds: RL bounds ∪ SSM sampled bounds - bounds: dict[str, tuple[float, float]] = dict(entry.get("rl_bounds") or {}) + _rl_bounds = entry.get("rl_bounds") + bounds: dict[str, tuple[float, float]] = dict( + _rl_bounds if _rl_bounds is not None else {} + ) for p in ssm_sampled: if p in ssm_entry["bounds_ssm"]: bounds[p] = ssm_entry["bounds_ssm"][p] # params_default aligned with list_params - rl_defaults: list[float] = list(entry.get("rl_params_default") or []) + _rl_defaults = entry.get("rl_params_default") + rl_defaults: list[float] = list(_rl_defaults if _rl_defaults is not None else []) ssm_all_defaults: list[float] = list(ssm_entry["params_default_ssm"]) ssm_sampled_defaults = [ ssm_all_defaults[i] From 78bd5c580cebd9ea46367fba247d832215dc87fb Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 6 May 2026 16:00:16 -0400 Subject: [PATCH 09/41] Add validation to prevent registration of SSM functions with non-empty .computed attribute --- src/hssm/rl/registry.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index 143ca7f1..62058006 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -424,6 +424,14 @@ def register_ssm( "ssm_base_logp_func must be decorated with @annotate_function " "(missing .inputs or .outputs attribute)." ) + existing_computed = getattr(ssm_base_logp_func, "computed", None) + if existing_computed: + raise ValueError( + "ssm_base_logp_func should not have a non-empty .computed attribute " + "at registration time. The .computed dict is injected later by " + "get_rlssm_model_config() when composing the learning process. " + "Pass the raw base function instead." + ) if name in _SSM_REGISTRY: _logger.warning( "SSM '%s' is already in the SSM registry and will be overwritten.", name From 00638ca755ab5b73c1829fedc534c129ff7fae1c Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 7 May 2026 11:31:28 -0400 Subject: [PATCH 10/41] Refactor SSM registry to convert parameters to appropriate types for consistency --- src/hssm/rl/registry.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index 62058006..33e34f5f 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -438,10 +438,10 @@ def register_ssm( ) _SSM_REGISTRY[name] = { "ssm_base_logp_func": ssm_base_logp_func, - "list_params_ssm": list_params_ssm, - "bounds_ssm": bounds_ssm, - "params_default_ssm": params_default_ssm, - "response": response or ["rt", "response"], + "list_params_ssm": list(list_params_ssm), + "bounds_ssm": dict(bounds_ssm), + "params_default_ssm": list(params_default_ssm), + "response": list(response) if response is not None else ["rt", "response"], } # Pre-built: cache immediately so _get_ssm_logp never calls a factory. _SSM_LOGP_CACHE[name] = ssm_base_logp_func From 101444efb4d15820f013c16077691350b8ca4965 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 7 May 2026 11:39:12 -0400 Subject: [PATCH 11/41] Update error message in get_rlssm_model_config to clarify custom model registration process --- src/hssm/rl/registry.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index 33e34f5f..812a9ecb 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -235,7 +235,9 @@ def get_rlssm_model_config( raise ValueError( f"Model '{model}' not found in the RLSSM registry. " f"Available models: {available}. " - "To use a custom model, pass 'model_config=' directly." + "To add a custom model, use register_rlssm_model() " + "(and register_ssm() for custom decision processes), " + "or pass 'model_config=' directly to RLSSM()." ) # Shallow-copy so overrides don't mutate the registry entry. From 91b2572b286306b6748e96a28db94e734e861303 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 7 May 2026 12:32:10 -0400 Subject: [PATCH 12/41] Use more transparent terminology --- src/hssm/rl/rlssm.py | 5 ++--- tests/test_rlssm.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 1f318943..3a0c9ff8 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -6,7 +6,7 @@ a fully populated :class:`~hssm.rl.config.RLSSMConfig` to be passed directly. - :class:`RLSSM` — the public-facing subclass with a simplified constructor that accepts a *model* name string, optional *learning_process* / *decision_process* - overrides, and an optional *model_config* escape hatch. Config construction is + overrides, and an optional *model_config* override. Config construction is delegated to :func:`~hssm.rl.registry.get_rlssm_model_config`. The key difference from :class:`~hssm.hssm.HSSM` is the likelihood: @@ -358,7 +358,7 @@ class RLSSM(_RLSSM): include : list | None, optional Parameter specifications forwarded to :class:`~hssm.base.HSSMBase`. model_config : RLSSMConfig | None, optional - Fully built config (escape hatch). When provided, *model*, + Fully built config. When provided, *model*, *learning_process*, *decision_process*, and *choices* are ignored (a warning is emitted if they are non-default). learning_process : dict | None, optional @@ -434,7 +434,6 @@ def __init__( _my_init_args = self._store_init_args(locals(), kwargs) if model_config is not None: - # Escape-hatch path: caller supplied a fully built config. if model != "rldm" or any( x is not None for x in [learning_process, decision_process, choices] ): diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py index 1c8f3b72..afd37365 100644 --- a/tests/test_rlssm.py +++ b/tests/test_rlssm.py @@ -358,7 +358,7 @@ def test_rlssm_default_model_is_rldm(rldm_data) -> None: assert model.model_config.decision_process == "angle" -def test_rlssm_model_config_escape_hatch(rldm_data, rlssm_config) -> None: +def test_rlssm_model_config_provided(rldm_data, rlssm_config) -> None: """Passing model_config= directly should bypass the registry.""" model = RLSSM(data=rldm_data, model_config=rlssm_config) assert model.model_config.model_name == rlssm_config.model_name From c82db43174ccf77424e8281e3c63ff5e8eec6e92 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 8 May 2026 11:52:51 -0400 Subject: [PATCH 13/41] Implement defensive copying for decision process specifications in RLSSM model registration --- src/hssm/rl/registry.py | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index 812a9ecb..c2532162 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -15,6 +15,7 @@ from __future__ import annotations import logging +from copy import deepcopy from typing import Any from hssm.distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx @@ -81,6 +82,25 @@ def _make_angle_base_logp() -> Any: _SSM_LOGP_CACHE: dict[str, Any] = {} +def _get_decision_process_spec( + decision_process: str | dict[str, Any], +) -> dict[str, Any]: + """Return a defensive copy of a registered decision-process specification.""" + if isinstance(decision_process, dict): + return deepcopy(decision_process) + + if decision_process not in _SSM_REGISTRY: + available_ssms = list(_SSM_REGISTRY.keys()) + raise ValueError( + f"Decision process '{decision_process}' not found in the SSM registry. " + f"Available: {available_ssms}. Use register_ssm() to add it." + ) + + spec = deepcopy(_SSM_REGISTRY[decision_process]) + spec["name"] = decision_process + return spec + + def _get_ssm_logp(name: str) -> Any: """Return the annotated SSM base logp function, building it on first use. @@ -121,7 +141,7 @@ def _get_ssm_logp(name: str) -> Any: _RLSSM_REGISTRY: dict[str, dict[str, Any]] = { "rldm": { - "decision_process": "angle", + "decision_process": _get_decision_process_spec("angle"), "learning_process": {"v": _compute_v_annotated}, "rl_params": ["rl_alpha", "scaler"], "rl_bounds": { @@ -246,19 +266,12 @@ def get_rlssm_model_config( if learning_process is not None: entry["learning_process"] = learning_process if decision_process is not None: - entry["decision_process"] = decision_process + entry["decision_process"] = _get_decision_process_spec(decision_process) if choices is not None: entry["choices"] = choices - dp: str = entry["decision_process"] - if dp not in _SSM_REGISTRY: - available_ssms = list(_SSM_REGISTRY.keys()) - raise ValueError( - f"Decision process '{dp}' not found in the SSM registry. " - f"Available: {available_ssms}. Use register_ssm() to add it." - ) - - ssm_entry = _SSM_REGISTRY[dp] + ssm_entry = _get_decision_process_spec(entry["decision_process"]) + dp: str = ssm_entry["name"] ssm_base = _get_ssm_logp(dp) lp: dict[str, Any] = entry["learning_process"] @@ -371,7 +384,7 @@ def register_rlssm_model( name, ) _RLSSM_REGISTRY[name] = { - "decision_process": decision_process, + "decision_process": _get_decision_process_spec(decision_process), # Shallow-copy all mutable caller-supplied collections so that later # mutations of the originals do not silently corrupt the registry entry. "learning_process": dict(learning_process), From 9310e885e601b061261dd902dec72874859d2d12 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 8 May 2026 13:56:16 -0400 Subject: [PATCH 14/41] Refactor SSM registry documentation and enhance SSM base log-likelihood function handling --- src/hssm/rl/registry.py | 170 +++++++++++++++++++++++++++------------- 1 file changed, 114 insertions(+), 56 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index c2532162..80fd7f1a 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -2,8 +2,12 @@ This module provides: -- :data:`_SSM_REGISTRY` — maps SSM names (e.g. ``"angle"``) to their base - annotated JAX log-likelihood functions and parameter metadata. +- :data:`_SSM_REGISTRY` — holds *custom* SSM entries added via + :func:`register_ssm`. Built-in HSSM models (``"ddm"``, ``"angle"``, + ``"weibull"``, and any other model in :mod:`hssm.modelconfig` that exposes an + ``approx_differentiable`` likelihood) are resolved automatically from + :func:`hssm.modelconfig.get_default_model_config` and do **not** need to be + pre-registered here. - :data:`_RLSSM_REGISTRY` — maps named RLSSM model strings (e.g. ``"rldm"``) to their default decision process, learning process, and parameter info. - :func:`get_rlssm_model_config` — builds a :class:`~hssm.rl.config.RLSSMConfig` @@ -38,89 +42,143 @@ # --------------------------------------------------------------------------- # SSM base log-likelihood registry # --------------------------------------------------------------------------- -# Each entry provides: +# This dict holds only *custom* SSM entries added at runtime via register_ssm(). +# Built-in HSSM models are resolved on demand from hssm.modelconfig — see +# _build_ssm_spec_from_modelconfig() and _get_decision_process_spec(). +# +# Each entry (custom or derived) provides: # ssm_base_logp_func - annotated JAX function (inputs + outputs, no computed) # list_params_ssm - ordered SSM parameter names (including computed ones) -# bounds_ssm - bounds for non-computed SSM params +# bounds_ssm - bounds for all SSM params # params_default_ssm - default values aligned with list_params_ssm # response - data column names +_SSM_REGISTRY: dict[str, dict[str, Any]] = {} -def _make_angle_base_logp() -> Any: - """Build the annotated angle SSM base logp function from its ONNX file.""" - _raw = make_jax_matrix_logp_funcs_from_onnx(model="angle.onnx") +# Cache for resolved SSM base logp functions (populated on first use by +# _get_ssm_logp, or immediately by register_ssm when a pre-built func is +# supplied by the caller). +_SSM_LOGP_CACHE: dict[str, Any] = {} + + +def _make_ssm_base_logp_from_onnx( + onnx_file: str, + list_params_ssm: list[str], + response: list[str], +) -> Any: + """Build and annotate a JAX log-likelihood function from an ONNX model file.""" + _raw = make_jax_matrix_logp_funcs_from_onnx(model=onnx_file) return annotate_function( - inputs=["v", "a", "z", "t", "theta", "rt", "response"], + inputs=list_params_ssm + response, outputs=["logp"], )(_raw) -_SSM_REGISTRY: dict[str, dict[str, Any]] = { - "angle": { - # Factory callable — invoked lazily on first use via _get_ssm_logp(). - # Storing a callable (not the result) avoids loading angle.onnx at - # import time (which would trigger hf_hub_download in offline envs). - "ssm_base_logp_func_factory": _make_angle_base_logp, - # All SSM params in the order the SSM expects them (includes computed). - "list_params_ssm": ["v", "a", "z", "t", "theta"], - # Bounds only for params that will be *sampled* (not RL-computed). - "bounds_ssm": { - "a": (0.3, 3.0), - "z": (0.1, 0.9), - "t": (0.001, 2.0), - "theta": (-0.1, 1.3), - }, - # Defaults aligned with list_params_ssm: v, a, z, t, theta - "params_default_ssm": [0.0, 1.5, 0.5, 0.5, 0.0], - "response": ["rt", "response"], - }, -} +def _build_ssm_spec_from_modelconfig(name: str) -> dict[str, Any]: + """Build an SSM registry-compatible spec from HSSM's modelconfig system. -# Cache for resolved SSM base logp functions (populated on first use by -# _get_ssm_logp, or immediately by register_ssm when a pre-built func is -# supplied by the caller). -_SSM_LOGP_CACHE: dict[str, Any] = {} + This allows any built-in HSSM model with an ``approx_differentiable`` + likelihood to be used as an RLSSM decision process without re-registering + it in ``_SSM_REGISTRY``. Parameter defaults are computed as midpoints of + the model's parameter bounds. + + Raises + ------ + ValueError + If *name* is not a supported HSSM model or it has no + ``approx_differentiable`` likelihood. + """ + # Local import to avoid circular dependencies at module level. + from hssm.modelconfig import get_default_model_config # noqa: PLC0415 + + try: + mc = get_default_model_config(name) # type: ignore[arg-type] + except ValueError as exc: + raise ValueError( + f"Decision process '{name}' is not a registered custom SSM and is not a " + "supported HSSM model. " + f"Custom SSMs in registry: {list(_SSM_REGISTRY.keys())}. " + "Use register_ssm() to add a custom decision process." + ) from exc + + ad = mc["likelihoods"].get("approx_differentiable") + if ad is None: + raise ValueError( + f"Model '{name}' has no approx_differentiable likelihood and cannot be " + "used as an RLSSM decision process." + ) + + list_params_ssm: list[str] = list(mc["list_params"]) + bounds_ssm: dict[str, tuple[float, float]] = dict(ad["bounds"]) + response: list[str] = list(mc["response"]) + onnx_file: str = str(ad["loglik"]) + + # Derive parameter defaults as midpoints of their respective bounds. + params_default_ssm = [ + (bounds_ssm[p][0] + bounds_ssm[p][1]) / 2.0 if p in bounds_ssm else 0.0 + for p in list_params_ssm + ] + + # Capture loop variables explicitly to avoid closure-over-variable issues. + def _factory( + _onnx_file: str = onnx_file, + _params: list[str] = list_params_ssm, + _response: list[str] = response, + ) -> Any: + return _make_ssm_base_logp_from_onnx(_onnx_file, _params, _response) + + return { + "ssm_base_logp_func_factory": _factory, + "list_params_ssm": list_params_ssm, + "bounds_ssm": bounds_ssm, + "params_default_ssm": params_default_ssm, + "response": response, + "name": name, + } def _get_decision_process_spec( decision_process: str | dict[str, Any], ) -> dict[str, Any]: - """Return a defensive copy of a registered decision-process specification.""" + """Return a defensive copy of a decision-process specification. + + Custom SSMs (registered via :func:`register_ssm`) take precedence. For + everything else the spec is derived on the fly from HSSM's modelconfig + system, meaning any built-in model with an ``approx_differentiable`` + likelihood (e.g. ``"ddm"``, ``"angle"``, ``"weibull"``) works out of the + box without explicit registration. + """ if isinstance(decision_process, dict): return deepcopy(decision_process) - if decision_process not in _SSM_REGISTRY: - available_ssms = list(_SSM_REGISTRY.keys()) - raise ValueError( - f"Decision process '{decision_process}' not found in the SSM registry. " - f"Available: {available_ssms}. Use register_ssm() to add it." - ) + # Custom registry takes precedence over modelconfig. + if decision_process in _SSM_REGISTRY: + spec = deepcopy(_SSM_REGISTRY[decision_process]) + spec["name"] = decision_process + return spec - spec = deepcopy(_SSM_REGISTRY[decision_process]) - spec["name"] = decision_process - return spec + # Fall back to HSSM's modelconfig for built-in SSMs. + return _build_ssm_spec_from_modelconfig(decision_process) def _get_ssm_logp(name: str) -> Any: """Return the annotated SSM base logp function, building it on first use. - For built-in SSMs the ONNX model is downloaded / loaded only when this - function is first called (lazy initialisation). Subsequent calls return - the cached object without any I/O. + ONNX models are downloaded / loaded only when first called (lazy + initialisation). Subsequent calls return the cached object. """ if name not in _SSM_LOGP_CACHE: - if name not in _SSM_REGISTRY: - raise KeyError( - f"SSM '{name}' not found in the SSM registry. " - f"Available: {list(_SSM_REGISTRY.keys())}. " - "Use register_ssm() to add it." - ) - entry = _SSM_REGISTRY[name] - if "ssm_base_logp_func_factory" in entry: - _SSM_LOGP_CACHE[name] = entry["ssm_base_logp_func_factory"]() + if name in _SSM_REGISTRY: + entry = _SSM_REGISTRY[name] + if "ssm_base_logp_func_factory" in entry: + _SSM_LOGP_CACHE[name] = entry["ssm_base_logp_func_factory"]() + else: + # Pre-built function registered via register_ssm(). + _SSM_LOGP_CACHE[name] = entry["ssm_base_logp_func"] else: - # Pre-built function registered via register_ssm(). - _SSM_LOGP_CACHE[name] = entry["ssm_base_logp_func"] + # Build from HSSM's modelconfig for built-in SSMs. + spec = _build_ssm_spec_from_modelconfig(name) + _SSM_LOGP_CACHE[name] = spec["ssm_base_logp_func_factory"]() return _SSM_LOGP_CACHE[name] From 5b5b380d99d9def86d7c884d10e15cd44e31d61f Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 8 May 2026 14:21:28 -0400 Subject: [PATCH 15/41] Add unit tests for RL registry helpers to validate model registration and configuration --- tests/rl/test_registry.py | 558 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 558 insertions(+) create mode 100644 tests/rl/test_registry.py diff --git a/tests/rl/test_registry.py b/tests/rl/test_registry.py new file mode 100644 index 00000000..6c0b8d9b --- /dev/null +++ b/tests/rl/test_registry.py @@ -0,0 +1,558 @@ +"""Unit tests for the RL registry helpers. + +These tests exercise the registry module directly without constructing full +RLSSM model instances, so regressions in lazy SSM resolution, config +composition, and registration validation are caught at the module boundary. +""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any + +import pytest + +from hssm.rl import registry +from hssm.rl.config import RLSSMConfig +from hssm.utils import annotate_function + + +@pytest.fixture(autouse=True) +def isolated_registries(monkeypatch: pytest.MonkeyPatch) -> None: + """Isolate global registries so tests do not leak state.""" + monkeypatch.setattr(registry, "_SSM_REGISTRY", deepcopy(registry._SSM_REGISTRY)) + monkeypatch.setattr(registry, "_RLSSM_REGISTRY", deepcopy(registry._RLSSM_REGISTRY)) + monkeypatch.setattr(registry, "_SSM_LOGP_CACHE", dict(registry._SSM_LOGP_CACHE)) + + +@pytest.fixture +def learning_process() -> dict[str, Any]: + """Return an annotated RL learning rule for test models.""" + + @annotate_function( + inputs=["rl_alpha", "response", "feedback"], + outputs=["v"], + ) + def compute_v(rl_alpha, response, feedback): + return rl_alpha + + return {"v": compute_v} + + +@pytest.fixture +def annotated_ssm_base_logp() -> Any: + """Return an annotated SSM base log-likelihood function.""" + + @annotate_function( + inputs=["v", "a", "rt", "response"], + outputs=["logp"], + ) + def base_logp(v, a, rt, response): + return a + + return base_logp + + +def test_get_ssm_logp_builds_lazy_factory_once(annotated_ssm_base_logp: Any) -> None: + """Lazy SSM factories should only build and cache one function instance.""" + call_count = 0 + + def factory() -> Any: + nonlocal call_count + call_count += 1 + return annotated_ssm_base_logp + + registry._SSM_REGISTRY["lazy_unit_test_ssm"] = { + "ssm_base_logp_func_factory": factory, + "list_params_ssm": ["v", "a"], + "bounds_ssm": {"a": (0.3, 3.0)}, + "params_default_ssm": [0.0, 1.5], + "response": ["rt", "response"], + } + + first = registry._get_ssm_logp("lazy_unit_test_ssm") + second = registry._get_ssm_logp("lazy_unit_test_ssm") + + assert first is annotated_ssm_base_logp + assert second is first + assert call_count == 1 + + +def test_derive_rl_params_excludes_response_and_extra_fields( + learning_process: dict[str, Any], +) -> None: + """Derived RL params should ignore response columns and extra fields.""" + derived = registry._derive_rl_params( + learning_process=learning_process, + response=["rt", "response"], + extra_fields=["feedback"], + ) + + assert derived == ["rl_alpha"] + + +def test_get_rlssm_model_config_builds_expected_config( + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], +) -> None: + """Registry config composition should exclude computed SSM params.""" + registry.register_ssm( + name="unit_test_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + response=["rt", "response"], + ) + registry.register_rlssm_model( + name="unit_test_model", + decision_process="unit_test_ssm", + learning_process=learning_process, + rl_params=["rl_alpha"], + rl_bounds={"rl_alpha": (0.0, 1.0)}, + rl_params_default=[0.2], + extra_fields=["feedback"], + choices=[0, 1], + ) + + config = registry.get_rlssm_model_config("unit_test_model") + + assert isinstance(config, RLSSMConfig) + assert config.list_params == ["rl_alpha", "a"] + assert config.bounds == {"rl_alpha": (0.0, 1.0), "a": (0.3, 3.0)} + assert config.params_default == [0.2, 1.5] + assert config.response == ["rt", "response"] + assert config.ssm_logp_func.computed == learning_process + + config.response.append("mutated") + assert registry._SSM_REGISTRY["unit_test_ssm"]["response"] == ["rt", "response"] + + +def test_get_rlssm_model_config_respects_explicit_empty_rl_fields( + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], +) -> None: + """Explicit empty RL collections must not trigger fallback derivation.""" + registry.register_ssm( + name="empty_rl_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + response=["rt", "response"], + ) + registry._RLSSM_REGISTRY["empty_rl_model"] = { + "decision_process": "empty_rl_ssm", + "learning_process": learning_process, + "rl_params": [], + "rl_bounds": {}, + "rl_params_default": [], + "extra_fields": ["feedback"], + "choices": [0, 1], + "description": "test model", + "decision_process_loglik_kind": "approx_differentiable", + "learning_process_kind": "blackbox", + } + + config = registry.get_rlssm_model_config("empty_rl_model") + + assert config.list_params == ["a"] + assert config.bounds == {"a": (0.3, 3.0)} + assert config.params_default == [1.5] + + +def test_register_rlssm_model_copies_mutable_inputs( + learning_process: dict[str, Any], +) -> None: + """Caller mutations after registration must not alter the stored model.""" + rl_params = ["rl_alpha"] + rl_bounds = {"rl_alpha": (0.0, 1.0)} + rl_defaults = [0.2] + extra_fields = ["feedback"] + choices = [0, 1] + + registry.register_rlssm_model( + name="copy_test_model", + decision_process="angle", + learning_process=learning_process, + rl_params=rl_params, + rl_bounds=rl_bounds, + rl_params_default=rl_defaults, + extra_fields=extra_fields, + choices=choices, + ) + + rl_params.append("scaler") + rl_bounds["scaler"] = (0.0, 10.0) + rl_defaults.append(1.0) + extra_fields.append("trial") + choices.append(2) + # Built-in SSMs (e.g. "angle") are derived from modelconfig on each call + # and are never stored as shared mutable state in _SSM_REGISTRY, so there + # is no registry entry to corrupt here. + learning_process["other"] = next(iter(learning_process.values())) + + stored = registry._RLSSM_REGISTRY["copy_test_model"] + + assert stored["decision_process"]["name"] == "angle" + assert stored["decision_process"]["response"] == ["rt", "response"] + assert stored["rl_params"] == ["rl_alpha"] + assert stored["rl_bounds"] == {"rl_alpha": (0.0, 1.0)} + assert stored["rl_params_default"] == [0.2] + assert stored["extra_fields"] == ["feedback"] + assert stored["choices"] == [0, 1] + assert list(stored["learning_process"]) == ["v"] + + +def test_register_ssm_caches_prebuilt_function( + annotated_ssm_base_logp: Any, +) -> None: + """Registering a pre-built SSM should populate the cache immediately.""" + registry.register_ssm( + name="cached_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + ) + + assert registry._SSM_LOGP_CACHE["cached_ssm"] is annotated_ssm_base_logp + assert registry._SSM_REGISTRY["cached_ssm"]["response"] == ["rt", "response"] + + +def test_register_ssm_rejects_precomputed_function( + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], +) -> None: + """SSM registration should reject functions that already carry computed params.""" + precomputed_logp = annotate_function( + inputs=annotated_ssm_base_logp.inputs, + outputs=annotated_ssm_base_logp.outputs, + computed=learning_process, + )(annotated_ssm_base_logp) + + with pytest.raises(ValueError, match="should not have a non-empty .computed"): + registry.register_ssm( + name="invalid_ssm", + ssm_base_logp_func=precomputed_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + ) + + +def test_get_rlssm_model_config_unknown_model_raises() -> None: + """Unknown RLSSM model names should fail with a clear error.""" + with pytest.raises(ValueError, match="not found in the RLSSM registry"): + registry.get_rlssm_model_config("does_not_exist") + + +# --------------------------------------------------------------------------- +# _build_ssm_spec_from_modelconfig — error paths +# --------------------------------------------------------------------------- + + +def test_build_ssm_spec_unknown_model_raises() -> None: + """Completely unknown names should raise ValueError from the modelconfig bridge.""" + with pytest.raises(ValueError, match="not a registered custom SSM"): + registry._build_ssm_spec_from_modelconfig("totally_unknown_model") + + +def test_build_ssm_spec_no_approx_differentiable_raises( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Models without an approx_differentiable likelihood must be rejected.""" + import hssm.modelconfig as mc_module + + def _fake_get(name): # type: ignore[no-untyped-def] + return { + "list_params": ["v", "a"], + "response": ["rt", "response"], + "likelihoods": { + "analytical": {"loglik": lambda: None, "bounds": {}, "backend": None}, + }, + } + + monkeypatch.setattr(mc_module, "get_default_model_config", _fake_get) + with pytest.raises(ValueError, match="no approx_differentiable likelihood"): + registry._build_ssm_spec_from_modelconfig("ddm") + + +# --------------------------------------------------------------------------- +# _make_ssm_base_logp_from_onnx + _factory (ONNX paths, mocked) +# --------------------------------------------------------------------------- + + +def test_build_ssm_spec_factory_calls_onnx_loader( + monkeypatch: pytest.MonkeyPatch, + annotated_ssm_base_logp: Any, +) -> None: + """The lazy factory produced by _build_ssm_spec_from_modelconfig should + call make_jax_matrix_logp_funcs_from_onnx with the correct filename.""" + import hssm.distribution_utils.onnx as onnx_module + + called_with: list[str] = [] + + def _fake_onnx(model: str) -> Any: + called_with.append(model) + return annotated_ssm_base_logp + + monkeypatch.setattr(onnx_module, "make_jax_matrix_logp_funcs_from_onnx", _fake_onnx) + monkeypatch.setattr(registry, "make_jax_matrix_logp_funcs_from_onnx", _fake_onnx) + + spec = registry._build_ssm_spec_from_modelconfig("angle") + result = spec["ssm_base_logp_func_factory"]() + + assert called_with == ["angle.onnx"] + assert callable(result) + assert result.inputs == ["v", "a", "z", "t", "theta", "rt", "response"] + assert result.outputs == ["logp"] + + +# --------------------------------------------------------------------------- +# _get_ssm_logp — built-in SSM path (not in _SSM_REGISTRY) +# --------------------------------------------------------------------------- + + +def test_get_ssm_logp_resolves_builtin_via_modelconfig( + monkeypatch: pytest.MonkeyPatch, + annotated_ssm_base_logp: Any, +) -> None: + """_get_ssm_logp should build and cache the logp for a built-in SSM.""" + import hssm.distribution_utils.onnx as onnx_module + + monkeypatch.setattr( + onnx_module, + "make_jax_matrix_logp_funcs_from_onnx", + lambda **_: annotated_ssm_base_logp, + ) + monkeypatch.setattr( + registry, + "make_jax_matrix_logp_funcs_from_onnx", + lambda **_: annotated_ssm_base_logp, + ) + + result = registry._get_ssm_logp("angle") + + assert callable(result) + assert "angle" in registry._SSM_LOGP_CACHE + # Second call returns cached value without re-building. + assert registry._get_ssm_logp("angle") is result + + +# --------------------------------------------------------------------------- +# _build_ssm_logp_func — raises when func already carries .computed +# --------------------------------------------------------------------------- + + +def test_build_ssm_logp_func_raises_if_already_computed( + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], +) -> None: + """_build_ssm_logp_func must refuse functions that already have .computed set.""" + precomputed = annotate_function( + inputs=annotated_ssm_base_logp.inputs, + outputs=annotated_ssm_base_logp.outputs, + computed=learning_process, + )(annotated_ssm_base_logp) + + with pytest.raises(ValueError, match="already has a non-empty .computed"): + registry._build_ssm_logp_func(precomputed, learning_process) + + +# --------------------------------------------------------------------------- +# _derive_rl_params — LP function missing .inputs (warning branch) +# --------------------------------------------------------------------------- + + +def test_derive_rl_params_warns_for_unannotated_lp_func() -> None: + """A learning-process function without .inputs should log a warning and be skipped.""" + + def unannotated_func(x): # type: ignore[no-untyped-def] + return x + + lp = {"v": unannotated_func} + result = registry._derive_rl_params( + learning_process=lp, + response=["rt", "response"], + extra_fields=["feedback"], + ) + # The unannotated function contributes no params. + assert result == [] + + +# --------------------------------------------------------------------------- +# get_rlssm_model_config — rl_params=None fallback derivation +# --------------------------------------------------------------------------- + + +def test_get_rlssm_model_config_derives_rl_params_when_absent( + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], +) -> None: + """When rl_params is absent from the registry entry, params are derived + from the learning_process .inputs.""" + registry.register_ssm( + name="derive_params_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + response=["rt", "response"], + ) + # Inject an entry without rl_params so the fallback derivation runs. + registry._RLSSM_REGISTRY["derive_params_model"] = { + "decision_process": "derive_params_ssm", + "learning_process": learning_process, + # rl_params deliberately absent + "rl_bounds": {}, + "rl_params_default": [], + "extra_fields": ["feedback"], + "choices": [0, 1], + "description": None, + "decision_process_loglik_kind": "approx_differentiable", + "learning_process_kind": "blackbox", + } + + config = registry.get_rlssm_model_config("derive_params_model") + + # "rl_alpha" is the only input to learning_process that isn't response/extra. + assert "rl_alpha" in config.list_params + + +# --------------------------------------------------------------------------- +# get_rlssm_model_config — SSM param absent from bounds_ssm is skipped +# --------------------------------------------------------------------------- + + +def test_get_rlssm_model_config_skips_missing_bounds( + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], +) -> None: + """SSM params not present in bounds_ssm must not appear in the output bounds.""" + registry.register_ssm( + name="no_bounds_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + # "a" intentionally has no bounds entry + bounds_ssm={}, + params_default_ssm=[0.0, 1.5], + response=["rt", "response"], + ) + registry.register_rlssm_model( + name="no_bounds_model", + decision_process="no_bounds_ssm", + learning_process=learning_process, + rl_params=["rl_alpha"], + rl_bounds={"rl_alpha": (0.0, 1.0)}, + rl_params_default=[0.2], + extra_fields=["feedback"], + choices=[0, 1], + ) + + config = registry.get_rlssm_model_config("no_bounds_model") + + assert "a" not in config.bounds + assert "rl_alpha" in config.bounds + + +# --------------------------------------------------------------------------- +# register_rlssm_model — overwrite warning +# --------------------------------------------------------------------------- + + +def test_register_rlssm_model_warns_on_overwrite( + learning_process: dict[str, Any], + annotated_ssm_base_logp: Any, + caplog: pytest.LogCaptureFixture, +) -> None: + """Re-registering an existing RLSSM model name should emit a warning.""" + registry.register_ssm( + name="overwrite_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + ) + registry.register_rlssm_model( + name="overwrite_rlssm", + decision_process="overwrite_ssm", + learning_process=learning_process, + rl_params=["rl_alpha"], + rl_bounds={"rl_alpha": (0.0, 1.0)}, + rl_params_default=[0.2], + ) + + import logging + + with caplog.at_level(logging.WARNING, logger="hssm"): + registry.register_rlssm_model( + name="overwrite_rlssm", + decision_process="overwrite_ssm", + learning_process=learning_process, + rl_params=["rl_alpha"], + rl_bounds={"rl_alpha": (0.0, 1.0)}, + rl_params_default=[0.2], + ) + + assert any("overwrite_rlssm" in r.message for r in caplog.records) + + +# --------------------------------------------------------------------------- +# register_ssm — validation error paths +# --------------------------------------------------------------------------- + + +def test_register_ssm_rejects_non_callable() -> None: + """register_ssm must raise when ssm_base_logp_func is not callable.""" + with pytest.raises(ValueError, match="must be callable"): + registry.register_ssm( + name="bad_ssm", + ssm_base_logp_func="not_a_function", # type: ignore[arg-type] + list_params_ssm=["v"], + bounds_ssm={}, + params_default_ssm=[0.0], + ) + + +def test_register_ssm_rejects_unannotated_callable() -> None: + """register_ssm must raise when the callable lacks .inputs or .outputs.""" + + def plain(x): # type: ignore[no-untyped-def] + return x + + with pytest.raises(ValueError, match="must be decorated with @annotate_function"): + registry.register_ssm( + name="unannotated_ssm", + ssm_base_logp_func=plain, + list_params_ssm=["v"], + bounds_ssm={}, + params_default_ssm=[0.0], + ) + + +def test_register_ssm_warns_on_overwrite( + annotated_ssm_base_logp: Any, + caplog: pytest.LogCaptureFixture, +) -> None: + """Re-registering an existing SSM name should emit a warning.""" + registry.register_ssm( + name="dup_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + ) + + import logging + + with caplog.at_level(logging.WARNING, logger="hssm"): + registry.register_ssm( + name="dup_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + ) + + assert any("dup_ssm" in r.message for r in caplog.records) From 486076bcff2dffd0c3a3b95991a525e4a902c965 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 8 May 2026 14:29:33 -0400 Subject: [PATCH 16/41] Move rl-related tests to rl directory --- tests/{ => rl}/test_rl_builder_output_shape.py | 0 tests/{ => rl}/test_rl_utils.py | 0 tests/{ => rl}/test_rldm_likelihood.py | 2 +- tests/{ => rl}/test_rlssm.py | 2 +- tests/{ => rl}/test_rlssm_config.py | 0 5 files changed, 2 insertions(+), 2 deletions(-) rename tests/{ => rl}/test_rl_builder_output_shape.py (100%) rename tests/{ => rl}/test_rl_utils.py (100%) rename tests/{ => rl}/test_rldm_likelihood.py (97%) rename tests/{ => rl}/test_rlssm.py (99%) rename tests/{ => rl}/test_rlssm_config.py (100%) diff --git a/tests/test_rl_builder_output_shape.py b/tests/rl/test_rl_builder_output_shape.py similarity index 100% rename from tests/test_rl_builder_output_shape.py rename to tests/rl/test_rl_builder_output_shape.py diff --git a/tests/test_rl_utils.py b/tests/rl/test_rl_utils.py similarity index 100% rename from tests/test_rl_utils.py rename to tests/rl/test_rl_utils.py diff --git a/tests/test_rldm_likelihood.py b/tests/rl/test_rldm_likelihood.py similarity index 97% rename from tests/test_rldm_likelihood.py rename to tests/rl/test_rldm_likelihood.py index e679f22c..54fd544c 100644 --- a/tests/test_rldm_likelihood.py +++ b/tests/rl/test_rldm_likelihood.py @@ -14,7 +14,7 @@ @pytest.fixture def fixture_path(): - return Path(__file__).parent / "fixtures" + return Path(__file__).parent.parent / "fixtures" def test_make_rldm_logp_func(fixture_path): diff --git a/tests/test_rlssm.py b/tests/rl/test_rlssm.py similarity index 99% rename from tests/test_rlssm.py rename to tests/rl/test_rlssm.py index afd37365..62d612eb 100644 --- a/tests/test_rlssm.py +++ b/tests/rl/test_rlssm.py @@ -67,7 +67,7 @@ def _set_floatx_float32() -> Generator[None, None, None]: def rldm_data() -> pd.DataFrame: """Load the RLDM fixture dataset (balanced panel).""" raw = np.load( - Path(__file__).parent / "fixtures" / "rldm_data.npy", allow_pickle=True + Path(__file__).parent.parent / "fixtures" / "rldm_data.npy", allow_pickle=True ).item() return pd.DataFrame(raw["data"]) diff --git a/tests/test_rlssm_config.py b/tests/rl/test_rlssm_config.py similarity index 100% rename from tests/test_rlssm_config.py rename to tests/rl/test_rlssm_config.py From e29e731a6331f48b90a45927c5c3f07098509d56 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 8 May 2026 14:50:54 -0400 Subject: [PATCH 17/41] Add extra tests to increase test coverage --- tests/rl/test_rlssm.py | 63 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/tests/rl/test_rlssm.py b/tests/rl/test_rlssm.py index 62d612eb..8531ef9d 100644 --- a/tests/rl/test_rlssm.py +++ b/tests/rl/test_rlssm.py @@ -5,9 +5,12 @@ variant, bambi / PyMC model construction, and a sampling smoke test. """ +import logging from collections.abc import Generator from pathlib import Path +from unittest.mock import patch +import cloudpickle import jax.numpy as jnp import numpy as np import pandas as pd @@ -15,7 +18,8 @@ import pytest import hssm -from hssm.rl import RLSSM, RLSSMConfig, _RLSSM, register_rlssm_model +from hssm.distribution_utils import make_distribution as real_make_distribution +from hssm.rl import _RLSSM, RLSSM, RLSSMConfig, register_rlssm_model from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise from hssm.utils import annotate_function @@ -256,10 +260,6 @@ def test_rlssm_extra_fields_are_copies(rldm_data, rlssm_config) -> None: to_numpy(copy=True) should return a new buffer; if it returned a view, in-place mutations of the DataFrame would silently corrupt the distribution. """ - from unittest.mock import patch - - from hssm.distribution_utils import make_distribution as real_make_distribution - model = RLSSM(data=rldm_data, model_config=rlssm_config) captured: dict = {} @@ -305,7 +305,7 @@ def test_rlssm_sample_smoke(rldm_data, rlssm_config) -> None: def test_rlssm_pickle_round_trip( rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig ) -> None: - """cloudpickle round-trip must reconstruct an equivalent RLSSM. + """Cloudpickle round-trip must reconstruct an equivalent RLSSM. Verifies that __getstate__ / __setstate__ survive serialisation: - The reconstructed object is a fresh RLSSM (not the same instance). @@ -314,8 +314,6 @@ def test_rlssm_pickle_round_trip( - model_config.model_name is preserved. - model.model (bambi model) is rebuilt, confirming full re-initialisation. """ - import cloudpickle - model = RLSSM(data=rldm_data, model_config=rlssm_config) blob = cloudpickle.dumps(model) restored = cloudpickle.loads(blob) @@ -419,3 +417,52 @@ def test_rlssm_init_args_uses_simplified_interface(rldm_data) -> None: # model_config should not be baked in as a hard reference assert "model_config" in model._init_args assert model._init_args["model_config"] is None + + +def test_rlssm_model_config_with_overrides_warns( + rldm_data, rlssm_config, caplog +) -> None: + """Warn when model_config is given alongside model/overrides. + + The extra arguments (model, learning_process, decision_process, choices) + should be ignored and a warning emitted. + """ + with caplog.at_level(logging.WARNING, logger="hssm"): + RLSSM(data=rldm_data, model_config=rlssm_config, model="some_other_model") + + assert any("ignoring" in r.message for r in caplog.records) + + +def test_rlssm_no_extra_fields_none_passed_to_make_distribution( + rldm_data, rlssm_config +) -> None: + """When extra_fields is empty, make_distribution receives extra_fields=None.""" + # Build a config without extra_fields. + config_no_extra = RLSSMConfig( + model_name="rldm_no_extra", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_kind="blackbox", + list_params=rlssm_config.list_params, + params_default=rlssm_config.params_default, + bounds=rlssm_config.bounds, + learning_process=rlssm_config.learning_process, + response=list(rlssm_config.response), + choices=list(rlssm_config.choices), + extra_fields=[], # empty → should pass None to make_distribution + ssm_logp_func=_dummy_ssm_logp, + ) + model = RLSSM(data=rldm_data, model_config=config_no_extra) + captured: dict = {} + + def capturing_make_distribution(*args, **kwargs): + captured["extra_fields"] = kwargs.get("extra_fields") + return real_make_distribution(*args, **kwargs) + + with patch( + "hssm.rl.rlssm.make_distribution", side_effect=capturing_make_distribution + ): + model._make_model_distribution() + + assert captured.get("extra_fields") is None From d12594f4c8993c7bc2dcb3c23eb6a4f6225f93af Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 11 May 2026 14:41:33 -0400 Subject: [PATCH 18/41] Update RLSSM registry to include 2AB Rescorla-Wagner model --- src/hssm/rl/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index 80fd7f1a..9e6d496d 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -198,7 +198,7 @@ def _get_ssm_logp(name: str) -> Any: # learning_process_kind _RLSSM_REGISTRY: dict[str, dict[str, Any]] = { - "rldm": { + "2AB_RescorlaWagner_Angle": { "decision_process": _get_decision_process_spec("angle"), "learning_process": {"v": _compute_v_annotated}, "rl_params": ["rl_alpha", "scaler"], From 215f50f6672640956cf1d7d4da7a60d1d8eecabc Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 11 May 2026 17:57:14 -0400 Subject: [PATCH 19/41] Remove redundant classproperty decorator --- src/hssm/hssm.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 1e43582d..1952ec63 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -51,34 +51,6 @@ } -class classproperty: - """A decorator that combines the behavior of @property and @classmethod. - - This decorator allows you to define a property that can be accessed on the class - itself, rather than on instances of the class. It is useful for defining class-level - properties that need to perform some computation or access class-level data. - - This implementation is provided for compatibility with Python versions 3.10 through - 3.12, as one cannot combine the @property and @classmethod decorators is across all - these versions. - - Example - ------- - class MyClass: - @classproperty - def my_class_property(cls): - return "This is a class property" - - print(MyClass.my_class_property) # Output: This is a class property - """ - - def __init__(self, fget): - self.fget = fget - - def __get__(self, instance, owner): # noqa: D105 - return self.fget(owner) - - class HSSM(HSSMBase): """The basic Hierarchical Sequential Sampling Model (HSSM) class. From 434a27c376362ae02c8c2c58680cf69656e34eb3 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 11 May 2026 18:23:10 -0400 Subject: [PATCH 20/41] Add list_models to public API documentation in __init__.py --- src/hssm/rl/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/hssm/rl/__init__.py b/src/hssm/rl/__init__.py index 78579888..bb6a4cb1 100644 --- a/src/hssm/rl/__init__.py +++ b/src/hssm/rl/__init__.py @@ -20,7 +20,11 @@ """ from .config import RLSSMConfig -from .registry import get_rlssm_model_config, register_rlssm_model, register_ssm +from .registry import ( + get_rlssm_model_config, + register_rlssm_model, + register_ssm, +) from .rlssm import _RLSSM, RLSSM from .utils import validate_balanced_panel From a3d910b9ed6bd93ede995b11ce502203c5591341 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 11 May 2026 18:24:12 -0400 Subject: [PATCH 21/41] Add 2AB Rescorla-Wagner models to RLSSM registry and implement list_models function --- src/hssm/rl/registry.py | 66 ++++++++++++++++++++++++++++++++-- tests/rl/test_registry.py | 74 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 2 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index 9e6d496d..ded46949 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -198,6 +198,24 @@ def _get_ssm_logp(name: str) -> Any: # learning_process_kind _RLSSM_REGISTRY: dict[str, dict[str, Any]] = { + "2AB_RescorlaWagner_DDM": { + "decision_process": _get_decision_process_spec("ddm"), + "learning_process": {"v": _compute_v_annotated}, + "rl_params": ["rl_alpha", "scaler"], + "rl_bounds": { + "rl_alpha": (0.0, 1.0), + "scaler": (0.0, 10.0), + }, + "rl_params_default": [0.1, 1.0], + "extra_fields": ["feedback"], + "choices": [0, 1], + "description": ( + "RLSSM model with Rescorla-Wagner Q-learning and the " + "standard DDM as decision process." + ), + "decision_process_loglik_kind": "approx_differentiable", + "learning_process_kind": "blackbox", + }, "2AB_RescorlaWagner_Angle": { "decision_process": _get_decision_process_spec("angle"), "learning_process": {"v": _compute_v_annotated}, @@ -216,6 +234,24 @@ def _get_ssm_logp(name: str) -> Any: "decision_process_loglik_kind": "approx_differentiable", "learning_process_kind": "blackbox", }, + "2AB_RescorlaWagner_Weibull": { + "decision_process": _get_decision_process_spec("weibull"), + "learning_process": {"v": _compute_v_annotated}, + "rl_params": ["rl_alpha", "scaler"], + "rl_bounds": { + "rl_alpha": (0.0, 1.0), + "scaler": (0.0, 10.0), + }, + "rl_params_default": [0.1, 1.0], + "extra_fields": ["feedback"], + "choices": [0, 1], + "description": ( + "RLSSM model with Rescorla-Wagner Q-learning and a " + "Weibull-bound DDM as decision process." + ), + "decision_process_loglik_kind": "approx_differentiable", + "learning_process_kind": "blackbox", + }, } @@ -280,7 +316,7 @@ def _derive_rl_params( def get_rlssm_model_config( - model: str = "rldm", + model: str = "2AB_RescorlaWagner_DDM", choices: list[int] | None = None, learning_process: dict[str, Any] | None = None, decision_process: str | None = None, @@ -290,7 +326,7 @@ def get_rlssm_model_config( Parameters ---------- model: - Name of a registered RLSSM model (e.g. ``"rldm"``). + Name of a registered RLSSM model (e.g. ``"2AB_RescorlaWagner_DDM"``). choices: Override the response choice values stored in the registry. learning_process: @@ -391,6 +427,32 @@ def get_rlssm_model_config( ) +# --------------------------------------------------------------------------- +# Public query helpers +# --------------------------------------------------------------------------- + + +def list_models() -> dict[str, str | None]: + """Return the names and descriptions of all registered RLSSM models. + + This is the recommended starting point for new users who want to discover + which models are available out of the box. + + Returns + ------- + dict[str, str | None] + Mapping of model name → description string (or ``None`` if no + description was provided at registration time). + + Examples + -------- + >>> import hssm + >>> hssm.rl.list_rlssm_models() + {'2AB_RescorlaWagner_DDM': 'RLSSM model with Rescorla-Wagner ...', ...} + """ + return {name: entry.get("description") for name, entry in _RLSSM_REGISTRY.items()} + + # --------------------------------------------------------------------------- # Public registration helpers # --------------------------------------------------------------------------- diff --git a/tests/rl/test_registry.py b/tests/rl/test_registry.py index 6c0b8d9b..3a90b0e8 100644 --- a/tests/rl/test_registry.py +++ b/tests/rl/test_registry.py @@ -556,3 +556,77 @@ def test_register_ssm_warns_on_overwrite( ) assert any("dup_ssm" in r.message for r in caplog.records) + + +# --------------------------------------------------------------------------- +# Built-in starter-pack RLSSM models (DDM and Weibull variants) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "model_name, expected_dp", + [ + ("2AB_RescorlaWagner_DDM", "ddm"), + ("2AB_RescorlaWagner_Weibull", "weibull"), + ], +) +def test_builtin_2ab_models_are_registered(model_name: str, expected_dp: str) -> None: + """2AB_RescorlaWagner_DDM and _Weibull must be present in the RLSSM registry.""" + assert model_name in registry._RLSSM_REGISTRY + entry = registry._RLSSM_REGISTRY[model_name] + assert entry["decision_process"]["name"] == expected_dp + assert entry["rl_params"] == ["rl_alpha", "scaler"] + assert entry["extra_fields"] == ["feedback"] + assert entry["choices"] == [0, 1] + assert entry["decision_process_loglik_kind"] == "approx_differentiable" + assert entry["learning_process_kind"] == "blackbox" + + +@pytest.mark.parametrize( + "model_name", + ["2AB_RescorlaWagner_DDM", "2AB_RescorlaWagner_Weibull"], +) +def test_builtin_2ab_models_config_structure( + monkeypatch: pytest.MonkeyPatch, + annotated_ssm_base_logp: Any, + model_name: str, +) -> None: + """get_rlssm_model_config should produce a well-formed RLSSMConfig for + both the DDM and Weibull starter-pack models.""" + import hssm.distribution_utils.onnx as onnx_module + + monkeypatch.setattr( + onnx_module, + "make_jax_matrix_logp_funcs_from_onnx", + lambda model: annotated_ssm_base_logp, + ) + monkeypatch.setattr( + registry, + "make_jax_matrix_logp_funcs_from_onnx", + lambda model: annotated_ssm_base_logp, + ) + + config = registry.get_rlssm_model_config(model_name) + + assert isinstance(config, RLSSMConfig) + # RL params come first + assert config.list_params[:2] == ["rl_alpha", "scaler"] + assert "rl_alpha" in config.bounds + assert "scaler" in config.bounds + assert config.choices == [0, 1] + assert config.extra_fields == ["feedback"] + assert config.ssm_logp_func.computed == {"v": registry._compute_v_annotated} + + +# --------------------------------------------------------------------------- +# list_rlssm_models +# --------------------------------------------------------------------------- + + +def test_list_rlssm_models_returns_all_names() -> None: + """list_rlssm_models should return every key in _RLSSM_REGISTRY with its description.""" + result = registry.list_models() + + assert set(result.keys()) == set(registry._RLSSM_REGISTRY.keys()) + for name, desc in result.items(): + assert desc == registry._RLSSM_REGISTRY[name].get("description") From 10cd59ee69cddec1282024d60e711745e867750e Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 11 May 2026 18:24:26 -0400 Subject: [PATCH 22/41] Update RLSSM model default and add classproperty for listing models --- src/hssm/rl/rlssm.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 3a0c9ff8..2a760a0c 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -40,7 +40,7 @@ from hssm.rl.likelihoods.builder import make_rl_logp_op from hssm.rl.utils import validate_balanced_panel -from ..base import HSSMBase +from ..base import HSSMBase, classproperty from .config import RLSSMConfig from .registry import get_rlssm_model_config @@ -351,7 +351,7 @@ class RLSSM(_RLSSM): data : pd.DataFrame Trial-level data (balanced panel required). model : str, optional - Name of a registered RLSSM model. Defaults to ``"rldm"``. + Name of a registered RLSSM model. Defaults to ``"2AB_RescorlaWagner_DDM"``. choices : list[int] | None, optional Override the choice values in the registry. ``None`` uses the registry default. @@ -412,7 +412,7 @@ class RLSSM(_RLSSM): def __init__( self, data: pd.DataFrame, - model: str = "rldm", + model: str = "2AB_RescorlaWagner_DDM", choices: list[int] | None = None, include: list[dict[str, Any] | Any] | None = None, model_config: RLSSMConfig | None = None, @@ -479,3 +479,25 @@ def __init__( # IMPORTANT: This MUST be the last statement in __init__. Any code # added after this line cannot read self.missing_data / self.deadline. self.__dict__["_rlssm_fully_initialized"] = True + + @classproperty + def list_models(cls) -> dict[str, str | None]: + """All registered RLSSM models and their descriptions. + + This is the recommended entry point for newcomers to discover which + models are available without constructing a full model instance. + + Returns + ------- + dict[str, str | None] + Mapping of model name → description (``None`` if not provided). + + Examples + -------- + >>> from hssm.rl import RLSSM + >>> RLSSM.list_models + {'2AB_RescorlaWagner_DDM': 'RLSSM model with ...', ...} + """ + from .registry import list_models # noqa: PLC0415 + + return list_models() From a34a76694f34bc867f039c17309908fae273ac7d Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Mon, 11 May 2026 18:26:15 -0400 Subject: [PATCH 23/41] Update RLSSM tests to validate 2AB Rescorla-Wagner model instantiation and defaults --- tests/rl/test_rlssm.py | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/tests/rl/test_rlssm.py b/tests/rl/test_rlssm.py index 8531ef9d..c52cc122 100644 --- a/tests/rl/test_rlssm.py +++ b/tests/rl/test_rlssm.py @@ -337,23 +337,32 @@ def test_rlssm_is_subclass_of_internal() -> None: assert issubclass(RLSSM, _RLSSM) -def test_rlssm_simplified_init(rldm_data) -> None: - """RLSSM(data, model='rldm') should build without supplying model_config.""" - model = RLSSM(data=rldm_data, model="rldm") +@pytest.mark.parametrize( + "model_name, expected_dp", + [ + ("2AB_RescorlaWagner_DDM", "ddm"), + ("2AB_RescorlaWagner_Angle", "angle"), + ("2AB_RescorlaWagner_Weibull", "weibull"), + ], +) +def test_rlssm_builtin_models_instantiate( + rldm_data, model_name: str, expected_dp: str +) -> None: + """All built-in 2AB_RescorlaWagner_* models should instantiate correctly.""" + model = RLSSM(data=rldm_data, model=model_name) assert isinstance(model, RLSSM) - # Registry-derived params: rl_alpha, scaler (RL) + a, z, t, theta (SSM) + assert model.model_config.decision_process == expected_dp assert "rl_alpha" in model.params assert "scaler" in model.params assert "a" in model.params - # v is computed inside the Op; must NOT appear as a free parameter assert "v" not in model.params -def test_rlssm_default_model_is_rldm(rldm_data) -> None: - """Omitting model should default to 'rldm'.""" +def test_rlssm_default_model_is_ddm(rldm_data) -> None: + """Omitting model should default to '2AB_RescorlaWagner_DDM'.""" model = RLSSM(data=rldm_data) assert isinstance(model, RLSSM) - assert model.model_config.decision_process == "angle" + assert model.model_config.decision_process == "ddm" def test_rlssm_model_config_provided(rldm_data, rlssm_config) -> None: @@ -370,21 +379,21 @@ def test_rlssm_unregistered_model_raises(rldm_data) -> None: def test_rlssm_missing_data_property_raises(rldm_data) -> None: """Accessing .missing_data on a built RLSSM instance must raise NotImplementedError.""" - model = RLSSM(data=rldm_data, model="rldm") + model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") with pytest.raises(NotImplementedError, match="missing_data"): _ = model.missing_data def test_rlssm_deadline_property_raises(rldm_data) -> None: """Accessing .deadline on a built RLSSM instance must raise NotImplementedError.""" - model = RLSSM(data=rldm_data, model="rldm") + model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") with pytest.raises(NotImplementedError, match="deadline"): _ = model.deadline def test_rlssm_loglik_missing_data_property_raises(rldm_data) -> None: """Accessing .loglik_missing_data on a built RLSSM instance must raise NotImplementedError.""" - model = RLSSM(data=rldm_data, model="rldm") + model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") with pytest.raises(NotImplementedError, match="loglik_missing_data"): _ = model.loglik_missing_data @@ -411,9 +420,9 @@ def test_register_rlssm_model(rldm_data) -> None: def test_rlssm_init_args_uses_simplified_interface(rldm_data) -> None: """_init_args should reflect the simplified constructor, not model_config.""" - model = RLSSM(data=rldm_data, model="rldm") + model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") assert "model" in model._init_args - assert model._init_args["model"] == "rldm" + assert model._init_args["model"] == "2AB_RescorlaWagner_DDM" # model_config should not be baked in as a hard reference assert "model_config" in model._init_args assert model._init_args["model_config"] is None From d8c9359f0f8b49a0eeb8182e00a0eaf855ef8d03 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Tue, 12 May 2026 18:21:25 -0400 Subject: [PATCH 24/41] Raise error for missing SSM bounds in get_rlssm_model_config to prevent silent failures --- src/hssm/rl/registry.py | 11 +++++++++-- tests/rl/test_registry.py | 17 ++++++++++------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index ded46949..ad907620 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -391,13 +391,20 @@ def get_rlssm_model_config( list_params = rl_params + ssm_sampled # bounds: RL bounds ∪ SSM sampled bounds + missing_bounds = [p for p in ssm_sampled if p not in ssm_entry["bounds_ssm"]] + if missing_bounds: + raise ValueError( + f"SSM parameter(s) {missing_bounds} are included in list_params but have " + f"no entry in bounds_ssm for decision process '{dp}'. " + "Provide bounds for all sampled parameters via register_ssm() or ensure " + "the built-in modelconfig includes them." + ) _rl_bounds = entry.get("rl_bounds") bounds: dict[str, tuple[float, float]] = dict( _rl_bounds if _rl_bounds is not None else {} ) for p in ssm_sampled: - if p in ssm_entry["bounds_ssm"]: - bounds[p] = ssm_entry["bounds_ssm"][p] + bounds[p] = ssm_entry["bounds_ssm"][p] # params_default aligned with list_params _rl_defaults = entry.get("rl_params_default") diff --git a/tests/rl/test_registry.py b/tests/rl/test_registry.py index 3a90b0e8..26e3f6eb 100644 --- a/tests/rl/test_registry.py +++ b/tests/rl/test_registry.py @@ -421,15 +421,20 @@ def test_get_rlssm_model_config_derives_rl_params_when_absent( # --------------------------------------------------------------------------- -# get_rlssm_model_config — SSM param absent from bounds_ssm is skipped +# get_rlssm_model_config — SSM param absent from bounds_ssm raises early # --------------------------------------------------------------------------- -def test_get_rlssm_model_config_skips_missing_bounds( +def test_get_rlssm_model_config_raises_for_missing_bounds( annotated_ssm_base_logp: Any, learning_process: dict[str, Any], ) -> None: - """SSM params not present in bounds_ssm must not appear in the output bounds.""" + """SSM params missing from bounds_ssm must raise immediately with a clear error. + + Previously the factory silently skipped such params, producing a broken + RLSSMConfig that only failed later inside _RLSSM.__init__. Now it must + raise at the factory boundary so the error message points at the root cause. + """ registry.register_ssm( name="no_bounds_ssm", ssm_base_logp_func=annotated_ssm_base_logp, @@ -450,10 +455,8 @@ def test_get_rlssm_model_config_skips_missing_bounds( choices=[0, 1], ) - config = registry.get_rlssm_model_config("no_bounds_model") - - assert "a" not in config.bounds - assert "rl_alpha" in config.bounds + with pytest.raises(ValueError, match="no entry in bounds_ssm"): + registry.get_rlssm_model_config("no_bounds_model") # --------------------------------------------------------------------------- From f99a5336f72e57ce4ad722de4275a0d38cb6b3b4 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Tue, 12 May 2026 18:21:55 -0400 Subject: [PATCH 25/41] Clarify documentation for RLSSM registry and update example usage for list_models function --- src/hssm/rl/registry.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index ad907620..7dd6cff8 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -8,8 +8,9 @@ ``approx_differentiable`` likelihood) are resolved automatically from :func:`hssm.modelconfig.get_default_model_config` and do **not** need to be pre-registered here. -- :data:`_RLSSM_REGISTRY` — maps named RLSSM model strings (e.g. ``"rldm"``) - to their default decision process, learning process, and parameter info. +- :data:`_RLSSM_REGISTRY` — maps named RLSSM model strings (e.g. + ``"2AB_RescorlaWagner_ddm"``) to their default decision process, + learning process, and parameter info. - :func:`get_rlssm_model_config` — builds a :class:`~hssm.rl.config.RLSSMConfig` from a named model string with optional overrides. - :func:`register_rlssm_model` — register a custom named RLSSM model. @@ -454,7 +455,7 @@ def list_models() -> dict[str, str | None]: Examples -------- >>> import hssm - >>> hssm.rl.list_rlssm_models() + >>> hssm.rl.list_models() {'2AB_RescorlaWagner_DDM': 'RLSSM model with Rescorla-Wagner ...', ...} """ return {name: entry.get("description") for name, entry in _RLSSM_REGISTRY.items()} @@ -485,7 +486,9 @@ def register_rlssm_model( name: Registry key (e.g. ``"my_rldm"``). decision_process: - Name of the SSM to use (must already be in the SSM registry). + Name of the SSM to use. This may be either a custom SSM already registered + in the SSM registry via :func:`register_ssm`, or a built-in HSSM modelconfig SSM + name such as ``"ddm"``, ``"angle"``, or ``"weibull"``. learning_process: Dict mapping computed parameter name → annotated learning function. rl_params: From 918548cbf0d30e9d6b843d416b4f82d453c457fb Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Tue, 12 May 2026 18:26:04 -0400 Subject: [PATCH 26/41] Update model check in RLSSM constructor to use specific model name for validation --- src/hssm/rl/rlssm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 2a760a0c..9a930590 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -434,7 +434,7 @@ def __init__( _my_init_args = self._store_init_args(locals(), kwargs) if model_config is not None: - if model != "rldm" or any( + if model != "2AB_RescorlaWagner_DDM" or any( x is not None for x in [learning_process, decision_process, choices] ): _logger.warning( From 3626f4d859f7a2328c2d84a27fe263fb67f72c43 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 14 May 2026 15:10:20 -0400 Subject: [PATCH 27/41] fix: ensure defensive copying of learning process and choices in RLSSM model config --- src/hssm/rl/registry.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index 7dd6cff8..abdcfdc7 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -368,7 +368,9 @@ def get_rlssm_model_config( ssm_entry = _get_decision_process_spec(entry["decision_process"]) dp: str = ssm_entry["name"] ssm_base = _get_ssm_logp(dp) - lp: dict[str, Any] = entry["learning_process"] + # Defensive copy so callers mutating config.learning_process don't corrupt + # the registry entry (entry is only a shallow copy of _RLSSM_REGISTRY[model]). + lp: dict[str, Any] = dict(entry["learning_process"]) # Compose the full ssm_logp_func with .computed = learning_process. ssm_logp_func = _build_ssm_logp_func(ssm_base, lp) @@ -430,8 +432,10 @@ def get_rlssm_model_config( bounds=bounds, params_default=params_default, response=response, - choices=entry["choices"], - extra_fields=entry.get("extra_fields"), + choices=tuple(entry["choices"]), + extra_fields=list(entry["extra_fields"]) + if entry.get("extra_fields") is not None + else None, ) From f3a94d6d5ad9723fbd221bb47012d195f67d6615 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 14 May 2026 15:25:07 -0400 Subject: [PATCH 28/41] fix: update choices assertion to use tuple format in test_builtin_2ab_models_config_structure --- tests/rl/test_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/rl/test_registry.py b/tests/rl/test_registry.py index 26e3f6eb..5a4c6bc7 100644 --- a/tests/rl/test_registry.py +++ b/tests/rl/test_registry.py @@ -616,7 +616,7 @@ def test_builtin_2ab_models_config_structure( assert config.list_params[:2] == ["rl_alpha", "scaler"] assert "rl_alpha" in config.bounds assert "scaler" in config.bounds - assert config.choices == [0, 1] + assert config.choices == (0, 1) assert config.extra_fields == ["feedback"] assert config.ssm_logp_func.computed == {"v": registry._compute_v_annotated} From 54cebfee954ea0ca81c6d2d4f3e44af9155e53b4 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 15 May 2026 14:29:45 -0400 Subject: [PATCH 29/41] chore: remove unnecessary comments --- src/hssm/rl/rlssm.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 9a930590..3e06db5e 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -284,11 +284,6 @@ def _make_model_distribution(self) -> type[pm.Distribution]: ) -# --------------------------------------------------------------------------- -# Blocked-attribute descriptor -# --------------------------------------------------------------------------- - - class _BlockedAttribute: """Data descriptor that blocks read access with NotImplementedError. @@ -329,11 +324,6 @@ def __set__(self, obj: Any, value: Any) -> None: # noqa: D105 obj.__dict__[self._storage_key] = value -# --------------------------------------------------------------------------- -# Public wrapper -# --------------------------------------------------------------------------- - - class RLSSM(_RLSSM): """Reinforcement Learning Sequential Sampling Model — simplified public API. From 08ba693025be0954dc7a6387096e771ede85f100 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 15 May 2026 14:46:59 -0400 Subject: [PATCH 30/41] refactor: clean up test_rlssm.py by removing unnecessary comments and organizing test cases --- tests/rl/test_rlssm.py | 717 +++++++++++++++++++---------------------- 1 file changed, 338 insertions(+), 379 deletions(-) diff --git a/tests/rl/test_rlssm.py b/tests/rl/test_rlssm.py index c52cc122..2d76ee70 100644 --- a/tests/rl/test_rlssm.py +++ b/tests/rl/test_rlssm.py @@ -23,10 +23,6 @@ from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise from hssm.utils import annotate_function -# --------------------------------------------------------------------------- -# Module-level annotated helpers (shared by all tests) -# --------------------------------------------------------------------------- - # Annotate the RL learning function: maps # (rl_alpha, scaler, response, feedback) -> v _compute_v_annotated = annotate_function( @@ -35,10 +31,6 @@ )(compute_v_subject_wise) -# Annotated SSM log-likelihood function (simplified for testing). -# It receives a 2-D lan_matrix whose columns correspond to -# [v, a, z, t, theta, rt, response] -# and returns per-trial log-probabilities of shape (n_total_trials,). @annotate_function( inputs=["v", "a", "z", "t", "theta", "rt", "response"], outputs=["logp"], @@ -51,11 +43,6 @@ def _dummy_ssm_logp(lan_matrix: jnp.ndarray) -> jnp.ndarray: return jnp.sum(lan_matrix, axis=1) -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - @pytest.fixture(scope="module", autouse=True) def _set_floatx_float32() -> Generator[None, None, None]: """Ensure float32 is used for this module's tests, then restore previous setting.""" @@ -103,375 +90,347 @@ def rlssm_config() -> RLSSMConfig: ) -# --------------------------------------------------------------------------- -# Initialisation & config-validation tests -# --------------------------------------------------------------------------- - - -def test_rlssm_init(rldm_data, rlssm_config) -> None: - """Basic RLSSM initialisation should succeed and return an RLSSM instance.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - assert isinstance(model, RLSSM) - assert model.model_config.model_name == "rldm_test" - - -def test_rlssm_panel_attrs(rldm_data, rlssm_config) -> None: - """n_participants and n_trials should match the fixture data structure.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - - n_participants = rldm_data["participant_id"].nunique() - n_trials = len(rldm_data) // n_participants - - assert model.n_participants == n_participants - assert model.n_trials == n_trials - - -def test_rlssm_params_keys(rldm_data, rlssm_config) -> None: - """model.params should contain exactly list_params + p_outlier.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - expected = set(rlssm_config.list_params) | {"p_outlier"} - assert set(model.params.keys()) == expected - - -def test_rlssm_unbalanced_raises(rldm_data, rlssm_config) -> None: - """Dropping one row should make the panel unbalanced → ValueError.""" - unbalanced = rldm_data.iloc[:-1].copy() - with pytest.raises(ValueError, match="balanced panels"): - RLSSM(data=unbalanced, model_config=rlssm_config) - - -def test_rlssm_nan_participant_id_raises(rldm_data, rlssm_config) -> None: - """NaN in participant_id column should raise ValueError before groupby silently drops rows.""" - nan_data = rldm_data.copy() - nan_data.loc[nan_data.index[0], "participant_id"] = float("nan") - with pytest.raises(ValueError, match="NaN"): - RLSSM(data=nan_data, model_config=rlssm_config) - - -def test_rlssm_missing_ssm_logp_func_raises(rldm_data, rlssm_config) -> None: - """RLSSMConfig without ssm_logp_func should raise ValueError on init.""" - bad_config = RLSSMConfig( - model_name="rldm_bad", - loglik_kind="approx_differentiable", - decision_process="angle", - decision_process_loglik_kind="approx_differentiable", - learning_process_kind="blackbox", - list_params=rlssm_config.list_params, - params_default=rlssm_config.params_default, - bounds=rlssm_config.bounds, - learning_process=rlssm_config.learning_process, - response=list(rlssm_config.response), - choices=list(rlssm_config.choices), - extra_fields=list(rlssm_config.extra_fields), - # ssm_logp_func intentionally omitted → defaults to None - ) - with pytest.raises(ValueError, match="ssm_logp_func"): - RLSSM(data=rldm_data, model_config=bad_config) - - -def test_rlssm_unannotated_ssm_logp_func_raises(rldm_data, rlssm_config) -> None: - """A plain callable without @annotate_function attrs should raise ValueError.""" - bad_config = RLSSMConfig( - model_name="rldm_bad", - loglik_kind="approx_differentiable", - decision_process="angle", - decision_process_loglik_kind="approx_differentiable", - learning_process_kind="blackbox", - list_params=rlssm_config.list_params, - params_default=rlssm_config.params_default, - bounds=rlssm_config.bounds, - learning_process=rlssm_config.learning_process, - response=list(rlssm_config.response), - choices=list(rlssm_config.choices), - extra_fields=list(rlssm_config.extra_fields), - ssm_logp_func=lambda x: x, # callable but no .inputs/.outputs/.computed - ) - with pytest.raises(ValueError, match="annotate_function"): - RLSSM(data=rldm_data, model_config=bad_config) - - -def test_rlssm_missing_data_raises(rldm_data, rlssm_config) -> None: - """Passing missing_data!=False should raise NotImplementedError with 'missing_data' in msg.""" - with pytest.raises(NotImplementedError, match="missing_data"): - RLSSM(data=rldm_data, model_config=rlssm_config, missing_data=True) - - -def test_rlssm_deadline_raises(rldm_data, rlssm_config) -> None: - """Passing deadline!=False should raise NotImplementedError with 'deadline' in msg.""" - with pytest.raises(NotImplementedError, match="deadline"): - RLSSM(data=rldm_data, model_config=rlssm_config, deadline=True) - - -# --------------------------------------------------------------------------- -# Model-structure tests -# --------------------------------------------------------------------------- - - -def test_rlssm_params_is_trialwise_aligned(rldm_data, rlssm_config) -> None: - """params_is_trialwise must align with list_params (same length, p_outlier=False).""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - assert model.model_config.list_params is not None - params_is_trialwise = [ - name != "p_outlier" for name in model.model_config.list_params - ] - assert len(params_is_trialwise) == len(model.model_config.list_params) - for name, is_tw in zip(model.model_config.list_params, params_is_trialwise): - if name == "p_outlier": - assert not is_tw, "p_outlier must be non-trialwise" - else: - assert is_tw, f"{name} must be trialwise" - - -def test_rlssm_get_prefix(rldm_data, rlssm_config) -> None: - """_get_prefix must use token-based matching, not substring search. - - - 'rl_alpha_Intercept' → 'rl_alpha' (underscore-containing RL param) - - 'p_outlier_log__' → 'p_outlier' (lapse param via token loop, not substring) - - 'a_Intercept' → 'a' (single-token standard param) - """ - model = RLSSM(data=rldm_data, model_config=rlssm_config) - assert model._get_prefix("rl_alpha_Intercept") == "rl_alpha" - assert model._get_prefix("p_outlier_log__") == "p_outlier" - assert model._get_prefix("p_outlier") == "p_outlier" - assert model._get_prefix("a_Intercept") == "a" - # Fallback: not in params - assert model._get_prefix("unknown_param") == "unknown_param" - - -def test_rlssm_no_lapse(rldm_data, rlssm_config) -> None: - """Setting p_outlier=None should remove p_outlier from params.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config, p_outlier=None) - assert "p_outlier" not in model.params - - -def test_rlssm_model_built(rldm_data, rlssm_config) -> None: - """The bambi model should be built and the computed param 'v' absent from params.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - assert model.model is not None - # rl_alpha is a free (sampled) parameter - assert "rl_alpha" in model.params - # v is computed inside the Op; it must NOT appear as a free parameter - assert "v" not in model.params - - -def test_rlssm_extra_fields_are_copies(rldm_data, rlssm_config) -> None: - """extra_fields passed to make_distribution must be independent numpy copies. - - to_numpy(copy=True) should return a new buffer; if it returned a view, - in-place mutations of the DataFrame would silently corrupt the distribution. - """ - model = RLSSM(data=rldm_data, model_config=rlssm_config) - captured: dict = {} - - def capturing_make_distribution(*args, **kwargs): - captured["extra_fields"] = kwargs.get("extra_fields") - return real_make_distribution(*args, **kwargs) - - with patch( - "hssm.rl.rlssm.make_distribution", side_effect=capturing_make_distribution - ): - model._make_model_distribution() - - assert captured.get("extra_fields") is not None - for field_name, arr in zip(rlssm_config.extra_fields, captured["extra_fields"]): - original = model.data[field_name].to_numpy() - assert not np.shares_memory(arr, original), ( - f"extra_fields['{field_name}'] shares memory with the DataFrame — " - "it is a view, not a copy" +class TestRLSSMInit: + """Basic construction, attribute checks, and invalid-input guards at construction time.""" + + def test_rlssm_init(self, rldm_data, rlssm_config) -> None: + """Basic RLSSM initialisation should succeed and return an RLSSM instance.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert isinstance(model, RLSSM) + assert model.model_config.model_name == "rldm_test" + + def test_rlssm_panel_attrs(self, rldm_data, rlssm_config) -> None: + """n_participants and n_trials should match the fixture data structure.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + + n_participants = rldm_data["participant_id"].nunique() + n_trials = len(rldm_data) // n_participants + + assert model.n_participants == n_participants + assert model.n_trials == n_trials + + def test_rlssm_params_keys(self, rldm_data, rlssm_config) -> None: + """model.params should contain exactly list_params + p_outlier.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + expected = set(rlssm_config.list_params) | {"p_outlier"} + assert set(model.params.keys()) == expected + + def test_rlssm_unbalanced_raises(self, rldm_data, rlssm_config) -> None: + """Dropping one row should make the panel unbalanced → ValueError.""" + unbalanced = rldm_data.iloc[:-1].copy() + with pytest.raises(ValueError, match="balanced panels"): + RLSSM(data=unbalanced, model_config=rlssm_config) + + def test_rlssm_nan_participant_id_raises(self, rldm_data, rlssm_config) -> None: + """NaN in participant_id column should raise ValueError before groupby silently drops rows.""" + nan_data = rldm_data.copy() + nan_data.loc[nan_data.index[0], "participant_id"] = float("nan") + with pytest.raises(ValueError, match="NaN"): + RLSSM(data=nan_data, model_config=rlssm_config) + + def test_rlssm_missing_ssm_logp_func_raises(self, rldm_data, rlssm_config) -> None: + """RLSSMConfig without ssm_logp_func should raise ValueError on init.""" + bad_config = RLSSMConfig( + model_name="rldm_bad", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_kind="blackbox", + list_params=rlssm_config.list_params, + params_default=rlssm_config.params_default, + bounds=rlssm_config.bounds, + learning_process=rlssm_config.learning_process, + response=list(rlssm_config.response), + choices=list(rlssm_config.choices), + extra_fields=list(rlssm_config.extra_fields), + # ssm_logp_func intentionally omitted → defaults to None ) + with pytest.raises(ValueError, match="ssm_logp_func"): + RLSSM(data=rldm_data, model_config=bad_config) + + def test_rlssm_unannotated_ssm_logp_func_raises( + self, rldm_data, rlssm_config + ) -> None: + """A plain callable without @annotate_function attrs should raise ValueError.""" + bad_config = RLSSMConfig( + model_name="rldm_bad", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_kind="blackbox", + list_params=rlssm_config.list_params, + params_default=rlssm_config.params_default, + bounds=rlssm_config.bounds, + learning_process=rlssm_config.learning_process, + response=list(rlssm_config.response), + choices=list(rlssm_config.choices), + extra_fields=list(rlssm_config.extra_fields), + ssm_logp_func=lambda x: x, # callable but no .inputs/.outputs/.computed + ) + with pytest.raises(ValueError, match="annotate_function"): + RLSSM(data=rldm_data, model_config=bad_config) + + def test_rlssm_missing_data_raises(self, rldm_data, rlssm_config) -> None: + """Passing missing_data!=False should raise NotImplementedError with 'missing_data' in msg.""" + with pytest.raises(NotImplementedError, match="missing_data"): + RLSSM(data=rldm_data, model_config=rlssm_config, missing_data=True) + + def test_rlssm_deadline_raises(self, rldm_data, rlssm_config) -> None: + """Passing deadline!=False should raise NotImplementedError with 'deadline' in msg.""" + with pytest.raises(NotImplementedError, match="deadline"): + RLSSM(data=rldm_data, model_config=rlssm_config, deadline=True) + + +class TestRLSSMModelStructure: + """Internal model anatomy after construction: params, prefix, lapse, bambi/pymc, extra_fields.""" + + def test_rlssm_params_is_trialwise_aligned(self, rldm_data, rlssm_config) -> None: + """params_is_trialwise must align with list_params (same length, p_outlier=False).""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model.model_config.list_params is not None + params_is_trialwise = [ + name != "p_outlier" for name in model.model_config.list_params + ] + assert len(params_is_trialwise) == len(model.model_config.list_params) + for name, is_tw in zip(model.model_config.list_params, params_is_trialwise): + if name == "p_outlier": + assert not is_tw, "p_outlier must be non-trialwise" + else: + assert is_tw, f"{name} must be trialwise" + + def test_rlssm_get_prefix(self, rldm_data, rlssm_config) -> None: + """_get_prefix must use token-based matching, not substring search. + + - 'rl_alpha_Intercept' → 'rl_alpha' (underscore-containing RL param) + - 'p_outlier_log__' → 'p_outlier' (lapse param via token loop, not substring) + - 'a_Intercept' → 'a' (single-token standard param) + """ + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model._get_prefix("rl_alpha_Intercept") == "rl_alpha" + assert model._get_prefix("p_outlier_log__") == "p_outlier" + assert model._get_prefix("p_outlier") == "p_outlier" + assert model._get_prefix("a_Intercept") == "a" + # Fallback: not in params + assert model._get_prefix("unknown_param") == "unknown_param" + + def test_rlssm_no_lapse(self, rldm_data, rlssm_config) -> None: + """Setting p_outlier=None should remove p_outlier from params.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config, p_outlier=None) + assert "p_outlier" not in model.params + + def test_rlssm_model_built(self, rldm_data, rlssm_config) -> None: + """The bambi model should be built and the computed param 'v' absent from params.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model.model is not None + # rl_alpha is a free (sampled) parameter + assert "rl_alpha" in model.params + # v is computed inside the Op; it must NOT appear as a free parameter + assert "v" not in model.params + + def test_rlssm_extra_fields_are_copies(self, rldm_data, rlssm_config) -> None: + """extra_fields passed to make_distribution must be independent numpy copies. + + to_numpy(copy=True) should return a new buffer; if it returned a view, + in-place mutations of the DataFrame would silently corrupt the distribution. + """ + model = RLSSM(data=rldm_data, model_config=rlssm_config) + captured: dict = {} + + def capturing_make_distribution(*args, **kwargs): + captured["extra_fields"] = kwargs.get("extra_fields") + return real_make_distribution(*args, **kwargs) + + with patch( + "hssm.rl.rlssm.make_distribution", side_effect=capturing_make_distribution + ): + model._make_model_distribution() + + assert captured.get("extra_fields") is not None + for field_name, arr in zip(rlssm_config.extra_fields, captured["extra_fields"]): + original = model.data[field_name].to_numpy() + assert not np.shares_memory(arr, original), ( + f"extra_fields['{field_name}'] shares memory with the DataFrame — " + "it is a view, not a copy" + ) + + def test_rlssm_pymc_model(self, rldm_data, rlssm_config) -> None: + """pymc_model should be accessible after model construction.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model.pymc_model is not None + + def test_rlssm_no_extra_fields_none_passed_to_make_distribution( + self, rldm_data, rlssm_config + ) -> None: + """When extra_fields is empty, make_distribution receives extra_fields=None.""" + config_no_extra = RLSSMConfig( + model_name="rldm_no_extra", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_kind="blackbox", + list_params=rlssm_config.list_params, + params_default=rlssm_config.params_default, + bounds=rlssm_config.bounds, + learning_process=rlssm_config.learning_process, + response=list(rlssm_config.response), + choices=list(rlssm_config.choices), + extra_fields=[], # empty → should pass None to make_distribution + ssm_logp_func=_dummy_ssm_logp, + ) + model = RLSSM(data=rldm_data, model_config=config_no_extra) + captured: dict = {} + + def capturing_make_distribution(*args, **kwargs): + captured["extra_fields"] = kwargs.get("extra_fields") + return real_make_distribution(*args, **kwargs) + + with patch( + "hssm.rl.rlssm.make_distribution", side_effect=capturing_make_distribution + ): + model._make_model_distribution() + + assert captured.get("extra_fields") is None + + +class TestRLSSMSerialization: + """Cloudpickle serialisation and deserialisation.""" + + def test_rlssm_pickle_round_trip( + self, rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig + ) -> None: + """Cloudpickle round-trip must reconstruct an equivalent RLSSM. + + Verifies that __getstate__ / __setstate__ survive serialisation: + - The reconstructed object is a fresh RLSSM (not the same instance). + - n_participants and n_trials are preserved. + - list_params (including p_outlier) are preserved. + - model_config.model_name is preserved. + - model.model (bambi model) is rebuilt, confirming full re-initialisation. + """ + model = RLSSM(data=rldm_data, model_config=rlssm_config) + blob = cloudpickle.dumps(model) + restored = cloudpickle.loads(blob) + + assert restored is not model + assert isinstance(restored, RLSSM) + assert restored.n_participants == model.n_participants + assert restored.n_trials == model.n_trials + assert restored.list_params == model.list_params + assert restored.model_config.model_name == model.model_config.model_name + assert restored.model is not None + + +class TestRLSSMSampling: + """Slow sampling smoke tests.""" + + @pytest.mark.slow + def test_rlssm_sample_smoke(self, rldm_data, rlssm_config) -> None: + """Minimal sampling run should return an InferenceData object.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + trace = model.sample( + draws=4, tune=50, chains=1, cores=1, sampler="numpyro", target_accept=0.9 + ) + assert trace is not None -def test_rlssm_pymc_model(rldm_data, rlssm_config) -> None: - """pymc_model should be accessible after model construction.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - assert model.pymc_model is not None - - -# --------------------------------------------------------------------------- -# Slow sampling smoke test -# --------------------------------------------------------------------------- - - -@pytest.mark.slow -def test_rlssm_sample_smoke(rldm_data, rlssm_config) -> None: - """Minimal sampling run should return an InferenceData object.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - trace = model.sample( - draws=4, tune=50, chains=1, cores=1, sampler="numpyro", target_accept=0.9 - ) - assert trace is not None - - -def test_rlssm_pickle_round_trip( - rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig -) -> None: - """Cloudpickle round-trip must reconstruct an equivalent RLSSM. - - Verifies that __getstate__ / __setstate__ survive serialisation: - - The reconstructed object is a fresh RLSSM (not the same instance). - - n_participants and n_trials are preserved. - - list_params (including p_outlier) are preserved. - - model_config.model_name is preserved. - - model.model (bambi model) is rebuilt, confirming full re-initialisation. - """ - model = RLSSM(data=rldm_data, model_config=rlssm_config) - blob = cloudpickle.dumps(model) - restored = cloudpickle.loads(blob) - - assert restored is not model - assert isinstance(restored, RLSSM) - assert restored.n_participants == model.n_participants - assert restored.n_trials == model.n_trials - assert restored.list_params == model.list_params - assert restored.model_config.model_name == model.model_config.model_name - assert restored.model is not None - - -# --------------------------------------------------------------------------- -# Simplified public interface tests (new RLSSM wrapper) -# --------------------------------------------------------------------------- - - -def test_rlssm_is_subclass_of_internal() -> None: - """RLSSM must be a subclass of _RLSSM.""" - assert issubclass(RLSSM, _RLSSM) - - -@pytest.mark.parametrize( - "model_name, expected_dp", - [ - ("2AB_RescorlaWagner_DDM", "ddm"), - ("2AB_RescorlaWagner_Angle", "angle"), - ("2AB_RescorlaWagner_Weibull", "weibull"), - ], -) -def test_rlssm_builtin_models_instantiate( - rldm_data, model_name: str, expected_dp: str -) -> None: - """All built-in 2AB_RescorlaWagner_* models should instantiate correctly.""" - model = RLSSM(data=rldm_data, model=model_name) - assert isinstance(model, RLSSM) - assert model.model_config.decision_process == expected_dp - assert "rl_alpha" in model.params - assert "scaler" in model.params - assert "a" in model.params - assert "v" not in model.params - - -def test_rlssm_default_model_is_ddm(rldm_data) -> None: - """Omitting model should default to '2AB_RescorlaWagner_DDM'.""" - model = RLSSM(data=rldm_data) - assert isinstance(model, RLSSM) - assert model.model_config.decision_process == "ddm" - - -def test_rlssm_model_config_provided(rldm_data, rlssm_config) -> None: - """Passing model_config= directly should bypass the registry.""" - model = RLSSM(data=rldm_data, model_config=rlssm_config) - assert model.model_config.model_name == rlssm_config.model_name - - -def test_rlssm_unregistered_model_raises(rldm_data) -> None: - """Using an unknown model name should raise ValueError.""" - with pytest.raises(ValueError, match="not found in the RLSSM registry"): - RLSSM(data=rldm_data, model="model_not_in_registry") - - -def test_rlssm_missing_data_property_raises(rldm_data) -> None: - """Accessing .missing_data on a built RLSSM instance must raise NotImplementedError.""" - model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") - with pytest.raises(NotImplementedError, match="missing_data"): - _ = model.missing_data - - -def test_rlssm_deadline_property_raises(rldm_data) -> None: - """Accessing .deadline on a built RLSSM instance must raise NotImplementedError.""" - model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") - with pytest.raises(NotImplementedError, match="deadline"): - _ = model.deadline - - -def test_rlssm_loglik_missing_data_property_raises(rldm_data) -> None: - """Accessing .loglik_missing_data on a built RLSSM instance must raise NotImplementedError.""" - model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") - with pytest.raises(NotImplementedError, match="loglik_missing_data"): - _ = model.loglik_missing_data +class TestRLSSMSimplifiedInterface: + """Public model= kwarg API: registry lookups, register_rlssm_model, unsupported-feature properties.""" + def test_rlssm_is_subclass_of_internal(self) -> None: + """RLSSM must be a subclass of _RLSSM.""" + assert issubclass(RLSSM, _RLSSM) -def test_register_rlssm_model(rldm_data) -> None: - """A user-registered model should be instantiable via the simplified interface.""" - # Re-use the existing annotated learning function and ssm logp from the - # module-level helpers defined at the top of this test file. - register_rlssm_model( - name="rldm_custom_test", - decision_process="angle", - learning_process={"v": _compute_v_annotated}, - rl_params=["rl_alpha", "scaler"], - rl_bounds={"rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0)}, - rl_params_default=[0.1, 1.0], - extra_fields=["feedback"], - choices=[0, 1], - ) - model = RLSSM(data=rldm_data, model="rldm_custom_test") - assert isinstance(model, RLSSM) - assert "rl_alpha" in model.params - assert "v" not in model.params - - -def test_rlssm_init_args_uses_simplified_interface(rldm_data) -> None: - """_init_args should reflect the simplified constructor, not model_config.""" - model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") - assert "model" in model._init_args - assert model._init_args["model"] == "2AB_RescorlaWagner_DDM" - # model_config should not be baked in as a hard reference - assert "model_config" in model._init_args - assert model._init_args["model_config"] is None - - -def test_rlssm_model_config_with_overrides_warns( - rldm_data, rlssm_config, caplog -) -> None: - """Warn when model_config is given alongside model/overrides. - - The extra arguments (model, learning_process, decision_process, choices) - should be ignored and a warning emitted. - """ - with caplog.at_level(logging.WARNING, logger="hssm"): - RLSSM(data=rldm_data, model_config=rlssm_config, model="some_other_model") - - assert any("ignoring" in r.message for r in caplog.records) - - -def test_rlssm_no_extra_fields_none_passed_to_make_distribution( - rldm_data, rlssm_config -) -> None: - """When extra_fields is empty, make_distribution receives extra_fields=None.""" - # Build a config without extra_fields. - config_no_extra = RLSSMConfig( - model_name="rldm_no_extra", - loglik_kind="approx_differentiable", - decision_process="angle", - decision_process_loglik_kind="approx_differentiable", - learning_process_kind="blackbox", - list_params=rlssm_config.list_params, - params_default=rlssm_config.params_default, - bounds=rlssm_config.bounds, - learning_process=rlssm_config.learning_process, - response=list(rlssm_config.response), - choices=list(rlssm_config.choices), - extra_fields=[], # empty → should pass None to make_distribution - ssm_logp_func=_dummy_ssm_logp, + @pytest.mark.parametrize( + "model_name, expected_dp", + [ + ("2AB_RescorlaWagner_DDM", "ddm"), + ("2AB_RescorlaWagner_Angle", "angle"), + ("2AB_RescorlaWagner_Weibull", "weibull"), + ], ) - model = RLSSM(data=rldm_data, model_config=config_no_extra) - captured: dict = {} - - def capturing_make_distribution(*args, **kwargs): - captured["extra_fields"] = kwargs.get("extra_fields") - return real_make_distribution(*args, **kwargs) - - with patch( - "hssm.rl.rlssm.make_distribution", side_effect=capturing_make_distribution - ): - model._make_model_distribution() - - assert captured.get("extra_fields") is None + def test_rlssm_builtin_models_instantiate( + self, rldm_data, model_name: str, expected_dp: str + ) -> None: + """All built-in 2AB_RescorlaWagner_* models should instantiate correctly.""" + model = RLSSM(data=rldm_data, model=model_name) + assert isinstance(model, RLSSM) + assert model.model_config.decision_process == expected_dp + assert "rl_alpha" in model.params + assert "scaler" in model.params + assert "a" in model.params + assert "v" not in model.params + + def test_rlssm_default_model_is_ddm(self, rldm_data) -> None: + """Omitting model should default to '2AB_RescorlaWagner_DDM'.""" + model = RLSSM(data=rldm_data) + assert isinstance(model, RLSSM) + assert model.model_config.decision_process == "ddm" + + def test_rlssm_model_config_provided(self, rldm_data, rlssm_config) -> None: + """Passing model_config= directly should bypass the registry.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model.model_config.model_name == rlssm_config.model_name + + def test_rlssm_unregistered_model_raises(self, rldm_data) -> None: + """Using an unknown model name should raise ValueError.""" + with pytest.raises(ValueError, match="not found in the RLSSM registry"): + RLSSM(data=rldm_data, model="model_not_in_registry") + + def test_rlssm_missing_data_property_raises(self, rldm_data) -> None: + """Accessing .missing_data on a built RLSSM instance must raise NotImplementedError.""" + model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") + with pytest.raises(NotImplementedError, match="missing_data"): + _ = model.missing_data + + def test_rlssm_deadline_property_raises(self, rldm_data) -> None: + """Accessing .deadline on a built RLSSM instance must raise NotImplementedError.""" + model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") + with pytest.raises(NotImplementedError, match="deadline"): + _ = model.deadline + + def test_rlssm_loglik_missing_data_property_raises(self, rldm_data) -> None: + """Accessing .loglik_missing_data on a built RLSSM instance must raise NotImplementedError.""" + model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") + with pytest.raises(NotImplementedError, match="loglik_missing_data"): + _ = model.loglik_missing_data + + def test_register_rlssm_model(self, rldm_data) -> None: + """A user-registered model should be instantiable via the simplified interface.""" + # Re-use the existing annotated learning function and ssm logp from the + # module-level helpers defined at the top of this test file. + register_rlssm_model( + name="rldm_custom_test", + decision_process="angle", + learning_process={"v": _compute_v_annotated}, + rl_params=["rl_alpha", "scaler"], + rl_bounds={"rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0)}, + rl_params_default=[0.1, 1.0], + extra_fields=["feedback"], + choices=[0, 1], + ) + model = RLSSM(data=rldm_data, model="rldm_custom_test") + assert isinstance(model, RLSSM) + assert "rl_alpha" in model.params + assert "v" not in model.params + + def test_rlssm_init_args_uses_simplified_interface(self, rldm_data) -> None: + """_init_args should reflect the simplified constructor, not model_config.""" + model = RLSSM(data=rldm_data, model="2AB_RescorlaWagner_DDM") + assert "model" in model._init_args + assert model._init_args["model"] == "2AB_RescorlaWagner_DDM" + # model_config should not be baked in as a hard reference + assert "model_config" in model._init_args + assert model._init_args["model_config"] is None + + def test_rlssm_model_config_with_overrides_warns( + self, rldm_data, rlssm_config, caplog + ) -> None: + """Warn when model_config is given alongside model/overrides. + + The extra arguments (model, learning_process, decision_process, choices) + should be ignored and a warning emitted. + """ + with caplog.at_level(logging.WARNING, logger="hssm"): + RLSSM(data=rldm_data, model_config=rlssm_config, model="some_other_model") + + assert any("ignoring" in r.message for r in caplog.records) From 1e987d5fa7f8fe9563500bab07fa7144638d17cc Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 15 May 2026 15:06:55 -0400 Subject: [PATCH 31/41] Improve model_config handling in RLSSM class --- src/hssm/rl/rlssm.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py index 3e06db5e..bdb12413 100644 --- a/src/hssm/rl/rlssm.py +++ b/src/hssm/rl/rlssm.py @@ -340,7 +340,7 @@ class RLSSM(_RLSSM): ---------- data : pd.DataFrame Trial-level data (balanced panel required). - model : str, optional + model : str | None, optional Name of a registered RLSSM model. Defaults to ``"2AB_RescorlaWagner_DDM"``. choices : list[int] | None, optional Override the choice values in the registry. ``None`` uses the registry @@ -402,7 +402,7 @@ class RLSSM(_RLSSM): def __init__( self, data: pd.DataFrame, - model: str = "2AB_RescorlaWagner_DDM", + model: str | None = None, choices: list[int] | None = None, include: list[dict[str, Any] | Any] | None = None, model_config: RLSSMConfig | None = None, @@ -423,22 +423,21 @@ def __init__( # NOTE: _store_init_args only operates on its arguments, not on self. _my_init_args = self._store_init_args(locals(), kwargs) - if model_config is not None: - if model != "2AB_RescorlaWagner_DDM" or any( - x is not None for x in [learning_process, decision_process, choices] - ): - _logger.warning( - "model_config was provided; ignoring model, learning_process, " - "decision_process, and choices arguments." - ) - else: - model_config = get_rlssm_model_config( - model=model, - choices=choices, - learning_process=learning_process, - decision_process=decision_process, + if model_config is not None and any( + x is not None for x in [model, learning_process, decision_process, choices] + ): + _logger.warning( + "model_config was provided; ignoring model, learning_process, " + "decision_process, and choices arguments." ) + model_config = model_config or get_rlssm_model_config( + model=model or "2AB_RescorlaWagner_DDM", + choices=choices, + learning_process=learning_process, + decision_process=decision_process, + ) + # missing_data / deadline are intentionally omitted — _RLSSM defaults # them to False. The _BlockedAttribute descriptors on this class # silently accept writes from the base-class init path and block reads From 80d3c98600c90a95ce13b107d272be98bbaf134c Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 15 May 2026 15:22:14 -0400 Subject: [PATCH 32/41] fix: update logger initialization to use module name for better context --- src/hssm/rl/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index abdcfdc7..e3e9e6bb 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -29,7 +29,7 @@ from .config import RLSSMConfig -_logger = logging.getLogger("hssm") +_logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Default annotated Rescorla-Wagner learning function From abd86dc1628473b7ba4031a2bf7fd6f41b541f5d Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 15 May 2026 15:23:01 -0400 Subject: [PATCH 33/41] fix: simplify parameter default calculation in _build_ssm_spec_from_modelconfig --- src/hssm/rl/registry.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index e3e9e6bb..ff8b2079 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -116,8 +116,7 @@ def _build_ssm_spec_from_modelconfig(name: str) -> dict[str, Any]: # Derive parameter defaults as midpoints of their respective bounds. params_default_ssm = [ - (bounds_ssm[p][0] + bounds_ssm[p][1]) / 2.0 if p in bounds_ssm else 0.0 - for p in list_params_ssm + sum(bounds_ssm[p]) / 2 if p in bounds_ssm else 0.0 for p in list_params_ssm ] # Capture loop variables explicitly to avoid closure-over-variable issues. From aac4f6fa2ee0954a5cd06de273dcb4b9beeb523a Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 15 May 2026 15:29:37 -0400 Subject: [PATCH 34/41] fix: optimize SSM logp retrieval by caching and restructuring conditionals --- src/hssm/rl/registry.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index ff8b2079..2a8f0457 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -167,18 +167,21 @@ def _get_ssm_logp(name: str) -> Any: ONNX models are downloaded / loaded only when first called (lazy initialisation). Subsequent calls return the cached object. """ - if name not in _SSM_LOGP_CACHE: - if name in _SSM_REGISTRY: - entry = _SSM_REGISTRY[name] - if "ssm_base_logp_func_factory" in entry: - _SSM_LOGP_CACHE[name] = entry["ssm_base_logp_func_factory"]() - else: - # Pre-built function registered via register_ssm(). - _SSM_LOGP_CACHE[name] = entry["ssm_base_logp_func"] - else: - # Build from HSSM's modelconfig for built-in SSMs. - spec = _build_ssm_spec_from_modelconfig(name) - _SSM_LOGP_CACHE[name] = spec["ssm_base_logp_func_factory"]() + if name in _SSM_LOGP_CACHE: + return _SSM_LOGP_CACHE[name] + + if name not in _SSM_REGISTRY: + # Build from HSSM's modelconfig for built-in SSMs. + spec = _build_ssm_spec_from_modelconfig(name) + _SSM_LOGP_CACHE[name] = spec["ssm_base_logp_func_factory"]() + return _SSM_LOGP_CACHE[name] + + entry = _SSM_REGISTRY[name] + if "ssm_base_logp_func_factory" in entry: + _SSM_LOGP_CACHE[name] = entry["ssm_base_logp_func_factory"]() + else: + # Pre-built function registered via register_ssm(). + _SSM_LOGP_CACHE[name] = entry["ssm_base_logp_func"] return _SSM_LOGP_CACHE[name] From 1042c70fef04ae8f8a8d6bb1915d510f63e55776 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 15 May 2026 15:35:50 -0400 Subject: [PATCH 35/41] fix: streamline parameter filtering in get_rlssm_model_config by using lp directly --- src/hssm/rl/registry.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index 2a8f0457..8cc95e4c 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -378,8 +378,7 @@ def get_rlssm_model_config( ssm_logp_func = _build_ssm_logp_func(ssm_base, lp) # list_params = [sampled RL params] + [sampled SSM params (non-computed)] - computed_set = set(lp.keys()) - ssm_sampled = [p for p in ssm_entry["list_params_ssm"] if p not in computed_set] + ssm_sampled = [p for p in ssm_entry["list_params_ssm"] if p not in lp] # Defensive copy of response to prevent downstream mutation of registry. response = list(ssm_entry["response"]) @@ -418,7 +417,7 @@ def get_rlssm_model_config( ssm_sampled_defaults = [ ssm_all_defaults[i] for i, p in enumerate(ssm_entry["list_params_ssm"]) - if p not in computed_set + if p not in lp ] params_default = rl_defaults + ssm_sampled_defaults From 532062189de6f2785d62ce65937b88a4d7746b93 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Fri, 15 May 2026 16:11:39 -0400 Subject: [PATCH 36/41] fix: enhance RLSSM registry tests with improved error handling and caching logic --- tests/rl/test_registry.py | 1005 ++++++++++++++++++------------------- 1 file changed, 482 insertions(+), 523 deletions(-) diff --git a/tests/rl/test_registry.py b/tests/rl/test_registry.py index 5a4c6bc7..82273cab 100644 --- a/tests/rl/test_registry.py +++ b/tests/rl/test_registry.py @@ -7,6 +7,7 @@ from __future__ import annotations +import logging from copy import deepcopy from typing import Any @@ -53,442 +54,312 @@ def base_logp(v, a, rt, response): return base_logp -def test_get_ssm_logp_builds_lazy_factory_once(annotated_ssm_base_logp: Any) -> None: - """Lazy SSM factories should only build and cache one function instance.""" - call_count = 0 +class TestBuildSsmSpecFromModelconfig: + def test_unknown_model_raises(self) -> None: + """Completely unknown names should raise ValueError from the modelconfig bridge.""" + with pytest.raises(ValueError, match="not a registered custom SSM"): + registry._build_ssm_spec_from_modelconfig("totally_unknown_model") + + def test_no_approx_differentiable_raises( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Models without an approx_differentiable likelihood must be rejected.""" + import hssm.modelconfig as mc_module + + def _fake_get(name): # type: ignore[no-untyped-def] + return { + "list_params": ["v", "a"], + "response": ["rt", "response"], + "likelihoods": { + "analytical": { + "loglik": lambda: None, + "bounds": {}, + "backend": None, + }, + }, + } + + monkeypatch.setattr(mc_module, "get_default_model_config", _fake_get) + with pytest.raises(ValueError, match="no approx_differentiable likelihood"): + registry._build_ssm_spec_from_modelconfig("ddm") + + def test_factory_calls_onnx_loader( + self, + monkeypatch: pytest.MonkeyPatch, + annotated_ssm_base_logp: Any, + ) -> None: + """The lazy factory should call make_jax_matrix_logp_funcs_from_onnx with the + correct filename.""" + import hssm.distribution_utils.onnx as onnx_module + + called_with: list[str] = [] + + def _fake_onnx(model: str) -> Any: + called_with.append(model) + return annotated_ssm_base_logp + + monkeypatch.setattr( + onnx_module, "make_jax_matrix_logp_funcs_from_onnx", _fake_onnx + ) + monkeypatch.setattr( + registry, "make_jax_matrix_logp_funcs_from_onnx", _fake_onnx + ) - def factory() -> Any: - nonlocal call_count - call_count += 1 - return annotated_ssm_base_logp + spec = registry._build_ssm_spec_from_modelconfig("angle") + result = spec["ssm_base_logp_func_factory"]() - registry._SSM_REGISTRY["lazy_unit_test_ssm"] = { - "ssm_base_logp_func_factory": factory, - "list_params_ssm": ["v", "a"], - "bounds_ssm": {"a": (0.3, 3.0)}, - "params_default_ssm": [0.0, 1.5], - "response": ["rt", "response"], - } + assert called_with == ["angle.onnx"] + assert callable(result) + assert result.inputs == ["v", "a", "z", "t", "theta", "rt", "response"] + assert result.outputs == ["logp"] - first = registry._get_ssm_logp("lazy_unit_test_ssm") - second = registry._get_ssm_logp("lazy_unit_test_ssm") - assert first is annotated_ssm_base_logp - assert second is first - assert call_count == 1 +class TestGetSsmLogp: + def test_builds_lazy_factory_once(self, annotated_ssm_base_logp: Any) -> None: + """Lazy SSM factories should only build and cache one function instance.""" + call_count = 0 + def factory() -> Any: + nonlocal call_count + call_count += 1 + return annotated_ssm_base_logp -def test_derive_rl_params_excludes_response_and_extra_fields( - learning_process: dict[str, Any], -) -> None: - """Derived RL params should ignore response columns and extra fields.""" - derived = registry._derive_rl_params( - learning_process=learning_process, - response=["rt", "response"], - extra_fields=["feedback"], - ) + registry._SSM_REGISTRY["lazy_unit_test_ssm"] = { + "ssm_base_logp_func_factory": factory, + "list_params_ssm": ["v", "a"], + "bounds_ssm": {"a": (0.3, 3.0)}, + "params_default_ssm": [0.0, 1.5], + "response": ["rt", "response"], + } - assert derived == ["rl_alpha"] - - -def test_get_rlssm_model_config_builds_expected_config( - annotated_ssm_base_logp: Any, - learning_process: dict[str, Any], -) -> None: - """Registry config composition should exclude computed SSM params.""" - registry.register_ssm( - name="unit_test_ssm", - ssm_base_logp_func=annotated_ssm_base_logp, - list_params_ssm=["v", "a"], - bounds_ssm={"a": (0.3, 3.0)}, - params_default_ssm=[0.0, 1.5], - response=["rt", "response"], - ) - registry.register_rlssm_model( - name="unit_test_model", - decision_process="unit_test_ssm", - learning_process=learning_process, - rl_params=["rl_alpha"], - rl_bounds={"rl_alpha": (0.0, 1.0)}, - rl_params_default=[0.2], - extra_fields=["feedback"], - choices=[0, 1], - ) + first = registry._get_ssm_logp("lazy_unit_test_ssm") + second = registry._get_ssm_logp("lazy_unit_test_ssm") + + assert first is annotated_ssm_base_logp + assert second is first + assert call_count == 1 + + def test_resolves_builtin_via_modelconfig( + self, + monkeypatch: pytest.MonkeyPatch, + annotated_ssm_base_logp: Any, + ) -> None: + """_get_ssm_logp should build and cache the logp for a built-in SSM.""" + import hssm.distribution_utils.onnx as onnx_module + + monkeypatch.setattr( + onnx_module, + "make_jax_matrix_logp_funcs_from_onnx", + lambda **_: annotated_ssm_base_logp, + ) + monkeypatch.setattr( + registry, + "make_jax_matrix_logp_funcs_from_onnx", + lambda **_: annotated_ssm_base_logp, + ) - config = registry.get_rlssm_model_config("unit_test_model") - - assert isinstance(config, RLSSMConfig) - assert config.list_params == ["rl_alpha", "a"] - assert config.bounds == {"rl_alpha": (0.0, 1.0), "a": (0.3, 3.0)} - assert config.params_default == [0.2, 1.5] - assert config.response == ["rt", "response"] - assert config.ssm_logp_func.computed == learning_process - - config.response.append("mutated") - assert registry._SSM_REGISTRY["unit_test_ssm"]["response"] == ["rt", "response"] - - -def test_get_rlssm_model_config_respects_explicit_empty_rl_fields( - annotated_ssm_base_logp: Any, - learning_process: dict[str, Any], -) -> None: - """Explicit empty RL collections must not trigger fallback derivation.""" - registry.register_ssm( - name="empty_rl_ssm", - ssm_base_logp_func=annotated_ssm_base_logp, - list_params_ssm=["v", "a"], - bounds_ssm={"a": (0.3, 3.0)}, - params_default_ssm=[0.0, 1.5], - response=["rt", "response"], - ) - registry._RLSSM_REGISTRY["empty_rl_model"] = { - "decision_process": "empty_rl_ssm", - "learning_process": learning_process, - "rl_params": [], - "rl_bounds": {}, - "rl_params_default": [], - "extra_fields": ["feedback"], - "choices": [0, 1], - "description": "test model", - "decision_process_loglik_kind": "approx_differentiable", - "learning_process_kind": "blackbox", - } - - config = registry.get_rlssm_model_config("empty_rl_model") - - assert config.list_params == ["a"] - assert config.bounds == {"a": (0.3, 3.0)} - assert config.params_default == [1.5] - - -def test_register_rlssm_model_copies_mutable_inputs( - learning_process: dict[str, Any], -) -> None: - """Caller mutations after registration must not alter the stored model.""" - rl_params = ["rl_alpha"] - rl_bounds = {"rl_alpha": (0.0, 1.0)} - rl_defaults = [0.2] - extra_fields = ["feedback"] - choices = [0, 1] - - registry.register_rlssm_model( - name="copy_test_model", - decision_process="angle", - learning_process=learning_process, - rl_params=rl_params, - rl_bounds=rl_bounds, - rl_params_default=rl_defaults, - extra_fields=extra_fields, - choices=choices, - ) + result = registry._get_ssm_logp("angle") + + assert callable(result) + assert "angle" in registry._SSM_LOGP_CACHE + # Second call returns cached value without re-building. + assert registry._get_ssm_logp("angle") is result + + +class TestBuildSsmLogpFunc: + def test_raises_if_already_computed( + self, + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], + ) -> None: + """_build_ssm_logp_func must refuse functions that already have .computed set.""" + precomputed = annotate_function( + inputs=annotated_ssm_base_logp.inputs, + outputs=annotated_ssm_base_logp.outputs, + computed=learning_process, + )(annotated_ssm_base_logp) + + with pytest.raises(ValueError, match="already has a non-empty .computed"): + registry._build_ssm_logp_func(precomputed, learning_process) + + +class TestDeriveRlParams: + def test_excludes_response_and_extra_fields( + self, + learning_process: dict[str, Any], + ) -> None: + """Derived RL params should ignore response columns and extra fields.""" + derived = registry._derive_rl_params( + learning_process=learning_process, + response=["rt", "response"], + extra_fields=["feedback"], + ) - rl_params.append("scaler") - rl_bounds["scaler"] = (0.0, 10.0) - rl_defaults.append(1.0) - extra_fields.append("trial") - choices.append(2) - # Built-in SSMs (e.g. "angle") are derived from modelconfig on each call - # and are never stored as shared mutable state in _SSM_REGISTRY, so there - # is no registry entry to corrupt here. - learning_process["other"] = next(iter(learning_process.values())) - - stored = registry._RLSSM_REGISTRY["copy_test_model"] - - assert stored["decision_process"]["name"] == "angle" - assert stored["decision_process"]["response"] == ["rt", "response"] - assert stored["rl_params"] == ["rl_alpha"] - assert stored["rl_bounds"] == {"rl_alpha": (0.0, 1.0)} - assert stored["rl_params_default"] == [0.2] - assert stored["extra_fields"] == ["feedback"] - assert stored["choices"] == [0, 1] - assert list(stored["learning_process"]) == ["v"] - - -def test_register_ssm_caches_prebuilt_function( - annotated_ssm_base_logp: Any, -) -> None: - """Registering a pre-built SSM should populate the cache immediately.""" - registry.register_ssm( - name="cached_ssm", - ssm_base_logp_func=annotated_ssm_base_logp, - list_params_ssm=["v", "a"], - bounds_ssm={"a": (0.3, 3.0)}, - params_default_ssm=[0.0, 1.5], - ) + assert derived == ["rl_alpha"] - assert registry._SSM_LOGP_CACHE["cached_ssm"] is annotated_ssm_base_logp - assert registry._SSM_REGISTRY["cached_ssm"]["response"] == ["rt", "response"] + def test_warns_for_unannotated_lp_func(self) -> None: + """A learning-process function without .inputs should log a warning and be skipped.""" + def unannotated_func(x): # type: ignore[no-untyped-def] + return x + + lp = {"v": unannotated_func} + result = registry._derive_rl_params( + learning_process=lp, + response=["rt", "response"], + extra_fields=["feedback"], + ) + # The unannotated function contributes no params. + assert result == [] -def test_register_ssm_rejects_precomputed_function( - annotated_ssm_base_logp: Any, - learning_process: dict[str, Any], -) -> None: - """SSM registration should reject functions that already carry computed params.""" - precomputed_logp = annotate_function( - inputs=annotated_ssm_base_logp.inputs, - outputs=annotated_ssm_base_logp.outputs, - computed=learning_process, - )(annotated_ssm_base_logp) - with pytest.raises(ValueError, match="should not have a non-empty .computed"): +class TestRegisterSsm: + def test_caches_prebuilt_function(self, annotated_ssm_base_logp: Any) -> None: + """Registering a pre-built SSM should populate the cache immediately.""" registry.register_ssm( - name="invalid_ssm", - ssm_base_logp_func=precomputed_logp, + name="cached_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, list_params_ssm=["v", "a"], bounds_ssm={"a": (0.3, 3.0)}, params_default_ssm=[0.0, 1.5], ) + assert registry._SSM_LOGP_CACHE["cached_ssm"] is annotated_ssm_base_logp + assert registry._SSM_REGISTRY["cached_ssm"]["response"] == ["rt", "response"] + + def test_rejects_precomputed_function( + self, + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], + ) -> None: + """SSM registration should reject functions that already carry computed params.""" + precomputed_logp = annotate_function( + inputs=annotated_ssm_base_logp.inputs, + outputs=annotated_ssm_base_logp.outputs, + computed=learning_process, + )(annotated_ssm_base_logp) + + with pytest.raises(ValueError, match="should not have a non-empty .computed"): + registry.register_ssm( + name="invalid_ssm", + ssm_base_logp_func=precomputed_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + ) + + def test_rejects_non_callable(self) -> None: + """register_ssm must raise when ssm_base_logp_func is not callable.""" + with pytest.raises(ValueError, match="must be callable"): + registry.register_ssm( + name="bad_ssm", + ssm_base_logp_func="not_a_function", # type: ignore[arg-type] + list_params_ssm=["v"], + bounds_ssm={}, + params_default_ssm=[0.0], + ) + + def test_rejects_unannotated_callable(self) -> None: + """register_ssm must raise when the callable lacks .inputs or .outputs.""" + + def plain(x): # type: ignore[no-untyped-def] + return x + + with pytest.raises( + ValueError, match="must be decorated with @annotate_function" + ): + registry.register_ssm( + name="unannotated_ssm", + ssm_base_logp_func=plain, + list_params_ssm=["v"], + bounds_ssm={}, + params_default_ssm=[0.0], + ) + + def test_warns_on_overwrite( + self, + annotated_ssm_base_logp: Any, + caplog: pytest.LogCaptureFixture, + ) -> None: + """Re-registering an existing SSM name should emit a warning.""" + registry.register_ssm( + name="dup_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + ) -def test_get_rlssm_model_config_unknown_model_raises() -> None: - """Unknown RLSSM model names should fail with a clear error.""" - with pytest.raises(ValueError, match="not found in the RLSSM registry"): - registry.get_rlssm_model_config("does_not_exist") - - -# --------------------------------------------------------------------------- -# _build_ssm_spec_from_modelconfig — error paths -# --------------------------------------------------------------------------- - - -def test_build_ssm_spec_unknown_model_raises() -> None: - """Completely unknown names should raise ValueError from the modelconfig bridge.""" - with pytest.raises(ValueError, match="not a registered custom SSM"): - registry._build_ssm_spec_from_modelconfig("totally_unknown_model") - - -def test_build_ssm_spec_no_approx_differentiable_raises( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Models without an approx_differentiable likelihood must be rejected.""" - import hssm.modelconfig as mc_module - - def _fake_get(name): # type: ignore[no-untyped-def] - return { - "list_params": ["v", "a"], - "response": ["rt", "response"], - "likelihoods": { - "analytical": {"loglik": lambda: None, "bounds": {}, "backend": None}, - }, - } - - monkeypatch.setattr(mc_module, "get_default_model_config", _fake_get) - with pytest.raises(ValueError, match="no approx_differentiable likelihood"): - registry._build_ssm_spec_from_modelconfig("ddm") - - -# --------------------------------------------------------------------------- -# _make_ssm_base_logp_from_onnx + _factory (ONNX paths, mocked) -# --------------------------------------------------------------------------- - - -def test_build_ssm_spec_factory_calls_onnx_loader( - monkeypatch: pytest.MonkeyPatch, - annotated_ssm_base_logp: Any, -) -> None: - """The lazy factory produced by _build_ssm_spec_from_modelconfig should - call make_jax_matrix_logp_funcs_from_onnx with the correct filename.""" - import hssm.distribution_utils.onnx as onnx_module - - called_with: list[str] = [] - - def _fake_onnx(model: str) -> Any: - called_with.append(model) - return annotated_ssm_base_logp - - monkeypatch.setattr(onnx_module, "make_jax_matrix_logp_funcs_from_onnx", _fake_onnx) - monkeypatch.setattr(registry, "make_jax_matrix_logp_funcs_from_onnx", _fake_onnx) - - spec = registry._build_ssm_spec_from_modelconfig("angle") - result = spec["ssm_base_logp_func_factory"]() - - assert called_with == ["angle.onnx"] - assert callable(result) - assert result.inputs == ["v", "a", "z", "t", "theta", "rt", "response"] - assert result.outputs == ["logp"] - - -# --------------------------------------------------------------------------- -# _get_ssm_logp — built-in SSM path (not in _SSM_REGISTRY) -# --------------------------------------------------------------------------- - - -def test_get_ssm_logp_resolves_builtin_via_modelconfig( - monkeypatch: pytest.MonkeyPatch, - annotated_ssm_base_logp: Any, -) -> None: - """_get_ssm_logp should build and cache the logp for a built-in SSM.""" - import hssm.distribution_utils.onnx as onnx_module - - monkeypatch.setattr( - onnx_module, - "make_jax_matrix_logp_funcs_from_onnx", - lambda **_: annotated_ssm_base_logp, - ) - monkeypatch.setattr( - registry, - "make_jax_matrix_logp_funcs_from_onnx", - lambda **_: annotated_ssm_base_logp, - ) - - result = registry._get_ssm_logp("angle") - - assert callable(result) - assert "angle" in registry._SSM_LOGP_CACHE - # Second call returns cached value without re-building. - assert registry._get_ssm_logp("angle") is result - - -# --------------------------------------------------------------------------- -# _build_ssm_logp_func — raises when func already carries .computed -# --------------------------------------------------------------------------- - - -def test_build_ssm_logp_func_raises_if_already_computed( - annotated_ssm_base_logp: Any, - learning_process: dict[str, Any], -) -> None: - """_build_ssm_logp_func must refuse functions that already have .computed set.""" - precomputed = annotate_function( - inputs=annotated_ssm_base_logp.inputs, - outputs=annotated_ssm_base_logp.outputs, - computed=learning_process, - )(annotated_ssm_base_logp) - - with pytest.raises(ValueError, match="already has a non-empty .computed"): - registry._build_ssm_logp_func(precomputed, learning_process) - - -# --------------------------------------------------------------------------- -# _derive_rl_params — LP function missing .inputs (warning branch) -# --------------------------------------------------------------------------- - - -def test_derive_rl_params_warns_for_unannotated_lp_func() -> None: - """A learning-process function without .inputs should log a warning and be skipped.""" - - def unannotated_func(x): # type: ignore[no-untyped-def] - return x - - lp = {"v": unannotated_func} - result = registry._derive_rl_params( - learning_process=lp, - response=["rt", "response"], - extra_fields=["feedback"], - ) - # The unannotated function contributes no params. - assert result == [] - - -# --------------------------------------------------------------------------- -# get_rlssm_model_config — rl_params=None fallback derivation -# --------------------------------------------------------------------------- - - -def test_get_rlssm_model_config_derives_rl_params_when_absent( - annotated_ssm_base_logp: Any, - learning_process: dict[str, Any], -) -> None: - """When rl_params is absent from the registry entry, params are derived - from the learning_process .inputs.""" - registry.register_ssm( - name="derive_params_ssm", - ssm_base_logp_func=annotated_ssm_base_logp, - list_params_ssm=["v", "a"], - bounds_ssm={"a": (0.3, 3.0)}, - params_default_ssm=[0.0, 1.5], - response=["rt", "response"], - ) - # Inject an entry without rl_params so the fallback derivation runs. - registry._RLSSM_REGISTRY["derive_params_model"] = { - "decision_process": "derive_params_ssm", - "learning_process": learning_process, - # rl_params deliberately absent - "rl_bounds": {}, - "rl_params_default": [], - "extra_fields": ["feedback"], - "choices": [0, 1], - "description": None, - "decision_process_loglik_kind": "approx_differentiable", - "learning_process_kind": "blackbox", - } - - config = registry.get_rlssm_model_config("derive_params_model") - - # "rl_alpha" is the only input to learning_process that isn't response/extra. - assert "rl_alpha" in config.list_params - - -# --------------------------------------------------------------------------- -# get_rlssm_model_config — SSM param absent from bounds_ssm raises early -# --------------------------------------------------------------------------- - - -def test_get_rlssm_model_config_raises_for_missing_bounds( - annotated_ssm_base_logp: Any, - learning_process: dict[str, Any], -) -> None: - """SSM params missing from bounds_ssm must raise immediately with a clear error. - - Previously the factory silently skipped such params, producing a broken - RLSSMConfig that only failed later inside _RLSSM.__init__. Now it must - raise at the factory boundary so the error message points at the root cause. - """ - registry.register_ssm( - name="no_bounds_ssm", - ssm_base_logp_func=annotated_ssm_base_logp, - list_params_ssm=["v", "a"], - # "a" intentionally has no bounds entry - bounds_ssm={}, - params_default_ssm=[0.0, 1.5], - response=["rt", "response"], - ) - registry.register_rlssm_model( - name="no_bounds_model", - decision_process="no_bounds_ssm", - learning_process=learning_process, - rl_params=["rl_alpha"], - rl_bounds={"rl_alpha": (0.0, 1.0)}, - rl_params_default=[0.2], - extra_fields=["feedback"], - choices=[0, 1], - ) - - with pytest.raises(ValueError, match="no entry in bounds_ssm"): - registry.get_rlssm_model_config("no_bounds_model") - - -# --------------------------------------------------------------------------- -# register_rlssm_model — overwrite warning -# --------------------------------------------------------------------------- - - -def test_register_rlssm_model_warns_on_overwrite( - learning_process: dict[str, Any], - annotated_ssm_base_logp: Any, - caplog: pytest.LogCaptureFixture, -) -> None: - """Re-registering an existing RLSSM model name should emit a warning.""" - registry.register_ssm( - name="overwrite_ssm", - ssm_base_logp_func=annotated_ssm_base_logp, - list_params_ssm=["v", "a"], - bounds_ssm={"a": (0.3, 3.0)}, - params_default_ssm=[0.0, 1.5], - ) - registry.register_rlssm_model( - name="overwrite_rlssm", - decision_process="overwrite_ssm", - learning_process=learning_process, - rl_params=["rl_alpha"], - rl_bounds={"rl_alpha": (0.0, 1.0)}, - rl_params_default=[0.2], - ) + with caplog.at_level(logging.WARNING, logger="hssm"): + registry.register_ssm( + name="dup_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + ) + + assert any("dup_ssm" in r.message for r in caplog.records) + + +class TestRegisterRlssmModel: + def test_copies_mutable_inputs( + self, + learning_process: dict[str, Any], + ) -> None: + """Caller mutations after registration must not alter the stored model.""" + rl_params = ["rl_alpha"] + rl_bounds = {"rl_alpha": (0.0, 1.0)} + rl_defaults = [0.2] + extra_fields = ["feedback"] + choices = [0, 1] - import logging + registry.register_rlssm_model( + name="copy_test_model", + decision_process="angle", + learning_process=learning_process, + rl_params=rl_params, + rl_bounds=rl_bounds, + rl_params_default=rl_defaults, + extra_fields=extra_fields, + choices=choices, + ) - with caplog.at_level(logging.WARNING, logger="hssm"): + rl_params.append("scaler") + rl_bounds["scaler"] = (0.0, 10.0) + rl_defaults.append(1.0) + extra_fields.append("trial") + choices.append(2) + # Built-in SSMs (e.g. "angle") are derived from modelconfig on each call + # and are never stored as shared mutable state in _SSM_REGISTRY, so there + # is no registry entry to corrupt here. + learning_process["other"] = next(iter(learning_process.values())) + + stored = registry._RLSSM_REGISTRY["copy_test_model"] + + assert stored["decision_process"]["name"] == "angle" + assert stored["decision_process"]["response"] == ["rt", "response"] + assert stored["rl_params"] == ["rl_alpha"] + assert stored["rl_bounds"] == {"rl_alpha": (0.0, 1.0)} + assert stored["rl_params_default"] == [0.2] + assert stored["extra_fields"] == ["feedback"] + assert stored["choices"] == [0, 1] + assert list(stored["learning_process"]) == ["v"] + + def test_warns_on_overwrite( + self, + learning_process: dict[str, Any], + annotated_ssm_base_logp: Any, + caplog: pytest.LogCaptureFixture, + ) -> None: + """Re-registering an existing RLSSM model name should emit a warning.""" + registry.register_ssm( + name="overwrite_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + ) registry.register_rlssm_model( name="overwrite_rlssm", decision_process="overwrite_ssm", @@ -498,138 +369,226 @@ def test_register_rlssm_model_warns_on_overwrite( rl_params_default=[0.2], ) - assert any("overwrite_rlssm" in r.message for r in caplog.records) - - -# --------------------------------------------------------------------------- -# register_ssm — validation error paths -# --------------------------------------------------------------------------- - + with caplog.at_level(logging.WARNING, logger="hssm"): + registry.register_rlssm_model( + name="overwrite_rlssm", + decision_process="overwrite_ssm", + learning_process=learning_process, + rl_params=["rl_alpha"], + rl_bounds={"rl_alpha": (0.0, 1.0)}, + rl_params_default=[0.2], + ) + + assert any("overwrite_rlssm" in r.message for r in caplog.records) + + +class TestGetRlssmModelConfig: + def test_builds_expected_config( + self, + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], + ) -> None: + """Registry config composition should exclude computed SSM params.""" + registry.register_ssm( + name="unit_test_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + response=["rt", "response"], + ) + registry.register_rlssm_model( + name="unit_test_model", + decision_process="unit_test_ssm", + learning_process=learning_process, + rl_params=["rl_alpha"], + rl_bounds={"rl_alpha": (0.0, 1.0)}, + rl_params_default=[0.2], + extra_fields=["feedback"], + choices=[0, 1], + ) -def test_register_ssm_rejects_non_callable() -> None: - """register_ssm must raise when ssm_base_logp_func is not callable.""" - with pytest.raises(ValueError, match="must be callable"): + config = registry.get_rlssm_model_config("unit_test_model") + + assert isinstance(config, RLSSMConfig) + assert config.list_params == ["rl_alpha", "a"] + assert config.bounds == {"rl_alpha": (0.0, 1.0), "a": (0.3, 3.0)} + assert config.params_default == [0.2, 1.5] + assert config.response == ["rt", "response"] + assert config.ssm_logp_func.computed == learning_process + + # Mutating the returned config must not corrupt the registry's stored list. + config.response.append("mutated") + assert registry._SSM_REGISTRY["unit_test_ssm"]["response"] == ["rt", "response"] + + def test_respects_explicit_empty_rl_fields( + self, + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], + ) -> None: + """Explicit empty RL collections must not trigger fallback derivation.""" registry.register_ssm( - name="bad_ssm", - ssm_base_logp_func="not_a_function", # type: ignore[arg-type] - list_params_ssm=["v"], - bounds_ssm={}, - params_default_ssm=[0.0], + name="empty_rl_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + response=["rt", "response"], ) + registry._RLSSM_REGISTRY["empty_rl_model"] = { + "decision_process": "empty_rl_ssm", + "learning_process": learning_process, + "rl_params": [], + "rl_bounds": {}, + "rl_params_default": [], + "extra_fields": ["feedback"], + "choices": [0, 1], + "description": "test model", + "decision_process_loglik_kind": "approx_differentiable", + "learning_process_kind": "blackbox", + } + config = registry.get_rlssm_model_config("empty_rl_model") -def test_register_ssm_rejects_unannotated_callable() -> None: - """register_ssm must raise when the callable lacks .inputs or .outputs.""" + assert config.list_params == ["a"] + assert config.bounds == {"a": (0.3, 3.0)} + assert config.params_default == [1.5] - def plain(x): # type: ignore[no-untyped-def] - return x + def test_unknown_model_raises(self) -> None: + """Unknown RLSSM model names should fail with a clear error.""" + with pytest.raises(ValueError, match="not found in the RLSSM registry"): + registry.get_rlssm_model_config("does_not_exist") - with pytest.raises(ValueError, match="must be decorated with @annotate_function"): + def test_derives_rl_params_when_absent( + self, + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], + ) -> None: + """When rl_params is absent from the registry entry, params are derived + from the learning_process .inputs.""" registry.register_ssm( - name="unannotated_ssm", - ssm_base_logp_func=plain, - list_params_ssm=["v"], - bounds_ssm={}, - params_default_ssm=[0.0], + name="derive_params_ssm", + ssm_base_logp_func=annotated_ssm_base_logp, + list_params_ssm=["v", "a"], + bounds_ssm={"a": (0.3, 3.0)}, + params_default_ssm=[0.0, 1.5], + response=["rt", "response"], ) + # Inject an entry without rl_params so the fallback derivation runs. + registry._RLSSM_REGISTRY["derive_params_model"] = { + "decision_process": "derive_params_ssm", + "learning_process": learning_process, + # rl_params deliberately absent + "rl_bounds": {}, + "rl_params_default": [], + "extra_fields": ["feedback"], + "choices": [0, 1], + "description": None, + "decision_process_loglik_kind": "approx_differentiable", + "learning_process_kind": "blackbox", + } + config = registry.get_rlssm_model_config("derive_params_model") -def test_register_ssm_warns_on_overwrite( - annotated_ssm_base_logp: Any, - caplog: pytest.LogCaptureFixture, -) -> None: - """Re-registering an existing SSM name should emit a warning.""" - registry.register_ssm( - name="dup_ssm", - ssm_base_logp_func=annotated_ssm_base_logp, - list_params_ssm=["v", "a"], - bounds_ssm={"a": (0.3, 3.0)}, - params_default_ssm=[0.0, 1.5], - ) + # "rl_alpha" is the only input to learning_process that isn't response/extra. + assert "rl_alpha" in config.list_params - import logging + def test_raises_for_missing_bounds( + self, + annotated_ssm_base_logp: Any, + learning_process: dict[str, Any], + ) -> None: + """SSM params missing from bounds_ssm must raise immediately with a clear error. - with caplog.at_level(logging.WARNING, logger="hssm"): + Previously the factory silently skipped such params, producing a broken + RLSSMConfig that only failed later inside _RLSSM.__init__. Now it must + raise at the factory boundary so the error message points at the root cause. + """ registry.register_ssm( - name="dup_ssm", + name="no_bounds_ssm", ssm_base_logp_func=annotated_ssm_base_logp, list_params_ssm=["v", "a"], - bounds_ssm={"a": (0.3, 3.0)}, + # "a" intentionally has no bounds entry + bounds_ssm={}, params_default_ssm=[0.0, 1.5], + response=["rt", "response"], + ) + registry.register_rlssm_model( + name="no_bounds_model", + decision_process="no_bounds_ssm", + learning_process=learning_process, + rl_params=["rl_alpha"], + rl_bounds={"rl_alpha": (0.0, 1.0)}, + rl_params_default=[0.2], + extra_fields=["feedback"], + choices=[0, 1], ) - assert any("dup_ssm" in r.message for r in caplog.records) - - -# --------------------------------------------------------------------------- -# Built-in starter-pack RLSSM models (DDM and Weibull variants) -# --------------------------------------------------------------------------- - - -@pytest.mark.parametrize( - "model_name, expected_dp", - [ - ("2AB_RescorlaWagner_DDM", "ddm"), - ("2AB_RescorlaWagner_Weibull", "weibull"), - ], -) -def test_builtin_2ab_models_are_registered(model_name: str, expected_dp: str) -> None: - """2AB_RescorlaWagner_DDM and _Weibull must be present in the RLSSM registry.""" - assert model_name in registry._RLSSM_REGISTRY - entry = registry._RLSSM_REGISTRY[model_name] - assert entry["decision_process"]["name"] == expected_dp - assert entry["rl_params"] == ["rl_alpha", "scaler"] - assert entry["extra_fields"] == ["feedback"] - assert entry["choices"] == [0, 1] - assert entry["decision_process_loglik_kind"] == "approx_differentiable" - assert entry["learning_process_kind"] == "blackbox" - - -@pytest.mark.parametrize( - "model_name", - ["2AB_RescorlaWagner_DDM", "2AB_RescorlaWagner_Weibull"], -) -def test_builtin_2ab_models_config_structure( - monkeypatch: pytest.MonkeyPatch, - annotated_ssm_base_logp: Any, - model_name: str, -) -> None: - """get_rlssm_model_config should produce a well-formed RLSSMConfig for - both the DDM and Weibull starter-pack models.""" - import hssm.distribution_utils.onnx as onnx_module - - monkeypatch.setattr( - onnx_module, - "make_jax_matrix_logp_funcs_from_onnx", - lambda model: annotated_ssm_base_logp, - ) - monkeypatch.setattr( - registry, - "make_jax_matrix_logp_funcs_from_onnx", - lambda model: annotated_ssm_base_logp, - ) + with pytest.raises(ValueError, match="no entry in bounds_ssm"): + registry.get_rlssm_model_config("no_bounds_model") - config = registry.get_rlssm_model_config(model_name) - assert isinstance(config, RLSSMConfig) - # RL params come first - assert config.list_params[:2] == ["rl_alpha", "scaler"] - assert "rl_alpha" in config.bounds - assert "scaler" in config.bounds - assert config.choices == (0, 1) - assert config.extra_fields == ["feedback"] - assert config.ssm_logp_func.computed == {"v": registry._compute_v_annotated} +class TestBuiltinModels: + @pytest.mark.parametrize( + "model_name, expected_dp", + [ + ("2AB_RescorlaWagner_DDM", "ddm"), + ("2AB_RescorlaWagner_Weibull", "weibull"), + ], + ) + def test_are_registered(self, model_name: str, expected_dp: str) -> None: + """2AB_RescorlaWagner_DDM and _Weibull must be present in the RLSSM registry.""" + assert model_name in registry._RLSSM_REGISTRY + entry = registry._RLSSM_REGISTRY[model_name] + assert entry["decision_process"]["name"] == expected_dp + assert entry["rl_params"] == ["rl_alpha", "scaler"] + assert entry["extra_fields"] == ["feedback"] + assert entry["choices"] == [0, 1] + assert entry["decision_process_loglik_kind"] == "approx_differentiable" + assert entry["learning_process_kind"] == "blackbox" + + @pytest.mark.parametrize( + "model_name", + ["2AB_RescorlaWagner_DDM", "2AB_RescorlaWagner_Weibull"], + ) + def test_config_structure( + self, + monkeypatch: pytest.MonkeyPatch, + annotated_ssm_base_logp: Any, + model_name: str, + ) -> None: + """get_rlssm_model_config should produce a well-formed RLSSMConfig for + both the DDM and Weibull starter-pack models.""" + import hssm.distribution_utils.onnx as onnx_module + + monkeypatch.setattr( + onnx_module, + "make_jax_matrix_logp_funcs_from_onnx", + lambda model: annotated_ssm_base_logp, + ) + monkeypatch.setattr( + registry, + "make_jax_matrix_logp_funcs_from_onnx", + lambda model: annotated_ssm_base_logp, + ) + config = registry.get_rlssm_model_config(model_name) -# --------------------------------------------------------------------------- -# list_rlssm_models -# --------------------------------------------------------------------------- + assert isinstance(config, RLSSMConfig) + # RL params come first + assert config.list_params[:2] == ["rl_alpha", "scaler"] + assert "rl_alpha" in config.bounds + assert "scaler" in config.bounds + assert config.choices == (0, 1) + assert config.extra_fields == ["feedback"] + assert config.ssm_logp_func.computed == {"v": registry._compute_v_annotated} -def test_list_rlssm_models_returns_all_names() -> None: - """list_rlssm_models should return every key in _RLSSM_REGISTRY with its description.""" - result = registry.list_models() +class TestListModels: + def test_returns_all_names(self) -> None: + """list_models should return every key in _RLSSM_REGISTRY with its description.""" + result = registry.list_models() - assert set(result.keys()) == set(registry._RLSSM_REGISTRY.keys()) - for name, desc in result.items(): - assert desc == registry._RLSSM_REGISTRY[name].get("description") + assert set(result.keys()) == set(registry._RLSSM_REGISTRY.keys()) + for name, desc in result.items(): + assert desc == registry._RLSSM_REGISTRY[name].get("description") From bc9c88fdb39ed952556c8ee742407835c1c20f02 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 21 May 2026 14:03:23 -0400 Subject: [PATCH 37/41] fix: rename _derive_rl_params to _derive_lp_params --- src/hssm/rl/registry.py | 4 ++-- tests/rl/test_registry.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index 8cc95e4c..ceda02f2 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -284,7 +284,7 @@ def _build_ssm_logp_func(ssm_base_logp_func: Any, learning_process: dict) -> Any )(ssm_base_logp_func) -def _derive_rl_params( +def _derive_lp_params( learning_process: dict[str, Any], response: list[str], extra_fields: list[str], @@ -388,7 +388,7 @@ def get_rlssm_model_config( # the fallback derivation logic. _rl_params = entry.get("rl_params") rl_params: list[str] = ( - _derive_rl_params(lp, response, entry.get("extra_fields") or []) + _derive_lp_params(lp, response, entry.get("extra_fields") or []) if _rl_params is None else _rl_params ) diff --git a/tests/rl/test_registry.py b/tests/rl/test_registry.py index 82273cab..de0ed2c4 100644 --- a/tests/rl/test_registry.py +++ b/tests/rl/test_registry.py @@ -189,7 +189,7 @@ def test_excludes_response_and_extra_fields( learning_process: dict[str, Any], ) -> None: """Derived RL params should ignore response columns and extra fields.""" - derived = registry._derive_rl_params( + derived = registry._derive_lp_params( learning_process=learning_process, response=["rt", "response"], extra_fields=["feedback"], @@ -204,7 +204,7 @@ def unannotated_func(x): # type: ignore[no-untyped-def] return x lp = {"v": unannotated_func} - result = registry._derive_rl_params( + result = registry._derive_lp_params( learning_process=lp, response=["rt", "response"], extra_fields=["feedback"], From b47d9dbc60ea89a95b9bb7ab113e8180156ea9a5 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 21 May 2026 14:26:27 -0400 Subject: [PATCH 38/41] Rename rl_params to learning_process_params --- src/hssm/rl/registry.py | 30 +++++++++++++++--------------- tests/rl/test_registry.py | 26 +++++++++++++------------- tests/rl/test_rlssm.py | 2 +- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index ceda02f2..587fca86 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -191,9 +191,9 @@ def _get_ssm_logp(name: str) -> Any: # Each entry provides: # decision_process - key into _SSM_REGISTRY # learning_process - {param: annotated_func} -# rl_params - ordered list of sampled RL parameter names +# learning_process_params - ordered list of sampled RL parameter names # rl_bounds - {param: (lo, hi)} for RL params -# rl_params_default - default values aligned with rl_params +# rl_params_default - default values aligned with learning_process_params # extra_fields - extra data column names required by LP # choices - response choice values # description - human-readable description @@ -204,7 +204,7 @@ def _get_ssm_logp(name: str) -> Any: "2AB_RescorlaWagner_DDM": { "decision_process": _get_decision_process_spec("ddm"), "learning_process": {"v": _compute_v_annotated}, - "rl_params": ["rl_alpha", "scaler"], + "learning_process_params": ["rl_alpha", "scaler"], "rl_bounds": { "rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0), @@ -222,7 +222,7 @@ def _get_ssm_logp(name: str) -> Any: "2AB_RescorlaWagner_Angle": { "decision_process": _get_decision_process_spec("angle"), "learning_process": {"v": _compute_v_annotated}, - "rl_params": ["rl_alpha", "scaler"], + "learning_process_params": ["rl_alpha", "scaler"], "rl_bounds": { "rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0), @@ -240,7 +240,7 @@ def _get_ssm_logp(name: str) -> Any: "2AB_RescorlaWagner_Weibull": { "decision_process": _get_decision_process_spec("weibull"), "learning_process": {"v": _compute_v_annotated}, - "rl_params": ["rl_alpha", "scaler"], + "learning_process_params": ["rl_alpha", "scaler"], "rl_bounds": { "rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0), @@ -295,7 +295,7 @@ def _derive_lp_params( neither response columns nor extra fields. """ exclude = set(response) | set(extra_fields) - rl_params: list[str] = [] + learning_process_params: list[str] = [] seen: set[str] = set() for param_name, lp_func in learning_process.items(): if not hasattr(lp_func, "inputs"): @@ -308,9 +308,9 @@ def _derive_lp_params( continue for inp in lp_func.inputs: if inp not in exclude and inp not in seen: - rl_params.append(inp) + learning_process_params.append(inp) seen.add(inp) - return rl_params + return learning_process_params # --------------------------------------------------------------------------- @@ -386,13 +386,13 @@ def get_rlssm_model_config( # Use `is None` checks so that explicitly empty containers ([], {}) are # respected as valid "no RL params" configuration and not overridden by # the fallback derivation logic. - _rl_params = entry.get("rl_params") - rl_params: list[str] = ( + _rl_params = entry.get("learning_process_params") + learning_process_params: list[str] = ( _derive_lp_params(lp, response, entry.get("extra_fields") or []) if _rl_params is None else _rl_params ) - list_params = rl_params + ssm_sampled + list_params = learning_process_params + ssm_sampled # bounds: RL bounds ∪ SSM sampled bounds missing_bounds = [p for p in ssm_sampled if p not in ssm_entry["bounds_ssm"]] @@ -475,7 +475,7 @@ def register_rlssm_model( name: str, decision_process: str, learning_process: dict[str, Any], - rl_params: list[str], + learning_process_params: list[str], rl_bounds: dict[str, tuple[float, float]], rl_params_default: list[float], extra_fields: list[str] | None = None, @@ -496,12 +496,12 @@ def register_rlssm_model( name such as ``"ddm"``, ``"angle"``, or ``"weibull"``. learning_process: Dict mapping computed parameter name → annotated learning function. - rl_params: + learning_process_params: Ordered list of sampled RL parameter names. rl_bounds: Parameter bounds for the RL parameters. rl_params_default: - Default values aligned with *rl_params*. + Default values aligned with *learning_process_params*. extra_fields: Data column names required by the learning process (e.g. ``["feedback"]``). choices: @@ -523,7 +523,7 @@ def register_rlssm_model( # Shallow-copy all mutable caller-supplied collections so that later # mutations of the originals do not silently corrupt the registry entry. "learning_process": dict(learning_process), - "rl_params": list(rl_params), + "learning_process_params": list(learning_process_params), "rl_bounds": dict(rl_bounds), "rl_params_default": list(rl_params_default), "extra_fields": list(extra_fields) if extra_fields is not None else [], diff --git a/tests/rl/test_registry.py b/tests/rl/test_registry.py index de0ed2c4..51f1c2cd 100644 --- a/tests/rl/test_registry.py +++ b/tests/rl/test_registry.py @@ -308,7 +308,7 @@ def test_copies_mutable_inputs( learning_process: dict[str, Any], ) -> None: """Caller mutations after registration must not alter the stored model.""" - rl_params = ["rl_alpha"] + learning_process_params = ["rl_alpha"] rl_bounds = {"rl_alpha": (0.0, 1.0)} rl_defaults = [0.2] extra_fields = ["feedback"] @@ -318,14 +318,14 @@ def test_copies_mutable_inputs( name="copy_test_model", decision_process="angle", learning_process=learning_process, - rl_params=rl_params, + learning_process_params=learning_process_params, rl_bounds=rl_bounds, rl_params_default=rl_defaults, extra_fields=extra_fields, choices=choices, ) - rl_params.append("scaler") + learning_process_params.append("scaler") rl_bounds["scaler"] = (0.0, 10.0) rl_defaults.append(1.0) extra_fields.append("trial") @@ -339,7 +339,7 @@ def test_copies_mutable_inputs( assert stored["decision_process"]["name"] == "angle" assert stored["decision_process"]["response"] == ["rt", "response"] - assert stored["rl_params"] == ["rl_alpha"] + assert stored["learning_process_params"] == ["rl_alpha"] assert stored["rl_bounds"] == {"rl_alpha": (0.0, 1.0)} assert stored["rl_params_default"] == [0.2] assert stored["extra_fields"] == ["feedback"] @@ -364,7 +364,7 @@ def test_warns_on_overwrite( name="overwrite_rlssm", decision_process="overwrite_ssm", learning_process=learning_process, - rl_params=["rl_alpha"], + learning_process_params=["rl_alpha"], rl_bounds={"rl_alpha": (0.0, 1.0)}, rl_params_default=[0.2], ) @@ -374,7 +374,7 @@ def test_warns_on_overwrite( name="overwrite_rlssm", decision_process="overwrite_ssm", learning_process=learning_process, - rl_params=["rl_alpha"], + learning_process_params=["rl_alpha"], rl_bounds={"rl_alpha": (0.0, 1.0)}, rl_params_default=[0.2], ) @@ -401,7 +401,7 @@ def test_builds_expected_config( name="unit_test_model", decision_process="unit_test_ssm", learning_process=learning_process, - rl_params=["rl_alpha"], + learning_process_params=["rl_alpha"], rl_bounds={"rl_alpha": (0.0, 1.0)}, rl_params_default=[0.2], extra_fields=["feedback"], @@ -438,7 +438,7 @@ def test_respects_explicit_empty_rl_fields( registry._RLSSM_REGISTRY["empty_rl_model"] = { "decision_process": "empty_rl_ssm", "learning_process": learning_process, - "rl_params": [], + "learning_process_params": [], "rl_bounds": {}, "rl_params_default": [], "extra_fields": ["feedback"], @@ -464,7 +464,7 @@ def test_derives_rl_params_when_absent( annotated_ssm_base_logp: Any, learning_process: dict[str, Any], ) -> None: - """When rl_params is absent from the registry entry, params are derived + """When learning_process_params is absent from the registry entry, params are derived from the learning_process .inputs.""" registry.register_ssm( name="derive_params_ssm", @@ -474,11 +474,11 @@ def test_derives_rl_params_when_absent( params_default_ssm=[0.0, 1.5], response=["rt", "response"], ) - # Inject an entry without rl_params so the fallback derivation runs. + # Inject an entry without learning_process_params so the fallback derivation runs. registry._RLSSM_REGISTRY["derive_params_model"] = { "decision_process": "derive_params_ssm", "learning_process": learning_process, - # rl_params deliberately absent + # learning_process_params deliberately absent "rl_bounds": {}, "rl_params_default": [], "extra_fields": ["feedback"], @@ -517,7 +517,7 @@ def test_raises_for_missing_bounds( name="no_bounds_model", decision_process="no_bounds_ssm", learning_process=learning_process, - rl_params=["rl_alpha"], + learning_process_params=["rl_alpha"], rl_bounds={"rl_alpha": (0.0, 1.0)}, rl_params_default=[0.2], extra_fields=["feedback"], @@ -541,7 +541,7 @@ def test_are_registered(self, model_name: str, expected_dp: str) -> None: assert model_name in registry._RLSSM_REGISTRY entry = registry._RLSSM_REGISTRY[model_name] assert entry["decision_process"]["name"] == expected_dp - assert entry["rl_params"] == ["rl_alpha", "scaler"] + assert entry["learning_process_params"] == ["rl_alpha", "scaler"] assert entry["extra_fields"] == ["feedback"] assert entry["choices"] == [0, 1] assert entry["decision_process_loglik_kind"] == "approx_differentiable" diff --git a/tests/rl/test_rlssm.py b/tests/rl/test_rlssm.py index 2d76ee70..2c2c82d6 100644 --- a/tests/rl/test_rlssm.py +++ b/tests/rl/test_rlssm.py @@ -402,7 +402,7 @@ def test_register_rlssm_model(self, rldm_data) -> None: name="rldm_custom_test", decision_process="angle", learning_process={"v": _compute_v_annotated}, - rl_params=["rl_alpha", "scaler"], + learning_process_params=["rl_alpha", "scaler"], rl_bounds={"rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0)}, rl_params_default=[0.1, 1.0], extra_fields=["feedback"], From 4973e8ad36ea600a724bd9f473d535f86fd1b93f Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 21 May 2026 14:39:03 -0400 Subject: [PATCH 39/41] Fix line too long --- src/hssm/rl/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index 587fca86..ff05308d 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -191,7 +191,7 @@ def _get_ssm_logp(name: str) -> Any: # Each entry provides: # decision_process - key into _SSM_REGISTRY # learning_process - {param: annotated_func} -# learning_process_params - ordered list of sampled RL parameter names +# learning_process_params - ordered list of sampled RL parameter names # rl_bounds - {param: (lo, hi)} for RL params # rl_params_default - default values aligned with learning_process_params # extra_fields - extra data column names required by LP From 52d5fd6cb03cdd7a9bb3d85e96447fa2eed61134 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 21 May 2026 14:50:24 -0400 Subject: [PATCH 40/41] Rename rl_bounds to learning_process_bounds --- src/hssm/rl/registry.py | 16 ++++++++-------- tests/rl/test_registry.py | 20 ++++++++++---------- tests/rl/test_rlssm.py | 2 +- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index ff05308d..68cb00a4 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -192,7 +192,7 @@ def _get_ssm_logp(name: str) -> Any: # decision_process - key into _SSM_REGISTRY # learning_process - {param: annotated_func} # learning_process_params - ordered list of sampled RL parameter names -# rl_bounds - {param: (lo, hi)} for RL params +# learning_process_bounds - {param: (lo, hi)} for RL params # rl_params_default - default values aligned with learning_process_params # extra_fields - extra data column names required by LP # choices - response choice values @@ -205,7 +205,7 @@ def _get_ssm_logp(name: str) -> Any: "decision_process": _get_decision_process_spec("ddm"), "learning_process": {"v": _compute_v_annotated}, "learning_process_params": ["rl_alpha", "scaler"], - "rl_bounds": { + "learning_process_bounds": { "rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0), }, @@ -223,7 +223,7 @@ def _get_ssm_logp(name: str) -> Any: "decision_process": _get_decision_process_spec("angle"), "learning_process": {"v": _compute_v_annotated}, "learning_process_params": ["rl_alpha", "scaler"], - "rl_bounds": { + "learning_process_bounds": { "rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0), }, @@ -241,7 +241,7 @@ def _get_ssm_logp(name: str) -> Any: "decision_process": _get_decision_process_spec("weibull"), "learning_process": {"v": _compute_v_annotated}, "learning_process_params": ["rl_alpha", "scaler"], - "rl_bounds": { + "learning_process_bounds": { "rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0), }, @@ -403,7 +403,7 @@ def get_rlssm_model_config( "Provide bounds for all sampled parameters via register_ssm() or ensure " "the built-in modelconfig includes them." ) - _rl_bounds = entry.get("rl_bounds") + _rl_bounds = entry.get("learning_process_bounds") bounds: dict[str, tuple[float, float]] = dict( _rl_bounds if _rl_bounds is not None else {} ) @@ -476,7 +476,7 @@ def register_rlssm_model( decision_process: str, learning_process: dict[str, Any], learning_process_params: list[str], - rl_bounds: dict[str, tuple[float, float]], + learning_process_bounds: dict[str, tuple[float, float]], rl_params_default: list[float], extra_fields: list[str] | None = None, choices: list[int] | None = None, @@ -498,7 +498,7 @@ def register_rlssm_model( Dict mapping computed parameter name → annotated learning function. learning_process_params: Ordered list of sampled RL parameter names. - rl_bounds: + learning_process_bounds: Parameter bounds for the RL parameters. rl_params_default: Default values aligned with *learning_process_params*. @@ -524,7 +524,7 @@ def register_rlssm_model( # mutations of the originals do not silently corrupt the registry entry. "learning_process": dict(learning_process), "learning_process_params": list(learning_process_params), - "rl_bounds": dict(rl_bounds), + "learning_process_bounds": dict(learning_process_bounds), "rl_params_default": list(rl_params_default), "extra_fields": list(extra_fields) if extra_fields is not None else [], "choices": list(choices) if choices is not None else [0, 1], diff --git a/tests/rl/test_registry.py b/tests/rl/test_registry.py index 51f1c2cd..02363876 100644 --- a/tests/rl/test_registry.py +++ b/tests/rl/test_registry.py @@ -309,7 +309,7 @@ def test_copies_mutable_inputs( ) -> None: """Caller mutations after registration must not alter the stored model.""" learning_process_params = ["rl_alpha"] - rl_bounds = {"rl_alpha": (0.0, 1.0)} + learning_process_bounds = {"rl_alpha": (0.0, 1.0)} rl_defaults = [0.2] extra_fields = ["feedback"] choices = [0, 1] @@ -319,14 +319,14 @@ def test_copies_mutable_inputs( decision_process="angle", learning_process=learning_process, learning_process_params=learning_process_params, - rl_bounds=rl_bounds, + learning_process_bounds=learning_process_bounds, rl_params_default=rl_defaults, extra_fields=extra_fields, choices=choices, ) learning_process_params.append("scaler") - rl_bounds["scaler"] = (0.0, 10.0) + learning_process_bounds["scaler"] = (0.0, 10.0) rl_defaults.append(1.0) extra_fields.append("trial") choices.append(2) @@ -340,7 +340,7 @@ def test_copies_mutable_inputs( assert stored["decision_process"]["name"] == "angle" assert stored["decision_process"]["response"] == ["rt", "response"] assert stored["learning_process_params"] == ["rl_alpha"] - assert stored["rl_bounds"] == {"rl_alpha": (0.0, 1.0)} + assert stored["learning_process_bounds"] == {"rl_alpha": (0.0, 1.0)} assert stored["rl_params_default"] == [0.2] assert stored["extra_fields"] == ["feedback"] assert stored["choices"] == [0, 1] @@ -365,7 +365,7 @@ def test_warns_on_overwrite( decision_process="overwrite_ssm", learning_process=learning_process, learning_process_params=["rl_alpha"], - rl_bounds={"rl_alpha": (0.0, 1.0)}, + learning_process_bounds={"rl_alpha": (0.0, 1.0)}, rl_params_default=[0.2], ) @@ -375,7 +375,7 @@ def test_warns_on_overwrite( decision_process="overwrite_ssm", learning_process=learning_process, learning_process_params=["rl_alpha"], - rl_bounds={"rl_alpha": (0.0, 1.0)}, + learning_process_bounds={"rl_alpha": (0.0, 1.0)}, rl_params_default=[0.2], ) @@ -402,7 +402,7 @@ def test_builds_expected_config( decision_process="unit_test_ssm", learning_process=learning_process, learning_process_params=["rl_alpha"], - rl_bounds={"rl_alpha": (0.0, 1.0)}, + learning_process_bounds={"rl_alpha": (0.0, 1.0)}, rl_params_default=[0.2], extra_fields=["feedback"], choices=[0, 1], @@ -439,7 +439,7 @@ def test_respects_explicit_empty_rl_fields( "decision_process": "empty_rl_ssm", "learning_process": learning_process, "learning_process_params": [], - "rl_bounds": {}, + "learning_process_bounds": {}, "rl_params_default": [], "extra_fields": ["feedback"], "choices": [0, 1], @@ -479,7 +479,7 @@ def test_derives_rl_params_when_absent( "decision_process": "derive_params_ssm", "learning_process": learning_process, # learning_process_params deliberately absent - "rl_bounds": {}, + "learning_process_bounds": {}, "rl_params_default": [], "extra_fields": ["feedback"], "choices": [0, 1], @@ -518,7 +518,7 @@ def test_raises_for_missing_bounds( decision_process="no_bounds_ssm", learning_process=learning_process, learning_process_params=["rl_alpha"], - rl_bounds={"rl_alpha": (0.0, 1.0)}, + learning_process_bounds={"rl_alpha": (0.0, 1.0)}, rl_params_default=[0.2], extra_fields=["feedback"], choices=[0, 1], diff --git a/tests/rl/test_rlssm.py b/tests/rl/test_rlssm.py index 2c2c82d6..c1840175 100644 --- a/tests/rl/test_rlssm.py +++ b/tests/rl/test_rlssm.py @@ -403,7 +403,7 @@ def test_register_rlssm_model(self, rldm_data) -> None: decision_process="angle", learning_process={"v": _compute_v_annotated}, learning_process_params=["rl_alpha", "scaler"], - rl_bounds={"rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0)}, + learning_process_bounds={"rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0)}, rl_params_default=[0.1, 1.0], extra_fields=["feedback"], choices=[0, 1], From fe0be816bfa45c14f0aad58a44c8846566c7c6a9 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 21 May 2026 14:53:22 -0400 Subject: [PATCH 41/41] Rename rl_params_default to learning_process_params_default --- src/hssm/rl/registry.py | 19 ++++++++++--------- tests/rl/test_registry.py | 16 ++++++++-------- tests/rl/test_rlssm.py | 2 +- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/hssm/rl/registry.py b/src/hssm/rl/registry.py index 68cb00a4..df4c8897 100644 --- a/src/hssm/rl/registry.py +++ b/src/hssm/rl/registry.py @@ -192,8 +192,9 @@ def _get_ssm_logp(name: str) -> Any: # decision_process - key into _SSM_REGISTRY # learning_process - {param: annotated_func} # learning_process_params - ordered list of sampled RL parameter names -# learning_process_bounds - {param: (lo, hi)} for RL params -# rl_params_default - default values aligned with learning_process_params +# learning_process_bounds - {param: (lo, hi)} for RL params +# learning_process_params_default +# - default values aligned with learning_process_params # extra_fields - extra data column names required by LP # choices - response choice values # description - human-readable description @@ -209,7 +210,7 @@ def _get_ssm_logp(name: str) -> Any: "rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0), }, - "rl_params_default": [0.1, 1.0], + "learning_process_params_default": [0.1, 1.0], "extra_fields": ["feedback"], "choices": [0, 1], "description": ( @@ -227,7 +228,7 @@ def _get_ssm_logp(name: str) -> Any: "rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0), }, - "rl_params_default": [0.1, 1.0], + "learning_process_params_default": [0.1, 1.0], "extra_fields": ["feedback"], "choices": [0, 1], "description": ( @@ -245,7 +246,7 @@ def _get_ssm_logp(name: str) -> Any: "rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0), }, - "rl_params_default": [0.1, 1.0], + "learning_process_params_default": [0.1, 1.0], "extra_fields": ["feedback"], "choices": [0, 1], "description": ( @@ -411,7 +412,7 @@ def get_rlssm_model_config( bounds[p] = ssm_entry["bounds_ssm"][p] # params_default aligned with list_params - _rl_defaults = entry.get("rl_params_default") + _rl_defaults = entry.get("learning_process_params_default") rl_defaults: list[float] = list(_rl_defaults if _rl_defaults is not None else []) ssm_all_defaults: list[float] = list(ssm_entry["params_default_ssm"]) ssm_sampled_defaults = [ @@ -477,7 +478,7 @@ def register_rlssm_model( learning_process: dict[str, Any], learning_process_params: list[str], learning_process_bounds: dict[str, tuple[float, float]], - rl_params_default: list[float], + learning_process_params_default: list[float], extra_fields: list[str] | None = None, choices: list[int] | None = None, description: str | None = None, @@ -500,7 +501,7 @@ def register_rlssm_model( Ordered list of sampled RL parameter names. learning_process_bounds: Parameter bounds for the RL parameters. - rl_params_default: + learning_process_params_default: Default values aligned with *learning_process_params*. extra_fields: Data column names required by the learning process (e.g. ``["feedback"]``). @@ -525,7 +526,7 @@ def register_rlssm_model( "learning_process": dict(learning_process), "learning_process_params": list(learning_process_params), "learning_process_bounds": dict(learning_process_bounds), - "rl_params_default": list(rl_params_default), + "learning_process_params_default": list(learning_process_params_default), "extra_fields": list(extra_fields) if extra_fields is not None else [], "choices": list(choices) if choices is not None else [0, 1], "description": description, diff --git a/tests/rl/test_registry.py b/tests/rl/test_registry.py index 02363876..f8bbcf29 100644 --- a/tests/rl/test_registry.py +++ b/tests/rl/test_registry.py @@ -320,7 +320,7 @@ def test_copies_mutable_inputs( learning_process=learning_process, learning_process_params=learning_process_params, learning_process_bounds=learning_process_bounds, - rl_params_default=rl_defaults, + learning_process_params_default=rl_defaults, extra_fields=extra_fields, choices=choices, ) @@ -341,7 +341,7 @@ def test_copies_mutable_inputs( assert stored["decision_process"]["response"] == ["rt", "response"] assert stored["learning_process_params"] == ["rl_alpha"] assert stored["learning_process_bounds"] == {"rl_alpha": (0.0, 1.0)} - assert stored["rl_params_default"] == [0.2] + assert stored["learning_process_params_default"] == [0.2] assert stored["extra_fields"] == ["feedback"] assert stored["choices"] == [0, 1] assert list(stored["learning_process"]) == ["v"] @@ -366,7 +366,7 @@ def test_warns_on_overwrite( learning_process=learning_process, learning_process_params=["rl_alpha"], learning_process_bounds={"rl_alpha": (0.0, 1.0)}, - rl_params_default=[0.2], + learning_process_params_default=[0.2], ) with caplog.at_level(logging.WARNING, logger="hssm"): @@ -376,7 +376,7 @@ def test_warns_on_overwrite( learning_process=learning_process, learning_process_params=["rl_alpha"], learning_process_bounds={"rl_alpha": (0.0, 1.0)}, - rl_params_default=[0.2], + learning_process_params_default=[0.2], ) assert any("overwrite_rlssm" in r.message for r in caplog.records) @@ -403,7 +403,7 @@ def test_builds_expected_config( learning_process=learning_process, learning_process_params=["rl_alpha"], learning_process_bounds={"rl_alpha": (0.0, 1.0)}, - rl_params_default=[0.2], + learning_process_params_default=[0.2], extra_fields=["feedback"], choices=[0, 1], ) @@ -440,7 +440,7 @@ def test_respects_explicit_empty_rl_fields( "learning_process": learning_process, "learning_process_params": [], "learning_process_bounds": {}, - "rl_params_default": [], + "learning_process_params_default": [], "extra_fields": ["feedback"], "choices": [0, 1], "description": "test model", @@ -480,7 +480,7 @@ def test_derives_rl_params_when_absent( "learning_process": learning_process, # learning_process_params deliberately absent "learning_process_bounds": {}, - "rl_params_default": [], + "learning_process_params_default": [], "extra_fields": ["feedback"], "choices": [0, 1], "description": None, @@ -519,7 +519,7 @@ def test_raises_for_missing_bounds( learning_process=learning_process, learning_process_params=["rl_alpha"], learning_process_bounds={"rl_alpha": (0.0, 1.0)}, - rl_params_default=[0.2], + learning_process_params_default=[0.2], extra_fields=["feedback"], choices=[0, 1], ) diff --git a/tests/rl/test_rlssm.py b/tests/rl/test_rlssm.py index c1840175..e5fe341c 100644 --- a/tests/rl/test_rlssm.py +++ b/tests/rl/test_rlssm.py @@ -404,7 +404,7 @@ def test_register_rlssm_model(self, rldm_data) -> None: learning_process={"v": _compute_v_annotated}, learning_process_params=["rl_alpha", "scaler"], learning_process_bounds={"rl_alpha": (0.0, 1.0), "scaler": (0.0, 10.0)}, - rl_params_default=[0.1, 1.0], + learning_process_params_default=[0.1, 1.0], extra_fields=["feedback"], choices=[0, 1], )