From 24f88d4b02c7841fcef913f72196117248d6e2bd Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Sun, 12 Jan 2025 14:28:11 +0200 Subject: [PATCH] feat: support pydantic v2 --- projects/fal/src/fal/toolkit/types.py | 129 +++++++++++++++++--------- 1 file changed, 85 insertions(+), 44 deletions(-) diff --git a/projects/fal/src/fal/toolkit/types.py b/projects/fal/src/fal/toolkit/types.py index e6226849..7cb89998 100644 --- a/projects/fal/src/fal/toolkit/types.py +++ b/projects/fal/src/fal/toolkit/types.py @@ -1,69 +1,110 @@ import re from typing import Any, Dict, Union +import pydantic from pydantic.utils import update_not_none -from pydantic.validators import str_validator + +# https://github.com/pydantic/pydantic/pull/2573 +if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."): + IS_PYDANTIC_V2 = False +else: + IS_PYDANTIC_V2 = True MAX_DATA_URI_LENGTH = 10 * 1024 * 1024 MAX_HTTPS_URL_LENGTH = 2048 +HTTP_URL_REGEX = ( + r"^https:\/\/(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(?::\d{1,5})?(?:\/[^\s]*)?$" +) + class DataUri(str): - @classmethod - def __get_validators__(cls): - yield cls.validate + if IS_PYDANTIC_V2: - @classmethod - def validate(cls, value: Any) -> "DataUri": - value = str_validator(value) - value = value.strip() + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler) -> Any: + from pydantic_core.core_schema import str_schema - if not value.startswith("data:"): - raise ValueError("Data URI must start with 'data:'") + return str_schema(pattern="^data:", max_length=MAX_DATA_URI_LENGTH) - if len(value) > MAX_DATA_URI_LENGTH: - raise ValueError( - f"Data URI is too long. Max length is {MAX_DATA_URI_LENGTH} bytes." - ) + def __get_pydantic_json_schema__(cls, core_schema, handler) -> Dict[str, Any]: + json_schema = handler(core_schema) + json_schema.update(format="data-uri") + return json_schema + else: + + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value: Any) -> "DataUri": + from pydantic.validators import str_validator - return cls(value) + value = str_validator(value) + value = value.strip() - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none(field_schema, format="data-uri") + if not value.startswith("data:"): + raise ValueError("Data URI must start with 'data:'") + + if len(value) > MAX_DATA_URI_LENGTH: + raise ValueError( + f"Data URI is too long. Max length is {MAX_DATA_URI_LENGTH} bytes." + ) + + return cls(value) + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none(field_schema, format="data-uri") class HttpsUrl(str): - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, value: Any) -> "HttpsUrl": - value = str_validator(value) - value = value.strip() - - # Regular expression for validating HTTPS URL format - https_url_regex = ( - r"^https:\/\/(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(?::\d{1,5})?(?:\/[^\s]*)?$" - ) - - if not re.match(https_url_regex, value): - raise ValueError( - "URL must start with 'https://' and follow the correct format." - ) + if IS_PYDANTIC_V2: - if len(value) > MAX_HTTPS_URL_LENGTH: - raise ValueError( - f"HTTPS URL is too long. Max length is " - f"{MAX_HTTPS_URL_LENGTH} characters." + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler) -> Any: + from pydantic_core.core_schema import str_schema + + return str_schema( + pattern=HTTP_URL_REGEX, + max_length=MAX_HTTPS_URL_LENGTH, ) - return cls(value) + def __get_pydantic_json_schema__(cls, core_schema, handler) -> Dict[str, Any]: + json_schema = handler(core_schema) + json_schema.update(format="https-url") + return json_schema + + else: + + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value: Any) -> "HttpsUrl": + from pydantic.validators import str_validator + + value = str_validator(value) + value = value.strip() + + if not re.match(HTTP_URL_REGEX, value): + raise ValueError( + "URL must start with 'https://' and follow the correct format." + ) + + if len(value) > MAX_HTTPS_URL_LENGTH: + raise ValueError( + f"HTTPS URL is too long. Max length is " + f"{MAX_HTTPS_URL_LENGTH} characters." + ) + + return cls(value) - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none(field_schema, format="https-url") + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none(field_schema, format="https-url") FileInput = Union[HttpsUrl, DataUri]