Skip to content

Commit

Permalink
Merge pull request #703 from adjeiv/save_sys_attrs
Browse files Browse the repository at this point in the history
Ensure trial notes are transferred when renamed
  • Loading branch information
c-bata authored Nov 30, 2023
2 parents c8dfa45 + caa2d0c commit 9eeda40
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
1 change: 1 addition & 0 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def rename_study(study_id: int) -> dict[str, Any]:
storage=storage, study_name=dst_study_name, directions=src_study.directions
)
dst_study.add_trials(src_study.get_trials(deepcopy=False))
note.copy_notes(storage, src_study, dst_study)
except DuplicatedStudyError:
response.status = 400 # Bad request
return {"reason": f"study_name={dst_study_name} is duplicaated"}
Expand Down
13 changes: 13 additions & 0 deletions optuna_dashboard/_note.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,19 @@ def note_str_key_prefix(trial_id: Optional[int]) -> str:
return f"dashboard:{trial_id}:note_str:"


def copy_notes(storage: BaseStorage, src_study: optuna.Study, dst_study: optuna.Study) -> None:
system_attrs = storage.get_study_system_attrs(study_id=src_study._study_id)

# Copy individual trial notes
for src_trial, dst_trial in zip(src_study.get_trials(), dst_study.get_trials()):
note = get_note_from_system_attrs(system_attrs, src_trial._trial_id)["body"]
save_note_with_version(storage, dst_study._study_id, dst_trial._trial_id, 0, note)

# Copy study note
note = get_note_from_system_attrs(system_attrs, None)["body"]
save_note_with_version(storage, dst_study._study_id, None, 0, note)


def get_note_from_system_attrs(system_attrs: dict[str, Any], trial_id: Optional[int]) -> NoteType:
if note_ver_key(trial_id) not in system_attrs:
return {
Expand Down
22 changes: 22 additions & 0 deletions python_tests/test_note.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,25 @@ def test_save_and_get_trial_note(self) -> None:
note_dict = note.get_note_from_system_attrs(system_attrs, trial._trial_id)
self.assertEqual(note_dict["body"], body)
self.assertEqual(note_dict["version"], expected_ver)

def test_copy_notes(self) -> None:
old_study = optuna.create_study()
old_trials = [
old_study.ask({"x1": optuna.distributions.FloatDistribution(0, 10)}) for _ in range(2)
]
storage = old_study._storage

notes = ["trial 0", "trial 1"]
for trial, body in zip(old_trials, notes):
save_note(trial, body)
save_note(old_study, "Study")

new_study = optuna.create_study(storage=storage, directions=old_study.directions)
new_study.add_trials(old_study.get_trials(deepcopy=False))

note.copy_notes(storage, old_study, new_study)
system_attrs = new_study._storage.get_study_system_attrs(new_study._study_id)
for new_trial, body in zip(new_study.get_trials(), notes):
actual = note.get_note_from_system_attrs(system_attrs, new_trial._trial_id)
self.assertEqual(actual["body"], body)
self.assertEqual(get_note(new_study), "Study")

0 comments on commit 9eeda40

Please sign in to comment.