Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
haochen authored and haochen committed Apr 11, 2023
0 parents commit 56730e6
Show file tree
Hide file tree
Showing 22 changed files with 4,780 additions and 0 deletions.
1,027 changes: 1,027 additions & 0 deletions annnotator.py

Large diffs are not rendered by default.

1,091 changes: 1,091 additions & 0 deletions canvas.py

Large diffs are not rendered by default.

1,617 changes: 1,617 additions & 0 deletions categories.txt

Large diffs are not rendered by default.

197 changes: 197 additions & 0 deletions mask_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from typing import Any, Dict, List, Optional

import cv2
import numpy as np
import torch
from tqdm import tqdm

from metaseg import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from metaseg.utils import download_model, load_image, load_video


class SegAutoMaskPredictor:
def __init__(self):
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"

def load_model(self, model_type):
if self.model is None:
self.model_path = download_model(model_type)
self.model = sam_model_registry[model_type](checkpoint=self.model_path)
self.model.to(device=self.device)

return self.model

def predict(self, frame, points_per_side, points_per_batch, min_area):
frame = load_image(frame)
mask_generator = SamAutomaticMaskGenerator(
self.model, points_per_side=points_per_side, points_per_batch=points_per_batch, min_mask_region_area=min_area
)

masks = mask_generator.generate(frame)

return frame, masks

def save_image(self, source, model_type, points_per_side, points_per_batch, min_area):
read_image = load_image(source)
image, anns = self.predict(read_image, model_type, points_per_side, points_per_batch, min_area)
if len(anns) == 0:
return

sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
mask_image = np.zeros((anns[0]["segmentation"].shape[0], anns[0]["segmentation"].shape[1], 3), dtype=np.uint8)
colors = np.random.randint(0, 255, size=(256, 3), dtype=np.uint8)
for i, ann in enumerate(sorted_anns):
m = ann["segmentation"]
img = np.ones((m.shape[0], m.shape[1], 3), dtype=np.uint8)
color = colors[i % 256]
for i in range(3):
img[:, :, 0] = color[0]
img[:, :, 1] = color[1]
img[:, :, 2] = color[2]
img = cv2.bitwise_and(img, img, mask=m.astype(np.uint8))
img = cv2.addWeighted(img, 0.35, np.zeros_like(img), 0.65, 0)
mask_image = cv2.add(mask_image, img)

combined_mask = cv2.add(image, mask_image)
cv2.imwrite("output.jpg", combined_mask)

return "output.jpg"

def save_video(self, source, model_type, points_per_side, points_per_batch, min_area):
cap, out = load_video(source)
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
colors = np.random.randint(0, 255, size=(256, 3), dtype=np.uint8)

for _ in tqdm(range(length)):
ret, frame = cap.read()
if not ret:
break

image, anns = self.predict(frame, model_type, points_per_side, points_per_batch, min_area)
if len(anns) == 0:
continue

sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
mask_image = np.zeros(
(anns[0]["segmentation"].shape[0], anns[0]["segmentation"].shape[1], 3), dtype=np.uint8
)

for i, ann in enumerate(sorted_anns):
m = ann["segmentation"]
color = colors[i % 256] # Her nesne için farklı bir renk kullan
img = np.zeros((m.shape[0], m.shape[1], 3), dtype=np.uint8)
img[:, :, 0] = color[0]
img[:, :, 1] = color[1]
img[:, :, 2] = color[2]
img = cv2.bitwise_and(img, img, mask=m.astype(np.uint8))
img = cv2.addWeighted(img, 0.35, np.zeros_like(img), 0.65, 0)
mask_image = cv2.add(mask_image, img)

combined_mask = cv2.add(frame, mask_image)
out.write(combined_mask)

out.release()
cap.release()
cv2.destroyAllWindows()

return "output.mp4"


class SegManualMaskPredictor:
def __init__(self):
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"

def load_model(self, model_type):
if self.model is None:
self.model_path = download_model(model_type)
self.model = sam_model_registry[model_type](checkpoint=self.model_path)
self.model.to(device=self.device)

return self.model

def load_mask(self, mask, random_color):
if random_color:
color = np.random.rand(3) * 255
else:
color = np.array([100, 50, 0])

h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
mask_image = mask_image.astype(np.uint8)
return mask_image

def load_box(self, box, image):
x, y, w, h = int(box[0]), int(box[1]), int(box[2]), int(box[3])
cv2.rectangle(image, (x, y), (w, h), (0, 255, 0), 2)
return image

