Cpaniaguam/rlssm simplified interface#955
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a simplified public interface for RLSSM by adding a named-model registry and wrapping the existing RLSSM implementation behind a friendlier constructor that can build an RLSSMConfig from a model name (defaulting to "rldm").
Changes:
- Split the existing RLSSM implementation into an internal base class (
_RLSSM) and a public wrapper (RLSSM) with a simplified constructor. - Added an RLSSM/SSM registry and factory (
get_rlssm_model_config,register_rlssm_model,register_ssm) to construct configs from named models. - Expanded the RLSSM test suite to cover the simplified interface, registry behavior, and the new wrapper semantics.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
tests/test_rlssm.py |
Adds coverage for the new simplified RLSSM constructor and registry-based config creation. |
src/hssm/rl/rlssm.py |
Renames the prior implementation to _RLSSM and adds the public RLSSM wrapper + blocked-attribute behavior. |
src/hssm/rl/registry.py |
New registry/factory module for named RLSSM models and SSM base logp functions. |
src/hssm/rl/__init__.py |
Exposes _RLSSM and registry helpers in the hssm.rl public API. |
src/hssm/__init__.py |
Exposes register_rlssm_model at the top-level hssm API. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Added KeyError for missing SSM in the registry. - Improved ValueError messages for non-callable logp functions. - Implemented runtime checks for list_params and loglik in RLSSM. - Ensured defensive copying of mutable parameters in registry functions.
…rs for RL parameters, bounds, and defaults
…y .computed attribute
…l registration process
AlexanderFengler
left a comment
There was a problem hiding this comment.
Left a few comments, mostly looks good already.
@krishnbera can you take a look too?
| - ``RLSSMConfig``: the config class for RL + SSM models in :mod:`hssm.rl.config`. | ||
| - ``get_rlssm_model_config``: factory that builds a config from a named model. | ||
| - ``register_rlssm_model``: register a custom named RLSSM model. | ||
| - ``register_ssm``: register a custom SSM base logp function. |
There was a problem hiding this comment.
would we want/need a register_learning_process equivalent here? ( @krishnbera )
There was a problem hiding this comment.
yes we should have a way of registering the learning process separately.
…SSM model registration
…od function handling
flowchart TD
USER["User: RLSSM(model='rldm')"]
RLSSM_REG["_RLSSM_REGISTRY\nnamed models\ne.g. 'rldm'"]
SSM_REG["_SSM_REGISTRY\ncustom SSMs only\n(empty by default)"]
MODELCONFIG["hssm.modelconfig\nbuilt-in SSMs\nddm · angle · weibull · ornstein · …"]
CACHE["_SSM_LOGP_CACHE\nlazy ONNX → JAX fn"]
OUTPUT["RLSSMConfig"]
USER --> RLSSM_REG
RLSSM_REG -- decision_process name --> LOOKUP
LOOKUP{"registered\ncustom SSM?"}
LOOKUP -- yes --> SSM_REG
LOOKUP -- no --> MODELCONFIG
SSM_REG & MODELCONFIG --> CACHE
CACHE -- annotated JAX fn --> OUTPUT
RLSSM_REG -- rl params / bounds / learning process --> OUTPUT
|
…_models_config_structure
… organizing test cases
krishnbera
left a comment
There was a problem hiding this comment.
looks good overall. added minor comments.
| Examples | ||
| -------- | ||
| >>> import hssm | ||
| >>> hssm.rl.list_models() |
flowchart TD subgraph SSM_Registry A1["_SSM_REGISTRY (dict)"] A2["_SSM_LOGP_CACHE (dict)"] end subgraph RLSSM_Registry B1["_RLSSM_REGISTRY (dict)"] end subgraph User_API C1[register_ssm] C2[register_rlssm_model] C3[get_rlssm_model_config] end C1-->|adds entry|A1 C1-->|adds entry|A2 C2-->|adds entry|B1 C3-->|reads entry|B1 C3-->|reads entry|A1 C3-->|calls _get_ssm_logp|A2 A2-->|lazy build if needed|A1 C3-->|returns RLSSMConfig|D1["RLSSMConfig"] style D1 fill:#1e7a1e,stroke:#fff,stroke-width:2px,color:#fff style A1 fill:#1e3a7a,stroke:#fff,stroke-width:2px,color:#fff style A2 fill:#1e3a7a,stroke:#fff,stroke-width:2px,color:#fff style B1 fill:#7a1e1e,stroke:#fff,stroke-width:2px,color:#fff style C1 fill:#7a7a1e,stroke:#fff,stroke-width:2px,color:#fff style C2 fill:#7a7a1e,stroke:#fff,stroke-width:2px,color:#fff style C3 fill:#7a7a1e,stroke:#fff,stroke-width:2px,color:#fff