Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise error if enum not provided for a value we're trying to encode #2587

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, config: SQAConfig) -> None:
self.config = config

def get_enum_name(
self, value: Optional[int], enum: Optional[Enum]
self, value: Optional[int], enum: Optional[Union[Enum, Type[Enum]]]
) -> Optional[str]:
"""Given an enum value (int) and an enum (of ints), return the
corresponding enum name. If the value is not present in the enum,
Expand Down
19 changes: 14 additions & 5 deletions ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from enum import Enum

from logging import Logger
from typing import Any, cast, Dict, List, Optional, Tuple, Type
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union

import plotly
import plotly.io as pio
Expand Down Expand Up @@ -123,19 +123,28 @@ def validate_experiment_metadata(
)

def get_enum_value(
self, value: Optional[str], enum: Optional[Enum]
self, value: Optional[str], enum: Optional[Union[Enum, Type[Enum]]]
) -> Optional[int]:
"""Given an enum name (string) and an enum (of ints), return the
corresponding enum value. If the name is not present in the enum,
throw an error.
"""
if value is None or enum is None:
if value is None:
return None

error = SQAEncodeError(
f"Value {value} is invalid for enum {enum}. You may be "
"using a registry or config that doesn't support the value "
"you are trying to save."
)
if enum is None:
raise error

try:
return enum[value].value # pyre-ignore T29651755
# pyre-ignore[16]: `Enum` has no attribute `__getitem__`. T29651755
return enum[value].value
except KeyError:
raise SQAEncodeError(f"Value {value} is invalid for enum {enum}.")
raise error

def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
"""Convert Ax Experiment to SQLAlchemy.
Expand Down
6 changes: 3 additions & 3 deletions ax/storage/sqa_store/sqa_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, Optional, Type
from typing import Any, Callable, Dict, Optional, Type, Union

from ax.analysis.analysis_report import AnalysisReport
from ax.analysis.base_analysis import BaseAnalysis
Expand Down Expand Up @@ -87,8 +87,8 @@ def _default_class_to_sqa_class(self=None) -> Dict[Type[Base], Type[SQABase]]:
class_to_sqa_class: Dict[Type[Base], Type[SQABase]] = field(
default_factory=_default_class_to_sqa_class
)
experiment_type_enum: Optional[Enum] = None
generator_run_type_enum: Optional[Enum] = GeneratorRunType # pyre-ignore [8]
experiment_type_enum: Optional[Union[Enum, Type[Enum]]] = None
generator_run_type_enum: Optional[Union[Enum, Type[Enum]]] = GeneratorRunType

# pyre-fixme[4]: Attribute annotation cannot contain `Any`.
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
Expand Down
30 changes: 30 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import logging
from datetime import datetime
from enum import Enum
from logging import Logger
from typing import Any
from unittest import mock
Expand Down Expand Up @@ -202,6 +203,35 @@ def test_ExperimentSaveAndLoad(self) -> None:
loaded_experiment = load_experiment(exp.name)
self.assertEqual(loaded_experiment, exp)

def test_saving_an_experiment_with_type_requires_an_enum(self) -> None:
self.experiment.experiment_type = "TEST"
with self.assertRaises(SQAEncodeError):
save_experiment(self.experiment)

def test_saving_an_experiment_with_type_works_with_an_enum(self) -> None:
class TestExperimentTypeEnum(Enum):
TEST = 0

self.experiment.experiment_type = "TEST"
save_experiment(
self.experiment,
config=SQAConfig(experiment_type_enum=TestExperimentTypeEnum),
)
self.assertIsNotNone(self.experiment.db_id)

def test_saving_an_experiment_with_type_errors_with_missing_enum_value(
self,
) -> None:
class TestExperimentTypeEnum(Enum):
NOT_TEST = 0

self.experiment.experiment_type = "TEST"
with self.assertRaises(SQAEncodeError):
save_experiment(
self.experiment,
config=SQAConfig(experiment_type_enum=TestExperimentTypeEnum),
)

def test_LoadExperimentTrialsInBatches(self) -> None:
for _ in range(4):
self.experiment.new_trial()
Expand Down