From 83beb34640e65e88f32bfb417a93da19dbe4ee29 Mon Sep 17 00:00:00 2001 From: "John T. Wodder II" Date: Tue, 11 Jun 2024 14:28:58 -0400 Subject: [PATCH] Replace record types magic strings with an enum --- src/duct.py | 25 ++++++++++++++++++++----- test/test_execution.py | 18 +++++++++--------- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/src/duct.py b/src/duct.py index 7b40ad58..80c28d28 100755 --- a/src/duct.py +++ b/src/duct.py @@ -47,6 +47,21 @@ def has_stderr(self) -> bool: return self is Outputs.ALL or self is Outputs.STDERR +class RecordTypes(str, Enum): + ALL = "all" + SYSTEM_SUMMARY = "system-summary" + PROCESSES_SAMPLES = "processes-samples" + + def __str__(self) -> str: + return self.value + + def has_system_summary(self) -> bool: + return self is RecordTypes.ALL or self is RecordTypes.SYSTEM_SUMMARY + + def has_processes_samples(self) -> bool: + return self is RecordTypes.ALL or self is RecordTypes.PROCESSES_SAMPLES + + @dataclass class SystemInfo: uid: str | None @@ -258,7 +273,7 @@ class Arguments: report_interval: float capture_outputs: Outputs outputs: Outputs - record_types: str + record_types: RecordTypes @classmethod def from_argv(cls) -> Arguments: @@ -326,9 +341,9 @@ def from_argv(cls) -> Arguments: parser.add_argument( "-t", "--record-types", - type=str, default="all", - choices=["all", "system-summary", "processes-samples"], + choices=list(RecordTypes), + type=RecordTypes, help="Record system-summary, processes-samples, or all", ) args = parser.parse_args() @@ -517,7 +532,7 @@ def execute(args: Arguments) -> None: datetime_filesafe, ) stop_event = threading.Event() - if args.record_types in ["all", "processes-samples"]: + if args.record_types.has_processes_samples(): monitoring_args = [ report, process, @@ -532,7 +547,7 @@ def execute(args: Arguments) -> None: else: monitoring_thread = None - if args.record_types in ["all", "system-summary"]: + if args.record_types.has_system_summary(): report.collect_environment() report.get_system_info() system_info_path = f"{args.output_prefix}info.json".format( diff --git a/test/test_execution.py b/test/test_execution.py index 305c0515..7b67d65d 100644 --- a/test/test_execution.py +++ b/test/test_execution.py @@ -4,7 +4,7 @@ from unittest import mock import pytest from utils import assert_files -from duct import Arguments, Outputs, execute +from duct import Arguments, Outputs, RecordTypes, execute TEST_SCRIPT = str(Path(__file__).with_name("data") / "test_script.py") @@ -25,7 +25,7 @@ def test_sanity_green(temp_output_dir: str) -> None: report_interval=60.0, capture_outputs=Outputs.ALL, outputs=Outputs.ALL, - record_types="all", + record_types=RecordTypes.ALL, ) execute(args) # When runtime < sample_interval, we won't have a usage.json @@ -42,7 +42,7 @@ def test_sanity_red(temp_output_dir: str) -> None: report_interval=60.0, capture_outputs=Outputs.ALL, outputs=Outputs.ALL, - record_types="all", + record_types=RecordTypes.ALL, ) with mock.patch("sys.stdout", new_callable=mock.MagicMock) as mock_stdout: execute(args) @@ -65,7 +65,7 @@ def test_outputs_full(temp_output_dir: str) -> None: report_interval=0.1, capture_outputs=Outputs.ALL, outputs=Outputs.ALL, - record_types="all", + record_types=RecordTypes.ALL, ) execute(args) expected_files = ["stdout", "stderr", "info.json", "usage.json"] @@ -81,7 +81,7 @@ def test_outputs_passthrough(temp_output_dir: str) -> None: report_interval=0.1, capture_outputs=Outputs.NONE, outputs=Outputs.ALL, - record_types="all", + record_types=RecordTypes.ALL, ) execute(args) expected_files = ["info.json", "usage.json"] @@ -99,7 +99,7 @@ def test_outputs_capture(temp_output_dir: str) -> None: report_interval=0.1, capture_outputs=Outputs.ALL, outputs=Outputs.NONE, - record_types="all", + record_types=RecordTypes.ALL, ) execute(args) # TODO make this work assert mock.call("this is of test of STDOUT\n") not in mock_stdout.write.mock_calls @@ -117,7 +117,7 @@ def test_outputs_none(temp_output_dir: str) -> None: report_interval=0.1, capture_outputs=Outputs.NONE, outputs=Outputs.NONE, - record_types="all", + record_types=RecordTypes.ALL, ) execute(args) # assert mock.call("this is of test of STDOUT\n") not in mock_stdout.write.mock_calls @@ -138,7 +138,7 @@ def test_exit_before_first_sample(temp_output_dir: str) -> None: report_interval=0.1, capture_outputs=Outputs.ALL, outputs=Outputs.NONE, - record_types="all", + record_types=RecordTypes.ALL, ) execute(args) expected_files = ["stdout", "stderr", "info.json"] @@ -156,7 +156,7 @@ def test_run_less_than_report_interval(temp_output_dir: str) -> None: report_interval=0.1, capture_outputs=Outputs.ALL, outputs=Outputs.NONE, - record_types="all", + record_types=RecordTypes.ALL, ) execute(args) # Specifically we need to assert that usage.json gets written anyway.