-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c102ff2
commit a0b6ac0
Showing
7 changed files
with
125 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
yolov3 | ||
yolov3.weights | ||
yolov3.cfg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from .process_rss_text import process_rss_text, ParsedRssText | ||
from .lambdas import _extract_images_from_description | ||
from .yolov3_tf import yolov3_tf | ||
from .abstract_expensive_rss_lambda import abstract_expensive_rss_lambda | ||
from .rss_image_utils import _create_item_element_with_image, _extract_link | ||
|
||
def _image_recognition_tf(rss_text: str, class_id: int) -> str: | ||
def processor(parsed_rss_text: ParsedRssText): | ||
root = parsed_rss_text.root | ||
parent = parsed_rss_text.parent | ||
items = parsed_rss_text.items | ||
|
||
matched_images = [] | ||
for item in items: | ||
images = _extract_images_from_description(item, root.nsmap) | ||
for image in images: | ||
img_src = image.get('src') | ||
if yolov3_tf(img_src, class_id): | ||
matched_images.append(_create_item_element_with_image( | ||
img_src, | ||
item.tag, | ||
_extract_link(item, root.nsmap))) | ||
|
||
# remove all items and appended kept items | ||
for item in items: | ||
parent.remove(item) | ||
for item in matched_images: | ||
parent.append(item) | ||
|
||
return process_rss_text(rss_text, processor) | ||
|
||
def rss_image_recognition_tf(rss_text: str, class_id: int, url: str) -> str: | ||
hash = url + ":" + str(class_id) + ":tf" | ||
|
||
return abstract_expensive_rss_lambda( | ||
rss_text, | ||
_image_recognition_tf, | ||
hash, | ||
[class_id]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
import logging | ||
import tensorflow as tf | ||
import time | ||
import requests | ||
import json | ||
from .rss_image_utils import _download_image | ||
from .file_cache import file_cache | ||
|
||
tfserving_root = os.getenv("TFSERVING_ROOT", "http://localhost:8501") | ||
size = 320 | ||
|
||
|
||
@file_cache(verbose=True) | ||
def yolov3_tf(image_url: str, desired_class_id: int) -> bool: | ||
start_time = time.time() | ||
|
||
# Downlaod image | ||
image_path = _download_image(image_url) | ||
if image_path is None: | ||
logging.error(f"failed to download image from {image.get('src')}") | ||
return False | ||
|
||
# Decode image | ||
image = tf.image.decode_image(open(image_path, 'rb').read(), channels=3) | ||
image = tf.expand_dims(image, axis=0) | ||
image = tf.image.resize(image, (size, size)) | ||
image = image / 255 | ||
|
||
# Make request | ||
data = { | ||
"signature_name": "serving_default", | ||
"instances": image.numpy().tolist() | ||
} | ||
resp = requests.post(f"{tfserving_root}/v1/models/yolov3:predict", json=data) | ||
resp = json.loads(resp.content.decode('utf-8'))['predictions'][0] | ||
|
||
res = False | ||
valid_predictions = resp['yolo_nms_3'] | ||
for i in range(valid_predictions): | ||
clazz = resp['yolo_nms_2'][i] | ||
if clazz == desired_class_id: | ||
res = True | ||
break | ||
|
||
logging.info(f"yolov3 tf took {time.time() - start_time} seconds") | ||
|
||
return res |