Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ax/api/utils/generation_strategy_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from ax.adapter.registry import Generators
from ax.api.utils.structs import GenerationStrategyDispatchStruct
from ax.core.experiment_status import ExperimentStatus
from ax.core.trial_status import TrialStatus
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.generation_strategy.center_generation_node import CenterGenerationNode
Expand Down Expand Up @@ -95,6 +96,7 @@ def _get_sobol_node(
],
transition_criteria=transition_criteria,
should_deduplicate=True,
suggested_experiment_status=ExperimentStatus.INITIALIZATION,
)


Expand Down Expand Up @@ -175,6 +177,7 @@ def _get_mbm_node(
)
],
should_deduplicate=True,
suggested_experiment_status=ExperimentStatus.OPTIMIZATION,
), mbm_name


Expand Down Expand Up @@ -225,6 +228,7 @@ def choose_generation_strategy(
generator_kwargs={"seed": struct.initialization_random_seed},
)
],
suggested_experiment_status=ExperimentStatus.INITIALIZATION,
)
]
gs_name = "QuasiRandomSearch"
Expand Down
8 changes: 8 additions & 0 deletions ax/core/generator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pandas as pd
from ax.core.arm import Arm
from ax.core.experiment_status import ExperimentStatus
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.core.types import (
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(
candidate_metadata_by_arm_signature: None
| (dict[str, TCandidateMetadata]) = None,
generation_node_name: str | None = None,
suggested_experiment_status: ExperimentStatus | None = None,
) -> None:
"""Inits GeneratorRun.

Expand Down Expand Up @@ -142,6 +144,10 @@ def __init__(
via a generation strategy (in which case this name should reflect the
name of the generation node in a generation strategy) or a standalone
generation node (in which case this name should be ``-1``).
suggested_experiment_status: Optional ``ExperimentStatus`` that indicates
what the experiment's status should be once this generator run is
added to a trial. This is propagated from the generation node's
suggested_experiment_status field and is advisory only.
"""
self._arm_weight_table: OrderedDict[str, ArmWeight] = OrderedDict()
if weights is None:
Expand Down Expand Up @@ -191,6 +197,7 @@ def __init__(
)
self._candidate_metadata_by_arm_signature = candidate_metadata_by_arm_signature
self._generation_node_name = generation_node_name
self.suggested_experiment_status = suggested_experiment_status

@property
def arms(self) -> list[Arm]:
Expand Down Expand Up @@ -327,6 +334,7 @@ def clone(self) -> GeneratorRun:
generator_state_after_gen=self._generator_state_after_gen,
candidate_metadata_by_arm_signature=cand_metadata,
generation_node_name=self._generation_node_name,
suggested_experiment_status=self.suggested_experiment_status,
)
generator_run._time_created = self._time_created
generator_run._generator_key = self._generator_key
Expand Down
20 changes: 20 additions & 0 deletions ax/core/tests/test_generator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

from ax.core.arm import Arm
from ax.core.experiment_status import ExperimentStatus
from ax.core.generator_run import GeneratorRun
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand Down Expand Up @@ -48,6 +49,10 @@ def setUp(self) -> None:
search_space=self.search_space,
model_predictions=self.model_predictions,
)
self.run_with_suggested_status = GeneratorRun(
arms=self.arms,
suggested_experiment_status=ExperimentStatus.INITIALIZATION,
)

def test_Init(self) -> None:
self.assertEqual(
Expand Down Expand Up @@ -184,3 +189,18 @@ def test_Sortable(self) -> None:
weights=self.weights,
)
self.assertTrue(generator_run1 < generator_run2)

def test_SuggestedExperimentStatus(self) -> None:
self.assertEqual(
self.run_with_suggested_status.suggested_experiment_status,
ExperimentStatus.INITIALIZATION,
)

def test_SuggestedExperimentStatusDefaultNone(self) -> None:
self.assertIsNone(self.unweighted_run.suggested_experiment_status)

def test_ClonePreservesSuggestedExperimentStatus(self) -> None:
cloned = self.run_with_suggested_status.clone()
self.assertEqual(
cloned.suggested_experiment_status, ExperimentStatus.INITIALIZATION
)
15 changes: 14 additions & 1 deletion ax/generation_strategy/center_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ax.core.arm import Arm
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.experiment_status import ExperimentStatus
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.parameter import DerivedParameter
Expand All @@ -29,17 +30,29 @@
class CenterGenerationNode(ExternalGenerationNode):
next_node_name: str

