Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vista3d_ohifv3 support #1771

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions monailabel/tasks/infer/basic_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@
import torch
from monai.data import decollate_batch
from monai.inferers import Inferer, SimpleInferer, SlidingWindowInferer
from monai.utils import deprecated
from monai.utils import deprecated, optional_import

from monailabel.interfaces.exception import MONAILabelError, MONAILabelException
from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType
from monailabel.interfaces.utils.transform import dump_data, run_transforms
from monailabel.tasks.infer.prompt_utils import prompt_run_inferer
from monailabel.transform.cache import CacheTransformDatad
from monailabel.transform.writer import ClassificationWriter, DetectionWriter, Writer
from monailabel.utils.others.generic import device_list, device_map, name_to_device

rearrange, _ = optional_import("einops", name="rearrange")


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -368,6 +372,7 @@ def __call__(

if result_file_name is not None and isinstance(result_file_name, str):
logger.info(f"Result File: {result_file_name}")

logger.info(f"Result Json Keys: {list(result_json.keys())}")
return result_file_name, result_json

Expand Down Expand Up @@ -396,7 +401,6 @@ def __call__(self, data):
if d is None:
return run_transforms(data, transforms, log_prefix="PRE", use_compose=False)
return run_transforms(d, post_cache, log_prefix="PRE", use_compose=False) if post_cache else d

return run_transforms(data, transforms, log_prefix="PRE", use_compose=False)

def run_invert_transforms(self, data: Dict[str, Any], pre_transforms, names):
Expand Down Expand Up @@ -503,26 +507,41 @@ def run_inferer(self, data: Dict[str, Any], convert_to_batch=True, device="cuda"
logger.info(f"Inferer:: {device} => {inferer.__class__.__name__} => {inferer.__dict__}")

network = self._get_network(device, data)
modelname = data.get("model", None)
if network:
inputs = data[self.input_key]
inputs = inputs if torch.is_tensor(inputs) else torch.from_numpy(inputs)
inputs = inputs[None] if convert_to_batch else inputs
inputs = inputs.to(torch.device(device))

with torch.no_grad():
outputs = inferer(inputs, network)

if device.startswith("cuda"):
torch.cuda.empty_cache()
if "vista" in modelname:
return prompt_run_inferer(
data,
inferer,
network,
input_key=self.input_key,
output_label_key=self.output_label_key,
device=device,
)
else:
inputs = data[self.input_key]
inputs = inputs if torch.is_tensor(inputs) else torch.from_numpy(inputs)
inputs = inputs[None] if convert_to_batch else inputs
inputs = inputs.to(torch.device(device))

with torch.no_grad():
outputs = inferer(
inputs,
network,
)

if device.startswith("cuda"):
torch.cuda.empty_cache()

if convert_to_batch:
if isinstance(outputs, dict):
outputs_d = decollate_batch(outputs)
outputs = outputs_d[0]
else:
outputs = outputs[0]

if convert_to_batch:
if isinstance(outputs, dict):
outputs_d = decollate_batch(outputs)
outputs = outputs_d[0]
else:
outputs = outputs[0]
data[self.output_label_key] = outputs

data[self.output_label_key] = outputs
else:
# consider them as callable transforms
data = run_transforms(data, inferer, log_prefix="INF", log_name="Inferer")
Expand Down
6 changes: 3 additions & 3 deletions monailabel/tasks/infer/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from monai.bundle import ConfigItem, ConfigParser
from monai.inferers import Inferer, SimpleInferer
from monai.transforms import Compose, LoadImaged, SaveImaged
from monai.transforms import Compose, CropForegroundd, Invertd, LoadImaged, SaveImaged

from monailabel.interfaces.tasks.infer_v2 import InferType
from monailabel.tasks.infer.basic_infer import BasicInferTask
Expand Down Expand Up @@ -82,8 +82,8 @@ def __init__(
conf: Dict[str, str],
const: Optional[BundleConstants] = None,
type: Union[str, InferType] = "",
pre_filter: Optional[Sequence] = None,
post_filter: Optional[Sequence] = [SaveImaged],
pre_filter: Optional[Sequence] = [CropForegroundd],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reasons why we needed these changes? which is applicable to everyone

post_filter: Optional[Sequence] = [SaveImaged, Invertd],
extend_load_image: bool = True,
add_post_restore: bool = True,
dropout: float = 0.0,
Expand Down
136 changes: 136 additions & 0 deletions monailabel/tasks/infer/prompt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Any, Dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when vista3d bundle is used.. why we need to have this related code in monailabel?


import numpy as np
import torch
from monai.data import decollate_batch
from monai.utils import optional_import

rearrange, _ = optional_import("einops", name="rearrange")


def transform_points(point, affine):
bs, n = point.shape[:2]
point = np.concatenate((point, np.ones((bs, n, 1))), axis=-1)
point = rearrange(point, "b n d -> d (b n)")
point = affine @ point
point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3]
return point


def check_prompts_format(label_prompt, points, point_labels):
"""check the format of user prompts
label_prompt: [1,2,3,4,...,B] List of tensors
points: [[[x,y,z], [x,y,z], ...]] List of coordinates of a single object
point_labels: [[1,1,0,...]] List of scalar that matches number of points
"""
# check prompt is given
if label_prompt is None and points is None:
everything_labels = list(
{i + 1 for i in range(132)} - {2, 16, 18, 20, 21, 23, 24, 25, 26, 27, 128, 129, 130, 131, 132}
)
if everything_labels is not None:
label_prompt = [torch.tensor(_) for _ in everything_labels]

return label_prompt, points, point_labels
else:
raise ValueError("Prompt must be given for inference.")
# check label_prompt
if label_prompt is not None:

if isinstance(label_prompt, list):
# if not np.all([len(_) == 1 for _ in label_prompt]):
# raise ValueError("Label prompt must be a list of single scalar, [1,2,3,4,...,].")
if not np.all([(x < 255).item() for x in label_prompt]):
raise ValueError("Current bundle only supports label prompt smaller than 255.")
if points is None:
supported_list = list({i + 1 for i in range(132)} - {16, 18, 129, 130, 131})
if not np.all([x in supported_list for x in label_prompt]):
raise ValueError("Undefined label prompt detected. Provide point prompts for zero-shot.")
else:
raise ValueError("Label prompt must be a list, [1,2,3,4,...,].")
# check points
if points is not None:
if point_labels is None:
raise ValueError("Point labels must be given if points are given.")
if not np.all([len(_) == 3 for _ in points]):
raise ValueError("Points must be three dimensional (x,y,z) in the shape of [[x,y,z],...,[x,y,z]].")
if len(points) != len(point_labels):
raise ValueError("Points must match point labels.")
if not np.all([_ in [-1, 0, 1, 2, 3] for _ in point_labels]):
raise ValueError("Point labels can only be -1,0,1 and 2,3 for special flags.")
if label_prompt is not None and points is not None:
if len(label_prompt) != 1:
raise ValueError("Label prompt can only be a single object if provided with point prompts.")
# check point_labels
if point_labels is not None:
if points is None:
raise ValueError("Points must be given if point labels are given.")
return label_prompt, points, point_labels


def prompt_run_inferer(
data: Dict[str, Any],
inferer,
network,
input_key="image",
output_label_key="pred",
device="cuda",
convert_to_batch=True,
):
# Retrieve label_prompt, points, and point_labels
label_prompt, points, point_labels = (
data.get("label_prompt", None),
data.get("points", None),
data.get("point_labels", None),
)

if label_prompt is not None:
label_prompt = [torch.tensor(_) for _ in label_prompt]
if isinstance(label_prompt, torch.Tensor):
if label_prompt.numel() == 0:
label_prompt = None
elif isinstance(label_prompt, list):
if len(label_prompt) == 0:
label_prompt = None

label_prompt, points, point_labels = check_prompts_format(label_prompt, points, point_labels)
label_prompt = (
torch.as_tensor([label_prompt]).to(torch.device(device))[0].unsqueeze(-1) if label_prompt is not None else None
)
data["label_prompt"] = label_prompt

# Transform points based on spatial scaling factors
if points is not None:
points = torch.as_tensor([points])

original_spatial_shape = np.array(data["image_meta_dict"]["spatial_shape"])
resized_spatial_shape = np.array(data[input_key].shape[1:])
scaling_factors = resized_spatial_shape / original_spatial_shape
transformed_point = points * scaling_factors
transformed_point_rounded = np.round(transformed_point)
points = transformed_point_rounded.to(torch.device(device))

point_labels = torch.as_tensor([point_labels]).to(torch.device(device)) if point_labels is not None else None
data["points"] = points
data["point_labels"] = point_labels

inputs = data[input_key]
inputs = inputs if torch.is_tensor(inputs) else torch.from_numpy(inputs)
inputs = inputs[None].to(torch.device(device))
inputs = inputs.to(torch.device(device))

with torch.no_grad():
outputs = inferer(inputs, network, point_coords=points, point_labels=point_labels, class_vector=label_prompt)

if device.startswith("cuda"):
torch.cuda.empty_cache()

if convert_to_batch:
if isinstance(outputs, dict):
outputs_d = decollate_batch(outputs)
outputs = outputs_d[0]
else:
outputs = outputs[0]

data[output_label_key] = outputs[0] if isinstance(outputs, list) else outputs
return data
1 change: 1 addition & 0 deletions monailabel/utils/others/modelzoo_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@
"lung_nodule_ct_detection", # The first lung nodule detection task can be used for MONAI Label. Added Dec 2022
"wholeBody_ct_segmentation", # The SegResNet trained TotalSegmentator dataset with 104 tissues. Added Feb 2023
"vista2d", # The VISTA segmentation trained on a collection of 15K public microscopy images. Added Aug 2024
"vista3d",
]
25 changes: 25 additions & 0 deletions plugins/ohifv3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,33 @@ To run the OHIF plugin, update DICOM and MONAI Server Endpoints in configs/nginx
```bash
sh run.sh
```

You can also open dev mode in OHIF Viewer:
```bash
# cd Viewers, cd to the ohifv3/Viewers folder after yarn build.
# use cors friendly browser such as Chrome:
# google-chrome --disable-web-security --user-data-dir="/tmp/chrome_dev"

yarn run dev:orthanc
```

You can then visit http://127.0.0.1:3000/ohif/ on your browser to see the running OHIF.


### VISTA Interactive Models example

```bash
# use local monialbel build if monailabel is not the latest tagged.

# git clone https://github.com/Project-MONAI/MONAILabel.git

# git switch to the branch neede.

# command to start monailabel server with vista3d
'path/to/MONAILabel/monailabel/scripts/monailabel' start_server --app sample-apps/monaibundle --studies http://127.0.0.1:8042/dicom-web --conf models vista3d
```


### Installing Orthanc (DICOMWeb)

#### Ubuntu 20.x
Expand Down
4 changes: 2 additions & 2 deletions plugins/ohifv3/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ yarn config set workspaces-experimental true
yarn install
APP_CONFIG=config/monai_label.js PUBLIC_URL=/ohif/ QUICK_BUILD=true yarn run build

rm -rf ${install_dir}
# rm -rf ${install_dir}
cp -r platform/app/dist/ ${install_dir}
echo "Copied OHIF to ${install_dir}"
rm -rf ../Viewers
# rm -rf ../Viewers

cd ${curr_dir}
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
}
.modelSelector .selectBox {
width: 100%;
color: #000;
}
.modelSelector .actionButton {
border: 2px solid #000;
border-radius: 15px;
background-color: #add8e6;
color: var(--ui-gray-dark);
color: #000;
line-height: 25px;
padding: 0 15px;
outline: none;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
@import url("w3.css");

.monaiLabelPanel {
background-color: #789;
background-color: rgb(0, 0, 0);
height: 100%;
width: 100%;
display: flex;
flex-direction: column;
color: var(--text-primary-color);
color: #fff;
padding: 2px;
overflow-y: scroll; /* Make the panel scrollable vertically */
/* Accordion styles */
Expand All @@ -15,7 +15,7 @@
font-size: 14px;
text-decoration: underline;
font-weight: 500;
color: #000;
color: #ffffff;
margin: 1px;
text-align: center;
}
Expand All @@ -40,7 +40,7 @@
display: flex;
justify-content: space-between;
padding: 0.4em;
background: #16202b;
background: #3e5975;
border-right: 1px dotted #3c5d80;
color: #fff;
font-size: 12px;
Expand All @@ -61,11 +61,11 @@
.monaiLabelPanel .tab-content {
max-height: 0;
padding: 0 1em;
background: #808080;
background: #000000;
transition: all 0.35s;
width: 90%;
font-size: small;
color: #000;
color: #fff;
}
.monaiLabelPanel .tab-close {
display: flex;
Expand All @@ -76,10 +76,10 @@
cursor: pointer;
}
.monaiLabelPanel .tab-close:hover {
background: #1a252f;
background: #3e5975;
}
.monaiLabelPanel input:checked + .tab-label {
background: #000;
background: #009bd1;
}
.monaiLabelPanel input:checked + .tab-label::after {
transform: rotate(90deg);
Expand Down
Loading
Loading