diff --git a/docs/my-website/docs/proxy/jwt_auth_arch.md b/docs/my-website/docs/proxy/jwt_auth_arch.md
new file mode 100644
index 000000000000..e48fa71f8be5
--- /dev/null
+++ b/docs/my-website/docs/proxy/jwt_auth_arch.md
@@ -0,0 +1,116 @@
+import Image from '@theme/IdealImage';
+import Tabs from '@theme/Tabs';
+import TabItem from '@theme/TabItem';
+
+# Control Model Access with SSO (Azure AD/Keycloak/etc.)
+
+:::info
+
+✨ JWT Auth is on LiteLLM Enterprise
+
+[Enterprise Pricing](https://www.litellm.ai/#pricing)
+
+[Get free 7-day trial key](https://www.litellm.ai/#trial)
+
+:::
+
+
+
+## Example Token
+
+
+
+
+```bash
+{
+ "sub": "1234567890",
+ "name": "John Doe",
+ "email": "john.doe@example.com",
+ "roles": ["basic_user"] # 👈 ROLE
+}
+```
+
+
+
+```bash
+{
+ "sub": "1234567890",
+ "name": "John Doe",
+ "email": "john.doe@example.com",
+ "resource_access": {
+ "litellm-test-client-id": {
+ "roles": ["basic_user"] # 👈 ROLE
+ }
+ }
+}
+```
+
+
+
+## Proxy Configuration
+
+
+
+
+```yaml
+general_settings:
+ enable_jwt_auth: True
+ litellm_jwtauth:
+ user_roles_jwt_field: "roles" # the field in the JWT that contains the roles
+ user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM
+ enforce_rbac: true # if true, will check if the user has the correct role to access the model
+
+ role_permissions: # control what models are allowed for each role
+ - role: internal_user
+ models: ["anthropic-claude"]
+
+model_list:
+ - model: anthropic-claude
+ litellm_params:
+ model: claude-3-5-haiku-20241022
+ - model: openai-gpt-4o
+ litellm_params:
+ model: gpt-4o
+```
+
+
+
+
+```yaml
+general_settings:
+ enable_jwt_auth: True
+ litellm_jwtauth:
+ user_roles_jwt_field: "resource_access.litellm-test-client-id.roles" # the field in the JWT that contains the roles
+ user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM
+ enforce_rbac: true # if true, will check if the user has the correct role to access the model
+
+ role_permissions: # control what models are allowed for each role
+ - role: internal_user
+ models: ["anthropic-claude"]
+
+model_list:
+ - model: anthropic-claude
+ litellm_params:
+ model: claude-3-5-haiku-20241022
+ - model: openai-gpt-4o
+ litellm_params:
+ model: gpt-4o
+```
+
+
+
+
+
+## How it works
+
+1. Specify JWT_PUBLIC_KEY_URL - This is the public keys endpoint of your OpenID provider. For Azure AD it's `https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`.
+
+1. Map JWT roles to LiteLLM roles - Done via `user_roles_jwt_field` and `user_allowed_roles`
+ - Currently just `internal_user` is supported for role mapping.
+2. Specify model access:
+ - `role_permissions`: control what models are allowed for each role.
+ - `role`: the LiteLLM role to control access for. Allowed roles = ["internal_user", "proxy_admin", "team"]
+ - `models`: list of models that the role is allowed to access.
+ - `model_list`: parent list of models on the proxy. [Learn more](./configs.md#llm-configs-model_list)
+
+3. Model Checks: The proxy will run validation checks on the received JWT. [Code](https://github.com/BerriAI/litellm/blob/3a4f5b23b5025b87b6d969f2485cc9bc741f9ba6/litellm/proxy/auth/user_api_key_auth.py#L284)
\ No newline at end of file
diff --git a/docs/my-website/docs/proxy/model_access.md b/docs/my-website/docs/proxy/model_access.md
index 545d74865bbc..854baa2edbf3 100644
--- a/docs/my-website/docs/proxy/model_access.md
+++ b/docs/my-website/docs/proxy/model_access.md
@@ -344,3 +344,6 @@ curl -i http://localhost:4000/v1/chat/completions \
+
+
+## [Role Based Access Control (RBAC)](./jwt_auth_arch)
\ No newline at end of file
diff --git a/docs/my-website/docs/proxy/token_auth.md b/docs/my-website/docs/proxy/token_auth.md
index ffff2694fe5f..df57cadd3b34 100644
--- a/docs/my-website/docs/proxy/token_auth.md
+++ b/docs/my-website/docs/proxy/token_auth.md
@@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
-# JWT-based Auth
+# SSO - JWT-based Auth
Use JWT's to auth admins / projects into the proxy.
@@ -183,6 +183,24 @@ Expected Scope in JWT:
}
```
+### Control Model Access
+
+```yaml
+general_settings:
+ enable_jwt_auth: True
+ litellm_jwtauth:
+ user_roles_jwt_field: "resource_access.litellm-test-client-id.roles"
+ user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM
+ enforce_rbac: true # if true, will check if the user has the correct role to access the model + endpoint
+
+ role_permissions: # control what models + endpointsare allowed for each role
+ - role: internal_user
+ models: ["anthropic-claude"]
+```
+
+
+**[Architecture Diagram (Control Model Access)](./jwt_auth_arch)**
+
## Advanced - Allowed Routes
Configure which routes a JWT can access via the config.
diff --git a/docs/my-website/img/control_model_access_jwt.png b/docs/my-website/img/control_model_access_jwt.png
new file mode 100644
index 000000000000..ab6cda53961f
Binary files /dev/null and b/docs/my-website/img/control_model_access_jwt.png differ
diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js
index 29d674f3f450..d20f2a73e4da 100644
--- a/docs/my-website/sidebars.js
+++ b/docs/my-website/sidebars.js
@@ -51,7 +51,7 @@ const sidebars = {
{
type: "category",
label: "Architecture",
- items: ["proxy/architecture", "proxy/db_info", "router_architecture", "proxy/user_management_heirarchy"],
+ items: ["proxy/architecture", "proxy/db_info", "router_architecture", "proxy/user_management_heirarchy", "proxy/jwt_auth_arch"],
},
{
type: "link",
diff --git a/litellm/litellm_core_utils/dot_notation_indexing.py b/litellm/litellm_core_utils/dot_notation_indexing.py
new file mode 100644
index 000000000000..fda37f65007d
--- /dev/null
+++ b/litellm/litellm_core_utils/dot_notation_indexing.py
@@ -0,0 +1,59 @@
+"""
+This file contains the logic for dot notation indexing.
+
+Used by JWT Auth to get the user role from the token.
+"""
+
+from typing import Any, Dict, Optional, TypeVar
+
+T = TypeVar("T")
+
+
+def get_nested_value(
+ data: Dict[str, Any], key_path: str, default: Optional[T] = None
+) -> Optional[T]:
+ """
+ Retrieves a value from a nested dictionary using dot notation.
+
+ Args:
+ data: The dictionary to search in
+ key_path: The path to the value using dot notation (e.g., "a.b.c")
+ default: The default value to return if the path is not found
+
+ Returns:
+ The value at the specified path, or the default value if not found
+
+ Example:
+ >>> data = {"a": {"b": {"c": "value"}}}
+ >>> get_nested_value(data, "a.b.c")
+ 'value'
+ >>> get_nested_value(data, "a.b.d", "default")
+ 'default'
+ """
+ if not key_path:
+ return default
+
+ # Remove metadata. prefix if it exists
+ key_path = (
+ key_path.replace("metadata.", "", 1)
+ if key_path.startswith("metadata.")
+ else key_path
+ )
+
+ # Split the key path into parts
+ parts = key_path.split(".")
+
+ # Traverse through the dictionary
+ current: Any = data
+ for part in parts:
+ try:
+ current = current[part]
+ except (KeyError, TypeError):
+ return default
+
+ # If default is None, we can return any type
+ if default is None:
+ return current
+
+ # Otherwise, ensure the type matches the default
+ return current if isinstance(current, type(default)) else default
diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml
index 4182ba86feff..423032ac86fd 100644
--- a/litellm/proxy/_new_secret_config.yaml
+++ b/litellm/proxy/_new_secret_config.yaml
@@ -1,18 +1,16 @@
model_list:
- - model_name: gpt-3.5-turbo-end-user-test
- litellm_params:
- model: gpt-3.5-turbo
- region_name: "eu"
- model_info:
- id: "1"
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
- timeout: 2
- num_retries: 0
- model_name: anthropic-claude
litellm_params:
- model: anthropic.claude-3-sonnet-20240229-v1:0
-
-litellm_settings:
- callbacks: ["langsmith"]
\ No newline at end of file
+ model: claude-3-5-haiku-20241022
+ - model_name: groq/*
+ litellm_params:
+ model: groq/*
+ api_key: os.environ/GROQ_API_KEY
+ mock_response: Hi!
+ - model_name: deepseek/*
+ litellm_params:
+ model: deepseek/*
+ api_key: os.environ/DEEPSEEK_API_KEY
diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py
index 5a456aec9715..bf3f6b6543a5 100644
--- a/litellm/proxy/_types.py
+++ b/litellm/proxy/_types.py
@@ -445,6 +445,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
user_id_jwt_field: Optional[str] = None
user_email_jwt_field: Optional[str] = None
user_allowed_email_domain: Optional[str] = None
+ user_roles_jwt_field: Optional[str] = None
+ user_allowed_roles: Optional[List[str]] = None
user_id_upsert: bool = Field(
default=False, description="If user doesn't exist, upsert them into the db."
)
@@ -458,11 +460,19 @@ def __init__(self, **kwargs: Any) -> None:
allowed_keys = self.__annotations__.keys()
invalid_keys = set(kwargs.keys()) - allowed_keys
+ user_roles_jwt_field = kwargs.get("user_roles_jwt_field")
+ user_allowed_roles = kwargs.get("user_allowed_roles")
if invalid_keys:
raise ValueError(
f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}."
)
+ if (user_roles_jwt_field is not None and user_allowed_roles is None) or (
+ user_roles_jwt_field is None and user_allowed_roles is not None
+ ):
+ raise ValueError(
+ "user_allowed_roles must be provided if user_roles_jwt_field is set."
+ )
super().__init__(**kwargs)
@@ -2335,3 +2345,15 @@ class ClientSideFallbackModel(TypedDict, total=False):
ALL_FALLBACK_MODEL_VALUES = Union[str, ClientSideFallbackModel]
+
+
+RBAC_ROLES = Literal[
+ LitellmUserRoles.PROXY_ADMIN,
+ LitellmUserRoles.TEAM,
+ LitellmUserRoles.INTERNAL_USER,
+]
+
+
+class RoleBasedPermissions(TypedDict):
+ role: Required[RBAC_ROLES]
+ models: Required[List[str]]
diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py
index d6bbf760bde9..8d0132709c18 100644
--- a/litellm/proxy/auth/auth_checks.py
+++ b/litellm/proxy/auth/auth_checks.py
@@ -12,7 +12,7 @@
import re
import time
import traceback
-from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast
from fastapi import status
from pydantic import BaseModel
@@ -24,6 +24,7 @@
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
+ RBAC_ROLES,
CallInfo,
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
@@ -35,6 +36,7 @@
LitellmUserRoles,
ProxyErrorTypes,
ProxyException,
+ RoleBasedPermissions,
UserAPIKeyAuth,
)
from litellm.proxy.auth.route_checks import RouteChecks
@@ -100,6 +102,14 @@ async def common_checks(
llm_router=llm_router,
)
+ ## 2.1 If user can call model (if personal key)
+ if team_object is None and user_object is not None:
+ await can_user_call_model(
+ model=_model,
+ llm_router=llm_router,
+ user_object=user_object,
+ )
+
# 3. If team is in budget
await _team_max_budget_check(
team_object=team_object,
@@ -391,6 +401,30 @@ def _update_last_db_access_time(
last_db_access_time[key] = (value, time.time())
+def get_role_based_models(
+ rbac_role: RBAC_ROLES,
+ general_settings: dict,
+) -> Optional[List[str]]:
+ """
+ Get the models allowed for a user role.
+
+ Used by JWT Auth.
+ """
+
+ role_based_permissions = cast(
+ Optional[List[RoleBasedPermissions]],
+ general_settings.get("role_permissions", []),
+ )
+ if role_based_permissions is None:
+ return None
+
+ for role_based_permission in role_based_permissions:
+ if role_based_permission["role"] == rbac_role:
+ return role_based_permission["models"]
+
+ return None
+
+
@log_db_metrics
async def get_user_object(
user_id: str,
@@ -836,11 +870,10 @@ async def get_org_object(
)
-async def can_key_call_model(
+async def _can_object_call_model(
model: str,
- llm_model_list: Optional[list],
- valid_token: UserAPIKeyAuth,
- llm_router: Optional[litellm.Router],
+ llm_router: Optional[Router],
+ models: List[str],
) -> Literal[True]:
"""
Checks if token can call a given model
@@ -855,9 +888,6 @@ async def can_key_call_model(
model = litellm.model_alias_map[model]
## check if model in allowed model names
- verbose_proxy_logger.debug(
- f"LLM Model List pre access group check: {llm_model_list}"
- )
from collections import defaultdict
access_groups: Dict[str, List[str]] = defaultdict(list)
@@ -868,13 +898,13 @@ async def can_key_call_model(
len(access_groups) > 0 and llm_router is not None
): # check if token contains any model access groups
for idx, m in enumerate(
- valid_token.models
+ models
): # loop token models, if any of them are an access group add the access group
if m in access_groups:
return True
# Filter out models that are access_groups
- filtered_models = [m for m in valid_token.models if m not in access_groups]
+ filtered_models = [m for m in models if m not in access_groups]
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
@@ -885,25 +915,61 @@ async def can_key_call_model(
all_model_access: bool = False
- if (
- len(filtered_models) == 0 and len(valid_token.models) == 0
- ) or "*" in filtered_models:
+ if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models:
all_model_access = True
if model is not None and model not in filtered_models and all_model_access is False:
raise ProxyException(
- message=f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}",
+ message=f"API Key not allowed to access model. This token can only access models={models}. Tried to access {model}",
type=ProxyErrorTypes.key_model_access_denied,
param="model",
code=status.HTTP_401_UNAUTHORIZED,
)
- valid_token.models = filtered_models
+
verbose_proxy_logger.debug(
- f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}"
+ f"filtered allowed_models: {filtered_models}; models: {models}"
)
return True
+async def can_key_call_model(
+ model: str,
+ llm_model_list: Optional[list],
+ valid_token: UserAPIKeyAuth,
+ llm_router: Optional[litellm.Router],
+) -> Literal[True]:
+ """
+ Checks if token can call a given model
+
+ Returns:
+ - True: if token allowed to call model
+
+ Raises:
+ - Exception: If token not allowed to call model
+ """
+ return await _can_object_call_model(
+ model=model,
+ llm_router=llm_router,
+ models=valid_token.models,
+ )
+
+
+async def can_user_call_model(
+ model: str,
+ llm_router: Optional[Router],
+ user_object: Optional[LiteLLM_UserTable],
+) -> Literal[True]:
+
+ if user_object is None:
+ return True
+
+ return await _can_object_call_model(
+ model=model,
+ llm_router=llm_router,
+ models=user_object.models,
+ )
+
+
async def is_valid_fallback_model(
model: str,
llm_router: Optional[Router],
@@ -1161,7 +1227,11 @@ def _model_custom_llm_provider_matches_wildcard_pattern(
- `model=claude-3-5-sonnet-20240620`
- `allowed_model_pattern=anthropic/*`
"""
- model, custom_llm_provider, _, _ = get_llm_provider(model=model)
+ try:
+ model, custom_llm_provider, _, _ = get_llm_provider(model=model)
+ except Exception:
+ return False
+
return is_model_allowed_by_pattern(
model=f"{custom_llm_provider}/{model}",
allowed_model_pattern=allowed_model_pattern,
diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py
index cf5701154676..bcda413b68be 100644
--- a/litellm/proxy/auth/handle_jwt.py
+++ b/litellm/proxy/auth/handle_jwt.py
@@ -16,8 +16,10 @@
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
+from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import (
+ RBAC_ROLES,
JWKKeyValue,
JWTKeyItem,
LiteLLM_JWTAuth,
@@ -59,7 +61,7 @@ def is_jwt(self, token: str):
parts = token.split(".")
return len(parts) == 3
- def get_rbac_role(self, token: dict) -> Optional[LitellmUserRoles]:
+ def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]:
"""
Returns the RBAC role the token 'belongs' to.
@@ -78,12 +80,18 @@ def get_rbac_role(self, token: dict) -> Optional[LitellmUserRoles]:
"""
scopes = self.get_scopes(token=token)
is_admin = self.is_admin(scopes=scopes)
+ user_roles = self.get_user_roles(token=token, default_value=None)
+
if is_admin:
return LitellmUserRoles.PROXY_ADMIN
elif self.get_team_id(token=token, default_value=None) is not None:
return LitellmUserRoles.TEAM
elif self.get_user_id(token=token, default_value=None) is not None:
return LitellmUserRoles.INTERNAL_USER
+ elif user_roles is not None and self.is_allowed_user_role(
+ user_roles=user_roles
+ ):
+ return LitellmUserRoles.INTERNAL_USER
return None
@@ -166,6 +174,43 @@ def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str
user_id = default_value
return user_id
+ def get_user_roles(
+ self, token: dict, default_value: Optional[List[str]]
+ ) -> Optional[List[str]]:
+ """
+ Returns the user role from the token.
+
+ Set via 'user_roles_jwt_field' in the config.
+ """
+ try:
+ if self.litellm_jwtauth.user_roles_jwt_field is not None:
+ user_roles = get_nested_value(
+ data=token,
+ key_path=self.litellm_jwtauth.user_roles_jwt_field,
+ default=default_value,
+ )
+ else:
+ user_roles = default_value
+ except KeyError:
+ user_roles = default_value
+ return user_roles
+
+ def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool:
+ """
+ Returns the user role from the token.
+
+ Set via 'user_allowed_roles' in the config.
+ """
+ if (
+ user_roles is not None
+ and self.litellm_jwtauth.user_allowed_roles is not None
+ and any(
+ role in self.litellm_jwtauth.user_allowed_roles for role in user_roles
+ )
+ ):
+ return True
+ return False
+
def get_user_email(
self, token: dict, default_value: Optional[str]
) -> Optional[str]:
diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py
index 33247308f647..7d499af5b266 100644
--- a/litellm/proxy/auth/user_api_key_auth.py
+++ b/litellm/proxy/auth/user_api_key_auth.py
@@ -33,6 +33,7 @@
get_end_user_object,
get_key_object,
get_org_object,
+ get_role_based_models,
get_team_object,
get_user_object,
is_valid_fallback_model,
@@ -281,9 +282,34 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
return LitellmUserRoles.TEAM
+def can_rbac_role_call_model(
+ rbac_role: RBAC_ROLES,
+ general_settings: dict,
+ model: Optional[str],
+) -> Literal[True]:
+ """
+ Checks if user is allowed to access the model, based on their role.
+ """
+ role_based_models = get_role_based_models(
+ rbac_role=rbac_role, general_settings=general_settings
+ )
+ if role_based_models is None or model is None:
+ return True
+
+ if model not in role_based_models:
+ raise HTTPException(
+ status_code=403,
+ detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
+ )
+
+ return True
+
+
async def _jwt_auth_user_api_key_auth_builder(
api_key: str,
jwt_handler: JWTHandler,
+ request_data: dict,
+ general_settings: dict,
route: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
@@ -295,14 +321,20 @@ async def _jwt_auth_user_api_key_auth_builder(
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
# check if unmatched token and enforce_rbac is true
- if (
- jwt_handler.litellm_jwtauth.enforce_rbac is True
- and jwt_handler.get_rbac_role(token=jwt_valid_token) is None
- ):
- raise HTTPException(
- status_code=403,
- detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
- )
+ if jwt_handler.litellm_jwtauth.enforce_rbac is True:
+ rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
+ if rbac_role is None:
+ raise HTTPException(
+ status_code=403,
+ detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
+ )
+ else:
+ # run rbac validation checks
+ can_rbac_role_call_model(
+ rbac_role=rbac_role,
+ general_settings=general_settings,
+ model=request_data.get("model"),
+ )
# get scopes
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
@@ -431,18 +463,18 @@ async def _jwt_auth_user_api_key_auth_builder(
proxy_logging_obj=proxy_logging_obj,
)
- return {
- "is_proxy_admin": False,
- "team_id": team_id,
- "team_object": team_object,
- "user_id": user_id,
- "user_object": user_object,
- "org_id": org_id,
- "org_object": org_object,
- "end_user_id": end_user_id,
- "end_user_object": end_user_object,
- "token": api_key,
- }
+ return JWTAuthBuilderResult(
+ is_proxy_admin=False,
+ team_id=team_id,
+ team_object=team_object,
+ user_id=user_id,
+ user_object=user_object,
+ org_id=org_id,
+ org_object=org_object,
+ end_user_id=end_user_id,
+ end_user_object=end_user_object,
+ token=api_key,
+ )
async def _user_api_key_auth_builder( # noqa: PLR0915
@@ -581,6 +613,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
if is_jwt:
result = await _jwt_auth_user_api_key_auth_builder(
+ request_data=request_data,
+ general_settings=general_settings,
api_key=api_key,
jwt_handler=jwt_handler,
route=route,
diff --git a/litellm/utils.py b/litellm/utils.py
index 976c8e2e4acf..92d6dc37dbe4 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -474,6 +474,11 @@ def function_setup( # noqa: PLR0915
if inspect.iscoroutinefunction(callback):
litellm._async_failure_callback.append(callback)
removed_async_items.append(index)
+ elif (
+ callback in litellm._known_custom_logger_compatible_callbacks
+ and isinstance(callback, str)
+ ):
+ _add_custom_logger_callback_to_specific_event(callback, "failure")
# Pop the async items from failure_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
@@ -1385,30 +1390,33 @@ def _select_tokenizer(
@lru_cache(maxsize=128)
def _select_tokenizer_helper(model: str):
- if model in litellm.cohere_models and "command-r" in model:
- # cohere
- cohere_tokenizer = Tokenizer.from_pretrained(
- "Xenova/c4ai-command-r-v01-tokenizer"
- )
- return {"type": "huggingface_tokenizer", "tokenizer": cohere_tokenizer}
- # anthropic
- elif model in litellm.anthropic_models and "claude-3" not in model:
- claude_tokenizer = Tokenizer.from_str(claude_json_str)
- return {"type": "huggingface_tokenizer", "tokenizer": claude_tokenizer}
- # llama2
- elif "llama-2" in model.lower() or "replicate" in model.lower():
- tokenizer = Tokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
- return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
- # llama3
- elif "llama-3" in model.lower():
- tokenizer = Tokenizer.from_pretrained("Xenova/llama-3-tokenizer")
- return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
+ try:
+ if model in litellm.cohere_models and "command-r" in model:
+ # cohere
+ cohere_tokenizer = Tokenizer.from_pretrained(
+ "Xenova/c4ai-command-r-v01-tokenizer"
+ )
+ return {"type": "huggingface_tokenizer", "tokenizer": cohere_tokenizer}
+ # anthropic
+ elif model in litellm.anthropic_models and "claude-3" not in model:
+ claude_tokenizer = Tokenizer.from_str(claude_json_str)
+ return {"type": "huggingface_tokenizer", "tokenizer": claude_tokenizer}
+ # llama2
+ elif "llama-2" in model.lower() or "replicate" in model.lower():
+ tokenizer = Tokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
+ return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
+ # llama3
+ elif "llama-3" in model.lower():
+ tokenizer = Tokenizer.from_pretrained("Xenova/llama-3-tokenizer")
+ return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
+ except Exception as e:
+ verbose_logger.debug(f"Error selecting tokenizer: {e}")
+
# default - tiktoken
- else:
- return {
- "type": "openai_tokenizer",
- "tokenizer": encoding,
- } # default to openai tokenizer
+ return {
+ "type": "openai_tokenizer",
+ "tokenizer": encoding,
+ } # default to openai tokenizer
def encode(model="", text="", custom_tokenizer: Optional[dict] = None):
diff --git a/tests/local_testing/test_token_counter.py b/tests/local_testing/test_token_counter.py
index ef9cc9194594..e1e2c36e9fd2 100644
--- a/tests/local_testing/test_token_counter.py
+++ b/tests/local_testing/test_token_counter.py
@@ -382,3 +382,80 @@ def test_img_url_token_counter(img_url):
def test_token_encode_disallowed_special():
encode(model="gpt-3.5-turbo", text="Hello, world! <|endoftext|>")
+
+
+import unittest
+from unittest.mock import patch, MagicMock
+from litellm.utils import encoding, _select_tokenizer_helper, claude_json_str
+
+
+class TestTokenizerSelection(unittest.TestCase):
+ @patch("litellm.utils.Tokenizer.from_pretrained")
+ def test_llama3_tokenizer_api_failure(self, mock_from_pretrained):
+ # Setup mock to raise an error
+ mock_from_pretrained.side_effect = Exception("Failed to load tokenizer")
+
+ # Test with llama-3 model
+ result = _select_tokenizer_helper("llama-3-7b")
+
+ # Verify the attempt to load Llama-3 tokenizer
+ mock_from_pretrained.assert_called_once_with("Xenova/llama-3-tokenizer")
+
+ # Verify fallback to OpenAI tokenizer
+ self.assertEqual(result["type"], "openai_tokenizer")
+ self.assertEqual(result["tokenizer"], encoding)
+
+ @patch("litellm.utils.Tokenizer.from_pretrained")
+ def test_cohere_tokenizer_api_failure(self, mock_from_pretrained):
+ # Setup mock to raise an error
+ mock_from_pretrained.side_effect = Exception("Failed to load tokenizer")
+
+ # Add Cohere model to the list for testing
+ litellm.cohere_models = ["command-r-v1"]
+
+ # Test with Cohere model
+ result = _select_tokenizer_helper("command-r-v1")
+
+ # Verify the attempt to load Cohere tokenizer
+ mock_from_pretrained.assert_called_once_with(
+ "Xenova/c4ai-command-r-v01-tokenizer"
+ )
+
+ # Verify fallback to OpenAI tokenizer
+ self.assertEqual(result["type"], "openai_tokenizer")
+ self.assertEqual(result["tokenizer"], encoding)
+
+ @patch("litellm.utils.Tokenizer.from_str")
+ def test_claude_tokenizer_api_failure(self, mock_from_str):
+ # Setup mock to raise an error
+ mock_from_str.side_effect = Exception("Failed to load tokenizer")
+
+ # Add Claude model to the list for testing
+ litellm.anthropic_models = ["claude-2"]
+
+ # Test with Claude model
+ result = _select_tokenizer_helper("claude-2")
+
+ # Verify the attempt to load Claude tokenizer
+ mock_from_str.assert_called_once_with(claude_json_str)
+
+ # Verify fallback to OpenAI tokenizer
+ self.assertEqual(result["type"], "openai_tokenizer")
+ self.assertEqual(result["tokenizer"], encoding)
+
+ @patch("litellm.utils.Tokenizer.from_pretrained")
+ def test_llama2_tokenizer_api_failure(self, mock_from_pretrained):
+ # Setup mock to raise an error
+ mock_from_pretrained.side_effect = Exception("Failed to load tokenizer")
+
+ # Test with Llama-2 model
+ result = _select_tokenizer_helper("llama-2-7b")
+
+ # Verify the attempt to load Llama-2 tokenizer
+ mock_from_pretrained.assert_called_once_with(
+ "hf-internal-testing/llama-tokenizer"
+ )
+
+ # Verify fallback to OpenAI tokenizer
+ self.assertEqual(result["type"], "openai_tokenizer")
+ self.assertEqual(result["tokenizer"], encoding)
diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py
index c651d84fa33a..866577c69aad 100644
--- a/tests/local_testing/test_utils.py
+++ b/tests/local_testing/test_utils.py
@@ -1529,6 +1529,34 @@ def test_add_custom_logger_callback_to_specific_event_e2e(monkeypatch):
assert len(litellm.failure_callback) == curr_len_failure_callback
+def test_add_custom_logger_callback_to_specific_event_e2e_failure(monkeypatch):
+ from litellm.integrations.openmeter import OpenMeterLogger
+
+ monkeypatch.setattr(litellm, "success_callback", [])
+ monkeypatch.setattr(litellm, "failure_callback", [])
+ monkeypatch.setattr(litellm, "callbacks", [])
+ monkeypatch.setenv("OPENMETER_API_KEY", "wedlwe")
+ monkeypatch.setenv("OPENMETER_API_URL", "https://openmeter.dev")
+
+ litellm.failure_callback = ["openmeter"]
+
+ curr_len_success_callback = len(litellm.success_callback)
+ curr_len_failure_callback = len(litellm.failure_callback)
+
+ litellm.completion(
+ model="gpt-4o-mini",
+ messages=[{"role": "user", "content": "Hello, world!"}],
+ mock_response="Testing langfuse",
+ )
+
+ assert len(litellm.success_callback) == curr_len_success_callback
+ assert len(litellm.failure_callback) == curr_len_failure_callback
+
+ assert any(
+ isinstance(callback, OpenMeterLogger) for callback in litellm.failure_callback
+ )
+
+
@pytest.mark.asyncio
async def test_wrapper_kwargs_passthrough():
from litellm.utils import client
diff --git a/tests/proxy_unit_tests/test_auth_checks.py b/tests/proxy_unit_tests/test_auth_checks.py
index 85b5b216a564..04af3d6e299a 100644
--- a/tests/proxy_unit_tests/test_auth_checks.py
+++ b/tests/proxy_unit_tests/test_auth_checks.py
@@ -508,3 +508,43 @@ async def budget_alerts(self, type, user_info):
assert (
alert_triggered == expect_alert
), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}"
+
+
+@pytest.mark.asyncio
+async def test_can_user_call_model():
+ from litellm.proxy.auth.auth_checks import can_user_call_model
+ from litellm.proxy._types import ProxyException
+ from litellm import Router
+
+ router = Router(
+ model_list=[
+ {
+ "model_name": "anthropic-claude",
+ "litellm_params": {"model": "anthropic/anthropic-claude"},
+ },
+ {
+ "model_name": "gpt-3.5-turbo",
+ "litellm_params": {"model": "gpt-3.5-turbo", "api_key": "test-api-key"},
+ },
+ ]
+ )
+
+ args = {
+ "model": "anthropic-claude",
+ "llm_router": router,
+ "user_object": LiteLLM_UserTable(
+ user_id="testuser21@mycompany.com",
+ max_budget=None,
+ spend=0.0042295,
+ model_max_budget={},
+ model_spend={},
+ user_email="testuser@mycompany.com",
+ models=["gpt-3.5-turbo"],
+ ),
+ }
+
+ with pytest.raises(ProxyException) as e:
+ await can_user_call_model(**args)
+
+ args["model"] = "gpt-3.5-turbo"
+ await can_user_call_model(**args)
diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py
index a428a29c6341..3e9ba17889e4 100644
--- a/tests/proxy_unit_tests/test_user_api_key_auth.py
+++ b/tests/proxy_unit_tests/test_user_api_key_auth.py
@@ -855,6 +855,8 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa
"user_api_key_cache": Mock(),
"parent_otel_span": None,
"proxy_logging_obj": Mock(),
+ "request_data": {},
+ "general_settings": {},
}
if enforce_rbac:
@@ -877,3 +879,55 @@ def test_user_api_key_auth_end_user_str():
user_api_key_auth = UserAPIKeyAuth(**user_api_key_args)
assert user_api_key_auth.end_user_id == "1"
+
+
+def test_can_rbac_role_call_model():
+ from litellm.proxy.auth.user_api_key_auth import can_rbac_role_call_model
+ from litellm.proxy._types import RoleBasedPermissions
+
+ roles_based_permissions = [
+ RoleBasedPermissions(
+ role=LitellmUserRoles.INTERNAL_USER,
+ models=["gpt-4"],
+ ),
+ RoleBasedPermissions(
+ role=LitellmUserRoles.PROXY_ADMIN,
+ models=["anthropic-claude"],
+ ),
+ ]
+
+ assert can_rbac_role_call_model(
+ rbac_role=LitellmUserRoles.INTERNAL_USER,
+ general_settings={"role_permissions": roles_based_permissions},
+ model="gpt-4",
+ )
+
+ with pytest.raises(HTTPException):
+ can_rbac_role_call_model(
+ rbac_role=LitellmUserRoles.INTERNAL_USER,
+ general_settings={"role_permissions": roles_based_permissions},
+ model="gpt-4o",
+ )
+
+ with pytest.raises(HTTPException):
+ can_rbac_role_call_model(
+ rbac_role=LitellmUserRoles.PROXY_ADMIN,
+ general_settings={"role_permissions": roles_based_permissions},
+ model="gpt-4o",
+ )
+
+
+def test_can_rbac_role_call_model_no_role_permissions():
+ from litellm.proxy.auth.user_api_key_auth import can_rbac_role_call_model
+
+ assert can_rbac_role_call_model(
+ rbac_role=LitellmUserRoles.INTERNAL_USER,
+ general_settings={},
+ model="gpt-4",
+ )
+
+ assert can_rbac_role_call_model(
+ rbac_role=LitellmUserRoles.PROXY_ADMIN,
+ general_settings={"role_permissions": []},
+ model="anthropic-claude",
+ )
diff --git a/tests/test_users.py b/tests/test_users.py
index 7e267ac4df14..812783681c29 100644
--- a/tests/test_users.py
+++ b/tests/test_users.py
@@ -7,13 +7,17 @@
from openai import AsyncOpenAI
from test_team import list_teams
from typing import Optional
+from test_keys import generate_key
+from fastapi import HTTPException
-async def new_user(session, i, user_id=None, budget=None, budget_duration=None):
+async def new_user(
+ session, i, user_id=None, budget=None, budget_duration=None, models=None
+):
url = "http://0.0.0.0:4000/user/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
- "models": ["azure-models"],
+ "models": models or ["azure-models"],
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None,
"max_budget": budget,
@@ -37,6 +41,51 @@ async def new_user(session, i, user_id=None, budget=None, budget_duration=None):
return await response.json()
+async def generate_key(
+ session,
+ i,
+ budget=None,
+ budget_duration=None,
+ models=["azure-models", "gpt-4", "dall-e-3"],
+ max_parallel_requests: Optional[int] = None,
+ user_id: Optional[str] = None,
+ team_id: Optional[str] = None,
+ metadata: Optional[dict] = None,
+ calling_key="sk-1234",
+):
+ url = "http://0.0.0.0:4000/key/generate"
+ headers = {
+ "Authorization": f"Bearer {calling_key}",
+ "Content-Type": "application/json",
+ }
+ data = {
+ "models": models,
+ "aliases": {"mistral-7b": "gpt-3.5-turbo"},
+ "duration": None,
+ "max_budget": budget,
+ "budget_duration": budget_duration,
+ "max_parallel_requests": max_parallel_requests,
+ "user_id": user_id,
+ "team_id": team_id,
+ "metadata": metadata,
+ }
+
+ print(f"data: {data}")
+
+ async with session.post(url, headers=headers, json=data) as response:
+ status = response.status
+ response_text = await response.text()
+
+ print(f"Response {i} (Status code: {status}):")
+ print(response_text)
+ print()
+
+ if status != 200:
+ raise Exception(f"Request {i} did not return a 200 status code: {status}")
+
+ return await response.json()
+
+
@pytest.mark.asyncio
async def test_user_new():
"""
@@ -210,3 +259,59 @@ async def test_global_proxy_budget_update():
new_new_spend = user_info["user_info"]["spend"]
print(f"new_spend: {new_spend}; original_spend: {original_spend}")
assert new_new_spend > new_spend
+
+
+@pytest.mark.asyncio
+async def test_user_model_access():
+ """
+ - Create user with model access
+ - Create key with user
+ - Call model that user has access to -> should work
+ - Call wildcard model that user has access to -> should work
+ - Call model that user does not have access to -> should fail
+ - Call wildcard model that user does not have access to -> should fail
+ """
+ import openai
+
+ async with aiohttp.ClientSession() as session:
+ get_user = f"krrish_{time.time()}@berri.ai"
+ await new_user(
+ session=session,
+ i=0,
+ user_id=get_user,
+ models=["good-model", "anthropic/*"],
+ )
+
+ result = await generate_key(
+ session=session,
+ i=0,
+ user_id=get_user,
+ models=[], # assign no models. Allow inheritance from user
+ )
+ key = result["key"]
+
+ await chat_completion(
+ session=session,
+ key=key,
+ model="anthropic/claude-3-5-haiku-20241022",
+ )
+
+ await chat_completion(
+ session=session,
+ key=key,
+ model="good-model",
+ )
+
+ with pytest.raises(openai.AuthenticationError):
+ await chat_completion(
+ session=session,
+ key=key,
+ model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
+ )
+
+ with pytest.raises(openai.AuthenticationError):
+ await chat_completion(
+ session=session,
+ key=key,
+ model="groq/claude-3-5-haiku-20241022",
+ )