-
Notifications
You must be signed in to change notification settings - Fork 871
/
Copy pathmodel_handler_generalized.py
73 lines (62 loc) · 2.33 KB
/
model_handler_generalized.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from ts.torch_handler.base_handler import BaseHandler
from fairseq.models.transformer import TransformerModel
import torch
import json
import os
import logging
logger = logging.getLogger(__name__)
class LanguageTranslationHandler(BaseHandler):
def __init__(self):
self._context = None
self.initialized = False
self.model = None
self.device = None
def initialize(self, context):
self._context = context
self.initialized = True
self.manifest = context.manifest
properties = context.system_properties
model_dir = properties.get("model_dir")
self.device = torch.device(
"cuda:" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu")
#read configs for the model_name, bpe etc. from setup_config.json
setup_config_path = os.path.join(model_dir, "setup_config.json")
if os.path.isfile(setup_config_path):
with open(setup_config_path) as setup_config_file:
self.setup_config = json.load(setup_config_file)
else:
logger.warning('Missing the setup_config.json file.')
# load the model
self.model = TransformerModel.from_pretrained(
model_dir,
checkpoint_file='model.pt',
data_name_or_path=model_dir,
tokenizer='moses',
bpe=self.setup_config["bpe"]
)
self.model.to(self.device)
self.model.eval()
self.initialized = True
def preprocess(self, data):
textInput = []
for row in data:
text = row.get("data") or row.get("body")
decoded_text = text.decode('utf-8')
textInput.append(decoded_text)
return textInput
def inference(self, data, *args, **kwargs):
inference_output = []
with torch.no_grad():
translation = self.model.translate(data, beam=5)
logger.info("Model translated: %s", translation)
for i in range(0, len(data)):
output = {
"input": data[i],
self.setup_config["translated_output"]: translation[i]
}
inference_output.append(json.dumps(output))
return inference_output
def postprocess(self, data):
return data