Skip to content

[ENH] Add _reset parameter to set_params in skbase #554

@SimonBlanke

Description

@SimonBlanke

Problem

set_params in skbase unconditionally calls self.reset() after setting new parameter values (line 389 in skbase/base/_base.py). The reset() method deletes all instance attributes and re-runs __init__, which destroys any pretrained state including model weights, the _state flag, and the _pretrained_attrs list. This makes the natural workflow of "pretrain a model, adjust hyperparameters for fine-tuning, then continue training" impossible.
The reset() call exists to ensure consistency between parameter values and derived state. In the general case this is correct: if you change a parameter that influenced how __init__ set up internal structures, those structures need rebuilding. But for the fine-tuning (pretraining) use case, the user explicitly wants to keep the derived state (the pretrained network weights) and only change parameters that affect the next training iteration, such as learning rate, batch size, or number of epochs. These parameters are consumed only when training resumes and don't invalidate existing weights.

Proposed Solution

Add a _reset parameter (default True) to set_params in skbase/base/_base.py. When _reset=False, parameter values are set via setattr as usual, but reset() is not called, so all instance state (including pretrained attributes) survives. The parameter propagates through nested component calls so that pipeline.set_params(forecaster__lr=0.001, _reset=False) preserves state at every level.

Design Decisions

User takes responsibility. We deliberately do not validate which parameters are safe to change without reset. Building a validation layer for "finetuning-compatible parameters" would require per-estimator annotation of every parameter and still could not cover value-dependent cases (some parameters are safe to change to certain values but not others, e.g. increasing num_epochs is fine but changing context_length breaks weight dimensions). If a user changes a structural parameter without reset and the model can't handle it, it will fail at the next training or prediction call with a clear error (typically a tensor shape mismatch). This "when it fails, it fails" approach avoids a complex validation system that would be incomplete by nature and burdensome to maintain.

Underscore prefix on _reset. The parameter uses _reset rather than reset to avoid collision with estimator parameters that might be named reset. The underscore prefix also signals "use with care" to the user, which aligns with the responsibility model described above. We could also use a more specific model_reset.

Propagation to nested components. When called on a composite estimator, _reset=False propagates through all nested set_params calls. This is essential for the pipeline use case where a user wants to adjust a finetuning parameter on a nested pretrained component without triggering resets anywhere in the hierarchy.

Example Usage

# Pretrain, adjust LR, continue training
model = TTMForecaster(model_path="ibm/TTM", learning_rate=1e-3)
model.pretrain(X=X_panel, y=y_panel)
model.set_params(learning_rate=1e-4, num_epochs=5, _reset=False)
model.pretrain(X=X_panel2, y=y_panel2)  # continues with lower LR, weights preserved

# Inside a pipeline
pipe = make_pipeline(StandardScaler(), TTMForecaster(...))
pipe.pretrain(X=X_panel, y=y_panel)
pipe.set_params(ttmforecaster__learning_rate=1e-4, _reset=False)
pipe.fit(X=X_test, y=y_test)  # fine-tunes pretrained model on target

What Happens on Misuse

When a user changes a structural parameter without resetting, the stored network has architecture expectations (input dimensions, layer counts, embedding sizes) that no longer match the new configuration. The next training or prediction call will fail with an explicit error from the underlying framework (PyTorch dimension mismatch, shape error, etc.). This is acceptable and expected. The user changed the model's structural identity while keeping its old weights; the inconsistency surfaces immediately.

model.pretrain(X=X_panel, y=y_panel)
model.set_params(context_length=1024, _reset=False)  # structural change
model.pretrain(X=X_panel2, y=y_panel2)
# RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x64 and 1024x64)

Acceptance Criteria

set_params(_reset=False, **params) should set parameter values without calling reset(). Pretrained state (model weights, _state, _pretrained_attrs) must survive a set_params(_reset=False) call. Calling set_params() without specifying _reset must behave exactly as before (full backward compatibility, _reset=True is the default). Nested set_params on composite estimators must propagate _reset to component estimators.

Part of

Tracking issue: sktime/sktime/issues/10151

Metadata

Metadata

Assignees

No one assigned

    Labels

    API designAPI design & software architectureenhancementAdding new functionalityimplementing frameworkImplementing core skbase framework

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions