Skip to content

Commit

Permalink
Add type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
jwodder committed Jun 11, 2024
1 parent 1c0bfee commit de5929d
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 118 deletions.
17 changes: 16 additions & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ jobs:
- 'pypy-3.8'
- 'pypy-3.9'
- 'pypy-3.10'
toxenv: [py]
include:
- python-version: '3.8'
toxenv: lint
os: ubuntu-latest
- python-version: '3.8'
toxenv: typing
os: ubuntu-latest
steps:
- name: Check out repository
uses: actions/checkout@v4
Expand All @@ -53,13 +61,20 @@ jobs:
python -m pip install --upgrade pip wheel
python -m pip install --upgrade --upgrade-strategy=eager tox
- name: Run tests
- name: Run tests with coverage
if: matrix.toxenv == 'py'
run: tox -e py -- -vv --cov-report=xml

- name: Run generic tests
if: matrix.toxenv != 'py'
run: tox -e ${{ matrix.toxenv }}

- name: Upload coverage to Codecov
if: matrix.toxenv == 'py'
uses: codecov/codecov-action@v4
with:
fail_ci_if_error: false
token: ${{ secrets.CODECOV_TOKEN }}
name: ${{ matrix.python-version }}

# vim:set et sts=2:
31 changes: 15 additions & 16 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,18 @@ python_requires = >= 3.8
console_scripts =
duct = duct:main

# TODO(asmacdo)
# [mypy]
# ignore_missing_imports = False
# disallow_untyped_defs = True
# disallow_incomplete_defs = True
# no_implicit_optional = True
# warn_redundant_casts = True
# warn_return_any = True
# warn_unreachable = True
# local_partial_types = True
# no_implicit_reexport = True
# strict_equality = True
# show_error_codes = True
# show_traceback = True
# pretty = True
#
[mypy]
allow_incomplete_defs = False
allow_untyped_defs = False
ignore_missing_imports = False
# <https://github.com/python/mypy/issues/7773>:
no_implicit_optional = True
implicit_reexport = False
local_partial_types = True
pretty = True
show_error_codes = True
show_traceback = True
strict_equality = True
warn_redundant_casts = True
warn_return_any = True
warn_unreachable = True
111 changes: 61 additions & 50 deletions src/duct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations
import argparse
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
import json
Expand All @@ -12,7 +13,7 @@
import sys
import threading
import time
from typing import Any, Dict, Optional, TextIO, Tuple, Union
from typing import IO, Any, TextIO

__version__ = "0.0.1"
ENV_PREFIXES = ("PBS_", "SLURM_", "OSG")
Expand All @@ -33,50 +34,43 @@ class Colors:
class Report:
"""Top level report"""

start_time: float
command: str
session_id: int
gpus: Optional[list]
number: int
system_info: Dict[str, Any] # Use more specific types if possible

def __init__(
self,
command: str,
arguments,
session_id: int,
arguments: list[str],
session_id: int | None,
output_prefix: str,
process,
datetime_filesafe,
process: subprocess.Popen,
datetime_filesafe: str,
) -> None:
self.start_time = time.time()
self._command = command
self.arguments = arguments
self.session_id = session_id
self.gpus = []
self.env = None
self.gpus: list | None = []
self.env: dict[str, str] | None = None
self.number = 0
self.system_info = {}
self.system_info: dict[str, Any] = {} # Use more specific types if possible
self.output_prefix = output_prefix
self.max_values = defaultdict(dict)
self.max_values: dict[str, dict[str, Any]] = defaultdict(dict)
self.process = process
self._sample = defaultdict(dict)
self._sample: dict[str, dict[str, Any]] = defaultdict(dict)
self.datetime_filesafe = datetime_filesafe
self.end_time: float | None = None
self.run_time_seconds: str | None = None

@property
def command(self):
def command(self) -> str:
return " ".join([self._command] + self.arguments)

@property
def elapsed_time(self):
def elapsed_time(self) -> float:
return time.time() - self.start_time

def collect_environment(self):
self.env = (
{k: v for k, v in os.environ.items() if k.startswith(ENV_PREFIXES)},
)
def collect_environment(self) -> None:
self.env = {k: v for k, v in os.environ.items() if k.startswith(ENV_PREFIXES)}

