From 518c60bafa9af02ed25814a032baf744acb2fd1a Mon Sep 17 00:00:00 2001 From: Cristian Lara Date: Wed, 11 Feb 2026 15:16:01 -0800 Subject: [PATCH] Add suggested_experiment_status field (#4902) Summary: This change introduces an optional `suggested_experiment_state` field to the `GenerationNode` class that allows tracking what experiment status is suggested for a given generation node. This is part of a larger effort to add status tracking to experiments. The field is: - Optional (defaults to None for backward compatibility) - Advisory only (does not automatically update experiment.status) - Configurable per GenerationNode instance - Serialized automatically via SerializationMixin - Displayed in __repr__ when set This is Phase 1 where we are just adding the column but not yet doing anything with it. In the next diffs in this stack we will propagate this through the orchestrator and eventually set this status on the experiment. Reviewed By: mgarrard Differential Revision: D88089767 --- .../center_generation_node.py | 15 +++++++- .../external_generation_node.py | 5 +++ ax/generation_strategy/generation_node.py | 15 ++++++++ .../tests/test_generation_node.py | 34 +++++++++++++++++-- ax/storage/json_store/decoder.py | 9 +++++ ax/storage/json_store/encoders.py | 1 + ax/storage/json_store/registry.py | 2 ++ 7 files changed, 78 insertions(+), 3 deletions(-) diff --git a/ax/generation_strategy/center_generation_node.py b/ax/generation_strategy/center_generation_node.py index d8f1e8ee90e..201cf4ed646 100644 --- a/ax/generation_strategy/center_generation_node.py +++ b/ax/generation_strategy/center_generation_node.py @@ -13,6 +13,7 @@ from ax.core.arm import Arm from ax.core.data import Data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.parameter import DerivedParameter @@ -29,7 +30,12 @@ class CenterGenerationNode(ExternalGenerationNode): next_node_name: str - def __init__(self, next_node_name: str) -> None: + def __init__( + self, + next_node_name: str, + suggested_experiment_status: ExperimentStatus + | None = ExperimentStatus.INITIALIZATION, + ) -> None: """A generation node that samples the center of the search space. This generation node is only used to generate the first point of the experiment. After one point is generated, it will transition to `next_node_name`. @@ -37,9 +43,16 @@ def __init__(self, next_node_name: str) -> None: If the generated point is a duplicate of an arm already attached to the experiment, this will fallback to Sobol through the use of ``GenerationNode`` deduplication logic. + + Args: + next_node_name: The name of the node to transition to after generating + the center point. + suggested_experiment_status: Optional suggested experiment status for this + node. """ super().__init__( name="CenterOfSearchSpace", + suggested_experiment_status=suggested_experiment_status, transition_criteria=[ AutoTransitionAfterGen( transition_to=next_node_name, diff --git a/ax/generation_strategy/external_generation_node.py b/ax/generation_strategy/external_generation_node.py index b5192a97b9d..6e153848895 100644 --- a/ax/generation_strategy/external_generation_node.py +++ b/ax/generation_strategy/external_generation_node.py @@ -14,6 +14,7 @@ from ax.core.arm import Arm from ax.core.data import Data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.types import TParameterization @@ -48,6 +49,7 @@ class ExternalGenerationNode(GenerationNode, ABC): def __init__( self, name: str, + suggested_experiment_status: ExperimentStatus | None = None, should_deduplicate: bool = True, transition_criteria: Sequence[TransitionCriterion] | None = None, ) -> None: @@ -59,6 +61,8 @@ def __init__( Args: name: Name of the generation node. + suggested_experiment_status: Optional suggested experiment status for this + node. Defaults to None if not specified. should_deduplicate: Whether to deduplicate the generated points against the existing trials on the experiment. If True, the duplicate points will be discarded and re-generated up to 5 times, after which a @@ -73,6 +77,7 @@ def __init__( super().__init__( name=name, generator_specs=[], + suggested_experiment_status=suggested_experiment_status, best_model_selector=None, should_deduplicate=should_deduplicate, transition_criteria=transition_criteria, diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 89b9a0ec1ae..70b64d641d8 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -15,6 +15,7 @@ from ax.core.data import Data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.trial_status import TrialStatus @@ -112,6 +113,10 @@ class GenerationNode(SerializationMixin, SortableBase): store the most recent previous ``GenerationNode`` name. should_skip: Whether to skip this node during generation time. Defaults to False, and can only currently be set to True via ``NodeInputConstructors`` + suggested_experiment_status: Optional ``ExperimentStatus`` that indicates + what the experiment's status should be once the experiment adds trials + using ``GeneratorRun``-s produced from this node. This is advisory only + and does not automatically update the experiment's status. fallback_specs: Optional dict mapping expected exception types to `ModelSpec` fallbacks used when gen fails. @@ -134,6 +139,7 @@ class GenerationNode(SerializationMixin, SortableBase): _previous_node_name: str | None = None _trial_type: str | None = None _should_skip: bool = False + suggested_experiment_status: ExperimentStatus | None = None fallback_specs: dict[type[Exception], GeneratorSpec] # [TODO] Handle experiment passing more eloquently by enforcing experiment @@ -155,6 +161,7 @@ def __init__( previous_node_name: str | None = None, trial_type: str | None = None, should_skip: bool = False, + suggested_experiment_status: ExperimentStatus | None = None, fallback_specs: dict[type[Exception], GeneratorSpec] | None = None, ) -> None: self._name = name @@ -187,6 +194,7 @@ def __init__( self._previous_node_name = previous_node_name self._trial_type = trial_type self._should_skip = should_skip + self.suggested_experiment_status = suggested_experiment_status self.fallback_specs = ( fallback_specs if fallback_specs is not None else DEFAULT_FALLBACK ) @@ -354,6 +362,10 @@ def __repr__(self) -> str: str_rep += ( f", transition_criteria={str(self._brief_transition_criteria_repr())}" ) + if self.suggested_experiment_status is not None: + str_rep += ( + f", suggested_experiment_status={self.suggested_experiment_status!r}" + ) return f"{str_rep})" def _fit( @@ -999,6 +1011,7 @@ class GenerationStep: whether to transition to the next step. If False, `num_trials` and `min_trials_observed` will only count trials generatd by this step. If True, they will count all trials in the experiment (of corresponding statuses). + suggested_experiment_status: The suggested experiment status for this step. Note for developers: by "generator" here we really mean an ``Adapter`` object, which contains a ``Generator`` under the hood. We call it "generator" here to simplify and @@ -1019,6 +1032,7 @@ def __new__( use_all_trials_in_exp: bool = False, use_update: bool = False, # DEPRECATED. index: int = -1, # Index of this step, set internally. + suggested_experiment_status: ExperimentStatus | None = None, # Deprecated arguments for backwards compatibility. model_kwargs: dict[str, Any] | None = None, model_gen_kwargs: dict[str, Any] | None = None, @@ -1135,6 +1149,7 @@ def __new__( step_index=index, generator_name=resolved_generator_name ), generator_specs=[generator_spec], + suggested_experiment_status=suggested_experiment_status, should_deduplicate=should_deduplicate, transition_criteria=transition_criteria, ) diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index 3cae54f0b27..266873cf5ad 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -11,6 +11,7 @@ import torch from ax.adapter.factory import get_sobol from ax.adapter.registry import Generators +from ax.core.experiment_status import ExperimentStatus from ax.core.observation import ObservationFeatures from ax.core.trial_status import TrialStatus from ax.exceptions.core import UserInputError @@ -53,13 +54,16 @@ def setUp(self) -> None: generator_gen_kwargs={}, ) self.sobol_generation_node = GenerationNode( - name="test", generator_specs=[self.sobol_generator_spec] + name="test", + generator_specs=[self.sobol_generator_spec], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, ) self.branin_experiment = get_branin_experiment(with_completed_trial=True) self.branin_data = self.branin_experiment.lookup_data() self.node_short = GenerationNode( name="test", generator_specs=[self.sobol_generator_spec], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, trial_type=Keys.SHORT_RUN, ) @@ -97,6 +101,30 @@ def test_init(self) -> None: self.assertEqual(node.generator_specs, mbm_specs) self.assertIs(node.best_model_selector, model_selector) + def test_suggested_experiment_status(self) -> None: + """Test that suggested_experiment_status is properly set and accessible.""" + with self.subTest("initialization set"): + self.assertEqual( + self.sobol_generation_node.suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + + with self.subTest("default None when not provided"): + node_without_state = GenerationNode( + name="test", + generator_specs=[self.sobol_generator_spec], + ) + self.assertIsNone(node_without_state.suggested_experiment_status) + + with self.subTest("__repr__ includes status when set"): + repr_str = repr(self.sobol_generation_node) + self.assertIn("suggested_experiment_status", repr_str) + self.assertIn("INITIALIZATION", repr_str) + + with self.subTest("__repr__ excludes status when None"): + repr_str_without = repr(node_without_state) + self.assertNotIn("suggested_experiment_status", repr_str_without) + def test_input_constructor_none(self) -> None: self.assertEqual(self.sobol_generation_node._input_constructors, {}) self.assertEqual(self.sobol_generation_node.input_constructors, {}) @@ -320,6 +348,7 @@ def test_node_string_representation(self) -> None: generator_specs=[ self.mbm_generator_spec, ], + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, transition_criteria=[ MinTrials( threshold=5, @@ -335,7 +364,8 @@ def test_node_string_representation(self) -> None: "GenerationNode(name='test', " "generator_specs=[GeneratorSpec(generator_enum=BoTorch, " "generator_key_override=None)], " - "transition_criteria=[MinTrials(transition_to='next_node')])", + "transition_criteria=[MinTrials(transition_to='next_node')], " + "suggested_experiment_status=ExperimentStatus.OPTIMIZATION)", ) def test_single_fixed_features(self) -> None: diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 0beccd47c09..7c617316f15 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -889,6 +889,15 @@ def generation_node_from_json( if "trial_type" in generation_node_json.keys() else None ), + suggested_experiment_status=( + object_from_json( + object_json=generation_node_json.pop("suggested_experiment_status"), + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + if "suggested_experiment_status" in generation_node_json.keys() + else None # Default for old records without the field + ), ) diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index bfd6157f129..ed17c193c63 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -411,6 +411,7 @@ def generation_node_to_dict(generation_node: GenerationNode) -> dict[str, Any]: "generator_spec_to_gen_from": generation_node._generator_spec_to_gen_from, "previous_node_name": generation_node._previous_node_name, "trial_type": generation_node._trial_type, + "suggested_experiment_status": generation_node.suggested_experiment_status, # need to manually encode input constructors because the key is an enum. # Our encoding and decoding logic in object_to_json and object_from_json # doesn't recursively encode/decode the keys of dictionaries. diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 6f5ab41aee9..25912543bd9 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -29,6 +29,7 @@ from ax.core.batch_trial import AbandonedArm, BatchTrial from ax.core.data import Data from ax.core.evaluations_to_data import DataType +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.llm_provider import LLMMessage from ax.core.map_metric import MapMetric @@ -316,6 +317,7 @@ "DerivedParameter": DerivedParameter, "DomainType": DomainType, "Experiment": Experiment, + "ExperimentStatus": ExperimentStatus, "FactorialMetric": FactorialMetric, "FilterFeatures": FilterFeatures, "FixedParameter": fixed_parameter_from_json,