diff --git a/multi_object_tracking/yolo_sam/.dockerignore b/multi_object_tracking/yolo_sam/.dockerignore new file mode 100644 index 000000000..894795019 --- /dev/null +++ b/multi_object_tracking/yolo_sam/.dockerignore @@ -0,0 +1,18 @@ +# Exclude everything +_wsgi.py + +# Include Dockerfile and docker-compose for reference (optional, decide based on your use case) +!Dockerfile +!docker-compose.yml + +# Include Python application files +!*.py + +# Include requirements files +!requirements*.txt + +# Include script +!*.sh + +# Exclude specific requirements if necessary +# requirements-test.txt (Uncomment if you decide to exclude this) diff --git a/multi_object_tracking/yolo_sam/Dockerfile b/multi_object_tracking/yolo_sam/Dockerfile new file mode 100644 index 000000000..033a3263d --- /dev/null +++ b/multi_object_tracking/yolo_sam/Dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime + +ARG DEBIAN_FRONTEND=noninteractive +ARG TEST_ENV + +WORKDIR /app + +# Update Conda +RUN conda update conda -y + +# Install system dependencies +RUN --mount=type=cache,target="/var/cache/apt",sharing=locked \ + --mount=type=cache,target="/var/lib/apt/lists",sharing=locked \ + apt-get -y update \ + && apt-get install -y git wget g++ freeglut3-dev build-essential \ + libx11-dev libxmu-dev libxi-dev libglu1-mesa libglu1-mesa-dev \ + libfreeimage-dev ffmpeg libsm6 libxext6 libffi-dev python3-dev \ + python3-pip gcc + +# Environment variables +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_CACHE_DIR=/.cache \ + PORT=9090 \ + WORKERS=2 \ + THREADS=4 \ + CUDA_HOME=/usr/local/cuda \ + TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0;7.5;8.0;8.6+PTX;8.9;9.0" \ + SEGMENT_ANYTHING_2_REPO_PATH=/segment-anything-2 \ + PYTHONPATH=/app + +# Install CUDA toolkit via Conda +RUN conda install -c "nvidia/label/cuda-12.1.1" cuda -y + +# Install Python dependencies +COPY requirements-base.txt . +RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ + pip install -r requirements-base.txt + +COPY requirements.txt . +RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ + pip install -r requirements.txt + +# Install segment-anything-2 +RUN cd / && git clone --depth 1 --branch main --single-branch https://github.com/facebookresearch/sam2.git +WORKDIR /sam2 +RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ + pip install -e . +RUN cd checkpoints && ./download_ckpts.sh + +# Return to app working directory +WORKDIR /app + +# Install test dependencies (optional) +COPY requirements-test.txt . +RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ + if [ "$TEST_ENV" = "true" ]; then \ + pip install -r requirements-test.txt; \ + fi + +# Download YOLO models +RUN /bin/sh -c 'if [ ! -f /app/models/yolov8m.pt ]; then \ + yolo predict model=/app/models/yolov8m.pt source=/app/tests/car.jpg \ + && yolo predict model=/app/models/yolov8n.pt source=/app/tests/car.jpg \ + && yolo predict model=/app/models/yolov8n-cls.pt source=/app/tests/car.jpg \ + && yolo predict model=/app/models/yolov8n-seg.pt source=/app/tests/car.jpg; \ + fi' + +# Copy app files +COPY . ./ + +# Default command +CMD ["/app/start.sh"] diff --git a/multi_object_tracking/yolo_sam/README.md b/multi_object_tracking/yolo_sam/README.md new file mode 100644 index 000000000..0bcfbca74 --- /dev/null +++ b/multi_object_tracking/yolo_sam/README.md @@ -0,0 +1,58 @@ +This guide describes the simplest way to start using ML backend with Label Studio. + +## Running with Docker (Recommended) + +1. Start Machine Learning backend on `http://localhost:9090` with prebuilt image: + +```bash +docker-compose up +``` + +2. Validate that backend is running + +```bash +$ curl http://localhost:9090/ +{"status":"UP"} +``` + +3. Connect to the backend from Label Studio running on the same host: go to your project `Settings -> Machine Learning -> Add Model` and specify `http://localhost:9090` as a URL. + + +## Building from source (Advanced) + +To build the ML backend from source, you have to clone the repository and build the Docker image: + +```bash +docker-compose build +``` + +## Running without Docker (Advanced) + +To run the ML backend without Docker, you have to clone the repository and install all dependencies using pip: + +```bash +python -m venv ml-backend +source ml-backend/bin/activate +pip install -r requirements.txt +``` + +Then you can start the ML backend: + +```bash +label-studio-ml start ./dir_with_your_model +``` + +# Configuration +Parameters can be set in `docker-compose.yml` before running the container. + + +The following common parameters are available: +- `BASIC_AUTH_USER` - specify the basic auth user for the model server +- `BASIC_AUTH_PASS` - specify the basic auth password for the model server +- `LOG_LEVEL` - set the log level for the model server +- `WORKERS` - specify the number of workers for the model server +- `THREADS` - specify the number of threads for the model server + +# Customization + +The ML backend can be customized by adding your own models and logic inside the `./dir_with_your_model` directory. \ No newline at end of file diff --git a/multi_object_tracking/yolo_sam/_wsgi.py b/multi_object_tracking/yolo_sam/_wsgi.py new file mode 100644 index 000000000..789f04669 --- /dev/null +++ b/multi_object_tracking/yolo_sam/_wsgi.py @@ -0,0 +1,125 @@ +import os +import argparse +import json +import logging +import logging.config + +# Set a default log level if LOG_LEVEL is not defined +log_level = os.getenv("LOG_LEVEL", "INFO") + +logging.config.dictConfig({ + "version": 1, + "disable_existing_loggers": False, # Prevent overriding existing loggers + "formatters": { + "standard": { + "format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s" + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": log_level, + "stream": "ext://sys.stdout", + "formatter": "standard" + } + }, + "root": { + "level": log_level, + "handlers": [ + "console" + ], + "propagate": True + } +}) + +from label_studio_ml.api import init_app +from model import NewModel + + +_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json') + + +def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH): + if not os.path.exists(config_path): + return dict() + with open(config_path) as f: + config = json.load(f) + assert isinstance(config, dict) + return config + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Label studio') + parser.add_argument( + '-p', '--port', dest='port', type=int, default=9090, + help='Server port') + parser.add_argument( + '--host', dest='host', type=str, default='0.0.0.0', + help='Server host') + parser.add_argument( + '--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='), + help='Additional LabelStudioMLBase model initialization kwargs') + parser.add_argument( + '-d', '--debug', dest='debug', action='store_true', + help='Switch debug mode') + parser.add_argument( + '--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=log_level, + help='Logging level') + parser.add_argument( + '--model-dir', dest='model_dir', default=os.path.dirname(__file__), + help='Directory where models are stored (relative to the project directory)') + parser.add_argument( + '--check', dest='check', action='store_true', + help='Validate model instance before launching server') + parser.add_argument('--basic-auth-user', + default=os.environ.get('ML_SERVER_BASIC_AUTH_USER', None), + help='Basic auth user') + + parser.add_argument('--basic-auth-pass', + default=os.environ.get('ML_SERVER_BASIC_AUTH_PASS', None), + help='Basic auth pass') + + args = parser.parse_args() + + # setup logging level + if args.log_level: + logging.root.setLevel(args.log_level) + + def isfloat(value): + try: + float(value) + return True + except ValueError: + return False + + def parse_kwargs(): + param = dict() + for k, v in args.kwargs: + if v.isdigit(): + param[k] = int(v) + elif v == 'True' or v == 'true': + param[k] = True + elif v == 'False' or v == 'false': + param[k] = False + elif isfloat(v): + param[k] = float(v) + else: + param[k] = v + return param + + kwargs = get_kwargs_from_config() + + if args.kwargs: + kwargs.update(parse_kwargs()) + + if args.check: + print('Check "' + NewModel.__name__ + '" instance creation..') + model = NewModel(**kwargs) + + app = init_app(model_class=NewModel, basic_auth_user=args.basic_auth_user, basic_auth_pass=args.basic_auth_pass) + + app.run(host=args.host, port=args.port, debug=args.debug) + +else: + # for uWSGI use + app = init_app(model_class=NewModel) diff --git a/multi_object_tracking/yolo_sam/control_models/__init__.py b/multi_object_tracking/yolo_sam/control_models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/multi_object_tracking/yolo_sam/control_models/base.py b/multi_object_tracking/yolo_sam/control_models/base.py new file mode 100644 index 000000000..fbc3e1cc4 --- /dev/null +++ b/multi_object_tracking/yolo_sam/control_models/base.py @@ -0,0 +1,201 @@ +import os +import logging + +from pydantic import BaseModel +from typing import Optional, List, Dict, ClassVar +from ultralytics import YOLO + +from label_studio_ml.model import LabelStudioMLBase +from label_studio_ml.utils import DATA_UNDEFINED_NAME +from label_studio_sdk._extensions.label_studio_tools.core.utils.io import get_local_path +from label_studio_sdk.label_interface.control_tags import ControlTag +from label_studio_sdk.label_interface import LabelInterface + + +# use matplotlib plots for debug +DEBUG_PLOT = os.getenv("DEBUG_PLOT", "false").lower() in ["1", "true"] +MODEL_SCORE_THRESHOLD = float(os.getenv("MODEL_SCORE_THRESHOLD", 0.5)) +DEFAULT_MODEL_ROOT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models") +MODEL_ROOT = os.getenv("MODEL_ROOT", DEFAULT_MODEL_ROOT) +os.makedirs(MODEL_ROOT, exist_ok=True) +# if true, allow to use custom model path from the control tag in the labeling config +ALLOW_CUSTOM_MODEL_PATH = os.getenv("ALLOW_CUSTOM_MODEL_PATH", "true").lower() in [ + "1", + "true", +] + +# Global cache for YOLO models +_model_cache = {} +logger = logging.getLogger(__name__) + + +def get_bool(attr, attr_name, default="false"): + return attr.get(attr_name, default).lower() in ["1", "true", "yes"] + + +class ControlModel(BaseModel): + """ + Represents a control tag in Label Studio, which is associated with a specific type of labeling task + and is used to generate predictions using a YOLO model. + + Attributes: + type (str): Type of the control, e.g., RectangleLabels, Choices, etc. + control (ControlTag): The actual control element from the Label Studio configuration. + from_name (str): The name of the control tag, used to link the control to the data. + to_name (str): The name of the data field that this control is associated with. + value (str): The value name from the object that this control operates on, e.g., an image or text field. + model (object): The model instance (e.g., YOLO) used to generate predictions for this control. + model_path (str): Path to the YOLO model file. + model_score_threshold (float): Threshold for prediction scores; predictions below this value will be ignored. + label_map (Optional[Dict[str, str]]): A mapping of model labels to Label Studio labels. + """ + + type: ClassVar[str] + control: ControlTag + from_name: str + to_name: str + value: str + model: YOLO + model_path: ClassVar[str] + model_score_threshold: float = 0.5 + label_map: Optional[Dict[str, str]] = {} + label_studio_ml_backend: LabelStudioMLBase + project_id: Optional[str] = None + + def __init__(self, **data): + super().__init__(**data) + + @classmethod + def is_control_matched(cls, control) -> bool: + """Check if the control tag matches the model type. + Args: + control (ControlTag): The control tag from the Label Studio Interface. + """ + raise NotImplementedError("This method should be overridden in derived classes") + + @staticmethod + def get_from_name_for_label_map( + label_interface: LabelInterface, target_name: str + ) -> str: + """Get the 'from_name' attribute for the label map building.""" + return target_name + + @classmethod + def create(cls, mlbackend: LabelStudioMLBase, control: ControlTag): + """Factory method to create an instance of a specific control model class. + Args: + mlbackend (LabelStudioMLBase): The ML backend instance. + control (ControlTag): The control tag from the Label Studio Interface. + """ + from_name = control.name + to_name = control.to_name[0] + value = control.objects[0].value_name + + # if skip is true, don't process this control + if get_bool(control.attr, "model_skip", "false"): + logger.info( + f"Skipping control tag '{control.tag}' with name '{from_name}', model_skip=true found" + ) + return None + # read threshold attribute from the control tag, e.g.: + model_score_threshold = float( + control.attr.get("model_score_threshold") + or control.attr.get( + "score_threshold" + ) # not recommended option, use `model_score_threshold` + or MODEL_SCORE_THRESHOLD + ) + # read `model_path` attribute from the control tag + model_path = ( + ALLOW_CUSTOM_MODEL_PATH and control.attr.get("model_path") + ) or cls.model_path + + model = cls.get_cached_model(model_path) + model_names = model.names.values() # class names from the model + # from_name for label mapping can be differed from control.name (e.g. VideoRectangle) + label_map_from_name = cls.get_from_name_for_label_map( + mlbackend.label_interface, from_name + ) + label_map = mlbackend.build_label_map(label_map_from_name, model_names) + + return cls( + control=control, + from_name=from_name, + to_name=to_name, + value=value, + model=model, + model_score_threshold=model_score_threshold, + label_map=label_map, + label_studio_ml_backend=mlbackend, + project_id=mlbackend.project_id, + ) + + @classmethod + def load_yolo_model(cls, filename) -> YOLO: + """Load YOLO model from the file.""" + path = os.path.join(MODEL_ROOT, filename) + logger.info(f"Loading yolo model: {path}") + model = YOLO(path) + logger.info(f"Model {path} names:\n{model.names}") + return model + + @classmethod + def get_cached_model(cls, path: str) -> YOLO: + if path not in _model_cache: + _model_cache[path] = cls.load_yolo_model(path) + return _model_cache[path] + + def debug_plot(self, image): + if not DEBUG_PLOT: + return + + import matplotlib.pyplot as plt + + plt.figure(figsize=(10, 10)) + plt.imshow(image[..., ::-1]) + plt.axis("off") + plt.title(self.type) + plt.show() + + def predict_regions(self, path) -> List[Dict]: + """Predict regions in the image using the YOLO model. + Args: + path (str): Path to the file with media + """ + raise NotImplementedError("This method should be overridden in derived classes") + + def fit(self, event, data, **kwargs): + """Fit the model.""" + logger.warning("The fit method is not implemented for this control model") + return False + + def get_path(self, task): + task_path = task["data"].get(self.value) or task["data"].get( + DATA_UNDEFINED_NAME + ) + if task_path is None: + raise ValueError( + f"Can't load path using key '{self.value}' from task {task}" + ) + if not isinstance(task_path, str): + raise ValueError(f"Path should be a string, but got {task_path}") + + # try path as local file or try to load it from Label Studio instance/download via http + path = ( + task_path + if os.path.exists(task_path) + else get_local_path(task_path, task_id=task.get("id")) + ) + logger.debug(f"load_image: {task_path} => {path}") + return path + + def __str__(self): + """Return a string with full representation of the control tag.""" + return ( + f"{self.type} from_name={self.from_name}, " + f"label_map={self.label_map}, model_score_threshold={self.model_score_threshold}" + ) + + class Config: + arbitrary_types_allowed = True + protected_namespaces = ("__.*__", "_.*") # Excludes 'model_' diff --git a/multi_object_tracking/yolo_sam/control_models/choices.py b/multi_object_tracking/yolo_sam/control_models/choices.py new file mode 100644 index 000000000..fc9dac106 --- /dev/null +++ b/multi_object_tracking/yolo_sam/control_models/choices.py @@ -0,0 +1,92 @@ +import logging +import numpy as np + +from control_models.base import ControlModel +from typing import List, Dict + + +logger = logging.getLogger(__name__) + + +class ChoicesModel(ControlModel): + """ + Class representing a Choices (classes) control tag for YOLO model. + """ + + type = "Choices" + model_path = "yolov8n-cls.pt" + + @classmethod + def is_control_matched(cls, control) -> bool: + # check object tag type + if control.objects[0].tag != "Image": + return False + # support both Choices and Taxonomy because of their similarity + return control.tag in [cls.type, "Taxonomy"] + + def predict_regions(self, path) -> List[Dict]: + results = self.model.predict(path) + self.debug_plot(results[0].plot()) + return self.create_choices(results, path) + + def create_choices(self, results, path): + logger.debug(f"create_choices: {self.from_name}") + mode = self.control.attr.get("choice", "single") + data = results[0].probs.data.cpu().numpy() + + # single + if mode in ["single", "single-radio"]: + # we must keep data items that matches label_map only, because we need to search among label_map only + indexes = [ + i for i, name in self.model.names.items() if name in self.label_map + ] + data = data[indexes] + model_names = [self.model.names[i] for i in indexes] + # find the best choice + index = np.argmax(data) + probs = [data[index]] + names = [model_names[index]] + # multi + else: + # get indexes of data where data >= self.model_score_threshold + indexes = np.where(data >= self.model_score_threshold) + probs = data[indexes].tolist() + names = [self.model.names[int(i)] for i in indexes[0]] + + if not probs: + logger.debug("No choices found") + return [] + + score = np.mean(probs) + logger.debug( + "----------------------\n" + f"task id > {path}\n" + f"control: {self.control}\n" + f"probs > {probs}\n" + f"score > {score}\n" + f"names > {names}\n" + ) + + if score < self.model_score_threshold: + logger.debug(f"Score is too low for single choice: {names[0]} = {probs[0]}") + return [] + + # map to Label Studio labels + output_labels = [ + self.label_map[name] for name in names if name in self.label_map + ] + + # add new region with rectangle + return [ + { + "from_name": self.from_name, + "to_name": self.to_name, + "type": "choices", + "value": {"choices": output_labels}, + "score": float(score), + } + ] + + +# pre-load and cache default model at startup +ChoicesModel.get_cached_model(ChoicesModel.model_path) diff --git a/multi_object_tracking/yolo_sam/control_models/keypoint_labels.py b/multi_object_tracking/yolo_sam/control_models/keypoint_labels.py new file mode 100644 index 000000000..86199d4d1 --- /dev/null +++ b/multi_object_tracking/yolo_sam/control_models/keypoint_labels.py @@ -0,0 +1,173 @@ +import logging +from control_models.base import ControlModel, get_bool +from typing import List, Dict + +logger = logging.getLogger(__name__) + + +class KeypointLabelsModel(ControlModel): + """ + Class representing a KeypointLabels control tag for YOLO model. + """ + + type = "KeyPointLabels" + model_path = ( + "yolov8n-pose.pt" # Adjust the model path to your keypoint detection model + ) + add_bboxes: bool = True + point_size: float = 1 + point_threshold: float = 0 + point_map: Dict = {} + + def __init__(self, **data): + super().__init__(**data) + + self.add_bboxes = get_bool(self.control.attr, "model_add_bboxes", "true") + self.point_size = float(self.control.attr.get("model_point_size", 1)) + self.point_threshold = float(self.control.attr.get("model_point_threshold", 0)) + self.point_map = self.build_point_mapping() + + @classmethod + def is_control_matched(cls, control) -> bool: + # Check object tag type + if control.objects[0].tag != "Image": + return False + return control.tag == cls.type + + def build_point_mapping(self): + """Build a mapping between points and Label Studio labels, e.g. +