From f4d4474e8be96a989d75f50df153cab21262d063 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Tue, 10 Feb 2026 13:24:13 -0800 Subject: [PATCH 1/2] Upstream MultiTypeExperiment features into Experiment (#4873) Summary: These changes will enable us to deprecate multitypeexperment, simplifying the Ax data model ahead of storage changes. 1. In Experiment make the default_trial_type a new Key.DEFAULT_TRIAL_TYPE value instead of None 2. Move over logic for bookkeeping metric -> trial_type and runner -> trial_type mappings 3. Treat LONG_ and SHORT_RUN trial types as special cases which map to DEFAULT_TRIAL_TYPE (i.e. if a Trial has trial_type=LONG_RUN then use whichever metrics and runners are mapped to DEFAULT_TRIAL_TYPE 4. Fix tests which expect the default_trial_type of an Experiment to be None This diff allows us to remove all isinstance(foo, MultiTypeExperiment) checks in Ax in the next diff, then to deprecate MultiTypeExperiment entirely. Differential Revision: D91618283 --- ax/core/base_trial.py | 9 +- ax/core/experiment.py | 237 +++++++++++++++++++++++++----- ax/core/tests/test_batch_trial.py | 3 +- ax/core/tests/test_experiment.py | 7 +- ax/core/tests/test_observation.py | 10 ++ ax/core/tests/test_trial.py | 4 +- ax/orchestration/orchestrator.py | 28 ++-- ax/storage/json_store/decoder.py | 14 +- ax/storage/json_store/decoders.py | 9 +- ax/storage/sqa_store/decoder.py | 31 +++- ax/storage/sqa_store/encoder.py | 7 +- ax/utils/common/constants.py | 1 + 12 files changed, 292 insertions(+), 68 deletions(-) diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index ee201603067..77b8160cfae 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -98,16 +98,15 @@ def __init__( self._ttl_seconds: int | None = ttl_seconds self._index: int = self._experiment._attach_trial(self, index=index) - trial_type = ( + self._trial_type: str = ( trial_type if trial_type is not None else self._experiment.default_trial_type ) - if not self._experiment.supports_trial_type(trial_type): + if not self._experiment.supports_trial_type(self._trial_type): raise ValueError( - f"Trial type {trial_type} is not supported by the experiment." + f"Trial type {self._trial_type} is not supported by the experiment." ) - self._trial_type = trial_type self.__status: TrialStatus | None = None # Uses `_status` setter, which updates trial statuses to trial indices @@ -285,7 +284,7 @@ def stop_metadata(self) -> dict[str, Any]: return self._stop_metadata @property - def trial_type(self) -> str | None: + def trial_type(self) -> str: """The type of the trial. Relevant for experiments containing different kinds of trials diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 071fe83f802..39b466faf02 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -100,7 +100,7 @@ def __init__( default_data_type: Any = None, auxiliary_experiments_by_purpose: None | (dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]]) = None, - default_trial_type: str | None = None, + default_trial_type: str = Keys.DEFAULT_TRIAL_TYPE.value, ) -> None: """Inits Experiment. @@ -123,6 +123,8 @@ def __init__( default_data_type: Deprecated and ignored. auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments for different purposes (e.g., transfer learning). + default_trial_type: Default trial type for trials on this experiment. + Defaults to Keys.DEFAULT_TRIAL_TYPE. """ if default_data_type is not None: warnings.warn( @@ -150,10 +152,16 @@ def __init__( self._properties: dict[str, Any] = properties or {} # Initialize trial type to runner mapping - self._default_trial_type = default_trial_type - self._trial_type_to_runner: dict[str | None, Runner | None] = { - default_trial_type: runner + self._default_trial_type: str = ( + default_trial_type or Keys.DEFAULT_TRIAL_TYPE.value + ) + self._trial_type_to_runner: dict[str, Runner | None] = { + self._default_trial_type: runner } + + # Maps metric names to their trial types. Every metric must have an entry. + self._metric_to_trial_type: dict[str, str] = {} + # Used to keep track of whether any trials on the experiment # specify a TTL. Since trials need to be checked for their TTL's # expiration often, having this attribute helps avoid unnecessary @@ -413,16 +421,46 @@ def runner(self) -> Runner | None: def runner(self, runner: Runner | None) -> None: """Set the default runner and update trial type mapping.""" self._runner = runner - if runner is not None: - self._trial_type_to_runner[self._default_trial_type] = runner - else: - self._trial_type_to_runner = {None: None} + self._trial_type_to_runner[self._default_trial_type] = runner @runner.deleter def runner(self) -> None: """Delete the runner.""" self._runner = None - self._trial_type_to_runner = {None: None} + self._trial_type_to_runner[self._default_trial_type] = None + + def add_trial_type(self, trial_type: str, runner: Runner) -> "Experiment": + """Add a new trial_type to be supported by this experiment. + + Args: + trial_type: The new trial_type to be added. + runner: The default runner for trials of this type. + + Returns: + The experiment with the new trial type added. + """ + if self.supports_trial_type(trial_type): + raise ValueError(f"Experiment already contains trial_type `{trial_type}`") + + self._trial_type_to_runner[trial_type] = runner + return self + + def update_runner(self, trial_type: str, runner: Runner) -> "Experiment": + """Update the default runner for an existing trial_type. + + Args: + trial_type: The trial_type to update. + runner: The new runner for trials of this type. + + Returns: + The experiment with the updated runner. + """ + if not self.supports_trial_type(trial_type): + raise ValueError(f"Experiment does not contain trial_type `{trial_type}`") + + self._trial_type_to_runner[trial_type] = runner + self._runner = runner + return self @property def parameters(self) -> dict[str, Parameter]: @@ -489,13 +527,25 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None: f"`{Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF.value}` " "property that is set to `True` on this experiment." ) + + # Remove old OC metrics from trial type mapping + prev_optimization_config = self._optimization_config + if prev_optimization_config is not None: + for metric_name in prev_optimization_config.metrics.keys(): + self._metric_to_trial_type.pop(metric_name, None) + for metric_name in optimization_config.metrics.keys(): if metric_name in self._tracking_metrics: self.remove_tracking_metric(metric_name) + # add metrics from the previous optimization config that are not in the new # optimization config as tracking metrics - prev_optimization_config = self._optimization_config self._optimization_config = optimization_config + + # Map new OC metrics to default trial type + for metric_name in optimization_config.metrics.keys(): + self._metric_to_trial_type[metric_name] = self._default_trial_type + if prev_optimization_config is not None: metrics_to_track = ( set(prev_optimization_config.metrics.keys()) @@ -505,6 +555,16 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None: for metric_name in metrics_to_track: self.add_tracking_metric(prev_optimization_config.metrics[metric_name]) + # Clean up any stale entries in _metric_to_trial_type that don't correspond + # to actual metrics (can happen when same optimization_config object is + # mutated and reassigned). + current_metric_names = set(self.metrics.keys()) + stale_metric_names = ( + set(self._metric_to_trial_type.keys()) - current_metric_names + ) + for metric_name in stale_metric_names: + self._metric_to_trial_type.pop(metric_name, None) + @property def is_moo_problem(self) -> bool: """Whether the experiment's optimization config contains multiple objectives.""" @@ -553,12 +613,25 @@ def immutable_search_space_and_opt_config(self) -> bool: def tracking_metrics(self) -> list[Metric]: return list(self._tracking_metrics.values()) - def add_tracking_metric(self, metric: Metric) -> Self: + def add_tracking_metric( + self, + metric: Metric, + trial_type: str | None = None, + ) -> Self: """Add a new metric to the experiment. Args: metric: Metric to be added. + trial_type: The trial type for which this metric is used. If not + provided, defaults to the experiment's default trial type. """ + effective_trial_type = ( + trial_type if trial_type is not None else self._default_trial_type + ) + + if not self.supports_trial_type(effective_trial_type): + raise ValueError(f"`{effective_trial_type}` is not a supported trial type.") + if metric.name in self._tracking_metrics: raise ValueError( f"Metric `{metric.name}` already defined on experiment. " @@ -574,9 +647,14 @@ def add_tracking_metric(self, metric: Metric) -> Self: ) self._tracking_metrics[metric.name] = metric + self._metric_to_trial_type[metric.name] = effective_trial_type return self - def add_tracking_metrics(self, metrics: list[Metric]) -> Self: + def add_tracking_metrics( + self, + metrics: list[Metric], + metrics_to_trial_types: dict[str, str] | None = None, + ) -> Self: """Add a list of new metrics to the experiment. If any of the metrics are already defined on the experiment, @@ -584,23 +662,58 @@ def add_tracking_metrics(self, metrics: list[Metric]) -> Self: Args: metrics: Metrics to be added. + metrics_to_trial_types: Optional mapping from metric names to + corresponding trial types. If not provided for a metric, + the experiment's default trial type is used. """ - # Before setting any metrics, we validate none are already on - # the experiment + metrics_to_trial_types = metrics_to_trial_types or {} for metric in metrics: - self.add_tracking_metric(metric) + self.add_tracking_metric( + metric=metric, + trial_type=metrics_to_trial_types.get(metric.name), + ) return self - def update_tracking_metric(self, metric: Metric) -> Self: + def update_tracking_metric( + self, + metric: Metric, + trial_type: str | None = None, + ) -> Self: """Redefine a metric that already exists on the experiment. Args: metric: New metric definition. + trial_type: The trial type for which this metric is used. If not + provided, keeps the existing trial type mapping. """ if metric.name not in self._tracking_metrics: raise ValueError(f"Metric `{metric.name}` doesn't exist on experiment.") + # Validate trial type if provided + effective_trial_type = ( + trial_type + if trial_type is not None + else self._metric_to_trial_type.get(metric.name, self._default_trial_type) + ) + + # Check that optimization config metrics stay on default trial type + oc = self.optimization_config + oc_metrics = oc.metrics if oc else {} + if ( + metric.name in oc_metrics + and effective_trial_type != self._default_trial_type + ): + raise ValueError( + f"Metric `{metric.name}` must remain a " + f"`{self._default_trial_type}` metric because it is part of the " + "optimization_config." + ) + + if not self.supports_trial_type(effective_trial_type): + raise ValueError(f"`{effective_trial_type}` is not a supported trial type.") + self._tracking_metrics[metric.name] = metric + self._metric_to_trial_type[metric.name] = effective_trial_type return self def remove_tracking_metric(self, metric_name: str) -> Self: @@ -613,6 +726,7 @@ def remove_tracking_metric(self, metric_name: str) -> Self: raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.") del self._tracking_metrics[metric_name] + self._metric_to_trial_type.pop(metric_name, None) return self @property @@ -852,8 +966,21 @@ def _fetch_trial_data( ) -> dict[str, MetricFetchResult]: trial = self.trials[trial_index] + # If metrics are not provided, fetch all metrics on the experiment for the + # relevant trial type, or the default trial type as a fallback. Otherwise, + # fetch provided metrics. + if metrics is None: + resolved_metrics = [ + metric + for metric in list(self.metrics.values()) + if self._metric_to_trial_type.get(metric.name, self._default_trial_type) + == trial.trial_type + ] + else: + resolved_metrics = metrics + trial_data = self._lookup_or_fetch_trials_results( - trials=[trial], metrics=metrics, **kwargs + trials=[trial], metrics=resolved_metrics, **kwargs ) if trial_index in trial_data: @@ -1548,39 +1675,79 @@ def __repr__(self) -> str: # overridden in the MultiTypeExperiment class. @property - def default_trial_type(self) -> str | None: - """Default trial type assigned to trials in this experiment. - - In the base experiment class this is always None. For experiments - with multiple trial types, use the MultiTypeExperiment class. - """ + def default_trial_type(self) -> str: + """Default trial type assigned to trials in this experiment.""" return self._default_trial_type - def runner_for_trial_type(self, trial_type: str | None) -> Runner | None: + def runner_for_trial_type(self, trial_type: str) -> Runner | None: """The default runner to use for a given trial type. Looks up the appropriate runner for this trial type in the trial_type_to_runner. """ + # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default" + # trial types for deployment. + if ( + trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN + ) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE): + return self._trial_type_to_runner[Keys.DEFAULT_TRIAL_TYPE] + if not self.supports_trial_type(trial_type): raise ValueError(f"Trial type `{trial_type}` is not supported.") if (runner := self._trial_type_to_runner.get(trial_type)) is None: return self.runner # return the default runner return runner - def supports_trial_type(self, trial_type: str | None) -> bool: + def supports_trial_type(self, trial_type: str) -> bool: """Whether this experiment allows trials of the given type. - The base experiment class only supports None. For experiments - with multiple trial types, use the MultiTypeExperiment class. + Checks if the trial type is registered in the trial_type_to_runner mapping. """ - return ( - trial_type is None - # We temporarily allow "short run" and "long run" trial - # types in single-type experiments during development of - # a new ``GenerationStrategy`` that needs them. - or trial_type == Keys.SHORT_RUN - or trial_type == Keys.LONG_RUN - ) + # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default" + # trial types for deployment. + if ( + trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN + ) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE): + return True + + return trial_type in self._trial_type_to_runner + + @property + def is_multi_type(self) -> bool: + """Returns True if this experiment has multiple trial types registered.""" + return len(self._trial_type_to_runner) > 1 + + @property + def metric_to_trial_type(self) -> dict[str, str]: + """Read-only mapping of metric names to trial types.""" + return self._metric_to_trial_type.copy() + + def metrics_for_trial_type(self, trial_type: str) -> list[Metric]: + """Returns metrics associated with a specific trial type. + + Args: + trial_type: The trial type to get metrics for. + + Returns: + List of metrics associated with the given trial type. + """ + # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default" + # trial types for deployment. + if ( + trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN + ) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE): + return [ + self.metrics[metric_name] + for metric_name, metric_trial_type in self._metric_to_trial_type.items() + if metric_trial_type == Keys.DEFAULT_TRIAL_TYPE + ] + + if not self.supports_trial_type(trial_type): + raise ValueError(f"Trial type `{trial_type}` is not supported.") + return [ + self.metrics[metric_name] + for metric_name, metric_trial_type in self._metric_to_trial_type.items() + if metric_trial_type == trial_type + ] def attach_trial( self, diff --git a/ax/core/tests/test_batch_trial.py b/ax/core/tests/test_batch_trial.py index 8a2c201f4b9..97901dd3be9 100644 --- a/ax/core/tests/test_batch_trial.py +++ b/ax/core/tests/test_batch_trial.py @@ -410,8 +410,9 @@ def test_clone_to(self, _) -> None: cloned_batch._time_created = batch._time_created self.assertEqual(cloned_batch, batch) # test cloning with clear_trial_type=True + # When clear_trial_type=True, uses experiment's default_trial_type cloned_batch = batch.clone_to(clear_trial_type=True) - self.assertIsNone(cloned_batch.trial_type) + self.assertEqual(cloned_batch.trial_type, self.experiment.default_trial_type) self.assertEqual( cloned_batch.generation_method_str, f"{MANUAL_GENERATION_METHOD_STR}, {STATUS_QUO_GENERATION_METHOD_STR}", diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 9903fbeb578..c87ec407bb5 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -1417,7 +1417,7 @@ def test_clone_with(self) -> None: cloned_experiment._time_created = experiment._time_created self.assertEqual(cloned_experiment, experiment) - # test clear_trial_type + # test clear_trial_type - uses experiment's default_trial_type experiment = get_branin_experiment( with_batch=True, num_batch_trial=1, @@ -1427,7 +1427,10 @@ def test_clone_with(self) -> None: with self.assertRaisesRegex(ValueError, ".* foo is not supported by the exp"): experiment.clone_with() cloned_experiment = experiment.clone_with(clear_trial_type=True) - self.assertIsNone(cloned_experiment.trials[0].trial_type) + self.assertEqual( + cloned_experiment.trials[0].trial_type, + cloned_experiment.default_trial_type, + ) # Test cloning with specific properties to keep experiment_w_props = get_branin_experiment() diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index c18ccf25deb..0817cac286a 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -268,6 +268,8 @@ def test_ObservationsFromData(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: Trial( experiment, GeneratorRun(arms=[arms[obs["arm_name"]]]) @@ -525,6 +527,8 @@ def test_ObservationsFromDataAbandoned(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: ( Trial(experiment, GeneratorRun(arms=[arms[obs["arm_name"]]])) @@ -637,6 +641,8 @@ def test_ObservationsFromDataWithSomeMissingTimes(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: Trial( experiment, GeneratorRun(arms=[arms[obs["arm_name"]]]) @@ -744,6 +750,8 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial( } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { 0: BatchTrial(experiment, GeneratorRun(arms=list(arms_by_name.values()))) } @@ -885,6 +893,8 @@ def test_ObservationsWithCandidateMetadata(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: Trial( experiment, diff --git a/ax/core/tests/test_trial.py b/ax/core/tests/test_trial.py index 33eef439201..0fa7d3aecfe 100644 --- a/ax/core/tests/test_trial.py +++ b/ax/core/tests/test_trial.py @@ -460,9 +460,9 @@ def test_clone_to(self) -> None: # check that trial_type is cloned correctly self.assertEqual(new_trial.trial_type, "foo") - # test clear_trial_type + # test clear_trial_type - uses experiment's default_trial_type new_trial = self.trial.clone_to(clear_trial_type=True) - self.assertIsNone(new_trial.trial_type) + self.assertEqual(new_trial.trial_type, new_experiment.default_trial_type) def test_update_trial_status_on_clone(self) -> None: for status in [ diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index 8d4804893b7..a54e5d56496 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -58,7 +58,7 @@ set_ax_logger_levels, ) from ax.utils.common.timeutils import current_timestamp_in_millis -from pyre_extensions import assert_is_instance, none_throws +from pyre_extensions import none_throws NOT_IMPLEMENTED_IN_BASE_CLASS_MSG = """ \ @@ -364,7 +364,7 @@ def options(self, options: OrchestratorOptions) -> None: self._validate_runner_and_implemented_metrics(experiment=self.experiment) @property - def trial_type(self) -> str | None: + def trial_type(self) -> str: """Trial type for the experiment this Orchestrator is running. This returns None if the experiment is not a MultitypeExperiment @@ -374,8 +374,10 @@ def trial_type(self) -> str | None: experiment is a MultiTypeExperiment and None otherwise. """ if isinstance(self.experiment, MultiTypeExperiment): - return self.options.mt_experiment_trial_type - return None + return ( + self.options.mt_experiment_trial_type or Keys.DEFAULT_TRIAL_TYPE.value + ) + return Keys.DEFAULT_TRIAL_TYPE.value @property def running_trials(self) -> list[BaseTrial]: @@ -487,12 +489,8 @@ def runner(self) -> Runner: """``Runner`` specified on the experiment associated with this ``Orchestrator`` instance. """ - if self.trial_type is not None: - runner = assert_is_instance( - self.experiment, MultiTypeExperiment - ).runner_for_trial_type(trial_type=none_throws(self.trial_type)) - else: - runner = self.experiment.runner + runner = self.experiment.runner_for_trial_type(trial_type=self.trial_type) + if runner is None: raise UnsupportedError( "`Orchestrator` requires that experiment specifies a `Runner`." @@ -2034,11 +2032,11 @@ def _fetch_and_process_trials_data_results( try: kwargs = deepcopy(self.options.fetch_kwargs) - if self.trial_type is not None: - metrics = assert_is_instance( - self.experiment, MultiTypeExperiment - ).metrics_for_trial_type(trial_type=none_throws(self.trial_type)) - kwargs["metrics"] = metrics + metrics = self.experiment.metrics_for_trial_type( + trial_type=none_throws(self.trial_type) + ) + kwargs["metrics"] = metrics + results = self.experiment.fetch_trials_data_results( trial_indices=trial_indices, **kwargs, diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 0beccd47c09..213994b6f06 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -59,6 +59,7 @@ CORE_DECODER_REGISTRY, ) from ax.storage.utils import data_by_trial_to_data +from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.serialization import ( extract_init_args, @@ -727,8 +728,19 @@ def experiment_from_json( } ) experiment._arms_by_name = {} - if _trial_type_to_runner is not None: + + # Handle backwards compatibility issue where some Experiments support None + # trial types. + if ( + _trial_type_to_runner is not None + and len(_trial_type_to_runner) > 0 + and ({*_trial_type_to_runner.keys()} != {None}) + ): experiment._trial_type_to_runner = _trial_type_to_runner + else: + experiment._trial_type_to_runner = { + Keys.DEFAULT_TRIAL_TYPE.value: experiment.runner + } _load_experiment_info( exp=experiment, diff --git a/ax/storage/json_store/decoders.py b/ax/storage/json_store/decoders.py index 5d65a745c97..26e30e10074 100644 --- a/ax/storage/json_store/decoders.py +++ b/ax/storage/json_store/decoders.py @@ -48,6 +48,7 @@ REMOVED_TRANSFORMS, REVERSE_TRANSFORM_REGISTRY, ) +from ax.utils.common.constants import Keys from ax.utils.common.kwargs import warn_on_kwargs from ax.utils.common.logger import get_logger from ax.utils.common.typeutils_torch import torch_type_from_str @@ -158,7 +159,9 @@ def batch_trial_from_json( # the SQ at the end of this function. ) batch._index = index - batch._trial_type = trial_type + batch._trial_type = ( + trial_type if trial_type is not None else Keys.DEFAULT_TRIAL_TYPE.value + ) batch._time_created = time_created batch._time_completed = time_completed batch._time_staged = time_staged @@ -219,7 +222,9 @@ def trial_from_json( experiment=experiment, generator_run=generator_run, ttl_seconds=ttl_seconds ) trial._index = index - trial._trial_type = trial_type + trial._trial_type = ( + trial_type if trial_type is not None else Keys.DEFAULT_TRIAL_TYPE.value + ) # Swap `DISPATCHED` for `RUNNING`, since `DISPATCHED` is deprecated and nearly # equivalent to `RUNNING`. trial._status = status if status != TrialStatus.DISPATCHED else TrialStatus.RUNNING diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 92628b093f3..40faa2f6485 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -386,11 +386,15 @@ def experiment_from_sqa( experiment.data = data_by_trial_to_data(data_by_trial=data_by_trial) trial_type_to_runner = { - sqa_runner.trial_type: self.runner_from_sqa(sqa_runner) + ( + sqa_runner.trial_type + if sqa_runner.trial_type is not None + else Keys.DEFAULT_TRIAL_TYPE.value + ): self.runner_from_sqa(sqa_runner) for sqa_runner in experiment_sqa.runners } if len(trial_type_to_runner) == 0: - trial_type_to_runner = {None: None} + trial_type_to_runner = {Keys.DEFAULT_TRIAL_TYPE.value: None} experiment._trials = {trial.index: trial for trial in trials} experiment._arms_by_name = {} @@ -415,9 +419,24 @@ def experiment_from_sqa( # `_trial_type_to_runner` is set in _init_mt_experiment_from_sqa if subclass != "MultiTypeExperiment": experiment._trial_type_to_runner = cast( - dict[str | None, Runner | None], trial_type_to_runner + dict[str, Runner | None], trial_type_to_runner ) experiment.db_id = experiment_sqa.id + + # For non-MultiTypeExperiment, populate _metric_to_trial_type + # This is needed because the metrics were added directly to the experiment + # without going through the setters that populate this field. + if subclass != "MultiTypeExperiment": + default_trial_type = Keys.DEFAULT_TRIAL_TYPE.value + # Add OC metrics + oc = experiment.optimization_config + if oc is not None: + for metric_name in oc.metrics.keys(): + experiment._metric_to_trial_type[metric_name] = default_trial_type + # Add tracking metrics + for metric_name in experiment._tracking_metrics.keys(): + experiment._metric_to_trial_type[metric_name] = default_trial_type + return experiment def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: @@ -996,7 +1015,11 @@ def trial_from_sqa( reduced_state=reduced_state, immutable_search_space_and_opt_config=immutable_ss_and_oc, ) - trial._trial_type = trial_sqa.trial_type + trial._trial_type = ( + trial_sqa.trial_type + if trial_sqa.trial_type is not None + else Keys.DEFAULT_TRIAL_TYPE.value + ) # Swap `DISPATCHED` for `RUNNING`, since `DISPATCHED` is deprecated and nearly # equivalent to `RUNNING`. trial._status = ( diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 00a9568caa7..cba1364ee7d 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -226,7 +226,12 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: metric.name ] elif experiment.runner: - runners.append(self.runner_to_sqa(none_throws(experiment.runner))) + runners.append( + self.runner_to_sqa( + none_throws(experiment.runner), + trial_type=experiment.default_trial_type, + ) + ) properties = experiment._properties.copy() if ( oc := experiment.optimization_config diff --git a/ax/utils/common/constants.py b/ax/utils/common/constants.py index 2f6a6d5b4d5..a1f9fdf531b 100644 --- a/ax/utils/common/constants.py +++ b/ax/utils/common/constants.py @@ -53,6 +53,7 @@ class Keys(StrEnum): COST_INTERCEPT = "cost_intercept" CURRENT_VALUE = "current_value" DEFAULT_OBJECTIVE_NAME = "objective" + DEFAULT_TRIAL_TYPE = "default" EXPAND = "expand" EXPECTED_ACQF_VAL = "expected_acquisition_value" EXPERIMENT_TOTAL_CONCURRENT_ARMS = "total_concurrent_arms" From a172262a780be0c904cfee8e0a17559d7f86409c Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Tue, 10 Feb 2026 13:24:13 -0800 Subject: [PATCH 2/2] Remove bifrucation around MultiTypeExperiment (#4874) Summary: With recent changes to experiment we no longer need this bifructation. Next diff will remove places where we construct MultiTypeExperiment, and the one after will deprecate the class entirely Differential Revision: D91920991 --- ax/orchestration/orchestrator.py | 24 ++++----------------- ax/orchestration/tests/test_orchestrator.py | 16 ++++++-------- ax/service/ax_client.py | 15 +++++-------- ax/service/tests/test_report_utils.py | 4 ++-- ax/service/utils/report_utils.py | 8 ++++--- 5 files changed, 23 insertions(+), 44 deletions(-) diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index a54e5d56496..a7a3401c215 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -27,7 +27,6 @@ from ax.core.multi_type_experiment import ( filter_trials_by_type, get_trial_indices_for_statuses, - MultiTypeExperiment, ) from ax.core.runner import Runner from ax.core.trial import Trial @@ -367,17 +366,11 @@ def options(self, options: OrchestratorOptions) -> None: def trial_type(self) -> str: """Trial type for the experiment this Orchestrator is running. - This returns None if the experiment is not a MultitypeExperiment - Returns: - Trial type for the experiment this Orchestrator is running if the - experiment is a MultiTypeExperiment and None otherwise. + Trial type for the experiment this Orchestrator is running. + Defaults to Keys.DEFAULT_TRIAL_TYPE if not specified. """ - if isinstance(self.experiment, MultiTypeExperiment): - return ( - self.options.mt_experiment_trial_type or Keys.DEFAULT_TRIAL_TYPE.value - ) - return Keys.DEFAULT_TRIAL_TYPE.value + return self.options.mt_experiment_trial_type or Keys.DEFAULT_TRIAL_TYPE.value @property def running_trials(self) -> list[BaseTrial]: @@ -1619,11 +1612,7 @@ def _validate_options(self, options: OrchestratorOptions) -> None: "will be unable to fetch intermediate results with which to " "evaluate early stopping criteria." ) - if isinstance(self.experiment, MultiTypeExperiment): - if options.mt_experiment_trial_type is None: - raise UserInputError( - "Must specify `mt_experiment_trial_type` for MultiTypeExperiment." - ) + if options.mt_experiment_trial_type is not None: if not self.experiment.supports_trial_type( options.mt_experiment_trial_type ): @@ -1631,11 +1620,6 @@ def _validate_options(self, options: OrchestratorOptions) -> None: "Experiment does not support trial type " f"{options.mt_experiment_trial_type}." ) - elif options.mt_experiment_trial_type is not None: - raise UserInputError( - "`mt_experiment_trial_type` must be None unless the experiment is a " - "MultiTypeExperiment." - ) def _get_max_pending_trials(self) -> int: """Returns the maximum number of pending trials specified in the options, or diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 67f982ca1a1..8417fd031d7 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -2736,12 +2736,9 @@ def test_validate_options_not_none_mt_trial_type( self, msg: str | None = None ) -> None: # test that error is raised if `mt_experiment_trial_type` is not - # compatible with the type of experiment (single or multi-type) + # a supported trial type for this experiment if msg is None: - msg = ( - "`mt_experiment_trial_type` must be None unless the experiment is a " - "MultiTypeExperiment." - ) + msg = "Experiment does not support trial type type1." options = OrchestratorOptions( init_seconds_between_polls=0, # No wait bw polls so test is fast. batch_size=10, @@ -2752,7 +2749,7 @@ def test_validate_options_not_none_mt_trial_type( ), ) gs = self.two_sobol_steps_GS - with self.assertRaisesRegex(UserInputError, msg): + with self.assertRaisesRegex(ValueError, msg): Orchestrator( experiment=self.branin_experiment, generation_strategy=gs, @@ -3010,10 +3007,11 @@ def test_fetch_and_process_trials_data_results_failed_non_objective( def test_validate_options_not_none_mt_trial_type( self, msg: str | None = None ) -> None: - # test if a MultiTypeExperiment with `mt_experiment_trial_type=None` - self.orchestrator_options_kwargs["mt_experiment_trial_type"] = None + # test that error is raised if `mt_experiment_trial_type` is not + # a supported trial type for this experiment (using an invalid type) + self.orchestrator_options_kwargs["mt_experiment_trial_type"] = "invalid_type" super().test_validate_options_not_none_mt_trial_type( - msg="Must specify `mt_experiment_trial_type` for MultiTypeExperiment." + msg="Experiment does not support trial type invalid_type." ) def test_run_n_trials_single_step_existing_experiment( diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 592c316f037..fd1d4794886 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -24,7 +24,6 @@ from ax.core.evaluations_to_data import raw_evaluations_to_data from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective from ax.core.observation import ObservationFeatures from ax.core.runner import Runner @@ -458,15 +457,11 @@ def add_tracking_metrics( for metric_name in metric_names ] - if isinstance(self.experiment, MultiTypeExperiment): - experiment = assert_is_instance(self.experiment, MultiTypeExperiment) - experiment.add_tracking_metrics( - metrics=metric_objects, - metrics_to_trial_types=metrics_to_trial_types, - canonical_names=canonical_names, - ) - else: - self.experiment.add_tracking_metrics(metrics=metric_objects) + self.experiment.add_tracking_metrics( + metrics=metric_objects, + metrics_to_trial_types=metrics_to_trial_types, + **({"canonical_names": canonical_names} if canonical_names else {}), + ) @copy_doc(Experiment.remove_tracking_metric) def remove_tracking_metric(self, metric_name: str) -> None: diff --git a/ax/service/tests/test_report_utils.py b/ax/service/tests/test_report_utils.py index e857215be50..044c9fca78c 100644 --- a/ax/service/tests/test_report_utils.py +++ b/ax/service/tests/test_report_utils.py @@ -199,9 +199,9 @@ def test_exp_to_df_with_failure(self) -> None: self.assertEqual(f"{fail_reason}...", df["reason"].iloc[0]) def test_exp_to_df(self) -> None: - # MultiTypeExperiment should fail + # Experiments with multiple trial types should fail exp = get_multi_type_experiment() - with self.assertRaisesRegex(ValueError, "MultiTypeExperiment"): + with self.assertRaisesRegex(ValueError, "multiple trial types"): exp_to_df(exp=exp) # exp with no trials should return empty results diff --git a/ax/service/utils/report_utils.py b/ax/service/utils/report_utils.py index 4ddabe4c211..3f434504041 100644 --- a/ax/service/utils/report_utils.py +++ b/ax/service/utils/report_utils.py @@ -32,7 +32,6 @@ from ax.core.generator_run import GeneratorRunType from ax.core.map_metric import MapMetric from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, ScalarizedObjective from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, @@ -787,8 +786,11 @@ def exp_to_df( ) # Accept Experiment and SimpleExperiment - if isinstance(exp, MultiTypeExperiment): - raise ValueError("Cannot transform MultiTypeExperiments to DataFrames.") + # Reject experiments with multiple trial types as they need special handling + if len(exp._trial_type_to_runner) > 1: + raise ValueError( + "Cannot transform experiments with multiple trial types to DataFrames." + ) key_components = ["trial_index", "arm_name"]