forked from haochenheheda/segment-anything-annotator
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
haochen
authored and
haochen
committed
Apr 11, 2023
0 parents
commit 56730e6
Showing
22 changed files
with
4,780 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
import cv2 | ||
import numpy as np | ||
import torch | ||
from tqdm import tqdm | ||
|
||
from metaseg import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry | ||
from metaseg.utils import download_model, load_image, load_video | ||
|
||
|
||
class SegAutoMaskPredictor: | ||
def __init__(self): | ||
self.model = None | ||
self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
def load_model(self, model_type): | ||
if self.model is None: | ||
self.model_path = download_model(model_type) | ||
self.model = sam_model_registry[model_type](checkpoint=self.model_path) | ||
self.model.to(device=self.device) | ||
|
||
return self.model | ||
|
||
def predict(self, frame, points_per_side, points_per_batch, min_area): | ||
frame = load_image(frame) | ||
mask_generator = SamAutomaticMaskGenerator( | ||
self.model, points_per_side=points_per_side, points_per_batch=points_per_batch, min_mask_region_area=min_area | ||
) | ||
|
||
masks = mask_generator.generate(frame) | ||
|
||
return frame, masks | ||
|
||
def save_image(self, source, model_type, points_per_side, points_per_batch, min_area): | ||
read_image = load_image(source) | ||
image, anns = self.predict(read_image, model_type, points_per_side, points_per_batch, min_area) | ||
if len(anns) == 0: | ||
return | ||
|
||
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True) | ||
mask_image = np.zeros((anns[0]["segmentation"].shape[0], anns[0]["segmentation"].shape[1], 3), dtype=np.uint8) | ||
colors = np.random.randint(0, 255, size=(256, 3), dtype=np.uint8) | ||
for i, ann in enumerate(sorted_anns): | ||
m = ann["segmentation"] | ||
img = np.ones((m.shape[0], m.shape[1], 3), dtype=np.uint8) | ||
color = colors[i % 256] | ||
for i in range(3): | ||
img[:, :, 0] = color[0] | ||
img[:, :, 1] = color[1] | ||
img[:, :, 2] = color[2] | ||
img = cv2.bitwise_and(img, img, mask=m.astype(np.uint8)) | ||
img = cv2.addWeighted(img, 0.35, np.zeros_like(img), 0.65, 0) | ||
mask_image = cv2.add(mask_image, img) | ||
|
||
combined_mask = cv2.add(image, mask_image) | ||
cv2.imwrite("output.jpg", combined_mask) | ||
|
||
return "output.jpg" | ||
|
||
def save_video(self, source, model_type, points_per_side, points_per_batch, min_area): | ||
cap, out = load_video(source) | ||
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | ||
colors = np.random.randint(0, 255, size=(256, 3), dtype=np.uint8) | ||
|
||
for _ in tqdm(range(length)): | ||
ret, frame = cap.read() | ||
if not ret: | ||
break | ||
|
||
image, anns = self.predict(frame, model_type, points_per_side, points_per_batch, min_area) | ||
if len(anns) == 0: | ||
continue | ||
|
||
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True) | ||
mask_image = np.zeros( | ||
(anns[0]["segmentation"].shape[0], anns[0]["segmentation"].shape[1], 3), dtype=np.uint8 | ||
) | ||
|
||
for i, ann in enumerate(sorted_anns): | ||
m = ann["segmentation"] | ||
color = colors[i % 256] # Her nesne için farklı bir renk kullan | ||
img = np.zeros((m.shape[0], m.shape[1], 3), dtype=np.uint8) | ||
img[:, :, 0] = color[0] | ||
img[:, :, 1] = color[1] | ||
img[:, :, 2] = color[2] | ||
img = cv2.bitwise_and(img, img, mask=m.astype(np.uint8)) | ||
img = cv2.addWeighted(img, 0.35, np.zeros_like(img), 0.65, 0) | ||
mask_image = cv2.add(mask_image, img) | ||
|
||
combined_mask = cv2.add(frame, mask_image) | ||
out.write(combined_mask) | ||
|
||
out.release() | ||
cap.release() | ||
cv2.destroyAllWindows() | ||
|
||
return "output.mp4" | ||
|
||
|
||
class SegManualMaskPredictor: | ||
def __init__(self): | ||
self.model = None | ||
self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
def load_model(self, model_type): | ||
if self.model is None: | ||
self.model_path = download_model(model_type) | ||
self.model = sam_model_registry[model_type](checkpoint=self.model_path) | ||
self.model.to(device=self.device) | ||
|
||
return self.model | ||
|
||
def load_mask(self, mask, random_color): | ||
if random_color: | ||
color = np.random.rand(3) * 255 | ||
else: | ||
color = np.array([100, 50, 0]) | ||
|
||
h, w = mask.shape[-2:] | ||
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | ||
mask_image = mask_image.astype(np.uint8) | ||
return mask_image | ||
|
||
def load_box(self, box, image): | ||
x, y, w, h = int(box[0]), int(box[1]), int(box[2]), int(box[3]) | ||
cv2.rectangle(image, (x, y), (w, h), (0, 255, 0), 2) | ||
return image | ||
|
||
def multi_boxes(self, boxes, predictor, image): | ||
input_boxes = torch.tensor(boxes, device=predictor.device) | ||
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2]) | ||
return input_boxes, transformed_boxes | ||
|
||
def predict( | ||
self, | ||
frame, | ||
input_box=None, | ||
input_point=None, | ||
input_label=None, | ||
multimask_output=False, | ||
): | ||
predictor = SamPredictor(self.model) | ||
predictor.set_image(frame) | ||
|
||
if type(input_box[0]) == list: | ||
input_boxes, new_boxes = self.multi_boxes(input_box, predictor, frame) | ||
|
||
masks, _, _ = predictor.predict_torch( | ||
point_coords=None, | ||
point_labels=None, | ||
boxes=new_boxes, | ||
multimask_output=False, | ||
) | ||
|
||
elif type(input_box[0]) == int: | ||
input_boxes = np.array(input_box)[None, :] | ||
|
||
masks, _, _ = predictor.predict( | ||
point_coords=input_point, | ||
point_labels=input_label, | ||
box=input_boxes, | ||
multimask_output=multimask_output, | ||
) | ||
|
||
return frame, masks, input_boxes | ||
|
||
def save_image( | ||
self, | ||
source, | ||
model_type, | ||
input_box=None, | ||
input_point=None, | ||
input_label=None, | ||
multimask_output=False, | ||
output_path="v0.jpg", | ||
): | ||
read_image = load_image(source) | ||
image, anns, boxes = self.predict(read_image, model_type, input_box, input_point, input_label, multimask_output) | ||
if len(anns) == 0: | ||
return | ||
|
||
if type(input_box[0]) == list: | ||
for mask in anns: | ||
mask_image = self.load_mask(mask.cpu().numpy(), False) | ||
|
||
for box in boxes: | ||
image = self.load_box(box.cpu().numpy(), image) | ||
|
||
elif type(input_box[0]) == int: | ||
mask_image = self.load_mask(anns, True) | ||
image = self.load_box(input_box, image) | ||
|
||
combined_mask = cv2.add(image, mask_image) | ||
cv2.imwrite(output_path, combined_mask) | ||
|
||
return output_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
absl-py==1.3.0 | ||
addict==2.4.0 | ||
albumentations==1.3.0 | ||
alembic @ file:///home/conda/feedstock_root/build_artifacts/alembic_1675830701130/work | ||
antlr4-python3-runtime==4.9.3 | ||
argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1640817743617/work | ||
argon2-cffi-bindings @ file:///D:/bld/argon2-cffi-bindings_1649500465986/work | ||
backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work | ||
backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work | ||
beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1675252249248/work | ||
black==22.12.0 | ||
bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1674535352125/work | ||
brotlipy @ file:///D:/bld/brotlipy_1648854404735/work | ||
cachetools==5.2.0 | ||
certifi==2022.12.7 | ||
cffi @ file:///D:/bld/cffi_1656782956657/work | ||
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1661170624537/work | ||
click==7.1.2 | ||
clip @ git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1 | ||
cloudpickle==2.2.0 | ||
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work | ||
cryptography @ file:///D:/bld/cryptography_1657174118602/work | ||
cycler==0.11.0 | ||
Cython==0.29.32 | ||
dataclasses==0.6 | ||
debugpy @ file:///C:/ci/debugpy_1637091911212/work | ||
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work | ||
defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work | ||
detectron2 @ git+https://github.com/facebookresearch/detectron2.git@857d5de21a7789d1bba46694cf608b1cb2ea128a | ||
easydict==1.10 | ||
entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work | ||
fasttext==0.9.2 | ||
filelock==3.8.2 | ||
flit_core @ file:///home/conda/feedstock_root/build_artifacts/flit-core_1667734568827/work/source/flit_core | ||
fonttools==4.38.0 | ||
ftfy==6.1.1 | ||
future==0.18.2 | ||
fuzzywuzzy @ file:///home/conda/feedstock_root/build_artifacts/fuzzywuzzy_1605704184142/work | ||
fvcore==0.1.5.post20221221 | ||
google-auth==2.15.0 | ||
google-auth-oauthlib==0.4.6 | ||
graphviz==0.20.1 | ||
greenlet @ file:///D:/bld/greenlet_1661444189601/work | ||
grpcio==1.51.1 | ||
huggingface-hub==0.11.1 | ||
hydra-core==1.3.1 | ||
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work | ||
imageio==2.23.0 | ||
imgaug==0.4.0 | ||
imgviz==1.6.2 | ||
importlib-metadata==5.2.0 | ||
importlib-resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1672681417544/work | ||
iopath==0.1.9 | ||
ipykernel @ file:///D:/bld/ipykernel_1655369313836/work | ||
ipython @ file:///D:/bld/ipython_1651240762678/work | ||
ipython-genutils==0.2.0 | ||
ipywidgets @ file:///home/conda/feedstock_root/build_artifacts/ipywidgets_1671720089366/work | ||
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1669134318875/work | ||
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1654302431367/work | ||
joblib==1.2.0 | ||
jsonschema==2.6.0 | ||
jupyter @ file:///D:/bld/jupyter_1637233295291/work | ||
jupyter-console @ file:///home/conda/feedstock_root/build_artifacts/jupyter_console_1655961255101/work | ||
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1673615989977/work | ||
jupyter_core @ file:///C:/b/abs_a9330r1z_i/croots/recipe/jupyter_core_1664917313457/work | ||
jupyterlab-pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1649936611996/work | ||
jupyterlab-widgets @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_widgets_1671722028097/work | ||
kiwisolver==1.4.4 | ||
labelme==5.1.1 | ||
lvis @ git+https://github.com/lvis-dataset/lvis-api.git@35f09cd7c5f313a9bf27b329ca80effe2b0c8a93 | ||
Mako @ file:///home/conda/feedstock_root/build_artifacts/mako_1668568582731/work | ||
Markdown==3.4.1 | ||
MarkupSafe @ file:///C:/ci/markupsafe_1654508076077/work | ||
matplotlib==3.5.3 | ||
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work | ||
mistune @ file:///home/conda/feedstock_root/build_artifacts/mistune_1675771498296/work | ||
-e d:\cvpr_2023_dataset\detpro | ||
mmpycocotools==12.0.3 | ||
mss==7.0.1 | ||
mypy-extensions==0.4.3 | ||
natsort==8.2.0 | ||
nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1662750566673/work | ||
nbconvert @ file:///home/conda/feedstock_root/build_artifacts/nbconvert-meta_1674590374792/work | ||
nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1617383142101/work | ||
nbgrader @ file:///D:/bld/nbgrader_1615406840332/work | ||
nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work | ||
networkx==2.6.3 | ||
ninja==1.11.1 | ||
nltk==3.8 | ||
notebook @ file:///D:/bld/notebook_1616419450136/work | ||
numpy==1.21.6 | ||
oauthlib==3.2.2 | ||
omegaconf==2.3.0 | ||
opencv-python==4.6.0.66 | ||
opencv-python-headless==4.6.0.66 | ||
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1673482170163/work | ||
pandocfilters @ file:///home/conda/feedstock_root/build_artifacts/pandocfilters_1631603243851/work | ||
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work | ||
pathspec==0.10.3 | ||
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work | ||
Pillow==9.3.0 | ||
platformdirs==2.6.0 | ||
portalocker==2.6.0 | ||
prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1674535637125/work | ||
prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1670414775770/work | ||
protobuf==3.20.3 | ||
psutil @ file:///C:/Windows/Temp/abs_b2c2fd7f-9fd5-4756-95ea-8aed74d0039flsd9qufz/croots/recipe/psutil_1656431277748/work | ||
pyasn1==0.4.8 | ||
pyasn1-modules==0.2.8 | ||
pybind11==2.10.2 | ||
pycocotools==2.0.6 | ||
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work | ||
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1672682006896/work | ||
pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1663846997386/work | ||
pyparsing==3.0.9 | ||
PyQt5==5.15.4 | ||
pyqt5-plugins==5.15.4.2.2 | ||
PyQt5-Qt5==5.15.2 | ||
PyQt5-sip==12.11.0 | ||
pyqt5-tools==5.15.4.3.2 | ||
PyQtChart==5.12 | ||
PyQtWebEngine==5.12.1 | ||
PySocks @ file:///D:/bld/pysocks_1648857448477/work | ||
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work | ||
python-dotenv==0.21.0 | ||
python-Levenshtein @ file:///D:/bld/python-levenshtein_1649607057775/work | ||
PyWavelets==1.3.0 | ||
pywin32==305 | ||
pywinpty @ file:///C:/ci_310/pywinpty_1644230983541/work/target/wheels/pywinpty-2.0.2-cp37-none-win_amd64.whl | ||
PyYAML==6.0 | ||
pyzmq @ file:///C:/ci/pyzmq_1657597757589/work | ||
qt5-applications==5.15.2.2.2 | ||
qt5-tools==5.15.2.1.2 | ||
qtconsole @ file:///home/conda/feedstock_root/build_artifacts/qtconsole-base_1667404144336/work | ||
QtPy @ file:///home/conda/feedstock_root/build_artifacts/qtpy_1667873092748/work | ||
qudida==0.0.4 | ||
regex==2022.10.31 | ||
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1673863902341/work | ||
requests-oauthlib==1.3.1 | ||
rsa==4.9 | ||
scikit-image==0.19.3 | ||
scikit-learn==1.0.2 | ||
scipy==1.7.3 | ||
Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1628511208346/work | ||
shapely==2.0.0 | ||
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work | ||
soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work | ||
SQLAlchemy @ file:///D:/bld/sqlalchemy_1659998874450/work | ||
tabulate==0.9.0 | ||
tensorboard==2.11.0 | ||
tensorboard-data-server==0.6.1 | ||
tensorboard-plugin-wit==1.8.1 | ||
termcolor==2.1.1 | ||
terminado @ file:///D:/bld/terminado_1652790822513/work | ||
terminaltables==3.1.10 | ||
threadpoolctl==3.1.0 | ||
tifffile==2021.11.2 | ||
timm==0.6.12 | ||
tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1666100256010/work | ||
tomli==2.0.1 | ||
torch==1.13.1+cu116 | ||
torchaudio==0.13.1+cu116 | ||
torchvision==0.14.1+cu116 | ||
tornado @ file:///D:/bld/tornado_1656937938087/work | ||
tqdm==4.64.1 | ||
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work | ||
typed-ast==1.5.4 | ||
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1665144421445/work | ||
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1673452138552/work | ||
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work | ||
webencodings==0.5.1 | ||
Werkzeug==2.2.2 | ||
widgetsnbextension @ file:///home/conda/feedstock_root/build_artifacts/widgetsnbextension_1672066693230/work | ||
win-inet-pton @ file:///D:/bld/win_inet_pton_1667051142467/work | ||
wincertstore==0.2 | ||
yacs==0.1.8 | ||
yapf==0.32.0 | ||
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1675982654259/work |
Oops, something went wrong.