@@ -114,6 +115,7 @@ In addition, you can find a table of the basic supported fields, modalities, vie
NuClick
Segmentation
Classification
+
SAM2 (2D)
@@ -143,6 +145,7 @@ In addition, you can find a table of the basic supported fields, modalities, vie
DeepEdit
Tooltracking
InBody/OutBody
+
SAM2 (2D)
@@ -210,6 +213,19 @@ To install the _**latest features**_ using one of the following options:
docker run --gpus all --rm -ti --ipc=host --net=host projectmonai/monailabel:latest bash
+### SAM-2
+
+> By default, [**SAM2**](https://github.com/facebookresearch/sam2/) model is included for all the Apps when **_python >= 3.10_**
+> - **sam_2d**: for any organ or tissue and others over a given slice/2D image.
+> - **sam_3d**: to support SAM2 propagation over multiple slices (Radiology/MONAI-Bundle).
+
+If you are using `pip install monailabel` by default it uses [SAM-2](https://huggingface.co/facebook/sam2-hiera-large) models.
+
+To use [SAM-2.1](https://huggingface.co/facebook/sam2.1-hiera-large) use one of following options.
+ - Use monailabel [Docker](https://hub.docker.com/r/projectmonai/monailabel) instead of pip package
+ - Run monailabel in dev mode (git checkout)
+ - If you have installed monailabel via pip then uninstall **_sam2_** package `pip uninstall sam2` and then run `pip install -r requirements.txt` or install latest **SAM-2** from it's [github](https://github.com/facebookresearch/sam2/tree/main?tab=readme-ov-file#installation).
+
## Step 2 MONAI Label Sample Applications
Radiology
diff --git a/docs/images/dsa.jpg b/docs/images/dsa.jpg
index 696763273..bfbbf741d 100644
Binary files a/docs/images/dsa.jpg and b/docs/images/dsa.jpg differ
diff --git a/docs/images/quickstart/monai-label-plugin-favorite-modules-1.png b/docs/images/quickstart/monai-label-plugin-favorite-modules-1.png
index 41919275a..d80e55b57 100644
Binary files a/docs/images/quickstart/monai-label-plugin-favorite-modules-1.png and b/docs/images/quickstart/monai-label-plugin-favorite-modules-1.png differ
diff --git a/docs/images/qupath.jpg b/docs/images/qupath.jpg
index 4d493a8fe..4525a1049 100644
Binary files a/docs/images/qupath.jpg and b/docs/images/qupath.jpg differ
diff --git a/docs/images/sample-apps/deepedit_brain_tumor.png b/docs/images/sample-apps/deepedit_brain_tumor.png
deleted file mode 100644
index ecb476635..000000000
Binary files a/docs/images/sample-apps/deepedit_brain_tumor.png and /dev/null differ
diff --git a/docs/images/sample-apps/deepedit_left_atrium.png b/docs/images/sample-apps/deepedit_left_atrium.png
deleted file mode 100644
index d4f15730c..000000000
Binary files a/docs/images/sample-apps/deepedit_left_atrium.png and /dev/null differ
diff --git a/docs/images/sample-apps/deepedit_left_ventricle.png b/docs/images/sample-apps/deepedit_left_ventricle.png
deleted file mode 100644
index 7fc509fbd..000000000
Binary files a/docs/images/sample-apps/deepedit_left_ventricle.png and /dev/null differ
diff --git a/docs/images/sample-apps/deepedit_lungs.png b/docs/images/sample-apps/deepedit_lungs.png
deleted file mode 100644
index 5ce16a961..000000000
Binary files a/docs/images/sample-apps/deepedit_lungs.png and /dev/null differ
diff --git a/docs/images/sample-apps/deepedit_spleen.png b/docs/images/sample-apps/deepedit_spleen.png
deleted file mode 100644
index fdbd5ab7c..000000000
Binary files a/docs/images/sample-apps/deepedit_spleen.png and /dev/null differ
diff --git a/docs/images/sample-apps/deepedit_vertebra.png b/docs/images/sample-apps/deepedit_vertebra.png
deleted file mode 100644
index 4e78b3478..000000000
Binary files a/docs/images/sample-apps/deepedit_vertebra.png and /dev/null differ
diff --git a/docs/images/sample-apps/segmentation_heart_ventricles.png b/docs/images/sample-apps/segmentation_heart_ventricles.png
deleted file mode 100644
index 75bbd3a41..000000000
Binary files a/docs/images/sample-apps/segmentation_heart_ventricles.png and /dev/null differ
diff --git a/monailabel/client/client.py b/monailabel/client/client.py
index 37082fe69..91d172abc 100644
--- a/monailabel/client/client.py
+++ b/monailabel/client/client.py
@@ -336,6 +336,7 @@ def infer(self, model, image_id, params, label_in=None, file=None, session_id=No
fields = {"params": json.dumps(params) if params else "{}"}
files = {"label": label_in} if label_in else {}
files.update({"file": file} if file and not session_id else {})
+ logger.info(f"Files: {files}")
status, form, files, _ = MONAILabelUtils.http_multipart(
"POST", self._server_url, selector, fields, files, headers=self._headers
@@ -584,6 +585,9 @@ def send_response(conn, content_type="application/json"):
@staticmethod
def save_result(files, tmpdir):
+ if not files:
+ return None
+
for name in files:
data = files[name]
result_file = os.path.join(tmpdir, name)
diff --git a/monailabel/config.py b/monailabel/config.py
index bf195c7bc..25589a400 100644
--- a/monailabel/config.py
+++ b/monailabel/config.py
@@ -8,14 +8,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import os
+from importlib.metadata import distributions
from typing import Any, Dict, List, Optional
from pydantic import AnyHttpUrl
from pydantic_settings import BaseSettings, SettingsConfigDict
+def is_package_installed(name):
+ return name in sorted(x.name for x in distributions())
+
+
class Settings(BaseSettings):
MONAI_LABEL_API_STR: str = ""
MONAI_LABEL_PROJECT_NAME: str = "MONAILabel"
@@ -98,6 +102,19 @@ class Settings(BaseSettings):
MONAI_ZOO_REPO: str = "Project-MONAI/model-zoo/hosting_storage_v1"
MONAI_ZOO_AUTH_TOKEN: str = ""
+ # Refer: https://github.com/facebookresearch/sam2?tab=readme-ov-file#model-description
+ # Refer: https://huggingface.co/facebook/sam2-hiera-large
+ MONAI_SAM_MODEL_PT: str = (
+ "https://huggingface.co/facebook/sam2.1-hiera-large/resolve/main/sam2.1_hiera_large.pt"
+ if is_package_installed("SAM-2")
+ else "https://huggingface.co/facebook/sam2-hiera-large/resolve/main/sam2_hiera_large.pt"
+ )
+ MONAI_SAM_MODEL_CFG: str = (
+ "https://huggingface.co/facebook/sam2.1-hiera-large/resolve/main/sam2.1_hiera_l.yaml"
+ if is_package_installed("SAM-2")
+ else "https://huggingface.co/facebook/sam2-hiera-large/resolve/main/sam2_hiera_l.yaml"
+ )
+
model_config = SettingsConfigDict(
env_file=".env",
case_sensitive=True,
diff --git a/monailabel/endpoints/infer.py b/monailabel/endpoints/infer.py
index cecb4d7dd..abf2cc4cc 100644
--- a/monailabel/endpoints/infer.py
+++ b/monailabel/endpoints/infer.py
@@ -91,10 +91,8 @@ def send_response(datastore, result, output, background_tasks):
if output == "json":
return res_json
- m_type = get_mime_type(res_img)
-
if output == "image":
- return FileResponse(res_img, media_type=m_type, filename=os.path.basename(res_img))
+ return FileResponse(res_img, media_type=get_mime_type(res_img), filename=os.path.basename(res_img))
if output == "dicom_seg":
res_dicom_seg = result.get("dicom_seg")
@@ -106,7 +104,7 @@ def send_response(datastore, result, output, background_tasks):
res_fields = dict()
res_fields["params"] = (None, json.dumps(res_json), "application/json")
if res_img and os.path.exists(res_img):
- res_fields["image"] = (os.path.basename(res_img), open(res_img, "rb"), m_type)
+ res_fields["image"] = (os.path.basename(res_img), open(res_img, "rb"), get_mime_type(res_img))
else:
logger.info(f"Return only Result Json as Result Image is not available: {res_img}")
return res_json
diff --git a/monailabel/interfaces/tasks/infer_v2.py b/monailabel/interfaces/tasks/infer_v2.py
index 234452822..f87b8db7b 100644
--- a/monailabel/interfaces/tasks/infer_v2.py
+++ b/monailabel/interfaces/tasks/infer_v2.py
@@ -92,5 +92,5 @@ def is_valid(self) -> bool:
pass
@abstractmethod
- def __call__(self, request) -> Union[Dict, Tuple[str, Dict[str, Any]]]:
+ def __call__(self, request) -> Tuple[Union[str, None], Dict]:
pass
diff --git a/monailabel/sam2/__init__.py b/monailabel/sam2/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/monailabel/sam2/infer.py b/monailabel/sam2/infer.py
new file mode 100644
index 000000000..1ddca8ed3
--- /dev/null
+++ b/monailabel/sam2/infer.py
@@ -0,0 +1,464 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import copy
+import logging
+import os
+import pathlib
+import shutil
+import tempfile
+from datetime import timedelta
+from time import time
+from typing import Any, Dict, Tuple, Union
+
+import numpy as np
+import pylab
+import schedule
+import torch
+from hydra import initialize_config_dir
+from hydra.core.global_hydra import GlobalHydra
+from monai.transforms import KeepLargestConnectedComponent, LoadImaged
+from PIL import Image
+from sam2.build_sam import build_sam2, build_sam2_video_predictor
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+from skimage.util import img_as_ubyte
+from timeloop import Timeloop
+from tqdm import tqdm
+
+from monailabel.config import settings
+from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType
+from monailabel.interfaces.utils.transform import run_transforms
+from monailabel.transform.writer import Writer
+from monailabel.utils.others.generic import (
+ device_list,
+ download_file,
+ get_basename_no_ext,
+ md5_digest,
+ name_to_device,
+ remove_file,
+ strtobool,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class ImageCache:
+ def __init__(self):
+ cache_path = settings.MONAI_LABEL_DATASTORE_CACHE_PATH
+ self.cache_path = (
+ os.path.join(cache_path, "sam2")
+ if cache_path
+ else os.path.join(pathlib.Path.home(), ".cache", "monailabel", "sam2")
+ )
+ self.cached_dirs = {}
+ self.cache_expiry_sec = 10 * 60
+
+ remove_file(self.cache_path)
+ os.makedirs(self.cache_path, exist_ok=True)
+ logger.info(f"Image Cache Initialized: {self.cache_path}")
+
+ def cleanup(self):
+ ts = time()
+ expired = {k: v for k, v in self.cached_dirs.items() if v < ts}
+ for k, v in expired.items():
+ self.cached_dirs.pop(k)
+ logger.info(f"Remove Expired Image: {k}; ExpiryTs: {v}; CurrentTs: {ts}")
+ remove_file(k)
+
+ def monitor(self):
+ self.cleanup()
+ time_loop = Timeloop()
+ schedule.every(1).minutes.do(self.cleanup)
+
+ @time_loop.job(interval=timedelta(seconds=60))
+ def run_scheduler():
+ schedule.run_pending()
+
+ time_loop.start(block=False)
+
+
+image_cache = ImageCache()
+image_cache.monitor()
+
+
+class Sam2InferTask(InferTask):
+ def __init__(
+ self,
+ model_dir,
+ type=InferType.ANNOTATION,
+ dimension=2,
+ labels=None,
+ additional_info=None,
+ image_loader=LoadImaged(keys="image"),
+ post_trans=None,
+ writer=Writer(ref_image="image"),
+ config=None,
+ ):
+ super().__init__(
+ type=type,
+ dimension=dimension,
+ labels=labels,
+ description="SAM2 (Segment Anything Model)",
+ config={"device": device_list(), "reset_state": False, "largest_cc": False, "pylab": False},
+ )
+ self.additional_info = additional_info
+ self.image_loader = image_loader
+ self.post_trans = post_trans
+ self.writer = writer
+ if config:
+ self._config.update(config)
+
+ # Download PreTrained Model
+ pt_url = settings.MONAI_SAM_MODEL_PT
+ conf_url = settings.MONAI_SAM_MODEL_CFG
+ sam_pt = pt_url.split("/")[-1]
+ sam_conf = conf_url.split("/")[-1]
+
+ self.path = os.path.join(model_dir, sam_pt)
+ self.config_path = os.path.join(model_dir, sam_conf)
+
+ GlobalHydra.instance().clear()
+ initialize_config_dir(config_dir=model_dir)
+
+ download_file(pt_url, self.path)
+ download_file(conf_url, self.config_path)
+ self.config_path = sam_conf
+
+ self.predictors = {}
+ self.image_cache = {}
+ self.inference_state = None
+
+ def info(self) -> Dict[str, Any]:
+ d = super().info()
+ if self.additional_info:
+ d.update(self.additional_info)
+ return d
+
+ def is_valid(self) -> bool:
+ return True
+
+ def run2d(self, image_tensor, request, debug=False):
+ device = name_to_device(request.get("device", "cuda"))
+ predictor = self.predictors.get(device)
+ if predictor is None:
+ logger.info(f"Using Device: {device}")
+ device_t = torch.device(device)
+ if device_t.type == "cuda":
+ torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
+ if torch.cuda.get_device_properties(0).major >= 8:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+ sam2_model = build_sam2(self.config_path, self.path, device=device)
+ predictor = SAM2ImagePredictor(sam2_model)
+ self.predictors[device] = predictor
+
+ slice_idx = request.get("slice")
+ if slice_idx is None or slice_idx < 0:
+ slices = {p[2] for p in request["foreground"] if len(p) > 2}
+ slices.update({p[2] for p in request["background"] if len(p) > 2})
+ slices = list(slices)
+ slice_idx = slices[0] if len(slices) else -1
+ else:
+ slices = {slice_idx}
+
+ if slice_idx < 0 and len(request["roi"]) == 6:
+ slice_idx = round(request["roi"][4] + (request["roi"][5] - request["roi"][4]) // 2)
+ slices = {slice_idx}
+ logger.info(f"Slices: {slices}; Slice Index: {slice_idx}")
+
+ if slice_idx < 0:
+ slice_np = image_tensor.cpu().numpy()
+ slice_rgb_np = slice_np.astype(np.uint8) if np.max(slice_np) > 1 else img_as_ubyte(slice_np)
+ else:
+ slice_np = image_tensor[:, :, slice_idx].cpu().numpy()
+
+ if strtobool(request.get("pylab")):
+ slice_rgb_file = tempfile.NamedTemporaryFile(suffix=".jpg").name
+ pylab.imsave(slice_rgb_file, slice_np, format="jpg", cmap="Greys_r")
+ slice_rgb_np = np.array(Image.open(slice_rgb_file))
+ remove_file(slice_rgb_file)
+ else:
+ slice_rgb_np = np.array(Image.fromarray(slice_np).convert("RGB"))
+
+ logger.info(f"Slice Index:{slice_idx}; (Image) Slice Shape: {slice_np.shape}")
+ if debug:
+ logger.info(f"Slice {slice_np.shape} Type: {slice_np.dtype}; Max: {np.max(slice_np)}")
+ logger.info(f"Slice RGB {slice_rgb_np.shape} Type: {slice_rgb_np.dtype}; Max: {np.max(slice_rgb_np)}")
+ if slice_idx < 0 and image_tensor.meta.get("filename_or_obj"):
+ shutil.copy(image_tensor.meta["filename_or_obj"], "image.jpg")
+ else:
+ pylab.imsave("image.jpg", slice_np, format="jpg", cmap="Greys_r")
+ Image.fromarray(slice_rgb_np).save("slice.jpg")
+
+ predictor.reset_predictor()
+ predictor.set_image(slice_rgb_np)
+
+ location = request.get("location", (0, 0))
+ tx, ty = location[0], location[1]
+ fp = [[p[0] - tx, p[1] - ty] for p in request["foreground"]]
+ bp = [[p[0] - tx, p[1] - ty] for p in request["background"]]
+ roi = request.get("roi")
+ roi = [roi[0] - tx, roi[1] - ty, roi[2] - tx, roi[3] - ty] if roi else None
+
+ if debug:
+ slice_rgb_np_p = np.copy(slice_rgb_np)
+ if roi:
+ slice_rgb_np_p[roi[0] : roi[2], roi[1] : roi[3], 2] = 255
+ for k, ps in {1: fp, 0: bp}.items():
+ for p in ps:
+ slice_rgb_np_p[p[0] - 2 : p[0] + 2, p[1] - 2 : p[1] + 2, k] = 255
+ Image.fromarray(slice_rgb_np_p).save("slice_p.jpg")
+
+ point_coords = fp + bp
+ point_coords = [[p[1], p[0]] for p in point_coords] # Flip x,y => y,x
+ box = [roi[1], roi[0], roi[3], roi[2]] if roi else None
+
+ point_labels = [1] * len(fp) + [0] * len(bp)
+ logger.info(f"Coords: {point_coords}; Labels: {point_labels}; Box: {box}")
+
+ masks, scores, _ = predictor.predict(
+ point_coords=np.array(point_coords) if point_coords else None,
+ point_labels=np.array(point_labels) if point_labels else None,
+ multimask_output=False,
+ box=np.array(box) if box else None,
+ )
+ # sorted_ind = np.argsort(scores)[::-1]
+ # masks = masks[sorted_ind]
+ # scores = scores[sorted_ind]
+ if strtobool(request.get("largest_cc", False)):
+ masks = KeepLargestConnectedComponent()(masks).cpu().numpy()
+
+ logger.info(f"Masks Shape: {masks.shape}; Scores: {scores}")
+ if self.post_trans is None:
+ if slice_idx < 0:
+ pred = masks[0]
+ else:
+ pred = np.zeros(tuple(image_tensor.shape))
+ pred[:, :, slice_idx] = masks[0]
+
+ data = copy.copy(request)
+ data.update({"image_path": request["image"], "pred": pred, "image": image_tensor})
+ else:
+ data = copy.copy(request)
+ data.update({"image_path": request["image"], "pred": masks[0], "image": image_tensor})
+ data = run_transforms(data, self.post_trans, log_prefix="POST", use_compose=False)
+
+ if debug:
+ # pylab.imsave("mask.jpg", masks[0], format="jpg", cmap="Greys_r")
+ Image.fromarray(masks[0] > 0).save("mask.jpg")
+
+ return self.writer(data)
+
+ def run_3d(self, image_tensor, set_image_state, request, debug=False):
+ device = name_to_device(request.get("device", "cuda"))
+ reset_state = strtobool(request.get("reset_state", "false"))
+ predictor = self.predictors.get(device)
+ if predictor is None:
+ logger.info(f"Using Device: {device}")
+ device_t = torch.device(device)
+ if device_t.type == "cuda":
+ torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
+ if torch.cuda.get_device_properties(0).major >= 8:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+ predictor = build_sam2_video_predictor(self.config_path, self.path, device=device)
+ self.predictors[device] = predictor
+
+ image_path = request["image"]
+ video_dir = os.path.join(
+ image_cache.cache_path, get_basename_no_ext(image_path) if debug else md5_digest(image_path)
+ )
+ if not os.path.isdir(video_dir):
+ os.makedirs(video_dir, exist_ok=True)
+ for slice_idx in tqdm(range(image_tensor.shape[-1])):
+ slice_np = image_tensor[:, :, slice_idx].numpy()
+ slice_file = os.path.join(video_dir, f"{str(slice_idx).zfill(5)}.jpg")
+
+ if strtobool(request.get("pylab")):
+ pylab.imsave(slice_file, slice_np, format="jpg", cmap="Greys_r")
+ else:
+ Image.fromarray(slice_np).convert("RGB").save(slice_file)
+ logger.info(f"Image (Flattened): {image_tensor.shape[-1]} slices; {video_dir}")
+
+ # Set Expiry Time
+ image_cache.cached_dirs[video_dir] = time() + image_cache.cache_expiry_sec
+
+ if reset_state or set_image_state:
+ if self.inference_state:
+ predictor.reset_state(self.inference_state)
+ self.inference_state = predictor.init_state(video_path=video_dir)
+
+ logger.info(f"Image Shape: {image_tensor.shape}")
+ fps: dict[int, Any] = {}
+ bps: dict[int, Any] = {}
+ sids = set()
+ for key in {"foreground", "background"}:
+ for p in request[key]:
+ sid = p[2]
+ sids.add(sid)
+ kps = fps if key == "foreground" else bps
+ if kps.get(sid):
+ kps[sid].append([p[0], p[1]])
+ else:
+ kps[sid] = [[p[0], p[1]]]
+
+ box = None
+ roi = request.get("roi")
+ if roi:
+ box = [roi[1], roi[0], roi[3], roi[2]]
+ sids.update([i for i in range(roi[4], roi[5])])
+
+ pred = np.zeros(tuple(image_tensor.shape))
+ for sid in sorted(sids):
+ fp = fps.get(sid, [])
+ bp = bps.get(sid, [])
+
+ point_coords = fp + bp
+ point_coords = [[p[1], p[0]] for p in point_coords] # Flip x,y => y,x
+ point_labels = [1] * len(fp) + [0] * len(bp)
+ # logger.info(f"{sid} - Coords: {point_coords}; Labels: {point_labels}; Box: {box}")
+
+ o_frame_ids, o_obj_ids, o_mask_logits = predictor.add_new_points_or_box(
+ inference_state=self.inference_state,
+ frame_idx=sid,
+ obj_id=1,
+ points=np.array(point_coords) if point_coords else None,
+ labels=np.array(point_labels) if point_labels else None,
+ box=np.array(box) if box else None,
+ )
+
+ # logger.info(f"{sid} - mask_logits: {o_mask_logits.shape}; frame_ids: {o_frame_ids}; obj_ids: {o_obj_ids}")
+ pred[:, :, sid] = (o_mask_logits[0][0] > 0.0).cpu().numpy()
+
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(self.inference_state):
+ # logger.info(f"propagate: {out_frame_idx} - mask_logits: {out_mask_logits.shape}; obj_ids: {out_obj_ids}")
+ pred[:, :, out_frame_idx] = (out_mask_logits[0][0] > 0.0).cpu().numpy()
+
+ writer = Writer(ref_image="image")
+ data = copy.copy(request)
+ data.update({"image_path": request["image"], "pred": pred, "image": image_tensor})
+ return writer(data)
+
+ def __call__(self, request, debug=False) -> Tuple[Union[str, None], Dict]:
+ start_ts = time()
+
+ logger.info(f"Infer Request: {request}")
+ image_path = request["image"]
+ image_tensor = self.image_cache.get(image_path)
+ set_image_state = False
+ cache_image = request.get("cache_image", True)
+
+ if "foreground" not in request:
+ request["foreground"] = []
+ if "background" not in request:
+ request["background"] = []
+ if "roi" not in request:
+ request["roi"] = []
+
+ if not cache_image or image_tensor is None:
+ # TODO:: Fix this to cache more than one image session
+ self.image_cache.clear()
+ image_tensor = self.image_loader(request)["image"]
+ if debug:
+ logger.info(f"Image Meta: {image_tensor.meta}")
+ self.image_cache[image_path] = image_tensor
+ set_image_state = True
+
+ logger.info(f"Image Shape: {image_tensor.shape}; cached: {cache_image}")
+ if self.dimension == 2:
+ mask_file, result_json = self.run2d(image_tensor, request, debug)
+ else:
+ mask_file, result_json = self.run_3d(image_tensor, set_image_state, request)
+
+ logger.info(f"Mask File: {mask_file}; Latency: {round(time() - start_ts, 4)} sec")
+ result_json["latencies"] = {
+ "pre": 0,
+ "infer": 0,
+ "invert": 0,
+ "post": 0,
+ "write": 0,
+ "total": round(time() - start_ts, 2),
+ "transform": None,
+ }
+ return mask_file, result_json
+
+
+"""
+def main():
+ import shutil
+
+ logging.basicConfig(
+ level=logging.INFO,
+ format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ force=True,
+ )
+
+ app_name = "radiology"
+ app_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "sample-apps", app_name))
+ model_dir = os.path.join(app_dir, "model")
+ logger.info(f"Model Dir: {model_dir}")
+ if app_name == "pathology":
+ from lib.transforms import LoadImagePatchd
+
+ from monailabel.transform.post import FindContoursd
+ from monailabel.transform.writer import PolygonWriter
+
+ task = Sam2InferTask(
+ model_dir=model_dir,
+ dimension=2,
+ additional_info={"nuclick": True, "pathology": True},
+ image_loader=LoadImagePatchd(keys="image", padding=False),
+ post_trans=[FindContoursd(keys="pred")],
+ writer=PolygonWriter(),
+ )
+ request = {
+ "device": "cuda:1",
+ "reset_state": False,
+ "model": "sam2",
+ "image": "/home/sachi/Datasets/wsi/JP2K-33003-1.svs",
+ "output": "asap",
+ "level": 0,
+ "location": (2183, 4873),
+ "size": (128, 128),
+ "tile_size": [128, 128],
+ "min_poly_area": 30,
+ "foreground": [[2247, 4937]],
+ "background": [],
+ # "roi": [2220, 4900, 2320, 5000],
+ "max_workers": 1,
+ "id": 0,
+ "logging": "INFO",
+ "result_write_to_file": False,
+ "description": "SAM2 (Segment Anything Model)",
+ "save_label": False,
+ }
+ else:
+ task = Sam2InferTask(model_dir)
+ request = {
+ "image": "/home/sachi/Datasets/SAM2/image.nii.gz",
+ "foreground": [[71, 175, 105]], # [199, 129, 47], [200, 100, 41]],
+ # "background": [[286, 175, 105]],
+ "roi": [44, 110, 113, 239, 72, 178],
+ "largest_cc": True,
+ }
+
+ result = task(request, debug=True)
+ if app_name == "pathology":
+ print(result)
+ else:
+ shutil.move(result[0], "mask.nii.gz")
+
+
+if __name__ == "__main__":
+ main()
+"""
diff --git a/monailabel/sam2/utils.py b/monailabel/sam2/utils.py
new file mode 100644
index 000000000..8abaef244
--- /dev/null
+++ b/monailabel/sam2/utils.py
@@ -0,0 +1,9 @@
+from monai.utils import optional_import
+
+
+def is_sam2_module_available():
+ try:
+ _, flag = optional_import("sam2")
+ return flag
+ except ImportError:
+ return False
diff --git a/monailabel/tasks/infer/basic_infer.py b/monailabel/tasks/infer/basic_infer.py
index 930bf6b85..7fef2d644 100644
--- a/monailabel/tasks/infer/basic_infer.py
+++ b/monailabel/tasks/infer/basic_infer.py
@@ -27,7 +27,7 @@
from monailabel.interfaces.utils.transform import dump_data, run_transforms
from monailabel.transform.cache import CacheTransformDatad
from monailabel.transform.writer import ClassificationWriter, DetectionWriter, Writer
-from monailabel.utils.others.generic import device_list, device_map, name_to_device
+from monailabel.utils.others.generic import device_list, device_map, name_to_device, strtobool
logger = logging.getLogger(__name__)
@@ -253,7 +253,7 @@ def detector(self, data=None) -> Optional[Callable]:
def __call__(
self, request, callbacks: Union[Dict[CallBackTypes, Any], None] = None
- ) -> Union[Dict, Tuple[str, Dict[str, Any]]]:
+ ) -> Tuple[Union[str, None], Dict]:
"""
It provides basic implementation to run the following in order
- Run Pre Transforms
@@ -322,8 +322,8 @@ def __call__(
data = callback_run_post_transforms(data)
latency_post = time.time() - start
- if self.skip_writer:
- return dict(data)
+ if self.skip_writer or strtobool(data.get("skip_writer")):
+ return None, dict(data)
start = time.time()
result_file_name, result_json = self.writer(data)
diff --git a/monailabel/tasks/infer/bundle.py b/monailabel/tasks/infer/bundle.py
index 120426bd5..3023f0fc5 100644
--- a/monailabel/tasks/infer/bundle.py
+++ b/monailabel/tasks/infer/bundle.py
@@ -137,15 +137,14 @@ def __init__(
# labels = ({v.lower(): int(k) for k, v in pred.get("channel_def", {}).items() if v.lower() != "background"})
labels = {}
+ type = self._get_type(os.path.basename(path), type)
for k, v in pred.get("channel_def", {}).items():
- if (not type.lower() == "deepedit") and (v.lower() != "background"):
- labels[v.lower()] = int(k)
- else:
+ logger.info(f"Model: {os.path.basename(path)}; Type: {type}; Label: {v} => {k}")
+ if v.lower() != "background" or type.lower() == "deepedit":
labels[v.lower()] = int(k)
description = metadata.get("description")
spatial_shape = image.get("spatial_shape")
dimension = len(spatial_shape) if spatial_shape else 3
- type = self._get_type(os.path.basename(path), type)
# if detection task, set post restore to False by default.
self.add_post_restore = False if type == "detection" else add_post_restore
@@ -273,28 +272,21 @@ def post_transforms(self, data=None) -> Sequence[Callable]:
return post
def _get_type(self, name, type):
+ if type:
+ return type
+
name = name.lower() if name else ""
- return (
- (
- InferType.DEEPEDIT
- if "deepedit" in name
- else (
- InferType.DEEPGROW
- if "deepgrow" in name
- else (
- InferType.DETECTION
- if "detection" in name
- else (
- InferType.SEGMENTATION
- if "segmentation" in name
- else InferType.CLASSIFICATION if "classification" in name else InferType.SEGMENTATION
- )
- )
- )
- )
- if not type
- else type
- )
+ if "deepedit" in name:
+ return InferType.DEEPEDIT
+ if "deepgrow" in name:
+ return InferType.DEEPGROW
+ if "detection" in name:
+ return InferType.DETECTION
+ if "segmentation" in name:
+ return InferType.SEGMENTATION
+ if "classification" in name:
+ return InferType.CLASSIFICATION
+ return InferType.SEGMENTATION
def _filter_transforms(self, transforms, filters):
if not filters or not transforms:
diff --git a/monailabel/tasks/scoring/epistemic_v2.py b/monailabel/tasks/scoring/epistemic_v2.py
index ee1849033..c040b844b 100644
--- a/monailabel/tasks/scoring/epistemic_v2.py
+++ b/monailabel/tasks/scoring/epistemic_v2.py
@@ -173,7 +173,7 @@ def run_scoring(self, image_id, simulation_size, model_ts, datastore):
accum_unl_outputs = []
for i in range(simulation_size):
- data = self.infer_task(request=request)
+ _, data = self.infer_task(request=request)
pred = data[self.infer_task.output_label_key] if isinstance(data, dict) else None
if pred is not None:
logger.debug(f"EPISTEMIC:: {image_id} => {i} => pred: {pred.shape}; sum: {np.sum(pred)}")
diff --git a/monailabel/transform/post.py b/monailabel/transform/post.py
index 615040474..cf4abdf96 100644
--- a/monailabel/transform/post.py
+++ b/monailabel/transform/post.py
@@ -129,7 +129,7 @@ def __call__(self, data):
spatial_size = spatial_shape[-len(current_size) :]
# Undo Spacing
- if torch.any(torch.Tensor(np.not_equal(current_size, spatial_size))):
+ if np.any(np.not_equal(current_size, spatial_size)):
resizer = Resize(spatial_size=spatial_size, mode=self.mode[idx])
result = resizer(result, mode=self.mode[idx], align_corners=self.align_corners[idx])
diff --git a/monailabel/utils/others/generic.py b/monailabel/utils/others/generic.py
index 0d741598e..a26edec63 100644
--- a/monailabel/utils/others/generic.py
+++ b/monailabel/utils/others/generic.py
@@ -241,8 +241,8 @@ def _list_files(d, ext):
]
-def strtobool(str):
- return bool(distutils.util.strtobool(str))
+def strtobool(s):
+ return False if s is None else s if isinstance(s, bool) else bool(distutils.util.strtobool(s))
def is_openslide_supported(name):
@@ -336,7 +336,7 @@ def get_bundle_models(app_dir, conf, conf_key="models"):
zoo_source = conf.get("zoo_source", settings.MONAI_ZOO_SOURCE)
models = conf.get(conf_key)
- models = models.split(",")
+ models = models.split(",") if models else []
models = [m.strip() for m in models]
if zoo_source == "monaihosting": # if in github env, access model zoo
diff --git a/monailabel/utils/others/pathology.py b/monailabel/utils/others/pathology.py
index 78ecb922d..03e6317eb 100644
--- a/monailabel/utils/others/pathology.py
+++ b/monailabel/utils/others/pathology.py
@@ -129,7 +129,7 @@ def create_asap_annotations_xml(json_data, loglevel="INFO"):
label = element["label"]
color = to_hex(color_map.get(label))
- logger.info(f"Adding Contours for label: {label}; color: {color}; color_map: {color_map}")
+ logger.debug(f"Adding Contours for label: {label}; color: {color}; color_map: {color_map}")
labels[label] = color
contours = element["contours"]
diff --git a/plugins/cvat/README.md b/plugins/cvat/README.md
index 587b0e4cb..38b934309 100644
--- a/plugins/cvat/README.md
+++ b/plugins/cvat/README.md
@@ -30,40 +30,51 @@ To install CVAT and enable Semi-Automatic and Automatic Annotation, follow these
```bash
git clone https://github.com/opencv/cvat
cd cvat
-git checkout v2.1.0 # MONAI Label requires tag v2.1.0
# Use your external IP instead of localhost to make the CVAT projects sharable
-export CVAT_HOST=127.0.0.1
-export CVAT_VERSION=v2.1.0
+export CVAT_HOST=`hostname -I | awk '{print $1}'`
# Start CVAT from docker-compose, make sure the IP and port are available.
+# Refer: https://docs.cvat.ai/docs/administration/advanced/installation_automatic_annotation/
docker-compose -f docker-compose.yml -f components/serverless/docker-compose.serverless.yml up -d
# Create a CVAT superuser account
-docker exec -it cvat bash -ic 'python3 ~/manage.py createsuperuser'
-
+docker exec -it cvat_server bash -ic 'python3 ~/manage.py createsuperuser'
```
**Note:** The setup process uses ports 8070, 8080, and 8090. If alternative ports are preferred, please refer to the [CVAT Guide](https://opencv.github.io/cvat/docs/administration/basics/installation/). For more information on installation steps, see the CVAT [Documentation for Semi-automatic and Automatic Annotation](https://opencv.github.io/cvat/docs/administration/advanced/installation_automatic_annotation/).
After completing these steps, CVAT should be accessible via http://127.0.0.1:8080 in Chrome. Use the superuser account created during installation to log in.
+
+
#### Setup Nuclio Container Platform
```bash
# Get Nuclio dashboard
-wget https://github.com/nuclio/nuclio/releases/download/1.5.16/nuctl-1.5.16-linux-amd64
-chmod +x nuctl-1.5.16-linux-amd64
-ln -sf $(pwd)/nuctl-1.5.16-linux-amd64 /usr/local/bin/nuctl
+export NUCLIO_VERSION=1.13.0
+wget https://github.com/nuclio/nuclio/releases/download/$NUCLIO_VERSION/nuctl-$NUCLIO_VERSION-linux-amd64
+chmod +x nuctl-$NUCLIO_VERSION-linux-amd64
+ln -sf $(pwd)/nuctl-$NUCLIO_VERSION-linux-amd64 /usr/local/bin/nuctl
```
-#### Deployment of Endoscopy Models
+#### Deployment of Endoscopy/SAM2 Models
This step is to deploy MONAI Label plugin with endoscopic models using Nuclio tool.
+> **Prerequisite:** MONAI Label Server is up and running for _**endoscopy**_ app.
```bash
+# Run MONAI Label Server (Make sure this Host/IP is accessible inside a docker)
+export MONAI_LABEL_SERVER=http://`hostname -I | awk '{print $1}'`:8000
+
git clone https://github.com/Project-MONAI/MONAILabel.git
+
# Deploy all endoscopy models
./plugins/cvat/deploy.sh endoscopy
+
# Or to deploy specific function and model, e.g., tooltracking
./plugins/cvat/deploy.sh endoscopy tooltracking
+
+# Deploy SAM2 Interactor
+./plugins/cvat/deploy.sh sam2 interactor
+
```
After model deployment, you can see the model names in the `Models` page of CVAT.
@@ -77,15 +88,4 @@ To check or monitor the status of deployed function containers, you can open the
That's it! With these steps, you should have successfully installed CVAT with the MONAI Label extension and deployed endoscopic models using the Nuclio tool.
### Publish Latest Model to CVAT/Nuclio
-Once you've fine-tuned the model and confirmed that it meets all the necessary conditions, you can push the updated model to the CVAT/Nuclio function container. This will allow you to use the latest version of the model in your workflows and applications.
-
-```bash
-workspace/endoscopy/update_cvat_model.sh
-
-# Bundle Example: publish tool tracking bundle trained model (run this command on the node where cvat/nuclio containers are running)
-workspace/endoscopy/update_cvat_model.sh tootracking
-# Bundle Example: publish inbody trained model
-workspace/endoscopy/update_cvat_model.sh inbody
-# DeepEdit Example: publish deepedit trained model (Not from bundle)
-workspace/endoscopy/update_cvat_model.sh deepedit
-```
+> Not Needed to publish the model to CVAT. Model is always served via MONAI Label.
diff --git a/plugins/cvat/deploy.sh b/plugins/cvat/deploy.sh
index d3b91af6a..6b9f5c243 100755
--- a/plugins/cvat/deploy.sh
+++ b/plugins/cvat/deploy.sh
@@ -18,6 +18,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)"
FUNCTION=${1:-**}
MODEL=${2:-*}
FUNCTIONS_DIR=${3:-$SCRIPT_DIR}
+MONAI_LABEL_SERVER="${MONAI_LABEL_SERVER:-http://`hostname -I | awk '{print $1}'`:8000}"
nuctl create project cvat
@@ -26,8 +27,12 @@ shopt -s globstar
for func_config in "$FUNCTIONS_DIR"/$FUNCTION/${MODEL}.yaml
do
func_root="$FUNCTIONS_DIR"
+ echo "Using MONAI Label Server: $MONAI_LABEL_SERVER"
+ cp $func_config ${func_config}.bak
+ sed -i "s|http://monailabel.com|$MONAI_LABEL_SERVER|g" $func_config
echo "Deploying $func_config..."
nuctl deploy --project-name cvat --path "$func_root" --file "$func_config" --platform local
+ mv ${func_config}.bak $func_config
done
nuctl get function
diff --git a/plugins/cvat/detector.py b/plugins/cvat/detector.py
new file mode 100644
index 000000000..e099a31c5
--- /dev/null
+++ b/plugins/cvat/detector.py
@@ -0,0 +1,145 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import io
+import json
+import logging
+import os
+import tempfile
+
+import numpy as np
+from PIL import Image
+
+from monailabel.client import MONAILabelClient
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+)
+
+
+def init_context(context):
+ context.logger.info("Init context... 0%")
+ server = os.environ.get("MONAI_LABEL_SERVER", "http://0.0.0.0:8000")
+ model = os.environ.get("MONAI_LABEL_MODEL", "tooltracking")
+ client = MONAILabelClient(server)
+
+ info = client.info()
+ model_info = info["models"][model] if info and info["models"] else None
+ context.logger.info(f"Monai Label Info: {model_info}")
+ assert model_info
+
+ context.user_data.model = model
+ context.user_data.model_handler = client
+ context.logger.info("Init context...100%")
+
+
+def handler(context, event):
+ model: str = context.user_data.model
+ client: MONAILabelClient = context.user_data.model_handler
+ context.logger.info(f"Run model: {model}")
+
+ data = event.body
+ image = Image.open(io.BytesIO(base64.b64decode(data["image"])))
+ context.logger.info(f"Image: {image.size}")
+
+ image_file = tempfile.NamedTemporaryFile(suffix=".jpg").name
+ image.save(image_file)
+
+ params = {"output": "json"}
+ _, output_json = client.infer(model=model, image_id="", file=image_file, params=params)
+ if isinstance(output_json, str) or isinstance(output_json, bytes):
+ output_json = json.loads(output_json)
+
+ results = []
+ prediction = output_json.get("prediction")
+ if prediction:
+ context.logger.info(f"(Classification) Prediction: {prediction}")
+ # CVAT Limitation:: tag is not yet supported https://github.com/opencv/cvat/issues/4212
+ # CVAT Limitation:: select highest score and create bbox to represent as tag
+ e = None
+ for element in prediction:
+ if element["score"] > 0:
+ e = element if e is None or element["score"] > e["score"] else e
+ context.logger.info(f"New Max Element: {e}")
+
+ context.logger.info(f"Final Element with Max Score: {e}")
+ if e:
+ results.append(
+ {
+ "label": e["label"],
+ "confidence": e["score"],
+ "type": "rectangle",
+ "points": [0, 0, image.size[0] - 1, image.size[1] - 1],
+ }
+ )
+ context.logger.info(f"(Classification) Results: {results}")
+ else:
+ annotations = output_json.get("annotations")
+ for a in annotations:
+ annotation = a.get("annotation", {})
+ if not annotation:
+ continue
+
+ elements = annotation.get("elements", [])
+ for element in elements:
+ label = element["label"]
+ contours = element["contours"]
+ for contour in contours:
+ points = np.array(contour, int)
+ results.append(
+ {
+ "label": label,
+ "points": points.flatten().tolist(),
+ "type": "polygon",
+ }
+ )
+
+ context.logger.info("=============================================================================\n")
+ return context.Response(
+ body=json.dumps(results),
+ headers={},
+ content_type="application/json",
+ status_code=200,
+ )
+
+
+if __name__ == "__main__":
+ import logging
+ from argparse import Namespace
+
+ logging.basicConfig(
+ level=logging.INFO,
+ format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ def print_all(*args, **kwargs):
+ return {"args": args, **kwargs}
+
+ with open("/home/sachi/Datasets/endo/frame001.jpg", "rb") as fp:
+ image = base64.b64encode(fp.read())
+
+ event = {"body": {"image": image}}
+ event = Namespace(**event)
+
+ context = Namespace(
+ **{
+ "logger": logging.getLogger(__name__),
+ "user_data": Namespace(**{"model": None, "model_handler": None}),
+ "Response": print_all,
+ }
+ )
+ init_context(context)
+ response = handler(context, event)
+ logging.info(response)
diff --git a/plugins/cvat/endoscopy/deepedit.yaml b/plugins/cvat/endoscopy/deepedit.yaml
index ba383f0fb..494e8a37a 100644
--- a/plugins/cvat/endoscopy/deepedit.yaml
+++ b/plugins/cvat/endoscopy/deepedit.yaml
@@ -14,8 +14,8 @@ metadata:
namespace: cvat
annotations:
name: DeepEdit
+ version: 2
type: interactor
- framework: pytorch
spec:
min_pos_points: 1
min_neg_points: 0
@@ -24,7 +24,7 @@ metadata:
spec:
description: A pre-trained DeepEdit model for interactive model for Endoscopy
runtime: 'python:3.8'
- handler: main:handler
+ handler: interactor:handler
eventTimeout: 30s
build:
@@ -34,17 +34,9 @@ spec:
directives:
preCopy:
- kind: ENV
- value: MONAI_LABEL_APP_DIR=/usr/local/monailabel/sample-apps/endoscopy
+ value: MONAI_LABEL_SERVER=http://monailabel.com
- kind: ENV
- value: MONAI_LABEL_MODELS=deepedit
- - kind: ENV
- value: PYTHONPATH=/usr/local/monailabel/sample-apps/endoscopy
- - kind: ENV
- value: MONAI_PRETRAINED_PATH=https://github.com/Project-MONAI/MONAILabel/releases/download/data
- - kind: ENV
- value: INTERACTOR_MODEL=true
- - kind: ENV
- value: MONAI_LABEL_FLIP_INPUT_POINTS=false
+ value: MONAI_LABEL_MODEL=deepedit
triggers:
myHttpTrigger:
@@ -53,11 +45,6 @@ spec:
workerAvailabilityTimeoutMilliseconds: 10000
attributes:
maxRequestBodySize: 33554432 # 32MB
- port: 8902
-
- resources:
- limits:
- nvidia.com/gpu: 1
platform:
attributes:
diff --git a/plugins/cvat/endoscopy/inbody.yaml b/plugins/cvat/endoscopy/inbody.yaml
index 076242cf5..48586b49c 100644
--- a/plugins/cvat/endoscopy/inbody.yaml
+++ b/plugins/cvat/endoscopy/inbody.yaml
@@ -25,7 +25,7 @@ metadata:
spec:
description: A pre-trained classification model for Endoscopy to flag if image follows InBody or OutBody
runtime: 'python:3.8'
- handler: main:handler
+ handler: detector:handler
eventTimeout: 30s
build:
@@ -35,13 +35,9 @@ spec:
directives:
preCopy:
- kind: ENV
- value: MONAI_LABEL_APP_DIR=/usr/local/monailabel/sample-apps/endoscopy
+ value: MONAI_LABEL_SERVER=http://monailabel.com
- kind: ENV
- value: MONAI_LABEL_MODELS=inbody
- - kind: ENV
- value: PYTHONPATH=/usr/local/monailabel/sample-apps/endoscopy
- - kind: ENV
- value: MONAI_PRETRAINED_PATH=https://github.com/Project-MONAI/MONAILabel/releases/download/data
+ value: MONAI_LABEL_MODEL=inbody
triggers:
myHttpTrigger:
@@ -50,11 +46,6 @@ spec:
workerAvailabilityTimeoutMilliseconds: 10000
attributes:
maxRequestBodySize: 33554432 # 32MB
- port: 8901
-
- resources:
- limits:
- nvidia.com/gpu: 1
platform:
attributes:
diff --git a/plugins/cvat/endoscopy/tooltracking.yaml b/plugins/cvat/endoscopy/tooltracking.yaml
index 584053f45..1cac8673e 100644
--- a/plugins/cvat/endoscopy/tooltracking.yaml
+++ b/plugins/cvat/endoscopy/tooltracking.yaml
@@ -24,7 +24,7 @@ metadata:
spec:
description: A pre-trained tool tracking model for Endoscopy
runtime: 'python:3.8'
- handler: main:handler
+ handler: detector:handler
eventTimeout: 30s
build:
@@ -34,13 +34,9 @@ spec:
directives:
preCopy:
- kind: ENV
- value: MONAI_LABEL_APP_DIR=/usr/local/monailabel/sample-apps/endoscopy
+ value: MONAI_LABEL_SERVER=http://monailabel.com
- kind: ENV
- value: MONAI_LABEL_MODELS=tooltracking
- - kind: ENV
- value: PYTHONPATH=/usr/local/monailabel/sample-apps/endoscopy
- - kind: ENV
- value: MONAI_PRETRAINED_PATH=https://github.com/Project-MONAI/MONAILabel/releases/download/data
+ value: MONAI_LABEL_MODEL=tooltracking
triggers:
myHttpTrigger:
@@ -49,11 +45,6 @@ spec:
workerAvailabilityTimeoutMilliseconds: 10000
attributes:
maxRequestBodySize: 33554432 # 32MB
- port: 8900
-
- resources:
- limits:
- nvidia.com/gpu: 1
platform:
attributes:
diff --git a/plugins/cvat/interactor.py b/plugins/cvat/interactor.py
new file mode 100644
index 000000000..19ac8d636
--- /dev/null
+++ b/plugins/cvat/interactor.py
@@ -0,0 +1,122 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import io
+import json
+import logging
+import os
+import tempfile
+
+import numpy as np
+from PIL import Image
+
+from monailabel.client import MONAILabelClient
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+)
+
+
+def init_context(context):
+ context.logger.info("Init context... 0%")
+ server = os.environ.get("MONAI_LABEL_SERVER", "http://0.0.0.0:8000")
+ model = os.environ.get("MONAI_LABEL_MODEL", "sam2")
+ client = MONAILabelClient(server)
+
+ info = client.info()
+ model_info = info["models"][model] if info and info["models"] else None
+ context.logger.info(f"Monai Label Info: {model_info}")
+ assert model_info
+
+ context.user_data.model = model
+ context.user_data.model_handler = client
+ context.logger.info("Init context...100%")
+
+
+def handler(context, event):
+ model: str = context.user_data.model
+ client: MONAILabelClient = context.user_data.model_handler
+ context.logger.info(f"Run model: {model}")
+
+ data = event.body
+ image = Image.open(io.BytesIO(base64.b64decode(data["image"])))
+ foreground = data.get("pos_points")
+ background = data.get("neg_points")
+ roi = data.get("obj_bbox", None)
+ context.logger.info(f"Image: {image.size}; Foreground: {foreground}; Background: {background}")
+
+ image_file = tempfile.NamedTemporaryFile(suffix=".jpg").name
+ image.save(image_file)
+
+ params = {
+ "output": "mask",
+ "foreground": np.asarray(foreground, dtype=int).tolist() if foreground else [],
+ "background": np.asarray(background, dtype=int).tolist() if background else [],
+ # "largest_cc": True,
+ }
+ if roi:
+ roi = np.asarray(roi, dtype=int).flatten().tolist()
+ params["roi"] = roi
+
+ context.logger.info(f"Model:{model}; Params: {params}")
+ output_mask, output_json = client.infer(model=model, image_id="", file=image_file, params=params)
+ if isinstance(output_json, str) or isinstance(output_json, bytes):
+ output_json = json.loads(output_json)
+ # context.logger.info(f"Mask File: {output_mask}")
+
+ mask_im = Image.open(output_mask)
+ mask_np = np.array(mask_im).astype(np.uint8)
+ os.remove(output_mask)
+ os.remove(image_file)
+
+ resp = {"mask": mask_np.tolist()}
+ context.logger.info(f"Image: {image.size}; Mask: {mask_im.size} vs {mask_np.shape}; JSON: {output_json}")
+
+ context.logger.info("=============================================================================\n")
+ return context.Response(
+ body=json.dumps(resp),
+ headers={},
+ content_type="application/json",
+ status_code=200,
+ )
+
+
+if __name__ == "__main__":
+ import logging
+ from argparse import Namespace
+
+ logging.basicConfig(
+ level=logging.INFO,
+ format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ with open("/home/sachi/Datasets/endo/frame001.jpg", "rb") as fp:
+ image = base64.b64encode(fp.read())
+
+ event = Namespace(**{"body": {"image": image, "pos_points": [[1209, 493]]}})
+
+ def print_all(*args, **kwargs):
+ return {"args": args, **kwargs}
+
+ context = Namespace(
+ **{
+ "logger": logging.getLogger(__name__),
+ "user_data": Namespace(**{"model": None, "model_handler": None}),
+ "Response": print_all,
+ }
+ )
+ init_context(context)
+ response = handler(context, event)
+ # logging.info(response)
diff --git a/plugins/cvat/main.py b/plugins/cvat/main.py
deleted file mode 100644
index aa14d6d14..000000000
--- a/plugins/cvat/main.py
+++ /dev/null
@@ -1,184 +0,0 @@
-# Copyright (c) MONAI Consortium
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import base64
-import io
-import json
-import logging
-import os
-from distutils.util import strtobool
-
-import numpy as np
-from PIL import Image
-
-from monailabel.interfaces.utils.app import app_instance
-
-logging.basicConfig(
- level=logging.INFO,
- format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s",
- datefmt="%Y-%m-%d %H:%M:%S",
-)
-
-
-def init_context(context):
- context.logger.info("Init context... 0%")
-
- app_dir = os.environ.get("MONAI_LABEL_APP_DIR", "/opt/conda/monailabel/sample-apps/pathology")
- studies = os.environ.get("MONAI_LABEL_STUDIES", "/opt/monailabel/studies")
- model = os.environ.get("MONAI_LABEL_MODELS", "segmentation_nuclei")
- pretrained_path = os.environ.get(
- "MONAI_PRETRAINED_PATH", "https://github.com/Project-MONAI/MONAILabel/releases/download/data"
- )
- conf = {"preload": "true", "models": model, "pretrained_path": pretrained_path}
-
- root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
- app_dir = app_dir if os.path.exists(app_dir) else os.path.join(root_dir, "sample-apps", "pathology")
- studies = studies if os.path.exists(os.path.dirname(studies)) else os.path.join(root_dir, "studies")
-
- app = app_instance(app_dir, studies, conf)
-
- context.user_data.model = model
- context.user_data.model_handler = app
- context.logger.info("Init context...100%")
-
-
-def handler(context, event):
- context.logger.info(f"Run model: {context.user_data.model}")
- data = event.body
-
- image = Image.open(io.BytesIO(base64.b64decode(data["image"])))
- image_np = np.asarray(image.convert("RGB"), dtype=np.uint8)
-
- flip_image = strtobool(os.environ.get("MONAI_LABEL_FLIP_INPUT_IMAGE", "true"))
- flip_points = strtobool(os.environ.get("MONAI_LABEL_FLIP_INPUT_POINTS", "true"))
- flip_output = strtobool(os.environ.get("MONAI_LABEL_FLIP_OUTPUT_POINTS", "false"))
-
- if flip_image:
- image_np = np.moveaxis(image_np, 0, 1)
-
- pos_points = data.get("pos_points")
- neg_points = data.get("neg_points")
- if flip_points:
- foreground = np.flip(np.array(pos_points, int), 1).tolist() if pos_points else pos_points
- background = np.flip(np.array(neg_points, int), 1).tolist() if neg_points else neg_points
- else:
- foreground = np.array(pos_points, int).tolist() if pos_points else pos_points
- background = np.array(neg_points, int).tolist() if neg_points else neg_points
-
- context.logger.info(f"Image: {image_np.shape}; Foreground: {foreground}; Background: {background}")
-
- json_data = context.user_data.model_handler.infer(
- request={
- "model": context.user_data.model,
- "image": image_np,
- "foreground": foreground,
- "background": background,
- "output": "json",
- }
- )
-
- results = []
- prediction = json_data["params"].get("prediction")
- if prediction:
- context.logger.info(f"(Classification) Prediction: {prediction}")
-
- # CVAT Limitation:: tag is not yet supported https://github.com/opencv/cvat/issues/4212
- # CVAT Limitation:: select highest score and create bbox to represent as tag
- e = None
- for element in prediction:
- if element["score"] > 0:
- e = element if e is None or element["score"] > e["score"] else e
- context.logger.info(f"New Max Element: {e}")
-
- context.logger.info(f"Final Element with Max Score: {e}")
- if e:
- results.append(
- {
- "label": e["label"],
- "confidence": e["score"],
- "type": "rectangle",
- "points": [0, 0, image_np.shape[0] - 1, image_np.shape[1] - 1],
- }
- )
- context.logger.info(f"(Classification) Results: {results}")
- else:
- interactor = strtobool(os.environ.get("INTERACTOR_MODEL", "false"))
- annotations = json_data["params"].get("annotations")
- for a in annotations:
- annotation = a.get("annotation", {})
- if not annotation:
- continue
-
- elements = annotation.get("elements", [])
- for element in elements:
- label = element["label"]
- contours = element["contours"]
- for contour in contours:
- points = np.array(contour, int)
- if flip_output:
- points = np.flip(points, axis=None)
-
- # CVAT limitation:: only one polygon result for interactor
- if interactor and contour:
- return context.Response(
- body=json.dumps(points.tolist()),
- headers={},
- content_type="application/json",
- status_code=200,
- )
-
- results.append(
- {
- "label": label,
- "points": points.flatten().tolist(),
- "type": "polygon",
- }
- )
-
- return context.Response(
- body=json.dumps(results),
- headers={},
- content_type="application/json",
- status_code=200,
- )
-
-
-"""
-if __name__ == "__main__":
- import logging
- from argparse import Namespace
-
- logging.basicConfig(
- level=logging.INFO,
- format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s",
- datefmt="%Y-%m-%d %H:%M:%S",
- )
-
- context = {
- "logger": logging.getLogger(__name__),
- "user_data": Namespace(**{"model": None, "model_handler": None}),
- }
- context = Namespace(**context)
-
- with open("test.jpg", "rb") as fp:
- image = base64.b64encode(fp.read())
-
- event = {
- "body": {
- "image": image,
- }
- }
- event = Namespace(**event)
-
- init_context(context)
- response = handler(context, event)
- print(response)
-"""
diff --git a/plugins/cvat/pathology/deepedit_nuclei.yaml b/plugins/cvat/pathology/deepedit_nuclei.yaml
index d46afbe20..3084c67e2 100644
--- a/plugins/cvat/pathology/deepedit_nuclei.yaml
+++ b/plugins/cvat/pathology/deepedit_nuclei.yaml
@@ -24,7 +24,7 @@ metadata:
spec:
description: A pre-trained interaction/deepedit model for Pathology
runtime: 'python:3.8'
- handler: main:handler
+ handler: detector:handler
eventTimeout: 30s
build:
@@ -34,13 +34,9 @@ spec:
directives:
preCopy:
- kind: ENV
- value: MONAI_LABEL_APP_DIR=/opt/conda/monailabel/sample-apps/pathology
+ value: MONAI_LABEL_SERVER=http://monailabel.com
- kind: ENV
- value: MONAI_LABEL_MODELS=deepedit_nuclei
- - kind: ENV
- value: PYTHONPATH=/opt/conda/monailabel/sample-apps/pathology
- - kind: ENV
- value: MONAI_PRETRAINED_PATH=https://github.com/Project-MONAI/MONAILabel/releases/download/data
+ value: MONAI_LABEL_MODEL=deepedit_nuclei
triggers:
myHttpTrigger:
@@ -50,10 +46,6 @@ spec:
attributes:
maxRequestBodySize: 33554432 # 32MB
- resources:
- limits:
- nvidia.com/gpu: 1
-
platform:
attributes:
restartPolicy:
diff --git a/plugins/cvat/pathology/nuclick.yaml b/plugins/cvat/pathology/nuclick.yaml
index f08edede5..82976e241 100644
--- a/plugins/cvat/pathology/nuclick.yaml
+++ b/plugins/cvat/pathology/nuclick.yaml
@@ -14,6 +14,7 @@ metadata:
namespace: cvat
annotations:
name: Nuclick
+ version: 2
type: interactor
framework: pytorch
spec:
@@ -25,7 +26,7 @@ metadata:
spec:
description: A pre-trained NuClick model for interactive cell segmentation for Pathology
runtime: 'python:3.8'
- handler: main:handler
+ handler: interactor:handler
eventTimeout: 30s
build:
@@ -35,15 +36,9 @@ spec:
directives:
preCopy:
- kind: ENV
- value: MONAI_LABEL_APP_DIR=/opt/conda/monailabel/sample-apps/pathology
+ value: MONAI_LABEL_SERVER=http://monailabel.com
- kind: ENV
- value: MONAI_LABEL_MODELS=nuclick
- - kind: ENV
- value: PYTHONPATH=/opt/conda/monailabel/sample-apps/pathology
- - kind: ENV
- value: MONAI_PRETRAINED_PATH=https://github.com/Project-MONAI/MONAILabel/releases/download/data
- - kind: ENV
- value: INTERACTOR_MODEL=true
+ value: MONAI_LABEL_MODEL=nuclick
triggers:
myHttpTrigger:
@@ -53,10 +48,6 @@ spec:
attributes:
maxRequestBodySize: 33554432 # 32MB
- resources:
- limits:
- nvidia.com/gpu: 1
-
platform:
attributes:
restartPolicy:
diff --git a/plugins/cvat/pathology/segmentation_nuclei.yaml b/plugins/cvat/pathology/segmentation_nuclei.yaml
index 0583afd50..f42a8a661 100644
--- a/plugins/cvat/pathology/segmentation_nuclei.yaml
+++ b/plugins/cvat/pathology/segmentation_nuclei.yaml
@@ -28,7 +28,7 @@ metadata:
spec:
description: A pre-trained semantic segmentation model for Pathology
runtime: 'python:3.8'
- handler: main:handler
+ handler: detector:handler
eventTimeout: 30s
build:
@@ -38,13 +38,9 @@ spec:
directives:
preCopy:
- kind: ENV
- value: MONAI_LABEL_APP_DIR=/opt/conda/monailabel/sample-apps/pathology
+ value: MONAI_LABEL_SERVER=http://monailabel.com
- kind: ENV
- value: MONAI_LABEL_MODELS=segmentation_nuclei
- - kind: ENV
- value: PYTHONPATH=/opt/conda/monailabel/sample-apps/pathology
- - kind: ENV
- value: MONAI_PRETRAINED_PATH=https://github.com/Project-MONAI/MONAILabel/releases/download/data
+ value: MONAI_LABEL_MODEL=segmentation_nuclei
triggers:
myHttpTrigger:
@@ -54,10 +50,6 @@ spec:
attributes:
maxRequestBodySize: 33554432 # 32MB
- resources:
- limits:
- nvidia.com/gpu: 1
-
platform:
attributes:
restartPolicy:
diff --git a/plugins/cvat/sam2/interactor.yaml b/plugins/cvat/sam2/interactor.yaml
new file mode 100644
index 000000000..6787c300e
--- /dev/null
+++ b/plugins/cvat/sam2/interactor.yaml
@@ -0,0 +1,56 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+metadata:
+ name: monailabel.sam2.interactor
+ namespace: cvat
+ annotations:
+ name: SAM2
+ version: 2
+ type: interactor
+ spec:
+ min_pos_points: 0
+ min_neg_points: 0
+ startswith_box_optional: true
+ help_message: The interactor allows to annotate a Tool using SAM2 model
+
+spec:
+ description: A pre-trained SAM2 model for interactive model
+ runtime: 'python:3.8'
+ handler: interactor:handler
+ eventTimeout: 30s
+
+ build:
+ image: cvat/monailabel.sam2.interactor
+ baseImage: projectmonai/monailabel:latest
+
+ directives:
+ preCopy:
+ - kind: ENV
+ value: MONAI_LABEL_SERVER=http://monailabel.com
+ - kind: ENV
+ value: MONAI_LABEL_MODEL=sam_2d
+
+ triggers:
+ myHttpTrigger:
+ maxWorkers: 1
+ kind: 'http'
+ workerAvailabilityTimeoutMilliseconds: 10000
+ attributes:
+ maxRequestBodySize: 33554432 # 32MB
+
+ platform:
+ attributes:
+ restartPolicy:
+ name: always
+ maximumRetryCount: 1
+ mountMode: volume
+ network: cvat_cvat
diff --git a/plugins/cvat/sam2/tracker.yaml b/plugins/cvat/sam2/tracker.yaml
new file mode 100644
index 000000000..c6842f9bf
--- /dev/null
+++ b/plugins/cvat/sam2/tracker.yaml
@@ -0,0 +1,51 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+metadata:
+ name: monailabel.sam2.tracker
+ namespace: cvat
+ annotations:
+ name: SAM2T
+ type: tracker
+ spec:
+
+spec:
+ description: A pre-trained SAM2 model for tracking model
+ runtime: 'python:3.8'
+ handler: tracker:handler
+ eventTimeout: 30s
+
+ build:
+ image: cvat/monailabel.sam2.tracker
+ baseImage: projectmonai/monailabel:latest
+
+ directives:
+ preCopy:
+ - kind: ENV
+ value: MONAI_LABEL_SERVER=http://monailabel.com
+ - kind: ENV
+ value: MONAI_LABEL_MODEL=sam_2d
+
+ triggers:
+ myHttpTrigger:
+ maxWorkers: 1
+ kind: 'http'
+ workerAvailabilityTimeoutMilliseconds: 10000
+ attributes:
+ maxRequestBodySize: 33554432 # 32MB
+
+ platform:
+ attributes:
+ restartPolicy:
+ name: always
+ maximumRetryCount: 1
+ mountMode: volume
+ network: cvat_cvat
diff --git a/plugins/cvat/tracker.py b/plugins/cvat/tracker.py
new file mode 100644
index 000000000..60e75d4c4
--- /dev/null
+++ b/plugins/cvat/tracker.py
@@ -0,0 +1,149 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import io
+import json
+import logging
+import os
+import tempfile
+
+import numpy as np
+from PIL import Image
+
+from monailabel.client import MONAILabelClient
+from monailabel.transform.post import FindContoursd
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+)
+
+
+def init_context(context):
+ context.logger.info("Init context... 0%")
+ server = os.environ.get("MONAI_LABEL_SERVER", "http://0.0.0.0:8000")
+ model = os.environ.get("MONAI_LABEL_MODEL", "sam2")
+ client = MONAILabelClient(server)
+
+ info = client.info()
+ model_info = info["models"][model] if info and info["models"] else None
+ context.logger.info(f"Monai Label Info: {model_info}")
+ assert model_info
+
+ context.user_data.model = model
+ context.user_data.model_handler = client
+ context.logger.info("Init context...100%")
+
+
+def handler(context, event):
+ model: str = context.user_data.model
+ client: MONAILabelClient = context.user_data.model_handler
+ context.logger.info(f"Run model: {model}")
+ # TODO:: This is not really a tracker; Need to accumulate previous images + rois and do actual SAM2 Propagation.
+
+ data = event.body
+ image = Image.open(io.BytesIO(base64.b64decode(data["image"])))
+ context.logger.info(f"Image: {image.size}")
+ context.logger.info(f"Event Data Keys: {data.keys()}")
+
+ image_file = tempfile.NamedTemporaryFile(suffix=".jpg").name
+ image.save(image_file)
+
+ shapes = data.get("shapes")
+ states = data.get("states")
+ context.logger.info(f"Shapes: {shapes}; States: {states}")
+
+ rois = []
+ for i, shape in enumerate(shapes):
+ roi = np.array(shape).astype(int).tolist()
+ context.logger.info(f"{i} => Shape: {shape}; roi: {roi}")
+ rois.append(roi)
+
+ roi = rois[-1] # Pick the last
+ params = {"output": "json", "roi": roi}
+
+ # context.logger.info(f"Model:{model}; Params: {params}")
+ output_mask, output_json = client.infer(model=model, image_id="", file=image_file, params=params)
+ if isinstance(output_json, str) or isinstance(output_json, bytes):
+ output_json = json.loads(output_json)
+ # context.logger.info(f"Mask: {output_mask}; Output JSON: {output_json}")
+
+ mask_np = np.array(Image.open(output_mask)).astype(np.uint8)
+ os.remove(output_mask)
+ os.remove(image_file)
+ context.logger.info(f"Image: {image.size}; Mask: {mask_np.shape}; JSON: {output_json}")
+
+ results = {"shapes": [], "states": []}
+ d = FindContoursd(keys="pred")({"pred": mask_np})
+ annotation = d.get("result", {}).get("annotation")
+ for element in annotation.get("elements", []):
+ contours = element["contours"]
+ all_points = []
+ for contour in contours:
+ points = np.flip(np.array(contour, int))
+ all_points.append(points.flatten().tolist())
+
+ def bounding_box(pts):
+ x, y = zip(*pts)
+ return [min(x), min(y), max(x), max(y)]
+
+ bbox = bounding_box(np.array(all_points).astype(int).reshape(-1, 2).tolist())
+ context.logger.info(f"Input Box: {roi}; Output Box: {bbox}")
+ results["shapes"].append(bbox)
+
+ context.logger.info("=============================================================================\n")
+ return context.Response(
+ body=json.dumps(results),
+ headers={},
+ content_type="application/json",
+ status_code=200,
+ )
+
+
+if __name__ == "__main__":
+ import logging
+ from argparse import Namespace
+
+ logging.basicConfig(
+ level=logging.INFO,
+ format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ with open("/home/sachi/Datasets/endo/frame001.jpg", "rb") as fp:
+ image = base64.b64encode(fp.read())
+
+ event = Namespace(
+ **{
+ "body": {
+ "image": image,
+ "shapes": [[327, 352, 1152, 803]],
+ "states": [],
+ }
+ }
+ )
+
+ def print_all(*args, **kwargs):
+ return {"args": args, **kwargs}
+
+ context = Namespace(
+ **{
+ "logger": logging.getLogger(__name__),
+ "user_data": Namespace(**{"model": None, "model_handler": None}),
+ "Response": print_all,
+ }
+ )
+
+ init_context(context)
+ response = handler(context, event)
+ print(response)
diff --git a/plugins/ohifv3/build.sh b/plugins/ohifv3/build.sh
index 8c4679a6b..5c0e53aec 100755
--- a/plugins/ohifv3/build.sh
+++ b/plugins/ohifv3/build.sh
@@ -50,6 +50,9 @@ APP_CONFIG=config/monai_label.js PUBLIC_URL=/ohif/ QUICK_BUILD=true yarn run bui
rm -rf ${install_dir}
cp -r platform/app/dist/ ${install_dir}
echo "Copied OHIF to ${install_dir}"
-rm -rf ../Viewers
+
+cd ..
+rm -rf Viewers
+find . -type d -name "node_modules" -exec rm -rf "{}" +
cd ${curr_dir}
diff --git a/plugins/ohifv3/extensions/monai-label/src/components/ModelSelector.tsx b/plugins/ohifv3/extensions/monai-label/src/components/ModelSelector.tsx
index c2371f336..b70475148 100644
--- a/plugins/ohifv3/extensions/monai-label/src/components/ModelSelector.tsx
+++ b/plugins/ohifv3/extensions/monai-label/src/components/ModelSelector.tsx
@@ -1,4 +1,3 @@
-
import React, { Component } from 'react';
import PropTypes from 'prop-types';
@@ -41,9 +40,11 @@ export default class ModelSelector extends Component {
return null;
}
- onChangeModel = evt => {
+ onChangeModel = (evt) => {
this.setState({ currentModel: evt.target.value });
- if (this.props.onSelectModel) this.props.onSelectModel(evt.target.value);
+ if (this.props.onSelectModel) {
+ this.props.onSelectModel(evt.target.value);
+ }
};
currentModel = () => {
@@ -84,7 +85,7 @@ export default class ModelSelector extends Component {
onChange={this.onChangeModel}
value={currentModel}
>
- {this.props.models.map(model => (
+ {this.props.models.map((model) => (
diff --git a/plugins/ohifv3/extensions/monai-label/src/components/MonaiLabelPanel.tsx b/plugins/ohifv3/extensions/monai-label/src/components/MonaiLabelPanel.tsx
index f62e3f5d6..5a13d1f54 100644
--- a/plugins/ohifv3/extensions/monai-label/src/components/MonaiLabelPanel.tsx
+++ b/plugins/ohifv3/extensions/monai-label/src/components/MonaiLabelPanel.tsx
@@ -121,7 +121,7 @@ export default class MonaiLabelPanel extends Component {
const response = await this.client().info();
// remove the background
- const labels = response.data.labels.splice(1)
+ const labels = response.data.labels.splice(1);
const segmentations = [
{
@@ -129,7 +129,7 @@ export default class MonaiLabelPanel extends Component {
label: 'Segmentations',
segments: labels.map((label, index) => ({
segmentIndex: index + 1,
- label
+ label,
})),
isActive: true,
activeSegmentIndex: 1,
@@ -137,7 +137,7 @@ export default class MonaiLabelPanel extends Component {
];
this.props.commandsManager.runCommand('loadSegmentationsForViewport', {
- segmentations
+ segmentations,
});
if (response.status !== 200) {
@@ -163,16 +163,18 @@ export default class MonaiLabelPanel extends Component {
// Leave Event
for (const action of Object.keys(this.actions)) {
if (this.state.action === action) {
- if (this.actions[action].current)
+ if (this.actions[action].current) {
this.actions[action].current.onLeaveActionTab();
+ }
}
}
// Enter Event
for (const action of Object.keys(this.actions)) {
if (name === action) {
- if (this.actions[action].current)
+ if (this.actions[action].current) {
this.actions[action].current.onEnterActionTab();
+ }
}
}
this.setState({ action: name });
@@ -193,11 +195,10 @@ export default class MonaiLabelPanel extends Component {
console.info('These are the predicted labels');
console.info(onInfoLabelNames);
- if (onInfoLabelNames.hasOwnProperty('background')){
-delete onInfoLabelNames.background;
+ if (onInfoLabelNames.hasOwnProperty('background')) {
+ delete onInfoLabelNames.background;
}
-
const ret = SegmentationReader.parseNrrdData(response.data);
if (!ret) {
@@ -248,7 +249,7 @@ delete onInfoLabelNames.background;
};
_debug = async () => {
- let nrrdFetch = await fetch('http://localhost:3000/pred2.nrrd');
+ const nrrdFetch = await fetch('http://localhost:3000/pred2.nrrd');
const info = {
spleen: 1,
diff --git a/plugins/ohifv3/extensions/monai-label/src/components/SegmentationToolbox.tsx b/plugins/ohifv3/extensions/monai-label/src/components/SegmentationToolbox.tsx
index 4577a56b7..338b389ab 100644
--- a/plugins/ohifv3/extensions/monai-label/src/components/SegmentationToolbox.tsx
+++ b/plugins/ohifv3/extensions/monai-label/src/components/SegmentationToolbox.tsx
@@ -83,11 +83,11 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
const unsubscriptions = [];
- events.forEach(event => {
+ events.forEach((event) => {
const { unsubscribe } = segmentationService.subscribe(event, () => {
const segmentations = segmentationService.getSegmentations();
- const activeSegmentation = segmentations?.find(seg => seg.isActive);
+ const activeSegmentation = segmentations?.find((seg) => seg.isActive);
setToolsEnabled(activeSegmentation?.segmentCount > 0);
});
@@ -96,7 +96,7 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
});
return () => {
- unsubscriptions.forEach(unsubscribe => unsubscribe());
+ unsubscriptions.forEach((unsubscribe) => unsubscribe());
};
}, [activeViewportId, viewports, segmentationService]);
@@ -117,7 +117,7 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
}, [toolbarService, updateActiveTool]);
const setToolActive = useCallback(
- toolName => {
+ (toolName) => {
toolbarService.recordInteraction({
groupId: 'SegmentationTools',
itemId: 'Brush',
@@ -139,8 +139,12 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
const updateBrushSize = useCallback(
(toolName, brushSize) => {
- toolGroupService.getToolGroupIds()?.forEach(toolGroupId => {
- segmentationUtils.setBrushSizeForToolGroup(toolGroupId, brushSize, toolName);
+ toolGroupService.getToolGroupIds()?.forEach((toolGroupId) => {
+ segmentationUtils.setBrushSizeForToolGroup(
+ toolGroupId,
+ brushSize,
+ toolName
+ );
});
},
[toolGroupService]
@@ -150,7 +154,7 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
(valueAsStringOrNumber, toolCategory) => {
const value = Number(valueAsStringOrNumber);
- _getToolNamesFromCategory(toolCategory).forEach(toolName => {
+ _getToolNamesFromCategory(toolCategory).forEach((toolName) => {
updateBrushSize(toolName, value);
});
@@ -166,7 +170,7 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
);
const handleRangeChange = useCallback(
- newRange => {
+ (newRange) => {
if (
newRange[0] === state.ThresholdBrush.thresholdRange[0] &&
newRange[1] === state.ThresholdBrush.thresholdRange[1]
@@ -176,8 +180,8 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
const toolNames = _getToolNamesFromCategory('ThresholdBrush');
- toolNames.forEach(toolName => {
- toolGroupService.getToolGroupIds()?.forEach(toolGroupId => {
+ toolNames.forEach((toolName) => {
+ toolGroupService.getToolGroupIds()?.forEach((toolGroupId) => {
const toolGroup = toolGroupService.getToolGroup(toolGroupId);
toolGroup.setToolConfiguration(toolName, {
strategySpecificConfiguration: {
@@ -208,7 +212,9 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
name: 'Brush',
icon: 'icon-tool-brush',
disabled: !toolsEnabled,
- active: state.activeTool === 'CircularBrush' || state.activeTool === 'SphereBrush',
+ active:
+ state.activeTool === 'CircularBrush' ||
+ state.activeTool === 'SphereBrush',
onClick: () => setToolActive('CircularBrush'),
options: [
{
@@ -219,7 +225,7 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
max: 100,
value: state.Brush.brushSize,
step: 0.5,
- onChange: value => onBrushSizeChange(value, 'Brush'),
+ onChange: (value) => onBrushSizeChange(value, 'Brush'),
},
{
name: 'Mode',
@@ -230,7 +236,7 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
{ value: 'CircularBrush', label: 'Circle' },
{ value: 'SphereBrush', label: 'Sphere' },
],
- onChange: value => setToolActive(value),
+ onChange: (value) => setToolActive(value),
},
],
},
@@ -238,7 +244,9 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
name: 'Eraser',
icon: 'icon-tool-eraser',
disabled: !toolsEnabled,
- active: state.activeTool === 'CircularEraser' || state.activeTool === 'SphereEraser',
+ active:
+ state.activeTool === 'CircularEraser' ||
+ state.activeTool === 'SphereEraser',
onClick: () => setToolActive('CircularEraser'),
options: [
{
@@ -249,7 +257,7 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
max: 100,
value: state.Eraser.brushSize,
step: 0.5,
- onChange: value => onBrushSizeChange(value, 'Eraser'),
+ onChange: (value) => onBrushSizeChange(value, 'Eraser'),
},
{
name: 'Mode',
@@ -260,7 +268,7 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
{ value: 'CircularEraser', label: 'Circle' },
{ value: 'SphereEraser', label: 'Sphere' },
],
- onChange: value => setToolActive(value),
+ onChange: (value) => setToolActive(value),
},
],
},
@@ -284,7 +292,7 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
{ value: 'RectangleScissor', label: 'Rectangle' },
{ value: 'SphereScissor', label: 'Sphere' },
],
- onChange: value => setToolActive(value),
+ onChange: (value) => setToolActive(value),
},
],
},
@@ -305,7 +313,7 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
max: 100,
value: state.ThresholdBrush.brushSize,
step: 0.5,
- onChange: value => onBrushSizeChange(value, 'ThresholdBrush'),
+ onChange: (value) => onBrushSizeChange(value, 'ThresholdBrush'),
},
{
name: 'Mode',
@@ -316,7 +324,7 @@ function SegmentationToolbox({ servicesManager, extensionManager }) {
{ value: 'ThresholdCircularBrush', label: 'Circle' },
{ value: 'ThresholdSphereBrush', label: 'Sphere' },
],
- onChange: value => setToolActive(value),
+ onChange: (value) => setToolActive(value),
},
{
type: 'custom',
diff --git a/plugins/ohifv3/extensions/monai-label/src/components/SettingsTable.tsx b/plugins/ohifv3/extensions/monai-label/src/components/SettingsTable.tsx
index d5db4bc9a..12fe1a284 100644
--- a/plugins/ohifv3/extensions/monai-label/src/components/SettingsTable.tsx
+++ b/plugins/ohifv3/extensions/monai-label/src/components/SettingsTable.tsx
@@ -8,8 +8,8 @@ export default class SettingsTable extends Component {
constructor(props) {
super(props);
- const onInfo = props.onInfo
- this.onInfo = onInfo
+ const onInfo = props.onInfo;
+ this.onInfo = onInfo;
this.state = this.getSettings();
}
@@ -35,8 +35,8 @@ export default class SettingsTable extends Component {
};
};
- onBlurSeverURL = evt => {
- let url = evt.target.value;
+ onBlurSeverURL = (evt) => {
+ const url = evt.target.value;
this.setState({ url: url });
CookieUtils.setCookie('MONAILABEL_SERVER_URL', url);
};
diff --git a/plugins/ohifv3/extensions/monai-label/src/components/Toolbox/ThresholdSettingsPreset.tsx b/plugins/ohifv3/extensions/monai-label/src/components/Toolbox/ThresholdSettingsPreset.tsx
index f708233c3..63670c9f6 100644
--- a/plugins/ohifv3/extensions/monai-label/src/components/Toolbox/ThresholdSettingsPreset.tsx
+++ b/plugins/ohifv3/extensions/monai-label/src/components/Toolbox/ThresholdSettingsPreset.tsx
@@ -33,16 +33,19 @@ function ThresholdSettings({ onRangeChange }) {
const [options, setOptions] = useState(defaultOptions);
const [selectedPreset, setSelectedPreset] = useState(defaultOptions[0].value);
- const handleRangeChange = newRange => {
- const selectedOption = options.find(o => o.value === selectedPreset);
+ const handleRangeChange = (newRange) => {
+ const selectedOption = options.find((o) => o.value === selectedPreset);
- if (newRange[0] === selectedOption.range[0] && newRange[1] === selectedOption.range[1]) {
+ if (
+ newRange[0] === selectedOption.range[0] &&
+ newRange[1] === selectedOption.range[1]
+ ) {
return;
}
onRangeChange(newRange);
- const updatedOptions = options.map(o => {
+ const updatedOptions = options.map((o) => {
if (o.value === selectedPreset) {
return {
...o,
@@ -55,7 +58,9 @@ function ThresholdSettings({ onRangeChange }) {
setOptions(updatedOptions);
};
- const selectedPresetRange = options.find(ds => ds.value === selectedPreset).range;
+ const selectedPresetRange = options.find(
+ (ds) => ds.value === selectedPreset
+ ).range;
return (