Skip to content

Commit

Permalink
poc for yolov3 tf
Browse files Browse the repository at this point in the history
  • Loading branch information
KTachibanaM committed Oct 21, 2024
1 parent c102ff2 commit a0b6ac0
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
// "forwardPorts": [],

// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand": "pip3 install --user -r requirements.txt && ollama pull llava",
"postCreateCommand": "pip3 install --user -r requirements.txt",
"features": {
"ghcr.io/prulloac/devcontainer-features/ollama:1": {}
"ghcr.io/devcontainers/features/docker-in-docker:2": {}
}

// Configure tool-specific properties.
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ services:
- ./file_cache:/app/file_cache
```

TODO: `docker run -p 8501:8501 --mount type=bind,source=./blobs/yolov3,target=/models/yolov3 -e MODEL_NAME=yolov3 -t tensorflow/serving`

## Development

Open in VSCode, then run
Expand Down
32 changes: 32 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from rss_lambda.rss_image_recognition import rss_image_recognition
from rss_lambda.llava import is_llava_available
from rss_lambda.rss_image_gender import rss_image_gender
from rss_lambda.rss_image_recognition_tf import rss_image_recognition_tf


if os.getenv('SENTRY_DSN'):
Expand Down Expand Up @@ -78,6 +79,37 @@ def _rss_image_recog():
return e.message, 500


@app.route("/rss_image_recog_tf")
def rss_image_recog_tf():
# parse url
url = request.args.get('url', default=None)
if not url:
return "No url provided", 400
url = unquote(url)
parsed_url = urlparse(url)
if not all([parsed_url.scheme, parsed_url.netloc]):
return "Invalid url", 400
rss_text_or_res = download_feed(url, request.headers)

# parse class_id
class_id = request.args.get('class_id', default=None)
if not class_id:
return "No class_id provided", 400

# Hack for Reeder (iOS)
if class_id.endswith("/rss"):
class_id = class_id[:-4]
if class_id.endswith("/feed"):
class_id = class_id[:-5]

class_id = int(class_id)

try:
return Response(rss_image_recognition_tf(rss_text_or_res, class_id, url), mimetype='application/xml')
except RSSLambdaError as e:
return e.message, 500


@app.route("/rss_image_gender")
def _rss_image_gender():
if not is_llava_available():
Expand Down
1 change: 1 addition & 0 deletions blobs/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
yolov3
yolov3.weights
yolov3.cfg
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@ sentry-sdk[flask]==1.43.0
opencv-python-headless==4.10.0.84
ollama==0.3.3
pillow==11.0.0
cachew==0.17.20241017
colorlog==6.8.2
tensorflow==2.17.0
39 changes: 39 additions & 0 deletions rss_lambda/rss_image_recognition_tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from .process_rss_text import process_rss_text, ParsedRssText
from .lambdas import _extract_images_from_description
from .yolov3_tf import yolov3_tf
from .abstract_expensive_rss_lambda import abstract_expensive_rss_lambda
from .rss_image_utils import _create_item_element_with_image, _extract_link

def _image_recognition_tf(rss_text: str, class_id: int) -> str:
def processor(parsed_rss_text: ParsedRssText):
root = parsed_rss_text.root
parent = parsed_rss_text.parent
items = parsed_rss_text.items

matched_images = []
for item in items:
images = _extract_images_from_description(item, root.nsmap)
for image in images:
img_src = image.get('src')
if yolov3_tf(img_src, class_id):
matched_images.append(_create_item_element_with_image(
img_src,
item.tag,
_extract_link(item, root.nsmap)))

# remove all items and appended kept items
for item in items:
parent.remove(item)
for item in matched_images:
parent.append(item)

return process_rss_text(rss_text, processor)

def rss_image_recognition_tf(rss_text: str, class_id: int, url: str) -> str:
hash = url + ":" + str(class_id) + ":tf"

return abstract_expensive_rss_lambda(
rss_text,
_image_recognition_tf,
hash,
[class_id])
48 changes: 48 additions & 0 deletions rss_lambda/yolov3_tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import logging
import tensorflow as tf
import time
import requests
import json
from .rss_image_utils import _download_image
from .file_cache import file_cache

tfserving_root = os.getenv("TFSERVING_ROOT", "http://localhost:8501")
size = 320


@file_cache(verbose=True)
def yolov3_tf(image_url: str, desired_class_id: int) -> bool:
start_time = time.time()

# Downlaod image
image_path = _download_image(image_url)
if image_path is None:
logging.error(f"failed to download image from {image.get('src')}")
return False

# Decode image
image = tf.image.decode_image(open(image_path, 'rb').read(), channels=3)
image = tf.expand_dims(image, axis=0)
image = tf.image.resize(image, (size, size))
image = image / 255

# Make request
data = {
"signature_name": "serving_default",
"instances": image.numpy().tolist()
}
resp = requests.post(f"{tfserving_root}/v1/models/yolov3:predict", json=data)
resp = json.loads(resp.content.decode('utf-8'))['predictions'][0]

res = False
valid_predictions = resp['yolo_nms_3']
for i in range(valid_predictions):
clazz = resp['yolo_nms_2'][i]
if clazz == desired_class_id:
res = True
break

logging.info(f"yolov3 tf took {time.time() - start_time} seconds")

return res

0 comments on commit a0b6ac0

Please sign in to comment.