diff --git a/deepaas/__init__.py b/deepaas/__init__.py index 07f79744..c45bf43d 100644 --- a/deepaas/__init__.py +++ b/deepaas/__init__.py @@ -18,7 +18,7 @@ import importlib.metadata from pathlib import Path -__version__ = "2.4.0" +__version__ = "3.0.0" def extract_version() -> str: diff --git a/deepaas/api/__init__.py b/deepaas/api/__init__.py index be07064f..9b020c04 100644 --- a/deepaas/api/__init__.py +++ b/deepaas/api/__init__.py @@ -14,38 +14,39 @@ # License for the specific language governing permissions and limitations # under the License. -import pathlib +import json -from aiohttp import web -import aiohttp_apispec +import fastapi +import fastapi.responses from oslo_config import cfg import deepaas from deepaas.api import v2 -from deepaas.api import versions +from deepaas.api.v2 import responses from deepaas import log from deepaas import model LOG = log.getLogger(__name__) APP = None +VERSIONS = {} CONF = cfg.CONF LINKS = """ -- [Project website](https://deep-hybrid.datacloud.eu). -- [Project documentation](https://docs.deep-hybrid.datacloud.eu). -- [Model marketplace](https://marketplace.deep-hybrid.datacloud.eu). +- [AI4EOSC Project website](https://ai4eosc.eu). +- [Project documentation](https://docs.ai4eosc.eu). +- [API documentation](https://docs.ai4os.eu/deepaas). +- [AI4EOSC Model marketplace](https://dashboard.cloud.ai4eosc.eu/marketplace). """ API_DESCRIPTION = ( "" "\n\nThis is a REST API that is focused on providing access " - "to machine learning models. By using the DEEPaaS API " - "users can easily run a REST API in front of their model, " - "thus accessing its functionality via HTTP calls. " + "to machine learning models. " "\n\nCurrently you are browsing the " "[Swagger UI](https://swagger.io/tools/swagger-ui/) " "for this API, a tool that allows you to visualize and interact with the " @@ -53,75 +54,83 @@ ) + LINKS -async def get_app( - swagger=True, - enable_doc=True, - doc="/api", - prefix="", - static_path="/static/swagger", - base_path="", - enable_train=True, - enable_predict=True, -): - """Get the main app.""" +def get_fastapi_app( + enable_doc: bool = True, + enable_train: bool = True, # FIXME(aloga): not handled yet + enable_predict: bool = True, + base_path: str = "", +) -> fastapi.FastAPI: + """Get the main app, based on FastAPI.""" global APP + global VERSIONS if APP: return APP - APP = web.Application(debug=CONF.debug, client_max_size=CONF.client_max_size) + APP = fastapi.FastAPI( + title="Model serving API endpoint", + description=API_DESCRIPTION, + version=deepaas.extract_version(), + docs_url=f"{base_path}/docs" if enable_doc else None, # NOTE(aloga): changed + redoc_url=f"{base_path}/redoc" if enable_doc else None, # NOTE(aloga): new + openapi_url=f"{base_path}/openapi.json", # NOTE(aloga): changed + ) - APP.middlewares.append(web.normalize_path_middleware()) + model.load_v2_model() + LOG.info("Serving loaded V2 model: %s", model.V2_MODEL_NAME) - model.register_v2_models(APP) + if CONF.warm: + LOG.debug("Warming models...") + model.V2_MODEL.warm() + + v2app = v2.get_app( + # FIXME(aloga): these have no effect now, remove. + enable_train=enable_train, + enable_predict=enable_predict, + ) + + APP.include_router(v2app, prefix=f"{base_path}/v2", tags=["v2"]) + VERSIONS["v2"] = v2.get_v2_version + + APP.add_api_route( + f"{base_path}/", + get_root, + methods=["GET"], + summary="Get API version information", + tags=["version"], + response_model=responses.VersionsAndLinks, + ) - v2app = v2.get_app(enable_train=enable_train, enable_predict=enable_predict) - if base_path: - path = str(pathlib.Path(base_path) / "v2") - else: - path = "/v2" - APP.add_subapp(path, v2app) - versions.register_version("stable", v2.get_version) + return APP - if base_path: - # Get versions.routes, and transform them to have the base_path, as we cannot - # directly modify the routes already created and stored in the RouteTableDef - for route in versions.routes: - APP.router.add_route( - route.method, str(pathlib.Path(base_path + route.path)), route.handler - ) - else: - APP.add_routes(versions.routes) - LOG.info("Serving loaded V2 models: %s", list(model.V2_MODELS.keys())) +async def get_root(request: fastapi.Request) -> fastapi.responses.JSONResponse: + versions = [] + for _ver, info in VERSIONS.items(): + resp = await info(request) + versions.append(json.loads(resp.body)) - if CONF.warm: - for _, m in model.V2_MODELS.items(): - LOG.debug("Warming models...") - await m.warm() - - if swagger: - doc = str(pathlib.Path(base_path + doc)) - swagger = str(pathlib.Path(base_path + "/swagger.json")) - static_path = str(pathlib.Path(base_path + static_path)) - - # init docs with all parameters, usual for ApiSpec - aiohttp_apispec.setup_aiohttp_apispec( - app=APP, - title="DEEP as a Service API endpoint", - info={ - "description": API_DESCRIPTION, - }, - externalDocs={ - "description": "API documentation", - "url": "https://deepaas.readthedocs.org/", - }, - version=deepaas.extract_version(), - url=swagger, - swagger_path=doc if enable_doc else None, - prefix=prefix, - static_path=static_path, - in_place=True, - ) + root = str(request.url_for("get_root")) - return APP + response = {"versions": versions, "links": []} + + doc = APP.docs_url.strip("/") + if doc: + doc = {"rel": "help", "type": "text/html", "href": f"{root}{doc}"} + response["links"].append(doc) + + redoc = APP.redoc_url.strip("/") + if redoc: + redoc = {"rel": "help", "type": "text/html", "href": f"{root}{redoc}"} + response["links"].append(redoc) + + spec = APP.openapi_url.strip("/") + if spec: + spec = { + "rel": "describedby", + "type": "application/json", + "href": f"{root}{spec}", + } + response["links"].append(spec) + + return fastapi.responses.JSONResponse(content=response) diff --git a/deepaas/api/v2/__init__.py b/deepaas/api/v2/__init__.py index 8064681c..40b80e58 100644 --- a/deepaas/api/v2/__init__.py +++ b/deepaas/api/v2/__init__.py @@ -14,50 +14,57 @@ # License for the specific language governing permissions and limitations # under the License. -from aiohttp import web -import aiohttp_apispec +import fastapi +import fastapi.responses from oslo_config import cfg from deepaas.api.v2 import debug as v2_debug from deepaas.api.v2 import models as v2_model from deepaas.api.v2 import predict as v2_predict from deepaas.api.v2 import responses -from deepaas.api.v2 import train as v2_train + +# from deepaas.api.v2 import train as v2_train from deepaas import log CONF = cfg.CONF LOG = log.getLogger("deepaas.api.v2") +# XXX APP = None def get_app(enable_train=True, enable_predict=True): global APP - APP = web.Application() + # FIXME(aloga): check we cat get rid of global variables + APP = fastapi.APIRouter() v2_debug.setup_debug() - APP.router.add_get("/", get_version, name="v2", allow_head=False) - v2_debug.setup_routes(APP) - v2_model.setup_routes(APP) - v2_train.setup_routes(APP, enable=enable_train) - v2_predict.setup_routes(APP, enable=enable_predict) + APP.include_router(v2_debug.get_router(), tags=["debug"]) + APP.include_router(v2_model.get_router(), tags=["models"]) + if enable_predict: + APP.include_router(v2_predict.get_router(), tags=["predict"]) + + # APP.router.add_get("/", get_version, name="v2", allow_head=False) + # v2_debug.setup_routes(APP) + # v2_model.setup_routes(APP) + # v2_train.setup_routes(APP, enable=enable_train) + # v2_predict.setup_routes(APP, enable=enable_predict) + + APP.add_api_route( + "/", + get_v2_version, + methods=["GET"], + tags=["version"], + response_model=responses.Versions, + ) return APP -@aiohttp_apispec.docs( - tags=["versions"], - summary="Get V2 API version information", -) -@aiohttp_apispec.response_schema(responses.Version(), 200) -@aiohttp_apispec.response_schema(responses.Failure(), 400) -async def get_version(request): - # NOTE(aloga): we use the router table from this application (i.e. the - # global APP in this module) to be able to build the correct url, as it can - # be prefixed outside of this module (in an add_subapp() call) - root = APP.router["v2"].url_for() +async def get_v2_version(request: fastapi.Request) -> fastapi.responses.JSONResponse: + root = str(request.url_for("get_v2_version")) version = { "version": "stable", "id": "v2", @@ -65,9 +72,8 @@ async def get_version(request): { "rel": "self", "type": "application/json", - "href": "%s" % root, + "href": f"{root}", } ], } - - return web.json_response(version) + return fastapi.responses.JSONResponse(content=version) diff --git a/deepaas/api/v2/debug.py b/deepaas/api/v2/debug.py index e380c04d..e5e37d56 100644 --- a/deepaas/api/v2/debug.py +++ b/deepaas/api/v2/debug.py @@ -20,16 +20,15 @@ import sys import warnings -from aiohttp import web -import aiohttp_apispec +import fastapi from oslo_config import cfg from deepaas import log CONF = cfg.CONF -app = web.Application() -routes = web.RouteTableDef() +router = fastapi.APIRouter(prefix="/debug") + # Ugly global variable to provide a string stream to read the DEBUG output # if it is enabled @@ -52,6 +51,9 @@ def close(self): for f in self.handles: f.close() + def isatty(self): + return all(f.isatty() for f in self.handles) + def setup_debug(): global DEBUG_STREAM @@ -59,7 +61,7 @@ def setup_debug(): if CONF.debug_endpoint: DEBUG_STREAM = io.StringIO() - logger = log.getLogger("deepaas").logger + logger = log.getLogger("deepaas") hdlr = logging.StreamHandler(DEBUG_STREAM) hdlr.setFormatter( logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") @@ -78,23 +80,33 @@ def setup_debug(): sys.stderr = MultiOut(DEBUG_STREAM, sys.stderr) -@aiohttp_apispec.docs( +@router.get( + "/", + summary="Return debug information if enabled by API.", + description="Return debug information if enabled by API.", tags=["debug"], - summary="""Return debug information if enabled by API.""", - description="""Return debug information if enabled by API.""", - produces=["text/plain"], + response_class=fastapi.responses.PlainTextResponse, responses={ - 200: {"description": "Debug information if debug endpoint is enabled"}, - 204: {"description": "Debug endpoint not enabled"}, + "200": { + "content": {"text/plain": {}}, + "description": "Debug information if debug endpoint is enabled", + }, + "204": {"description": "Debug endpoint not enabled"}, }, ) -async def get(request): +async def get(): if DEBUG_STREAM is not None: print("--- DEBUG MARKER %s ---" % datetime.datetime.now()) resp = DEBUG_STREAM.getvalue() - return web.Response(text=resp) - return web.HTTPNoContent() + return fastapi.responses.PlainTextResponse(resp) + else: + return fastapi.responses.Response(status_code=204) + +def get_router() -> fastapi.APIRouter: + """Auxiliary function to get the router. -def setup_routes(app): - app.router.add_get("/debug/", get, allow_head=False) + We use this function to be able to include the router in the main + application and do things before it gets included. + """ + return router diff --git a/deepaas/api/v2/models.py b/deepaas/api/v2/models.py index 406f2642..93b0b8ae 100644 --- a/deepaas/api/v2/models.py +++ b/deepaas/api/v2/models.py @@ -14,52 +14,55 @@ # License for the specific language governing permissions and limitations # under the License. -import urllib.parse - -from aiohttp import web -import aiohttp_apispec +import fastapi from deepaas.api.v2 import responses from deepaas import model -@aiohttp_apispec.docs( - tags=["models"], +router = fastapi.APIRouter(prefix="/models") + + +@router.get( + "/", summary="Return loaded models and its information", - description="DEEPaaS can load several models and server them on the same " - "endpoint, making a call to the root of the models namespace " - "will return the loaded models, as long as their basic " - "metadata.", + description="Return list of DEEPaaS loaded models. In previous versions, DEEPaaS " + "could load several models and serve them on the same endpoint.", + tags=["models"], + response_model=responses.ModelList, ) -@aiohttp_apispec.response_schema(responses.ModelMeta(), 200) -async def index(request): - """Return loaded models and its information. +async def index_models( + request: fastapi.Request, +): + """Return loaded models and its information.""" + + name = model.V2_MODEL_NAME + model_obj = model.V2_MODEL + m = { + "id": name, + "name": name, + "links": [ + { + "rel": "self", + "href": str(request.url_for("get_model/" + name)), + } + ], + } + meta = model_obj.get_metadata() + m.update(meta) + return {"models": [m]} + + +def _get_handler_for_model(model_name, model_obj): + """Auxiliary function to get the handler for a model. - DEEPaaS can load several models and server them on the same endpoint, - making a call to the root of the models namespace will return the - loaded models, as long as their basic metadata. + This function returns a handler for a model that can be used to + register the routes in the router. """ - models = [] - for name, obj in model.V2_MODELS.items(): - m = { - "id": name, - "name": name, - "links": [ - { - "rel": "self", - "href": urllib.parse.urljoin("%s/" % request.path, name), - } - ], - } - meta = obj.get_metadata() - m.update(meta) - models.append(m) - return web.json_response({"models": models}) - - -def _get_handler(model_name, model_obj): class Handler(object): + """Class to handle the model metadata endpoints.""" + model_name = None model_obj = None @@ -67,36 +70,50 @@ def __init__(self, model_name, model_obj): self.model_name = model_name self.model_obj = model_obj - @aiohttp_apispec.docs( - tags=["models"], - summary="Return model's metadata", - ) - @aiohttp_apispec.response_schema(responses.ModelMeta(), 200) - async def get(self, request): + async def get(self, request: fastapi.Request): + """Return model's metadata.""" m = { "id": self.model_name, "name": self.model_name, "links": [ { "rel": "self", - "href": request.path.rstrip("/"), + "href": str(request.url), } ], } meta = self.model_obj.get_metadata() m.update(meta) - return web.json_response(m) + return m + + def register_routes(self, router): + """Register routes for the model in the router.""" + router.add_api_route( + f"/{self.model_name}", + self.get, + name="get_model/" + self.model_name, + summary="Return model's metadata", + tags=["models"], + response_model=responses.ModelMeta, + ) return Handler(model_name, model_obj) -def setup_routes(app): - app.router.add_get("/models/", index, allow_head=False) +def get_router() -> fastapi.APIRouter: + """Auxiliary function to get the router. + + We use this function to be able to include the router in the main + application and do things before it gets included. + + In this case we explicitly include the model's endpoints. + + """ + model_name = model.V2_MODEL_NAME + model_obj = model.V2_MODEL + + hdlr = _get_handler_for_model(model_name, model_obj) + hdlr.register_routes(router) - # In the next lines we iterate over the loaded models and create the - # different resources for each model. This way we can also load the - # expected parameters if needed (as in the training method). - for model_name, model_obj in model.V2_MODELS.items(): - hdlr = _get_handler(model_name, model_obj) - app.router.add_get("/models/%s/" % model_name, hdlr.get, allow_head=False) + return router diff --git a/deepaas/api/v2/predict.py b/deepaas/api/v2/predict.py index 8141469e..e6af37ae 100644 --- a/deepaas/api/v2/predict.py +++ b/deepaas/api/v2/predict.py @@ -14,10 +14,9 @@ # License for the specific language governing permissions and limitations # under the License. -from aiohttp import web -import aiohttp_apispec -from webargs import aiohttpparser -import webargs.core +import fastapi +import fastapi.encoders +import fastapi.exceptions from deepaas.api.v2 import responses from deepaas.api.v2 import utils @@ -33,20 +32,26 @@ def _get_model_response(model_name, model_obj): return responses.Prediction -def _get_handler(model_name, model_obj): - aux = model_obj.get_predict_args() - accept = aux.get("accept", None) - if accept: - accept.validate.choices.append("*/*") - accept.load_default = accept.validate.choices[0] - accept.location = "headers" +router = fastapi.APIRouter(prefix="/models") - handler_args = webargs.core.dict2schema(aux) - handler_args.opts.ordered = True - response = _get_model_response(model_name, model_obj) +def _get_handler_for_model(model_name, model_obj): + """Auxiliary function to get the handler for a model. + + This function returns a handler for a model that can be used to + register the routes in the router. + + """ + + user_declared_args = model_obj.get_predict_args() + pydantic_schema = utils.get_pydantic_schema_from_marshmallow_fields( + "PydanticSchema", + user_declared_args, + ) class Handler(object): + """Class to handle the model metadata endpoints.""" + model_name = None model_obj = None @@ -54,47 +59,54 @@ def __init__(self, model_name, model_obj): self.model_name = model_name self.model_obj = model_obj - @aiohttp_apispec.docs( - tags=["models"], - summary="Make a prediction given the input data", - produces=accept.validate.choices if accept else None, - ) - @aiohttp_apispec.querystring_schema(handler_args) - @aiohttp_apispec.response_schema(response(), 200) - @aiohttp_apispec.response_schema(responses.Failure(), 400) - async def post(self, request): - args = await aiohttpparser.parser.parse(handler_args, request) - task = self.model_obj.predict(**args) - await task - - ret = task.result()["output"] + async def predict(self, args: pydantic_schema = fastapi.Depends()): + """Make a prediction given the input data.""" + dict_args = args.model_dump(by_alias=True) + + ret = await self.model_obj.predict(**args.model_dump(by_alias=True)) if isinstance(ret, model.v2.wrapper.ReturnedFile): ret = open(ret.filename, "rb") - accept = args.get("accept", "application/json") - if accept not in ["application/json", "*/*"]: - response = web.Response( - body=ret, - content_type=accept, - ) - return response if self.model_obj.has_schema: - self.model_obj.validate_response(ret) - return web.json_response(ret) + # FIXME(aloga): Validation does not work, as we are converting from + # Marshmallow to Pydantic, check this as son as possible. + # self.model_obj.validate_response(ret) + return fastapi.responses.JSONResponse(ret) + + return fastapi.responses.JSONResponse( + content={"status": "OK", "predictions": ret} + ) + + def register_routes(self, router): + """Register the routes in the router.""" - return web.json_response({"status": "OK", "predictions": ret}) + response = _get_model_response(self.model_name, self.model_obj) + + router.add_api_route( + f"/{self.model_name}/predict", + self.predict, + methods=["POST"], + response_model=response, + tags=["models", "predict"], + ) return Handler(model_name, model_obj) -def setup_routes(app, enable=True): - # In the next lines we iterate over the loaded models and create the - # different resources for each model. This way we can also load the - # expected parameters if needed (as in the training method). - for model_name, model_obj in model.V2_MODELS.items(): - if enable: - hdlr = _get_handler(model_name, model_obj) - else: - hdlr = utils.NotEnabledHandler() - app.router.add_post("/models/%s/predict/" % model_name, hdlr.post) +def get_router(): + """Auxiliary function to get the router. + + We use this function to be able to include the router in the main + application and do things before it gets included. + + In this case we explicitly include the model precit endpoint. + + """ + model_name = model.V2_MODEL_NAME + model_obj = model.V2_MODEL + + hdlr = _get_handler_for_model(model_name, model_obj) + hdlr.register_routes(router) + + return router diff --git a/deepaas/api/v2/responses.py b/deepaas/api/v2/responses.py index 9bbff1de..aceb6c43 100644 --- a/deepaas/api/v2/responses.py +++ b/deepaas/api/v2/responses.py @@ -14,59 +14,92 @@ # License for the specific language governing permissions and limitations # under the License. -import marshmallow -from marshmallow import fields -from marshmallow import validate +import typing +import pydantic -class Location(marshmallow.Schema): - rel = fields.Str(required=True) - href = fields.Url(required=True) - type = fields.Str(required=True) +# class Training(marshmallow.Schema): +# uuid = fields.UUID(required=True, description="Training identifier") +# date = fields.DateTime(required=True, description="Training start time") +# status = fields.Str( +# required=True, +# description="Training status", +# enum=["running", "error", "completed", "cancelled"], +# validate=validate.OneOf(["running", "error", "completed", "cancelled"]), +# ) +# message = fields.Str(description="Optional message explaining status") -class Version(marshmallow.Schema): - version = fields.Str(required="True") - id = fields.Str(required="True") - links = fields.Nested(Location) - type = fields.Str() +# class TrainingList(marshmallow.Schema): +# trainings = fields.List(fields.Nested(Training)) -class Versions(marshmallow.Schema): - versions = fields.List(fields.Nested(Version)) +# Pydantic models for the API -class Failure(marshmallow.Schema): - message = fields.Str(required=True, description="Failure message") +class Version(pydantic.BaseModel): + version: str + id: str + type: str = "application/json" -class Prediction(marshmallow.Schema): - status = fields.String(required=True, description="Response status message") - predictions = fields.Str(required=True, description="String containing predictions") +class Versions(pydantic.BaseModel): + versions: typing.List[Version] -class ModelMeta(marshmallow.Schema): - id = fields.Str(required=True, description="Model identifier") # noqa - name = fields.Str(required=True, description="Model name") - description = fields.Str(required=True, description="Model description") - license = fields.Str(required=False, description="Model license") - author = fields.Str(required=False, description="Model author") - version = fields.Str(required=False, description="Model version") - url = fields.Str(required=False, description="Model url") - links = fields.List(fields.Nested(Location)) +class Location(pydantic.BaseModel): + rel: str + href: pydantic.AnyHttpUrl + type: str = "application/json" -class Training(marshmallow.Schema): - uuid = fields.UUID(required=True, description="Training identifier") - date = fields.DateTime(required=True, description="Training start time") - status = fields.Str( - required=True, - description="Training status", - enum=["running", "error", "completed", "cancelled"], - validate=validate.OneOf(["running", "error", "completed", "cancelled"]), + +class VersionsAndLinks(pydantic.BaseModel): + versions: typing.List[Version] + links: typing.List[Location] + + +class ModelMeta(pydantic.BaseModel): + """ "V2 model metadata. + + This class is used to represent the metadata of a model in the V2 API, as we were + doing in previous versions. + """ + + id: str = pydantic.Field(..., description="Model identifier") # noqa + name: str = pydantic.Field(..., description="Model name") + description: typing.Optional[str] = pydantic.Field( + description="Model description", default=None + ) + summary: typing.Optional[str] = pydantic.Field( + description="Model summary", default=None + ) + license: typing.Optional[str] = pydantic.Field( + description="Model license", default=None + ) + author: typing.Optional[str] = pydantic.Field( + description="Model author", default=None ) - message = fields.Str(description="Optional message explaining status") + version: typing.Optional[str] = pydantic.Field( + description="Model version", default=None + ) + url: typing.Optional[str] = pydantic.Field(description="Model url", default=None) + # Links can be alist of Locations, or an empty list + links: typing.List[Location] = pydantic.Field( + description="Model links", + ) + + +class ModelList(pydantic.BaseModel): + models: typing.List[ModelMeta] = pydantic.Field( + ..., description="List of loaded models" + ) + + +class Prediction(pydantic.BaseModel): + status: str = pydantic.Field(description="Response status message") + predictions: str = pydantic.Field(description="String containing predictions") -class TrainingList(marshmallow.Schema): - trainings = fields.List(fields.Nested(Training)) +class Failure(pydantic.BaseModel): + message: str = pydantic.Field(description="Failure message") diff --git a/deepaas/api/v2/utils.py b/deepaas/api/v2/utils.py index b66b5cc1..cb4b746a 100644 --- a/deepaas/api/v2/utils.py +++ b/deepaas/api/v2/utils.py @@ -14,7 +14,16 @@ # License for the specific language governing permissions and limitations # under the License. +import datetime +import decimal +import typing + +import fastapi from aiohttp import web +import marshmallow +import marshmallow.fields +import pydantic +import pydantic.utils class NotEnabledHandler(object): @@ -23,3 +32,204 @@ async def f(*args, **kwargs): raise web.HTTPPaymentRequired() return f + + +# Convert marshmallow fields to pydantic fields + + +CUSTOM_FIELD_DEFAULT = typing.Any + + +def get_dict_type(x): + """For dicts we need to look at the key and value type""" + key_type = get_pydantic_type(x.key_field) + if x.value_field: + value_type = get_pydantic_type(x.value_field) + return typing.Dict[key_type, value_type] + return typing.Dict[key_type, typing.Any] + + +def get_list_type(x): + """For lists we need to look at the value type""" + if x.inner: + c_type = get_pydantic_type(x.inner, optional=False) + return typing.List[c_type] + return typing.List + + +# def get_nested_model(x): +# """Return a model from a nested marshmallow schema""" +# return pydantic_from_marshmallow(x.schema) + + +FIELD_CONVERTERS = { + marshmallow.fields.Bool: bool, + marshmallow.fields.Boolean: bool, + marshmallow.fields.Date: datetime.date, + marshmallow.fields.DateTime: datetime.datetime, + marshmallow.fields.Decimal: decimal.Decimal, + marshmallow.fields.Dict: get_dict_type, + marshmallow.fields.Email: pydantic.EmailStr, + marshmallow.fields.Float: float, + marshmallow.fields.Function: typing.Callable, + marshmallow.fields.Int: int, + marshmallow.fields.Integer: int, + marshmallow.fields.List: get_list_type, + marshmallow.fields.Mapping: typing.Mapping, + marshmallow.fields.Method: typing.Callable, + # marshmallow.fields.Nested: get_nested_model, + marshmallow.fields.Number: typing.Union[pydantic.StrictFloat, pydantic.StrictInt], + marshmallow.fields.Str: str, + marshmallow.fields.String: str, + marshmallow.fields.Time: datetime.time, + marshmallow.fields.TimeDelta: datetime.timedelta, + marshmallow.fields.URL: pydantic.AnyUrl, + marshmallow.fields.Url: pydantic.AnyUrl, + marshmallow.fields.UUID: str, +} + + +def is_custom_field(field): + """If this is a subclass of marshmallow's Field and not in our list, we + assume its a custom field""" + ftype = type(field) + if issubclass(ftype, marshmallow.fields.Field) and ftype not in FIELD_CONVERTERS: + print(" Custom field") + return True + return False + + +def is_file_field(field): + """If this is a file field, we need to handle it differently.""" + if field is not None and field.metadata.get("type") == "file": + print(" File field") + return True + return False + + +def get_pydantic_type(field, optional=True): + """Get pydantic type from a marshmallow field""" + if field is None: + return typing.Any + elif is_file_field(field): + conv = fastapi.UploadFile + elif is_custom_field(field): + conv = typing.Any + else: + conv = FIELD_CONVERTERS[type(field)] + + # TODO: Is there a cleaner way to check for annotation types? + if isinstance(conv, type) or conv.__module__ == "typing": + pyd_type = conv + else: + pyd_type = conv(field) + + if optional and not field.required: + if is_file_field(field): + # If we have a file field, do not wrap with Optional, as FastAPI does not + # handle it correctly. Instead, we put None as default value later in the + # outer function. + pass + else: + pyd_type = typing.Optional[pyd_type] + + # FIXME(aloga): we need to handle enums + return pyd_type + + +def sanitize_field_name(field_name): + field_name = field_name.replace("-", "_") + field_name = field_name.replace(" ", "_") + field_name = field_name.replace(".", "_") + field_name = field_name.replace(":", "_") + field_name = field_name.replace("/", "_") + field_name = field_name.replace("\\", "_") + return field_name + + +def check_for_file_fields(fields): + for field_name, field in fields.items(): + if is_file_field(field): + return True + return False + + +def pydantic_from_marshmallow( + name: str, schema: marshmallow.Schema +) -> pydantic.BaseModel: + """Convert a marshmallow schema to a pydantic model. + + May only work for fairly simple cases. Barely tested. Enjoy. + """ + + pyd_fields = {} + have_file_fields = check_for_file_fields(schema._declared_fields) + + for field_name, field in schema._declared_fields.items(): + pyd_type = get_pydantic_type(field) + + description = field.metadata.get("description") + + if field.default: + default = field.default + elif field.missing: + default = field.missing + else: + default = None + + if is_file_field(field): + field_cls = fastapi.File + elif have_file_fields: + field_cls = fastapi.Form + else: + field_cls = pydantic.Field + + if field.required and not default: + default = field_cls( + ..., + description=description, + title=field_name, + serialization_alias=field_name, + ) + elif default is None: + if is_file_field(field): + # If we have a file field, it is not wraped with Optional, as FastAPI + # does not handle it correctly (c.f. get_pydantic_type function above). + # Instead, we put None as default value here, and FastAPI will handle it + # correctly. + default = None + else: + default = field_cls( + description=description, + title=field_name, + serialization_alias=field_name, + ) + else: + default = field_cls( + description=description, + default=default, + title=field_name, + serialization_alias=field_name, + ) + + field_name = sanitize_field_name(field_name) + + pyd_fields[field_name] = (pyd_type, default) + + ret = pydantic.create_model( + name, + **pyd_fields, + ) + return ret + + +def get_pydantic_schema_from_marshmallow_fields( + name: str, + fields: dict, +) -> pydantic.BaseModel: + + model = marshmallow.Schema.from_dict(fields) + + pydantic_model = pydantic_from_marshmallow(name, model()) + + return pydantic_model diff --git a/deepaas/api/versions.py b/deepaas/api/versions.py deleted file mode 100644 index 48df6c8d..00000000 --- a/deepaas/api/versions.py +++ /dev/null @@ -1,76 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2019 Spanish National Research Council (CSIC) -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import json - -from aiohttp import web -import aiohttp_apispec - -from deepaas.api.v2 import responses - -app = web.Application() -routes = web.RouteTableDef() - - -@routes.view("/", name="versions") -class Versions(web.View): - versions = {} - - def __init__(self, *args, **kwargs): - super(Versions, self).__init__(*args, **kwargs) - - @aiohttp_apispec.docs( - tags=["versions"], - summary="Get available API versions", - ) - @aiohttp_apispec.response_schema(responses.Versions(), 200) - @aiohttp_apispec.response_schema(responses.Failure(), 400) - async def get(self): - versions = [] - for _ver, info in self.versions.items(): - resp = await info(self.request) - versions.append(json.loads(resp.body)) - - response = {"versions": versions, "links": []} - # But here we use the global app coming in the request - doc = self.request.app.router.named_resources().get("swagger.docs") - if doc: - doc = {"rel": "help", "type": "text/html", "href": "%s" % doc.url_for()} - response["links"].append(doc) - - spec = self.request.app.router.named_resources().get("swagger.spec") - if spec: - spec = { - "rel": "describedby", - "type": "application/json", - "href": "%s" % spec.url_for(), - } - response["links"].append(spec) - - return web.json_response(response) - - -@routes.get("/ui") -async def redirect_ui(request): - doc_url = request.app.router.named_resources().get("swagger.docs").url_for() - return web.HTTPFound(doc_url) - - -def register_version(version, func): - # NOTE(aloga): we could use a @classmethod on Versions, but it fails - # with a TypeError: 'classmethod' object is not callable since the function - # is decorated. - Versions.versions[version] = func diff --git a/deepaas/cmd/run.py b/deepaas/cmd/run.py index d4e37d5a..69cc95bf 100644 --- a/deepaas/cmd/run.py +++ b/deepaas/cmd/run.py @@ -18,8 +18,8 @@ import pathlib import sys -from aiohttp import web from oslo_config import cfg +import uvicorn import deepaas from deepaas import api @@ -103,6 +103,7 @@ def main(): else: base_path = "" + # FIXME(aloga): ensure that these paths are correct base = "http://{}:{}{}".format(CONF.listen_ip, CONF.listen_port, base_path) spec = "{}/swagger.json".format(base) docs = "{}/api".format(base) @@ -111,19 +112,21 @@ def main(): print(INTRO) print(BANNER.format(docs, spec, v2)) - log.info("Starting DEEPaaS version %s", deepaas.extract_version()) + log.info( + "Starting DEEPaaS version %s with FastAPI backend", + deepaas.extract_version(), + ) - app = api.get_app( + print("FastAPI backend is still experimental.") + print("Press Ctrl+C to stop the server.") + app = api.get_fastapi_app( enable_doc=CONF.doc_endpoint, enable_train=CONF.train_endpoint, enable_predict=CONF.predict_endpoint, - base_path=CONF.base_path, - ) - web.run_app( - app, - host=CONF.listen_ip, - port=CONF.listen_port, + base_path=base_path, ) + uvicorn.run(app, host=CONF.listen_ip, port=CONF.listen_port) + log.debug("Shutting down") if __name__ == "__main__": diff --git a/deepaas/model/__init__.py b/deepaas/model/__init__.py index ffb393c3..891cbc47 100644 --- a/deepaas/model/__init__.py +++ b/deepaas/model/__init__.py @@ -16,14 +16,22 @@ from deepaas.model import v2 -V2_MODELS = v2.MODELS +# FIXME(aloga): this is extremely ugly +V2_MODEL = None +V2_MODEL_NAME = None -def register_v2_models(app): +def load_v2_model(): """Register V2 models. This method has to be called before the API is spawned, so that we can look up the correct entry points and load the defined models. """ - return v2.register_models(app) + global V2_MODEL + global V2_MODEL_NAME + + v2.load_model() + + V2_MODEL = v2.MODEL + V2_MODEL_NAME = v2.MODEL_NAME diff --git a/deepaas/model/v2/__init__.py b/deepaas/model/v2/__init__.py index e6e9b88b..22e801b7 100644 --- a/deepaas/model/v2/__init__.py +++ b/deepaas/model/v2/__init__.py @@ -25,15 +25,15 @@ CONF = config.CONF # Model registry -MODELS = {} -MODELS_LOADED = False +MODEL = None +MODEL_NAME = "" -def register_models(app): - global MODELS - global MODELS_LOADED +def load_model(): + global MODEL + global MODEL_NAME - if MODELS_LOADED: + if MODEL: return if CONF.model_name: @@ -53,16 +53,14 @@ def register_models(app): raise exceptions.MultipleModelsFound() try: - MODELS[model_name] = wrapper.ModelWrapper( + MODEL = wrapper.ModelWrapper( model_name, loading.get_model_by_name(model_name, "v2"), - app, ) + MODEL_NAME = model_name except exceptions.ModuleNotFoundError: LOG.error("Model not found: %s", model_name) raise except Exception as e: LOG.exception("Error loading model: %s", e) raise e - - MODELS_LOADED = True diff --git a/deepaas/model/v2/wrapper.py b/deepaas/model/v2/wrapper.py index 196abd8b..e6d09676 100644 --- a/deepaas/model/v2/wrapper.py +++ b/deepaas/model/v2/wrapper.py @@ -17,19 +17,16 @@ import asyncio import collections import contextlib -import datetime import functools import io -import multiprocessing -import multiprocessing.pool import os -import signal import tempfile from aiohttp import web import marshmallow from oslo_config import cfg +from deepaas.api.v2 import utils from deepaas import log LOG = log.getLogger(__name__) @@ -100,18 +97,9 @@ class ModelWrapper(object): a response schema that is not JSON schema valid (DRAFT 4) """ - def __init__(self, name, model_obj, app=None): + def __init__(self, name, model_obj): self.name = name self.model_obj = model_obj - self._app = app - - self._loop = asyncio.get_event_loop() - - self._workers = CONF.workers - self._executor = self._init_executor() - - if self._app is not None: - self._setup_cleanup() schema = getattr(self.model_obj, "schema", None) @@ -123,6 +111,7 @@ def __init__(self, name, model_obj, app=None): self.has_schema = True except Exception as e: LOG.exception(e) + # FIXME(aloga): do not use web exception here raise web.HTTPInternalServerError( reason=("Model defined schema is invalid, " "check server logs.") ) @@ -131,24 +120,21 @@ def __init__(self, name, model_obj, app=None): if issubclass(schema, marshmallow.Schema): self.has_schema = True except TypeError: + # FIXME(aloga): do not use web exception here raise web.HTTPInternalServerError( reason=("Model defined schema is invalid, " "check server logs.") ) else: self.has_schema = False - self.response_schema = schema - - def _setup_cleanup(self): - self._app.on_cleanup.append(self._close_executors) - - async def _close_executors(self, app): - self._executor.shutdown() - - def _init_executor(self): - n = self._workers - executor = CancellablePool(max_workers=n) - return executor + # Now convert to pydantic schema... + # FIXME(aloga): use try except + if schema is not None: + self.response_schema = utils.pydantic_from_marshmallow( + "ModelPredictionResponse", schema + ) + else: + self.response_schema = None @contextlib.contextmanager def _catch_error(self): @@ -269,7 +255,7 @@ def predict_wrap(predict_func, *args, **kwargs): return ret - def predict(self, *args, **kwargs): + async def predict(self, *args, **kwargs): """Perform a prediction on wrapped model's ``predict`` method. :raises HTTPNotImplemented: If the method is not @@ -295,9 +281,7 @@ def predict(self, *args, **kwargs): # FIXME(aloga); cleanup of tmpfile here with self._catch_error(): - return self._run_in_pool( - self.predict_wrap, self.model_obj.predict, *args, **kwargs - ) + return self.predict_wrap(self.model_obj.predict, *args, **kwargs) def train(self, *args, **kwargs): """Perform a training on wrapped model's ``train`` method. @@ -311,7 +295,7 @@ def train(self, *args, **kwargs): """ with self._catch_error(): - return self._run_in_pool(self.model_obj.train, *args, **kwargs) + return self.model_obj.train(*args, **kwargs) def get_train_args(self): """Add training arguments into the training parser. @@ -338,86 +322,3 @@ def get_predict_args(self): except (NotImplementedError, AttributeError): args = {} return args - - -class NonDaemonProcess(multiprocessing.context.SpawnProcess): - """Processes must use 'spawn' instead of 'fork' (which is the default - in Linux) in order to work CUDA [1] or Tensorflow [2]. - - [1] https://pytorch.org/docs/stable/notes/multiprocessing.html - #cuda-in-multiprocessing - [2] https://github.com/tensorflow/tensorflow/issues/5448 - #issuecomment-258934405 - """ - - @property - def daemon(self): - return False - - @daemon.setter - def daemon(self, value): - pass - - -class NonDaemonPool(multiprocessing.pool.Pool): - # Based on https://stackoverflow.com/questions/6974695/ - def Process(self, *args, **kwds): # noqa - proc = super(NonDaemonPool, self).Process(*args, **kwds) - proc.__class__ = NonDaemonProcess - - return proc - - -class CancellablePool(object): - def __init__(self, max_workers=None): - self._free = {self._new_pool() for _ in range(max_workers)} - self._working = set() - self._change = asyncio.Event() - - def _new_pool(self): - return NonDaemonPool(1, context=multiprocessing.get_context("spawn")) - - async def apply(self, fn, *args): - """ - Like multiprocessing.Pool.apply_async, but: - * is an asyncio coroutine - * terminates the process if cancelled - """ - while not self._free: - await self._change.wait() - self._change.clear() - pool = usable_pool = self._free.pop() - self._working.add(pool) - - loop = asyncio.get_event_loop() - fut = loop.create_future() - - def _on_done(obj): - ret = {"output": obj, "finish_date": str(datetime.datetime.now())} - loop.call_soon_threadsafe(fut.set_result, ret) - - def _on_err(err): - loop.call_soon_threadsafe(fut.set_exception, err) - - pool.apply_async(fn, args, callback=_on_done, error_callback=_on_err) - - try: - return await fut - except asyncio.CancelledError: - # This is ugly, but since our pools only have one slot we can - # kill the process before termination - try: - pool._pool[0].kill() - except AttributeError: - os.kill(pool._pool[0].pid, signal.SIGKILL) - pool.terminate() - usable_pool = self._new_pool() - finally: - self._working.remove(pool) - self._free.add(usable_pool) - self._change.set() - - def shutdown(self): - for p in self._working: - p.terminate() - self._free.clear() diff --git a/deepaas/tests/out_test/out_file.json b/deepaas/tests/out_test/out_file.json new file mode 100644 index 00000000..ac3114f0 --- /dev/null +++ b/deepaas/tests/out_test/out_file.json @@ -0,0 +1 @@ +[{'value1': {'pred': 1}, 'value2': {'pred': 0.9}}] diff --git a/deepaas/tests/out_test/out_xxxxxxxxxx.json b/deepaas/tests/out_test/out_xxxxxxxxxx.json new file mode 100644 index 00000000..ac3114f0 --- /dev/null +++ b/deepaas/tests/out_test/out_xxxxxxxxxx.json @@ -0,0 +1 @@ +[{'value1': {'pred': 1}, 'value2': {'pred': 0.9}}] diff --git a/deepaas/tests/out_test/tmp_dir.zip b/deepaas/tests/out_test/tmp_dir.zip new file mode 100644 index 00000000..d91884e6 Binary files /dev/null and b/deepaas/tests/out_test/tmp_dir.zip differ diff --git a/pyproject.toml b/pyproject.toml index 1c8c0249..2d41d0b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "deepaas" -version = "2.4.0" +version = "3.0.0" description = "DEEPaaS is a REST API to expose a machine learning model." authors = ["Alvaro Lopez Garcia "] license = "Apache-2" @@ -56,6 +56,7 @@ aiohttp-apispec = "^2.2.3" Werkzeug = "^3.0.3" marshmallow = "^3.21.3" webargs = "<6.0.0" +fastapi = "^0.111.0" [tool.poetry.group.dev.dependencies] @@ -94,7 +95,7 @@ mypy = "^1.10.0" [tool.poetry.group.test-pypi.dependencies] -twine = "^5.1.0" +poetry = "^1.8.3" [tool.poetry.group.test-pip-missing-reqs.dependencies] diff --git a/tox.ini b/tox.ini index 89ab2787..0036778f 100644 --- a/tox.ini +++ b/tox.ini @@ -90,14 +90,6 @@ exclude = [testenv:flake8] basepython = {[base]python} -deps = - flake8>=4.0,<4.1 - flake8-bugbear>=22.3,<22.4 - ; flake8-docstrings>=1.6,<1.7 - flake8-typing-imports>=1.12,<1.13 - flake8-colors>=0.1,<0.2 - pep8-naming>=0.12,<0.13 - pydocstyle>=6.1,<6.2 commands = poetry run flake8 {[base]package}