diff --git a/cylc/flow/task_pool.py b/cylc/flow/task_pool.py index c0a4e89a0c2..a10adda7d98 100644 --- a/cylc/flow/task_pool.py +++ b/cylc/flow/task_pool.py @@ -88,6 +88,7 @@ from cylc.flow.task_events_mgr import TaskEventsManager from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager from cylc.flow.flow_mgr import FlowMgr, FlowNums + from typing_extensions import Literal Pool = Dict['PointBase', Dict[str, TaskProxy]] @@ -1719,10 +1720,14 @@ def set_prereqs_and_outputs( # Illegal flow command opts return - _prereqs: List[Tokens] = [ - Tokens(prereq, relative=True) - for prereq in (prereqs or []) - ] + _prereqs: 'Union[List[Tokens], Literal["all"]]' + if prereqs == ['all']: + _prereqs = 'all' + else: + _prereqs = [ + Tokens(prereq, relative=True) + for prereq in (prereqs or []) + ] # Get matching pool tasks and future task definitions. itasks, future_tasks, unmatched = self.filter_task_proxies( @@ -1786,7 +1791,7 @@ def _set_outputs_itask( def _set_prereqs_itask( self, itask: 'TaskProxy', - prereqs: List[Tokens], + prereqs: 'Union[List[Tokens], Literal["all"]]', flow_nums: Set[int], flow_wait: bool ) -> None: @@ -1795,7 +1800,7 @@ def _set_prereqs_itask( Prerequisite format: "cycle/task:message" or "all". """ - if prereqs == ["all"]: + if prereqs == "all": itask.state.set_all_satisfied() else: itask.satisfy_me(prereqs) diff --git a/tests/integration/scripts/test_set.py b/tests/integration/scripts/test_set.py index e364384a09a..17c1555d6b7 100644 --- a/tests/integration/scripts/test_set.py +++ b/tests/integration/scripts/test_set.py @@ -149,3 +149,15 @@ async def test_incomplete_detection( async with start(schd) as log: schd.pool.set_prereqs_and_outputs(['1/one'], ['failed'], None, ['1']) assert log_filter(log, contains='1/one did not complete') + + +async def test_pre_all(flow, scheduler, run): + """Ensure that --pre=all is interpreted as a special case + and _not_ tokenized. + """ + id_ = flow({'scheduling': {'graph': {'R1': 'a => z'}}}) + schd = scheduler(id_, paused_start=False) + async with run(schd) as log: + schd.pool.set_prereqs_and_outputs(['1/z'], [], ['all'], ['all']) + warn_or_higher = [i for i in log.records if i.levelno > 20] + assert warn_or_higher == [] diff --git a/tests/integration/test_task_pool.py b/tests/integration/test_task_pool.py index c2036d8882d..facfe1f5c5a 100644 --- a/tests/integration/test_task_pool.py +++ b/tests/integration/test_task_pool.py @@ -1393,8 +1393,8 @@ async def test_set_outputs_live( 'runtime': { 'foo': { 'outputs': { - 'x': 'x', - 'y': 'y' + 'x': 'xylophone', + 'y': 'yacht' } } } @@ -1502,8 +1502,8 @@ async def test_prereq_satisfaction( 'runtime': { 'a': { 'outputs': { - 'x': 'x', - 'y': 'y' + 'x': 'xylophone', + 'y': 'yacht' } } }