def __init__(self, next_node_name: str) -> None:
def __init__(
self,
next_node_name: str,
suggested_experiment_status: ExperimentStatus
| None = ExperimentStatus.INITIALIZATION,
) -> None:
"""A generation node that samples the center of the search space.
This generation node is only used to generate the first point of the experiment.
After one point is generated, it will transition to `next_node_name`.

If the generated point is a duplicate of an arm already attached to the
experiment, this will fallback to Sobol through the use of ``GenerationNode``
deduplication logic.

Args:
next_node_name: The name of the node to transition to after generating
the center point.
suggested_experiment_status: Optional suggested experiment status for this
node.
"""
super().__init__(
name="CenterOfSearchSpace",
suggested_experiment_status=suggested_experiment_status,
transition_criteria=[
AutoTransitionAfterGen(
transition_to=next_node_name,
Expand Down
5 changes: 5 additions & 0 deletions ax/generation_strategy/external_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ax.core.arm import Arm
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.experiment_status import ExperimentStatus
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.types import TParameterization
Expand Down Expand Up @@ -48,6 +49,7 @@ class ExternalGenerationNode(GenerationNode, ABC):
def __init__(
self,
name: str,
suggested_experiment_status: ExperimentStatus | None = None,
should_deduplicate: bool = True,
transition_criteria: Sequence[TransitionCriterion] | None = None,
) -> None:
Expand All @@ -59,6 +61,8 @@ def __init__(

Args:
name: Name of the generation node.
suggested_experiment_status: Optional suggested experiment status for this
node. Defaults to None if not specified.
should_deduplicate: Whether to deduplicate the generated points against
the existing trials on the experiment. If True, the duplicate points
will be discarded and re-generated up to 5 times, after which a
Expand All @@ -73,6 +77,7 @@ def __init__(
super().__init__(
name=name,
generator_specs=[],
suggested_experiment_status=suggested_experiment_status,
best_model_selector=None,
should_deduplicate=should_deduplicate,
transition_criteria=transition_criteria,
Expand Down
16 changes: 16 additions & 0 deletions ax/generation_strategy/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.experiment_status import ExperimentStatus
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.trial_status import TrialStatus
Expand Down Expand Up @@ -112,6 +113,10 @@ class GenerationNode(SerializationMixin, SortableBase):
store the most recent previous ``GenerationNode`` name.
should_skip: Whether to skip this node during generation time. Defaults to
False, and can only currently be set to True via ``NodeInputConstructors``
suggested_experiment_status: Optional ``ExperimentStatus`` that indicates
what the experiment's status should be once the experiment adds trials
using ``GeneratorRun``-s produced from this node. This is advisory only
and does not automatically update the experiment's status.
fallback_specs: Optional dict mapping expected exception types to `ModelSpec`
fallbacks used when gen fails.

Expand All @@ -134,6 +139,7 @@ class GenerationNode(SerializationMixin, SortableBase):
_previous_node_name: str | None = None
_trial_type: str | None = None
_should_skip: bool = False
suggested_experiment_status: ExperimentStatus | None = None
fallback_specs: dict[type[Exception], GeneratorSpec]

# [TODO] Handle experiment passing more eloquently by enforcing experiment
Expand All @@ -155,6 +161,7 @@ def __init__(
previous_node_name: str | None = None,
trial_type: str | None = None,
should_skip: bool = False,
suggested_experiment_status: ExperimentStatus | None = None,
fallback_specs: dict[type[Exception], GeneratorSpec] | None = None,
) -> None:
self._name = name
Expand Down Expand Up @@ -187,6 +194,7 @@ def __init__(
self._previous_node_name = previous_node_name
self._trial_type = trial_type
self._should_skip = should_skip
self.suggested_experiment_status = suggested_experiment_status
self.fallback_specs = (
fallback_specs if fallback_specs is not None else DEFAULT_FALLBACK
)
Expand Down Expand Up @@ -354,6 +362,10 @@ def __repr__(self) -> str:
str_rep += (
f", transition_criteria={str(self._brief_transition_criteria_repr())}"
)
if self.suggested_experiment_status is not None:
str_rep += (
f", suggested_experiment_status={self.suggested_experiment_status!r}"
)
return f"{str_rep})"

def _fit(
Expand Down Expand Up @@ -505,6 +517,7 @@ def gen(
)

gr._generation_node_name = self.name
gr.suggested_experiment_status = self.suggested_experiment_status
# TODO: When we start using `trial_type` more commonly, give it a dedicated
# field on the `GeneratorRun` (or start creating trials from GS directly).
if self._trial_type is not None:
Expand Down Expand Up @@ -999,6 +1012,7 @@ class GenerationStep:
whether to transition to the next step. If False, `num_trials` and
`min_trials_observed` will only count trials generatd by this step. If True,
they will count all trials in the experiment (of corresponding statuses).
suggested_experiment_status: The suggested experiment status for this step.

Note for developers: by "generator" here we really mean an ``Adapter`` object, which
contains a ``Generator`` under the hood. We call it "generator" here to simplify and
Expand All @@ -1019,6 +1033,7 @@ def __new__(
use_all_trials_in_exp: bool = False,
use_update: bool = False, # DEPRECATED.
index: int = -1, # Index of this step, set internally.
suggested_experiment_status: ExperimentStatus | None = None,
# Deprecated arguments for backwards compatibility.
model_kwargs: dict[str, Any] | None = None,
model_gen_kwargs: dict[str, Any] | None = None,
Expand Down Expand Up @@ -1135,6 +1150,7 @@ def __new__(
step_index=index, generator_name=resolved_generator_name
),
generator_specs=[generator_spec],
suggested_experiment_status=suggested_experiment_status,
should_deduplicate=should_deduplicate,
transition_criteria=transition_criteria,
)
Expand Down
63 changes: 61 additions & 2 deletions ax/generation_strategy/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from ax.adapter.factory import get_sobol
from ax.adapter.registry import Generators
from ax.core.experiment_status import ExperimentStatus
from ax.core.observation import ObservationFeatures
from ax.core.trial_status import TrialStatus
from ax.exceptions.core import UserInputError
Expand Down Expand Up @@ -53,13 +54,20 @@ def setUp(self) -> None:
generator_gen_kwargs={},
)
self.sobol_generation_node = GenerationNode(
name="test", generator_specs=[self.sobol_generator_spec]
name="test",
generator_specs=[self.sobol_generator_spec],
suggested_experiment_status=ExperimentStatus.INITIALIZATION,
)
self.generation_node_without_state = GenerationNode(
name="test",
generator_specs=[self.sobol_generator_spec],
)
self.branin_experiment = get_branin_experiment(with_completed_trial=True)
self.branin_data = self.branin_experiment.lookup_data()
self.node_short = GenerationNode(
name="test",
generator_specs=[self.sobol_generator_spec],
suggested_experiment_status=ExperimentStatus.INITIALIZATION,
trial_type=Keys.SHORT_RUN,
)

Expand Down Expand Up @@ -97,6 +105,30 @@ def test_init(self) -> None:
self.assertEqual(node.generator_specs, mbm_specs)
self.assertIs(node.best_model_selector, model_selector)

def test_suggested_experiment_status(self) -> None:
"""Test that suggested_experiment_status is properly set and accessible."""
with self.subTest("initialization set"):
self.assertEqual(
self.sobol_generation_node.suggested_experiment_status,
ExperimentStatus.INITIALIZATION,
)

with self.subTest("default None when not provided"):
node_without_state = GenerationNode(
name="test",
generator_specs=[self.sobol_generator_spec],
)
self.assertIsNone(node_without_state.suggested_experiment_status)

with self.subTest("__repr__ includes status when set"):
repr_str = repr(self.sobol_generation_node)
self.assertIn("suggested_experiment_status", repr_str)
self.assertIn("INITIALIZATION", repr_str)

with self.subTest("__repr__ excludes status when None"):
repr_str_without = repr(node_without_state)
self.assertNotIn("suggested_experiment_status", repr_str_without)

def test_input_constructor_none(self) -> None:
self.assertEqual(self.sobol_generation_node._input_constructors, {})
self.assertEqual(self.sobol_generation_node.input_constructors, {})
Expand Down Expand Up @@ -181,6 +213,31 @@ def test_gen(self) -> None:
fixed_features=None,
)

def test_suggested_experiment_status_propagation(self) -> None:
"""Test that suggested_experiment_status propagates from node to GR."""
with self.subTest("with_suggested_experiment_status"):
gr = self.sobol_generation_node.gen(
experiment=self.branin_experiment,
data=self.branin_experiment.lookup_data(),
n=1,
pending_observations={"branin": []},
)
self.assertIsNotNone(gr)
self.assertEqual(
gr.suggested_experiment_status,
ExperimentStatus.INITIALIZATION,
)

with self.subTest("without_suggested_experiment_status"):
gr_without = self.generation_node_without_state.gen(
experiment=self.branin_experiment,
data=self.branin_experiment.lookup_data(),
n=1,
pending_observations={"branin": []},
)
self.assertIsNotNone(gr_without)
self.assertIsNone(gr_without.suggested_experiment_status)

@mock_botorch_optimize
def test_gen_with_trial_type(self) -> None:
mbm_short = GenerationNode(
Expand Down Expand Up @@ -320,6 +377,7 @@ def test_node_string_representation(self) -> None:
generator_specs=[
self.mbm_generator_spec,
],
suggested_experiment_status=ExperimentStatus.OPTIMIZATION,
transition_criteria=[
MinTrials(
threshold=5,
Expand All @@ -335,7 +393,8 @@ def test_node_string_representation(self) -> None:
"GenerationNode(name='test', "
"generator_specs=[GeneratorSpec(generator_enum=BoTorch, "
"generator_key_override=None)], "
"transition_criteria=[MinTrials(transition_to='next_node')])",
"transition_criteria=[MinTrials(transition_to='next_node')], "
"suggested_experiment_status=ExperimentStatus.OPTIMIZATION)",
)

def test_single_fixed_features(self) -> None:
Expand Down
9 changes: 9 additions & 0 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,15 @@ def generation_node_from_json(
if "trial_type" in generation_node_json.keys()
else None
),
suggested_experiment_status=(
object_from_json(
object_json=generation_node_json.pop("suggested_experiment_status"),
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
if "suggested_experiment_status" in generation_node_json.keys()
else None # Default for old records without the field
),
)


Expand Down
Loading