diff --git a/deepaas/__init__.py b/deepaas/__init__.py
index 07f79744..c45bf43d 100644
--- a/deepaas/__init__.py
+++ b/deepaas/__init__.py
@@ -18,7 +18,7 @@
import importlib.metadata
from pathlib import Path
-__version__ = "2.4.0"
+__version__ = "3.0.0"
def extract_version() -> str:
diff --git a/deepaas/api/__init__.py b/deepaas/api/__init__.py
index be07064f..9b020c04 100644
--- a/deepaas/api/__init__.py
+++ b/deepaas/api/__init__.py
@@ -14,38 +14,39 @@
# License for the specific language governing permissions and limitations
# under the License.
-import pathlib
+import json
-from aiohttp import web
-import aiohttp_apispec
+import fastapi
+import fastapi.responses
from oslo_config import cfg
import deepaas
from deepaas.api import v2
-from deepaas.api import versions
+from deepaas.api.v2 import responses
from deepaas import log
from deepaas import model
LOG = log.getLogger(__name__)
APP = None
+VERSIONS = {}
CONF = cfg.CONF
LINKS = """
-- [Project website](https://deep-hybrid.datacloud.eu).
-- [Project documentation](https://docs.deep-hybrid.datacloud.eu).
-- [Model marketplace](https://marketplace.deep-hybrid.datacloud.eu).
+- [AI4EOSC Project website](https://ai4eosc.eu).
+- [Project documentation](https://docs.ai4eosc.eu).
+- [API documentation](https://docs.ai4os.eu/deepaas).
+- [AI4EOSC Model marketplace](https://dashboard.cloud.ai4eosc.eu/marketplace).
"""
API_DESCRIPTION = (
"
"
"\n\nThis is a REST API that is focused on providing access "
- "to machine learning models. By using the DEEPaaS API "
- "users can easily run a REST API in front of their model, "
- "thus accessing its functionality via HTTP calls. "
+ "to machine learning models. "
"\n\nCurrently you are browsing the "
"[Swagger UI](https://swagger.io/tools/swagger-ui/) "
"for this API, a tool that allows you to visualize and interact with the "
@@ -53,75 +54,83 @@
) + LINKS
-async def get_app(
- swagger=True,
- enable_doc=True,
- doc="/api",
- prefix="",
- static_path="/static/swagger",
- base_path="",
- enable_train=True,
- enable_predict=True,
-):
- """Get the main app."""
+def get_fastapi_app(
+ enable_doc: bool = True,
+ enable_train: bool = True, # FIXME(aloga): not handled yet
+ enable_predict: bool = True,
+ base_path: str = "",
+) -> fastapi.FastAPI:
+ """Get the main app, based on FastAPI."""
global APP
+ global VERSIONS
if APP:
return APP
- APP = web.Application(debug=CONF.debug, client_max_size=CONF.client_max_size)
+ APP = fastapi.FastAPI(
+ title="Model serving API endpoint",
+ description=API_DESCRIPTION,
+ version=deepaas.extract_version(),
+ docs_url=f"{base_path}/docs" if enable_doc else None, # NOTE(aloga): changed
+ redoc_url=f"{base_path}/redoc" if enable_doc else None, # NOTE(aloga): new
+ openapi_url=f"{base_path}/openapi.json", # NOTE(aloga): changed
+ )
- APP.middlewares.append(web.normalize_path_middleware())
+ model.load_v2_model()
+ LOG.info("Serving loaded V2 model: %s", model.V2_MODEL_NAME)
- model.register_v2_models(APP)
+ if CONF.warm:
+ LOG.debug("Warming models...")
+ model.V2_MODEL.warm()
+
+ v2app = v2.get_app(
+ # FIXME(aloga): these have no effect now, remove.
+ enable_train=enable_train,
+ enable_predict=enable_predict,
+ )
+
+ APP.include_router(v2app, prefix=f"{base_path}/v2", tags=["v2"])
+ VERSIONS["v2"] = v2.get_v2_version
+
+ APP.add_api_route(
+ f"{base_path}/",
+ get_root,
+ methods=["GET"],
+ summary="Get API version information",
+ tags=["version"],
+ response_model=responses.VersionsAndLinks,
+ )
- v2app = v2.get_app(enable_train=enable_train, enable_predict=enable_predict)
- if base_path:
- path = str(pathlib.Path(base_path) / "v2")
- else:
- path = "/v2"
- APP.add_subapp(path, v2app)
- versions.register_version("stable", v2.get_version)
+ return APP
- if base_path:
- # Get versions.routes, and transform them to have the base_path, as we cannot
- # directly modify the routes already created and stored in the RouteTableDef
- for route in versions.routes:
- APP.router.add_route(
- route.method, str(pathlib.Path(base_path + route.path)), route.handler
- )
- else:
- APP.add_routes(versions.routes)
- LOG.info("Serving loaded V2 models: %s", list(model.V2_MODELS.keys()))
+async def get_root(request: fastapi.Request) -> fastapi.responses.JSONResponse:
+ versions = []
+ for _ver, info in VERSIONS.items():
+ resp = await info(request)
+ versions.append(json.loads(resp.body))
- if CONF.warm:
- for _, m in model.V2_MODELS.items():
- LOG.debug("Warming models...")
- await m.warm()
-
- if swagger:
- doc = str(pathlib.Path(base_path + doc))
- swagger = str(pathlib.Path(base_path + "/swagger.json"))
- static_path = str(pathlib.Path(base_path + static_path))
-
- # init docs with all parameters, usual for ApiSpec
- aiohttp_apispec.setup_aiohttp_apispec(
- app=APP,
- title="DEEP as a Service API endpoint",
- info={
- "description": API_DESCRIPTION,
- },
- externalDocs={
- "description": "API documentation",
- "url": "https://deepaas.readthedocs.org/",
- },
- version=deepaas.extract_version(),
- url=swagger,
- swagger_path=doc if enable_doc else None,
- prefix=prefix,
- static_path=static_path,
- in_place=True,
- )
+ root = str(request.url_for("get_root"))
- return APP
+ response = {"versions": versions, "links": []}
+
+ doc = APP.docs_url.strip("/")
+ if doc:
+ doc = {"rel": "help", "type": "text/html", "href": f"{root}{doc}"}
+ response["links"].append(doc)
+
+ redoc = APP.redoc_url.strip("/")
+ if redoc:
+ redoc = {"rel": "help", "type": "text/html", "href": f"{root}{redoc}"}
+ response["links"].append(redoc)
+
+ spec = APP.openapi_url.strip("/")
+ if spec:
+ spec = {
+ "rel": "describedby",
+ "type": "application/json",
+ "href": f"{root}{spec}",
+ }
+ response["links"].append(spec)
+
+ return fastapi.responses.JSONResponse(content=response)
diff --git a/deepaas/api/v2/__init__.py b/deepaas/api/v2/__init__.py
index 8064681c..40b80e58 100644
--- a/deepaas/api/v2/__init__.py
+++ b/deepaas/api/v2/__init__.py
@@ -14,50 +14,57 @@
# License for the specific language governing permissions and limitations
# under the License.
-from aiohttp import web
-import aiohttp_apispec
+import fastapi
+import fastapi.responses
from oslo_config import cfg
from deepaas.api.v2 import debug as v2_debug
from deepaas.api.v2 import models as v2_model
from deepaas.api.v2 import predict as v2_predict
from deepaas.api.v2 import responses
-from deepaas.api.v2 import train as v2_train
+
+# from deepaas.api.v2 import train as v2_train
from deepaas import log
CONF = cfg.CONF
LOG = log.getLogger("deepaas.api.v2")
+# XXX
APP = None
def get_app(enable_train=True, enable_predict=True):
global APP
- APP = web.Application()
+ # FIXME(aloga): check we cat get rid of global variables
+ APP = fastapi.APIRouter()
v2_debug.setup_debug()
- APP.router.add_get("/", get_version, name="v2", allow_head=False)
- v2_debug.setup_routes(APP)
- v2_model.setup_routes(APP)
- v2_train.setup_routes(APP, enable=enable_train)
- v2_predict.setup_routes(APP, enable=enable_predict)
+ APP.include_router(v2_debug.get_router(), tags=["debug"])
+ APP.include_router(v2_model.get_router(), tags=["models"])
+ if enable_predict:
+ APP.include_router(v2_predict.get_router(), tags=["predict"])
+
+ # APP.router.add_get("/", get_version, name="v2", allow_head=False)
+ # v2_debug.setup_routes(APP)
+ # v2_model.setup_routes(APP)
+ # v2_train.setup_routes(APP, enable=enable_train)
+ # v2_predict.setup_routes(APP, enable=enable_predict)
+
+ APP.add_api_route(
+ "/",
+ get_v2_version,
+ methods=["GET"],
+ tags=["version"],
+ response_model=responses.Versions,
+ )
return APP
-@aiohttp_apispec.docs(
- tags=["versions"],
- summary="Get V2 API version information",
-)
-@aiohttp_apispec.response_schema(responses.Version(), 200)
-@aiohttp_apispec.response_schema(responses.Failure(), 400)
-async def get_version(request):
- # NOTE(aloga): we use the router table from this application (i.e. the
- # global APP in this module) to be able to build the correct url, as it can
- # be prefixed outside of this module (in an add_subapp() call)
- root = APP.router["v2"].url_for()
+async def get_v2_version(request: fastapi.Request) -> fastapi.responses.JSONResponse:
+ root = str(request.url_for("get_v2_version"))
version = {
"version": "stable",
"id": "v2",
@@ -65,9 +72,8 @@ async def get_version(request):
{
"rel": "self",
"type": "application/json",
- "href": "%s" % root,
+ "href": f"{root}",
}
],
}
-
- return web.json_response(version)
+ return fastapi.responses.JSONResponse(content=version)
diff --git a/deepaas/api/v2/debug.py b/deepaas/api/v2/debug.py
index e380c04d..e5e37d56 100644
--- a/deepaas/api/v2/debug.py
+++ b/deepaas/api/v2/debug.py
@@ -20,16 +20,15 @@
import sys
import warnings
-from aiohttp import web
-import aiohttp_apispec
+import fastapi
from oslo_config import cfg
from deepaas import log
CONF = cfg.CONF
-app = web.Application()
-routes = web.RouteTableDef()
+router = fastapi.APIRouter(prefix="/debug")
+
# Ugly global variable to provide a string stream to read the DEBUG output
# if it is enabled
@@ -52,6 +51,9 @@ def close(self):
for f in self.handles:
f.close()
+ def isatty(self):
+ return all(f.isatty() for f in self.handles)
+
def setup_debug():
global DEBUG_STREAM
@@ -59,7 +61,7 @@ def setup_debug():
if CONF.debug_endpoint:
DEBUG_STREAM = io.StringIO()
- logger = log.getLogger("deepaas").logger
+ logger = log.getLogger("deepaas")
hdlr = logging.StreamHandler(DEBUG_STREAM)
hdlr.setFormatter(
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
@@ -78,23 +80,33 @@ def setup_debug():
sys.stderr = MultiOut(DEBUG_STREAM, sys.stderr)
-@aiohttp_apispec.docs(
+@router.get(
+ "/",
+ summary="Return debug information if enabled by API.",
+ description="Return debug information if enabled by API.",
tags=["debug"],
- summary="""Return debug information if enabled by API.""",
- description="""Return debug information if enabled by API.""",
- produces=["text/plain"],
+ response_class=fastapi.responses.PlainTextResponse,
responses={
- 200: {"description": "Debug information if debug endpoint is enabled"},
- 204: {"description": "Debug endpoint not enabled"},
+ "200": {
+ "content": {"text/plain": {}},
+ "description": "Debug information if debug endpoint is enabled",
+ },
+ "204": {"description": "Debug endpoint not enabled"},
},
)
-async def get(request):
+async def get():
if DEBUG_STREAM is not None:
print("--- DEBUG MARKER %s ---" % datetime.datetime.now())
resp = DEBUG_STREAM.getvalue()
- return web.Response(text=resp)
- return web.HTTPNoContent()
+ return fastapi.responses.PlainTextResponse(resp)
+ else:
+ return fastapi.responses.Response(status_code=204)
+
+def get_router() -> fastapi.APIRouter:
+ """Auxiliary function to get the router.
-def setup_routes(app):
- app.router.add_get("/debug/", get, allow_head=False)
+ We use this function to be able to include the router in the main
+ application and do things before it gets included.
+ """
+ return router
diff --git a/deepaas/api/v2/models.py b/deepaas/api/v2/models.py
index 406f2642..93b0b8ae 100644
--- a/deepaas/api/v2/models.py
+++ b/deepaas/api/v2/models.py
@@ -14,52 +14,55 @@
# License for the specific language governing permissions and limitations
# under the License.
-import urllib.parse
-
-from aiohttp import web
-import aiohttp_apispec
+import fastapi
from deepaas.api.v2 import responses
from deepaas import model
-@aiohttp_apispec.docs(
- tags=["models"],
+router = fastapi.APIRouter(prefix="/models")
+
+
+@router.get(
+ "/",
summary="Return loaded models and its information",
- description="DEEPaaS can load several models and server them on the same "
- "endpoint, making a call to the root of the models namespace "
- "will return the loaded models, as long as their basic "
- "metadata.",
+ description="Return list of DEEPaaS loaded models. In previous versions, DEEPaaS "
+ "could load several models and serve them on the same endpoint.",
+ tags=["models"],
+ response_model=responses.ModelList,
)
-@aiohttp_apispec.response_schema(responses.ModelMeta(), 200)
-async def index(request):
- """Return loaded models and its information.
+async def index_models(
+ request: fastapi.Request,
+):
+ """Return loaded models and its information."""
+
+ name = model.V2_MODEL_NAME
+ model_obj = model.V2_MODEL
+ m = {
+ "id": name,
+ "name": name,
+ "links": [
+ {
+ "rel": "self",
+ "href": str(request.url_for("get_model/" + name)),
+ }
+ ],
+ }
+ meta = model_obj.get_metadata()
+ m.update(meta)
+ return {"models": [m]}
+
+
+def _get_handler_for_model(model_name, model_obj):
+ """Auxiliary function to get the handler for a model.
- DEEPaaS can load several models and server them on the same endpoint,
- making a call to the root of the models namespace will return the
- loaded models, as long as their basic metadata.
+ This function returns a handler for a model that can be used to
+ register the routes in the router.
"""
- models = []
- for name, obj in model.V2_MODELS.items():
- m = {
- "id": name,
- "name": name,
- "links": [
- {
- "rel": "self",
- "href": urllib.parse.urljoin("%s/" % request.path, name),
- }
- ],
- }
- meta = obj.get_metadata()
- m.update(meta)
- models.append(m)
- return web.json_response({"models": models})
-
-
-def _get_handler(model_name, model_obj):
class Handler(object):
+ """Class to handle the model metadata endpoints."""
+
model_name = None
model_obj = None
@@ -67,36 +70,50 @@ def __init__(self, model_name, model_obj):
self.model_name = model_name
self.model_obj = model_obj
- @aiohttp_apispec.docs(
- tags=["models"],
- summary="Return model's metadata",
- )
- @aiohttp_apispec.response_schema(responses.ModelMeta(), 200)
- async def get(self, request):
+ async def get(self, request: fastapi.Request):
+ """Return model's metadata."""
m = {
"id": self.model_name,
"name": self.model_name,
"links": [
{
"rel": "self",
- "href": request.path.rstrip("/"),
+ "href": str(request.url),
}
],
}
meta = self.model_obj.get_metadata()
m.update(meta)
- return web.json_response(m)
+ return m
+
+ def register_routes(self, router):
+ """Register routes for the model in the router."""
+ router.add_api_route(
+ f"/{self.model_name}",
+ self.get,
+ name="get_model/" + self.model_name,
+ summary="Return model's metadata",
+ tags=["models"],
+ response_model=responses.ModelMeta,
+ )
return Handler(model_name, model_obj)
-def setup_routes(app):
- app.router.add_get("/models/", index, allow_head=False)
+def get_router() -> fastapi.APIRouter:
+ """Auxiliary function to get the router.
+
+ We use this function to be able to include the router in the main
+ application and do things before it gets included.
+
+ In this case we explicitly include the model's endpoints.
+
+ """
+ model_name = model.V2_MODEL_NAME
+ model_obj = model.V2_MODEL
+
+ hdlr = _get_handler_for_model(model_name, model_obj)
+ hdlr.register_routes(router)
- # In the next lines we iterate over the loaded models and create the
- # different resources for each model. This way we can also load the
- # expected parameters if needed (as in the training method).
- for model_name, model_obj in model.V2_MODELS.items():
- hdlr = _get_handler(model_name, model_obj)
- app.router.add_get("/models/%s/" % model_name, hdlr.get, allow_head=False)
+ return router
diff --git a/deepaas/api/v2/predict.py b/deepaas/api/v2/predict.py
index 8141469e..e6af37ae 100644
--- a/deepaas/api/v2/predict.py
+++ b/deepaas/api/v2/predict.py
@@ -14,10 +14,9 @@
# License for the specific language governing permissions and limitations
# under the License.
-from aiohttp import web
-import aiohttp_apispec
-from webargs import aiohttpparser
-import webargs.core
+import fastapi
+import fastapi.encoders
+import fastapi.exceptions
from deepaas.api.v2 import responses
from deepaas.api.v2 import utils
@@ -33,20 +32,26 @@ def _get_model_response(model_name, model_obj):
return responses.Prediction
-def _get_handler(model_name, model_obj):
- aux = model_obj.get_predict_args()
- accept = aux.get("accept", None)
- if accept:
- accept.validate.choices.append("*/*")
- accept.load_default = accept.validate.choices[0]
- accept.location = "headers"
+router = fastapi.APIRouter(prefix="/models")
- handler_args = webargs.core.dict2schema(aux)
- handler_args.opts.ordered = True
- response = _get_model_response(model_name, model_obj)
+def _get_handler_for_model(model_name, model_obj):
+ """Auxiliary function to get the handler for a model.
+
+ This function returns a handler for a model that can be used to
+ register the routes in the router.
+
+ """
+
+ user_declared_args = model_obj.get_predict_args()
+ pydantic_schema = utils.get_pydantic_schema_from_marshmallow_fields(
+ "PydanticSchema",
+ user_declared_args,
+ )
class Handler(object):
+ """Class to handle the model metadata endpoints."""
+
model_name = None
model_obj = None
@@ -54,47 +59,54 @@ def __init__(self, model_name, model_obj):
self.model_name = model_name
self.model_obj = model_obj
- @aiohttp_apispec.docs(
- tags=["models"],
- summary="Make a prediction given the input data",
- produces=accept.validate.choices if accept else None,
- )
- @aiohttp_apispec.querystring_schema(handler_args)
- @aiohttp_apispec.response_schema(response(), 200)
- @aiohttp_apispec.response_schema(responses.Failure(), 400)
- async def post(self, request):
- args = await aiohttpparser.parser.parse(handler_args, request)
- task = self.model_obj.predict(**args)
- await task
-
- ret = task.result()["output"]
+ async def predict(self, args: pydantic_schema = fastapi.Depends()):
+ """Make a prediction given the input data."""
+ dict_args = args.model_dump(by_alias=True)
+
+ ret = await self.model_obj.predict(**args.model_dump(by_alias=True))
if isinstance(ret, model.v2.wrapper.ReturnedFile):
ret = open(ret.filename, "rb")
- accept = args.get("accept", "application/json")
- if accept not in ["application/json", "*/*"]:
- response = web.Response(
- body=ret,
- content_type=accept,
- )
- return response
if self.model_obj.has_schema:
- self.model_obj.validate_response(ret)
- return web.json_response(ret)
+ # FIXME(aloga): Validation does not work, as we are converting from
+ # Marshmallow to Pydantic, check this as son as possible.
+ # self.model_obj.validate_response(ret)
+ return fastapi.responses.JSONResponse(ret)
+
+ return fastapi.responses.JSONResponse(
+ content={"status": "OK", "predictions": ret}
+ )
+
+ def register_routes(self, router):
+ """Register the routes in the router."""
- return web.json_response({"status": "OK", "predictions": ret})
+ response = _get_model_response(self.model_name, self.model_obj)
+
+ router.add_api_route(
+ f"/{self.model_name}/predict",
+ self.predict,
+ methods=["POST"],
+ response_model=response,
+ tags=["models", "predict"],
+ )
return Handler(model_name, model_obj)
-def setup_routes(app, enable=True):
- # In the next lines we iterate over the loaded models and create the
- # different resources for each model. This way we can also load the
- # expected parameters if needed (as in the training method).
- for model_name, model_obj in model.V2_MODELS.items():
- if enable:
- hdlr = _get_handler(model_name, model_obj)
- else:
- hdlr = utils.NotEnabledHandler()
- app.router.add_post("/models/%s/predict/" % model_name, hdlr.post)
+def get_router():
+ """Auxiliary function to get the router.
+
+ We use this function to be able to include the router in the main
+ application and do things before it gets included.
+
+ In this case we explicitly include the model precit endpoint.
+
+ """
+ model_name = model.V2_MODEL_NAME
+ model_obj = model.V2_MODEL
+
+ hdlr = _get_handler_for_model(model_name, model_obj)
+ hdlr.register_routes(router)
+
+ return router
diff --git a/deepaas/api/v2/responses.py b/deepaas/api/v2/responses.py
index 9bbff1de..aceb6c43 100644
--- a/deepaas/api/v2/responses.py
+++ b/deepaas/api/v2/responses.py
@@ -14,59 +14,92 @@
# License for the specific language governing permissions and limitations
# under the License.
-import marshmallow
-from marshmallow import fields
-from marshmallow import validate
+import typing
+import pydantic
-class Location(marshmallow.Schema):
- rel = fields.Str(required=True)
- href = fields.Url(required=True)
- type = fields.Str(required=True)
+# class Training(marshmallow.Schema):
+# uuid = fields.UUID(required=True, description="Training identifier")
+# date = fields.DateTime(required=True, description="Training start time")
+# status = fields.Str(
+# required=True,
+# description="Training status",
+# enum=["running", "error", "completed", "cancelled"],
+# validate=validate.OneOf(["running", "error", "completed", "cancelled"]),
+# )
+# message = fields.Str(description="Optional message explaining status")
-class Version(marshmallow.Schema):
- version = fields.Str(required="True")
- id = fields.Str(required="True")
- links = fields.Nested(Location)
- type = fields.Str()
+# class TrainingList(marshmallow.Schema):
+# trainings = fields.List(fields.Nested(Training))
-class Versions(marshmallow.Schema):
- versions = fields.List(fields.Nested(Version))
+# Pydantic models for the API
-class Failure(marshmallow.Schema):
- message = fields.Str(required=True, description="Failure message")
+class Version(pydantic.BaseModel):
+ version: str
+ id: str
+ type: str = "application/json"
-class Prediction(marshmallow.Schema):
- status = fields.String(required=True, description="Response status message")
- predictions = fields.Str(required=True, description="String containing predictions")
+class Versions(pydantic.BaseModel):
+ versions: typing.List[Version]
-class ModelMeta(marshmallow.Schema):
- id = fields.Str(required=True, description="Model identifier") # noqa
- name = fields.Str(required=True, description="Model name")
- description = fields.Str(required=True, description="Model description")
- license = fields.Str(required=False, description="Model license")
- author = fields.Str(required=False, description="Model author")
- version = fields.Str(required=False, description="Model version")
- url = fields.Str(required=False, description="Model url")
- links = fields.List(fields.Nested(Location))
+class Location(pydantic.BaseModel):
+ rel: str
+ href: pydantic.AnyHttpUrl
+ type: str = "application/json"
-class Training(marshmallow.Schema):
- uuid = fields.UUID(required=True, description="Training identifier")
- date = fields.DateTime(required=True, description="Training start time")
- status = fields.Str(
- required=True,
- description="Training status",
- enum=["running", "error", "completed", "cancelled"],
- validate=validate.OneOf(["running", "error", "completed", "cancelled"]),
+
+class VersionsAndLinks(pydantic.BaseModel):
+ versions: typing.List[Version]
+ links: typing.List[Location]
+
+
+class ModelMeta(pydantic.BaseModel):
+ """ "V2 model metadata.
+
+ This class is used to represent the metadata of a model in the V2 API, as we were
+ doing in previous versions.
+ """
+
+ id: str = pydantic.Field(..., description="Model identifier") # noqa
+ name: str = pydantic.Field(..., description="Model name")
+ description: typing.Optional[str] = pydantic.Field(
+ description="Model description", default=None
+ )
+ summary: typing.Optional[str] = pydantic.Field(
+ description="Model summary", default=None
+ )
+ license: typing.Optional[str] = pydantic.Field(
+ description="Model license", default=None
+ )
+ author: typing.Optional[str] = pydantic.Field(
+ description="Model author", default=None
)
- message = fields.Str(description="Optional message explaining status")
+ version: typing.Optional[str] = pydantic.Field(
+ description="Model version", default=None
+ )
+ url: typing.Optional[str] = pydantic.Field(description="Model url", default=None)
+ # Links can be alist of Locations, or an empty list
+ links: typing.List[Location] = pydantic.Field(
+ description="Model links",
+ )
+
+
+class ModelList(pydantic.BaseModel):
+ models: typing.List[ModelMeta] = pydantic.Field(
+ ..., description="List of loaded models"
+ )
+
+
+class Prediction(pydantic.BaseModel):
+ status: str = pydantic.Field(description="Response status message")
+ predictions: str = pydantic.Field(description="String containing predictions")
-class TrainingList(marshmallow.Schema):
- trainings = fields.List(fields.Nested(Training))
+class Failure(pydantic.BaseModel):
+ message: str = pydantic.Field(description="Failure message")
diff --git a/deepaas/api/v2/utils.py b/deepaas/api/v2/utils.py
index b66b5cc1..cb4b746a 100644
--- a/deepaas/api/v2/utils.py
+++ b/deepaas/api/v2/utils.py
@@ -14,7 +14,16 @@
# License for the specific language governing permissions and limitations
# under the License.
+import datetime
+import decimal
+import typing
+
+import fastapi
from aiohttp import web
+import marshmallow
+import marshmallow.fields
+import pydantic
+import pydantic.utils
class NotEnabledHandler(object):
@@ -23,3 +32,204 @@ async def f(*args, **kwargs):
raise web.HTTPPaymentRequired()
return f
+
+
+# Convert marshmallow fields to pydantic fields
+
+
+CUSTOM_FIELD_DEFAULT = typing.Any
+
+
+def get_dict_type(x):
+ """For dicts we need to look at the key and value type"""
+ key_type = get_pydantic_type(x.key_field)
+ if x.value_field:
+ value_type = get_pydantic_type(x.value_field)
+ return typing.Dict[key_type, value_type]
+ return typing.Dict[key_type, typing.Any]
+
+
+def get_list_type(x):
+ """For lists we need to look at the value type"""
+ if x.inner:
+ c_type = get_pydantic_type(x.inner, optional=False)
+ return typing.List[c_type]
+ return typing.List
+
+
+# def get_nested_model(x):
+# """Return a model from a nested marshmallow schema"""
+# return pydantic_from_marshmallow(x.schema)
+
+
+FIELD_CONVERTERS = {
+ marshmallow.fields.Bool: bool,
+ marshmallow.fields.Boolean: bool,
+ marshmallow.fields.Date: datetime.date,
+ marshmallow.fields.DateTime: datetime.datetime,
+ marshmallow.fields.Decimal: decimal.Decimal,
+ marshmallow.fields.Dict: get_dict_type,
+ marshmallow.fields.Email: pydantic.EmailStr,
+ marshmallow.fields.Float: float,
+ marshmallow.fields.Function: typing.Callable,
+ marshmallow.fields.Int: int,
+ marshmallow.fields.Integer: int,
+ marshmallow.fields.List: get_list_type,
+ marshmallow.fields.Mapping: typing.Mapping,
+ marshmallow.fields.Method: typing.Callable,
+ # marshmallow.fields.Nested: get_nested_model,
+ marshmallow.fields.Number: typing.Union[pydantic.StrictFloat, pydantic.StrictInt],
+ marshmallow.fields.Str: str,
+ marshmallow.fields.String: str,
+ marshmallow.fields.Time: datetime.time,
+ marshmallow.fields.TimeDelta: datetime.timedelta,
+ marshmallow.fields.URL: pydantic.AnyUrl,
+ marshmallow.fields.Url: pydantic.AnyUrl,
+ marshmallow.fields.UUID: str,
+}
+
+
+def is_custom_field(field):
+ """If this is a subclass of marshmallow's Field and not in our list, we
+ assume its a custom field"""
+ ftype = type(field)
+ if issubclass(ftype, marshmallow.fields.Field) and ftype not in FIELD_CONVERTERS:
+ print(" Custom field")
+ return True
+ return False
+
+
+def is_file_field(field):
+ """If this is a file field, we need to handle it differently."""
+ if field is not None and field.metadata.get("type") == "file":
+ print(" File field")
+ return True
+ return False
+
+
+def get_pydantic_type(field, optional=True):
+ """Get pydantic type from a marshmallow field"""
+ if field is None:
+ return typing.Any
+ elif is_file_field(field):
+ conv = fastapi.UploadFile
+ elif is_custom_field(field):
+ conv = typing.Any
+ else:
+ conv = FIELD_CONVERTERS[type(field)]
+
+ # TODO: Is there a cleaner way to check for annotation types?
+ if isinstance(conv, type) or conv.__module__ == "typing":
+ pyd_type = conv
+ else:
+ pyd_type = conv(field)
+
+ if optional and not field.required:
+ if is_file_field(field):
+ # If we have a file field, do not wrap with Optional, as FastAPI does not
+ # handle it correctly. Instead, we put None as default value later in the
+ # outer function.
+ pass
+ else:
+ pyd_type = typing.Optional[pyd_type]
+
+ # FIXME(aloga): we need to handle enums
+ return pyd_type
+
+
+def sanitize_field_name(field_name):
+ field_name = field_name.replace("-", "_")
+ field_name = field_name.replace(" ", "_")
+ field_name = field_name.replace(".", "_")
+ field_name = field_name.replace(":", "_")
+ field_name = field_name.replace("/", "_")
+ field_name = field_name.replace("\\", "_")
+ return field_name
+
+
+def check_for_file_fields(fields):
+ for field_name, field in fields.items():
+ if is_file_field(field):
+ return True
+ return False
+
+
+def pydantic_from_marshmallow(
+ name: str, schema: marshmallow.Schema
+) -> pydantic.BaseModel:
+ """Convert a marshmallow schema to a pydantic model.
+
+ May only work for fairly simple cases. Barely tested. Enjoy.
+ """
+
+ pyd_fields = {}
+ have_file_fields = check_for_file_fields(schema._declared_fields)
+
+ for field_name, field in schema._declared_fields.items():
+ pyd_type = get_pydantic_type(field)
+
+ description = field.metadata.get("description")
+
+ if field.default:
+ default = field.default
+ elif field.missing:
+ default = field.missing
+ else:
+ default = None
+
+ if is_file_field(field):
+ field_cls = fastapi.File
+ elif have_file_fields:
+ field_cls = fastapi.Form
+ else:
+ field_cls = pydantic.Field
+
+ if field.required and not default:
+ default = field_cls(
+ ...,
+ description=description,
+ title=field_name,
+ serialization_alias=field_name,
+ )
+ elif default is None:
+ if is_file_field(field):
+ # If we have a file field, it is not wraped with Optional, as FastAPI
+ # does not handle it correctly (c.f. get_pydantic_type function above).
+ # Instead, we put None as default value here, and FastAPI will handle it
+ # correctly.
+ default = None
+ else:
+ default = field_cls(
+ description=description,
+ title=field_name,
+ serialization_alias=field_name,
+ )
+ else:
+ default = field_cls(
+ description=description,
+ default=default,
+ title=field_name,
+ serialization_alias=field_name,
+ )
+
+ field_name = sanitize_field_name(field_name)
+
+ pyd_fields[field_name] = (pyd_type, default)
+
+ ret = pydantic.create_model(
+ name,
+ **pyd_fields,
+ )
+ return ret
+
+
+def get_pydantic_schema_from_marshmallow_fields(
+ name: str,
+ fields: dict,
+) -> pydantic.BaseModel:
+
+ model = marshmallow.Schema.from_dict(fields)
+
+ pydantic_model = pydantic_from_marshmallow(name, model())
+
+ return pydantic_model
diff --git a/deepaas/api/versions.py b/deepaas/api/versions.py
deleted file mode 100644
index 48df6c8d..00000000
--- a/deepaas/api/versions.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# -*- coding: utf-8 -*-
-
-# Copyright 2019 Spanish National Research Council (CSIC)
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-import json
-
-from aiohttp import web
-import aiohttp_apispec
-
-from deepaas.api.v2 import responses
-
-app = web.Application()
-routes = web.RouteTableDef()
-
-
-@routes.view("/", name="versions")
-class Versions(web.View):
- versions = {}
-
- def __init__(self, *args, **kwargs):
- super(Versions, self).__init__(*args, **kwargs)
-
- @aiohttp_apispec.docs(
- tags=["versions"],
- summary="Get available API versions",
- )
- @aiohttp_apispec.response_schema(responses.Versions(), 200)
- @aiohttp_apispec.response_schema(responses.Failure(), 400)
- async def get(self):
- versions = []
- for _ver, info in self.versions.items():
- resp = await info(self.request)
- versions.append(json.loads(resp.body))
-
- response = {"versions": versions, "links": []}
- # But here we use the global app coming in the request
- doc = self.request.app.router.named_resources().get("swagger.docs")
- if doc:
- doc = {"rel": "help", "type": "text/html", "href": "%s" % doc.url_for()}
- response["links"].append(doc)
-
- spec = self.request.app.router.named_resources().get("swagger.spec")
- if spec:
- spec = {
- "rel": "describedby",
- "type": "application/json",
- "href": "%s" % spec.url_for(),
- }
- response["links"].append(spec)
-
- return web.json_response(response)
-
-
-@routes.get("/ui")
-async def redirect_ui(request):
- doc_url = request.app.router.named_resources().get("swagger.docs").url_for()
- return web.HTTPFound(doc_url)
-
-
-def register_version(version, func):
- # NOTE(aloga): we could use a @classmethod on Versions, but it fails
- # with a TypeError: 'classmethod' object is not callable since the function
- # is decorated.
- Versions.versions[version] = func
diff --git a/deepaas/cmd/run.py b/deepaas/cmd/run.py
index d4e37d5a..69cc95bf 100644
--- a/deepaas/cmd/run.py
+++ b/deepaas/cmd/run.py
@@ -18,8 +18,8 @@
import pathlib
import sys
-from aiohttp import web
from oslo_config import cfg
+import uvicorn
import deepaas
from deepaas import api
@@ -103,6 +103,7 @@ def main():
else:
base_path = ""
+ # FIXME(aloga): ensure that these paths are correct
base = "http://{}:{}{}".format(CONF.listen_ip, CONF.listen_port, base_path)
spec = "{}/swagger.json".format(base)
docs = "{}/api".format(base)
@@ -111,19 +112,21 @@ def main():
print(INTRO)
print(BANNER.format(docs, spec, v2))
- log.info("Starting DEEPaaS version %s", deepaas.extract_version())
+ log.info(
+ "Starting DEEPaaS version %s with FastAPI backend",
+ deepaas.extract_version(),
+ )
- app = api.get_app(
+ print("FastAPI backend is still experimental.")
+ print("Press Ctrl+C to stop the server.")
+ app = api.get_fastapi_app(
enable_doc=CONF.doc_endpoint,
enable_train=CONF.train_endpoint,
enable_predict=CONF.predict_endpoint,
- base_path=CONF.base_path,
- )
- web.run_app(
- app,
- host=CONF.listen_ip,
- port=CONF.listen_port,
+ base_path=base_path,
)
+ uvicorn.run(app, host=CONF.listen_ip, port=CONF.listen_port)
+ log.debug("Shutting down")
if __name__ == "__main__":
diff --git a/deepaas/model/__init__.py b/deepaas/model/__init__.py
index ffb393c3..891cbc47 100644
--- a/deepaas/model/__init__.py
+++ b/deepaas/model/__init__.py
@@ -16,14 +16,22 @@
from deepaas.model import v2
-V2_MODELS = v2.MODELS
+# FIXME(aloga): this is extremely ugly
+V2_MODEL = None
+V2_MODEL_NAME = None
-def register_v2_models(app):
+def load_v2_model():
"""Register V2 models.
This method has to be called before the API is spawned, so that we
can look up the correct entry points and load the defined models.
"""
- return v2.register_models(app)
+ global V2_MODEL
+ global V2_MODEL_NAME
+
+ v2.load_model()
+
+ V2_MODEL = v2.MODEL
+ V2_MODEL_NAME = v2.MODEL_NAME
diff --git a/deepaas/model/v2/__init__.py b/deepaas/model/v2/__init__.py
index e6e9b88b..22e801b7 100644
--- a/deepaas/model/v2/__init__.py
+++ b/deepaas/model/v2/__init__.py
@@ -25,15 +25,15 @@
CONF = config.CONF
# Model registry
-MODELS = {}
-MODELS_LOADED = False
+MODEL = None
+MODEL_NAME = ""
-def register_models(app):
- global MODELS
- global MODELS_LOADED
+def load_model():
+ global MODEL
+ global MODEL_NAME
- if MODELS_LOADED:
+ if MODEL:
return
if CONF.model_name:
@@ -53,16 +53,14 @@ def register_models(app):
raise exceptions.MultipleModelsFound()
try:
- MODELS[model_name] = wrapper.ModelWrapper(
+ MODEL = wrapper.ModelWrapper(
model_name,
loading.get_model_by_name(model_name, "v2"),
- app,
)
+ MODEL_NAME = model_name
except exceptions.ModuleNotFoundError:
LOG.error("Model not found: %s", model_name)
raise
except Exception as e:
LOG.exception("Error loading model: %s", e)
raise e
-
- MODELS_LOADED = True
diff --git a/deepaas/model/v2/wrapper.py b/deepaas/model/v2/wrapper.py
index 196abd8b..e6d09676 100644
--- a/deepaas/model/v2/wrapper.py
+++ b/deepaas/model/v2/wrapper.py
@@ -17,19 +17,16 @@
import asyncio
import collections
import contextlib
-import datetime
import functools
import io
-import multiprocessing
-import multiprocessing.pool
import os
-import signal
import tempfile
from aiohttp import web
import marshmallow
from oslo_config import cfg
+from deepaas.api.v2 import utils
from deepaas import log
LOG = log.getLogger(__name__)
@@ -100,18 +97,9 @@ class ModelWrapper(object):
a response schema that is not JSON schema valid (DRAFT 4)
"""
- def __init__(self, name, model_obj, app=None):
+ def __init__(self, name, model_obj):
self.name = name
self.model_obj = model_obj
- self._app = app
-
- self._loop = asyncio.get_event_loop()
-
- self._workers = CONF.workers
- self._executor = self._init_executor()
-
- if self._app is not None:
- self._setup_cleanup()
schema = getattr(self.model_obj, "schema", None)
@@ -123,6 +111,7 @@ def __init__(self, name, model_obj, app=None):
self.has_schema = True
except Exception as e:
LOG.exception(e)
+ # FIXME(aloga): do not use web exception here
raise web.HTTPInternalServerError(
reason=("Model defined schema is invalid, " "check server logs.")
)
@@ -131,24 +120,21 @@ def __init__(self, name, model_obj, app=None):
if issubclass(schema, marshmallow.Schema):
self.has_schema = True
except TypeError:
+ # FIXME(aloga): do not use web exception here
raise web.HTTPInternalServerError(
reason=("Model defined schema is invalid, " "check server logs.")
)
else:
self.has_schema = False
- self.response_schema = schema
-
- def _setup_cleanup(self):
- self._app.on_cleanup.append(self._close_executors)
-
- async def _close_executors(self, app):
- self._executor.shutdown()
-
- def _init_executor(self):
- n = self._workers
- executor = CancellablePool(max_workers=n)
- return executor
+ # Now convert to pydantic schema...
+ # FIXME(aloga): use try except
+ if schema is not None:
+ self.response_schema = utils.pydantic_from_marshmallow(
+ "ModelPredictionResponse", schema
+ )
+ else:
+ self.response_schema = None
@contextlib.contextmanager
def _catch_error(self):
@@ -269,7 +255,7 @@ def predict_wrap(predict_func, *args, **kwargs):
return ret
- def predict(self, *args, **kwargs):
+ async def predict(self, *args, **kwargs):
"""Perform a prediction on wrapped model's ``predict`` method.
:raises HTTPNotImplemented: If the method is not
@@ -295,9 +281,7 @@ def predict(self, *args, **kwargs):
# FIXME(aloga); cleanup of tmpfile here
with self._catch_error():
- return self._run_in_pool(
- self.predict_wrap, self.model_obj.predict, *args, **kwargs
- )
+ return self.predict_wrap(self.model_obj.predict, *args, **kwargs)
def train(self, *args, **kwargs):
"""Perform a training on wrapped model's ``train`` method.
@@ -311,7 +295,7 @@ def train(self, *args, **kwargs):
"""
with self._catch_error():
- return self._run_in_pool(self.model_obj.train, *args, **kwargs)
+ return self.model_obj.train(*args, **kwargs)
def get_train_args(self):
"""Add training arguments into the training parser.
@@ -338,86 +322,3 @@ def get_predict_args(self):
except (NotImplementedError, AttributeError):
args = {}
return args
-
-
-class NonDaemonProcess(multiprocessing.context.SpawnProcess):
- """Processes must use 'spawn' instead of 'fork' (which is the default
- in Linux) in order to work CUDA [1] or Tensorflow [2].
-
- [1] https://pytorch.org/docs/stable/notes/multiprocessing.html
- #cuda-in-multiprocessing
- [2] https://github.com/tensorflow/tensorflow/issues/5448
- #issuecomment-258934405
- """
-
- @property
- def daemon(self):
- return False
-
- @daemon.setter
- def daemon(self, value):
- pass
-
-
-class NonDaemonPool(multiprocessing.pool.Pool):
- # Based on https://stackoverflow.com/questions/6974695/
- def Process(self, *args, **kwds): # noqa
- proc = super(NonDaemonPool, self).Process(*args, **kwds)
- proc.__class__ = NonDaemonProcess
-
- return proc
-
-
-class CancellablePool(object):
- def __init__(self, max_workers=None):
- self._free = {self._new_pool() for _ in range(max_workers)}
- self._working = set()
- self._change = asyncio.Event()
-
- def _new_pool(self):
- return NonDaemonPool(1, context=multiprocessing.get_context("spawn"))
-
- async def apply(self, fn, *args):
- """
- Like multiprocessing.Pool.apply_async, but:
- * is an asyncio coroutine
- * terminates the process if cancelled
- """
- while not self._free:
- await self._change.wait()
- self._change.clear()
- pool = usable_pool = self._free.pop()
- self._working.add(pool)
-
- loop = asyncio.get_event_loop()
- fut = loop.create_future()
-
- def _on_done(obj):
- ret = {"output": obj, "finish_date": str(datetime.datetime.now())}
- loop.call_soon_threadsafe(fut.set_result, ret)
-
- def _on_err(err):
- loop.call_soon_threadsafe(fut.set_exception, err)
-
- pool.apply_async(fn, args, callback=_on_done, error_callback=_on_err)
-
- try:
- return await fut
- except asyncio.CancelledError:
- # This is ugly, but since our pools only have one slot we can
- # kill the process before termination
- try:
- pool._pool[0].kill()
- except AttributeError:
- os.kill(pool._pool[0].pid, signal.SIGKILL)
- pool.terminate()
- usable_pool = self._new_pool()
- finally:
- self._working.remove(pool)
- self._free.add(usable_pool)
- self._change.set()
-
- def shutdown(self):
- for p in self._working:
- p.terminate()
- self._free.clear()
diff --git a/deepaas/tests/out_test/out_file.json b/deepaas/tests/out_test/out_file.json
new file mode 100644
index 00000000..ac3114f0
--- /dev/null
+++ b/deepaas/tests/out_test/out_file.json
@@ -0,0 +1 @@
+[{'value1': {'pred': 1}, 'value2': {'pred': 0.9}}]
diff --git a/deepaas/tests/out_test/out_xxxxxxxxxx.json b/deepaas/tests/out_test/out_xxxxxxxxxx.json
new file mode 100644
index 00000000..ac3114f0
--- /dev/null
+++ b/deepaas/tests/out_test/out_xxxxxxxxxx.json
@@ -0,0 +1 @@
+[{'value1': {'pred': 1}, 'value2': {'pred': 0.9}}]
diff --git a/deepaas/tests/out_test/tmp_dir.zip b/deepaas/tests/out_test/tmp_dir.zip
new file mode 100644
index 00000000..d91884e6
Binary files /dev/null and b/deepaas/tests/out_test/tmp_dir.zip differ
diff --git a/pyproject.toml b/pyproject.toml
index 1c8c0249..2d41d0b8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "deepaas"
-version = "2.4.0"
+version = "3.0.0"
description = "DEEPaaS is a REST API to expose a machine learning model."
authors = ["Alvaro Lopez Garcia "]
license = "Apache-2"
@@ -56,6 +56,7 @@ aiohttp-apispec = "^2.2.3"
Werkzeug = "^3.0.3"
marshmallow = "^3.21.3"
webargs = "<6.0.0"
+fastapi = "^0.111.0"
[tool.poetry.group.dev.dependencies]
@@ -94,7 +95,7 @@ mypy = "^1.10.0"
[tool.poetry.group.test-pypi.dependencies]
-twine = "^5.1.0"
+poetry = "^1.8.3"
[tool.poetry.group.test-pip-missing-reqs.dependencies]
diff --git a/tox.ini b/tox.ini
index 89ab2787..0036778f 100644
--- a/tox.ini
+++ b/tox.ini
@@ -90,14 +90,6 @@ exclude =
[testenv:flake8]
basepython = {[base]python}
-deps =
- flake8>=4.0,<4.1
- flake8-bugbear>=22.3,<22.4
- ; flake8-docstrings>=1.6,<1.7
- flake8-typing-imports>=1.12,<1.13
- flake8-colors>=0.1,<0.2
- pep8-naming>=0.12,<0.13
- pydocstyle>=6.1,<6.2
commands =
poetry run flake8 {[base]package}