Skip to content

Commit

Permalink
fix(client): make model serving of the examples work (#2356)
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui authored Jun 15, 2023
1 parent d6132ee commit 6193e14
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 14 deletions.
2 changes: 1 addition & 1 deletion example/PennFudanPed/pfp/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def handler(file: str):
with open(file, "rb") as f:
data = f.read()
img = Image(data, mime_type=MIMEType.PNG)
_, res = predict_mask_rcnn({"image": img}, 0)
_, res = predict_mask_rcnn({"image": img}, {"index": 0})

bbox = res["bbox"]
_img = PILImage.open(file)
Expand Down
3 changes: 1 addition & 2 deletions example/cifar10/cifar/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(self) -> None:
self.model = self._load_model(self.device)

def predict(self, data, external):
print(f"index: {external['index']}")
data_tensor = self._pre(data["image"])
output = self.model(data_tensor)
return self._post(output)
Expand Down Expand Up @@ -86,5 +85,5 @@ def online_eval(self, img: PILImage.Image):
"ship",
"truck",
)
_, prob = self.predict(Image(fp=buf.getvalue()))
_, prob = self.predict({"image": Image(fp=buf.getvalue())}, {})
return {classes[i]: p for i, p in enumerate(prob[0])}
2 changes: 1 addition & 1 deletion example/nmt/nmt/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@ def _load_decoder_model(self, device):
examples=["i m not afraid to die .", "i study mathematics ."],
)
def online_eval(self, content: str):
return self.ppl(Text(content))
return self.ppl({"english": Text(content)})
9 changes: 3 additions & 6 deletions example/speech_command/sc/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_m5_model():
@torch.no_grad()
@evaluation.predict(resources={"nvidia.com/gpu": 1}, replicas=2)
def predict_speech(data):
_audio = io.BytesIO(data.speech.to_bytes())
_audio = io.BytesIO(data["speech"].to_bytes())
waveform, _ = torchaudio.load(_audio)
waveform = torch.nn.utils.rnn.pad_sequence(
[waveform.t()], batch_first=True, padding_value=0.0
Expand Down Expand Up @@ -112,14 +112,11 @@ def evaluate_speech(ppl_result):


@api(
[
gradio.Audio(type="filepath"),
gradio.Audio(source="microphone", type="filepath"),
],
gradio.Audio(type="filepath"),
gradio.Label(),
)
def online_eval(file: str):
with open(file, "rb") as f:
data = f.read()
_, prob = predict_speech(Audio(fp=data))
_, prob = predict_speech({"speech": Audio(fp=data)})
return {ALL_LABELS[i]: p for i, p in enumerate(prob)}
5 changes: 2 additions & 3 deletions example/text_cls_AG_NEWS/tcan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import gradio
from torchtext.data.utils import get_tokenizer, ngrams_iterator

from starwhale import Text, PipelineHandler, multi_classification
from starwhale import PipelineHandler, multi_classification
from starwhale.api.service import api

from .model import TextClassificationModel
Expand All @@ -24,7 +24,6 @@ def __init__(self) -> None:

@torch.no_grad()
def ppl(self, data: dict, **kw):
print(f"index: {kw['external']['index']}")
ngrams = list(ngrams_iterator(self.tokenizer(data["text"]), 2))
tensor = torch.tensor(self.vocab(ngrams)).to(self.device)
output = self.model(tensor, torch.tensor([0]).to(self.device))
Expand Down Expand Up @@ -69,5 +68,5 @@ def _load_vocab(self):
],
)
def online_eval(self, content: str):
_, prob = self.ppl(Text(content))
_, prob = self.ppl({"text": content})
return {_LABEL_NAMES[i]: p for i, p in enumerate(prob)}
2 changes: 1 addition & 1 deletion example/ucf101/ucf101/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,5 +240,5 @@ def cmp(self, ppl_result: t.Iterator) -> t.Any:
def online_eval(self, file: str):
with open(file, "rb") as f:
data = f.read()
prob = self.ppl([Video(fp=data)])[0]
prob = self.ppl([{"video": Video(fp=data)}])[0]
return {_LABELS[i]: p for i, p in enumerate(prob[1])}

0 comments on commit 6193e14

Please sign in to comment.