Skip to content

Commit

Permalink
Replace record types magic strings with an enum
Browse files Browse the repository at this point in the history
  • Loading branch information
jwodder committed Jun 11, 2024
1 parent 9051a1b commit ddd348b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
25 changes: 20 additions & 5 deletions src/duct.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ def has_stderr(self) -> bool:
return self is Outputs.ALL or self is Outputs.STDERR


class RecordTypes(str, Enum):
ALL = "all"
SYSTEM_SUMMARY = "system-summary"
PROCESSES_SAMPLES = "processes-samples"

def __str__(self) -> str:
return self.value

Check warning on line 56 in src/duct.py

View check run for this annotation

Codecov / codecov/patch

src/duct.py#L56

Added line #L56 was not covered by tests

def has_system_summary(self) -> bool:
return self is RecordTypes.ALL or self is RecordTypes.SYSTEM_SUMMARY

def has_processes_samples(self) -> bool:
return self is RecordTypes.ALL or self is RecordTypes.PROCESSES_SAMPLES


@dataclass
class SystemInfo:
uid: str | None
Expand Down Expand Up @@ -258,7 +273,7 @@ class Arguments:
report_interval: float
capture_outputs: Outputs
outputs: Outputs
record_types: str
record_types: RecordTypes

@classmethod
def from_argv(cls) -> Arguments:
Expand Down Expand Up @@ -326,9 +341,9 @@ def from_argv(cls) -> Arguments:
parser.add_argument(
"-t",
"--record-types",
type=str,
default="all",
choices=["all", "system-summary", "processes-samples"],
choices=list(RecordTypes),
type=RecordTypes,
help="Record system-summary, processes-samples, or all",
)
args = parser.parse_args()
Expand Down Expand Up @@ -517,7 +532,7 @@ def execute(args: Arguments) -> None:
datetime_filesafe,
)
stop_event = threading.Event()
if args.record_types in ["all", "processes-samples"]:
if args.record_types.has_processes_samples():
monitoring_args = [
report,
process,
Expand All @@ -532,7 +547,7 @@ def execute(args: Arguments) -> None:
else:
monitoring_thread = None

if args.record_types in ["all", "system-summary"]:
if args.record_types.has_system_summary():
report.collect_environment()
report.get_system_info()
system_info_path = f"{args.output_prefix}info.json".format(
Expand Down
18 changes: 9 additions & 9 deletions test/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from unittest import mock
import pytest
from utils import assert_files
from duct import Arguments, Outputs, execute
from duct import Arguments, Outputs, RecordTypes, execute

TEST_SCRIPT = str(Path(__file__).with_name("data") / "test_script.py")

Expand All @@ -25,7 +25,7 @@ def test_sanity_green(temp_output_dir: str) -> None:
report_interval=60.0,
capture_outputs=Outputs.ALL,
outputs=Outputs.ALL,
record_types="all",
record_types=RecordTypes.ALL,
)
execute(args)
# When runtime < sample_interval, we won't have a usage.json
Expand All @@ -42,7 +42,7 @@ def test_sanity_red(temp_output_dir: str) -> None:
report_interval=60.0,
capture_outputs=Outputs.ALL,
outputs=Outputs.ALL,
record_types="all",
record_types=RecordTypes.ALL,
)
with mock.patch("sys.stdout", new_callable=mock.MagicMock) as mock_stdout:
execute(args)
Expand All @@ -65,7 +65,7 @@ def test_outputs_full(temp_output_dir: str) -> None:
report_interval=0.1,
capture_outputs=Outputs.ALL,
outputs=Outputs.ALL,
record_types="all",
record_types=RecordTypes.ALL,
)
execute(args)
expected_files = ["stdout", "stderr", "info.json", "usage.json"]
Expand All @@ -81,7 +81,7 @@ def test_outputs_passthrough(temp_output_dir: str) -> None:
report_interval=0.1,
capture_outputs=Outputs.NONE,
outputs=Outputs.ALL,
record_types="all",
record_types=RecordTypes.ALL,
)
execute(args)
expected_files = ["info.json", "usage.json"]
Expand All @@ -99,7 +99,7 @@ def test_outputs_capture(temp_output_dir: str) -> None:
report_interval=0.1,
capture_outputs=Outputs.ALL,
outputs=Outputs.NONE,
record_types="all",
record_types=RecordTypes.ALL,
)
execute(args)
# TODO make this work assert mock.call("this is of test of STDOUT\n") not in mock_stdout.write.mock_calls
Expand All @@ -117,7 +117,7 @@ def test_outputs_none(temp_output_dir: str) -> None:
report_interval=0.1,
capture_outputs=Outputs.NONE,
outputs=Outputs.NONE,
record_types="all",
record_types=RecordTypes.ALL,
)
execute(args)
# assert mock.call("this is of test of STDOUT\n") not in mock_stdout.write.mock_calls
Expand All @@ -138,7 +138,7 @@ def test_exit_before_first_sample(temp_output_dir: str) -> None:
report_interval=0.1,
capture_outputs=Outputs.ALL,
outputs=Outputs.NONE,
record_types="all",
record_types=RecordTypes.ALL,
)
execute(args)
expected_files = ["stdout", "stderr", "info.json"]
Expand All @@ -156,7 +156,7 @@ def test_run_less_than_report_interval(temp_output_dir: str) -> None:
report_interval=0.1,
capture_outputs=Outputs.ALL,
outputs=Outputs.NONE,
record_types="all",
record_types=RecordTypes.ALL,
)
execute(args)
# Specifically we need to assert that usage.json gets written anyway.
Expand Down

0 comments on commit ddd348b

Please sign in to comment.