Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the base_route and crud_base #17

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ venv.bak/
site

.bento/
.vscode/
4 changes: 4 additions & 0 deletions fastapi_utils/crud/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base import CRUDBase
from .route import CRUDRoute

__all__ = ["CRUDBase", "CRUDRoute"]
158 changes: 158 additions & 0 deletions fastapi_utils/crud/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from decimal import Decimal
from enum import Enum
from typing import Dict, Generic, List, Optional, Type, TypeVar, Union

from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session

from fastapi_utils.camelcase import snake2camel
from sqlalchemy_filters import apply_filters, apply_sort

ModelType = TypeVar("ModelType")
MultiSchemaType = TypeVar("MultiSchemaType", bound=BaseModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
IDType = TypeVar("IDType")


class SortDirectionEnum(str, Enum):
ASC = "asc"
DESC = "desc"


class FilterOpEnum(str, Enum):
IS_NULL = "is_null"
IS_NOT_NULL = "is_not_null"
EQ_SYM = "=="
EQ = "eq"
NE_SYM = "!="
NE = "ne"
GT_SYM = ">"
GT = "gt"
LT_SYM = "<"
LT = "lt"
GE_SYM = ">="
GE = "ge"
LE_SYM = "<="
LE = "le"
LIKE = "like"
ILIKE = "ilike"
IN = "in"
NOT_IN = "not_in"


class SortField(BaseModel):
field: str
model: Optional[str] = None
direction: SortDirectionEnum = SortDirectionEnum.DESC


class FilterField(BaseModel):
field: str
model: Optional[str] = None
op: FilterOpEnum
value: Union[str, int, Decimal]


def get_filter_field(field: str, field_name: str, split_character: str = ":") -> FilterField:
model = None
op, value = field.split(":")
if "__" in field_name:
model, field_name = field_name.split("__")
model = snake2camel(model, start_lower=False)
filter_field = FilterField(field=field_name, model=model, op=op, value=value)
return filter_field


def get_filter_fields(fields: Optional[Dict[str, str]], split_character: str = ":") -> List[FilterField]:
filter_fields = []
if fields:
for field_name in fields:
if fields[field_name]:
filter_fields.append(get_filter_field(field=fields[field_name], field_name=field_name))
return filter_fields


def get_sort_field(field: str) -> SortField:
model = None
field_name, direction = field.split(":")
if "__" in field_name:
model, field_name = field_name.split("__")
sort_field = SortField(model=model, field=field_name, direction=direction)
return sort_field


def get_sort_fields(sort_string: Optional[str], split_character: str = ",") -> List[SortField]:
sort_fields = []
# There could be many sort fields
if sort_string:
sort_by_fields = sort_string.split(",")
for _to_sort in sort_by_fields:
sort_fields.append(get_sort_field(_to_sort))
return sort_fields


class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
"""
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
**Parameters**
* `model`: A SQLAlchemy model class
* `schema`: A Pydantic model (schema) class
"""
self.model = model

def get(self, db_session: Session, id: IDType) -> Optional[ModelType]:
return db_session.query(self.model).filter(self.model.id == id).first() # type: ignore

def get_many(
self,
db_session: Session,
*,
skip: int = 0,
limit: int = 100,
filter_by: Optional[Dict[str, str]] = None,
sort_by: Optional[str] = None,
) -> List[ModelType]:

sort_spec_pydantic = get_sort_fields(sort_by)
filter_spec_pydantic = get_filter_fields(filter_by)

sort_spec = [x.dict(exclude_none=True) for x in sort_spec_pydantic]
filter_spec = [x.dict(exclude_none=True) for x in filter_spec_pydantic]

query = db_session.query(self.model)
query = apply_filters(query, filter_spec)
query = apply_sort(query, sort_spec)

# count = query.count()
query = query.offset(skip).limit(limit)

return query.all()

def create(self, db_session: Session, *, obj_in: CreateSchemaType) -> ModelType:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
return db_obj

def update(self, db_session: Session, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType:
obj_data = jsonable_encoder(db_obj)
update_data = obj_in.dict(skip_defaults=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
return db_obj

def remove(self, db_session: Session, *, id: IDType) -> ModelType:
obj = db_session.query(self.model).get(id)
db_session.delete(obj)
db_session.commit()
return obj
107 changes: 107 additions & 0 deletions fastapi_utils/crud/route.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import ClassVar, Dict, Generic, Tuple, TypeVar

from fastapi import Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session

from fastapi_utils.crud import CRUDBase

ResponseModelType = TypeVar("ResponseModelType", bound=BaseModel)
ResponseModelManyType = TypeVar("ResponseModelManyType", bound=BaseModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
CRUDBaseType = TypeVar("CRUDBaseType", bound=CRUDBase)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know how to get mypy to detect the crud_base would have the get_many, get, delete methods?

IDType = TypeVar("IDType")


class CRUDRoute(Generic[ResponseModelType, ResponseModelManyType, CreateSchemaType, UpdateSchemaType, IDType]):
"""A base route that has the basic CRUD endpoints.

For read_many

"""

crud_base: ClassVar[CRUDBaseType]
filter_fields: ClassVar[Tuple[str]] = ()
db: Session = Depends(None)
object_name: ClassVar[str] = "CRUDBase"

def read_many(self, skip: int = 0, limit: int = 100, sort_by: str = None, **kwargs) -> ResponseModelManyType:
"""Reads many from the database with the provided filter and sort parameters.

Filter parameters need to be specified by overriding this read_many method and calling it like:

@router.get("/", response_model=List[Person])
def read_persons(
self, skip: int = 0, limit: int = 100, sort_by: str = None, name: str = None,
) -> List[Person]:
return super().read_many(skip=skip, limit=limit, sort_by=sort_by, name=name)

Where the filter fields are defined as parameters. In this case "name" is a filter field

Keyword Arguments:
skip {int} -- [description] (default: {0})
limit {int} -- [description] (default: {100})
sort_by {str} -- Expected in the form "model__field_name:asc,field_name:desc" (default: {None})

**kwargs {str} -- Filter field names expected in the form field_name or model__field_name if
filtering through
a join. The filter is defined as op:value. For example ==:paul or eq:paul

The filter op is specified in the crud_base FilterOpEnum.

Returns:
ResponseModelManyType -- [description]
"""
filter_fields: Dict[str, str] = {}

for field in self.filter_fields:
filter_fields[field] = kwargs.pop(field, None)

if len(kwargs) != 0:
raise ValueError(f"Method parameters have not been added to class filter fields {kwargs.keys()}")

results = self.crud_base.get_many(self.db, skip=skip, limit=limit, filter_by=filter_fields, sort_by=sort_by)
return results

def create(self, *, obj_in: CreateSchemaType,) -> ResponseModelType:
"""
Create new object.
"""
result = self.crud_base.create(db_session=self.db, obj_in=obj_in)
return result

def update(self, *, id: IDType, obj_in: UpdateSchemaType,) -> ResponseModelType:
"""
Update an object.
"""
result = self.crud_base.get(db_session=self.db, id=id)
if not result:
raise HTTPException(status_code=404, detail=f"{self.object_name} not found")
# if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id):
# raise HTTPException(status_code=400, detail="Not enough permissions")
result = self.crud_base.update(db_session=self.db, db_obj=result, obj_in=obj_in)
return result

def read(self, *, id: IDType,) -> ResponseModelType:
"""
Get object by ID.
"""
result = self.crud_base.get(db_session=self.db, id=id)
if not result:
raise HTTPException(status_code=404, detail=f"{self.object_name} not found")
# if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id):
# raise HTTPException(status_code=400, detail="Not enough permissions")
return result

def delete(self, *, id: IDType,) -> ResponseModelType:
"""
Delete an object.
"""
result = self.crud_base.get(db_session=self.db, id=id)
if not result:
raise HTTPException(status_code=404, detail=f"{self.object_name} not found")
# if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id):
# raise HTTPException(status_code=400, detail="Not enough permissions")
result = self.crud_base.remove(db_session=self.db, id=id)
return result
31 changes: 24 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ python = "^3.6"
fastapi = "*"
pydantic = "^1.0"
sqlalchemy = "^1.3.12"
sqlalchemy-filters = "^0.10.0"

[tool.poetry.dev-dependencies]
# Starlette features
Expand Down