Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Enable new model id's #886

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions inference/core/entities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
TaskType = str
ModelType = str
WorkspaceID = str
ModelID = str
27 changes: 22 additions & 5 deletions inference/core/models/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
ModelEndpointType,
get_from_url,
get_roboflow_model_data,
get_roboflow_workspace,
)
from inference.core.utils.image_utils import load_image
from inference.core.utils.onnx import get_onnxruntime_execution_providers
Expand Down Expand Up @@ -116,7 +117,13 @@ def __init__(
self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}
self.api_key = api_key if api_key else API_KEY
model_id = resolve_roboflow_model_alias(model_id=model_id)
self.dataset_id, self.version_id = model_id.split("/")
# TODO:
# Is this really all we had to do here?, think we don't even need it?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To answer this question, it seems like the version ID is primarily used in models like SAM where there are multiple versions of the foundation model that can be called. Everywhere else, it's mainly used for providing clear debug info.

if "/" in model_id:
self.dataset_id, self.version_id = model_id.split("/")
else:
self.model_id = model_id
# Model ID is only unique for a workspace
self.endpoint = model_id
self.device_id = GLOBAL_DEVICE_ID
self.cache_dir = os.path.join(cache_dir_root, self.endpoint)
Expand Down Expand Up @@ -274,10 +281,19 @@ def download_model_artifacts_from_roboflow_api(self) -> None:
"Could not find `model` key in roboflow API model description response."
)
if "environment" not in api_data:
raise ModelArtefactError(
"Could not find `environment` key in roboflow API model description response."
)
environment = get_from_url(api_data["environment"])
# Create default environment if not provided
environment = {
"PREPROCESSING": api_data.get("preprocessing", {}),
"MULTICLASS": api_data.get("multilabel", False),
#don't think we actually need this
"MODEL_NAME": api_data.get("modelName", ""),

# ClASS_MAP might be the only other thing that we would need
# "CLASS_MAP": api_data.get("classes", {}),
}
else:
# TODO: do we need to load the environment from the url or can we safely remove?
environment = get_from_url(api_data["environment"])
model_weights_response = get_from_url(api_data["model"], json_response=False)
save_bytes_in_cache(
content=model_weights_response.content,
Expand Down Expand Up @@ -308,6 +324,7 @@ def load_model_artifacts_from_cache(self) -> None:
model_id=self.endpoint,
object_pairs_hook=OrderedDict,
)

