diff --git a/docs/my-website/docs/proxy/jwt_auth_arch.md b/docs/my-website/docs/proxy/jwt_auth_arch.md new file mode 100644 index 000000000000..e48fa71f8be5 --- /dev/null +++ b/docs/my-website/docs/proxy/jwt_auth_arch.md @@ -0,0 +1,116 @@ +import Image from '@theme/IdealImage'; +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Control Model Access with SSO (Azure AD/Keycloak/etc.) + +:::info + +✨ JWT Auth is on LiteLLM Enterprise + +[Enterprise Pricing](https://www.litellm.ai/#pricing) + +[Get free 7-day trial key](https://www.litellm.ai/#trial) + +::: + + + +## Example Token + + + + +```bash +{ + "sub": "1234567890", + "name": "John Doe", + "email": "john.doe@example.com", + "roles": ["basic_user"] # 👈 ROLE +} +``` + + + +```bash +{ + "sub": "1234567890", + "name": "John Doe", + "email": "john.doe@example.com", + "resource_access": { + "litellm-test-client-id": { + "roles": ["basic_user"] # 👈 ROLE + } + } +} +``` + + + +## Proxy Configuration + + + + +```yaml +general_settings: + enable_jwt_auth: True + litellm_jwtauth: + user_roles_jwt_field: "roles" # the field in the JWT that contains the roles + user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM + enforce_rbac: true # if true, will check if the user has the correct role to access the model + + role_permissions: # control what models are allowed for each role + - role: internal_user + models: ["anthropic-claude"] + +model_list: + - model: anthropic-claude + litellm_params: + model: claude-3-5-haiku-20241022 + - model: openai-gpt-4o + litellm_params: + model: gpt-4o +``` + + + + +```yaml +general_settings: + enable_jwt_auth: True + litellm_jwtauth: + user_roles_jwt_field: "resource_access.litellm-test-client-id.roles" # the field in the JWT that contains the roles + user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM + enforce_rbac: true # if true, will check if the user has the correct role to access the model + + role_permissions: # control what models are allowed for each role + - role: internal_user + models: ["anthropic-claude"] + +model_list: + - model: anthropic-claude + litellm_params: + model: claude-3-5-haiku-20241022 + - model: openai-gpt-4o + litellm_params: + model: gpt-4o +``` + + + + + +## How it works + +1. Specify JWT_PUBLIC_KEY_URL - This is the public keys endpoint of your OpenID provider. For Azure AD it's `https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`. + +1. Map JWT roles to LiteLLM roles - Done via `user_roles_jwt_field` and `user_allowed_roles` + - Currently just `internal_user` is supported for role mapping. +2. Specify model access: + - `role_permissions`: control what models are allowed for each role. + - `role`: the LiteLLM role to control access for. Allowed roles = ["internal_user", "proxy_admin", "team"] + - `models`: list of models that the role is allowed to access. + - `model_list`: parent list of models on the proxy. [Learn more](./configs.md#llm-configs-model_list) + +3. Model Checks: The proxy will run validation checks on the received JWT. [Code](https://github.com/BerriAI/litellm/blob/3a4f5b23b5025b87b6d969f2485cc9bc741f9ba6/litellm/proxy/auth/user_api_key_auth.py#L284) \ No newline at end of file diff --git a/docs/my-website/docs/proxy/model_access.md b/docs/my-website/docs/proxy/model_access.md index 545d74865bbc..854baa2edbf3 100644 --- a/docs/my-website/docs/proxy/model_access.md +++ b/docs/my-website/docs/proxy/model_access.md @@ -344,3 +344,6 @@ curl -i http://localhost:4000/v1/chat/completions \ + + +## [Role Based Access Control (RBAC)](./jwt_auth_arch) \ No newline at end of file diff --git a/docs/my-website/docs/proxy/token_auth.md b/docs/my-website/docs/proxy/token_auth.md index ffff2694fe5f..df57cadd3b34 100644 --- a/docs/my-website/docs/proxy/token_auth.md +++ b/docs/my-website/docs/proxy/token_auth.md @@ -1,7 +1,7 @@ import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; -# JWT-based Auth +# SSO - JWT-based Auth Use JWT's to auth admins / projects into the proxy. @@ -183,6 +183,24 @@ Expected Scope in JWT: } ``` +### Control Model Access + +```yaml +general_settings: + enable_jwt_auth: True + litellm_jwtauth: + user_roles_jwt_field: "resource_access.litellm-test-client-id.roles" + user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM + enforce_rbac: true # if true, will check if the user has the correct role to access the model + endpoint + + role_permissions: # control what models + endpointsare allowed for each role + - role: internal_user + models: ["anthropic-claude"] +``` + + +**[Architecture Diagram (Control Model Access)](./jwt_auth_arch)** + ## Advanced - Allowed Routes Configure which routes a JWT can access via the config. diff --git a/docs/my-website/img/control_model_access_jwt.png b/docs/my-website/img/control_model_access_jwt.png new file mode 100644 index 000000000000..ab6cda53961f Binary files /dev/null and b/docs/my-website/img/control_model_access_jwt.png differ diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 29d674f3f450..d20f2a73e4da 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -51,7 +51,7 @@ const sidebars = { { type: "category", label: "Architecture", - items: ["proxy/architecture", "proxy/db_info", "router_architecture", "proxy/user_management_heirarchy"], + items: ["proxy/architecture", "proxy/db_info", "router_architecture", "proxy/user_management_heirarchy", "proxy/jwt_auth_arch"], }, { type: "link", diff --git a/litellm/litellm_core_utils/dot_notation_indexing.py b/litellm/litellm_core_utils/dot_notation_indexing.py new file mode 100644 index 000000000000..fda37f65007d --- /dev/null +++ b/litellm/litellm_core_utils/dot_notation_indexing.py @@ -0,0 +1,59 @@ +""" +This file contains the logic for dot notation indexing. + +Used by JWT Auth to get the user role from the token. +""" + +from typing import Any, Dict, Optional, TypeVar + +T = TypeVar("T") + + +def get_nested_value( + data: Dict[str, Any], key_path: str, default: Optional[T] = None +) -> Optional[T]: + """ + Retrieves a value from a nested dictionary using dot notation. + + Args: + data: The dictionary to search in + key_path: The path to the value using dot notation (e.g., "a.b.c") + default: The default value to return if the path is not found + + Returns: + The value at the specified path, or the default value if not found + + Example: + >>> data = {"a": {"b": {"c": "value"}}} + >>> get_nested_value(data, "a.b.c") + 'value' + >>> get_nested_value(data, "a.b.d", "default") + 'default' + """ + if not key_path: + return default + + # Remove metadata. prefix if it exists + key_path = ( + key_path.replace("metadata.", "", 1) + if key_path.startswith("metadata.") + else key_path + ) + + # Split the key path into parts + parts = key_path.split(".") + + # Traverse through the dictionary + current: Any = data + for part in parts: + try: + current = current[part] + except (KeyError, TypeError): + return default + + # If default is None, we can return any type + if default is None: + return current + + # Otherwise, ensure the type matches the default + return current if isinstance(current, type(default)) else default diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 4182ba86feff..423032ac86fd 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,18 +1,16 @@ model_list: - - model_name: gpt-3.5-turbo-end-user-test - litellm_params: - model: gpt-3.5-turbo - region_name: "eu" - model_info: - id: "1" - model_name: gpt-3.5-turbo litellm_params: model: gpt-3.5-turbo - timeout: 2 - num_retries: 0 - model_name: anthropic-claude litellm_params: - model: anthropic.claude-3-sonnet-20240229-v1:0 - -litellm_settings: - callbacks: ["langsmith"] \ No newline at end of file + model: claude-3-5-haiku-20241022 + - model_name: groq/* + litellm_params: + model: groq/* + api_key: os.environ/GROQ_API_KEY + mock_response: Hi! + - model_name: deepseek/* + litellm_params: + model: deepseek/* + api_key: os.environ/DEEPSEEK_API_KEY diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 5a456aec9715..bf3f6b6543a5 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -445,6 +445,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): user_id_jwt_field: Optional[str] = None user_email_jwt_field: Optional[str] = None user_allowed_email_domain: Optional[str] = None + user_roles_jwt_field: Optional[str] = None + user_allowed_roles: Optional[List[str]] = None user_id_upsert: bool = Field( default=False, description="If user doesn't exist, upsert them into the db." ) @@ -458,11 +460,19 @@ def __init__(self, **kwargs: Any) -> None: allowed_keys = self.__annotations__.keys() invalid_keys = set(kwargs.keys()) - allowed_keys + user_roles_jwt_field = kwargs.get("user_roles_jwt_field") + user_allowed_roles = kwargs.get("user_allowed_roles") if invalid_keys: raise ValueError( f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}." ) + if (user_roles_jwt_field is not None and user_allowed_roles is None) or ( + user_roles_jwt_field is None and user_allowed_roles is not None + ): + raise ValueError( + "user_allowed_roles must be provided if user_roles_jwt_field is set." + ) super().__init__(**kwargs) @@ -2335,3 +2345,15 @@ class ClientSideFallbackModel(TypedDict, total=False): ALL_FALLBACK_MODEL_VALUES = Union[str, ClientSideFallbackModel] + + +RBAC_ROLES = Literal[ + LitellmUserRoles.PROXY_ADMIN, + LitellmUserRoles.TEAM, + LitellmUserRoles.INTERNAL_USER, +] + + +class RoleBasedPermissions(TypedDict): + role: Required[RBAC_ROLES] + models: Required[List[str]] diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index d6bbf760bde9..8d0132709c18 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -12,7 +12,7 @@ import re import time import traceback -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast from fastapi import status from pydantic import BaseModel @@ -24,6 +24,7 @@ from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider from litellm.proxy._types import ( DB_CONNECTION_ERROR_TYPES, + RBAC_ROLES, CallInfo, LiteLLM_EndUserTable, LiteLLM_JWTAuth, @@ -35,6 +36,7 @@ LitellmUserRoles, ProxyErrorTypes, ProxyException, + RoleBasedPermissions, UserAPIKeyAuth, ) from litellm.proxy.auth.route_checks import RouteChecks @@ -100,6 +102,14 @@ async def common_checks( llm_router=llm_router, ) + ## 2.1 If user can call model (if personal key) + if team_object is None and user_object is not None: + await can_user_call_model( + model=_model, + llm_router=llm_router, + user_object=user_object, + ) + # 3. If team is in budget await _team_max_budget_check( team_object=team_object, @@ -391,6 +401,30 @@ def _update_last_db_access_time( last_db_access_time[key] = (value, time.time()) +def get_role_based_models( + rbac_role: RBAC_ROLES, + general_settings: dict, +) -> Optional[List[str]]: + """ + Get the models allowed for a user role. + + Used by JWT Auth. + """ + + role_based_permissions = cast( + Optional[List[RoleBasedPermissions]], + general_settings.get("role_permissions", []), + ) + if role_based_permissions is None: + return None + + for role_based_permission in role_based_permissions: + if role_based_permission["role"] == rbac_role: + return role_based_permission["models"] + + return None + + @log_db_metrics async def get_user_object( user_id: str, @@ -836,11 +870,10 @@ async def get_org_object( ) -async def can_key_call_model( +async def _can_object_call_model( model: str, - llm_model_list: Optional[list], - valid_token: UserAPIKeyAuth, - llm_router: Optional[litellm.Router], + llm_router: Optional[Router], + models: List[str], ) -> Literal[True]: """ Checks if token can call a given model @@ -855,9 +888,6 @@ async def can_key_call_model( model = litellm.model_alias_map[model] ## check if model in allowed model names - verbose_proxy_logger.debug( - f"LLM Model List pre access group check: {llm_model_list}" - ) from collections import defaultdict access_groups: Dict[str, List[str]] = defaultdict(list) @@ -868,13 +898,13 @@ async def can_key_call_model( len(access_groups) > 0 and llm_router is not None ): # check if token contains any model access groups for idx, m in enumerate( - valid_token.models + models ): # loop token models, if any of them are an access group add the access group if m in access_groups: return True # Filter out models that are access_groups - filtered_models = [m for m in valid_token.models if m not in access_groups] + filtered_models = [m for m in models if m not in access_groups] verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") @@ -885,25 +915,61 @@ async def can_key_call_model( all_model_access: bool = False - if ( - len(filtered_models) == 0 and len(valid_token.models) == 0 - ) or "*" in filtered_models: + if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models: all_model_access = True if model is not None and model not in filtered_models and all_model_access is False: raise ProxyException( - message=f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}", + message=f"API Key not allowed to access model. This token can only access models={models}. Tried to access {model}", type=ProxyErrorTypes.key_model_access_denied, param="model", code=status.HTTP_401_UNAUTHORIZED, ) - valid_token.models = filtered_models + verbose_proxy_logger.debug( - f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}" + f"filtered allowed_models: {filtered_models}; models: {models}" ) return True +async def can_key_call_model( + model: str, + llm_model_list: Optional[list], + valid_token: UserAPIKeyAuth, + llm_router: Optional[litellm.Router], +) -> Literal[True]: + """ + Checks if token can call a given model + + Returns: + - True: if token allowed to call model + + Raises: + - Exception: If token not allowed to call model + """ + return await _can_object_call_model( + model=model, + llm_router=llm_router, + models=valid_token.models, + ) + + +async def can_user_call_model( + model: str, + llm_router: Optional[Router], + user_object: Optional[LiteLLM_UserTable], +) -> Literal[True]: + + if user_object is None: + return True + + return await _can_object_call_model( + model=model, + llm_router=llm_router, + models=user_object.models, + ) + + async def is_valid_fallback_model( model: str, llm_router: Optional[Router], @@ -1161,7 +1227,11 @@ def _model_custom_llm_provider_matches_wildcard_pattern( - `model=claude-3-5-sonnet-20240620` - `allowed_model_pattern=anthropic/*` """ - model, custom_llm_provider, _, _ = get_llm_provider(model=model) + try: + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + except Exception: + return False + return is_model_allowed_by_pattern( model=f"{custom_llm_provider}/{model}", allowed_model_pattern=allowed_model_pattern, diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index cf5701154676..bcda413b68be 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -16,8 +16,10 @@ from litellm._logging import verbose_proxy_logger from litellm.caching.caching import DualCache +from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.proxy._types import ( + RBAC_ROLES, JWKKeyValue, JWTKeyItem, LiteLLM_JWTAuth, @@ -59,7 +61,7 @@ def is_jwt(self, token: str): parts = token.split(".") return len(parts) == 3 - def get_rbac_role(self, token: dict) -> Optional[LitellmUserRoles]: + def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]: """ Returns the RBAC role the token 'belongs' to. @@ -78,12 +80,18 @@ def get_rbac_role(self, token: dict) -> Optional[LitellmUserRoles]: """ scopes = self.get_scopes(token=token) is_admin = self.is_admin(scopes=scopes) + user_roles = self.get_user_roles(token=token, default_value=None) + if is_admin: return LitellmUserRoles.PROXY_ADMIN elif self.get_team_id(token=token, default_value=None) is not None: return LitellmUserRoles.TEAM elif self.get_user_id(token=token, default_value=None) is not None: return LitellmUserRoles.INTERNAL_USER + elif user_roles is not None and self.is_allowed_user_role( + user_roles=user_roles + ): + return LitellmUserRoles.INTERNAL_USER return None @@ -166,6 +174,43 @@ def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str user_id = default_value return user_id + def get_user_roles( + self, token: dict, default_value: Optional[List[str]] + ) -> Optional[List[str]]: + """ + Returns the user role from the token. + + Set via 'user_roles_jwt_field' in the config. + """ + try: + if self.litellm_jwtauth.user_roles_jwt_field is not None: + user_roles = get_nested_value( + data=token, + key_path=self.litellm_jwtauth.user_roles_jwt_field, + default=default_value, + ) + else: + user_roles = default_value + except KeyError: + user_roles = default_value + return user_roles + + def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool: + """ + Returns the user role from the token. + + Set via 'user_allowed_roles' in the config. + """ + if ( + user_roles is not None + and self.litellm_jwtauth.user_allowed_roles is not None + and any( + role in self.litellm_jwtauth.user_allowed_roles for role in user_roles + ) + ): + return True + return False + def get_user_email( self, token: dict, default_value: Optional[str] ) -> Optional[str]: diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 33247308f647..7d499af5b266 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -33,6 +33,7 @@ get_end_user_object, get_key_object, get_org_object, + get_role_based_models, get_team_object, get_user_object, is_valid_fallback_model, @@ -281,9 +282,34 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str: return LitellmUserRoles.TEAM +def can_rbac_role_call_model( + rbac_role: RBAC_ROLES, + general_settings: dict, + model: Optional[str], +) -> Literal[True]: + """ + Checks if user is allowed to access the model, based on their role. + """ + role_based_models = get_role_based_models( + rbac_role=rbac_role, general_settings=general_settings + ) + if role_based_models is None or model is None: + return True + + if model not in role_based_models: + raise HTTPException( + status_code=403, + detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}", + ) + + return True + + async def _jwt_auth_user_api_key_auth_builder( api_key: str, jwt_handler: JWTHandler, + request_data: dict, + general_settings: dict, route: str, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, @@ -295,14 +321,20 @@ async def _jwt_auth_user_api_key_auth_builder( jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key) # check if unmatched token and enforce_rbac is true - if ( - jwt_handler.litellm_jwtauth.enforce_rbac is True - and jwt_handler.get_rbac_role(token=jwt_valid_token) is None - ): - raise HTTPException( - status_code=403, - detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org", - ) + if jwt_handler.litellm_jwtauth.enforce_rbac is True: + rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token) + if rbac_role is None: + raise HTTPException( + status_code=403, + detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org", + ) + else: + # run rbac validation checks + can_rbac_role_call_model( + rbac_role=rbac_role, + general_settings=general_settings, + model=request_data.get("model"), + ) # get scopes scopes = jwt_handler.get_scopes(token=jwt_valid_token) @@ -431,18 +463,18 @@ async def _jwt_auth_user_api_key_auth_builder( proxy_logging_obj=proxy_logging_obj, ) - return { - "is_proxy_admin": False, - "team_id": team_id, - "team_object": team_object, - "user_id": user_id, - "user_object": user_object, - "org_id": org_id, - "org_object": org_object, - "end_user_id": end_user_id, - "end_user_object": end_user_object, - "token": api_key, - } + return JWTAuthBuilderResult( + is_proxy_admin=False, + team_id=team_id, + team_object=team_object, + user_id=user_id, + user_object=user_object, + org_id=org_id, + org_object=org_object, + end_user_id=end_user_id, + end_user_object=end_user_object, + token=api_key, + ) async def _user_api_key_auth_builder( # noqa: PLR0915 @@ -581,6 +613,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 verbose_proxy_logger.debug("is_jwt: %s", is_jwt) if is_jwt: result = await _jwt_auth_user_api_key_auth_builder( + request_data=request_data, + general_settings=general_settings, api_key=api_key, jwt_handler=jwt_handler, route=route, diff --git a/litellm/utils.py b/litellm/utils.py index 976c8e2e4acf..92d6dc37dbe4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -474,6 +474,11 @@ def function_setup( # noqa: PLR0915 if inspect.iscoroutinefunction(callback): litellm._async_failure_callback.append(callback) removed_async_items.append(index) + elif ( + callback in litellm._known_custom_logger_compatible_callbacks + and isinstance(callback, str) + ): + _add_custom_logger_callback_to_specific_event(callback, "failure") # Pop the async items from failure_callback in reverse order to avoid index issues for index in reversed(removed_async_items): @@ -1385,30 +1390,33 @@ def _select_tokenizer( @lru_cache(maxsize=128) def _select_tokenizer_helper(model: str): - if model in litellm.cohere_models and "command-r" in model: - # cohere - cohere_tokenizer = Tokenizer.from_pretrained( - "Xenova/c4ai-command-r-v01-tokenizer" - ) - return {"type": "huggingface_tokenizer", "tokenizer": cohere_tokenizer} - # anthropic - elif model in litellm.anthropic_models and "claude-3" not in model: - claude_tokenizer = Tokenizer.from_str(claude_json_str) - return {"type": "huggingface_tokenizer", "tokenizer": claude_tokenizer} - # llama2 - elif "llama-2" in model.lower() or "replicate" in model.lower(): - tokenizer = Tokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} - # llama3 - elif "llama-3" in model.lower(): - tokenizer = Tokenizer.from_pretrained("Xenova/llama-3-tokenizer") - return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} + try: + if model in litellm.cohere_models and "command-r" in model: + # cohere + cohere_tokenizer = Tokenizer.from_pretrained( + "Xenova/c4ai-command-r-v01-tokenizer" + ) + return {"type": "huggingface_tokenizer", "tokenizer": cohere_tokenizer} + # anthropic + elif model in litellm.anthropic_models and "claude-3" not in model: + claude_tokenizer = Tokenizer.from_str(claude_json_str) + return {"type": "huggingface_tokenizer", "tokenizer": claude_tokenizer} + # llama2 + elif "llama-2" in model.lower() or "replicate" in model.lower(): + tokenizer = Tokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} + # llama3 + elif "llama-3" in model.lower(): + tokenizer = Tokenizer.from_pretrained("Xenova/llama-3-tokenizer") + return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} + except Exception as e: + verbose_logger.debug(f"Error selecting tokenizer: {e}") + # default - tiktoken - else: - return { - "type": "openai_tokenizer", - "tokenizer": encoding, - } # default to openai tokenizer + return { + "type": "openai_tokenizer", + "tokenizer": encoding, + } # default to openai tokenizer def encode(model="", text="", custom_tokenizer: Optional[dict] = None): diff --git a/tests/local_testing/test_token_counter.py b/tests/local_testing/test_token_counter.py index ef9cc9194594..e1e2c36e9fd2 100644 --- a/tests/local_testing/test_token_counter.py +++ b/tests/local_testing/test_token_counter.py @@ -382,3 +382,80 @@ def test_img_url_token_counter(img_url): def test_token_encode_disallowed_special(): encode(model="gpt-3.5-turbo", text="Hello, world! <|endoftext|>") + + +import unittest +from unittest.mock import patch, MagicMock +from litellm.utils import encoding, _select_tokenizer_helper, claude_json_str + + +class TestTokenizerSelection(unittest.TestCase): + @patch("litellm.utils.Tokenizer.from_pretrained") + def test_llama3_tokenizer_api_failure(self, mock_from_pretrained): + # Setup mock to raise an error + mock_from_pretrained.side_effect = Exception("Failed to load tokenizer") + + # Test with llama-3 model + result = _select_tokenizer_helper("llama-3-7b") + + # Verify the attempt to load Llama-3 tokenizer + mock_from_pretrained.assert_called_once_with("Xenova/llama-3-tokenizer") + + # Verify fallback to OpenAI tokenizer + self.assertEqual(result["type"], "openai_tokenizer") + self.assertEqual(result["tokenizer"], encoding) + + @patch("litellm.utils.Tokenizer.from_pretrained") + def test_cohere_tokenizer_api_failure(self, mock_from_pretrained): + # Setup mock to raise an error + mock_from_pretrained.side_effect = Exception("Failed to load tokenizer") + + # Add Cohere model to the list for testing + litellm.cohere_models = ["command-r-v1"] + + # Test with Cohere model + result = _select_tokenizer_helper("command-r-v1") + + # Verify the attempt to load Cohere tokenizer + mock_from_pretrained.assert_called_once_with( + "Xenova/c4ai-command-r-v01-tokenizer" + ) + + # Verify fallback to OpenAI tokenizer + self.assertEqual(result["type"], "openai_tokenizer") + self.assertEqual(result["tokenizer"], encoding) + + @patch("litellm.utils.Tokenizer.from_str") + def test_claude_tokenizer_api_failure(self, mock_from_str): + # Setup mock to raise an error + mock_from_str.side_effect = Exception("Failed to load tokenizer") + + # Add Claude model to the list for testing + litellm.anthropic_models = ["claude-2"] + + # Test with Claude model + result = _select_tokenizer_helper("claude-2") + + # Verify the attempt to load Claude tokenizer + mock_from_str.assert_called_once_with(claude_json_str) + + # Verify fallback to OpenAI tokenizer + self.assertEqual(result["type"], "openai_tokenizer") + self.assertEqual(result["tokenizer"], encoding) + + @patch("litellm.utils.Tokenizer.from_pretrained") + def test_llama2_tokenizer_api_failure(self, mock_from_pretrained): + # Setup mock to raise an error + mock_from_pretrained.side_effect = Exception("Failed to load tokenizer") + + # Test with Llama-2 model + result = _select_tokenizer_helper("llama-2-7b") + + # Verify the attempt to load Llama-2 tokenizer + mock_from_pretrained.assert_called_once_with( + "hf-internal-testing/llama-tokenizer" + ) + + # Verify fallback to OpenAI tokenizer + self.assertEqual(result["type"], "openai_tokenizer") + self.assertEqual(result["tokenizer"], encoding) diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index c651d84fa33a..866577c69aad 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -1529,6 +1529,34 @@ def test_add_custom_logger_callback_to_specific_event_e2e(monkeypatch): assert len(litellm.failure_callback) == curr_len_failure_callback +def test_add_custom_logger_callback_to_specific_event_e2e_failure(monkeypatch): + from litellm.integrations.openmeter import OpenMeterLogger + + monkeypatch.setattr(litellm, "success_callback", []) + monkeypatch.setattr(litellm, "failure_callback", []) + monkeypatch.setattr(litellm, "callbacks", []) + monkeypatch.setenv("OPENMETER_API_KEY", "wedlwe") + monkeypatch.setenv("OPENMETER_API_URL", "https://openmeter.dev") + + litellm.failure_callback = ["openmeter"] + + curr_len_success_callback = len(litellm.success_callback) + curr_len_failure_callback = len(litellm.failure_callback) + + litellm.completion( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Hello, world!"}], + mock_response="Testing langfuse", + ) + + assert len(litellm.success_callback) == curr_len_success_callback + assert len(litellm.failure_callback) == curr_len_failure_callback + + assert any( + isinstance(callback, OpenMeterLogger) for callback in litellm.failure_callback + ) + + @pytest.mark.asyncio async def test_wrapper_kwargs_passthrough(): from litellm.utils import client diff --git a/tests/proxy_unit_tests/test_auth_checks.py b/tests/proxy_unit_tests/test_auth_checks.py index 85b5b216a564..04af3d6e299a 100644 --- a/tests/proxy_unit_tests/test_auth_checks.py +++ b/tests/proxy_unit_tests/test_auth_checks.py @@ -508,3 +508,43 @@ async def budget_alerts(self, type, user_info): assert ( alert_triggered == expect_alert ), f"Expected alert_triggered to be {expect_alert} for spend={spend}, soft_budget={soft_budget}" + + +@pytest.mark.asyncio +async def test_can_user_call_model(): + from litellm.proxy.auth.auth_checks import can_user_call_model + from litellm.proxy._types import ProxyException + from litellm import Router + + router = Router( + model_list=[ + { + "model_name": "anthropic-claude", + "litellm_params": {"model": "anthropic/anthropic-claude"}, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "gpt-3.5-turbo", "api_key": "test-api-key"}, + }, + ] + ) + + args = { + "model": "anthropic-claude", + "llm_router": router, + "user_object": LiteLLM_UserTable( + user_id="testuser21@mycompany.com", + max_budget=None, + spend=0.0042295, + model_max_budget={}, + model_spend={}, + user_email="testuser@mycompany.com", + models=["gpt-3.5-turbo"], + ), + } + + with pytest.raises(ProxyException) as e: + await can_user_call_model(**args) + + args["model"] = "gpt-3.5-turbo" + await can_user_call_model(**args) diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py index a428a29c6341..3e9ba17889e4 100644 --- a/tests/proxy_unit_tests/test_user_api_key_auth.py +++ b/tests/proxy_unit_tests/test_user_api_key_auth.py @@ -855,6 +855,8 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa "user_api_key_cache": Mock(), "parent_otel_span": None, "proxy_logging_obj": Mock(), + "request_data": {}, + "general_settings": {}, } if enforce_rbac: @@ -877,3 +879,55 @@ def test_user_api_key_auth_end_user_str(): user_api_key_auth = UserAPIKeyAuth(**user_api_key_args) assert user_api_key_auth.end_user_id == "1" + + +def test_can_rbac_role_call_model(): + from litellm.proxy.auth.user_api_key_auth import can_rbac_role_call_model + from litellm.proxy._types import RoleBasedPermissions + + roles_based_permissions = [ + RoleBasedPermissions( + role=LitellmUserRoles.INTERNAL_USER, + models=["gpt-4"], + ), + RoleBasedPermissions( + role=LitellmUserRoles.PROXY_ADMIN, + models=["anthropic-claude"], + ), + ] + + assert can_rbac_role_call_model( + rbac_role=LitellmUserRoles.INTERNAL_USER, + general_settings={"role_permissions": roles_based_permissions}, + model="gpt-4", + ) + + with pytest.raises(HTTPException): + can_rbac_role_call_model( + rbac_role=LitellmUserRoles.INTERNAL_USER, + general_settings={"role_permissions": roles_based_permissions}, + model="gpt-4o", + ) + + with pytest.raises(HTTPException): + can_rbac_role_call_model( + rbac_role=LitellmUserRoles.PROXY_ADMIN, + general_settings={"role_permissions": roles_based_permissions}, + model="gpt-4o", + ) + + +def test_can_rbac_role_call_model_no_role_permissions(): + from litellm.proxy.auth.user_api_key_auth import can_rbac_role_call_model + + assert can_rbac_role_call_model( + rbac_role=LitellmUserRoles.INTERNAL_USER, + general_settings={}, + model="gpt-4", + ) + + assert can_rbac_role_call_model( + rbac_role=LitellmUserRoles.PROXY_ADMIN, + general_settings={"role_permissions": []}, + model="anthropic-claude", + ) diff --git a/tests/test_users.py b/tests/test_users.py index 7e267ac4df14..812783681c29 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -7,13 +7,17 @@ from openai import AsyncOpenAI from test_team import list_teams from typing import Optional +from test_keys import generate_key +from fastapi import HTTPException -async def new_user(session, i, user_id=None, budget=None, budget_duration=None): +async def new_user( + session, i, user_id=None, budget=None, budget_duration=None, models=None +): url = "http://0.0.0.0:4000/user/new" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} data = { - "models": ["azure-models"], + "models": models or ["azure-models"], "aliases": {"mistral-7b": "gpt-3.5-turbo"}, "duration": None, "max_budget": budget, @@ -37,6 +41,51 @@ async def new_user(session, i, user_id=None, budget=None, budget_duration=None): return await response.json() +async def generate_key( + session, + i, + budget=None, + budget_duration=None, + models=["azure-models", "gpt-4", "dall-e-3"], + max_parallel_requests: Optional[int] = None, + user_id: Optional[str] = None, + team_id: Optional[str] = None, + metadata: Optional[dict] = None, + calling_key="sk-1234", +): + url = "http://0.0.0.0:4000/key/generate" + headers = { + "Authorization": f"Bearer {calling_key}", + "Content-Type": "application/json", + } + data = { + "models": models, + "aliases": {"mistral-7b": "gpt-3.5-turbo"}, + "duration": None, + "max_budget": budget, + "budget_duration": budget_duration, + "max_parallel_requests": max_parallel_requests, + "user_id": user_id, + "team_id": team_id, + "metadata": metadata, + } + + print(f"data: {data}") + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + @pytest.mark.asyncio async def test_user_new(): """ @@ -210,3 +259,59 @@ async def test_global_proxy_budget_update(): new_new_spend = user_info["user_info"]["spend"] print(f"new_spend: {new_spend}; original_spend: {original_spend}") assert new_new_spend > new_spend + + +@pytest.mark.asyncio +async def test_user_model_access(): + """ + - Create user with model access + - Create key with user + - Call model that user has access to -> should work + - Call wildcard model that user has access to -> should work + - Call model that user does not have access to -> should fail + - Call wildcard model that user does not have access to -> should fail + """ + import openai + + async with aiohttp.ClientSession() as session: + get_user = f"krrish_{time.time()}@berri.ai" + await new_user( + session=session, + i=0, + user_id=get_user, + models=["good-model", "anthropic/*"], + ) + + result = await generate_key( + session=session, + i=0, + user_id=get_user, + models=[], # assign no models. Allow inheritance from user + ) + key = result["key"] + + await chat_completion( + session=session, + key=key, + model="anthropic/claude-3-5-haiku-20241022", + ) + + await chat_completion( + session=session, + key=key, + model="good-model", + ) + + with pytest.raises(openai.AuthenticationError): + await chat_completion( + session=session, + key=key, + model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0", + ) + + with pytest.raises(openai.AuthenticationError): + await chat_completion( + session=session, + key=key, + model="groq/claude-3-5-haiku-20241022", + )