From 05ccb9d07b620bcfa348fc5d6ccdb9afeadd6101 Mon Sep 17 00:00:00 2001 From: Alex Lowe Date: Tue, 21 Jan 2025 18:23:21 -0500 Subject: [PATCH] fix(project): Use an annotated type for duration strings Inspired by: https://github.com/canonical/snapcraft/pull/5210 This makes an annotated type for duration strings and uses a stricter regex. --- snapcraft/models/project.py | 36 +++++++++++++++++----------- tests/unit/models/test_projects.py | 38 ++++++++---------------------- 2 files changed, 32 insertions(+), 42 deletions(-) diff --git a/snapcraft/models/project.py b/snapcraft/models/project.py index ccfdc64b70..b37dd42b87 100644 --- a/snapcraft/models/project.py +++ b/snapcraft/models/project.py @@ -50,6 +50,7 @@ is_architecture_supported, ) +TIME_DURATION_REGEX = re.compile(r"^([0-9]+(ns|us|ms|s|m)){1,5}$") ProjectName = Annotated[str, StringConstraints(max_length=40)] @@ -286,6 +287,23 @@ def _validate_mandatory_base(base: str | None, snap_type: str | None) -> None: ) +def _validate_duration_string(duration: str): + if not TIME_DURATION_REGEX.match(duration): + raise ValueError(f"{duration!r} is not a valid time value") + + return duration + + +DurationString = Annotated[ + str, + pydantic.Field( + examples=["1", "2s", "3m", "4ms", "5us", "6m7s8ms"], + pattern=TIME_DURATION_REGEX + ), + pydantic.BeforeValidator(_validate_duration_string), +] + + class Socket(models.CraftBaseModel): """Snapcraft app socket definition.""" @@ -378,11 +396,11 @@ class App(models.CraftBaseModel): completer: str | None = None stop_command: str | None = None post_stop_command: str | None = None - start_timeout: str | None = None - stop_timeout: str | None = None - watchdog_timeout: str | None = None + start_timeout: DurationString | None = None + stop_timeout: DurationString | None = None + watchdog_timeout: DurationString | None = None reload_command: str | None = None - restart_delay: str | None = None + restart_delay: DurationString | None = None timer: str | None = None daemon: Literal["simple", "forking", "oneshot", "notify", "dbus"] | None = None after: UniqueList[str] = pydantic.Field(default_factory=list) @@ -450,16 +468,6 @@ def _validate_apps_section_content(cls, command: str) -> str: return command - @pydantic.field_validator( - "start_timeout", "stop_timeout", "watchdog_timeout", "restart_delay" - ) - @classmethod - def _validate_time(cls, timeval): - if not re.match(r"^[0-9]+(ns|us|ms|s|m)*$", timeval): - raise ValueError(f"{timeval!r} is not a valid time value") - - return timeval - @pydantic.field_validator("command_chain") @classmethod def _validate_command_chain(cls, command_chains): diff --git a/tests/unit/models/test_projects.py b/tests/unit/models/test_projects.py index 56a91ad55b..11b5995bd0 100644 --- a/tests/unit/models/test_projects.py +++ b/tests/unit/models/test_projects.py @@ -40,6 +40,8 @@ # required project data for core24 snaps CORE24_DATA = {"base": "core24", "grade": "devel"} +VALID_DURATIONS = ["10ns", "10us", "10ms", "10s", "10m", "10m4s3us"] +INVALID_DURATIONS = ["10", "10 s", "10 seconds", "1:00", "invalid"] @pytest.fixture @@ -982,19 +984,14 @@ def test_app_post_stop_command(self, app_yaml_data): assert project.apps is not None assert project.apps["app1"].post_stop_command == "test-post-stop-command" - @pytest.mark.parametrize( - "start_timeout", ["10", "10ns", "10us", "10ms", "10s", "10m"] - ) + @pytest.mark.parametrize("start_timeout", VALID_DURATIONS) def test_app_start_timeout_valid(self, start_timeout, app_yaml_data): data = app_yaml_data(start_timeout=start_timeout) project = Project.unmarshal(data) assert project.apps is not None assert project.apps["app1"].start_timeout == start_timeout - @pytest.mark.parametrize( - "start_timeout", - ["10 s", "10 seconds", "1:00", "invalid"], - ) + @pytest.mark.parametrize("start_timeout", INVALID_DURATIONS) def test_app_start_timeout_invalid(self, start_timeout, app_yaml_data): data = app_yaml_data(start_timeout=start_timeout) @@ -1002,19 +999,14 @@ def test_app_start_timeout_invalid(self, start_timeout, app_yaml_data): with pytest.raises(pydantic.ValidationError, match=error): Project.unmarshal(data) - @pytest.mark.parametrize( - "stop_timeout", ["10", "10ns", "10us", "10ms", "10s", "10m"] - ) + @pytest.mark.parametrize("stop_timeout", VALID_DURATIONS) def test_app_stop_timeout_valid(self, stop_timeout, app_yaml_data): data = app_yaml_data(stop_timeout=stop_timeout) project = Project.unmarshal(data) assert project.apps is not None assert project.apps["app1"].stop_timeout == stop_timeout - @pytest.mark.parametrize( - "stop_timeout", - ["10 s", "10 seconds", "1:00", "invalid"], - ) + @pytest.mark.parametrize("stop_timeout", INVALID_DURATIONS) def test_app_stop_timeout_invalid(self, stop_timeout, app_yaml_data): data = app_yaml_data(stop_timeout=stop_timeout) @@ -1022,19 +1014,14 @@ def test_app_stop_timeout_invalid(self, stop_timeout, app_yaml_data): with pytest.raises(pydantic.ValidationError, match=error): Project.unmarshal(data) - @pytest.mark.parametrize( - "watchdog_timeout", ["10", "10ns", "10us", "10ms", "10s", "10m"] - ) + @pytest.mark.parametrize("watchdog_timeout", VALID_DURATIONS) def test_app_watchdog_timeout_valid(self, watchdog_timeout, app_yaml_data): data = app_yaml_data(watchdog_timeout=watchdog_timeout) project = Project.unmarshal(data) assert project.apps is not None assert project.apps["app1"].watchdog_timeout == watchdog_timeout - @pytest.mark.parametrize( - "watchdog_timeout", - ["10 s", "10 seconds", "1:00", "invalid"], - ) + @pytest.mark.parametrize("watchdog_timeout", INVALID_DURATIONS) def test_app_watchdog_timeout_invalid(self, watchdog_timeout, app_yaml_data): data = app_yaml_data(watchdog_timeout=watchdog_timeout) @@ -1048,19 +1035,14 @@ def test_app_reload_command(self, app_yaml_data): assert project.apps is not None assert project.apps["app1"].reload_command == "test-reload-command" - @pytest.mark.parametrize( - "restart_delay", ["10", "10ns", "10us", "10ms", "10s", "10m"] - ) + @pytest.mark.parametrize("restart_delay", VALID_DURATIONS) def test_app_restart_delay_valid(self, restart_delay, app_yaml_data): data = app_yaml_data(restart_delay=restart_delay) project = Project.unmarshal(data) assert project.apps is not None assert project.apps["app1"].restart_delay == restart_delay - @pytest.mark.parametrize( - "restart_delay", - ["10 s", "10 seconds", "1:00", "invalid"], - ) + @pytest.mark.parametrize("restart_delay", INVALID_DURATIONS) def test_app_restart_delay_invalid(self, restart_delay, app_yaml_data): data = app_yaml_data(restart_delay=restart_delay)