-
Notifications
You must be signed in to change notification settings - Fork 173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added the base_route and crud_base #17
base: master
Are you sure you want to change the base?
Changes from 3 commits
d34b334
a825b7d
ff18fc5
8c03449
6a25266
5fd2015
0e24812
8aaa3d3
d7f4f36
a08659b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,3 +87,4 @@ venv.bak/ | |
site | ||
|
||
.bento/ | ||
.vscode/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from typing import Dict, Generic, List, TypeVar | ||
|
||
from fastapi import Depends, HTTPException | ||
from pydantic import BaseModel | ||
from sqlalchemy.orm import Session | ||
|
||
from fastapi_utils.crud_base import Base, CRUDBase | ||
|
||
ResponseModelType = TypeVar("ResponseModelType", bound=BaseModel) | ||
ResponseModelManyType = TypeVar("ResponseModelManyType", bound=BaseModel) | ||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) | ||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) | ||
IDType = TypeVar("IDType") | ||
|
||
|
||
def get_filter_fields(self) -> List[str]: | ||
"""This would need to get overridden for each BaseRoute where the filter fields are defined. | ||
|
||
Returns: | ||
List[str] -- List of fields to filter by | ||
""" | ||
return [] | ||
|
||
|
||
class BaseRoute(Generic[ResponseModelType, ResponseModelManyType, CreateSchemaType, UpdateSchemaType, IDType]): | ||
"""A base route that has the basic CRUD endpoints. | ||
|
||
For read_many | ||
|
||
""" | ||
|
||
filter_fields: List[str] = Depends(get_filter_fields) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a benefit to this being a dependency rather than just a class var? That is to say, why not just use ...
filter_fields: typing.ClassVar[List[str]] = [] and then you can override the list in a subclass if desired? Also, if this approach makes sense, it may be better to use a tuple rather than a list for this to ensure immutability. I think that approach should work fine, and would have a bit less overhead too. And you could still override with a dependency if you wanted to. (That could maybe go into docs.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh that's exactly what I was looking for. Didn't know how to stop the class variables showing up in the openapi spec |
||
crud_base = CRUDBase(Base) # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should be possible for the user to specify their own base class, not the one provided by the library. Otherwise I think it would be difficult/impossible to override things like the table naming scheme or similar. Moreover, it could add some annoyances when using alembic for migrations. I'm not sure how hard it would be to refactor to allow a user-specified |
||
db: Session = Depends(None) | ||
object_name = "Base" | ||
|
||
def read_many(self, skip: int = 0, limit: int = 100, sort_by: str = None, **kwargs) -> ResponseModelManyType: | ||
"""Reads many from the database with the provided filter and sort parameters. | ||
|
||
Filter parameters need to be specified by overriding this read_many method and calling it like: | ||
|
||
@router.get("/", response_model=List[Person]) | ||
def read_persons( | ||
self, skip: int = 0, limit: int = 100, sort_by: str = None, name: str = None, | ||
) -> List[Person]: | ||
return super().read_many(skip=skip, limit=limit, sort_by=sort_by, name=name) | ||
|
||
Where the filter fields are defined as parameters. In this case "name" is a filter field | ||
|
||
Keyword Arguments: | ||
skip {int} -- [description] (default: {0}) | ||
limit {int} -- [description] (default: {100}) | ||
sort_by {str} -- Expected in the form "model__field_name:asc,field_name:desc" (default: {None}) | ||
|
||
**kwargs {str} -- Filter field names expected in the form field_name or model__field_name if filtering through | ||
a join. The filter is defined as op:value. For example ==:paul or eq:paul | ||
|
||
The filter op is specified in the crud_base FilterOpEnum. | ||
|
||
Returns: | ||
ResponseModelManyType -- [description] | ||
""" | ||
filter_fields: Dict[str, str] = {} | ||
for field in self.filter_fields: | ||
filter_fields[field] = kwargs.pop(field, None) | ||
results = self.crud_base.get_multi(self.db, skip=skip, limit=limit, filter_by=filter_fields, sort_by=sort_by) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should probably raise some kind of validation error if invalid filter fields are specified (i.e., In general, I really prefer to avoid using |
||
return results | ||
|
||
def create(self, *, obj_in: CreateSchemaType,) -> ResponseModelType: | ||
""" | ||
Create new object. | ||
""" | ||
result = self.crud_base.create(db_session=self.db, obj_in=obj_in) | ||
return result | ||
|
||
def update(self, *, id: IDType, obj_in: UpdateSchemaType,) -> ResponseModelType: | ||
""" | ||
Update an object. | ||
""" | ||
result = self.crud_base.get(db_session=self.db, id=id) | ||
if not result: | ||
raise HTTPException(status_code=404, detail=f"{self.object_name} not found") | ||
# if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): | ||
# raise HTTPException(status_code=400, detail="Not enough permissions") | ||
result = self.crud_base.update(db_session=self.db, db_obj=result, obj_in=obj_in) | ||
return result | ||
|
||
def read(self, *, id: IDType,) -> ResponseModelType: | ||
""" | ||
Get object by ID. | ||
""" | ||
result = self.crud_base.get(db_session=self.db, id=id) | ||
if not result: | ||
raise HTTPException(status_code=404, detail=f"{self.object_name} not found") | ||
# if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): | ||
# raise HTTPException(status_code=400, detail="Not enough permissions") | ||
return result | ||
|
||
def delete(self, *, id: IDType,) -> ResponseModelType: | ||
""" | ||
Delete an object. | ||
""" | ||
result = self.crud_base.get(db_session=self.db, id=id) | ||
if not result: | ||
raise HTTPException(status_code=404, detail=f"{self.object_name} not found") | ||
# if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): | ||
# raise HTTPException(status_code=400, detail="Not enough permissions") | ||
result = self.crud_base.remove(db_session=self.db, id=id) | ||
return result |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
from decimal import Decimal | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any chance we could put this file ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure |
||
from enum import Enum | ||
from typing import Dict, Generic, List, Optional, Type, TypeVar, Union | ||
|
||
from fastapi.encoders import jsonable_encoder | ||
from pydantic import BaseModel | ||
from sqlalchemy.ext.declarative import declarative_base | ||
from sqlalchemy.orm import Session | ||
from sqlalchemy_filters import apply_filters, apply_sort | ||
|
||
from fastapi_utils.camelcase import snake2camel | ||
|
||
Base = declarative_base() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As noted above, for multiple reasons I would much prefer if it were possible for the user to provide their own I think it may be challenging to get the type hinting right, but I think the functionality benefits are worth the associated challenge with the I suspect even in the worst case there's probably something we can do with an |
||
|
||
ModelType = TypeVar("ModelType", bound=Base) | ||
MultiSchemaType = TypeVar("MultiSchemaType", bound=BaseModel) | ||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) | ||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) | ||
IDType = TypeVar("IDType") | ||
|
||
|
||
class SortDirectionEnum(str, Enum): | ||
ASC = "asc" | ||
DESC = "desc" | ||
|
||
|
||
class FilterOpEnum(str, Enum): | ||
IS_NULL = "is_null" | ||
IS_NOT_NULL = "is_not_null" | ||
EQ_SYM = "==" | ||
EQ = "eq" | ||
NE_SYM = "!=" | ||
NE = "ne" | ||
GT_SYM = ">" | ||
GT = "gt" | ||
LT_SYM = "<" | ||
LT = "lt" | ||
GE_SYM = ">=" | ||
GE = "ge" | ||
LE_SYM = "<=" | ||
LE = "le" | ||
LIKE = "like" | ||
ILIKE = "ilike" | ||
IN = "in" | ||
NOT_IN = "not_in" | ||
|
||
|
||
class SortField(BaseModel): | ||
field: str | ||
model: Optional[str] = None | ||
direction: SortDirectionEnum = SortDirectionEnum.DESC | ||
|
||
|
||
class FilterField(BaseModel): | ||
field: str | ||
model: Optional[str] = None | ||
op: FilterOpEnum | ||
value: Union[str, int, Decimal] | ||
|
||
|
||
def get_filter_field(field: str, field_name: str, split_character: str = ":") -> FilterField: | ||
model = None | ||
op, value = field.split(":") | ||
if "__" in field_name: | ||
model, field_name = field_name.split("__") | ||
model = snake2camel(model, start_lower=False) | ||
filter_field = FilterField(field=field_name, model=model, op=op, value=value) | ||
return filter_field | ||
|
||
|
||
def get_filter_fields(fields: Optional[Dict[str, str]], split_character: str = ":") -> List[FilterField]: | ||
filter_fields = [] | ||
if fields: | ||
for field_name in fields: | ||
if fields[field_name]: | ||
filter_fields.append(get_filter_field(field=fields[field_name], field_name=field_name)) | ||
return filter_fields | ||
|
||
|
||
def get_sort_field(field: str) -> SortField: | ||
model = None | ||
field_name, direction = field.split(":") | ||
if "__" in field_name: | ||
model, field_name = field_name.split("__") | ||
sort_field = SortField(model=model, field=field_name, direction=direction) | ||
return sort_field | ||
|
||
|
||
def get_sort_fields(sort_string: str, split_character: str = ",") -> List[SortField]: | ||
sort_fields = [] | ||
# There could be many sort fields | ||
if sort_string: | ||
sort_by_fields = sort_string.split(",") | ||
for _to_sort in sort_by_fields: | ||
sort_fields.append(get_sort_field(_to_sort)) | ||
return sort_fields | ||
|
||
|
||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): | ||
def __init__(self, model: Type[ModelType]): | ||
""" | ||
CRUD object with default methods to Create, Read, Update, Delete (CRUD). | ||
**Parameters** | ||
* `model`: A SQLAlchemy model class | ||
* `schema`: A Pydantic model (schema) class | ||
""" | ||
self.model = model | ||
|
||
def get(self, db_session: Session, id: IDType) -> Optional[ModelType]: | ||
return db_session.query(self.model).filter(self.model.id == id).first() # type: ignore | ||
|
||
def get_multi( | ||
self, | ||
db_session: Session, | ||
*, | ||
skip: int = 0, | ||
limit: int = 100, | ||
sort_by: Optional[str] = None, | ||
filter_by: Optional[Dict[str, str]] = None, | ||
) -> List[ModelType]: | ||
|
||
sort_spec_pydantic = get_sort_fields(sort_by) | ||
filter_spec_pydantic = get_filter_fields(filter_by) | ||
|
||
sort_spec = [x.dict(exclude_none=True) for x in sort_spec_pydantic] | ||
filter_spec = [x.dict(exclude_none=True) for x in filter_spec_pydantic] | ||
|
||
query = db_session.query(self.model) | ||
query = apply_filters(query, filter_spec) | ||
query = apply_sort(query, sort_spec) | ||
|
||
count = query.count() | ||
query = query.offset(skip).limit(limit) | ||
|
||
return query.all() | ||
|
||
def create(self, db_session: Session, *, obj_in: CreateSchemaType) -> ModelType: | ||
obj_in_data = jsonable_encoder(obj_in) | ||
db_obj = self.model(**obj_in_data) | ||
db_session.add(db_obj) | ||
db_session.commit() | ||
db_session.refresh(db_obj) | ||
return db_obj | ||
|
||
def update(self, db_session: Session, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType: | ||
obj_data = jsonable_encoder(db_obj) | ||
update_data = obj_in.dict(skip_defaults=True) | ||
for field in obj_data: | ||
if field in update_data: | ||
setattr(db_obj, field, update_data[field]) | ||
db_session.add(db_obj) | ||
db_session.commit() | ||
db_session.refresh(db_obj) | ||
return db_obj | ||
|
||
def remove(self, db_session: Session, *, id: IDType) -> ModelType: | ||
obj = db_session.query(self.model).get(id) | ||
db_session.delete(obj) | ||
db_session.commit() | ||
return obj |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it would make sense to change the name here to
CRUDRoute
rather thanBaseRoute
?BaseRoute
seems a little overly generic to me.