Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,15 @@ def __init__(
self._ttl_seconds: int | None = ttl_seconds
self._index: int = self._experiment._attach_trial(self, index=index)

trial_type = (
self._trial_type: str = (
trial_type
if trial_type is not None
else self._experiment.default_trial_type
)
if not self._experiment.supports_trial_type(trial_type):
if not self._experiment.supports_trial_type(self._trial_type):
raise ValueError(
f"Trial type {trial_type} is not supported by the experiment."
f"Trial type {self._trial_type} is not supported by the experiment."
)
self._trial_type = trial_type

self.__status: TrialStatus | None = None
# Uses `_status` setter, which updates trial statuses to trial indices
Expand Down Expand Up @@ -285,7 +284,7 @@ def stop_metadata(self) -> dict[str, Any]:
return self._stop_metadata

@property
def trial_type(self) -> str | None:
def trial_type(self) -> str:
"""The type of the trial.

Relevant for experiments containing different kinds of trials
Expand Down
237 changes: 202 additions & 35 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
default_data_type: Any = None,
auxiliary_experiments_by_purpose: None
| (dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]]) = None,
default_trial_type: str | None = None,
default_trial_type: str = Keys.DEFAULT_TRIAL_TYPE.value,
) -> None:
"""Inits Experiment.

Expand All @@ -123,6 +123,8 @@ def __init__(
default_data_type: Deprecated and ignored.
auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments
for different purposes (e.g., transfer learning).
default_trial_type: Default trial type for trials on this experiment.
Defaults to Keys.DEFAULT_TRIAL_TYPE.
"""
if default_data_type is not None:
warnings.warn(
Expand Down Expand Up @@ -150,10 +152,16 @@ def __init__(
self._properties: dict[str, Any] = properties or {}

# Initialize trial type to runner mapping
self._default_trial_type = default_trial_type
self._trial_type_to_runner: dict[str | None, Runner | None] = {
default_trial_type: runner
self._default_trial_type: str = (
default_trial_type or Keys.DEFAULT_TRIAL_TYPE.value
)
self._trial_type_to_runner: dict[str, Runner | None] = {
self._default_trial_type: runner
}

# Maps metric names to their trial types. Every metric must have an entry.
self._metric_to_trial_type: dict[str, str] = {}

# Used to keep track of whether any trials on the experiment
# specify a TTL. Since trials need to be checked for their TTL's
# expiration often, having this attribute helps avoid unnecessary
Expand Down Expand Up @@ -413,16 +421,46 @@ def runner(self) -> Runner | None:
def runner(self, runner: Runner | None) -> None:
"""Set the default runner and update trial type mapping."""
self._runner = runner
if runner is not None:
self._trial_type_to_runner[self._default_trial_type] = runner
else:
self._trial_type_to_runner = {None: None}
self._trial_type_to_runner[self._default_trial_type] = runner

@runner.deleter
def runner(self) -> None:
"""Delete the runner."""
self._runner = None
self._trial_type_to_runner = {None: None}
self._trial_type_to_runner[self._default_trial_type] = None

def add_trial_type(self, trial_type: str, runner: Runner) -> "Experiment":
"""Add a new trial_type to be supported by this experiment.

Args:
trial_type: The new trial_type to be added.
runner: The default runner for trials of this type.

Returns:
The experiment with the new trial type added.
"""
if self.supports_trial_type(trial_type):
raise ValueError(f"Experiment already contains trial_type `{trial_type}`")

self._trial_type_to_runner[trial_type] = runner
return self

def update_runner(self, trial_type: str, runner: Runner) -> "Experiment":
"""Update the default runner for an existing trial_type.

Args:
trial_type: The trial_type to update.
runner: The new runner for trials of this type.

Returns:
The experiment with the updated runner.
"""
if not self.supports_trial_type(trial_type):
raise ValueError(f"Experiment does not contain trial_type `{trial_type}`")

self._trial_type_to_runner[trial_type] = runner
self._runner = runner
return self

@property
def parameters(self) -> dict[str, Parameter]:
Expand Down Expand Up @@ -489,13 +527,25 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
f"`{Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF.value}` "
"property that is set to `True` on this experiment."
)

# Remove old OC metrics from trial type mapping
prev_optimization_config = self._optimization_config
if prev_optimization_config is not None:
for metric_name in prev_optimization_config.metrics.keys():
self._metric_to_trial_type.pop(metric_name, None)

for metric_name in optimization_config.metrics.keys():
if metric_name in self._tracking_metrics:
self.remove_tracking_metric(metric_name)

# add metrics from the previous optimization config that are not in the new
# optimization config as tracking metrics
prev_optimization_config = self._optimization_config
self._optimization_config = optimization_config

# Map new OC metrics to default trial type
for metric_name in optimization_config.metrics.keys():
self._metric_to_trial_type[metric_name] = self._default_trial_type

if prev_optimization_config is not None:
metrics_to_track = (
set(prev_optimization_config.metrics.keys())
Expand All @@ -505,6 +555,16 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
for metric_name in metrics_to_track:
self.add_tracking_metric(prev_optimization_config.metrics[metric_name])

# Clean up any stale entries in _metric_to_trial_type that don't correspond
# to actual metrics (can happen when same optimization_config object is
# mutated and reassigned).
current_metric_names = set(self.metrics.keys())
stale_metric_names = (
set(self._metric_to_trial_type.keys()) - current_metric_names
)
for metric_name in stale_metric_names:
self._metric_to_trial_type.pop(metric_name, None)

@property
def is_moo_problem(self) -> bool:
"""Whether the experiment's optimization config contains multiple objectives."""
Expand Down Expand Up @@ -553,12 +613,25 @@ def immutable_search_space_and_opt_config(self) -> bool:
def tracking_metrics(self) -> list[Metric]:
return list(self._tracking_metrics.values())

def add_tracking_metric(self, metric: Metric) -> Self:
def add_tracking_metric(
self,
metric: Metric,
trial_type: str | None = None,
) -> Self:
"""Add a new metric to the experiment.

Args:
metric: Metric to be added.
trial_type: The trial type for which this metric is used. If not
provided, defaults to the experiment's default trial type.
"""
effective_trial_type = (
trial_type if trial_type is not None else self._default_trial_type
)

if not self.supports_trial_type(effective_trial_type):
raise ValueError(f"`{effective_trial_type}` is not a supported trial type.")

if metric.name in self._tracking_metrics:
raise ValueError(
f"Metric `{metric.name}` already defined on experiment. "
Expand All @@ -574,33 +647,73 @@ def add_tracking_metric(self, metric: Metric) -> Self:
)

self._tracking_metrics[metric.name] = metric
self._metric_to_trial_type[metric.name] = effective_trial_type
return self

def add_tracking_metrics(self, metrics: list[Metric]) -> Self:
def add_tracking_metrics(
self,
metrics: list[Metric],
metrics_to_trial_types: dict[str, str] | None = None,
) -> Self:
"""Add a list of new metrics to the experiment.

If any of the metrics are already defined on the experiment,
we raise an error and don't add any of them to the experiment

Args:
metrics: Metrics to be added.
metrics_to_trial_types: Optional mapping from metric names to
corresponding trial types. If not provided for a metric,
the experiment's default trial type is used.
"""
# Before setting any metrics, we validate none are already on
# the experiment
metrics_to_trial_types = metrics_to_trial_types or {}
for metric in metrics:
self.add_tracking_metric(metric)
self.add_tracking_metric(
metric=metric,
trial_type=metrics_to_trial_types.get(metric.name),
)
return self

def update_tracking_metric(self, metric: Metric) -> Self:
def update_tracking_metric(
self,
metric: Metric,
trial_type: str | None = None,
) -> Self:
"""Redefine a metric that already exists on the experiment.

Args:
metric: New metric definition.
trial_type: The trial type for which this metric is used. If not
provided, keeps the existing trial type mapping.
"""
if metric.name not in self._tracking_metrics:
raise ValueError(f"Metric `{metric.name}` doesn't exist on experiment.")

# Validate trial type if provided
effective_trial_type = (
trial_type
if trial_type is not None
else self._metric_to_trial_type.get(metric.name, self._default_trial_type)
)

# Check that optimization config metrics stay on default trial type
oc = self.optimization_config
oc_metrics = oc.metrics if oc else {}
if (
metric.name in oc_metrics
and effective_trial_type != self._default_trial_type
):
raise ValueError(
f"Metric `{metric.name}` must remain a "
f"`{self._default_trial_type}` metric because it is part of the "
"optimization_config."
)

if not self.supports_trial_type(effective_trial_type):
raise ValueError(f"`{effective_trial_type}` is not a supported trial type.")

self._tracking_metrics[metric.name] = metric
self._metric_to_trial_type[metric.name] = effective_trial_type
return self

def remove_tracking_metric(self, metric_name: str) -> Self:
Expand All @@ -613,6 +726,7 @@ def remove_tracking_metric(self, metric_name: str) -> Self:
raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.")

del self._tracking_metrics[metric_name]
self._metric_to_trial_type.pop(metric_name, None)
return self

@property
Expand Down Expand Up @@ -852,8 +966,21 @@ def _fetch_trial_data(
) -> dict[str, MetricFetchResult]:
trial = self.trials[trial_index]

# If metrics are not provided, fetch all metrics on the experiment for the
# relevant trial type, or the default trial type as a fallback. Otherwise,
# fetch provided metrics.
if metrics is None:
resolved_metrics = [
metric
for metric in list(self.metrics.values())
if self._metric_to_trial_type.get(metric.name, self._default_trial_type)
== trial.trial_type
]
else:
resolved_metrics = metrics

trial_data = self._lookup_or_fetch_trials_results(
trials=[trial], metrics=metrics, **kwargs
trials=[trial], metrics=resolved_metrics, **kwargs
)

if trial_index in trial_data:
Expand Down Expand Up @@ -1548,39 +1675,79 @@ def __repr__(self) -> str:
# overridden in the MultiTypeExperiment class.

@property
def default_trial_type(self) -> str | None:
"""Default trial type assigned to trials in this experiment.

In the base experiment class this is always None. For experiments
with multiple trial types, use the MultiTypeExperiment class.
"""
def default_trial_type(self) -> str:
"""Default trial type assigned to trials in this experiment."""
return self._default_trial_type

def runner_for_trial_type(self, trial_type: str | None) -> Runner | None:
def runner_for_trial_type(self, trial_type: str) -> Runner | None:
"""The default runner to use for a given trial type.

Looks up the appropriate runner for this trial type in the trial_type_to_runner.
"""
# Special case for LONG_ and SHORT_RUN trial types, which we treat as "default"
# trial types for deployment.
if (
trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN
) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE):
return self._trial_type_to_runner[Keys.DEFAULT_TRIAL_TYPE]

if not self.supports_trial_type(trial_type):
raise ValueError(f"Trial type `{trial_type}` is not supported.")
if (runner := self._trial_type_to_runner.get(trial_type)) is None:
return self.runner # return the default runner
return runner

def supports_trial_type(self, trial_type: str | None) -> bool:
def supports_trial_type(self, trial_type: str) -> bool:
"""Whether this experiment allows trials of the given type.

The base experiment class only supports None. For experiments
with multiple trial types, use the MultiTypeExperiment class.
Checks if the trial type is registered in the trial_type_to_runner mapping.
"""
return (
trial_type is None
# We temporarily allow "short run" and "long run" trial
# types in single-type experiments during development of
# a new ``GenerationStrategy`` that needs them.
or trial_type == Keys.SHORT_RUN
or trial_type == Keys.LONG_RUN
)
# Special case for LONG_ and SHORT_RUN trial types, which we treat as "default"
# trial types for deployment.
if (
trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN
) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE):
return True

return trial_type in self._trial_type_to_runner

@property
def is_multi_type(self) -> bool:
"""Returns True if this experiment has multiple trial types registered."""
return len(self._trial_type_to_runner) > 1

@property
def metric_to_trial_type(self) -> dict[str, str]:
"""Read-only mapping of metric names to trial types."""
return self._metric_to_trial_type.copy()

def metrics_for_trial_type(self, trial_type: str) -> list[Metric]:
"""Returns metrics associated with a specific trial type.

Args:
trial_type: The trial type to get metrics for.

Returns:
List of metrics associated with the given trial type.
"""
# Special case for LONG_ and SHORT_RUN trial types, which we treat as "default"
# trial types for deployment.
if (
trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN
) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE):
return [
self.metrics[metric_name]
for metric_name, metric_trial_type in self._metric_to_trial_type.items()
if metric_trial_type == Keys.DEFAULT_TRIAL_TYPE
]

if not self.supports_trial_type(trial_type):
raise ValueError(f"Trial type `{trial_type}` is not supported.")
return [
self.metrics[metric_name]
for metric_name, metric_trial_type in self._metric_to_trial_type.items()
if metric_trial_type == trial_type
]

def attach_trial(
self,
Expand Down
Loading