-
Notifications
You must be signed in to change notification settings - Fork 84
/
Copy pathworker.py
106 lines (81 loc) · 3.7 KB
/
worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# pylint: disable=E1137,E1136,E0110
import sys
import os
import time
import datetime
import collections
import argparse
import numpy as np
import logging
LOGGER = logging.getLogger("server")
import rpyc
from rpyc.utils.server import OneShotServer, ThreadedServer
from web_tool.ModelSessionKerasExample import KerasDenseFineTune
from web_tool.ModelSessionPytorchSolar import SolarFineTuning
from web_tool.ModelSessionPyTorchExample import TorchFineTuning
from web_tool.ModelSessionRandomForest import ModelSessionRandomForest
from web_tool.ModelSessionPyTorchSegmentationModel import TorchSegmentationModel
from web_tool.Utils import setup_logging, serialize, deserialize
from web_tool.Models import load_models
class MyService(rpyc.Service):
def __init__(self, model):
self.model = model
def on_connect(self, conn):
pass
def on_disconnect(self, conn):
pass
def exposed_last_tile(self):
return serialize(self.model.last_tile)
def exposed_run(self, tile, inference_mode=False):
tile = deserialize(tile) # need to serialize/deserialize numpy arrays
output = self.model.run(tile, inference_mode)
return serialize(output) # need to serialize/deserialize numpy arrays
def exposed_retrain(self):
return self.model.retrain()
def exposed_add_sample_point(self, row, col, class_idx):
return self.model.add_sample_point(row, col, class_idx)
def exposed_undo(self):
return self.model.undo()
def exposed_reset(self):
return self.model.reset()
def exposed_save_state_to(self, directory):
return self.model.save_state_to(directory)
def exposed_load_state_from(self, directory):
return self.model.load_state_from(directory)
def main():
parser = argparse.ArgumentParser(description="AI for Earth Land Cover Worker")
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose debugging", default=False)
parser.add_argument("--port", action="store", type=int, help="Port we are listenning on", default=0)
parser.add_argument("--gpu_id", action="store", dest="gpu_id", type=int, help="GPU to use", required=False)
parser.add_argument("--model_key", action="store", dest="model_key", type=str, help="Model key from models.json to use")
args = parser.parse_args(sys.argv[1:])
# Setup logging
log_path = os.path.join(os.getcwd(), "tmp/logs/")
setup_logging(log_path, "worker")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "" if args.gpu_id is None else str(args.gpu_id)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
model_configs = load_models()
if not args.model_key in model_configs:
LOGGER.error("'%s' is not recognized as a valid model, exiting..." % (args.model_key))
return
model_type = model_configs[args.model_key]["type"]
if model_type == "keras_example":
model = KerasDenseFineTune(args.gpu_id, **model_configs[args.model_key])
elif model_type == "pytorch_example":
model = TorchFineTuning(args.gpu_id, **model_configs[args.model_key])
elif model_type == "pytorch_solar":
model = SolarFineTuning(args.gpu_id, **model_configs[args.model_key])
elif model_type == "random_forest":
model = ModelSessionRandomForest(**model_configs[args.model_key])
elif model_type == "pytorch_segmodel":
model = TorchSegmentationModel(args.gpu_id, **model_configs[args.model_key])
else:
raise NotImplementedError("The given model type is not implemented yet.")
t = OneShotServer(MyService(model), port=args.port)
t.start()
if __name__ == "__main__":
main()