diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index ee201603067..77b8160cfae 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -98,16 +98,15 @@ def __init__( self._ttl_seconds: int | None = ttl_seconds self._index: int = self._experiment._attach_trial(self, index=index) - trial_type = ( + self._trial_type: str = ( trial_type if trial_type is not None else self._experiment.default_trial_type ) - if not self._experiment.supports_trial_type(trial_type): + if not self._experiment.supports_trial_type(self._trial_type): raise ValueError( - f"Trial type {trial_type} is not supported by the experiment." + f"Trial type {self._trial_type} is not supported by the experiment." ) - self._trial_type = trial_type self.__status: TrialStatus | None = None # Uses `_status` setter, which updates trial statuses to trial indices @@ -285,7 +284,7 @@ def stop_metadata(self) -> dict[str, Any]: return self._stop_metadata @property - def trial_type(self) -> str | None: + def trial_type(self) -> str: """The type of the trial. Relevant for experiments containing different kinds of trials diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 071fe83f802..39b466faf02 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -100,7 +100,7 @@ def __init__( default_data_type: Any = None, auxiliary_experiments_by_purpose: None | (dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]]) = None, - default_trial_type: str | None = None, + default_trial_type: str = Keys.DEFAULT_TRIAL_TYPE.value, ) -> None: """Inits Experiment. @@ -123,6 +123,8 @@ def __init__( default_data_type: Deprecated and ignored. auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments for different purposes (e.g., transfer learning). + default_trial_type: Default trial type for trials on this experiment. + Defaults to Keys.DEFAULT_TRIAL_TYPE. """ if default_data_type is not None: warnings.warn( @@ -150,10 +152,16 @@ def __init__( self._properties: dict[str, Any] = properties or {} # Initialize trial type to runner mapping - self._default_trial_type = default_trial_type - self._trial_type_to_runner: dict[str | None, Runner | None] = { - default_trial_type: runner + self._default_trial_type: str = ( + default_trial_type or Keys.DEFAULT_TRIAL_TYPE.value + ) + self._trial_type_to_runner: dict[str, Runner | None] = { + self._default_trial_type: runner } + + # Maps metric names to their trial types. Every metric must have an entry. + self._metric_to_trial_type: dict[str, str] = {} + # Used to keep track of whether any trials on the experiment # specify a TTL. Since trials need to be checked for their TTL's # expiration often, having this attribute helps avoid unnecessary @@ -413,16 +421,46 @@ def runner(self) -> Runner | None: def runner(self, runner: Runner | None) -> None: """Set the default runner and update trial type mapping.""" self._runner = runner - if runner is not None: - self._trial_type_to_runner[self._default_trial_type] = runner - else: - self._trial_type_to_runner = {None: None} + self._trial_type_to_runner[self._default_trial_type] = runner @runner.deleter def runner(self) -> None: """Delete the runner.""" self._runner = None - self._trial_type_to_runner = {None: None} + self._trial_type_to_runner[self._default_trial_type] = None + + def add_trial_type(self, trial_type: str, runner: Runner) -> "Experiment": + """Add a new trial_type to be supported by this experiment. + + Args: + trial_type: The new trial_type to be added. + runner: The default runner for trials of this type. + + Returns: + The experiment with the new trial type added. + """ + if self.supports_trial_type(trial_type): + raise ValueError(f"Experiment already contains trial_type `{trial_type}`") + + self._trial_type_to_runner[trial_type] = runner + return self + + def update_runner(self, trial_type: str, runner: Runner) -> "Experiment": + """Update the default runner for an existing trial_type. + + Args: + trial_type: The trial_type to update. + runner: The new runner for trials of this type. + + Returns: + The experiment with the updated runner. + """ + if not self.supports_trial_type(trial_type): + raise ValueError(f"Experiment does not contain trial_type `{trial_type}`") + + self._trial_type_to_runner[trial_type] = runner + self._runner = runner + return self @property def parameters(self) -> dict[str, Parameter]: @@ -489,13 +527,25 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None: f"`{Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF.value}` " "property that is set to `True` on this experiment." ) + + # Remove old OC metrics from trial type mapping + prev_optimization_config = self._optimization_config + if prev_optimization_config is not None: + for metric_name in prev_optimization_config.metrics.keys(): + self._metric_to_trial_type.pop(metric_name, None) + for metric_name in optimization_config.metrics.keys(): if metric_name in self._tracking_metrics: self.remove_tracking_metric(metric_name) + # add metrics from the previous optimization config that are not in the new # optimization config as tracking metrics - prev_optimization_config = self._optimization_config self._optimization_config = optimization_config + + # Map new OC metrics to default trial type + for metric_name in optimization_config.metrics.keys(): + self._metric_to_trial_type[metric_name] = self._default_trial_type + if prev_optimization_config is not None: metrics_to_track = ( set(prev_optimization_config.metrics.keys()) @@ -505,6 +555,16 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None: for metric_name in metrics_to_track: self.add_tracking_metric(prev_optimization_config.metrics[metric_name]) + # Clean up any stale entries in _metric_to_trial_type that don't correspond + # to actual metrics (can happen when same optimization_config object is + # mutated and reassigned). + current_metric_names = set(self.metrics.keys()) + stale_metric_names = ( + set(self._metric_to_trial_type.keys()) - current_metric_names + ) + for metric_name in stale_metric_names: + self._metric_to_trial_type.pop(metric_name, None) + @property def is_moo_problem(self) -> bool: """Whether the experiment's optimization config contains multiple objectives.""" @@ -553,12 +613,25 @@ def immutable_search_space_and_opt_config(self) -> bool: def tracking_metrics(self) -> list[Metric]: return list(self._tracking_metrics.values()) - def add_tracking_metric(self, metric: Metric) -> Self: + def add_tracking_metric( + self, + metric: Metric, + trial_type: str | None = None, + ) -> Self: """Add a new metric to the experiment. Args: metric: Metric to be added. + trial_type: The trial type for which this metric is used. If not + provided, defaults to the experiment's default trial type. """ + effective_trial_type = ( + trial_type if trial_type is not None else self._default_trial_type + ) + + if not self.supports_trial_type(effective_trial_type): + raise ValueError(f"`{effective_trial_type}` is not a supported trial type.") + if metric.name in self._tracking_metrics: raise ValueError( f"Metric `{metric.name}` already defined on experiment. " @@ -574,9 +647,14 @@ def add_tracking_metric(self, metric: Metric) -> Self: ) self._tracking_metrics[metric.name] = metric + self._metric_to_trial_type[metric.name] = effective_trial_type return self - def add_tracking_metrics(self, metrics: list[Metric]) -> Self: + def add_tracking_metrics( + self, + metrics: list[Metric], + metrics_to_trial_types: dict[str, str] | None = None, + ) -> Self: """Add a list of new metrics to the experiment. If any of the metrics are already defined on the experiment, @@ -584,23 +662,58 @@ def add_tracking_metrics(self, metrics: list[Metric]) -> Self: Args: metrics: Metrics to be added. + metrics_to_trial_types: Optional mapping from metric names to + corresponding trial types. If not provided for a metric, + the experiment's default trial type is used. """ - # Before setting any metrics, we validate none are already on - # the experiment + metrics_to_trial_types = metrics_to_trial_types or {} for metric in metrics: - self.add_tracking_metric(metric) + self.add_tracking_metric( + metric=metric, + trial_type=metrics_to_trial_types.get(metric.name), + ) return self - def update_tracking_metric(self, metric: Metric) -> Self: + def update_tracking_metric( + self, + metric: Metric, + trial_type: str | None = None, + ) -> Self: """Redefine a metric that already exists on the experiment. Args: metric: New metric definition. + trial_type: The trial type for which this metric is used. If not + provided, keeps the existing trial type mapping. """ if metric.name not in self._tracking_metrics: raise ValueError(f"Metric `{metric.name}` doesn't exist on experiment.") + # Validate trial type if provided + effective_trial_type = ( + trial_type + if trial_type is not None + else self._metric_to_trial_type.get(metric.name, self._default_trial_type) + ) + + # Check that optimization config metrics stay on default trial type + oc = self.optimization_config + oc_metrics = oc.metrics if oc else {} + if ( + metric.name in oc_metrics + and effective_trial_type != self._default_trial_type + ): + raise ValueError( + f"Metric `{metric.name}` must remain a " + f"`{self._default_trial_type}` metric because it is part of the " + "optimization_config." + ) + + if not self.supports_trial_type(effective_trial_type): + raise ValueError(f"`{effective_trial_type}` is not a supported trial type.") + self._tracking_metrics[metric.name] = metric + self._metric_to_trial_type[metric.name] = effective_trial_type return self def remove_tracking_metric(self, metric_name: str) -> Self: @@ -613,6 +726,7 @@ def remove_tracking_metric(self, metric_name: str) -> Self: raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.") del self._tracking_metrics[metric_name] + self._metric_to_trial_type.pop(metric_name, None) return self @property @@ -852,8 +966,21 @@ def _fetch_trial_data( ) -> dict[str, MetricFetchResult]: trial = self.trials[trial_index] + # If metrics are not provided, fetch all metrics on the experiment for the + # relevant trial type, or the default trial type as a fallback. Otherwise, + # fetch provided metrics. + if metrics is None: + resolved_metrics = [ + metric + for metric in list(self.metrics.values()) + if self._metric_to_trial_type.get(metric.name, self._default_trial_type) + == trial.trial_type + ] + else: + resolved_metrics = metrics + trial_data = self._lookup_or_fetch_trials_results( - trials=[trial], metrics=metrics, **kwargs + trials=[trial], metrics=resolved_metrics, **kwargs ) if trial_index in trial_data: @@ -1548,39 +1675,79 @@ def __repr__(self) -> str: # overridden in the MultiTypeExperiment class. @property - def default_trial_type(self) -> str | None: - """Default trial type assigned to trials in this experiment. - - In the base experiment class this is always None. For experiments - with multiple trial types, use the MultiTypeExperiment class. - """ + def default_trial_type(self) -> str: + """Default trial type assigned to trials in this experiment.""" return self._default_trial_type - def runner_for_trial_type(self, trial_type: str | None) -> Runner | None: + def runner_for_trial_type(self, trial_type: str) -> Runner | None: """The default runner to use for a given trial type. Looks up the appropriate runner for this trial type in the trial_type_to_runner. """ + # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default" + # trial types for deployment. + if ( + trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN + ) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE): + return self._trial_type_to_runner[Keys.DEFAULT_TRIAL_TYPE] + if not self.supports_trial_type(trial_type): raise ValueError(f"Trial type `{trial_type}` is not supported.") if (runner := self._trial_type_to_runner.get(trial_type)) is None: return self.runner # return the default runner return runner - def supports_trial_type(self, trial_type: str | None) -> bool: + def supports_trial_type(self, trial_type: str) -> bool: """Whether this experiment allows trials of the given type. - The base experiment class only supports None. For experiments - with multiple trial types, use the MultiTypeExperiment class. + Checks if the trial type is registered in the trial_type_to_runner mapping. """ - return ( - trial_type is None - # We temporarily allow "short run" and "long run" trial - # types in single-type experiments during development of - # a new ``GenerationStrategy`` that needs them. - or trial_type == Keys.SHORT_RUN - or trial_type == Keys.LONG_RUN - ) + # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default" + # trial types for deployment. + if ( + trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN + ) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE): + return True + + return trial_type in self._trial_type_to_runner + + @property + def is_multi_type(self) -> bool: + """Returns True if this experiment has multiple trial types registered.""" + return len(self._trial_type_to_runner) > 1 + + @property + def metric_to_trial_type(self) -> dict[str, str]: + """Read-only mapping of metric names to trial types.""" + return self._metric_to_trial_type.copy() + + def metrics_for_trial_type(self, trial_type: str) -> list[Metric]: + """Returns metrics associated with a specific trial type. + + Args: + trial_type: The trial type to get metrics for. + + Returns: + List of metrics associated with the given trial type. + """ + # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default" + # trial types for deployment. + if ( + trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN + ) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE): + return [ + self.metrics[metric_name] + for metric_name, metric_trial_type in self._metric_to_trial_type.items() + if metric_trial_type == Keys.DEFAULT_TRIAL_TYPE + ] + + if not self.supports_trial_type(trial_type): + raise ValueError(f"Trial type `{trial_type}` is not supported.") + return [ + self.metrics[metric_name] + for metric_name, metric_trial_type in self._metric_to_trial_type.items() + if metric_trial_type == trial_type + ] def attach_trial( self, diff --git a/ax/core/tests/test_batch_trial.py b/ax/core/tests/test_batch_trial.py index 8a2c201f4b9..97901dd3be9 100644 --- a/ax/core/tests/test_batch_trial.py +++ b/ax/core/tests/test_batch_trial.py @@ -410,8 +410,9 @@ def test_clone_to(self, _) -> None: cloned_batch._time_created = batch._time_created self.assertEqual(cloned_batch, batch) # test cloning with clear_trial_type=True + # When clear_trial_type=True, uses experiment's default_trial_type cloned_batch = batch.clone_to(clear_trial_type=True) - self.assertIsNone(cloned_batch.trial_type) + self.assertEqual(cloned_batch.trial_type, self.experiment.default_trial_type) self.assertEqual( cloned_batch.generation_method_str, f"{MANUAL_GENERATION_METHOD_STR}, {STATUS_QUO_GENERATION_METHOD_STR}", diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 9903fbeb578..c87ec407bb5 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -1417,7 +1417,7 @@ def test_clone_with(self) -> None: cloned_experiment._time_created = experiment._time_created self.assertEqual(cloned_experiment, experiment) - # test clear_trial_type + # test clear_trial_type - uses experiment's default_trial_type experiment = get_branin_experiment( with_batch=True, num_batch_trial=1, @@ -1427,7 +1427,10 @@ def test_clone_with(self) -> None: with self.assertRaisesRegex(ValueError, ".* foo is not supported by the exp"): experiment.clone_with() cloned_experiment = experiment.clone_with(clear_trial_type=True) - self.assertIsNone(cloned_experiment.trials[0].trial_type) + self.assertEqual( + cloned_experiment.trials[0].trial_type, + cloned_experiment.default_trial_type, + ) # Test cloning with specific properties to keep experiment_w_props = get_branin_experiment() diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index c18ccf25deb..0817cac286a 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -268,6 +268,8 @@ def test_ObservationsFromData(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: Trial( experiment, GeneratorRun(arms=[arms[obs["arm_name"]]]) @@ -525,6 +527,8 @@ def test_ObservationsFromDataAbandoned(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: ( Trial(experiment, GeneratorRun(arms=[arms[obs["arm_name"]]])) @@ -637,6 +641,8 @@ def test_ObservationsFromDataWithSomeMissingTimes(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: Trial( experiment, GeneratorRun(arms=[arms[obs["arm_name"]]]) @@ -744,6 +750,8 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial( } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { 0: BatchTrial(experiment, GeneratorRun(arms=list(arms_by_name.values()))) } @@ -885,6 +893,8 @@ def test_ObservationsWithCandidateMetadata(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: Trial( experiment, diff --git a/ax/core/tests/test_trial.py b/ax/core/tests/test_trial.py index 33eef439201..0fa7d3aecfe 100644 --- a/ax/core/tests/test_trial.py +++ b/ax/core/tests/test_trial.py @@ -460,9 +460,9 @@ def test_clone_to(self) -> None: # check that trial_type is cloned correctly self.assertEqual(new_trial.trial_type, "foo") - # test clear_trial_type + # test clear_trial_type - uses experiment's default_trial_type new_trial = self.trial.clone_to(clear_trial_type=True) - self.assertIsNone(new_trial.trial_type) + self.assertEqual(new_trial.trial_type, new_experiment.default_trial_type) def test_update_trial_status_on_clone(self) -> None: for status in [ diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index 8d4804893b7..a54e5d56496 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -58,7 +58,7 @@ set_ax_logger_levels, ) from ax.utils.common.timeutils import current_timestamp_in_millis -from pyre_extensions import assert_is_instance, none_throws +from pyre_extensions import none_throws NOT_IMPLEMENTED_IN_BASE_CLASS_MSG = """ \ @@ -364,7 +364,7 @@ def options(self, options: OrchestratorOptions) -> None: self._validate_runner_and_implemented_metrics(experiment=self.experiment) @property - def trial_type(self) -> str | None: + def trial_type(self) -> str: """Trial type for the experiment this Orchestrator is running. This returns None if the experiment is not a MultitypeExperiment @@ -374,8 +374,10 @@ def trial_type(self) -> str | None: experiment is a MultiTypeExperiment and None otherwise. """ if isinstance(self.experiment, MultiTypeExperiment): - return self.options.mt_experiment_trial_type - return None + return ( + self.options.mt_experiment_trial_type or Keys.DEFAULT_TRIAL_TYPE.value + ) + return Keys.DEFAULT_TRIAL_TYPE.value @property def running_trials(self) -> list[BaseTrial]: @@ -487,12 +489,8 @@ def runner(self) -> Runner: """``Runner`` specified on the experiment associated with this ``Orchestrator`` instance. """ - if self.trial_type is not None: - runner = assert_is_instance( - self.experiment, MultiTypeExperiment - ).runner_for_trial_type(trial_type=none_throws(self.trial_type)) - else: - runner = self.experiment.runner + runner = self.experiment.runner_for_trial_type(trial_type=self.trial_type) + if runner is None: raise UnsupportedError( "`Orchestrator` requires that experiment specifies a `Runner`." @@ -2034,11 +2032,11 @@ def _fetch_and_process_trials_data_results( try: kwargs = deepcopy(self.options.fetch_kwargs) - if self.trial_type is not None: - metrics = assert_is_instance( - self.experiment, MultiTypeExperiment - ).metrics_for_trial_type(trial_type=none_throws(self.trial_type)) - kwargs["metrics"] = metrics + metrics = self.experiment.metrics_for_trial_type( + trial_type=none_throws(self.trial_type) + ) + kwargs["metrics"] = metrics + results = self.experiment.fetch_trials_data_results( trial_indices=trial_indices, **kwargs, diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 0beccd47c09..213994b6f06 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -59,6 +59,7 @@ CORE_DECODER_REGISTRY, ) from ax.storage.utils import data_by_trial_to_data +from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.serialization import ( extract_init_args, @@ -727,8 +728,19 @@ def experiment_from_json( } ) experiment._arms_by_name = {} - if _trial_type_to_runner is not None: + + # Handle backwards compatibility issue where some Experiments support None + # trial types. + if ( + _trial_type_to_runner is not None + and len(_trial_type_to_runner) > 0 + and ({*_trial_type_to_runner.keys()} != {None}) + ): experiment._trial_type_to_runner = _trial_type_to_runner + else: + experiment._trial_type_to_runner = { + Keys.DEFAULT_TRIAL_TYPE.value: experiment.runner + } _load_experiment_info( exp=experiment, diff --git a/ax/storage/json_store/decoders.py b/ax/storage/json_store/decoders.py index 5d65a745c97..26e30e10074 100644 --- a/ax/storage/json_store/decoders.py +++ b/ax/storage/json_store/decoders.py @@ -48,6 +48,7 @@ REMOVED_TRANSFORMS, REVERSE_TRANSFORM_REGISTRY, ) +from ax.utils.common.constants import Keys from ax.utils.common.kwargs import warn_on_kwargs from ax.utils.common.logger import get_logger from ax.utils.common.typeutils_torch import torch_type_from_str @@ -158,7 +159,9 @@ def batch_trial_from_json( # the SQ at the end of this function. ) batch._index = index - batch._trial_type = trial_type + batch._trial_type = ( + trial_type if trial_type is not None else Keys.DEFAULT_TRIAL_TYPE.value + ) batch._time_created = time_created batch._time_completed = time_completed batch._time_staged = time_staged @@ -219,7 +222,9 @@ def trial_from_json( experiment=experiment, generator_run=generator_run, ttl_seconds=ttl_seconds ) trial._index = index - trial._trial_type = trial_type + trial._trial_type = ( + trial_type if trial_type is not None else Keys.DEFAULT_TRIAL_TYPE.value + ) # Swap `DISPATCHED` for `RUNNING`, since `DISPATCHED` is deprecated and nearly # equivalent to `RUNNING`. trial._status = status if status != TrialStatus.DISPATCHED else TrialStatus.RUNNING diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 92628b093f3..40faa2f6485 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -386,11 +386,15 @@ def experiment_from_sqa( experiment.data = data_by_trial_to_data(data_by_trial=data_by_trial) trial_type_to_runner = { - sqa_runner.trial_type: self.runner_from_sqa(sqa_runner) + ( + sqa_runner.trial_type + if sqa_runner.trial_type is not None + else Keys.DEFAULT_TRIAL_TYPE.value + ): self.runner_from_sqa(sqa_runner) for sqa_runner in experiment_sqa.runners } if len(trial_type_to_runner) == 0: - trial_type_to_runner = {None: None} + trial_type_to_runner = {Keys.DEFAULT_TRIAL_TYPE.value: None} experiment._trials = {trial.index: trial for trial in trials} experiment._arms_by_name = {} @@ -415,9 +419,24 @@ def experiment_from_sqa( # `_trial_type_to_runner` is set in _init_mt_experiment_from_sqa if subclass != "MultiTypeExperiment": experiment._trial_type_to_runner = cast( - dict[str | None, Runner | None], trial_type_to_runner + dict[str, Runner | None], trial_type_to_runner ) experiment.db_id = experiment_sqa.id + + # For non-MultiTypeExperiment, populate _metric_to_trial_type + # This is needed because the metrics were added directly to the experiment + # without going through the setters that populate this field. + if subclass != "MultiTypeExperiment": + default_trial_type = Keys.DEFAULT_TRIAL_TYPE.value + # Add OC metrics + oc = experiment.optimization_config + if oc is not None: + for metric_name in oc.metrics.keys(): + experiment._metric_to_trial_type[metric_name] = default_trial_type + # Add tracking metrics + for metric_name in experiment._tracking_metrics.keys(): + experiment._metric_to_trial_type[metric_name] = default_trial_type + return experiment def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: @@ -996,7 +1015,11 @@ def trial_from_sqa( reduced_state=reduced_state, immutable_search_space_and_opt_config=immutable_ss_and_oc, ) - trial._trial_type = trial_sqa.trial_type + trial._trial_type = ( + trial_sqa.trial_type + if trial_sqa.trial_type is not None + else Keys.DEFAULT_TRIAL_TYPE.value + ) # Swap `DISPATCHED` for `RUNNING`, since `DISPATCHED` is deprecated and nearly # equivalent to `RUNNING`. trial._status = ( diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 00a9568caa7..cba1364ee7d 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -226,7 +226,12 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: metric.name ] elif experiment.runner: - runners.append(self.runner_to_sqa(none_throws(experiment.runner))) + runners.append( + self.runner_to_sqa( + none_throws(experiment.runner), + trial_type=experiment.default_trial_type, + ) + ) properties = experiment._properties.copy() if ( oc := experiment.optimization_config diff --git a/ax/utils/common/constants.py b/ax/utils/common/constants.py index 2f6a6d5b4d5..a1f9fdf531b 100644 --- a/ax/utils/common/constants.py +++ b/ax/utils/common/constants.py @@ -53,6 +53,7 @@ class Keys(StrEnum): COST_INTERCEPT = "cost_intercept" CURRENT_VALUE = "current_value" DEFAULT_OBJECTIVE_NAME = "objective" + DEFAULT_TRIAL_TYPE = "default" EXPAND = "expand" EXPECTED_ACQF_VAL = "expected_acquisition_value" EXPERIMENT_TOTAL_CONCURRENT_ARMS = "total_concurrent_arms"