-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit b3096c6
Showing
14 changed files
with
550 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
version: '3' | ||
|
||
services: | ||
ml-service: | ||
build: ./ml-service | ||
ports: | ||
- 5000:5000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
FROM python:3.11.7 | ||
|
||
# Set the working directory | ||
WORKDIR /app | ||
|
||
# Copy the current directory contents into the container at /app | ||
COPY . /app/ | ||
|
||
# Install any needed packages specified in requirements.txt | ||
RUN pip install --no-cache-dir -r requirements.txt | ||
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 | ||
#RUN pip install flash-attn --no-build-isolation | ||
|
||
# Make port 80 available to the world outside this container | ||
EXPOSE 8000 | ||
|
||
# Run app.py when the container launches | ||
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"] |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from fastapi import FastAPI, APIRouter, UploadFile, File, Form, Query | ||
import uvicorn | ||
from inference import VisionInference,SpeechInference | ||
import logging | ||
import os | ||
from utils import extract_json_from_string | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
app = FastAPI() | ||
router = APIRouter() | ||
|
||
vision_inference = VisionInference() | ||
#speech_inference = SpeechInference() | ||
|
||
@router.get("/") | ||
async def home(): | ||
return {"message": "Deep Learning services"} | ||
|
||
|
||
@router.post("/vision") | ||
async def process_vision( | ||
image: UploadFile = File(...), | ||
instruction: str = Form(...), | ||
json_output: bool = Query(False) | ||
): | ||
# Save the image file to a temporary directory | ||
file_path = f"tmp/{image.filename}" | ||
with open(file_path, "wb") as f: | ||
f.write(await image.read()) | ||
|
||
# Process the image based on the instruction | ||
response = vision_inference.predict(file_path, instruction) | ||
|
||
# Remove the temporary image file | ||
os.remove(file_path) | ||
|
||
if json_output: | ||
return extract_json_from_string(response) | ||
else: | ||
return {"response": response} | ||
|
||
@router.post("/speech") | ||
async def process_speech( | ||
audio: UploadFile = File(...) | ||
): | ||
# Save the audio file to a temporary directory | ||
file_path = f"tmp/{audio.filename}" | ||
with open(file_path, "wb") as f: | ||
f.write(await audio.read()) | ||
|
||
# Process the audio file | ||
response = speech_inference.predict(file_path) | ||
|
||
# Remove the temporary audio file | ||
os.remove(file_path) | ||
|
||
return {"response": response} | ||
|
||
|
||
# @router.post("/sql") | ||
# async def process_prmpt( | ||
# prompt: str = Form(...) | ||
# ): | ||
|
||
# response = SQLInference.predict(prompt) | ||
|
||
# return {"response": response} | ||
|
||
app.include_router(router) | ||
|
||
if __name__ == "__main__": | ||
uvicorn.run("app:app", reload=False, port=5000, host="0.0.0.0") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from model import VisionModel,SpeechModel | ||
from transformers import pipeline | ||
from PIL import Image | ||
|
||
class VisionInference: | ||
def __init__(self): | ||
model_obj = VisionModel() | ||
self.model = model_obj.load_model() | ||
self.processor = model_obj.load_processor() | ||
|
||
def predict(self, image_path: str, instruction: str) -> str: | ||
image = Image.open(image_path) | ||
|
||
messages = [ | ||
{"role": "user", "content": f"<|image_1|>\n{instruction}"} | ||
] | ||
prompt = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | ||
inputs = self.processor(prompt, [image], return_tensors="pt").to("cuda:0") | ||
|
||
generation_args = { | ||
"max_new_tokens": 2048, | ||
"temperature": 0.0 | ||
} | ||
|
||
generate_ids = self.model.generate(**inputs, eos_token_id=self.processor.tokenizer.eos_token_id, **generation_args) | ||
|
||
# remove input tokens | ||
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] | ||
response = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | ||
|
||
return response | ||
|
||
|
||
class SpeechInference: | ||
def __init__(self): | ||
model_obj = SpeechModel() | ||
self.model = model_obj.load_model() | ||
self.processor = model_obj.load_processor() | ||
self.pipeline = pipe = pipeline( | ||
"automatic-speech-recognition", | ||
model=self.model, | ||
tokenizer=self.processor.tokenizer, | ||
feature_extractor=self.processor.feature_extractor, | ||
max_new_tokens=128, | ||
chunk_length_s=30, | ||
batch_size=16, | ||
return_timestamps=True, | ||
# torch_dtype=torch_dtype, | ||
# device=device, | ||
) | ||
|
||
def predict(self, audio_path: str) -> str: | ||
|
||
transcription = self.pipeline(audio_path) | ||
|
||
return transcription | ||
|
||
|
||
class SQLInference: | ||
def __init__(self): | ||
model_obj = SQLModel() | ||
self.model = model_obj.load_model() | ||
self.tokenizer = model_obj.load_tokenizer() | ||
|
||
def predict(self, prompt: str) -> str: | ||
messages = [ | ||
{"role": "user", "content": prompt} | ||
] | ||
|
||
input_ids = self.tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt') | ||
output_ids = self.model.generate(input_ids.to('cuda'), max_new_tokens=256) | ||
response = self.tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) | ||
|
||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from transformers import AutoProcessor,AutoModelForCausalLM,AutoModelForSpeechSeq2Seq | ||
import torch | ||
from peft import PeftModel | ||
|
||
CACHE_DIR = 'models' | ||
|
||
class VisionModel: | ||
|
||
def __init__(self): | ||
self.model_id = "microsoft/Phi-3-vision-128k-instruct" | ||
|
||
def load_model(self): | ||
try: | ||
model = AutoModelForCausalLM.from_pretrained(self.model_id, | ||
device_map="cuda", | ||
trust_remote_code=True, | ||
torch_dtype=torch.float16, | ||
_attn_implementation='flash_attention_2') | ||
except Exception as e: | ||
print(e) | ||
model = AutoModelForCausalLM.from_pretrained(self.model_id, | ||
device_map="cuda", | ||
trust_remote_code=True, | ||
torch_dtype=torch.float16, | ||
_attn_implementation='eager') | ||
return model | ||
|
||
def load_processor(self): | ||
processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True, cache_dir=CACHE_DIR) | ||
return processor | ||
|
||
|
||
class SpeechModel: | ||
|
||
def __init__(self): | ||
self.model_id = "openai/whisper-small" | ||
|
||
def load_model(self): | ||
model = AutoModelForSpeechSeq2Seq.from_pretrained(self.model_id, | ||
device_map="cuda", | ||
torch_dtype=torch.float16, | ||
low_cpu_mem_usage=True, | ||
use_safetensors=True) | ||
return model | ||
|
||
def load_processor(self): | ||
processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True, cache_dir=CACHE_DIR) | ||
return processor | ||
|
||
class SQLModel: | ||
|
||
def __init__(self): | ||
self.base_model_id = "defog/sqlcoder-7b-2" | ||
self.adapter_model_id = "manishdighore/intersystems-sql-coder" | ||
|
||
def load_model(self): | ||
model = AutoModelForCausalLM.from_pretrained( | ||
self.base_model_id, | ||
device_map="cuda", | ||
torch_dtype=torch.float16 | ||
) | ||
model = PeftModel.from_pretrained(model,self.adapter_model_id) | ||
model = model.to("cuda") | ||
model.eval() | ||
return model | ||
|
||
def load_tokenizer(self): | ||
tokenizer = AutoTokenizer.from_pretrained(self.base_model_id) | ||
return tokenizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
accelerate==0.32.1 | ||
annotated-types==0.7.0 | ||
anyio==4.4.0 | ||
certifi==2024.7.4 | ||
charset-normalizer==3.3.2 | ||
click==8.1.7 | ||
colorama==0.4.6 | ||
dnspython==2.6.1 | ||
email_validator==2.2.0 | ||
fastapi==0.111.1 | ||
fastapi-cli==0.0.4 | ||
filelock==3.13.1 | ||
fsspec==2024.2.0 | ||
h11==0.14.0 | ||
httpcore==1.0.5 | ||
httptools==0.6.1 | ||
httpx==0.27.0 | ||
huggingface-hub==0.23.5 | ||
idna==3.7 | ||
intel-openmp==2021.4.0 | ||
Jinja2==3.1.4 | ||
markdown-it-py==3.0.0 | ||
MarkupSafe==2.1.5 | ||
mdurl==0.1.2 | ||
mkl==2021.4.0 | ||
mpmath==1.3.0 | ||
networkx==3.2.1 | ||
numpy==1.26.3 | ||
packaging==24.1 | ||
pillow==10.2.0 | ||
psutil==6.0.0 | ||
pydantic==2.8.2 | ||
pydantic_core==2.20.1 | ||
Pygments==2.18.0 | ||
python-dotenv==1.0.1 | ||
python-multipart==0.0.9 | ||
PyYAML==6.0.1 | ||
regex==2024.5.15 | ||
requests==2.32.3 | ||
rich==13.7.1 | ||
safetensors==0.4.3 | ||
shellingham==1.5.4 | ||
sniffio==1.3.1 | ||
starlette==0.37.2 | ||
sympy==1.12 | ||
tbb==2021.11.0 | ||
tokenizers==0.19.1 | ||
# torch==2.3.1+cu121 | ||
# torchaudio==2.3.1+cu121 | ||
# torchvision==0.18.1+cu121 | ||
tqdm==4.66.4 | ||
transformers==4.42.4 | ||
typer==0.12.3 | ||
typing_extensions==4.12.2 | ||
urllib3==2.2.2 | ||
uvicorn==0.30.1 | ||
watchfiles==0.22.0 | ||
websockets==12.0 |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import re | ||
import json | ||
|
||
def extract_json_from_string(string_data): | ||
""" | ||
Extracts JSON data from a given string. | ||
Parameters: | ||
string_data (str): The input string containing JSON data. | ||
Returns: | ||
dict or list: The extracted JSON data as a Python object (dict or list). | ||
""" | ||
# Use regular expressions to find the JSON data within the string | ||
json_data_match = re.search(r'```json(.*?)```', string_data, re.DOTALL) | ||
|
||
# Check if JSON data was found | ||
if json_data_match: | ||
json_data = json_data_match.group(1).strip() | ||
# Convert the JSON string to a Python object | ||
json_object = json.loads(json_data) | ||
return json_object | ||
else: | ||
raise ValueError("No JSON data found in the string.") |
Oops, something went wrong.