From 5de6e6a417477daa76e786e3a8bcc3e98a2c66c0 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 12 Apr 2024 21:11:34 +0200 Subject: [PATCH 1/3] Add initial scripts for bioengine integration --- .../bioimageio/export_model_for_bioengine.py | 3 + examples/bioimageio/imjoy_test.py | 44 +++ micro_sam/bioimageio/bioengine_export.py | 250 ++++++++++++++++++ 3 files changed, 297 insertions(+) create mode 100644 examples/bioimageio/export_model_for_bioengine.py create mode 100644 examples/bioimageio/imjoy_test.py create mode 100644 micro_sam/bioimageio/bioengine_export.py diff --git a/examples/bioimageio/export_model_for_bioengine.py b/examples/bioimageio/export_model_for_bioengine.py new file mode 100644 index 00000000..0b3f9762 --- /dev/null +++ b/examples/bioimageio/export_model_for_bioengine.py @@ -0,0 +1,3 @@ +from micro_sam.modelzoo.bioengine_export import export_bioengine_model + +export_bioengine_model("vit_b", "test-export", opset=12) diff --git a/examples/bioimageio/imjoy_test.py b/examples/bioimageio/imjoy_test.py new file mode 100644 index 00000000..2bd500b5 --- /dev/null +++ b/examples/bioimageio/imjoy_test.py @@ -0,0 +1,44 @@ +import numpy as np +from imjoy_rpc.hypha import connect_to_server +import time + +image = np.random.randint(0, 255, size=(1, 3, 1024, 1024), dtype=np.uint8).astype( + "float32" +) + +# SERVER_URL = 'http://127.0.0.1:9520' # "https://ai.imjoy.io" +# SERVER_URL = "https://hypha.bioimage.io" +# SERVER_URL = "https://ai.imjoy.io" +SERVER_URL = "https://hypha.bioimage.io" + + +async def test_backbone(triton): + config = await triton.get_config(model_name="micro-sam-vit-b-backbone") + print(config) + + image = np.random.randint(0, 255, size=(1, 3, 1024, 1024), dtype=np.uint8).astype( + "float32" + ) + + start_time = time.time() + result = await triton.execute( + inputs=[image], + model_name="micro-sam-vit-b-backbone", + ) + print("Backbone", result) + embedding = result['output0__0'] + print("Time taken: ", time.time() - start_time) + print("Test passed", embedding.shape) + + +async def run(): + server = await connect_to_server( + {"name": "test client", "server_url": SERVER_URL, "method_timeout": 100} + ) + triton = await server.get_service("triton-client") + await test_backbone(triton) + + +if __name__ == "__main__": + import asyncio + asyncio.run(run()) diff --git a/micro_sam/bioimageio/bioengine_export.py b/micro_sam/bioimageio/bioengine_export.py new file mode 100644 index 00000000..9559f97d --- /dev/null +++ b/micro_sam/bioimageio/bioengine_export.py @@ -0,0 +1,250 @@ +import os +import warnings +from typing import Optional, Union + +import torch +from segment_anything.utils.onnx import SamOnnxModel + +try: + import onnxruntime + onnxruntime_exists = True +except ImportError: + onnxruntime_exists = False + +from ..util import get_sam_model + + +ENCODER_CONFIG = """name: "%s" +backend: "pytorch" +platform: "pytorch_libtorch" + +max_batch_size : 1 +input [ + { + name: "input0__0" + data_type: TYPE_FP32 + dims: [3, -1, -1] + } +] +output [ + { + name: "output0__0" + data_type: TYPE_FP32 + dims: [256, 64, 64] + } +] + +parameters: { + key: "INFERENCE_MODE" + value: { + string_value: "true" + } +}""" + + +DECODER_CONFIG = """name: "%s" +backend: "onnxruntime" +platform: "onnxruntime_onnx" + +parameters: { + key: "INFERENCE_MODE" + value: { + string_value: "true" + } +} + +instance_group { + count: 1 + kind: KIND_CPU +}""" + + +def _to_numpy(tensor): + return tensor.cpu().numpy() + + +def export_image_encoder( + model_type: str, + output_root: Union[str, os.PathLike], + export_name: Optional[str] = None, + checkpoint_path: Optional[str] = None, +) -> None: + """Export SAM image encoder to torchscript. + + The torchscript image encoder can be used for predicting image embeddings + with a backed, e.g. with [the bioengine](https://github.com/bioimage-io/bioengine-model-runner). + + Args: + model_type: The SAM model type. + output_root: The output root directory where the exported model is saved. + export_name: The name of the exported model. + checkpoint_path: Optional checkpoint for loading the exported model. + """ + if export_name is None: + export_name = model_type + name = f"sam-{export_name}-encoder" + + output_folder = os.path.join(output_root, name) + weight_output_folder = os.path.join(output_folder, "1") + os.makedirs(weight_output_folder, exist_ok=True) + + predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) + encoder = predictor.model.image_encoder + + encoder.eval() + input_ = torch.rand(1, 3, 1024, 1024) + traced_model = torch.jit.trace(encoder, input_) + weight_path = os.path.join(weight_output_folder, "model.pt") + traced_model.save(weight_path) + + config_output_path = os.path.join(output_folder, "config.pbtxt") + with open(config_output_path, "w") as f: + f.write(ENCODER_CONFIG % name) + + +def export_onnx_model( + model_type, + output_root, + opset: int, + export_name: Optional[str] = None, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + return_single_mask: bool = True, + gelu_approximate: bool = False, + use_stability_score: bool = False, + return_extra_metrics: bool = False, +) -> None: + """Export SAM prompt enocer and mask decoder to onnx. + + The onnx encoder and decoder can be used for interactive segmentation in the browser. + This code is adapted from + https://github.com/facebookresearch/segment-anything/blob/main/scripts/export_onnx_model.py + + Args: + model_type: The SAM model type. + output_root: The output root directory where the exported model is saved. + opset: The ONNX opset version. + export_name: The name of the exported model. + checkpoint_path: Optional checkpoint for loading the SAM model. + return_single_mask: Whether the mask decoder returns a single or multiple masks. + gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend + does not have an efficient GeLU implementation. + use_stability_score: Whether to use the stability score instead of the predicted score. + return_extra_metrics: Whether to return a larger set of metrics. + """ + if export_name is None: + export_name = model_type + name = f"sam-{export_name}-decoder" + + output_folder = os.path.join(output_root, name) + weight_output_folder = os.path.join(output_folder, "1") + os.makedirs(weight_output_folder, exist_ok=True) + + _, sam = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True) + weight_path = os.path.join(weight_output_folder, "model.onnx") + + onnx_model = SamOnnxModel( + model=sam, + return_single_mask=return_single_mask, + use_stability_score=use_stability_score, + return_extra_metrics=return_extra_metrics, + ) + + if gelu_approximate: + for n, m in onnx_model.named_modules: + if isinstance(m, torch.nn.GELU): + m.approximate = "tanh" + + dynamic_axes = { + "point_coords": {1: "num_points"}, + "point_labels": {1: "num_points"}, + } + + embed_dim = sam.prompt_encoder.embed_dim + embed_size = sam.prompt_encoder.image_embedding_size + + mask_input_size = [4 * x for x in embed_size] + dummy_inputs = { + "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), + "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), + "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), + "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), + "has_mask_input": torch.tensor([1], dtype=torch.float), + "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), + } + + _ = onnx_model(**dummy_inputs) + + output_names = ["masks", "iou_predictions", "low_res_masks"] + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + warnings.filterwarnings("ignore", category=UserWarning) + with open(weight_path, "wb") as f: + print(f"Exporting onnx model to {weight_path}...") + torch.onnx.export( + onnx_model, + tuple(dummy_inputs.values()), + f, + export_params=True, + verbose=False, + opset_version=opset, + do_constant_folding=True, + input_names=list(dummy_inputs.keys()), + output_names=output_names, + dynamic_axes=dynamic_axes, + ) + + if onnxruntime_exists: + ort_inputs = {k: _to_numpy(v) for k, v in dummy_inputs.items()} + # set cpu provider default + providers = ["CPUExecutionProvider"] + ort_session = onnxruntime.InferenceSession(weight_path, providers=providers) + _ = ort_session.run(None, ort_inputs) + print("Model has successfully been run with ONNXRuntime.") + + config_output_path = os.path.join(output_folder, "config.pbtxt") + with open(config_output_path, "w") as f: + f.write(DECODER_CONFIG % name) + + +def export_bioengine_model( + model_type, + output_root, + opset: int, + export_name: Optional[str] = None, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + return_single_mask: bool = True, + gelu_approximate: bool = False, + use_stability_score: bool = False, + return_extra_metrics: bool = False, +) -> None: + """Export SAM model to a format compatible with the BioEngine. + + [The bioengine](https://github.com/bioimage-io/bioengine-model-runner) enables running the + image encoder on an online backend, so that SAM can be used in an online tool, or to predict + the image embeddings via the online backend rather than on CPU. + + Args: + model_type: The SAM model type. + output_root: The output root directory where the exported model is saved. + opset: The ONNX opset version. + export_name: The name of the exported model. + checkpoint_path: Optional checkpoint for loading the SAM model. + return_single_mask: Whether the mask decoder returns a single or multiple masks. + gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend + does not have an efficient GeLU implementation. + use_stability_score: Whether to use the stability score instead of the predicted score. + return_extra_metrics: Whether to return a larger set of metrics. + """ + export_image_encoder(model_type, output_root, export_name, checkpoint_path) + export_onnx_model( + model_type=model_type, + output_root=output_root, + opset=opset, + export_name=export_name, + checkpoint_path=checkpoint_path, + return_single_mask=return_single_mask, + gelu_approximate=gelu_approximate, + use_stability_score=use_stability_score, + return_extra_metrics=return_extra_metrics, + ) From 74af1b625374532b8898ec4d465d012dff075408 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 10 May 2024 10:20:54 +0200 Subject: [PATCH 2/3] Fix CI --- examples/bioimageio/export_model_for_bioengine.py | 4 ++-- examples/bioimageio/{imjoy_test.py => run_imjoy.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename examples/bioimageio/{imjoy_test.py => run_imjoy.py} (100%) diff --git a/examples/bioimageio/export_model_for_bioengine.py b/examples/bioimageio/export_model_for_bioengine.py index 0b3f9762..e9f304d5 100644 --- a/examples/bioimageio/export_model_for_bioengine.py +++ b/examples/bioimageio/export_model_for_bioengine.py @@ -1,3 +1,3 @@ -from micro_sam.modelzoo.bioengine_export import export_bioengine_model +from micro_sam.bioimageio.bioengine_export import export_bioengine_model -export_bioengine_model("vit_b", "test-export", opset=12) +export_bioengine_model("vit_t", "test-export", opset=12) diff --git a/examples/bioimageio/imjoy_test.py b/examples/bioimageio/run_imjoy.py similarity index 100% rename from examples/bioimageio/imjoy_test.py rename to examples/bioimageio/run_imjoy.py From 4f7d41c3d017d67a5ed06659805ebf4a5efc4dfb Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 11 May 2024 17:19:19 +0200 Subject: [PATCH 3/3] Add proof-of-concept scripts for SAM hypha server --- examples/bioimageio/hypha_data_store.py | 141 ++++++++++++++++++ examples/bioimageio/sam_server.py | 185 ++++++++++++++++++++++++ examples/bioimageio/server_tests.py | 31 ++++ 3 files changed, 357 insertions(+) create mode 100644 examples/bioimageio/hypha_data_store.py create mode 100644 examples/bioimageio/sam_server.py create mode 100644 examples/bioimageio/server_tests.py diff --git a/examples/bioimageio/hypha_data_store.py b/examples/bioimageio/hypha_data_store.py new file mode 100644 index 00000000..14a06a0f --- /dev/null +++ b/examples/bioimageio/hypha_data_store.py @@ -0,0 +1,141 @@ +import json +import uuid +import mimetypes +import os +from urllib.parse import parse_qs + +class HyphaDataStore: + def __init__(self): + self.storage = {} + self._svc = None + self._server = None + + async def setup(self, server, service_id="data-store", visibility="public"): + self._server = server + self._svc = await server.register_service({ + "id": service_id, + "type": "functions", + "config": { + "visibility": visibility, + "require_context": False + }, + "get": self.http_get, + }, overwrite=True) + + def get_url(self, obj_id: str): + assert self._svc, "Service not initialized, call `setup()`" + assert obj_id in self.storage, "Object not found " + obj_id + return f"{self._server.config.public_base_url}/{self._server.config.workspace}/apps/{self._svc.id.split(':')[1]}/get?id={obj_id}" + + def put(self, obj_type: str, value: any, name: str, comment: str = ""): + assert self._svc, "Please call `setup()` before using the store" + obj_id = str(uuid.uuid4()) + if obj_type == 'file': + data = value + assert isinstance(data, (str, bytes)), "Value must be a string or bytes" + if isinstance(data, str) and data.startswith("file://"): + # File URL examples: + # Absolute URL: `file:///home/data/myfile.png` + # Relative URL: `file://./myimage.png`, or `file://myimage.png` + with open(data.replace("file://", ""), 'rb') as fil: + data = fil.read() + mime_type, _ = mimetypes.guess_type(name) + self.storage[obj_id] = { + 'type': obj_type, + 'name': name, + 'value': data, + 'mime_type': mime_type or 'application/octet-stream', + 'comment': comment + } + else: + self.storage[obj_id] = { + 'type': obj_type, + 'name': name, + 'value': value, + 'mime_type': 'application/json', + 'comment': comment + } + return obj_id + + def get(self, id: str): + assert self._svc, "Please call `setup()` before using the store" + obj = self.storage.get(id) + return obj + + def http_get(self, scope, context=None): + query_string = scope['query_string'] + id = parse_qs(query_string).get('id', [])[0] + obj = self.storage.get(id) + if obj is None: + return {'status': 404, 'headers': {}, 'body': "Not found: " + id} + + if obj['type'] == 'file': + data = obj['value'] + if isinstance(data, str): + if not os.path.isfile(data): + return { + "status": 404, + 'headers': {'Content-Type': 'text/plain'}, + "body": "File not found: " + data + } + with open(data, 'rb') as fil: + data = fil.read() + headers = { + 'Content-Type': obj['mime_type'], + 'Content-Length': str(len(obj['value'])), + 'Content-Disposition': f'inline; filename="{obj["name"].split("/")[-1]}"' + } + + return { + 'status': 200, + 'headers': headers, + 'body': obj['value'] + } + else: + return { + 'status': 200, + 'headers': {'Content-Type': 'application/json'}, + 'body': json.dumps(obj['value']) + } + + def http_list(self, scope, context=None): + query_string = scope.get('query_string', b'') + kws = parse_qs(query_string).get('keyword', []) + keyword = kws[0] if kws else None + result = [value for key, value in self.storage.items() if not keyword or keyword in value['name']] + return {'status': 200, 'headers': {'Content-Type': 'application/json'}, 'body': json.dumps(result)} + + def remove(self, obj_id: str): + assert self._svc, "Please call `setup()` before using the store" + if obj_id in self.storage: + del self.storage[obj_id] + return True + raise IndexError("Not found: " + obj_id) + +async def test_data_store(server_url="https://ai.imjoy.io"): + from imjoy_rpc.hypha import connect_to_server, login + token = await login({"server_url": server_url}) + server = await connect_to_server({"server_url": server_url, "token": token}) + + ds = HyphaDataStore() + # Setup would need to be completed in an ASGI compatible environment + await ds.setup(server) + + # Test PUT operation + file_id = ds.put('file', 'file:///home/data.txt', 'data.txt') + binary_id = ds.put('file', b'Some binary content', 'example.bin') + json_id = ds.put('json', {'hello': 'world'}, 'example.json') + + # Test GET operation + assert ds.get(file_id)['type'] == 'file' + assert ds.get(binary_id)['type'] == 'file' + assert ds.get(json_id)['type'] == 'json' + + # Test GET URL generation + print("URL for getting file", ds.get_url(file_id)) + print("URL for getting binary object", ds.get_url(binary_id)) + print("URL for getting json object", ds.get_url(json_id)) + +if __name__ == "__main__": + import asyncio + asyncio.run(test_data_store()) diff --git a/examples/bioimageio/sam_server.py b/examples/bioimageio/sam_server.py new file mode 100644 index 00000000..0148f1cd --- /dev/null +++ b/examples/bioimageio/sam_server.py @@ -0,0 +1,185 @@ +import os +import warnings +from functools import partial + +# import urllib +import imageio.v3 as imageio +import numpy as np +import requests +import torch + +from hypha_data_store import HyphaDataStore +from segment_anything import sam_model_registry, SamPredictor +from segment_anything.utils.onnx import SamOnnxModel + +image_url = "https://owncloud.gwdg.de/index.php/s/fSaOJIOYjmFBjPM/download" + + +def get_sam_model(model_name): + models = { + "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/staged/1/files/vit_b.pt", + # TODO + "vit_b_em_organelles": "", + } + model_url = models[model_name] + checkpoint_path = f"{model_name}.pt" + + if not os.path.exists(checkpoint_path): + response = requests.get(model_url) + if response.status_code == 200: + with open(checkpoint_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + device = "cuda" if torch.cuda.is_available() else "cpu" + model_type = model_name[:5] + sam = sam_model_registry[model_type]() + ckpt = torch.load(checkpoint_path, map_location=device) + sam.load_state_dict(ckpt) + return sam + + +def export_onnx_model( + sam, + output_path, + opset: int, + return_single_mask: bool = True, + gelu_approximate: bool = False, + use_stability_score: bool = False, + return_extra_metrics: bool = False, +) -> None: + + onnx_model = SamOnnxModel( + model=sam, + return_single_mask=return_single_mask, + use_stability_score=use_stability_score, + return_extra_metrics=return_extra_metrics, + ) + + if gelu_approximate: + for n, m in onnx_model.named_modules: + if isinstance(m, torch.nn.GELU): + m.approximate = "tanh" + + dynamic_axes = { + "point_coords": {1: "num_points"}, + "point_labels": {1: "num_points"}, + } + + embed_dim = sam.prompt_encoder.embed_dim + embed_size = sam.prompt_encoder.image_embedding_size + + mask_input_size = [4 * x for x in embed_size] + dummy_inputs = { + "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), + "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), + "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), + "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), + "has_mask_input": torch.tensor([1], dtype=torch.float), + "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), + } + + _ = onnx_model(**dummy_inputs) + + output_names = ["masks", "iou_predictions", "low_res_masks"] + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + warnings.filterwarnings("ignore", category=UserWarning) + with open(output_path, "wb") as f: + print(f"Exporting onnx model to {output_path}...") + torch.onnx.export( + onnx_model, + tuple(dummy_inputs.values()), + f, + export_params=True, + verbose=False, + opset_version=opset, + do_constant_folding=True, + input_names=list(dummy_inputs.keys()), + output_names=output_names, + dynamic_axes=dynamic_axes, + ) + + +def get_example_image(): + image = imageio.imread(image_url) + return np.asarray(image) + + +def _to_image(input_): + # we require the input to be uint8 + if input_.dtype != np.dtype("uint8"): + # first normalize the input to [0, 1] + input_ = input_.astype("float32") - input_.min() + input_ = input_ / input_.max() + # then bring to [0, 255] and cast to uint8 + input_ = (input_ * 255).astype("uint8") + if input_.ndim == 2: + image = np.concatenate([input_[..., None]] * 3, axis=-1) + elif input_.ndim == 3 and input_.shape[-1] == 3: + image = input_ + else: + raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image.") + return image + + +def compute_embeddings(model_name="vit_b"): + sam = get_sam_model(model_name) + predictor = SamPredictor(sam) + image = get_example_image() + predictor.reset_image() + predictor.set_image(_to_image(image)) + image_embeddings = predictor.get_image_embedding().cpu().numpy() + return image_embeddings + + +async def get_onnx(ds, model_name="vit_b", opset_version=12): + output_path = f"{model_name}.onnx" + if not os.path.exists(output_path): + sam = get_sam_model(model_name) + export_onnx_model(sam, output_path, opset=opset_version) + + file_id = ds.put("file", f"file://{output_path}", output_path) + url = ds.get_url(file_id) + return url + + +async def start_server(): + from imjoy_rpc.hypha import connect_to_server, login + + server_url = "https://ai.imjoy.io" + + token = await login({"server_url": server_url}) + server = await connect_to_server({"server_url": server_url, "token": token}) + + # Upload to hypha. + ds = HyphaDataStore() + await ds.setup(server) + + svc = await server.register_service({ + "name": "Sam Server", + "id": "bioimageio-colab", + "config": { + "visibility": "public" + }, + "get_onnx": partial(get_onnx, ds=ds), + "compute_embeddings": compute_embeddings, + "get_example_image": get_example_image, + "ping": lambda: "pong" + }) + sid = svc['id'] + # config_str = f'{{"service_id": "{sid}", "server_url": "{server_url}"}}' + # encoded_config = urllib.parse.quote(config_str, safe='/', encoding=None, errors=None) + # annotator_url = 'https://imjoy.io/lite?plugin=https://raw.githubusercontent.com/bioimage-io/bioimageio-colab/main/plugins/bioimageio-colab.imjoy.html&config=' + encoded_config + print(sid) + + +if __name__ == "__main__": + import asyncio + + loop = asyncio.get_event_loop() + loop.create_task(start_server()) + + loop.run_forever() diff --git a/examples/bioimageio/server_tests.py b/examples/bioimageio/server_tests.py new file mode 100644 index 00000000..53a716cc --- /dev/null +++ b/examples/bioimageio/server_tests.py @@ -0,0 +1,31 @@ + + +def test_example_image(): + from sam_server import get_example_image + + image = get_example_image() + print(image.shape) + + +def test_onnx(): + from sam_server import get_sam_model, export_onnx_model + + print("Download!!!") + sam = get_sam_model("vit_b") + print("ONNX!!!") + export_onnx_model(sam, "onnx-test.onnx", opset=12) + print("Done!!") + + +def test_embeddings(): + from sam_server import compute_embeddings + + embeds = compute_embeddings() + + print(embeds.shape) + + +if __name__ == "__main__": + test_example_image() + # test_onnx() + # test_embeddings()