Skip to content

Commit

Permalink
Raise error if enum not provided for a value we're trying to encode (#…
Browse files Browse the repository at this point in the history
…2587)

Summary:
Pull Request resolved: #2587

We probably want to let the user know if they're trying to encode/save a value we're going to drop.  It turns out prior to this diff, a `None` for the enum that would be used to encode the value is treated differently than an actual enum that just happens to be missing the value.  So a missing enum for experiment type would not raise an error if experiment type was passed.

This should debatably only be done in test mode though.  Silent failures are generally bad though and we could cause more problems down the road by not saving data we need to.

Reviewed By: Cesar-Cardoso

Differential Revision: D59925061
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Jul 19, 2024
1 parent 9679f7b commit 643e0c1
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
15 changes: 12 additions & 3 deletions ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,22 @@ def get_enum_value(
corresponding enum value. If the name is not present in the enum,
throw an error.
"""
if value is None or enum is None:
if value is None:
return None

error = SQAEncodeError(
f"Value {value} is invalid for enum {enum}. You may be "
"using a registry or config that doesn't support the value "
"you are trying to save."
)
if enum is None:
raise error

try:
return enum[value].value # pyre-ignore T29651755
# pyre-ignore[16]: `Enum` has no attribute `__getitem__`. T29651755
return enum[value].value
except KeyError:
raise SQAEncodeError(f"Value {value} is invalid for enum {enum}.")
raise error

def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
"""Convert Ax Experiment to SQLAlchemy.
Expand Down
36 changes: 36 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import logging
from datetime import datetime
from enum import Enum
from logging import Logger
from typing import Any
from unittest import mock
Expand Down Expand Up @@ -202,6 +203,41 @@ def test_ExperimentSaveAndLoad(self) -> None:
loaded_experiment = load_experiment(exp.name)
self.assertEqual(loaded_experiment, exp)

def test_saving_an_experiment_with_type_requires_an_enum(self) -> None:
self.experiment.experiment_type = "TEST"
with self.assertRaises(SQAEncodeError):
save_experiment(self.experiment)

def test_saving_an_experiment_with_type_works_with_an_enum(self) -> None:
class TestExperimentTypeEnum(Enum):
TEST = 0

self.experiment.experiment_type = "TEST"
save_experiment(
self.experiment,
# pyre-fixme[6]: In call `SQAConfig.__init__`, for argument
# `experiment_type_enum`, expected `Optional[Enum]` but got
# `Type[TestExperimentTypeEnum]`.
config=SQAConfig(experiment_type_enum=TestExperimentTypeEnum),
)
self.assertIsNotNone(self.experiment.db_id)

def test_saving_an_experiment_with_type_errors_with_missing_enum_value(
self,
) -> None:
class TestExperimentTypeEnum(Enum):
NOT_TEST = 0

self.experiment.experiment_type = "TEST"
with self.assertRaises(SQAEncodeError):
save_experiment(
self.experiment,
# pyre-fixme[6]: In call `SQAConfig.__init__`, for argument
# `experiment_type_enum`, expected `Optional[Enum]` but got
# `Type[TestExperimentTypeEnum]`.
config=SQAConfig(experiment_type_enum=TestExperimentTypeEnum),
)

def test_LoadExperimentTrialsInBatches(self) -> None:
for _ in range(4):
self.experiment.new_trial()
Expand Down

0 comments on commit 643e0c1

Please sign in to comment.