Skip to content

Commit

Permalink
Merge pull request #929 from roboflow/new-model-ids-2
Browse files Browse the repository at this point in the history
Enable new type of models - Roboflow instant models
  • Loading branch information
grzegorz-roboflow authored Jan 16, 2025
2 parents 779cc29 + 5b9663d commit 3f08a65
Show file tree
Hide file tree
Showing 14 changed files with 212 additions and 48 deletions.
3 changes: 2 additions & 1 deletion inference/core/entities/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
DatasetID = str
VersionID = str
ModelID = str
VersionID = int
TaskType = str
ModelType = str
WorkspaceID = str
13 changes: 13 additions & 0 deletions inference/core/managers/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from inference.core.env import DISABLE_PREPROC_AUTO_ORIENT
from inference.core.managers.base import ModelManager
from inference.core.registries.base import ModelRegistry
from inference.core.utils.roboflow import get_model_id_chunks
from inference.models.aliases import resolve_roboflow_model_alias

ACTIVE_LEARNING_ELIGIBLE_PARAM = "active_learning_eligible"
Expand Down Expand Up @@ -39,10 +40,14 @@ async def infer_from_request(
active_learning_disabled_for_request = getattr(
request, DISABLE_ACTIVE_LEARNING_PARAM, False
)
# TODO: active learning is disabled for instant models; to be enabled in the future
_, version_id = get_model_id_chunks(model_id=model_id)
roboflow_instant_model = version_id is None
if (
not active_learning_eligible
or active_learning_disabled_for_request
or request.api_key is None
or roboflow_instant_model
):
return prediction
self.register(prediction=prediction, model_id=model_id, request=request)
Expand All @@ -58,10 +63,14 @@ def infer_from_request_sync(
active_learning_disabled_for_request = getattr(
request, DISABLE_ACTIVE_LEARNING_PARAM, False
)
# TODO: active learning is disabled for instant models; to be enabled in the future
_, version_id = get_model_id_chunks(model_id=model_id)
roboflow_instant_model = version_id is None
if (
not active_learning_eligible
or active_learning_disabled_for_request
or request.api_key is None
or roboflow_instant_model
):
return prediction
self.register(prediction=prediction, model_id=model_id, request=request)
Expand Down Expand Up @@ -196,10 +205,14 @@ def infer_from_request_sync(
prediction = super().infer_from_request_sync(
model_id=model_id, request=request, **kwargs
)
# TODO: active learning is disabled for instant models; to be enabled in the future
_, version_id = get_model_id_chunks(model_id=model_id)
roboflow_instant_model = version_id is None
if (
not active_learning_eligible
or active_learning_disabled_for_request
or request.api_key is None
or roboflow_instant_model
):
return prediction
if BACKGROUND_TASKS_PARAM not in kwargs:
Expand Down
3 changes: 2 additions & 1 deletion inference/core/models/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from inference.core.utils.image_utils import load_image
from inference.core.utils.onnx import get_onnxruntime_execution_providers
from inference.core.utils.preprocess import letterbox_image, prepare
from inference.core.utils.roboflow import get_model_id_chunks
from inference.core.utils.visualisation import draw_detection_predictions
from inference.models.aliases import resolve_roboflow_model_alias

Expand Down Expand Up @@ -116,7 +117,7 @@ def __init__(
self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}
self.api_key = api_key if api_key else API_KEY
model_id = resolve_roboflow_model_alias(model_id=model_id)
self.dataset_id, self.version_id = model_id.split("/")
self.dataset_id, self.version_id = get_model_id_chunks(model_id=model_id)
self.endpoint = model_id
self.device_id = GLOBAL_DEVICE_ID
self.cache_dir = os.path.join(cache_dir_root, self.endpoint)
Expand Down
3 changes: 2 additions & 1 deletion inference/core/models/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from inference.core.models.base import Model
from inference.core.models.types import PreprocessReturnMetadata
from inference.core.utils.image_utils import encode_image_to_jpeg_bytes
from inference.core.utils.roboflow import get_model_id_chunks


class ModelStub(Model):
def __init__(self, model_id: str, api_key: str):
super().__init__()
self.model_id = model_id
self.api_key = api_key
self.dataset_id, self.version_id = model_id.split("/")
self.dataset_id, self.version_id = get_model_id_chunks(model_id=model_id)
self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}
initialise_cache(model_id=model_id)

Expand Down
56 changes: 38 additions & 18 deletions inference/core/registries/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

from inference.core.cache import cache
from inference.core.devices.utils import GLOBAL_DEVICE_ID
from inference.core.entities.types import DatasetID, ModelType, TaskType, VersionID
from inference.core.entities.types import (
DatasetID,
ModelID,
ModelType,
TaskType,
VersionID,
)
from inference.core.env import LAMBDA, MODEL_CACHE_DIR
from inference.core.exceptions import (
MissingApiKeyError,
Expand All @@ -19,6 +25,7 @@
PROJECT_TASK_TYPE_KEY,
ModelEndpointType,
get_roboflow_dataset_type,
get_roboflow_instant_model_data,
get_roboflow_model_data,
get_roboflow_workspace,
)
Expand Down Expand Up @@ -49,7 +56,7 @@ class RoboflowModelRegistry(ModelRegistry):
then returns a model class based on the model type.
"""

def get_model(self, model_id: str, api_key: str) -> Model:
def get_model(self, model_id: ModelID, api_key: str) -> Model:
"""Returns the model class based on the given model id and API key.
Args:
Expand All @@ -70,7 +77,7 @@ def get_model(self, model_id: str, api_key: str) -> Model:


def get_model_type(
model_id: str,
model_id: ModelID,
api_key: Optional[str] = None,
) -> Tuple[TaskType, ModelType]:
"""Retrieves the model type based on the given model ID and API key.
Expand Down Expand Up @@ -115,16 +122,24 @@ def get_model_type(
model_type=model_type,
)
return project_task_type, model_type
api_data = get_roboflow_model_data(
api_key=api_key,
model_id=model_id,
endpoint_type=ModelEndpointType.ORT,
device_id=GLOBAL_DEVICE_ID,
).get("ort")

if version_id is not None:
api_data = get_roboflow_model_data(
api_key=api_key,
model_id=model_id,
endpoint_type=ModelEndpointType.ORT,
device_id=GLOBAL_DEVICE_ID,
).get("ort")
project_task_type = api_data.get("type", "object-detection")
else:
api_data = get_roboflow_instant_model_data(
api_key=api_key,
model_id=model_id,
)
project_task_type = api_data.get("taskType", "object-detection")
if api_data is None:
raise ModelArtefactError("Error loading model artifacts from Roboflow API.")
# some older projects do not have type field - hence defaulting
project_task_type = api_data.get("type", "object-detection")
model_type = api_data.get("modelType")
if model_type is None or model_type == "ort":
# some very old model versions do not have modelType reported - and API respond in a generic way -
Expand All @@ -143,7 +158,8 @@ def get_model_type(


def get_model_metadata_from_cache(
dataset_id: str, version_id: str
dataset_id: Union[DatasetID, ModelID],
version_id: Optional[VersionID],
) -> Optional[Tuple[TaskType, ModelType]]:
if LAMBDA:
return _get_model_metadata_from_cache(
Expand All @@ -158,7 +174,7 @@ def get_model_metadata_from_cache(


def _get_model_metadata_from_cache(
dataset_id: str, version_id: str
dataset_id: Union[DatasetID, ModelID], version_id: Optional[VersionID]
) -> Optional[Tuple[TaskType, ModelType]]:
model_type_cache_path = construct_model_type_cache_path(
dataset_id=dataset_id, version_id=version_id
Expand Down Expand Up @@ -193,8 +209,8 @@ def model_metadata_content_is_invalid(content: Optional[Union[list, dict]]) -> b


def save_model_metadata_in_cache(
dataset_id: DatasetID,
version_id: VersionID,
dataset_id: Union[DatasetID, ModelID],
version_id: Optional[VersionID],
project_task_type: TaskType,
model_type: ModelType,
) -> None:
Expand All @@ -219,8 +235,8 @@ def save_model_metadata_in_cache(


def _save_model_metadata_in_cache(
dataset_id: DatasetID,
version_id: VersionID,
dataset_id: Union[DatasetID, ModelID],
version_id: Optional[VersionID],
project_task_type: TaskType,
model_type: ModelType,
) -> None:
Expand All @@ -236,6 +252,10 @@ def _save_model_metadata_in_cache(
)


def construct_model_type_cache_path(dataset_id: str, version_id: str) -> str:
cache_dir = os.path.join(MODEL_CACHE_DIR, dataset_id, version_id)
def construct_model_type_cache_path(
dataset_id: Union[DatasetID, ModelID], version_id: Optional[VersionID]
) -> str:
cache_dir = os.path.join(
MODEL_CACHE_DIR, dataset_id, version_id if version_id else ""
)
return os.path.join(cache_dir, "model_type.json")
34 changes: 34 additions & 0 deletions inference/core/roboflow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from inference.core.cache.base import BaseCache
from inference.core.entities.types import (
DatasetID,
ModelID,
ModelType,
TaskType,
VersionID,
Expand Down Expand Up @@ -246,6 +247,39 @@ def get_roboflow_model_data(
return api_data


@wrap_roboflow_api_errors()
def get_roboflow_instant_model_data(
api_key: str,
model_id: ModelID,
cache_prefix: str = "roboflow_api_data",
) -> dict:
api_data_cache_key = f"{cache_prefix}:{model_id}"
api_data = cache.get(api_data_cache_key)
if api_data is not None:
logger.debug(f"Loaded model data from cache with key: {api_data_cache_key}.")
return api_data
else:
params = [
("model", model_id),
]
if api_key is not None:
params.append(("api_key", api_key))
api_url = _add_params_to_url(
url=f"{API_BASE_URL}/getWeights",
params=params,
)
api_data = _get_from_url(url=api_url)
cache.set(
api_data_cache_key,
api_data,
expire=10,
)
logger.debug(
f"Loaded model data from Roboflow API and saved to cache with key: {api_data_cache_key}."
)
return api_data


@wrap_roboflow_api_errors()
def get_roboflow_base_lora(
api_key: str, repo: str, revision: str, device_id: str
Expand Down
29 changes: 25 additions & 4 deletions inference/core/utils/roboflow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
from typing import Tuple
from typing import Optional, Tuple, Union

from inference.core.entities.types import DatasetID, VersionID
from inference.core.entities.types import DatasetID, ModelID, VersionID
from inference.core.exceptions import InvalidModelIDError


def get_model_id_chunks(model_id: str) -> Tuple[DatasetID, VersionID]:
def get_model_id_chunks(
model_id: str,
) -> Tuple[Union[DatasetID, ModelID], Optional[VersionID]]:
model_id_chunks = model_id.split("/")
if len(model_id_chunks) != 2:
raise InvalidModelIDError(f"Model ID: `{model_id}` is invalid.")
return model_id_chunks[0], model_id_chunks[1]
dataset_id, version_id = model_id_chunks[0], model_id_chunks[1]
if dataset_id.lower() in {
"clip",
"cogvlm",
"doctr",
"doctr_rec",
"doctr_det",
"gaze",
"grounding_dino",
"sam",
"sam2",
"owlv2",
"trocr",
"yolo_world",
}:
return dataset_id, version_id
try:
return dataset_id, str(int(version_id))
except Exception:
return model_id, None
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
from typing import List, Literal, Optional, Type, Union

import paho.mqtt.client as mqtt
Expand Down Expand Up @@ -93,10 +94,23 @@ def describe_outputs(cls) -> List[OutputDefinition]:
OutputDefinition(name="message", kind=[STRING_KIND]),
]

@classmethod
def get_execution_engine_compatibility(cls) -> Optional[str]:
return ">=1.3.0,<2.0.0"


class MQTTWriterSinkBlockV1(WorkflowBlock):
def __init__(self):
self.mqtt_client: Optional[mqtt.Client] = None
self._connected = threading.Event()

def __del__(self):
try:
if self.mqtt_client is not None:
self.mqtt_client.disconnect()
self.mqtt_client.loop_stop()
except Exception as e:
logger.error("Failed to disconnect MQTT client: %s", e)

@classmethod
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
Expand Down Expand Up @@ -125,7 +139,12 @@ def run(
)
try:
# TODO: blocking, consider adding fire_and_forget like in OPC writer
print("Connecting")
self.mqtt_client.connect(host, port)
self.mqtt_client.loop_start()

if not self._connected.wait(timeout=timeout):
raise Exception("Connection timeout")
except Exception as e:
logger.error("Failed to connect to MQTT broker: %s", e)
return {
Expand All @@ -136,7 +155,10 @@ def run(
if not self.mqtt_client.is_connected():
try:
# TODO: blocking
print("Reconnecting")
self.mqtt_client.reconnect()
if not self._connected.wait(timeout=timeout):
raise Exception("Connection timeout")
except Exception as e:
logger.error("Failed to connect to MQTT broker: %s", e)
return {
Expand All @@ -163,8 +185,10 @@ def run(

def mqtt_on_connect(self, client, userdata, flags, reason_code, properties=None):
logger.info("Connected with result code %s", reason_code)
self._connected.set()

def mqtt_on_connect_fail(
self, client, userdata, flags, reason_code, properties=None
):
logger.error(f"Failed to connect with result code %s", reason_code)
self._connected.clear()
Loading

0 comments on commit 3f08a65

Please sign in to comment.