Skip to content

Commit

Permalink
Extract analysis base to be used by client and scheduler (facebook#3136)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#3136

This will allow us to also use clients to generate analyses.

Reviewed By: jelena-markovic

Differential Revision: D66706329

fbshipit-source-id: 7c3bf5438ffe74a0783534f92a8313bea9049d8b
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Dec 6, 2024
1 parent e0ff0b4 commit d709c5d
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 74 deletions.
5 changes: 3 additions & 2 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,15 @@
from ax.plot.feature_importances import plot_feature_importance_by_feature
from ax.plot.helper import _format_dict
from ax.plot.trace import optimization_trace_single_method
from ax.service.utils.analysis_base import AnalysisBase
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.service.utils.instantiation import (
FixedFeatures,
InstantiationBase,
ObjectiveProperties,
)
from ax.service.utils.report_utils import exp_to_df
from ax.service.utils.with_db_settings_base import DBSettings, WithDBSettingsBase
from ax.service.utils.with_db_settings_base import DBSettings
from ax.storage.json_store.decoder import (
generation_strategy_from_json,
object_from_json,
Expand Down Expand Up @@ -108,7 +109,7 @@
)


class AxClient(WithDBSettingsBase, BestPointMixin, InstantiationBase):
class AxClient(AnalysisBase, BestPointMixin, InstantiationBase):
"""
Convenience handler for management of experimentation cycle through a
service-like API. External system manages scheduling of the cycle and makes
Expand Down
74 changes: 2 additions & 72 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from __future__ import annotations

import traceback

from collections.abc import Callable, Generator, Iterable, Mapping
from copy import deepcopy
from dataclasses import dataclass
Expand All @@ -20,10 +18,6 @@
from typing import Any, cast, NamedTuple, Optional

import ax.service.utils.early_stopping as early_stopping_utils
import pandas as pd
from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE
from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard
from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
Expand Down Expand Up @@ -57,6 +51,7 @@
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.modelbridge_utils import get_fixed_features_from_experiment
from ax.service.utils.analysis_base import AnalysisBase
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.service.utils.scheduler_options import SchedulerOptions, TrialType
from ax.service.utils.with_db_settings_base import DBSettings, WithDBSettingsBase
Expand All @@ -70,7 +65,6 @@
set_ax_logger_levels,
)
from ax.utils.common.timeutils import current_timestamp_in_millis
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import assert_is_instance, none_throws


Expand Down Expand Up @@ -151,7 +145,7 @@ def append(self, text: str) -> None:
self.text += text


class Scheduler(WithDBSettingsBase, BestPointMixin):
class Scheduler(AnalysisBase, BestPointMixin):
"""Closed-loop manager class for Ax optimization.
Attributes:
Expand Down Expand Up @@ -679,62 +673,6 @@ def run_all_trials(
idle_callback=idle_callback,
)

def compute_analyses(
self, analyses: Iterable[Analysis] | None = None
) -> list[AnalysisCard]:
"""
Compute Analyses for the Experiment and GenerationStrategy associated with this
Scheduler instance and save them to the DB if possible. If an Analysis fails to
compute (e.g. due to a missing metric), it will be skipped and a warning will
be logged.
Args:
analyses: Analyses to compute. If None, the Scheduler will choose a set of
Analyses to compute based on the Experiment and GenerationStrategy.
"""
analyses = analyses if analyses is not None else self._choose_analyses()

results = [
analysis.compute_result(
experiment=self.experiment, generation_strategy=self.generation_strategy
)
for analysis in analyses
]

# TODO Accumulate Es into their own card, perhaps via unwrap_or_else
cards = [result.unwrap() for result in results if result.is_ok()]

for result in results:
if result.is_err():
e = checked_cast(AnalysisE, result.err)
traceback_str = "".join(
traceback.format_exception(
type(result.err.exception),
e.exception,
e.exception.__traceback__,
)
)
cards.append(
MarkdownAnalysisCard(
name=e.analysis.name,
# It would be better if we could reliably compute the title
# without risking another error
title=f"{e.analysis.name} Error",
subtitle=f"An error occurred while computing {e.analysis}",
attributes=e.analysis.attributes,
blob=traceback_str,
df=pd.DataFrame(),
level=AnalysisCardLevel.DEBUG,
)
)

self._save_analysis_cards_to_db_if_possible(
analysis_cards=cards,
experiment=self.experiment,
)

return cards

def run_trials_and_yield_results(
self,
max_trials: int,
Expand Down Expand Up @@ -1882,14 +1820,6 @@ def _get_next_trials(
trials.append(trial)
return trials, None

def _choose_analyses(self) -> list[Analysis]:
"""
Choose Analyses to compute based on the Experiment, GenerationStrategy, etc.
"""

# TODO Create a useful heuristic for choosing analyses
return [ParallelCoordinatesPlot()]

def _gen_new_trials_from_generation_strategy(
self,
num_trials: int,
Expand Down
97 changes: 97 additions & 0 deletions ax/service/utils/analysis_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
import traceback
from typing import Iterable

import pandas as pd

from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE
from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard
from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.service.utils.with_db_settings_base import WithDBSettingsBase
from ax.utils.common.typeutils import checked_cast


class AnalysisBase(WithDBSettingsBase):
"""
Base class for analysis functionality shared between AxClient and Scheduler.
"""

# pyre-fixme[13]: Attribute `experiment` is declared in class
# `AnalysisBase` to have type `Experiment` but is never initialized
experiment: Experiment
# pyre-fixme[13]: Attribute `generation_strategy` is declared in class
# `AnalysisBase` to have type `GenerationStrategyInterface` but
# is never initialized
generation_strategy: GenerationStrategyInterface

def _choose_analyses(self) -> list[Analysis]:
"""
Choose Analyses to compute based on the Experiment, GenerationStrategy, etc.
"""

# TODO Create a useful heuristic for choosing analyses
return [ParallelCoordinatesPlot()]

def compute_analyses(
self, analyses: Iterable[Analysis] | None = None
) -> list[AnalysisCard]:
"""
Compute Analyses for the Experiment and GenerationStrategy associated with this
Scheduler instance and save them to the DB if possible. If an Analysis fails to
compute (e.g. due to a missing metric), it will be skipped and a warning will
be logged.
Args:
analyses: Analyses to compute. If None, the Scheduler will choose a set of
Analyses to compute based on the Experiment and GenerationStrategy.
"""
analyses = analyses if analyses is not None else self._choose_analyses()

results = [
analysis.compute_result(
experiment=self.experiment,
generation_strategy=self.generation_strategy,
)
for analysis in analyses
]

# TODO Accumulate Es into their own card, perhaps via unwrap_or_else
cards = [result.unwrap() for result in results if result.is_ok()]

for result in results:
if result.is_err():
e = checked_cast(AnalysisE, result.err)
traceback_str = "".join(
traceback.format_exception(
type(result.err.exception),
e.exception,
e.exception.__traceback__,
)
)
cards.append(
MarkdownAnalysisCard(
name=e.analysis.name,
# It would be better if we could reliably compute the title
# without risking another error
title=f"{e.analysis.name} Error",
subtitle=f"An error occurred while computing {e.analysis}",
attributes=e.analysis.attributes,
blob=traceback_str,
df=pd.DataFrame(),
level=AnalysisCardLevel.DEBUG,
)
)

self._save_analysis_cards_to_db_if_possible(
analysis_cards=cards,
experiment=self.experiment,
)

return cards
9 changes: 9 additions & 0 deletions sphinx/source/service.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ Scheduler
Utils
-----

Analysis
~~~~~~~~

.. automodule:: ax.service.utils.analysis_base
:members:
:undoc-members:
:show-inheritance:


Best Point Identification
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down

0 comments on commit d709c5d

Please sign in to comment.