-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathretrain.py
133 lines (108 loc) · 4.58 KB
/
retrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Module with functions for the retraining and sampling steps in the pipeline
"""
import os
import copy
import userdefs
from retrain import sampling as sample
from retrain import utils, train
from retrain.dataloader import LabeledSet
import yolov3.utils as yoloutils
from yolov3 import parallelize
import analysis.benchmark as bench
import analysis.results as resloader
def benchmark_sample(sample_method, imgs, config, batch_num, last_epoch):
"""Simulate benchmarking and sampling at the edge, returning a list of samples."""
name, (sample_func, kwargs) = sample_method
bench_file = (
f"{config['output']}/{name}{batch_num}_benchmark_avg_1_{last_epoch}.csv"
)
if not os.path.exists(bench_file):
results_df = bench.benchmark_avg(
imgs, name, 1, last_epoch, config["conf_check_num"], config,
)
bench.save_results(results_df, bench_file)
# Create samples from the benchmark
results, _ = resloader.load_data(bench_file, by_actual=False)
print(f"===== {name} ======")
sample_files = sample.create_sample(
results, config["bandwidth"], sample_func, **kwargs
)
return sample_files
def sample_retrain(
sample_method, batches, config, last_epoch, seen_images, label_func, device=None,
):
"""Run the sampling and retraining pipeline for a particular sampling function."""
name, _ = sample_method
classes = utils.load_classes(config["class_list"])
seen_images = copy.deepcopy(seen_images)
for i, sample_folder in enumerate(batches):
sample_folder.label(classes, label_func)
sample_labeled = LabeledSet(
sample_folder.imgs, len(classes), config["img_size"],
)
sample_filename = f"{config['output']}/{name}{i}_sample_{last_epoch}.txt"
if os.path.exists(sample_filename):
print("Loading existing samples")
retrain_files = open(sample_filename, "r").read().split("\n")
else:
retrain_files = benchmark_sample(
sample_method, sample_labeled, config, i, last_epoch
)
# When deploying at the edge, this would be where data is
# sent from nodes to the Beehive, along with the benchmark file
with open(sample_filename, "w+") as out:
out.write("\n".join(retrain_files))
# Receive raw sampled data in the cloud
# This process simulates manually labeling/verifying all inferences
retrain_obj = LabeledSet(
retrain_files, len(classes), config["img_size"], prefix=f"{name}{i}"
)
new_splits_made = retrain_obj.load_or_split(
config["output"],
config["train_sample"],
config["valid_sample"],
save=False,
sample_dir=config["sample_set"],
)
if new_splits_made:
# If reloaded, splits have old images already incorporated
for set_name in retrain_obj.sets:
# Calculate proportion of old examples needed
number_desired = (1 / config["retrain_new"] - 1) * len(
getattr(retrain_obj, set_name)
)
if round(number_desired) == 0:
continue
print(set_name, number_desired)
extra_images = getattr(seen_images, set_name).split_batch(
round(number_desired)
)[0]
orig_set = getattr(retrain_obj, set_name)
orig_set += extra_images
seen_images += retrain_obj
retrain_obj.save_splits(config["output"])
retrain_obj.train.augment(config["images_per_class"])
config["start_epoch"] = last_epoch + 1
checkpoint = utils.find_checkpoint(config, name, last_epoch)
last_epoch = train.train(retrain_obj, config, checkpoint, device=device)
def retrain(config, sample_methods, sample_batches, base_epoch, init_imgs):
"""Sample images and retrain for all sample methods given."""
free_gpus = yoloutils.get_free_gpus(yoloutils.get_memory_needed(config))
grouped_args = list()
for i, sample_method in enumerate(sample_methods.items()):
device = free_gpus[i % len(free_gpus)]
method_args = (
sample_method,
sample_batches,
config,
base_epoch,
init_imgs,
userdefs.label_sample_set,
device,
)
grouped_args.append(method_args)
if not config["parallel"]:
sample_retrain(*method_args)
if config["parallel"]:
parallelize.run_parallel(sample_retrain, grouped_args)