def multi_boxes(self, boxes, predictor, image):
input_boxes = torch.tensor(boxes, device=predictor.device)
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
return input_boxes, transformed_boxes

def predict(
self,
frame,
input_box=None,
input_point=None,
input_label=None,
multimask_output=False,
):
predictor = SamPredictor(self.model)
predictor.set_image(frame)

if type(input_box[0]) == list:
input_boxes, new_boxes = self.multi_boxes(input_box, predictor, frame)

masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=new_boxes,
multimask_output=False,
)

elif type(input_box[0]) == int:
input_boxes = np.array(input_box)[None, :]

masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_boxes,
multimask_output=multimask_output,
)

return frame, masks, input_boxes

def save_image(
self,
source,
model_type,
input_box=None,
input_point=None,
input_label=None,
multimask_output=False,
output_path="v0.jpg",
):
read_image = load_image(source)
image, anns, boxes = self.predict(read_image, model_type, input_box, input_point, input_label, multimask_output)
if len(anns) == 0:
return

if type(input_box[0]) == list:
for mask in anns:
mask_image = self.load_mask(mask.cpu().numpy(), False)

for box in boxes:
image = self.load_box(box.cpu().numpy(), image)

elif type(input_box[0]) == int:
mask_image = self.load_mask(anns, True)
image = self.load_box(input_box, image)

combined_mask = cv2.add(image, mask_image)
cv2.imwrite(output_path, combined_mask)

