From ceb2293a8bab6672cf82d4401bdb976b7ca2ba17 Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Wed, 30 Oct 2024 20:55:45 -0700 Subject: [PATCH] Include progression information as metadata when transforming Data to Observations Differential Revision: D65255312 --- ax/core/observation.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/ax/core/observation.py b/ax/core/observation.py index 7dd6a7bd768..284f0848a00 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -439,6 +439,7 @@ def observations_from_data( data: Data, statuses_to_include: set[TrialStatus] | None = None, statuses_to_include_map_metric: set[TrialStatus] | None = None, + map_keys_as_parameters: bool = False, ) -> list[Observation]: """Convert Data to observations. @@ -455,46 +456,55 @@ def observations_from_data( with statuses in this set. Defaults to all statuses except abandoned. statuses_to_include_map_metric: data from MapMetrics will only be included for trials with statuses in this set. Defaults to completed status only. + map_keys_as_parameters: Whether map_keys should be returned as part of + the parameters of the Observation objects. Returns: List of Observation objects. """ + is_map_data = isinstance(data, MapData) + df = data.df if not is_map_data else data.map_df + if statuses_to_include is None: statuses_to_include = NON_ABANDONED_STATUSES if statuses_to_include_map_metric is None: statuses_to_include_map_metric = {TrialStatus.COMPLETED} - feature_cols = get_feature_cols(data) - observations = [] + + feature_cols = get_feature_cols(data, is_map_data=is_map_data) + arm_name_only = len(feature_cols) == 1 # there will always be an arm name # One DataFrame where all rows have all features. - isnull = data.df[feature_cols].isnull() + isnull = df[feature_cols].isnull() isnull_any = isnull.any(axis=1) incomplete_df_cols = isnull[isnull_any].any() # Get the incomplete_df columns that are complete, and usable as groupby keys. + obs_cols = OBS_COLS if not is_map_data else OBS_COLS.union(data.map_keys) complete_feature_cols = list( - OBS_COLS.intersection(incomplete_df_cols.index[~incomplete_df_cols]) + obs_cols.intersection(incomplete_df_cols.index[~incomplete_df_cols]) ) if set(feature_cols) == set(complete_feature_cols): - complete_df = data.df + complete_df = df incomplete_df = None else: # The groupby and filter is expensive, so do it only if we have to. - grouped = data.df.groupby(by=complete_feature_cols) + grouped = df.groupby(by=complete_feature_cols) complete_df = grouped.filter(lambda r: ~r[feature_cols].isnull().any().any()) incomplete_df = grouped.filter(lambda r: r[feature_cols].isnull().any().any()) # Get Observations from complete_df + observations = [] observations.extend( _observations_from_dataframe( experiment=experiment, df=complete_df, cols=feature_cols, arm_name_only=arm_name_only, + map_keys=[] if not is_map_data else data.map_keys, statuses_to_include=statuses_to_include, statuses_to_include_map_metric=statuses_to_include_map_metric, - map_keys=[], + map_keys_as_parameters=map_keys_as_parameters, ) ) if incomplete_df is not None: @@ -505,9 +515,10 @@ def observations_from_data( df=incomplete_df, cols=complete_feature_cols, arm_name_only=arm_name_only, + map_keys=[] if not is_map_data else data.map_keys, statuses_to_include=statuses_to_include, statuses_to_include_map_metric=statuses_to_include_map_metric, - map_keys=[], + map_keys_as_parameters=map_keys_as_parameters, ) ) return observations