From 8ce16b0c6c621e206edc4d0e50b264bf33bb9789 Mon Sep 17 00:00:00 2001 From: VictorS67 <185000048@qq.com> Date: Fri, 17 Jun 2022 16:32:24 -0400 Subject: [PATCH] fix duplicate for update batch --- mooclet_engine/engine/utils/data_downloader_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mooclet_engine/engine/utils/data_downloader_utils.py b/mooclet_engine/engine/utils/data_downloader_utils.py index e3a6ecc..e0daa67 100644 --- a/mooclet_engine/engine/utils/data_downloader_utils.py +++ b/mooclet_engine/engine/utils/data_downloader_utils.py @@ -152,7 +152,6 @@ def without_keys(d, keys): if parameters and not policy_params and "update_record" in parameters: record_df = pd.concat([record_df, pd.DataFrame(parameters["update_record"])]) - update_datapoint_count = 0 for version_index, version_row in version_df.iterrows(): datapoint_dict = { "study": version_row["study"], @@ -210,7 +209,7 @@ def without_keys(d, keys): reward_datapoint.update(context_datapoint) # Check if reward is used for updating parameters - if update_datapoint_count < len(record_df.index): + if len(record_df.index) > 0: # Get all checking conditions check_update = record_df["user_id"] == version_row["learner_id"] check_update &= record_df[reward_row["name"]] == reward_row["value"] @@ -220,7 +219,7 @@ def without_keys(d, keys): if len(record_df[check_update].index) != 0: # We have a updating datapoint reward_datapoint["update_group"] = update_group - update_datapoint_count += 1 + record_df = record_df.iloc[1:, :].reset_index(drop=True) reward_datapoint.update(datapoint_dict)