Skip to content

Commit

Permalink
Migrate to fastapi and add model (#250)
Browse files Browse the repository at this point in the history
* migrate to fastapi and refactor

* fix requirements

* Refactor and add models

* Refactor and cleanup
  • Loading branch information
Tom2rec authored Jan 15, 2025
1 parent e786e34 commit cf1ef75
Show file tree
Hide file tree
Showing 20 changed files with 206 additions and 164 deletions.
9 changes: 6 additions & 3 deletions src/backend/ml_api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@ conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvi
conda install conda-forge::fasttext
```

Run the app:
Run the app from the *ml_api* directory:

```
export FLASK_APP=app.py
flask run
uvicorn app:app
```

API will be available at [127.0.0.1:8000](http://127.0.0.1:8000)

Swagger UI will be available at [127.0.0.1:8000/docs](http://127.0.0.1:8000/docs)

Run tests:

```
Expand Down
52 changes: 52 additions & 0 deletions src/backend/ml_api/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from routers import classify, summarize, question, model
from utils.logging_config import setup_logging

setup_logging()

app = FastAPI()

app.include_router(classify.router, prefix="/classify", tags=["Classification"])
app.include_router(summarize.router, prefix="/summarize", tags=["Summarization"])
app.include_router(question.router, prefix="/question", tags=["Question Answering"])
app.include_router(model.router, prefix="/model", tags=["Models"])

@app.get("/", response_class=HTMLResponse, include_in_schema=False)
async def index():
return """
<html>
<head>
<style>
body {
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
font-family: 'Apple System', sans-serif;
text-align: center;
}
h1 {
font-size: 4em;
margin-bottom: 20px;
}
a {
font-size: 1.5em;
color: #007bff;
text-decoration: none;
font-weight: bold;
}
a:hover {
text-decoration: underline;
}
</style>
</head>
<body>
<div>
<h1>Machine Learning API</h1>
<a href="/docs">Go to API Documentation</a>
</div>
</body>
</html>
"""
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
import fasttext

from flaskr.classifiers.binary.base import BaseClassifier
from classifiers.binary.base import BaseClassifier


def write_temp_fasttext_train_file(
Expand Down
27 changes: 27 additions & 0 deletions src/backend/ml_api/external_models/transformers_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from utils.device import get_device
import logging

class TransformersModel:
def __init__(self, model_name: str):
self.device = get_device()
logging.info("Loading model and tokenizer...")

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
logging.info("Tokenizer loaded")

self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
self.model.to(self.device)
logging.info("Model loaded")

def generate_response(self, text: str, max_length=32, num_beams=4):
inputs = self.tokenizer.encode(text, return_tensors="pt").to(self.device)
outputs = self.model.generate(inputs, max_length=max_length, num_beams=num_beams, early_stopping=True)
response = self.tokenizer.decode(outputs[0].to("cpu"))
return response.replace("<pad>", "").replace("<s>", "").replace("</s>", "").strip()

models = {
"bigscience/T0_3B" : TransformersModel("bigscience/T0_3B"),
"google/flan-t5-small" : TransformersModel("google/flan-t5-small"),
"geektech/flan-t5-base-gpt4-relation" : TransformersModel("geektech/flan-t5-base-gpt4-relation"),
}
158 changes: 0 additions & 158 deletions src/backend/ml_api/flaskr/app.py

This file was deleted.

6 changes: 4 additions & 2 deletions src/backend/ml_api/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
transformers==4.34.0
flask==3.0.0
torch==2.5.1
requests==2.31.0
sentencepiece==0.1.97
gunicorn==21.2.0
protobuf==4.24.4
pytest==8.3.3
openai==1.55.1
scikit-learn==1.3.2
scikit-learn==1.3.2
numpy==1.24.4
fastapi[standard]==0.115.6
File renamed without changes.
33 changes: 33 additions & 0 deletions src/backend/ml_api/routers/classify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import Dict
from classifiers.binary.fasttext_classifier import FastTextClassifier

router = APIRouter()

class ClassifyInput(BaseModel):
xy_train: Dict[str, Dict[str, str]]
x_pred: Dict[str, Dict[str, str]]

class ClassifyOutput(BaseModel):
y_pred: Dict[str, str]
algorithm_id: str

@router.post("/")
async def classify(data: ClassifyInput) -> ClassifyOutput:
try:
xy_train = data.xy_train
x_pred = data.x_pred

algorithm = FastTextClassifier()
algorithm.train(
input_data=[x["title"] for x in xy_train.values()],
true_labels=[x["decision"] for x in xy_train.values()],
)

predictions = algorithm.predict([x["title"] for x in x_pred.values()])
return ClassifyOutput(y_pred=predictions, algorithm_id="FastText")
except Exception as e:
logging.error(f"Error during classification: {e}")
raise HTTPException(status_code=503, detail="Service Unavailable: Unable to process the request.")
11 changes: 11 additions & 0 deletions src/backend/ml_api/routers/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from fastapi import APIRouter
from pydantic import BaseModel
from external_models.transformers_model import models

router = APIRouter()
class GetModelsResponse(BaseModel):
models: list

@router.get("/")
async def get_models() -> GetModelsResponse:
return GetModelsResponse(models=list(models.keys()))
25 changes: 25 additions & 0 deletions src/backend/ml_api/routers/question.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from external_models.transformers_model import models
from typing import Dict

router = APIRouter()

class QuestionInput(BaseModel):
text: str
model: str

class QuestionResponse(BaseModel):
response: str

@router.post("/")
async def question(data: QuestionInput) -> QuestionResponse:
try:
model = models[data.model]
response = model.generate_response(data.text)
logging.debug("text: %s, response: %s", data.text, response)
return QuestionResponse(response=response)
except Exception as e:
logging.error(f"Error during question answering: {e}")
raise HTTPException(status_code=503, detail="Service Unavailable: Unable to process the request.")
25 changes: 25 additions & 0 deletions src/backend/ml_api/routers/summarize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from external_models.transformers_model import models
from typing import Dict

router = APIRouter()

class SummarizeInputRequest(BaseModel):
text: str
model: str

class SummarizeOutputResponse(BaseModel):
response: str

@router.post("/")
async def summarize(data: SummarizeInputRequest) -> SummarizeOutputResponse:
try:
model = models[data.model]
summary = model.generate_response(f"Summarize the following text: {data.text}")
return SummarizeOutputResponse(response=summary)
except Exception as e:
logging.error(f"Error during summarization: {e}")
raise HTTPException(status_code=503, detail="Service Unavailable: Unable to process the request.")

Empty file.
12 changes: 12 additions & 0 deletions src/backend/ml_api/utils/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
import logging

def get_device():
if torch.cuda.is_available():
logging.info("CUDA is available.")
return torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
logging.info("MPS is available.")
return torch.device("mps")
logging.info("CUDA and MPS are not available. Using CPU.")
return torch.device("cpu")
Loading

0 comments on commit cf1ef75

Please sign in to comment.