def get_system_info(self):
def get_system_info(self) -> None:
"""Gathers system information related to CPU, GPU, memory, and environment variables."""
self.system_info["uid"] = os.environ.get("USER")
self.system_info["memory_total"] = os.sysconf("SC_PAGE_SIZE") * os.sysconf(
Expand Down Expand Up @@ -106,7 +100,9 @@ def get_system_info(self):
except subprocess.CalledProcessError:
self.gpus = ["Failed to query GPU info"]

def calculate_total_usage(self, sample):
def calculate_total_usage(
self, sample: dict[str, dict[str, Any]]
) -> dict[str, dict[str, float]]:
pmem = 0.0
pcpu = 0.0
for _pid, pinfo in sample.items():
Expand All @@ -116,16 +112,19 @@ def calculate_total_usage(self, sample):
return totals

@staticmethod
def update_max_resources(maxes, sample):
def update_max_resources(
maxes: dict[str, dict[str, Any]], sample: dict[str, Any]
) -> None:
for pid in sample:
if pid in maxes:
for key, value in sample[pid].items():
maxes[pid][key] = max(maxes[pid].get(key, value), value)
else:
maxes[pid] = sample[pid].copy()

def collect_sample(self):
process_data = {}
def collect_sample(self) -> dict[str, dict[str, int | float | str]]:
assert self.session_id is not None
process_data: dict[str, dict[str, int | float | str]] = {}
try:
output = subprocess.check_output(
[
Expand All @@ -140,7 +139,6 @@ def collect_sample(self):
for line in output.splitlines()[1:]:
if line:
pid, pcpu, pmem, rss, vsz, etime, cmd = line.split(maxsplit=6)

process_data[pid] = {
# %CPU
"pcpu": float(pcpu),
Expand All @@ -156,16 +154,16 @@ def collect_sample(self):
pass
return process_data

def write_pid_samples(self):
def write_pid_samples(self) -> None:
resource_stats_log_path = f"{self.output_prefix}usage.json"
with open(resource_stats_log_path, "a") as resource_statistics_log:
resource_statistics_log.write(json.dumps(self._sample) + "\n")

def print_max_values(self):
def print_max_values(self) -> None:
for pid, maxes in self.max_values.items():
print(f"PID {pid} Maximum Values: {maxes}")

def finalize(self):
def finalize(self) -> None:
if not self.process.returncode:
print(Colors.OKGREEN)
else:
Expand All @@ -181,7 +179,7 @@ def finalize(self):
f"CPU Peak Usage: {self.max_values.get('totals', {}).get('pcpu', 'unknown')}%"
)

def __repr__(self):
def __repr__(self) -> str:
return json.dumps(
{
"command": self.command,
Expand Down Expand Up @@ -288,7 +286,13 @@ def from_argv(cls) -> Arguments:
)


def monitor_process(report, process, report_interval, sample_interval, stop_event):
def monitor_process(
report: Report,
process: subprocess.Popen,
report_interval: float,
sample_interval: float,
stop_event: threading.Event,
) -> None:
while not stop_event.wait(timeout=sample_interval):
while True:
if process.poll() is not None: # the passthrough command has finished
Expand All @@ -310,53 +314,58 @@ class TailPipe:

TAIL_CYCLE_TIME = 0.01

def __init__(self, file_path, buffer):
def __init__(self, file_path: str, buffer: IO[bytes]) -> None:
self.file_path = file_path
self.buffer = buffer
self.stop_event = None
self.infile = None
self.thread = None
self.stop_event: threading.Event | None = None
self.infile: IO[bytes] | None = None
self.thread: threading.Thread | None = None

def start(self):
def start(self) -> None:
Path(self.file_path).touch()
self.stop_event = threading.Event()
self.infile = open(self.file_path, "rb")
self.thread = threading.Thread(target=self._tail, daemon=True)
self.thread.start()

def fileno(self):
def fileno(self) -> int:
assert self.infile is not None
return self.infile.fileno()

def _catch_up(self):
def _catch_up(self) -> None:
assert self.infile is not None
data = self.infile.read()
if data:
self.buffer.write(data)
self.buffer.flush()

def _tail(self):
def _tail(self) -> None:
assert self.stop_event is not None
try:
while not self.stop_event.is_set():
self._catch_up()
time.sleep(TailPipe.TAIL_CYCLE_TIME)

# After stop event, collect and passthrough data one last time
self._catch_up()
except Exception:
raise
finally:
self.buffer.flush()

def close(self):
def close(self) -> None:
assert self.stop_event is not None
assert self.thread is not None
assert self.infile is not None
self.stop_event.set()
self.thread.join()
self.infile.close()


def prepare_outputs(
capture_outputs: str, outputs: str, output_prefix: str
) -> Tuple[Union[TextIO, TailPipe, int], Union[TextIO, TailPipe, int]]:
stdout: Union[TextIO, TailPipe, int]
stderr: Union[TextIO, TailPipe, int]
) -> tuple[TextIO | TailPipe | int | None, TextIO | TailPipe | int | None]:
stdout: TextIO | TailPipe | int | None
stderr: TextIO | TailPipe | int | None

if capture_outputs in ["all", "stdout"] and outputs in ["all", "stdout"]:
stdout = TailPipe(f"{output_prefix}stdout", buffer=sys.stdout.buffer)
Expand All @@ -380,7 +389,7 @@ def prepare_outputs(
return stdout, stderr


def safe_close_files(file_list):
def safe_close_files(file_list: Iterable[Any]) -> None:
for f in file_list:
try:
f.close()
Expand All @@ -398,12 +407,12 @@ def ensure_directories(path: str) -> None:
os.makedirs(directory, exist_ok=True)


def main():
def main() -> None:
args = Arguments.from_argv()
execute(args)


def execute(args):
def execute(args: Arguments) -> None:
"""A wrapper to execute a command, monitor and log the process details."""
datetime_filesafe = datetime.now().strftime("%Y.%m.%dT%H.%M.%S")
duct_pid = os.getpid()
Expand All @@ -414,10 +423,12 @@ def execute(args):
stdout, stderr = prepare_outputs(
args.capture_outputs, args.outputs, formatted_output_prefix
)
stdout_file: TextIO | IO[bytes] | int | None
if isinstance(stdout, TailPipe):
stdout_file = open(stdout.file_path, "wb")
else:
stdout_file = stdout
stderr_file: TextIO | IO[bytes] | int | None
if isinstance(stderr, TailPipe):
stderr_file = open(stderr.file_path, "wb")
else:
Expand Down
4 changes: 3 additions & 1 deletion test/data/cat_to_err.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import sys
from typing import IO


def cat_to_stream(path, buffer):
def cat_to_stream(path: str, buffer: IO[bytes]) -> None:
with open(path, "rb") as infile:
buffer.write(infile.read())

Expand Down
Loading

0 comments on commit de5929d

Please sign in to comment.