Cpaniaguam/add rlssm model registry#956
Conversation
e668fc9 to
46b958c
Compare
There was a problem hiding this comment.
Pull request overview
Adds a dedicated registry module for RLSSM model specifications and SSM base log-likelihood functions, enabling named RLSSM configs to be composed on demand (including lazy ONNX-backed SSM logp resolution), with accompanying unit tests to validate registry behavior.
Changes:
- Introduces
hssm.rl.registrywith_SSM_REGISTRY/_RLSSM_REGISTRY, lazy SSM logp caching, and helpers to register and composeRLSSMConfig. - Adds registry-focused unit tests covering lazy factory caching, RL param derivation, config composition, and registration validation.
- Provides public registration helpers
register_rlssm_model()andregister_ssm()for user-extensible registries.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
src/hssm/rl/registry.py |
Implements RLSSM/SSM registries, lazy SSM logp resolution + caching, and config composition/registration helpers. |
tests/rl/test_registry.py |
Adds unit tests targeting registry behavior (lazy loading, composition rules, and defensive copying expectations). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| inputs=["rl_alpha", "response", "feedback"], | ||
| outputs=["v"], | ||
| ) | ||
| def compute_v(rl_alpha, response, feedback): |
| # Shallow-copy so overrides don't mutate the registry entry. | ||
| entry = dict(_RLSSM_REGISTRY[model]) | ||
|
|
||
| if learning_process is not None: | ||
| entry["learning_process"] = learning_process | ||
| if decision_process is not None: | ||
| entry["decision_process"] = _get_decision_process_spec(decision_process) | ||
| if choices is not None: | ||
| entry["choices"] = choices |
| ssm_entry = _get_decision_process_spec(entry["decision_process"]) | ||
| dp: str = ssm_entry["name"] |
| # RLSSM named model registry | ||
| # --------------------------------------------------------------------------- | ||
| # Each entry provides: | ||
| # decision_process - key into _SSM_REGISTRY |
| _SSM_REGISTRY[name] = { | ||
| "ssm_base_logp_func": ssm_base_logp_func, | ||
| "list_params_ssm": list(list_params_ssm), | ||
| "bounds_ssm": dict(bounds_ssm), | ||
| "params_default_ssm": list(params_default_ssm), |
| @@ -0,0 +1,462 @@ | |||
| """Registry for named RLSSM models and SSM base log-likelihood functions. | |||
There was a problem hiding this comment.
I think I review this file via #955 ?
Possibly meant to be staggered PRs?
There was a problem hiding this comment.
@AlexanderFengler This is a separate idea for how to do the model RLSSM registry. I think the one in #955 is probably overengineered. I will notify everyone when they are ready for review. Thanks!
|
|
||
| This module provides: | ||
|
|
||
| - :data:`_SSM_REGISTRY` — maps SSM names (e.g. ``"angle"``) to their base |
There was a problem hiding this comment.
one complication that we need to account here: we want to work with choice-only models as well. therefore, design this such that we can flexibly combine the RL process with choice-only model (eg. inverse temperature softmax that Paul implemented).
@AlexanderFengler this will require making additional nomenclature choices? i dont think it is a good idea to use "SSM" for choice-only model.
There was a problem hiding this comment.
a general name could just be "_DECISION_PROCESS_REGISTRY". If the specified decision process is an SSM, we can just rely on the SSM_REGISTRY processing logic and define a new processing logic for choice-only models?
| "a": (0.3, 3.0), | ||
| "z": (0.1, 0.9), | ||
| "t": (0.001, 2.0), | ||
| "theta": (-0.1, 1.3), |
There was a problem hiding this comment.
theta lower bound should be 0.
There was a problem hiding this comment.
In
HSSM/src/hssm/modelconfig/angle_config.py
Line 30 in a34a766
| # learning_process_kind | ||
|
|
||
| _RLSSM_REGISTRY: dict[str, dict[str, Any]] = { | ||
| "rldm": { |
There was a problem hiding this comment.
we can just name this model -- "2AB_RescorlaWagner_Angle" (2AB stands for 2-armed bandits).
| # Shallow-copy so overrides don't mutate the registry entry. | ||
| entry = dict(_RLSSM_REGISTRY[model]) | ||
|
|
||
| if learning_process is not None: | ||
| entry["learning_process"] = learning_process | ||
| if decision_process is not None: | ||
| entry["decision_process"] = _get_decision_process_spec(decision_process) | ||
| if choices is not None: | ||
| entry["choices"] = choices |
flowchart TD A[register_rlssm_model] --> B[REGISTERED_RL_MODELS] subgraph Registry B["REGISTERED_RL_MODELS\nname -> RLSSM spec"] end subgraph Stored_Spec C["decision_process spec\nname, lazy ssm logp factory,\nparams, bounds, defaults, response"] D["learning_process"] E["rl_params / rl_bounds /\nrl_params_default"] F["choices / extra_fields /\ndescription"] end B --> C B --> D B --> E B --> F G[get_rlssm_model_config] --> B G --> H[copy named spec] H --> I[resolve lazy SSM logp] I --> J[compose computed SSM logp] J --> K[merge RL and SSM metadata] K --> L[build fresh RLSSMConfig] style B fill:#1f2937,stroke:#cbd5e1,stroke-width:2px,color:#f9fafb style C fill:#0f766e,stroke:#ccfbf1,stroke-width:2px,color:#f9fafb style D fill:#0f766e,stroke:#ccfbf1,stroke-width:2px,color:#f9fafb style E fill:#0f766e,stroke:#ccfbf1,stroke-width:2px,color:#f9fafb style F fill:#0f766e,stroke:#ccfbf1,stroke-width:2px,color:#f9fafb style A fill:#7c2d12,stroke:#fed7aa,stroke-width:2px,color:#f9fafb style G fill:#7c2d12,stroke:#fed7aa,stroke-width:2px,color:#f9fafb style H fill:#1d4ed8,stroke:#bfdbfe,stroke-width:2px,color:#f9fafb style I fill:#1d4ed8,stroke:#bfdbfe,stroke-width:2px,color:#f9fafb style J fill:#1d4ed8,stroke:#bfdbfe,stroke-width:2px,color:#f9fafb style K fill:#1d4ed8,stroke:#bfdbfe,stroke-width:2px,color:#f9fafb style L fill:#166534,stroke:#bbf7d0,stroke-width:2px,color:#f9fafbRegistry for RLSSM model specs.