return output_path
178 changes: 178 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
absl-py==1.3.0
addict==2.4.0
albumentations==1.3.0
alembic @ file:///home/conda/feedstock_root/build_artifacts/alembic_1675830701130/work
antlr4-python3-runtime==4.9.3
argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1640817743617/work
argon2-cffi-bindings @ file:///D:/bld/argon2-cffi-bindings_1649500465986/work
backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work
beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1675252249248/work
black==22.12.0
bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1674535352125/work
brotlipy @ file:///D:/bld/brotlipy_1648854404735/work
cachetools==5.2.0
certifi==2022.12.7
cffi @ file:///D:/bld/cffi_1656782956657/work
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1661170624537/work
click==7.1.2
clip @ git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
cloudpickle==2.2.0
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
cryptography @ file:///D:/bld/cryptography_1657174118602/work
cycler==0.11.0
Cython==0.29.32
dataclasses==0.6
debugpy @ file:///C:/ci/debugpy_1637091911212/work
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work
detectron2 @ git+https://github.com/facebookresearch/detectron2.git@857d5de21a7789d1bba46694cf608b1cb2ea128a
easydict==1.10
entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
fasttext==0.9.2
filelock==3.8.2
flit_core @ file:///home/conda/feedstock_root/build_artifacts/flit-core_1667734568827/work/source/flit_core
fonttools==4.38.0
ftfy==6.1.1
future==0.18.2
fuzzywuzzy @ file:///home/conda/feedstock_root/build_artifacts/fuzzywuzzy_1605704184142/work
fvcore==0.1.5.post20221221
google-auth==2.15.0
google-auth-oauthlib==0.4.6
graphviz==0.20.1
greenlet @ file:///D:/bld/greenlet_1661444189601/work
grpcio==1.51.1
huggingface-hub==0.11.1
hydra-core==1.3.1
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work
imageio==2.23.0
imgaug==0.4.0
imgviz==1.6.2
importlib-metadata==5.2.0
importlib-resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1672681417544/work
iopath==0.1.9
ipykernel @ file:///D:/bld/ipykernel_1655369313836/work
ipython @ file:///D:/bld/ipython_1651240762678/work
ipython-genutils==0.2.0
ipywidgets @ file:///home/conda/feedstock_root/build_artifacts/ipywidgets_1671720089366/work
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1669134318875/work
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1654302431367/work
joblib==1.2.0
jsonschema==2.6.0
jupyter @ file:///D:/bld/jupyter_1637233295291/work
jupyter-console @ file:///home/conda/feedstock_root/build_artifacts/jupyter_console_1655961255101/work
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1673615989977/work
jupyter_core @ file:///C:/b/abs_a9330r1z_i/croots/recipe/jupyter_core_1664917313457/work
jupyterlab-pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1649936611996/work
jupyterlab-widgets @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_widgets_1671722028097/work
kiwisolver==1.4.4
labelme==5.1.1
lvis @ git+https://github.com/lvis-dataset/lvis-api.git@35f09cd7c5f313a9bf27b329ca80effe2b0c8a93
Mako @ file:///home/conda/feedstock_root/build_artifacts/mako_1668568582731/work
Markdown==3.4.1
MarkupSafe @ file:///C:/ci/markupsafe_1654508076077/work
matplotlib==3.5.3
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
mistune @ file:///home/conda/feedstock_root/build_artifacts/mistune_1675771498296/work
-e d:\cvpr_2023_dataset\detpro
mmpycocotools==12.0.3
mss==7.0.1
mypy-extensions==0.4.3
natsort==8.2.0
nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1662750566673/work
nbconvert @ file:///home/conda/feedstock_root/build_artifacts/nbconvert-meta_1674590374792/work
nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1617383142101/work
nbgrader @ file:///D:/bld/nbgrader_1615406840332/work
nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work
networkx==2.6.3
ninja==1.11.1
nltk==3.8
notebook @ file:///D:/bld/notebook_1616419450136/work
numpy==1.21.6
oauthlib==3.2.2
omegaconf==2.3.0
opencv-python==4.6.0.66
opencv-python-headless==4.6.0.66
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1673482170163/work
pandocfilters @ file:///home/conda/feedstock_root/build_artifacts/pandocfilters_1631603243851/work
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
pathspec==0.10.3
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
Pillow==9.3.0
platformdirs==2.6.0
portalocker==2.6.0
prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1674535637125/work
prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1670414775770/work
protobuf==3.20.3
psutil @ file:///C:/Windows/Temp/abs_b2c2fd7f-9fd5-4756-95ea-8aed74d0039flsd9qufz/croots/recipe/psutil_1656431277748/work
pyasn1==0.4.8
pyasn1-modules==0.2.8
pybind11==2.10.2
pycocotools==2.0.6
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1672682006896/work
pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1663846997386/work
pyparsing==3.0.9
PyQt5==5.15.4
pyqt5-plugins==5.15.4.2.2
PyQt5-Qt5==5.15.2
PyQt5-sip==12.11.0
pyqt5-tools==5.15.4.3.2
PyQtChart==5.12
PyQtWebEngine==5.12.1
PySocks @ file:///D:/bld/pysocks_1648857448477/work
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
python-dotenv==0.21.0
python-Levenshtein @ file:///D:/bld/python-levenshtein_1649607057775/work
PyWavelets==1.3.0
pywin32==305
pywinpty @ file:///C:/ci_310/pywinpty_1644230983541/work/target/wheels/pywinpty-2.0.2-cp37-none-win_amd64.whl
PyYAML==6.0
pyzmq @ file:///C:/ci/pyzmq_1657597757589/work
qt5-applications==5.15.2.2.2
qt5-tools==5.15.2.1.2
qtconsole @ file:///home/conda/feedstock_root/build_artifacts/qtconsole-base_1667404144336/work
QtPy @ file:///home/conda/feedstock_root/build_artifacts/qtpy_1667873092748/work
qudida==0.0.4
regex==2022.10.31
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1673863902341/work
requests-oauthlib==1.3.1
rsa==4.9
scikit-image==0.19.3
scikit-learn==1.0.2
scipy==1.7.3
Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1628511208346/work
shapely==2.0.0
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work
SQLAlchemy @ file:///D:/bld/sqlalchemy_1659998874450/work
tabulate==0.9.0
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
termcolor==2.1.1
terminado @ file:///D:/bld/terminado_1652790822513/work
terminaltables==3.1.10
threadpoolctl==3.1.0
tifffile==2021.11.2
timm==0.6.12
tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1666100256010/work
tomli==2.0.1
torch==1.13.1+cu116
torchaudio==0.13.1+cu116
torchvision==0.14.1+cu116
tornado @ file:///D:/bld/tornado_1656937938087/work
tqdm==4.64.1
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work
typed-ast==1.5.4
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1665144421445/work
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1673452138552/work
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work
webencodings==0.5.1
Werkzeug==2.2.2
widgetsnbextension @ file:///home/conda/feedstock_root/build_artifacts/widgetsnbextension_1672066693230/work
win-inet-pton @ file:///D:/bld/win_inet_pton_1667051142467/work
wincertstore==0.2
yacs==0.1.8
yapf==0.32.0
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1675982654259/work
Loading

0 comments on commit 56730e6

Please sign in to comment.