From ec457ebf7c881f465a4ffe794d141e8e66580b13 Mon Sep 17 00:00:00 2001 From: Cristian Lara Date: Wed, 11 Feb 2026 14:56:36 -0800 Subject: [PATCH 1/6] Add suggested_experiment_status field (#4902) Summary: This change introduces an optional `suggested_experiment_state` field to the `GenerationNode` class that allows tracking what experiment status is suggested for a given generation node. This is part of a larger effort to add status tracking to experiments. The field is: - Optional (defaults to None for backward compatibility) - Advisory only (does not automatically update experiment.status) - Configurable per GenerationNode instance - Serialized automatically via SerializationMixin - Displayed in __repr__ when set This is Phase 1 where we are just adding the column but not yet doing anything with it. In the next diffs in this stack we will propagate this through the orchestrator and eventually set this status on the experiment. Reviewed By: mgarrard Differential Revision: D88089767 --- .../center_generation_node.py | 15 +++++++- .../external_generation_node.py | 5 +++ ax/generation_strategy/generation_node.py | 15 ++++++++ .../tests/test_generation_node.py | 34 +++++++++++++++++-- ax/storage/json_store/decoder.py | 9 +++++ ax/storage/json_store/encoders.py | 1 + ax/storage/json_store/registry.py | 2 ++ 7 files changed, 78 insertions(+), 3 deletions(-) diff --git a/ax/generation_strategy/center_generation_node.py b/ax/generation_strategy/center_generation_node.py index d8f1e8ee90e..201cf4ed646 100644 --- a/ax/generation_strategy/center_generation_node.py +++ b/ax/generation_strategy/center_generation_node.py @@ -13,6 +13,7 @@ from ax.core.arm import Arm from ax.core.data import Data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.parameter import DerivedParameter @@ -29,7 +30,12 @@ class CenterGenerationNode(ExternalGenerationNode): next_node_name: str - def __init__(self, next_node_name: str) -> None: + def __init__( + self, + next_node_name: str, + suggested_experiment_status: ExperimentStatus + | None = ExperimentStatus.INITIALIZATION, + ) -> None: """A generation node that samples the center of the search space. This generation node is only used to generate the first point of the experiment. After one point is generated, it will transition to `next_node_name`. @@ -37,9 +43,16 @@ def __init__(self, next_node_name: str) -> None: If the generated point is a duplicate of an arm already attached to the experiment, this will fallback to Sobol through the use of ``GenerationNode`` deduplication logic. + + Args: + next_node_name: The name of the node to transition to after generating + the center point. + suggested_experiment_status: Optional suggested experiment status for this + node. """ super().__init__( name="CenterOfSearchSpace", + suggested_experiment_status=suggested_experiment_status, transition_criteria=[ AutoTransitionAfterGen( transition_to=next_node_name, diff --git a/ax/generation_strategy/external_generation_node.py b/ax/generation_strategy/external_generation_node.py index b5192a97b9d..6e153848895 100644 --- a/ax/generation_strategy/external_generation_node.py +++ b/ax/generation_strategy/external_generation_node.py @@ -14,6 +14,7 @@ from ax.core.arm import Arm from ax.core.data import Data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.types import TParameterization @@ -48,6 +49,7 @@ class ExternalGenerationNode(GenerationNode, ABC): def __init__( self, name: str, + suggested_experiment_status: ExperimentStatus | None = None, should_deduplicate: bool = True, transition_criteria: Sequence[TransitionCriterion] | None = None, ) -> None: @@ -59,6 +61,8 @@ def __init__( Args: name: Name of the generation node. + suggested_experiment_status: Optional suggested experiment status for this + node. Defaults to None if not specified. should_deduplicate: Whether to deduplicate the generated points against the existing trials on the experiment. If True, the duplicate points will be discarded and re-generated up to 5 times, after which a @@ -73,6 +77,7 @@ def __init__( super().__init__( name=name, generator_specs=[], + suggested_experiment_status=suggested_experiment_status, best_model_selector=None, should_deduplicate=should_deduplicate, transition_criteria=transition_criteria, diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 89b9a0ec1ae..70b64d641d8 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -15,6 +15,7 @@ from ax.core.data import Data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.trial_status import TrialStatus @@ -112,6 +113,10 @@ class GenerationNode(SerializationMixin, SortableBase): store the most recent previous ``GenerationNode`` name. should_skip: Whether to skip this node during generation time. Defaults to False, and can only currently be set to True via ``NodeInputConstructors`` + suggested_experiment_status: Optional ``ExperimentStatus`` that indicates + what the experiment's status should be once the experiment adds trials + using ``GeneratorRun``-s produced from this node. This is advisory only + and does not automatically update the experiment's status. fallback_specs: Optional dict mapping expected exception types to `ModelSpec` fallbacks used when gen fails. @@ -134,6 +139,7 @@ class GenerationNode(SerializationMixin, SortableBase): _previous_node_name: str | None = None _trial_type: str | None = None _should_skip: bool = False + suggested_experiment_status: ExperimentStatus | None = None fallback_specs: dict[type[Exception], GeneratorSpec] # [TODO] Handle experiment passing more eloquently by enforcing experiment @@ -155,6 +161,7 @@ def __init__( previous_node_name: str | None = None, trial_type: str | None = None, should_skip: bool = False, + suggested_experiment_status: ExperimentStatus | None = None, fallback_specs: dict[type[Exception], GeneratorSpec] | None = None, ) -> None: self._name = name @@ -187,6 +194,7 @@ def __init__( self._previous_node_name = previous_node_name self._trial_type = trial_type self._should_skip = should_skip + self.suggested_experiment_status = suggested_experiment_status self.fallback_specs = ( fallback_specs if fallback_specs is not None else DEFAULT_FALLBACK ) @@ -354,6 +362,10 @@ def __repr__(self) -> str: str_rep += ( f", transition_criteria={str(self._brief_transition_criteria_repr())}" ) + if self.suggested_experiment_status is not None: + str_rep += ( + f", suggested_experiment_status={self.suggested_experiment_status!r}" + ) return f"{str_rep})" def _fit( @@ -999,6 +1011,7 @@ class GenerationStep: whether to transition to the next step. If False, `num_trials` and `min_trials_observed` will only count trials generatd by this step. If True, they will count all trials in the experiment (of corresponding statuses). + suggested_experiment_status: The suggested experiment status for this step. Note for developers: by "generator" here we really mean an ``Adapter`` object, which contains a ``Generator`` under the hood. We call it "generator" here to simplify and @@ -1019,6 +1032,7 @@ def __new__( use_all_trials_in_exp: bool = False, use_update: bool = False, # DEPRECATED. index: int = -1, # Index of this step, set internally. + suggested_experiment_status: ExperimentStatus | None = None, # Deprecated arguments for backwards compatibility. model_kwargs: dict[str, Any] | None = None, model_gen_kwargs: dict[str, Any] | None = None, @@ -1135,6 +1149,7 @@ def __new__( step_index=index, generator_name=resolved_generator_name ), generator_specs=[generator_spec], + suggested_experiment_status=suggested_experiment_status, should_deduplicate=should_deduplicate, transition_criteria=transition_criteria, ) diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index 3cae54f0b27..266873cf5ad 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -11,6 +11,7 @@ import torch from ax.adapter.factory import get_sobol from ax.adapter.registry import Generators +from ax.core.experiment_status import ExperimentStatus from ax.core.observation import ObservationFeatures from ax.core.trial_status import TrialStatus from ax.exceptions.core import UserInputError @@ -53,13 +54,16 @@ def setUp(self) -> None: generator_gen_kwargs={}, ) self.sobol_generation_node = GenerationNode( - name="test", generator_specs=[self.sobol_generator_spec] + name="test", + generator_specs=[self.sobol_generator_spec], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, ) self.branin_experiment = get_branin_experiment(with_completed_trial=True) self.branin_data = self.branin_experiment.lookup_data() self.node_short = GenerationNode( name="test", generator_specs=[self.sobol_generator_spec], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, trial_type=Keys.SHORT_RUN, ) @@ -97,6 +101,30 @@ def test_init(self) -> None: self.assertEqual(node.generator_specs, mbm_specs) self.assertIs(node.best_model_selector, model_selector) + def test_suggested_experiment_status(self) -> None: + """Test that suggested_experiment_status is properly set and accessible.""" + with self.subTest("initialization set"): + self.assertEqual( + self.sobol_generation_node.suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + + with self.subTest("default None when not provided"): + node_without_state = GenerationNode( + name="test", + generator_specs=[self.sobol_generator_spec], + ) + self.assertIsNone(node_without_state.suggested_experiment_status) + + with self.subTest("__repr__ includes status when set"): + repr_str = repr(self.sobol_generation_node) + self.assertIn("suggested_experiment_status", repr_str) + self.assertIn("INITIALIZATION", repr_str) + + with self.subTest("__repr__ excludes status when None"): + repr_str_without = repr(node_without_state) + self.assertNotIn("suggested_experiment_status", repr_str_without) + def test_input_constructor_none(self) -> None: self.assertEqual(self.sobol_generation_node._input_constructors, {}) self.assertEqual(self.sobol_generation_node.input_constructors, {}) @@ -320,6 +348,7 @@ def test_node_string_representation(self) -> None: generator_specs=[ self.mbm_generator_spec, ], + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, transition_criteria=[ MinTrials( threshold=5, @@ -335,7 +364,8 @@ def test_node_string_representation(self) -> None: "GenerationNode(name='test', " "generator_specs=[GeneratorSpec(generator_enum=BoTorch, " "generator_key_override=None)], " - "transition_criteria=[MinTrials(transition_to='next_node')])", + "transition_criteria=[MinTrials(transition_to='next_node')], " + "suggested_experiment_status=ExperimentStatus.OPTIMIZATION)", ) def test_single_fixed_features(self) -> None: diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 0beccd47c09..7c617316f15 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -889,6 +889,15 @@ def generation_node_from_json( if "trial_type" in generation_node_json.keys() else None ), + suggested_experiment_status=( + object_from_json( + object_json=generation_node_json.pop("suggested_experiment_status"), + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + if "suggested_experiment_status" in generation_node_json.keys() + else None # Default for old records without the field + ), ) diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index bfd6157f129..ed17c193c63 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -411,6 +411,7 @@ def generation_node_to_dict(generation_node: GenerationNode) -> dict[str, Any]: "generator_spec_to_gen_from": generation_node._generator_spec_to_gen_from, "previous_node_name": generation_node._previous_node_name, "trial_type": generation_node._trial_type, + "suggested_experiment_status": generation_node.suggested_experiment_status, # need to manually encode input constructors because the key is an enum. # Our encoding and decoding logic in object_to_json and object_from_json # doesn't recursively encode/decode the keys of dictionaries. diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 6f5ab41aee9..25912543bd9 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -29,6 +29,7 @@ from ax.core.batch_trial import AbandonedArm, BatchTrial from ax.core.data import Data from ax.core.evaluations_to_data import DataType +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.llm_provider import LLMMessage from ax.core.map_metric import MapMetric @@ -316,6 +317,7 @@ "DerivedParameter": DerivedParameter, "DomainType": DomainType, "Experiment": Experiment, + "ExperimentStatus": ExperimentStatus, "FactorialMetric": FactorialMetric, "FilterFeatures": FilterFeatures, "FixedParameter": fixed_parameter_from_json, From 290f9bab16f6f3648c5051827b48cb72f1dca3ab Mon Sep 17 00:00:00 2001 From: Cristian Lara Date: Wed, 11 Feb 2026 14:56:36 -0800 Subject: [PATCH 2/6] Set default suggested_experiment_status for nodes (#4903) Summary: Set the `suggested_experiment_status` field on generation nodes throughout the codebase. This will be used in the next diffs in stack to allow the orchestrator to automatically update experiment status based on the optimization phase. **Why this change:** This builds on the infrastructure added in previous diffs which added the status field to Experiment, and GenerationNode. Now we're actually setting the field on nodes so that experiments will automatically transition through their lifecycle based on what GS node they're in. Reviewed By: lena-kashtelyan, mgarrard Differential Revision: D88214256 --- ax/api/utils/generation_strategy_dispatch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ax/api/utils/generation_strategy_dispatch.py b/ax/api/utils/generation_strategy_dispatch.py index d85811019ba..780fba7db56 100644 --- a/ax/api/utils/generation_strategy_dispatch.py +++ b/ax/api/utils/generation_strategy_dispatch.py @@ -12,6 +12,7 @@ import torch from ax.adapter.registry import Generators from ax.api.utils.structs import GenerationStrategyDispatchStruct +from ax.core.experiment_status import ExperimentStatus from ax.core.trial_status import TrialStatus from ax.exceptions.core import UnsupportedError, UserInputError from ax.generation_strategy.center_generation_node import CenterGenerationNode @@ -95,6 +96,7 @@ def _get_sobol_node( ], transition_criteria=transition_criteria, should_deduplicate=True, + suggested_experiment_status=ExperimentStatus.INITIALIZATION, ) @@ -175,6 +177,7 @@ def _get_mbm_node( ) ], should_deduplicate=True, + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, ), mbm_name @@ -225,6 +228,7 @@ def choose_generation_strategy( generator_kwargs={"seed": struct.initialization_random_seed}, ) ], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, ) ] gs_name = "QuasiRandomSearch" From 7db9ed1e8f56bbbe570266ca1c3c719e393c28ea Mon Sep 17 00:00:00 2001 From: Cristian Lara Date: Wed, 11 Feb 2026 14:56:36 -0800 Subject: [PATCH 3/6] Add suggested_experiment_status column to GeneratorRun (#4886) Summary: ## Summary Add `suggested_experiment_status` column to `GeneratorRun`. Some benefits: 1. We don't need to modify the GS.gen() or Orchestrator methods to pass along a suggested status via tuple, instead it's baked into the GeneratorRuns that are already being passed along 2. The suggested status are more clearly stored in the database for historical tracking Prior to this approach I tried changing `GS.gen()` to return a tuple including the `suggested_experiment_status` but that over-complicated callsites. ## AOSC DIFF D92476170 Differential Revision: D88091530 --- ax/core/generator_run.py | 8 ++++++ ax/core/tests/test_generator_run.py | 20 ++++++++++++++ ax/storage/json_store/encoders.py | 1 + ax/storage/sqa_store/decoder.py | 1 + ax/storage/sqa_store/encoder.py | 1 + ax/storage/sqa_store/sqa_classes.py | 3 ++ ax/storage/sqa_store/tests/test_sqa_store.py | 29 ++++++++++++++++++++ ax/utils/testing/core_stubs.py | 2 ++ 8 files changed, 65 insertions(+) diff --git a/ax/core/generator_run.py b/ax/core/generator_run.py index 2235e2c8e9a..e7048173088 100644 --- a/ax/core/generator_run.py +++ b/ax/core/generator_run.py @@ -18,6 +18,7 @@ import pandas as pd from ax.core.arm import Arm +from ax.core.experiment_status import ExperimentStatus from ax.core.optimization_config import OptimizationConfig from ax.core.search_space import SearchSpace from ax.core.types import ( @@ -100,6 +101,7 @@ def __init__( candidate_metadata_by_arm_signature: None | (dict[str, TCandidateMetadata]) = None, generation_node_name: str | None = None, + suggested_experiment_status: ExperimentStatus | None = None, ) -> None: """Inits GeneratorRun. @@ -142,6 +144,10 @@ def __init__( via a generation strategy (in which case this name should reflect the name of the generation node in a generation strategy) or a standalone generation node (in which case this name should be ``-1``). + suggested_experiment_status: Optional ``ExperimentStatus`` that indicates + what the experiment's status should be once this generator run is + added to a trial. This is propagated from the generation node's + suggested_experiment_status field and is advisory only. """ self._arm_weight_table: OrderedDict[str, ArmWeight] = OrderedDict() if weights is None: @@ -191,6 +197,7 @@ def __init__( ) self._candidate_metadata_by_arm_signature = candidate_metadata_by_arm_signature self._generation_node_name = generation_node_name + self.suggested_experiment_status = suggested_experiment_status @property def arms(self) -> list[Arm]: @@ -327,6 +334,7 @@ def clone(self) -> GeneratorRun: generator_state_after_gen=self._generator_state_after_gen, candidate_metadata_by_arm_signature=cand_metadata, generation_node_name=self._generation_node_name, + suggested_experiment_status=self.suggested_experiment_status, ) generator_run._time_created = self._time_created generator_run._generator_key = self._generator_key diff --git a/ax/core/tests/test_generator_run.py b/ax/core/tests/test_generator_run.py index 8a372be2984..9360c93293f 100644 --- a/ax/core/tests/test_generator_run.py +++ b/ax/core/tests/test_generator_run.py @@ -7,6 +7,7 @@ # pyre-strict from ax.core.arm import Arm +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -48,6 +49,10 @@ def setUp(self) -> None: search_space=self.search_space, model_predictions=self.model_predictions, ) + self.run_with_suggested_status = GeneratorRun( + arms=self.arms, + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) def test_Init(self) -> None: self.assertEqual( @@ -184,3 +189,18 @@ def test_Sortable(self) -> None: weights=self.weights, ) self.assertTrue(generator_run1 < generator_run2) + + def test_SuggestedExperimentStatus(self) -> None: + self.assertEqual( + self.run_with_suggested_status.suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + + def test_SuggestedExperimentStatusDefaultNone(self) -> None: + self.assertIsNone(self.unweighted_run.suggested_experiment_status) + + def test_ClonePreservesSuggestedExperimentStatus(self) -> None: + cloned = self.run_with_suggested_status.clone() + self.assertEqual( + cloned.suggested_experiment_status, ExperimentStatus.INITIALIZATION + ) diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index ed17c193c63..87b6b44c045 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -367,6 +367,7 @@ def generator_run_to_dict(generator_run: GeneratorRun) -> dict[str, Any]: "generator_state_after_gen": gr._generator_state_after_gen, "candidate_metadata_by_arm_signature": cand_metadata, "generation_node_name": gr._generation_node_name, + "suggested_experiment_status": gr.suggested_experiment_status, } diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 612f440f716..63b83add043 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -813,6 +813,7 @@ def generator_run_from_sqa( class_decoder_registry=self.config.json_class_decoder_registry, ), generation_node_name=generator_run_sqa.generation_node_name, + suggested_experiment_status=generator_run_sqa.suggested_experiment_status, ) # Remove deprecated kwargs from generator kwargs & adapter kwargs. if generator_run._generator_kwargs is not None: diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index b96877a228f..87bf58cb88f 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -897,6 +897,7 @@ def generator_run_to_sqa( class_encoder_registry=self.config.json_class_encoder_registry, ), generation_node_name=generator_run._generation_node_name, + suggested_experiment_status=generator_run.suggested_experiment_status, ) return gr_sqa diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index dff59cb9e7a..1176634b9e2 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -215,6 +215,9 @@ class SQAGeneratorRun(Base): JSONEncodedTextDict ) generation_node_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + suggested_experiment_status: Column[ExperimentStatus | None] = Column( + IntEnum(ExperimentStatus), nullable=True + ) # relationships # Use selectin loading for collections to prevent idle timeout errors diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 9d72ea425dc..2bc73a7ed88 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -2539,6 +2539,35 @@ def test_generator_run_gen_metadata(self) -> None: ) self.assertEqual(decoded_gr.gen_metadata, gen_metadata) + def test_generator_run_suggested_experiment_status(self) -> None: + # Test round-trip with a status set. + gr = GeneratorRun( + arms=[], + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, + ) + generator_run_sqa = self.encoder.generator_run_to_sqa(gr) + self.assertEqual( + generator_run_sqa.suggested_experiment_status, + ExperimentStatus.OPTIMIZATION, + ) + decoded_gr = self.decoder.generator_run_from_sqa( + generator_run_sqa, False, False + ) + self.assertEqual( + decoded_gr.suggested_experiment_status, + ExperimentStatus.OPTIMIZATION, + ) + + def test_generator_run_suggested_experiment_status_none(self) -> None: + # Test round-trip with None (default). + gr = GeneratorRun(arms=[]) + generator_run_sqa = self.encoder.generator_run_to_sqa(gr) + self.assertIsNone(generator_run_sqa.suggested_experiment_status) + decoded_gr = self.decoder.generator_run_from_sqa( + generator_run_sqa, False, False + ) + self.assertIsNone(decoded_gr.suggested_experiment_status) + def test_update_generation_strategy_incrementally(self) -> None: experiment = get_branin_experiment() generation_strategy = choose_generation_strategy( diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 62b36197304..9c17c956603 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -31,6 +31,7 @@ from ax.core.data import Data from ax.core.evaluations_to_data import raw_evaluations_to_data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.map_metric import MapMetric from ax.core.metric import Metric @@ -2383,6 +2384,7 @@ def get_generator_run() -> GeneratorRun: candidate_metadata_by_arm_signature={ a.signature: {"md_key": f"md_val_{a.signature}"} for a in arms }, + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, ) From 86cdbae3d4615e3fcb73698c23d2c59ef863010d Mon Sep 17 00:00:00 2001 From: Cristian Lara Date: Wed, 11 Feb 2026 14:56:36 -0800 Subject: [PATCH 4/6] Propagate suggested_experiment_status from GenerationNode to GeneratorRun (#4885) Summary: In the previous diff (D88091530) we added `suggested_experiment_status` the column to GeneratorRun, now we populate it during creation from GenerationNode. Differential Revision: D92555215 --- ax/generation_strategy/generation_node.py | 1 + .../tests/test_generation_node.py | 29 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 70b64d641d8..26cf4cb29a5 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -517,6 +517,7 @@ def gen( ) gr._generation_node_name = self.name + gr.suggested_experiment_status = self.suggested_experiment_status # TODO: When we start using `trial_type` more commonly, give it a dedicated # field on the `GeneratorRun` (or start creating trials from GS directly). if self._trial_type is not None: diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index 266873cf5ad..b5444b85e3e 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -58,6 +58,10 @@ def setUp(self) -> None: generator_specs=[self.sobol_generator_spec], suggested_experiment_status=ExperimentStatus.INITIALIZATION, ) + self.generation_node_without_state = GenerationNode( + name="test", + generator_specs=[self.sobol_generator_spec], + ) self.branin_experiment = get_branin_experiment(with_completed_trial=True) self.branin_data = self.branin_experiment.lookup_data() self.node_short = GenerationNode( @@ -209,6 +213,31 @@ def test_gen(self) -> None: fixed_features=None, ) + def test_suggested_experiment_status_propagation(self) -> None: + """Test that suggested_experiment_status propagates from node to GR.""" + with self.subTest("with_suggested_experiment_status"): + gr = self.sobol_generation_node.gen( + experiment=self.branin_experiment, + data=self.branin_experiment.lookup_data(), + n=1, + pending_observations={"branin": []}, + ) + self.assertIsNotNone(gr) + self.assertEqual( + gr.suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + + with self.subTest("without_suggested_experiment_status"): + gr_without = self.generation_node_without_state.gen( + experiment=self.branin_experiment, + data=self.branin_experiment.lookup_data(), + n=1, + pending_observations={"branin": []}, + ) + self.assertIsNotNone(gr_without) + self.assertIsNone(gr_without.suggested_experiment_status) + @mock_botorch_optimize def test_gen_with_trial_type(self) -> None: mbm_short = GenerationNode( From fd4c2623b892c7e55dae1d25f4e24a5075288d3e Mon Sep 17 00:00:00 2001 From: Cristian Lara Date: Wed, 11 Feb 2026 14:56:36 -0800 Subject: [PATCH 5/6] Method to consolidate Experiment.status from generator runs (#4900) Summary: Add a new static method `experiment_status_from_generator_runs()` to `GenerationStrategy` that extracts and validates a suggested ExperimentStatus from a list of GeneratorRun objects. It collects all unique suggested_experiment_status values from the runs and: - Returns None with a warning if there are conflicting statuses across runs - Returns None with an info log if no statuses are found - Returns the single agreed-upon status otherwise Differential Revision: D92985915 --- ax/generation_strategy/generation_strategy.py | 38 ++++++++++ .../tests/test_generation_strategy.py | 69 +++++++++++++++++++ 2 files changed, 107 insertions(+) diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index 2076605ec43..48ec33d895e 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -16,6 +16,7 @@ from ax.adapter.base import Adapter from ax.core.data import Data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.utils import extend_pending_observations, extract_pending_observations @@ -312,6 +313,43 @@ def gen( ) return grs_for_multiple_trials + @staticmethod + def experiment_status_from_generator_runs( + generator_runs: list[GeneratorRun], + ) -> ExperimentStatus | None: + """Extract and validate suggested experiment status from generator runs. + + Collects the suggested_experiment_status directly from the GeneratorRun + objects, validates that all runs suggest the same status, and returns + that status. + + Args: + generator_runs: List of generator runs to extract statuses from. + + Returns: + The suggested experiment status that all generator runs agree on, + or None if no statuses were found or if there are conflicting statuses. + """ + suggested_statuses: set[ExperimentStatus] = set() + for gr in generator_runs: + if gr.suggested_experiment_status is not None: + suggested_statuses.add(gr.suggested_experiment_status) + + if len(suggested_statuses) > 1: + logger.warning( + "Multiple different suggested experiment statuses found: " + f"{suggested_statuses}. " + "All generator runs used in a single gen() call should suggest the " + "same experiment status. Skipping updating experiment status." + ) + return None + + if len(suggested_statuses) == 0: + logger.info("No suggested_experiment_status found on any generator runs.") + return None + + return suggested_statuses.pop() + def current_generator_run_limit( self, ) -> tuple[int, bool]: diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index cb954845614..b3795a711df 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -24,6 +24,7 @@ from ax.adapter.torch import TorchAdapter from ax.core.arm import Arm from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.parameter import ChoiceParameter, FixedParameter, Parameter, ParameterType @@ -2022,6 +2023,74 @@ def test_optimization_complete_single_node_no_criteria(self) -> None: self.assertFalse(gs.optimization_complete) + def test_experiment_status_from_generation_strategy(self) -> None: + """Test that experiment status is correctly propagated through + generator runs and extracted via experiment_status_from_generator_runs.""" + + with self.subTest("gen returns GRs with correct suggested_experiment_status"): + for status in [ + ExperimentStatus.INITIALIZATION, + ExperimentStatus.OPTIMIZATION, + ]: + with self.subTest(status=status): + exp = get_branin_experiment() + node_with_status = GenerationNode( + name="test_node", + generator_specs=[self.sobol_generator_spec], + suggested_experiment_status=status, + ) + gs = GenerationStrategy(nodes=[node_with_status]) + gs.experiment = exp + + grs = gs.gen(experiment=exp, num_trials=1) + flat_grs = [gr for trial_grs in grs for gr in trial_grs] + + extracted_status = ( + GenerationStrategy.experiment_status_from_generator_runs( + flat_grs + ) + ) + self.assertEqual(extracted_status, status) + + with self.subTest("conflicting statuses return None"): + gr1 = GeneratorRun( + arms=[Arm(name="0_0", parameters={"x1": 0.0, "x2": 0.0})], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) + gr2 = GeneratorRun( + arms=[Arm(name="0_1", parameters={"x1": 1.0, "x2": 1.0})], + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, + ) + mixed_grs = [gr1, gr2] + + result = GenerationStrategy.experiment_status_from_generator_runs(mixed_grs) + self.assertIsNone(result) + + with self.subTest("multiple trials all carry experiment status"): + exp = get_branin_experiment() + node_with_status = GenerationNode( + name="multi_trial_node", + generator_specs=[self.sobol_generator_spec], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) + gs = GenerationStrategy(nodes=[node_with_status]) + gs.experiment = exp + + grs = gs.gen(experiment=exp, num_trials=3) + + self.assertEqual(len(grs), 3) + for gr_list in grs: + self.assertEqual(len(gr_list), 1) + self.assertEqual(gr_list[0]._generation_node_name, "multi_trial_node") + self.assertEqual( + gr_list[0].suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + extracted_status = GenerationStrategy.experiment_status_from_generator_runs( + [gr for trial_grs in grs for gr in trial_grs] + ) + self.assertEqual(extracted_status, ExperimentStatus.INITIALIZATION) + # ------------- Testing helpers (put tests above this line) ------------- def _run_GS_for_N_rounds( From 5b61d6e4bb907fc253d5b97f422f50404ee3a3e7 Mon Sep 17 00:00:00 2001 From: Cristian Lara Date: Wed, 11 Feb 2026 14:56:36 -0800 Subject: [PATCH 6/6] Scheduler updates experiment status from generation strategy (#4891) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Connect the experiment lifecycle tracking system to the Ax Scheduler (Orchestrator), enabling automatic experiment status updates based on which optimization phase the generation strategy is in. Key Changes 🔧 1. Orchestrator (orchestrator.py): After generating new trials, extracts suggested experiment status from generator runs and updates experiment.status 2. DB Persistence (save.py + with_db_settings_base.py): Added update_experiment_status() function and wired it to save status changes to DB after trial generation 3. Tests: Added test_generate_candidates_updates_experiment_status() to verify the whole flow works Differential Revision: D87589267 --- ax/orchestration/orchestrator.py | 5 +++ ax/orchestration/tests/test_orchestrator.py | 43 +++++++++++++++++++ ax/storage/sqa_store/save.py | 29 +++++++++++++ ax/storage/sqa_store/with_db_settings_base.py | 40 +++++++++++++++++ 4 files changed, 117 insertions(+) diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index a35cfe8aa51..9c82a2175b0 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -540,6 +540,11 @@ def generate_candidates( if len(new_trials) > 0: new_generator_runs = [gr for t in new_trials for gr in t.generator_runs] + suggested_status = GenerationStrategy.experiment_status_from_generator_runs( + new_generator_runs + ) + if suggested_status is not None: + self.experiment.status = suggested_status self._save_or_update_trials_and_generation_strategy_if_possible( experiment=self.experiment, trials=new_trials + self.experiment.trials_by_status[TrialStatus.STALE], diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 67f982ca1a1..fd858a30d0f 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -24,6 +24,7 @@ from ax.core.batch_trial import BatchTrial from ax.core.data import Data, MAP_KEY from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric from ax.core.multi_type_experiment import MultiTypeExperiment @@ -49,6 +50,7 @@ GenerationStep, GenerationStrategy, ) +from ax.generation_strategy.generator_spec import GeneratorSpec from ax.generation_strategy.transition_criterion import MaxGenerationParallelism from ax.metrics.branin import BraninMetric from ax.metrics.branin_map import BraninTimestampMapMetric @@ -2667,6 +2669,47 @@ def test_generate_candidates_works_for_iteration(self) -> None: len(candidate_trial.arms), none_throws(orchestrator.options.batch_size) ) + def test_generate_candidates_updates_experiment_status(self) -> None: + init_test_engine_and_session_factory(force_init=True) + node_with_status = GenerationNode( + name="test_node", + generator_specs=[ + GeneratorSpec( + generator_enum=Generators.SOBOL, + model_kwargs={}, + ) + ], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) + gs = GenerationStrategy(nodes=[node_with_status]) + + # Create orchestrator with this generation strategy + self.branin_experiment.runner = InfinitePollRunner() + orchestrator = Orchestrator( + experiment=self.branin_experiment, + generation_strategy=gs, + options=OrchestratorOptions( + init_seconds_between_polls=0, + batch_size=1, + trial_type=TrialType.BATCH_TRIAL, + **self.orchestrator_options_kwargs, + ), + db_settings=self.db_settings, + ) + + # Verify the experiment status is not currently ExperimentStatus.INITIALIZATION + self.assertNotEqual( + orchestrator.experiment.status, ExperimentStatus.INITIALIZATION + ) + + # Execute: generate candidates + orchestrator.generate_candidates(num_trials=1) + + # Assert: verify experiment status was updated + self.assertEqual( + orchestrator.experiment.status, ExperimentStatus.INITIALIZATION + ) + def test_generate_candidates_does_not_generate_if_missing_data(self) -> None: # GIVEN a orchestrator that can't fetch data self.branin_experiment.optimization_config = OptimizationConfig( diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index cd3f6ba2eb3..9866be7688c 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -531,6 +531,35 @@ def update_properties_on_experiment( ) +def update_experiment_status( + experiment: Experiment, + config: SQAConfig | None = None, +) -> None: + """Update experiment status in the database. + + This function provides an efficient way to update only the experiment's status + field without re-saving the entire experiment. Use this when you need to persist + status changes immediately after calling status transition methods + (e.g., mark_initialization(), mark_optimization()). + + Note: save_experiment() already handles status updates, so this function is + optional. Use it when you need status-only updates for efficiency. + """ + config = SQAConfig() if config is None else config + exp_sqa_class = config.class_to_sqa_class[Experiment] + + exp_id = experiment.db_id + if exp_id is None: + raise UserInputError("Experiment must be saved before being updated.") + + with session_scope() as session: + session.query(exp_sqa_class).filter_by(id=exp_id).update( + { + "status": experiment.status, + } + ) + + def update_properties_on_trial( trial_with_updated_properties: BaseTrial, config: SQAConfig | None = None, diff --git a/ax/storage/sqa_store/with_db_settings_base.py b/ax/storage/sqa_store/with_db_settings_base.py index 0defee1c634..826ae469fe7 100644 --- a/ax/storage/sqa_store/with_db_settings_base.py +++ b/ax/storage/sqa_store/with_db_settings_base.py @@ -65,6 +65,7 @@ _save_or_update_trials, _update_generation_strategy, save_analysis_card, + update_experiment_status, update_properties_on_experiment, update_runner_on_experiment, ) @@ -325,6 +326,8 @@ def _save_or_update_trials_and_generation_strategy_if_possible( new_generator_runs=new_generator_runs, reduce_state_generator_runs=reduce_state_generator_runs, ) + if experiment.status is not None: + self._update_experiment_status_in_db_if_possible(experiment) return # No retries needed, covered in `self._save_or_update_trials_in_db_if_possible` @@ -467,6 +470,27 @@ def _update_experiment_properties_in_db( return True return False + def _update_experiment_status_in_db_if_possible( + self, + experiment: Experiment, + ) -> bool: + """Update experiment status in the database if DB settings are configured. + + Args: + experiment: Experiment with updated status. + + Returns: + True if the update was performed, False if DB settings are not configured. + """ + if self.db_settings_set: + _update_experiment_status_in_db( + experiment=experiment, + sqa_config=self.db_settings.encoder.config, + suppress_all_errors=self._suppress_all_errors, + ) + return True + return False + def _save_analysis_card_to_db_if_possible( self, experiment: Experiment, @@ -624,6 +648,22 @@ def _update_experiment_properties_in_db( ) +@retry_on_exception( + retries=3, + default_return_on_suppression=False, + exception_types=RETRY_EXCEPTION_TYPES, +) +def _update_experiment_status_in_db( + experiment: Experiment, + sqa_config: SQAConfig, + suppress_all_errors: bool, # Used by the decorator. +) -> None: + update_experiment_status( + experiment=experiment, + config=sqa_config, + ) + + @retry_on_exception( retries=3, default_return_on_suppression=False,