if "class_names.txt" in infer_bucket_files:
self.class_names = load_text_file_from_cache(
file="class_names.txt",
Expand Down
99 changes: 52 additions & 47 deletions inference/core/registries/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from inference.core.cache import cache
from inference.core.devices.utils import GLOBAL_DEVICE_ID
from inference.core.entities.types import DatasetID, ModelType, TaskType, VersionID
from inference.core.entities.types import ModelType, TaskType
from inference.core.env import LAMBDA, MODEL_CACHE_DIR
from inference.core.exceptions import (
MissingApiKeyError,
Expand Down Expand Up @@ -90,41 +90,51 @@ def get_model_type(
"""
model_id = resolve_roboflow_model_alias(model_id=model_id)
dataset_id, version_id = get_model_id_chunks(model_id=model_id)
lock_key, cache_path = determine_cache_paths(dataset_or_model_id=dataset_id, version_id=version_id)

if dataset_id in GENERIC_MODELS:
logger.debug(f"Loading generic model: {dataset_id}.")
return GENERIC_MODELS[dataset_id]

cached_metadata = get_model_metadata_from_cache(
dataset_id=dataset_id, version_id=version_id
cache_path=cache_path, lock_key=lock_key
)
if cached_metadata is not None:
return cached_metadata[0], cached_metadata[1]


# THis path will never be executed for a model ID
if version_id == STUB_VERSION_ID:
if api_key is None:
raise MissingApiKeyError(
"Stub model version provided but no API key was provided. API key is required to load stub models."
)
workspace_id = get_roboflow_workspace(api_key=api_key)

project_task_type = get_roboflow_dataset_type(
api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id
)
model_type = "stub"
save_model_metadata_in_cache(
dataset_id=dataset_id,
version_id=version_id,
cache_path=cache_path,
lock_key=lock_key,
project_task_type=project_task_type,
model_type=model_type,
# TODO: do we need to save the workspace_id here/for the cache path to be unique?
)
return project_task_type, model_type

api_data = get_roboflow_model_data(
api_key=api_key,
model_id=model_id,
endpoint_type=ModelEndpointType.ORT,
device_id=GLOBAL_DEVICE_ID,
).get("ort")

if api_data is None:
raise ModelArtefactError("Error loading model artifacts from Roboflow API.")
# some older projects do not have type field - hence defaulting
project_task_type = api_data.get("type", "object-detection")
project_task_type = api_data.get("taskType", "object-detection")
model_type = api_data.get("modelType")
if model_type is None or model_type == "ort":
# some very old model versions do not have modelType reported - and API respond in a generic way -
Expand All @@ -133,46 +143,46 @@ def get_model_type(
if model_type is None or project_task_type is None:
raise ModelArtefactError("Error loading model artifacts from Roboflow API.")
save_model_metadata_in_cache(
dataset_id=dataset_id,
version_id=version_id,
cache_path=cache_path,
lock_key=lock_key,
project_task_type=project_task_type,
model_type=model_type,
)

return project_task_type, model_type

def determine_cache_paths(dataset_or_model_id: str, version_id: Optional[str]) -> Tuple[str, str]:
if dataset_or_model_id and version_id:
# It's a dataset/version ID
lock_key = f"lock:metadata:dataset:{dataset_or_model_id}:{version_id}"
cache_path = construct_dataset_version_cache_path(dataset_or_model_id, version_id)
else:
# It's a model ID
lock_key = f"lock:metadata:model:{dataset_or_model_id}"
cache_path = construct_model_id_cache_path(dataset_or_model_id)

return lock_key, cache_path

def get_model_metadata_from_cache(
dataset_id: str, version_id: str
cache_path: str,
lock_key: str
) -> Optional[Tuple[TaskType, ModelType]]:
if LAMBDA:
return _get_model_metadata_from_cache(
dataset_id=dataset_id, version_id=version_id
)
with cache.lock(
f"lock:metadata:{dataset_id}:{version_id}", expire=CACHE_METADATA_LOCK_TIMEOUT
):
return _get_model_metadata_from_cache(
dataset_id=dataset_id, version_id=version_id
)
return _get_model_metadata_from_cache(cache_path=cache_path)

with cache.lock(lock_key, expire=CACHE_METADATA_LOCK_TIMEOUT):
return _get_model_metadata_from_cache(cache_path=cache_path)


def _get_model_metadata_from_cache(
dataset_id: str, version_id: str
) -> Optional[Tuple[TaskType, ModelType]]:
model_type_cache_path = construct_model_type_cache_path(
dataset_id=dataset_id, version_id=version_id
)
if not os.path.isfile(model_type_cache_path):
def _get_model_metadata_from_cache(cache_path: str) -> Optional[Tuple[TaskType, ModelType]]:
if not os.path.isfile(cache_path):
return None
try:
model_metadata = read_json(path=model_type_cache_path)
model_metadata = read_json(path=cache_path)
if model_metadata_content_is_invalid(content=model_metadata):
return None
return model_metadata[PROJECT_TASK_TYPE_KEY], model_metadata[MODEL_TYPE_KEY]
except ValueError as e:
logger.warning(
f"Could not load model description from cache under path: {model_type_cache_path} - decoding issue: {e}."
f"Could not load model description from cache under path: {cache_path} - decoding issue: {e}."
)
return None

Expand All @@ -193,49 +203,44 @@ def model_metadata_content_is_invalid(content: Optional[Union[list, dict]]) -> b


def save_model_metadata_in_cache(
dataset_id: DatasetID,
version_id: VersionID,
cache_path: str,
lock_key: str,
project_task_type: TaskType,
model_type: ModelType,
) -> None:
if LAMBDA:
_save_model_metadata_in_cache(
dataset_id=dataset_id,
version_id=version_id,
cache_path=cache_path,
project_task_type=project_task_type,
model_type=model_type,
)
return None
with cache.lock(
f"lock:metadata:{dataset_id}:{version_id}", expire=CACHE_METADATA_LOCK_TIMEOUT
):

with cache.lock(lock_key, expire=CACHE_METADATA_LOCK_TIMEOUT):
_save_model_metadata_in_cache(
dataset_id=dataset_id,
version_id=version_id,
cache_path=cache_path,
project_task_type=project_task_type,
model_type=model_type,
)
return None


def _save_model_metadata_in_cache(
dataset_id: DatasetID,
version_id: VersionID,
cache_path: str,
project_task_type: TaskType,
model_type: ModelType,
) -> None:
model_type_cache_path = construct_model_type_cache_path(
dataset_id=dataset_id, version_id=version_id
)
metadata = {
PROJECT_TASK_TYPE_KEY: project_task_type,
MODEL_TYPE_KEY: model_type,
}
dump_json(
path=model_type_cache_path, content=metadata, allow_override=True, indent=4
path=cache_path, content=metadata, allow_override=True, indent=4
)

def construct_model_id_cache_path(model_id: str) -> str:
"""Constructs the cache path for a given model ID."""
return os.path.join(MODEL_CACHE_DIR, "models", model_id, "model_type.json")

def construct_model_type_cache_path(dataset_id: str, version_id: str) -> str:
cache_dir = os.path.join(MODEL_CACHE_DIR, dataset_id, version_id)
return os.path.join(cache_dir, "model_type.json")
def construct_dataset_version_cache_path(dataset_id: str, version_id: str) -> str:
"""Constructs the cache path for a given dataset ID and version ID."""
return os.path.join(MODEL_CACHE_DIR, dataset_id, version_id, "model_type.json")
8 changes: 7 additions & 1 deletion inference/core/roboflow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,17 @@ def get_roboflow_model_data(
("nocache", "true"),
("device", device_id),
("dynamic", "true"),
("type", endpoint_type.value),
("model", model_id),
]
if api_key is not None:
params.append(("api_key", api_key))
api_url = _add_params_to_url(
url=f"{API_BASE_URL}/{endpoint_type.value}/{model_id}",
url=f"{API_BASE_URL}/getWeights",
params=params,
)
print("api_url", api_url)
Copy link
Contributor

@shantanubala shantanubala Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("api_url", api_url)

Can delete this before merge.


api_data = _get_from_url(url=api_url)
cache.set(
api_data_cache_key,
Expand Down Expand Up @@ -596,7 +600,9 @@ def get_from_url(

def _get_from_url(url: str, json_response: bool = True) -> Union[Response, dict]:
response = requests.get(wrap_url(url))

api_key_safe_raise_for_status(response=response)

if json_response:
return response.json()
return response
Expand Down
30 changes: 27 additions & 3 deletions inference/core/utils/roboflow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,35 @@
from typing import Tuple
from typing import Optional, Tuple, Union

from inference.core.entities.types import DatasetID, VersionID
from inference.core.exceptions import InvalidModelIDError


def get_model_id_chunks(model_id: str) -> Tuple[DatasetID, VersionID]:
def get_model_id_chunks(
model_id: str,
) -> Union[Tuple[DatasetID, VersionID], Tuple[str, None]]:
"""Parse a model ID into its components.

Args:
model_id (str): The model identifier, either in format "dataset/version"
or a plain string for the new model IDs

Returns:
Union[Tuple[DatasetID, VersionID], Tuple[str, None]]:
For traditional IDs: (dataset_id, version_id)
For new string IDs: (model_id, None)

Raises:
InvalidModelIDError: If traditional model ID format is invalid
"""
if "/" not in model_id:
# Handle new style model IDs that are just strings
return model_id, None

# Handle traditional dataset/version model IDs
model_id_chunks = model_id.split("/")
if len(model_id_chunks) != 2:
raise InvalidModelIDError(f"Model ID: `{model_id}` is invalid.")
raise InvalidModelIDError(
f"Model ID: `{model_id}` is invalid. Expected format: 'dataset/version' or 'model_name'"
)

return model_id_chunks[0], model_id_chunks[1]
1 change: 1 addition & 0 deletions inference/models/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
get_from_url,
get_roboflow_base_lora,
get_roboflow_model_data,
get_roboflow_workspace,
)
from inference.core.utils.image_utils import load_image_rgb

Expand Down
Loading
Loading