Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the base_route and crud_base #17

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ venv.bak/
site

.bento/
.vscode/
109 changes: 109 additions & 0 deletions fastapi_utils/base_route.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Dict, Generic, List, TypeVar

from fastapi import Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session

from fastapi_utils.crud_base import Base, CRUDBase

ResponseModelType = TypeVar("ResponseModelType", bound=BaseModel)
ResponseModelManyType = TypeVar("ResponseModelManyType", bound=BaseModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
IDType = TypeVar("IDType")


def get_filter_fields(self) -> List[str]:
"""This would need to get overridden for each BaseRoute where the filter fields are defined.

Returns:
List[str] -- List of fields to filter by
"""
return []


class BaseRoute(Generic[ResponseModelType, ResponseModelManyType, CreateSchemaType, UpdateSchemaType, IDType]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would make sense to change the name here to CRUDRoute rather than BaseRoute? BaseRoute seems a little overly generic to me.

"""A base route that has the basic CRUD endpoints.

For read_many

"""

filter_fields: List[str] = Depends(get_filter_fields)
Copy link
Collaborator

@dmontagu dmontagu Mar 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a benefit to this being a dependency rather than just a class var? That is to say, why not just use

    ...
    filter_fields: typing.ClassVar[List[str]] = []

and then you can override the list in a subclass if desired? Also, if this approach makes sense, it may be better to use a tuple rather than a list for this to ensure immutability.

I think that approach should work fine, and would have a bit less overhead too. And you could still override with a dependency if you wanted to. (That could maybe go into docs.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's exactly what I was looking for. Didn't know how to stop the class variables showing up in the openapi spec

crud_base = CRUDBase(Base) # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be possible for the user to specify their own base class, not the one provided by the library. Otherwise I think it would be difficult/impossible to override things like the table naming scheme or similar.

Moreover, it could add some annoyances when using alembic for migrations.

I'm not sure how hard it would be to refactor to allow a user-specified Base, but I think we should try to support that if possible.

db: Session = Depends(None)
object_name = "Base"

def read_many(self, skip: int = 0, limit: int = 100, sort_by: str = None, **kwargs) -> ResponseModelManyType:
"""Reads many from the database with the provided filter and sort parameters.

Filter parameters need to be specified by overriding this read_many method and calling it like:

@router.get("/", response_model=List[Person])
def read_persons(
self, skip: int = 0, limit: int = 100, sort_by: str = None, name: str = None,
) -> List[Person]:
return super().read_many(skip=skip, limit=limit, sort_by=sort_by, name=name)

Where the filter fields are defined as parameters. In this case "name" is a filter field

Keyword Arguments:
skip {int} -- [description] (default: {0})
limit {int} -- [description] (default: {100})
sort_by {str} -- Expected in the form "model__field_name:asc,field_name:desc" (default: {None})

**kwargs {str} -- Filter field names expected in the form field_name or model__field_name if filtering through
a join. The filter is defined as op:value. For example ==:paul or eq:paul

The filter op is specified in the crud_base FilterOpEnum.

Returns:
ResponseModelManyType -- [description]
"""
filter_fields: Dict[str, str] = {}
for field in self.filter_fields:
filter_fields[field] = kwargs.pop(field, None)
results = self.crud_base.get_multi(self.db, skip=skip, limit=limit, filter_by=filter_fields, sort_by=sort_by)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should probably raise some kind of validation error if invalid filter fields are specified (i.e., kwargs contains things that aren't used).

In general, I really prefer to avoid using **kwargs as an input for type safety considerations, but it feels like it may be appropriate here. But I think we should add some validation of the specified filter field values.

return results

def create(self, *, obj_in: CreateSchemaType,) -> ResponseModelType:
"""
Create new object.
"""
result = self.crud_base.create(db_session=self.db, obj_in=obj_in)
return result

def update(self, *, id: IDType, obj_in: UpdateSchemaType,) -> ResponseModelType:
"""
Update an object.
"""
result = self.crud_base.get(db_session=self.db, id=id)
if not result:
raise HTTPException(status_code=404, detail=f"{self.object_name} not found")
# if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id):
# raise HTTPException(status_code=400, detail="Not enough permissions")
result = self.crud_base.update(db_session=self.db, db_obj=result, obj_in=obj_in)
return result

def read(self, *, id: IDType,) -> ResponseModelType:
"""
Get object by ID.
"""
result = self.crud_base.get(db_session=self.db, id=id)
if not result:
raise HTTPException(status_code=404, detail=f"{self.object_name} not found")
# if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id):
# raise HTTPException(status_code=400, detail="Not enough permissions")
return result

def delete(self, *, id: IDType,) -> ResponseModelType:
"""
Delete an object.
"""
result = self.crud_base.get(db_session=self.db, id=id)
if not result:
raise HTTPException(status_code=404, detail=f"{self.object_name} not found")
# if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id):
# raise HTTPException(status_code=400, detail="Not enough permissions")
result = self.crud_base.remove(db_session=self.db, id=id)
return result
160 changes: 160 additions & 0 deletions fastapi_utils/crud_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from decimal import Decimal
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance we could put this file (crud_base.py) and (base_route.py) into a crud subfolder/package? It could be fastapi_utils.crud.base and fastapi_utils.crud.route.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

from enum import Enum
from typing import Dict, Generic, List, Optional, Type, TypeVar, Union

from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from sqlalchemy_filters import apply_filters, apply_sort

from fastapi_utils.camelcase import snake2camel

Base = declarative_base()
Copy link
Collaborator

@dmontagu dmontagu Mar 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted above, for multiple reasons I would much prefer if it were possible for the user to provide their own declarative_base(), rather than being forced into using the one defined in this module.

I think it may be challenging to get the type hinting right, but I think the functionality benefits are worth the associated challenge with the ModelType TypeVar.

I suspect even in the worst case there's probably something we can do with an if TYPE_CHECKING: block to make it work nicely for most/all cases.


ModelType = TypeVar("ModelType", bound=Base)
MultiSchemaType = TypeVar("MultiSchemaType", bound=BaseModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
IDType = TypeVar("IDType")


class SortDirectionEnum(str, Enum):
ASC = "asc"
DESC = "desc"


class FilterOpEnum(str, Enum):
IS_NULL = "is_null"
IS_NOT_NULL = "is_not_null"
EQ_SYM = "=="
EQ = "eq"
NE_SYM = "!="
NE = "ne"
GT_SYM = ">"
GT = "gt"
LT_SYM = "<"
LT = "lt"
GE_SYM = ">="
GE = "ge"
LE_SYM = "<="
LE = "le"
LIKE = "like"
ILIKE = "ilike"
IN = "in"
NOT_IN = "not_in"


class SortField(BaseModel):
field: str
model: Optional[str] = None
direction: SortDirectionEnum = SortDirectionEnum.DESC


class FilterField(BaseModel):
field: str
model: Optional[str] = None
op: FilterOpEnum
value: Union[str, int, Decimal]


def get_filter_field(field: str, field_name: str, split_character: str = ":") -> FilterField:
model = None
op, value = field.split(":")
if "__" in field_name:
model, field_name = field_name.split("__")
model = snake2camel(model, start_lower=False)
filter_field = FilterField(field=field_name, model=model, op=op, value=value)
return filter_field


def get_filter_fields(fields: Optional[Dict[str, str]], split_character: str = ":") -> List[FilterField]:
filter_fields = []
if fields:
for field_name in fields:
if fields[field_name]:
filter_fields.append(get_filter_field(field=fields[field_name], field_name=field_name))
return filter_fields


def get_sort_field(field: str) -> SortField:
model = None
field_name, direction = field.split(":")
if "__" in field_name:
model, field_name = field_name.split("__")
sort_field = SortField(model=model, field=field_name, direction=direction)
return sort_field


def get_sort_fields(sort_string: str, split_character: str = ",") -> List[SortField]:
sort_fields = []
# There could be many sort fields
if sort_string:
sort_by_fields = sort_string.split(",")
for _to_sort in sort_by_fields:
sort_fields.append(get_sort_field(_to_sort))
return sort_fields


class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
"""
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
**Parameters**
* `model`: A SQLAlchemy model class
* `schema`: A Pydantic model (schema) class
"""
self.model = model

def get(self, db_session: Session, id: IDType) -> Optional[ModelType]:
return db_session.query(self.model).filter(self.model.id == id).first() # type: ignore

def get_multi(
self,
db_session: Session,
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
filter_by: Optional[Dict[str, str]] = None,
) -> List[ModelType]:

sort_spec_pydantic = get_sort_fields(sort_by)
filter_spec_pydantic = get_filter_fields(filter_by)

sort_spec = [x.dict(exclude_none=True) for x in sort_spec_pydantic]
filter_spec = [x.dict(exclude_none=True) for x in filter_spec_pydantic]

query = db_session.query(self.model)
query = apply_filters(query, filter_spec)
query = apply_sort(query, sort_spec)

count = query.count()
query = query.offset(skip).limit(limit)

return query.all()

def create(self, db_session: Session, *, obj_in: CreateSchemaType) -> ModelType:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
return db_obj

def update(self, db_session: Session, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType:
obj_data = jsonable_encoder(db_obj)
update_data = obj_in.dict(skip_defaults=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
return db_obj

def remove(self, db_session: Session, *, id: IDType) -> ModelType:
obj = db_session.query(self.model).get(id)
db_session.delete(obj)
db_session.commit()
return obj
31 changes: 24 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ python = "^3.6"
fastapi = "*"
pydantic = "^1.0"
sqlalchemy = "^1.3.12"
sqlalchemy-filters = "^0.10.0"

[tool.poetry.dev-dependencies]
# Starlette features
Expand Down