-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Migrate to fastapi and add model (#250)
* migrate to fastapi and refactor * fix requirements * Refactor and add models * Refactor and cleanup
- Loading branch information
Showing
20 changed files
with
206 additions
and
164 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from fastapi import FastAPI | ||
from fastapi.responses import HTMLResponse | ||
from routers import classify, summarize, question, model | ||
from utils.logging_config import setup_logging | ||
|
||
setup_logging() | ||
|
||
app = FastAPI() | ||
|
||
app.include_router(classify.router, prefix="/classify", tags=["Classification"]) | ||
app.include_router(summarize.router, prefix="/summarize", tags=["Summarization"]) | ||
app.include_router(question.router, prefix="/question", tags=["Question Answering"]) | ||
app.include_router(model.router, prefix="/model", tags=["Models"]) | ||
|
||
@app.get("/", response_class=HTMLResponse, include_in_schema=False) | ||
async def index(): | ||
return """ | ||
<html> | ||
<head> | ||
<style> | ||
body { | ||
display: flex; | ||
justify-content: center; | ||
align-items: center; | ||
height: 100vh; | ||
margin: 0; | ||
font-family: 'Apple System', sans-serif; | ||
text-align: center; | ||
} | ||
h1 { | ||
font-size: 4em; | ||
margin-bottom: 20px; | ||
} | ||
a { | ||
font-size: 1.5em; | ||
color: #007bff; | ||
text-decoration: none; | ||
font-weight: bold; | ||
} | ||
a:hover { | ||
text-decoration: underline; | ||
} | ||
</style> | ||
</head> | ||
<body> | ||
<div> | ||
<h1>Machine Learning API</h1> | ||
<a href="/docs">Go to API Documentation</a> | ||
</div> | ||
</body> | ||
</html> | ||
""" |
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | ||
from utils.device import get_device | ||
import logging | ||
|
||
class TransformersModel: | ||
def __init__(self, model_name: str): | ||
self.device = get_device() | ||
logging.info("Loading model and tokenizer...") | ||
|
||
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
logging.info("Tokenizer loaded") | ||
|
||
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | ||
self.model.to(self.device) | ||
logging.info("Model loaded") | ||
|
||
def generate_response(self, text: str, max_length=32, num_beams=4): | ||
inputs = self.tokenizer.encode(text, return_tensors="pt").to(self.device) | ||
outputs = self.model.generate(inputs, max_length=max_length, num_beams=num_beams, early_stopping=True) | ||
response = self.tokenizer.decode(outputs[0].to("cpu")) | ||
return response.replace("<pad>", "").replace("<s>", "").replace("</s>", "").strip() | ||
|
||
models = { | ||
"bigscience/T0_3B" : TransformersModel("bigscience/T0_3B"), | ||
"google/flan-t5-small" : TransformersModel("google/flan-t5-small"), | ||
"geektech/flan-t5-base-gpt4-relation" : TransformersModel("geektech/flan-t5-base-gpt4-relation"), | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
transformers==4.34.0 | ||
flask==3.0.0 | ||
torch==2.5.1 | ||
requests==2.31.0 | ||
sentencepiece==0.1.97 | ||
gunicorn==21.2.0 | ||
protobuf==4.24.4 | ||
pytest==8.3.3 | ||
openai==1.55.1 | ||
scikit-learn==1.3.2 | ||
scikit-learn==1.3.2 | ||
numpy==1.24.4 | ||
fastapi[standard]==0.115.6 |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import logging | ||
from fastapi import APIRouter, HTTPException | ||
from pydantic import BaseModel | ||
from typing import Dict | ||
from classifiers.binary.fasttext_classifier import FastTextClassifier | ||
|
||
router = APIRouter() | ||
|
||
class ClassifyInput(BaseModel): | ||
xy_train: Dict[str, Dict[str, str]] | ||
x_pred: Dict[str, Dict[str, str]] | ||
|
||
class ClassifyOutput(BaseModel): | ||
y_pred: Dict[str, str] | ||
algorithm_id: str | ||
|
||
@router.post("/") | ||
async def classify(data: ClassifyInput) -> ClassifyOutput: | ||
try: | ||
xy_train = data.xy_train | ||
x_pred = data.x_pred | ||
|
||
algorithm = FastTextClassifier() | ||
algorithm.train( | ||
input_data=[x["title"] for x in xy_train.values()], | ||
true_labels=[x["decision"] for x in xy_train.values()], | ||
) | ||
|
||
predictions = algorithm.predict([x["title"] for x in x_pred.values()]) | ||
return ClassifyOutput(y_pred=predictions, algorithm_id="FastText") | ||
except Exception as e: | ||
logging.error(f"Error during classification: {e}") | ||
raise HTTPException(status_code=503, detail="Service Unavailable: Unable to process the request.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from fastapi import APIRouter | ||
from pydantic import BaseModel | ||
from external_models.transformers_model import models | ||
|
||
router = APIRouter() | ||
class GetModelsResponse(BaseModel): | ||
models: list | ||
|
||
@router.get("/") | ||
async def get_models() -> GetModelsResponse: | ||
return GetModelsResponse(models=list(models.keys())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import logging | ||
from fastapi import APIRouter, HTTPException | ||
from pydantic import BaseModel | ||
from external_models.transformers_model import models | ||
from typing import Dict | ||
|
||
router = APIRouter() | ||
|
||
class QuestionInput(BaseModel): | ||
text: str | ||
model: str | ||
|
||
class QuestionResponse(BaseModel): | ||
response: str | ||
|
||
@router.post("/") | ||
async def question(data: QuestionInput) -> QuestionResponse: | ||
try: | ||
model = models[data.model] | ||
response = model.generate_response(data.text) | ||
logging.debug("text: %s, response: %s", data.text, response) | ||
return QuestionResponse(response=response) | ||
except Exception as e: | ||
logging.error(f"Error during question answering: {e}") | ||
raise HTTPException(status_code=503, detail="Service Unavailable: Unable to process the request.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import logging | ||
from fastapi import APIRouter, HTTPException | ||
from pydantic import BaseModel | ||
from external_models.transformers_model import models | ||
from typing import Dict | ||
|
||
router = APIRouter() | ||
|
||
class SummarizeInputRequest(BaseModel): | ||
text: str | ||
model: str | ||
|
||
class SummarizeOutputResponse(BaseModel): | ||
response: str | ||
|
||
@router.post("/") | ||
async def summarize(data: SummarizeInputRequest) -> SummarizeOutputResponse: | ||
try: | ||
model = models[data.model] | ||
summary = model.generate_response(f"Summarize the following text: {data.text}") | ||
return SummarizeOutputResponse(response=summary) | ||
except Exception as e: | ||
logging.error(f"Error during summarization: {e}") | ||
raise HTTPException(status_code=503, detail="Service Unavailable: Unable to process the request.") | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import torch | ||
import logging | ||
|
||
def get_device(): | ||
if torch.cuda.is_available(): | ||
logging.info("CUDA is available.") | ||
return torch.device("cuda") | ||
elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): | ||
logging.info("MPS is available.") | ||
return torch.device("mps") | ||
logging.info("CUDA and MPS are not available. Using CPU.") | ||
return torch.device("cpu") |
Oops, something went wrong.