diff --git a/spectree/response.py b/spectree/response.py index 15716b2a..48a04289 100644 --- a/spectree/response.py +++ b/spectree/response.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from ._types import ModelType, NamingStrategy, OptionalModelType -from .utils import gen_list_model, get_model_key, parse_code +from .utils import gen_list_model, get_model_key, has_examples, parse_code # according to https://tools.ietf.org/html/rfc2616#section-10 # https://tools.ietf.org/html/rfc7231#section-6.1 @@ -160,4 +160,10 @@ def generate_spec( }, } + schema_extra = getattr(model.__config__, "schema_extra", None) + if schema_extra and has_examples(schema_extra): + responses[parse_code(code)]["content"]["application/json"][ + "examples" + ] = schema_extra + return responses diff --git a/spectree/spec.py b/spectree/spec.py index 66836bbc..54722521 100644 --- a/spectree/spec.py +++ b/spectree/spec.py @@ -233,6 +233,7 @@ async def async_validate(*args: Any, **kwargs: Any): if model is not None: model_key = self._add_model(model=model) setattr(validation, name, model_key) + validation.json_model = model if resp: # Make sure that the endpoint specific status code and data model for @@ -321,7 +322,8 @@ def _generate_spec(self) -> Dict[str, Any]: if deprecated: routes[path][method.lower()]["deprecated"] = deprecated - request_body = parse_request(func) + json_model = getattr(func, "json_model", None) + request_body = parse_request(func, json_model) if request_body: routes[path][method.lower()]["requestBody"] = request_body diff --git a/spectree/utils.py b/spectree/utils.py index 52af7c74..71223462 100644 --- a/spectree/utils.py +++ b/spectree/utils.py @@ -76,7 +76,14 @@ def parse_comments(func: Callable[..., Any]) -> Tuple[Optional[str], Optional[st return summary, description -def parse_request(func: Any) -> Dict[str, Any]: +def has_examples(schema_exta: dict) -> bool: + for _, v in schema_exta.items(): + if isinstance(v, dict) and "value" in v.keys(): + return True + return False + + +def parse_request(func: Any, model: Optional[Any] = None) -> Dict[str, Any]: """ get json spec """ @@ -86,6 +93,11 @@ def parse_request(func: Any) -> Dict[str, Any]: "schema": {"$ref": f"#/components/schemas/{func.json}"} } + if model: + schema_extra = getattr(model.__config__, "schema_extra", None) + if schema_extra and has_examples(schema_extra): + content_items["application/json"]["examples"] = schema_extra + if hasattr(func, "form"): content_items["multipart/form-data"] = { "schema": {"$ref": f"#/components/schemas/{func.form}"} diff --git a/tests/common.py b/tests/common.py index 7af87452..87c0cc8b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -70,6 +70,11 @@ class DemoModel(BaseModel): name: str = Field(..., description="user name") +class DemoModelWithSchemaExtra(BaseModel): + class Config: + schema_extra = {"example1": {"value": {"key1": "value1", "key2": "value2"}}} + + class DemoQuery(BaseModel): names1: List[str] = Field(...) names2: List[str] = Field(..., style="matrix", explode=True, non_keyword="dummy") diff --git a/tests/test_response.py b/tests/test_response.py index 103ee3f0..df50dc90 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -7,7 +7,7 @@ from spectree.response import DEFAULT_CODE_DESC, Response from spectree.utils import gen_list_model -from .common import JSON, DemoModel, get_model_path_key +from .common import JSON, DemoModel, DemoModelWithSchemaExtra, get_model_path_key class NormalClass: @@ -107,6 +107,34 @@ def test_response_spec(): assert spec.get(404) is None +def test_response_spec_with_schema_extra(): + resp = Response( + "HTTP_200", + HTTP_201=DemoModelWithSchemaExtra, + HTTP_401=(DemoModelWithSchemaExtra, "custom code description"), + HTTP_402=(None, "custom code description"), + ) + resp.add_model(422, ValidationError) + spec = resp.generate_spec() + assert spec["200"]["description"] == DEFAULT_CODE_DESC["HTTP_200"] + assert spec["201"]["description"] == DEFAULT_CODE_DESC["HTTP_201"] + assert spec["422"]["description"] == DEFAULT_CODE_DESC["HTTP_422"] + assert spec["401"]["description"] == "custom code description" + assert spec["402"]["description"] == "custom code description" + assert spec["201"]["content"]["application/json"]["schema"]["$ref"].split("/")[ + -1 + ] == get_model_path_key("tests.common.DemoModelWithSchemaExtra") + assert spec["201"]["content"]["application/json"]["examples"] == { + "example1": {"value": {"key1": "value1", "key2": "value2"}} + } + assert spec["422"]["content"]["application/json"]["schema"]["$ref"].split("/")[ + -1 + ] == get_model_path_key("spectree.models.ValidationError") + + assert spec.get(200) is None + assert spec.get(404) is None + + def test_list_model(): resp = Response(HTTP_200=List[JSON]) model = resp.find_model(200) diff --git a/tests/test_utils.py b/tests/test_utils.py index 802c0215..d3f4b3cb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,7 +12,7 @@ parse_resp, ) -from .common import DemoModel, DemoQuery, get_model_path_key +from .common import DemoModel, DemoModelWithSchemaExtra, DemoQuery, get_model_path_key api = SpecTree() @@ -31,6 +31,14 @@ def demo_func(): description""" +@api.validate(json=DemoModelWithSchemaExtra, resp=Response(HTTP_200=DemoModel)) +def demo_with_schema_extra_func(): + """ + summary + + description""" + + @api.validate(query=DemoQuery) def demo_func_with_query(): """ @@ -230,6 +238,18 @@ def test_parse_request(): assert parse_request(demo_class.demo_method) == {} +def test_parse_request_with_schema_extra(): + model_path_key = get_model_path_key("tests.common.DemoModelWithSchemaExtra") + + assert parse_request( + demo_with_schema_extra_func, demo_with_schema_extra_func.json_model + )["content"]["application/json"] == { + "schema": {"$ref": f"#/components/schemas/{model_path_key}"}, + "examples": {"example1": {"value": {"key1": "value1", "key2": "value2"}}}, + } + assert parse_request(demo_class.demo_method) == {} + + def test_parse_params(): models = { get_model_path_key("tests.common.DemoModel"): DemoModel.schema(