This repository has been archived by the owner on May 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
96 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Some basic setup: | ||
# Setup detectron2 logger | ||
import detectron2 | ||
import sys | ||
import os | ||
import argparse | ||
|
||
sys.path.insert(0, os.path.abspath('./detectron2')) | ||
|
||
|
||
from detectron2.utils.logger import setup_logger | ||
setup_logger() | ||
|
||
# import some common libraries | ||
import numpy as np | ||
import os, cv2 | ||
import matplotlib.pyplot as plt | ||
|
||
# import some common detectron2 utilities | ||
from detectron2 import model_zoo | ||
from detectron2.engine import DefaultPredictor | ||
from detectron2.config import get_cfg | ||
from detectron2.utils.visualizer import Visualizer | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--path2image", type=str) | ||
parser.add_argument("--path2save", type=str) | ||
args = parser.parse_args() | ||
|
||
cfg = get_cfg() | ||
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml")) | ||
cfg.MODEL.WEIGHTS = "model_final.pth" # path to the model we just trained | ||
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.001 # set a custom testing threshold | ||
cfg.SOLVER.BASE_LR = 0.00025 | ||
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 | ||
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 | ||
|
||
predictor = DefaultPredictor(cfg) | ||
|
||
im = cv2.imread(args.path2image) | ||
outputs = predictor(im) | ||
v = Visualizer( | ||
im[:, :, ::-1], | ||
) | ||
for box in outputs["instances"].pred_boxes.to('cpu'): | ||
v.draw_box(box) | ||
v.draw_text(str(box[:2].numpy()), tuple(box[:2].numpy())) | ||
v = v.get_output() | ||
img = v.get_image()[:, :, ::-1] | ||
plt.imshow(img) | ||
cv2.imwrite(args.path2save) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
|
||
os.system("python -m pip install pyyaml==5.1") | ||
os.system("python -m pip install fvcore") | ||
import sys, os, distutils.core | ||
os.system("git clone 'https://github.com/facebookresearch/detectron2'") | ||
dist = distutils.core.run_setup("./detectron2/setup.py") | ||
|
||
os.system("python -m pip install -e detectron2") | ||
|
||
import torch, detectron2 | ||
|
||
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2]) | ||
CUDA_VERSION = torch.__version__.split("+")[-1] | ||
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION) | ||
print("detectron2:", detectron2.__version__) | ||
|
||
os.system("python inference.py") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters