From 0837423165e83b525a83e1fcbef84f5f8f800f67 Mon Sep 17 00:00:00 2001 From: Jesse Schwartzentruber Date: Thu, 30 May 2024 11:07:43 -0400 Subject: [PATCH 1/2] Use dataclass instead of namedtuple for reduction status. This fixes a few mypy errors. --- grizzly/common/status.py | 17 ++++---- grizzly/common/status_reporter.py | 64 ++++++++++++++++--------------- grizzly/common/test_status.py | 9 +++-- 3 files changed, 47 insertions(+), 43 deletions(-) diff --git a/grizzly/common/status.py b/grizzly/common/status.py index 28db5268..c8e3de2b 100644 --- a/grizzly/common/status.py +++ b/grizzly/common/status.py @@ -6,7 +6,7 @@ from collections import defaultdict from contextlib import closing, contextmanager from copy import deepcopy -from dataclasses import dataclass +from dataclasses import astuple, dataclass from json import dumps, loads from logging import getLogger from os import getpid @@ -19,7 +19,6 @@ Dict, Generator, List, - NamedTuple, Optional, Set, Tuple, @@ -826,7 +825,8 @@ def start( return status -class ReductionStep(NamedTuple): +@dataclass(frozen=True) +class ReductionStep: name: str duration: Optional[float] successes: Optional[int] @@ -835,7 +835,8 @@ class ReductionStep(NamedTuple): iterations: Optional[int] -class _MilestoneTimer(NamedTuple): +@dataclass(frozen=True) +class _MilestoneTimer: name: str start: float attempts: int @@ -980,8 +981,8 @@ def report(self, force: bool = False, report_rate: float = REPORT_RATE) -> bool: analysis = dumps(self.analysis) run_params = dumps(self.run_params) sig_info = dumps(self.signature_info) - finished = dumps(self.finished_steps) - in_prog = dumps(self._in_progress_steps) + finished = dumps([astuple(step) for step in self.finished_steps]) + in_prog = dumps([astuple(step) for step in self._in_progress_steps]) strategies = dumps(self.strategies) last_reports = dumps(self.last_reports) @@ -1126,9 +1127,7 @@ def load_all( status.run_params = loads(entry[4]) status.signature_info = loads(entry[5]) status.successes = entry[6] - status.finished_steps = [ - ReductionStep._make(step) for step in loads(entry[8]) - ] + status.finished_steps = [ReductionStep(*step) for step in loads(entry[8])] status._in_progress_steps = [ _MilestoneTimer(*step) for step in loads(entry[9]) ] diff --git a/grizzly/common/status_reporter.py b/grizzly/common/status_reporter.py index 6fe8d484..c40a4c32 100644 --- a/grizzly/common/status_reporter.py +++ b/grizzly/common/status_reporter.py @@ -5,6 +5,7 @@ """Manage Grizzly status reports.""" from argparse import ArgumentParser from collections import defaultdict +from dataclasses import astuple, fields from datetime import timedelta from functools import partial from itertools import zip_longest @@ -14,9 +15,8 @@ from pathlib import Path from platform import system from re import match -from re import sub as re_sub from time import gmtime, localtime, strftime -from typing import Callable, Dict, Generator, List, Optional, Set, Tuple, Type +from typing import Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, Type from psutil import cpu_count, cpu_percent, disk_usage, getloadavg, virtual_memory @@ -529,34 +529,34 @@ class _TableFormatter: def __init__( self, - columns: Tuple[str, ...], - formatters: Tuple[Optional[Callable[..., str]]], + column_names: Tuple[str, ...], + formatters: Tuple[Optional[Callable[..., str]], ...], vsep: str = " | ", hsep: str = "-", ) -> None: """Initialize a TableFormatter instance. Arguments: - columns: List of column names for the table header. + column_names: List of column names for the table header. formatters: List of format functions for each column. None will result in hiding that column. vsep: Vertical separation between columns. hsep: Horizontal separation between header and data. """ - assert len(columns) == len(formatters) + assert len(column_names) == len(formatters) self._columns = tuple( - column for (column, fmt) in zip(columns, formatters) if fmt is not None + column for (column, fmt) in zip(column_names, formatters) if fmt is not None ) self._formatters = formatters self._vsep = vsep self._hsep = hsep - def format_rows(self, rows: List[ReductionStep]) -> Generator[str, None, None]: + def format_rows(self, rows: Iterable[ReductionStep]) -> Generator[str, None, None]: """Format rows as a table and return a line generator. Arguments: rows: Tabular data. Each row must be the same length as - `columns` passed to `__init__`. + `column_names` passed to `__init__`. Yields: Each line of formatted tabular data. @@ -564,16 +564,17 @@ def format_rows(self, rows: List[ReductionStep]) -> Generator[str, None, None]: max_width = [len(col) for col in self._columns] formatted: List[List[str]] = [] for row in rows: - assert len(row) == len(self._formatters) + data = astuple(row) + assert len(data) == len(self._formatters) formatted.append([]) offset = 0 - for idx, (data, formatter) in enumerate(zip(row, self._formatters)): + for idx, (datum, formatter) in enumerate(zip(data, self._formatters)): if formatter is None: offset += 1 continue - data = formatter(data) - max_width[idx - offset] = max(max_width[idx - offset], len(data)) - formatted[-1].append(data) + datum_str = formatter(datum) + max_width[idx - offset] = max(max_width[idx - offset], len(datum_str)) + formatted[-1].append(datum_str) # build a format_str to space out the columns with separators using `max_width` # the first column is left-aligned, and other fields are right-aligned. @@ -588,17 +589,15 @@ def format_rows(self, rows: List[ReductionStep]) -> Generator[str, None, None]: def _format_seconds(duration: float) -> str: - # format H:M:S, and then remove all leading zeros with regex + # format H:M:S, without leading zeros minutes, seconds = divmod(int(duration), 60) hours, minutes = divmod(minutes, 60) - result = re_sub("^[0:]*", "", f"{hours}:{minutes:02d}:{seconds:02d}") - # if the result is all zeroes, ensure one zero is output - if not result: - result = "0" + if hours: + return f"{hours}:{minutes:02d}:{seconds:02d}" + if minutes: + return f"{minutes}:{seconds:02d}" # a bare number is ambiguous. output 's' for seconds - if ":" not in result: - result += "s" - return result + return f"{seconds}s" def _format_duration(duration: Optional[int], total: float = 0) -> str: @@ -823,15 +822,18 @@ def summary( # pylint: disable=arguments-differ entries.append(self._last_reports_entry(report)) if report.total and report.original: tabulator = _TableFormatter( - ReductionStep._fields, - ReductionStep( - name=str, - # duration and attempts are % of total/last, size % of init/1st - duration=partial(_format_duration, total=report.total.duration), - attempts=partial(_format_number, total=report.total.attempts), - successes=partial(_format_number, total=report.total.successes), - iterations=None, # hide - size=partial(_format_number, total=report.original.size), + tuple(f.name for f in fields(ReductionStep)), + # this tuple must match the order of fields + # defined on ReductionStep! + ( + str, # name + # duration/successes/attempts are % of total/last + partial(_format_duration, total=report.total.duration), + partial(_format_number, total=report.total.successes), + partial(_format_number, total=report.total.attempts), + # size is % of init/1st + partial(_format_number, total=report.original.size), + None, # iterations (hidden) ), ) lines.extend(tabulator.format_rows(report.finished_steps)) diff --git a/grizzly/common/test_status.py b/grizzly/common/test_status.py index 5edfdb2a..7c5abe57 100644 --- a/grizzly/common/test_status.py +++ b/grizzly/common/test_status.py @@ -5,6 +5,7 @@ # pylint: disable=protected-access from contextlib import closing +from dataclasses import fields from itertools import count from multiprocessing import Event, Process from sqlite3 import connect @@ -523,10 +524,12 @@ def test_reduce_status_06(mocker, tmp_path): assert len(loaded_status.finished_steps) == 2 assert len(loaded_status._in_progress_steps) == 0 assert loaded_status.original == status.original - for field in ReductionStep._fields: - if field == "size": + for field in fields(ReductionStep): + if field.name == "size": continue - assert getattr(loaded_status.total, field) == getattr(status.total, field) + assert getattr(loaded_status.total, field.name) == getattr( + status.total, field.name + ) assert loaded_status.total.size is None From f5b65be8a33831076b4a021219f1fe91d34e75e2 Mon Sep 17 00:00:00 2001 From: Jesse Schwartzentruber Date: Thu, 30 May 2024 11:39:34 -0400 Subject: [PATCH 2/2] Add specific tests for reducer status format helpers --- grizzly/common/status_reporter.py | 2 +- grizzly/common/test_status_reporter.py | 46 ++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/grizzly/common/status_reporter.py b/grizzly/common/status_reporter.py index c40a4c32..9b12695a 100644 --- a/grizzly/common/status_reporter.py +++ b/grizzly/common/status_reporter.py @@ -604,7 +604,7 @@ def _format_duration(duration: Optional[int], total: float = 0) -> str: result = "" if duration is not None: if total == 0: - percent = 0 # pragma: no cover + percent = 0 else: percent = int(100 * duration / total) result = _format_seconds(duration) diff --git a/grizzly/common/test_status_reporter.py b/grizzly/common/test_status_reporter.py index f6a850e0..df56bd4b 100644 --- a/grizzly/common/test_status_reporter.py +++ b/grizzly/common/test_status_reporter.py @@ -17,6 +17,9 @@ ReductionStatusReporter, StatusReporter, TracebackReport, + _format_duration, + _format_number, + _format_seconds, main, ) @@ -773,3 +776,46 @@ def test_main_04(mocker, tmp_path, report_type): assert b"Runtime" not in dump_file.read_bytes() else: assert b"Timestamp" not in dump_file.read_bytes() + + +@mark.parametrize( + "value, expected", + [ + (0, "0s"), + (100, "1:40"), + (3600, "1:00:00"), + ], +) +def test_format_seconds(value, expected): + """test _format_seconds used by TableFormatter""" + assert _format_seconds(value) == expected + + +@mark.parametrize( + "value, total, expected", + [ + (None, 0, ""), + (0, 0, "0s ( 0%)"), + (100, 0, "1:40 ( 0%)"), + (100, 200, "1:40 ( 50%)"), + (3600, 3600, "1:00:00 (100%)"), + ], +) +def test_format_duration(value, total, expected): + """test _format_duration used by TableFormatter""" + assert _format_duration(value, total) == expected + + +@mark.parametrize( + "value, total, expected", + [ + (None, 0, ""), + (0, 0, "0 ( 0%)"), + (100, 0, "100 ( 0%)"), + (100, 200, "100 ( 50%)"), + (3600, 3600, "3600 (100%)"), + ], +) +def test_format_number(value, total, expected): + """test _format_number used by TableFormatter""" + assert _format_number(value, total) == expected