Skip to content

Commit

Permalink
Added more detail on expected return from functions in cylc/flow/dbst…
Browse files Browse the repository at this point in the history
…atecheck.py
  • Loading branch information
wxtim committed Apr 30, 2024
1 parent b514fb8 commit 551f0bd
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 12 deletions.
18 changes: 13 additions & 5 deletions cylc/flow/dbstatecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import sqlite3
import sys
from typing import Optional
from typing import Optional, Union
from textwrap import dedent

from cylc.flow.pathutil import expand_path
Expand Down Expand Up @@ -89,8 +89,11 @@ def display_maps(res):
for row in res:
sys.stdout.write((", ").join([str(s) for s in row]) + "\n")

def _get_pt_fmt(self):
"""Query a workflow database for a 'cycle point format' entry"""
def _get_pt_fmt(self) -> Union[None, str]:
"""Query a workflow database for a 'cycle point format' entry
Returns: None if Cycle point is integer, else a format string.
"""
for row in self.conn.execute(dedent(
rf'''
SELECT
Expand All @@ -103,9 +106,13 @@ def _get_pt_fmt(self):
['cycle_point_format']
):
return row[0]
return None

def _get_pt_fmt_compat(self):
"""Query a Cylc 7 suite database for 'cycle point format'."""
def _get_pt_fmt_compat(self) -> Union[None, str]:
"""Query a Cylc 7 suite database for 'cycle point format'.
Returns: None if Cycle point is integer, else a format string.
"""
# BACK COMPAT: Cylc 7 DB
# Workflows parameters table name change.
# from:
Expand All @@ -126,6 +133,7 @@ def _get_pt_fmt_compat(self):
['cycle_point_format']
):
return row[0]
return None

def state_lookup(self, state):
"""Allows for multiple states to be searched via a status alias."""
Expand Down
14 changes: 10 additions & 4 deletions cylc/flow/xtriggers/workflow_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def workflow_state(
point = str(add_offset(point, offset))

# Failure to connect to DB will raise exceptions here.
# It could mean the target workflow as not started yet,
# It could mean the target workflow has not started yet,
# but it could also mean a typo in the workflow ID, so
# so don't hide the error.
checker = CylcWorkflowDBChecker(cylc_run_dir, workflow)
Expand Down Expand Up @@ -128,7 +128,7 @@ def validate(args: Dict[str, Any]):
The rules for are:
* output/status: one at most (defaults to succeeded status)
* flow_num: Must be an integer
* flow_num: Must be a positive integer
* status: Must be a valid status
"""
Expand All @@ -146,7 +146,13 @@ def validate(args: Dict[str, Any]):
f"Invalid tasks status '{status}'"
)

if flow_num is not None and not isinstance(flow_num, int):
if (
flow_num is not None
and (
not isinstance(flow_num, int)
or flow_num < 0
)
):
raise WorkflowConfigError(
"flow_num must be an integer"
"flow_num must be a positive integer"
)
32 changes: 29 additions & 3 deletions tests/unit/xtriggers/test_workflow_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
from unittest.mock import Mock
from shutil import copytree, rmtree

from cylc.flow.exceptions import InputError
from cylc.flow.pathutil import get_cylc_run_dir
from cylc.flow.exceptions import InputError, WorkflowConfigError
from cylc.flow.workflow_files import WorkflowFiles
from cylc.flow.xtriggers.workflow_state import workflow_state
from cylc.flow.xtriggers.workflow_state import workflow_state, validate
from ..conftest import MonkeyMock


Expand Down Expand Up @@ -116,3 +115,30 @@ def test_back_compat(tmp_run_dir, caplog):
assert satisfied
satisfied, _ = suite_state(suite=id_, task='arkenstone', point='2012')
assert not satisfied


@pytest.mark.parametrize(
'args',
(
('foo', None, 1),
(None, 'failed', 42),
)
)
def test_validate_ok(args):
args = dict(zip(['output', 'status', 'flow_num'], args))
validate(args)


@pytest.mark.parametrize(
'args, error_re',
(
(('foo', 'failed', 2), r'not both$'),
((None, 'fried', 3), r"tasks status 'fried'$"),
((None, 'failed', 2.030481542), r'positive integer'),
((None, 'failed', -1), r'positive integer'),
)
)
def test_validate_fail(args, error_re):
args = dict(zip(['output', 'status', 'flow_num'], args))
with pytest.raises(WorkflowConfigError, match=error_re):
validate(args)

0 comments on commit 551f0bd

Please sign in to comment.