-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathheuristic_worker.py
152 lines (129 loc) · 5.25 KB
/
heuristic_worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import argparse
import csv
import functools
import logging
import os
import pathlib
import tqdm
import tqdm.contrib.concurrent
import tqdm.contrib.logging
import heuristics
from dataset.utils import combat_dir_iterator, dataset_checksum, get_combat_dirs
# ===== argparsing =====
parser = argparse.ArgumentParser(description="Applies defined heuristics to a dataset.", add_help=False)
parser.add_argument(
"-d",
"--data-dir",
help="the directory containing the raw data (default: data/)",
default="data/",
type=pathlib.Path,
)
parser.add_argument(
"-o",
"--output-dir",
help="the directory to save the heuristic results to (default: heuristic_results/)",
default="heuristic_results/",
type=pathlib.Path,
)
parser.add_argument(
"-h",
"--heuristic",
help="the heuristic(s) to run (defaults to all)",
action="append",
)
parser.add_argument(
"--force-recompute",
help="forces the worker to recompute regardless of prior computation",
action="store_true",
)
parser.add_argument("--help", help="displays CLI help", action="help")
# ===== main =====
log = logging.getLogger(__name__)
def get_heuristic(name: str) -> heuristics.Heuristic:
"""Returns the heuristic with the given name (utility method for CLI)"""
return getattr(heuristics, name)
def worker_entrypoint(heuristic: heuristics.Heuristic, combat_dir: str) -> tuple[str, int | float]:
"""Multiprocessing worker entrypoint, applies the given heuristic to one dir"""
return os.path.basename(combat_dir), heuristic(combat_dir_iterator(combat_dir))
class Runner:
def __init__(
self,
data_dir_path: pathlib.Path,
result_dir_path: pathlib.Path,
compute_heuristics: list[str] | None = None,
force_recompute: bool = False,
):
self.data_dir_path = data_dir_path
self.result_dir_path = result_dir_path
self.heuristics = compute_heuristics
self.force_recompute = force_recompute
self.dataset_checksum = None
@classmethod
def from_args(cls, args: argparse.Namespace):
return cls(
data_dir_path=args.data_dir,
result_dir_path=args.output_dir,
compute_heuristics=args.heuristic,
force_recompute=args.force_recompute,
)
def init(self):
num_cores = os.cpu_count() or 1
log.info(f"Hashing dataset (with parallelization={num_cores})...")
self.dataset_checksum = dataset_checksum(self.data_dir_path)
log.info(f"checksum={self.dataset_checksum}")
os.makedirs(self.result_dir_path, exist_ok=True)
def run_one(self, heuristic_name: str):
log.info(f"Applying heuristic {heuristic_name!r}...")
result_file_path = self.result_dir_path / f"{heuristic_name}.csv"
heuristic = get_heuristic(heuristic_name)
entrypoint = functools.partial(worker_entrypoint, heuristic)
# if the results already exist for this dataset and heuristic, we can skip everything
try:
with open(result_file_path, newline="") as f:
reader = csv.reader(f)
_, existing_checksum = next(reader)
if self.force_recompute:
log.info(
f"A result for this dataset already exists at {os.path.relpath(result_file_path)} but recompute is"
" forced, overwriting..."
)
elif existing_checksum == self.dataset_checksum:
log.info(f"A result for this dataset already exists at {os.path.relpath(result_file_path)}!")
return
else:
log.info("An existing result was found but the checksum does not match, overwriting...")
except FileNotFoundError:
pass
# execution
results = tqdm.contrib.concurrent.process_map(entrypoint, get_combat_dirs(self.data_dir_path), chunksize=10)
results.sort(key=lambda pair: pair[1])
log.info(f"Application of {heuristic_name} complete, saving results...")
# save results
with open(result_file_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(("checksum", self.dataset_checksum))
writer.writerows(results)
def run_heuristics(self, heuristic_names: list[str]):
if not all(hasattr(heuristics, name) for name in heuristic_names):
raise RuntimeError(
f"Heuristic(s) were passed but not defined: {set(heuristic_names).difference(heuristics.__all__)}"
)
with tqdm.contrib.logging.logging_redirect_tqdm():
for heuristic_name in tqdm.tqdm(heuristic_names):
self.run_one(heuristic_name)
def run_all(self):
self.run_heuristics(heuristics.__all__)
def run_cli(self):
self.init()
if self.heuristics is None:
self.run_all()
elif not self.heuristics:
raise RuntimeError(
"At least one heuristic should be passed, or the argument should be omitted to run all heuristics."
)
else:
self.run_heuristics(self.heuristics)
log.info("Done!")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
Runner.from_args(parser.parse_args()).run_cli()