Skip to content
This repository has been archived by the owner on May 24, 2024. It is now read-only.

Commit

Permalink
faster_rcnn
Browse files Browse the repository at this point in the history
  • Loading branch information
stazizov committed Dec 11, 2022
1 parent 17a5085 commit e09e590
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 57 deletions.
53 changes: 53 additions & 0 deletions faster_rcnn/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Some basic setup:
# Setup detectron2 logger
import detectron2
import sys
import os
import argparse

sys.path.insert(0, os.path.abspath('./detectron2'))


from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import os, cv2
import matplotlib.pyplot as plt

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--path2image", type=str)
parser.add_argument("--path2save", type=str)
args = parser.parse_args()

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml"))
cfg.MODEL.WEIGHTS = "model_final.pth" # path to the model we just trained
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.001 # set a custom testing threshold
cfg.SOLVER.BASE_LR = 0.00025
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1

predictor = DefaultPredictor(cfg)

im = cv2.imread(args.path2image)
outputs = predictor(im)
v = Visualizer(
im[:, :, ::-1],
)
for box in outputs["instances"].pred_boxes.to('cpu'):
v.draw_box(box)
v.draw_text(str(box[:2].numpy()), tuple(box[:2].numpy()))
v = v.get_output()
img = v.get_image()[:, :, ::-1]
plt.imshow(img)
cv2.imwrite(args.path2save)
18 changes: 18 additions & 0 deletions faster_rcnn/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os

os.system("python -m pip install pyyaml==5.1")
os.system("python -m pip install fvcore")
import sys, os, distutils.core
os.system("git clone 'https://github.com/facebookresearch/detectron2'")
dist = distutils.core.run_setup("./detectron2/setup.py")

os.system("python -m pip install -e detectron2")

import torch, detectron2

TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
print("detectron2:", detectron2.__version__)

os.system("python inference.py")
82 changes: 25 additions & 57 deletions notebooks/faster_rcnn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "FsePPpwZSmqt",
"outputId": "f7751403-b2be-4bec-c03e-77476b25863e",
"vscode": {
"languageId": "python"
}
"outputId": "f7751403-b2be-4bec-c03e-77476b25863e"
},
"outputs": [],
"source": [
Expand All @@ -40,10 +37,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "0d288Z2mF5dC",
"outputId": "a015861b-c9e5-4a56-c8f5-ed9a1f30d80d",
"vscode": {
"languageId": "python"
}
"outputId": "a015861b-c9e5-4a56-c8f5-ed9a1f30d80d"
},
"outputs": [],
"source": [
Expand All @@ -59,10 +53,7 @@
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "ZyAvNCJMmvFF",
"vscode": {
"languageId": "python"
}
"id": "ZyAvNCJMmvFF"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -112,10 +103,7 @@
"height": 496
},
"id": "dq9GY37ml1kr",
"outputId": "a1eebd6e-e572-4817-c682-9bd9cdb7f1fe",
"vscode": {
"languageId": "python"
}
"outputId": "a1eebd6e-e572-4817-c682-9bd9cdb7f1fe"
},
"outputs": [],
"source": [
Expand All @@ -141,10 +129,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "HUjkwRsOn1O0",
"outputId": "750bfc51-c8d1-4f71-f605-9549abad7792",
"vscode": {
"languageId": "python"
}
"outputId": "750bfc51-c8d1-4f71-f605-9549abad7792"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -173,10 +158,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "4Qg7zSVOulkb",
"outputId": "b3c3e5b1-44f0-4402-bd63-076af70ef442",
"vscode": {
"languageId": "python"
}
"outputId": "b3c3e5b1-44f0-4402-bd63-076af70ef442"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -204,10 +186,7 @@
"height": 1000
},
"id": "FDJJSo0yMgja",
"outputId": "2e945f2a-e35b-43ee-d691-33e1e8e88a30",
"vscode": {
"languageId": "python"
}
"outputId": "2e945f2a-e35b-43ee-d691-33e1e8e88a30"
},
"outputs": [],
"source": [
Expand All @@ -223,10 +202,7 @@
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "PIbAM2pv-urF",
"vscode": {
"languageId": "python"
}
"id": "PIbAM2pv-urF"
},
"outputs": [],
"source": [
Expand All @@ -250,10 +226,7 @@
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "2rR5wBAkO21S",
"vscode": {
"languageId": "python"
}
"id": "2rR5wBAkO21S"
},
"outputs": [],
"source": [
Expand All @@ -264,10 +237,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UkNbUzUOLYf0",
"vscode": {
"languageId": "python"
}
"id": "UkNbUzUOLYf0"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -298,10 +268,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "7unkuuiqLdqd",
"outputId": "bb9fd121-e584-4cf9-f748-b12b30731738",
"vscode": {
"languageId": "python"
}
"outputId": "bb9fd121-e584-4cf9-f748-b12b30731738"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -626,10 +593,7 @@
}
},
"id": "hBXeH8UXFcqU",
"outputId": "e68c28e1-14e0-4411-ffe9-37ebde681469",
"vscode": {
"languageId": "python"
}
"outputId": "e68c28e1-14e0-4411-ffe9-37ebde681469"
},
"outputs": [
{
Expand Down Expand Up @@ -677,10 +641,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "Ya5nEuMELeq8",
"outputId": "47400704-a137-4b7d-ddbb-e844738126ec",
"vscode": {
"languageId": "python"
}
"outputId": "47400704-a137-4b7d-ddbb-e844738126ec"
},
"outputs": [],
"source": [
Expand All @@ -707,10 +668,7 @@
"height": 1000
},
"id": "U5LhISJqWXgM",
"outputId": "48a0f5b8-9093-444a-cb56-e4528e22bfbc",
"vscode": {
"languageId": "python"
}
"outputId": "48a0f5b8-9093-444a-cb56-e4528e22bfbc"
},
"outputs": [
{
Expand Down Expand Up @@ -749,8 +707,18 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "proton",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]"
},
"vscode": {
"interpreter": {
"hash": "ae1f3d19af4a627e6fd227f47e41f2bdb26008c49e5df6c5bb5291576095ff68"
}
}
},
"nbformat": 4,
Expand Down

0 comments on commit e09e590

Please sign in to comment.