From 4dd1813e18f4f7d92c965b42a3be67c14bb38106 Mon Sep 17 00:00:00 2001 From: tangy5 <58751975+tangy5@users.noreply.github.com> Date: Wed, 8 May 2024 15:27:13 -0700 Subject: [PATCH] Fix cv2 (#1688) * fix cv2 Signed-off-by: tangy5 * fix cv2 Signed-off-by: tangy5 * fix cv2 Signed-off-by: tangy5 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * optinal import and fallback Signed-off-by: tangy5 * optinal import and fallback Signed-off-by: tangy5 * optinal import and fallback Signed-off-by: tangy5 * optinal import and fallback Signed-off-by: tangy5 * optinal import and fallback Signed-off-by: tangy5 * optinal import and fallback Signed-off-by: tangy5 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: tangy5 Signed-off-by: tangy5 Co-authored-by: tangy5 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monailabel/transform/post.py | 102 +++++++++++------- requirements.txt | 1 - .../pathology/lib/trainers/hovernet_nuclei.py | 18 +++- sample-apps/pathology/lib/utils.py | 98 ++++++++++++++--- setup.cfg | 1 - tests/unit/transform/test_post.py | 2 +- 6 files changed, 159 insertions(+), 63 deletions(-) diff --git a/monailabel/transform/post.py b/monailabel/transform/post.py index 106f65860..cf4abdf96 100644 --- a/monailabel/transform/post.py +++ b/monailabel/transform/post.py @@ -12,7 +12,6 @@ import logging from typing import Dict, Hashable, Mapping, Optional, Sequence, Union -import cv2 import nibabel as nib import numpy as np import skimage.measure as measure @@ -27,8 +26,9 @@ generate_spatial_bounding_box, get_extreme_points, ) -from monai.utils import InterpolateMode, convert_to_numpy, ensure_tuple_rep +from monai.utils import InterpolateMode, convert_to_numpy, ensure_tuple_rep, optional_import from shapely.geometry import Point, Polygon +from skimage.measure import approximate_polygon, find_contours from torchvision.utils import make_grid, save_image from monailabel.utils.others.label_colors import get_color @@ -36,9 +36,6 @@ logger = logging.getLogger(__name__) -# TODO:: Move to MONAI ?? - - class LargestCCd(MapTransform): def __init__(self, keys: KeysCollection, has_channel: bool = True): super().__init__(keys) @@ -183,7 +180,6 @@ def __init__( colormap=None, ): super().__init__(keys) - self.min_positive = min_positive self.min_poly_area = min_poly_area self.max_poly_area = max_poly_area @@ -208,9 +204,7 @@ def __call__(self, data): min_poly_area = d.get("min_poly_area", self.min_poly_area) max_poly_area = d.get("max_poly_area", self.max_poly_area) color_map = d.get(self.key_label_colors) if self.colormap is None else self.colormap - - foreground_points = d.get(self.key_foreground_points, []) if self.key_foreground_points else [] - foreground_points = [Point(pt[0], pt[1]) for pt in foreground_points] # polygons in (x, y) format + foreground_points = [Point(pt) for pt in d.get(self.key_foreground_points, [])] elements = [] label_names = set() @@ -220,43 +214,73 @@ def __call__(self, data): continue labels = [label for label in np.unique(p).tolist() if label > 0] - logger.debug(f"Total Unique Masks (excluding background): {labels}") for label_idx in labels: p = convert_to_numpy(d[key]) if isinstance(d[key], torch.Tensor) else d[key] p = np.where(p == label_idx, 1, 0).astype(np.uint8) - p = np.moveaxis(p, 0, 1) # for cv2 + p = np.moveaxis(p, 0, 1) + if label_idx == 0: + continue label_name = self.labels.get(label_idx, label_idx) label_names.add(label_name) - polygons = [] - contours, _ = cv2.findContours(p, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) - for contour in contours: - if len(contour) < 3: - continue - - contour = np.squeeze(contour) - area = cv2.contourArea(contour) - if area < min_poly_area: # Ignore poly with lesser area - continue - if 0 < max_poly_area < area: # Ignore very large poly (e.g. in case of nuclei) - continue - - contour[:, 0] += location[0] # X - contour[:, 1] += location[1] # Y - - coords = contour.astype(int).tolist() - if foreground_points: - for pt in foreground_points: - if Polygon(coords).contains(pt): - polygons.append(coords) - break - else: - polygons.append(coords) - - if len(polygons): - logger.debug(f"+++++ {label_idx} => Total Polygons Found: {len(polygons)}") - elements.append({"label": label_name, "contours": polygons}) + cv2, has_cv2 = optional_import("cv2") + if has_cv2: + polygons = [] + contours, _ = cv2.findContours(p, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + if len(contour) < 3: + continue + + contour = np.squeeze(contour) + area = cv2.contourArea(contour) + if area < min_poly_area: + continue + if 0 < max_poly_area < area: + continue + + contour[:, 0] += location[0] + contour[:, 1] += location[1] + + coords = contour.astype(int).tolist() + if foreground_points: + for pt in foreground_points: + if Polygon(coords).contains(pt): + polygons.append(coords) + break + else: + polygons.append(coords) + + if len(polygons): + logger.debug(f"+++++ {label_idx} => Total Polygons Found: {len(polygons)}") + elements.append({"label": label_name, "contours": polygons}) + else: + contours = find_contours(p, 0.5) + contours = [np.round(contour).astype(int) for contour in contours] + for contour in contours: + if not np.array_equal(contour[0], contour[-1]): + contour = np.append(contour, [contour[0]], axis=0) + + simplified_contour = approximate_polygon(contour, tolerance=0.5) + if len(simplified_contour) < 4: + continue + + simplified_contour = np.flip(simplified_contour, axis=1) + simplified_contour += location + simplified_contour = simplified_contour.astype(int) + + polygon = Polygon(simplified_contour) + if ( + polygon.is_valid + and polygon.area >= min_poly_area + and (max_poly_area <= 0 or polygon.area <= max_poly_area) + ): + formatted_contour = [simplified_contour.tolist()] + if foreground_points: + if any(polygon.contains(point) for point in foreground_points): + elements.append({"label": label_name, "contours": formatted_contour}) + else: + elements.append({"label": label_name, "contours": formatted_contour}) if elements: if d.get(self.result) is None: diff --git a/requirements.txt b/requirements.txt index 5fdb145ff..4887249dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,7 +28,6 @@ pydicom==2.4.4 pydicom-seg==0.4.1 pynetdicom==2.0.2 pynrrd==1.0.0 -opencv-python-headless==4.9.0.80 numpymaxflow==0.0.6 girder-client==3.2.3 ninja==1.11.1.1 diff --git a/sample-apps/pathology/lib/trainers/hovernet_nuclei.py b/sample-apps/pathology/lib/trainers/hovernet_nuclei.py index 154d2215c..e089c2a12 100644 --- a/sample-apps/pathology/lib/trainers/hovernet_nuclei.py +++ b/sample-apps/pathology/lib/trainers/hovernet_nuclei.py @@ -14,11 +14,12 @@ import pathlib from typing import Dict, Optional -import cv2 import numpy as np from lib.hovernet import PatchExtractor from lib.utils import split_dataset +from monai.utils import optional_import from PIL import Image +from scipy.ndimage import label from tqdm import tqdm from monailabel.interfaces.datastore import Datastore @@ -36,7 +37,11 @@ def __init__(self, path: str, conf: Dict[str, str], const: Optional[BundleConsta self.step_size = (164, 164) self.extract_type = "mirror" - def _fetch_datalist(self, request, datastore: Datastore): + def remove_file(path): + if os.path.exists(path): + os.remove(path) + + def _fetch_datalist(self, request, datastore): cache_dir = os.path.join(self.bundle_path, "cache", "train_ds") remove_file(cache_dir) @@ -71,13 +76,18 @@ def _fetch_datalist(self, request, datastore: Datastore): img = np.array(Image.open(d["image"]).convert("RGB")) ann_type = np.array(Image.open(d["label"])) - numLabels, ann_inst, _, _ = cv2.connectedComponentsWithStats(ann_type, 4, cv2.CV_32S) + cv2, has_cv2 = optional_import("cv2") + if has_cv2: + numLabels, ann_inst, _, _ = cv2.connectedComponentsWithStats(ann_type, 4, cv2.CV_32S) + else: + ann_inst, numLabels = label(ann_type) + ann = np.dstack([ann_inst, ann_type]) img = np.concatenate([img, ann], axis=-1) sub_patches = xtractor.extract(img, self.extract_type) - pbar_format = "Extracting : |{bar}| {n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_fmt}]" + pbar_format = "Extracting: |{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]" pbar = tqdm(total=len(sub_patches), leave=False, bar_format=pbar_format, ascii=True, position=1) for idx, patch in enumerate(sub_patches): diff --git a/sample-apps/pathology/lib/utils.py b/sample-apps/pathology/lib/utils.py index 3eef2aa86..9c2aeb6c9 100644 --- a/sample-apps/pathology/lib/utils.py +++ b/sample-apps/pathology/lib/utils.py @@ -18,11 +18,12 @@ from io import BytesIO from math import ceil -import cv2 import numpy as np import openslide import scipy -from PIL import Image +from monai.utils import optional_import +from PIL import Image, ImageDraw +from scipy.ndimage import center_of_mass, find_objects, label from tqdm import tqdm from monailabel.datastore.dsa import DSADatastore @@ -273,6 +274,7 @@ def split_consep_dataset( crop_size=256, ): dataset_json = [] + # logger.debug(f"Process Image: {d['image']} => Label: {d['label']}") images_dir = output_dir @@ -477,7 +479,21 @@ def split_nuclei_dataset( mask = Image.open(d["label"]) mask_np = np.array(mask) - numLabels, instances, stats, centroids = cv2.connectedComponentsWithStats(mask_np, 4, cv2.CV_32S) + cv2, has_cv2 = optional_import("cv2") + if has_cv2: + numLabels, instances, stats, centroids = cv2.connectedComponentsWithStats(mask_np, 4, cv2.CV_32S) + else: + numLabels, instances = label(mask_np) + stats = [] + centroids = center_of_mass(mask_np, instances, range(numLabels)) + + objects = find_objects(instances) + for i, slice_tuple in enumerate(objects): + if slice_tuple is not None: + dx, dy = slice_tuple + area = (dx.stop - dx.start) * (dy.stop - dy.start) + stats.append([dy.start, dx.start, dy.stop - dy.start, dx.stop - dx.start, area]) + logger.info("-------------------------------------------------------------------------------") logger.info(f"Image/Label ========> {d['image']} =====> {d['label']}") logger.info(f"Total Labels: {numLabels}") @@ -486,11 +502,11 @@ def split_nuclei_dataset( logger.info(f"Total Centroids: {len(centroids)}") logger.info(f"Total Classes in Mask: {np.unique(mask_np)}") - for nuclei_id, (x, y) in enumerate(centroids): + for nuclei_id, centroid in enumerate(centroids): if nuclei_id == 0: continue - x, y = (int(x), int(y)) + x, y = int(centroid[1]), int(centroid[0]) this_instance = np.where(instances == nuclei_id, mask_np, 0) class_id = int(np.max(this_instance)) @@ -556,9 +572,36 @@ def _group_item(groups, d, output_dir): return groups, item_id +def calculate_bounding_rect(points): + points = np.array(points, dtype=int) + x_min, y_min = np.min(points, axis=0) + x_max, y_max = np.max(points, axis=0) + w = x_max - x_min + 1 + h = y_max - y_min + 1 + return int(x_min), int(y_min), int(w), int(h) + + +def fill_poly(image_size, polygons, color, mode="L"): + if mode.upper() == "RGB": + img = Image.new("RGB", image_size, (0, 0, 0)) + else: + img = Image.new("L", image_size, 0) + + draw = ImageDraw.Draw(img) + for polygon in polygons: + draw.polygon([tuple(p) for p in polygon], fill=color) + return np.array(img) + + def _to_roi(points, max_region, polygons, annotation_id): logger.info(f"Total Points: {len(points)}") - x, y, w, h = cv2.boundingRect(np.array(points)) + + cv2, has_cv2 = optional_import("cv2") + if has_cv2: + x, y, w, h = cv2.boundingRect(np.array(points)) + else: + x, y, w, h = calculate_bounding_rect(points) + logger.info(f"ID: {annotation_id} => Groups: {polygons.keys()}; Location: ({x}, {y}); Size: {w} x {h}") if w > max_region[0]: @@ -584,25 +627,46 @@ def _to_dataset(item_id, x, y, w, h, img, tile_size, polygons, groups, output_di logger.debug(f"Image NP: {image_np.shape}; sum: {np.sum(image_np)}") tiled_images = _region_to_tiles(name, w, h, image_np, tile_size, output_dir, "Image") - label_np = np.zeros((h, w), dtype=np.uint8) # Transposed - for group, contours in polygons.items(): - color = groups.get(group, 1) - contours = [np.array([[p[0] - x, p[1] - y] for p in contour]) for contour in contours] + cv2, has_cv2 = optional_import("cv2") + if has_cv2: + label_np = np.zeros((h, w), dtype=np.uint8) # Transposed + for group, contours in polygons.items(): + color = groups.get(group, 1) + contours = [np.array([[p[0] - x, p[1] - y] for p in contour]) for contour in contours] + + cv2.fillPoly(label_np, pts=contours, color=color) + logger.info(f"{group} => p: {len(contours)}; c: {color}; unique: {np.unique(label_np, return_counts=True)}") + if debug: + regions_dir = os.path.join(output_dir, "regions") + label_path = os.path.realpath(os.path.join(regions_dir, "labels", group, f"{name}.png")) + os.makedirs(os.path.dirname(label_path), exist_ok=True) + cv2.imwrite(label_path, label_np) + else: + label_img = Image.new("L", (w, h), 0) + draw = ImageDraw.Draw(label_img) - cv2.fillPoly(label_np, pts=contours, color=color) - logger.info(f"{group} => p: {len(contours)}; c: {color}; unique: {np.unique(label_np, return_counts=True)}") + for group, contours in polygons.items(): + color = groups.get(group, 1) + pil_contours = [tuple((p[0] - x, p[1] - y) for p in contour) for contour in contours] - if debug: - regions_dir = os.path.join(output_dir, "regions") - label_path = os.path.realpath(os.path.join(regions_dir, "labels", group, f"{name}.png")) - os.makedirs(os.path.dirname(label_path), exist_ok=True) - cv2.imwrite(label_path, label_np) + for contour in pil_contours: + draw.polygon(contour, outline=color, fill=color) + + if debug: + regions_dir = os.path.join(output_dir, "regions") + label_path = os.path.realpath(os.path.join(regions_dir, "labels", group, f"{name}.png")) + os.makedirs(os.path.dirname(label_path), exist_ok=True) + label_img.save(label_path) + + label_np = np.array(label_img) tiled_labels = _region_to_tiles( name, w, h, label_np, tile_size, os.path.join(output_dir, "labels", "final"), "Label" ) + for k in tiled_images: dataset_json.append({"image": tiled_images[k], "label": tiled_labels[k]}) + return dataset_json diff --git a/setup.cfg b/setup.cfg index 1ec39c442..f360a6fce 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,7 +54,6 @@ install_requires = pydicom-seg==0.4.1 pynetdicom==2.0.2 pynrrd==1.0.0 - opencv-python-headless==4.9.0.80 numpymaxflow==0.0.6 girder-client==3.2.3 ninja==1.11.1.1 diff --git a/tests/unit/transform/test_post.py b/tests/unit/transform/test_post.py index e21802afa..de03d1b89 100644 --- a/tests/unit/transform/test_post.py +++ b/tests/unit/transform/test_post.py @@ -61,7 +61,7 @@ { "pred": np.array([[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 0, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]), }, - [[[1, 2], [2, 1], [3, 2], [2, 3]], [[1, 1], [1, 3], [3, 3], [3, 1]]], + [[[3, 4], [1, 4], [0, 3], [0, 1], [1, 0], [3, 0], [4, 1], [4, 3], [3, 4]]], ] DUMPIMAGEPREDICTION2DD_DATA = [