generated from guardrails-ai/validator-template
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20 from guardrails-ai/feat/serve
Added Inference Spec & CI
Showing
7 changed files
with
358 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
name: Deploy Inference Apps | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
workflow_dispatch: | ||
|
||
permissions: | ||
id-token: write | ||
contents: read | ||
|
||
jobs: | ||
deploy_dev: | ||
name: Deploy Inference Apps (Development) | ||
runs-on: ubuntu-latest | ||
env: | ||
ENV: dev | ||
AWS_REGION: us-east-1 | ||
AWS_CI_ROLE: ${{ secrets.AWS_INFER_CI_ROLE__DEV }} | ||
steps: | ||
|
||
- name: Configure AWS credentials | ||
uses: aws-actions/configure-aws-credentials@v4 | ||
with: | ||
aws-region: ${{ env.AWS_REGION }} | ||
role-to-assume: ${{ env.AWS_CI_ROLE}} | ||
|
||
- name: Deploy Ray Serve | ||
shell: bash | ||
run: | | ||
RAY_CLUSTER_NAME=ray-cluster-$ENV | ||
RAY_PRIVATE_IP=$(aws ec2 describe-instances --region $AWS_DEFAULT_REGION --filters "Name=tag:Name,Values=ray-cluster-$ENV-head" --query "Reservations[*].Instances[*].PrivateIpAddress" --output text) | ||
RAY_BASTION_PUBLIC_IP=$(aws ec2 describe-instances --region $AWS_DEFAULT_REGION --filters "Name=tag:Name,Values=ray-cluster-$ENV-bastion" --query "Reservations[*].Instances[*].PublicIpAddress" --output text) | ||
RAY_CLUSTER_KEY_PAIR_FILE=$RAY_CLUSTER_NAME | ||
RAY_CLUSTER_SECRET_KEY_PAIR_NAME=$RAY_CLUSTER_NAME-key-pair-secret | ||
aws secretsmanager get-secret-value --region $AWS_REGION --secret-id $RAY_CLUSTER_SECRET_KEY_PAIR_NAME --query SecretString --output text > ./${RAY_CLUSTER_KEY_PAIR_FILE}.pem | ||
chmod 400 ./${RAY_CLUSTER_KEY_PAIR_FILE}.pem | ||
echo "Deploying Ray Serve on $RAY_CLUSTER_NAME..." | ||
if ssh -o StrictHostKeyChecking=no -i ./${RAY_CLUSTER_KEY_PAIR_FILE}.pem ubuntu@$RAY_BASTION_PUBLIC_IP "source ~/.profile && bash bastion-ray-serve-deploy.sh $ENV" >/dev/null 2>&1; then | ||
echo "Deployment succeeded." | ||
else | ||
echo "Deployment failed." | ||
exit 1 | ||
fi | ||
deploy_prod: | ||
name: Deploy Inference Apps (Production) | ||
runs-on: ubuntu-latest | ||
depends-on: deploy_dev | ||
env: | ||
ENV: prod | ||
AWS_REGION: us-east-1 | ||
AWS_CI_ROLE: ${{ secrets.AWS_INFER_CI_ROLE__PROD }} | ||
steps: | ||
|
||
- name: Configure AWS credentials | ||
uses: aws-actions/configure-aws-credentials@v4 | ||
with: | ||
aws-region: ${{ env.AWS_REGION }} | ||
role-to-assume: ${{ env.AWS_CI_ROLE}} | ||
|
||
- name: Deploy Ray Serve | ||
shell: bash | ||
run: | | ||
RAY_CLUSTER_NAME=ray-cluster-$ENV | ||
RAY_PRIVATE_IP=$(aws ec2 describe-instances --region $AWS_DEFAULT_REGION --filters "Name=tag:Name,Values=ray-cluster-$ENV-head" --query "Reservations[*].Instances[*].PrivateIpAddress" --output text) | ||
RAY_BASTION_PUBLIC_IP=$(aws ec2 describe-instances --region $AWS_DEFAULT_REGION --filters "Name=tag:Name,Values=ray-cluster-$ENV-bastion" --query "Reservations[*].Instances[*].PublicIpAddress" --output text) | ||
RAY_CLUSTER_KEY_PAIR_FILE=$RAY_CLUSTER_NAME | ||
RAY_CLUSTER_SECRET_KEY_PAIR_NAME=$RAY_CLUSTER_NAME-key-pair-secret | ||
aws secretsmanager get-secret-value --region $AWS_REGION --secret-id $RAY_CLUSTER_SECRET_KEY_PAIR_NAME --query SecretString --output text > ./${RAY_CLUSTER_KEY_PAIR_FILE}.pem | ||
chmod 400 ./${RAY_CLUSTER_KEY_PAIR_FILE}.pem | ||
echo "Deploying Ray Serve on $RAY_CLUSTER_NAME..." | ||
if ssh -o StrictHostKeyChecking=no -i ./${RAY_CLUSTER_KEY_PAIR_FILE}.pem ubuntu@$RAY_BASTION_PUBLIC_IP "source ~/.profile && bash bastion-ray-serve-deploy.sh $ENV" >/dev/null 2>&1; then | ||
echo "Deployment succeeded." | ||
else | ||
echo "Deployment failed." | ||
exit 1 | ||
fi | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
name: Sagemaker ECR Publish (RC) | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
workflow_dispatch: | ||
inputs: | ||
is_release_candidate: | ||
description: 'Is this a release candidate?' | ||
required: true | ||
default: 'true' | ||
|
||
# Needed for OIDC / assume role | ||
permissions: | ||
id-token: write | ||
contents: read | ||
|
||
jobs: | ||
publish_image: | ||
name: Publish Sagemaker Image (Release Candidate) | ||
runs-on: ubuntu-latest | ||
env: | ||
VALIDATOR_TAG_NAME: competitorcheck | ||
AWS_REGION: us-east-1 | ||
WORKING_DIR: "./" | ||
AWS_CI_ROLE__PROD: ${{ secrets.AWS_CI_ROLE__PROD }} | ||
AWS_ECR_RELEASE_CANDIDATE: ${{ inputs.is_release_candidate || 'true' }} | ||
steps: | ||
|
||
- name: Check out head | ||
uses: actions/checkout@v3 | ||
with: | ||
persist-credentials: false | ||
|
||
- name: Set ECR Tag | ||
id: set-ecr-tag | ||
run: | | ||
if [ ${{ env.AWS_ECR_RELEASE_CANDIDATE }} == 'true' ]; then | ||
echo "This is a release candidate." | ||
echo "Setting tag to -rc" | ||
ECR_TAG=$VALIDATOR_TAG_NAME-rc | ||
else | ||
echo "This is a production image." | ||
ECR_TAG=$VALIDATOR_TAG_NAME | ||
fi | ||
echo "Setting ECR tag to $ECR_TAG" | ||
echo "ECR_TAG=$ECR_TAG" >> "$GITHUB_OUTPUT" | ||
- name: Set up QEMU | ||
uses: docker/setup-qemu-action@master | ||
with: | ||
platforms: linux/amd64 | ||
|
||
- name: Set up Docker Buildx | ||
uses: docker/setup-buildx-action@master | ||
with: | ||
platforms: linux/amd64 | ||
|
||
- name: Configure AWS credentials | ||
uses: aws-actions/configure-aws-credentials@v4 | ||
with: | ||
aws-region: ${{ env.AWS_REGION }} | ||
role-to-assume: ${{ env.AWS_CI_ROLE__PROD}} | ||
|
||
- name: Login to Amazon ECR | ||
id: login-ecr | ||
uses: aws-actions/amazon-ecr-login@v2 | ||
with: | ||
mask-password: 'true' | ||
|
||
- name: Build & Push ECR Image | ||
uses: docker/build-push-action@v2 | ||
with: | ||
builder: ${{ steps.buildx.outputs.name }} | ||
context: ${{ env.WORKING_DIR }} | ||
platforms: linux/amd64 | ||
cache-from: type=gha | ||
cache-to: type=gha,mode=max | ||
push: true | ||
tags: 064852979926.dkr.ecr.us-east-1.amazonaws.com/gr-sagemaker-validator-images-prod:${{ steps.set-ecr-tag.outputs.ECR_TAG }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Use an official PyTorch image with CUDA support | ||
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime | ||
|
||
# Set the working directory | ||
WORKDIR /app | ||
|
||
# Copy the pyproject.toml and any other necessary files (e.g., README, LICENSE) | ||
COPY pyproject.toml . | ||
COPY README.md . | ||
COPY LICENSE . | ||
|
||
# Install dependencies from the pyproject.toml file | ||
RUN pip install --upgrade pip setuptools wheel | ||
RUN pip install . | ||
|
||
# Install the necessary packages for the FastAPI app, including CuPy with CUDA 11.6 | ||
RUN pip install fastapi "uvicorn[standard]" gunicorn "spacy[cuda116]" | ||
|
||
# Copy the entire project code into the container | ||
COPY . /app | ||
|
||
# Copy the serve script into the container | ||
COPY serve /usr/local/bin/serve | ||
|
||
# Make the serve script executable | ||
RUN chmod +x /usr/local/bin/serve | ||
|
||
# Set environment variable to determine the device (cuda or cpu) | ||
ENV env=prod | ||
|
||
# Expose the port that the FastAPI app will run on | ||
EXPOSE 8080 | ||
|
||
# Set the entrypoint for SageMaker to the serve script | ||
ENTRYPOINT ["serve"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from fastapi import HTTPException | ||
from pydantic import BaseModel | ||
from typing import List, Tuple | ||
import spacy | ||
from models_host.base_inference_spec import BaseInferenceSpec | ||
|
||
class InferenceData(BaseModel): | ||
name: str | ||
shape: List[int] | ||
data: List | ||
datatype: str | ||
|
||
|
||
class InputRequest(BaseModel): | ||
inputs: List[InferenceData] | ||
|
||
class OutputResponse(BaseModel): | ||
modelname: str | ||
modelversion: str | ||
outputs: List[InferenceData] | ||
|
||
class InferenceSpec(BaseInferenceSpec): | ||
model_name = "en_core_web_trf" | ||
model = None | ||
|
||
def load(self): | ||
model_name = self.model_name | ||
print(f"Loading model {model_name}...") | ||
if not spacy.util.is_package(model_name): | ||
print( | ||
f"Spacy model {model_name} not installed. " | ||
"Download should start now and take a few minutes." | ||
) | ||
spacy.cli.download(model_name) # type: ignore | ||
self.model = spacy.load(model_name) | ||
|
||
def process_request(self, input_request: InputRequest) -> Tuple[Tuple, dict]: | ||
competitors = [] | ||
for inp in input_request.inputs: | ||
if inp.name == "text": | ||
text_vals = inp.data | ||
elif inp.name == "competitors": | ||
competitors = inp.data | ||
|
||
if text_vals is None or competitors is None: | ||
raise HTTPException(status_code=400, detail="Invalid input format") | ||
|
||
args = (text_vals, competitors) | ||
kwargs = {} | ||
return args, kwargs | ||
|
||
def infer(self, text_vals, competitors) -> OutputResponse: | ||
outputs = [] | ||
for idx, text in enumerate(text_vals): | ||
doc = self.model(text) # type: ignore | ||
|
||
located_competitors = [] | ||
for ent in doc.ents: | ||
if ent.text in competitors: | ||
located_competitors.append(ent.text) | ||
|
||
outputs.append( | ||
InferenceData( | ||
name=f"result{idx}", | ||
datatype="BYTES", | ||
shape=[1], | ||
data=[located_competitors], | ||
) | ||
) | ||
|
||
output_data = OutputResponse( | ||
modelname=self.model_name, modelversion="1", outputs=outputs | ||
) | ||
|
||
return output_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
{ | ||
"inputs": [ | ||
{ | ||
"name": "text", | ||
"shape": [1], | ||
"data": ["Apple made the iphone"], | ||
"datatype": "BYTES" | ||
}, | ||
{ | ||
"name": "competitors", | ||
"shape": [1], | ||
"data": ["Apple"], | ||
"datatype": "BYTES" | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#!/usr/bin/env python | ||
|
||
import multiprocessing | ||
import os | ||
import signal | ||
import subprocess | ||
import sys | ||
import spacy | ||
|
||
cpu_count = multiprocessing.cpu_count() | ||
default_worker_count = max(cpu_count // 2,1) | ||
|
||
model_server_timeout = os.environ.get('MODEL_SERVER_TIMEOUT', 60) | ||
model_server_workers = int(os.environ.get('MODEL_SERVER_WORKERS', default_worker_count)) | ||
|
||
def sigterm_handler(gunicorn_pid): | ||
try: | ||
os.kill(gunicorn_pid, signal.SIGTERM) | ||
except OSError: | ||
pass | ||
|
||
sys.exit(0) | ||
|
||
def save_and_load_model(): | ||
model = "en_core_web_trf" | ||
if not spacy.util.is_package(model): | ||
print( | ||
f"Spacy model {model} not installed. " | ||
"Download should start now and take a few minutes." | ||
) | ||
spacy.require_gpu() | ||
spacy.cli.download(model) # type: ignore | ||
|
||
def start_server(): | ||
print(f'Starting the inference server with {model_server_workers} workers.') | ||
|
||
save_and_load_model() | ||
|
||
# Start Gunicorn to serve the FastAPI app | ||
gunicorn = subprocess.Popen(['gunicorn', | ||
'--timeout', str(model_server_timeout), | ||
'-k', 'uvicorn.workers.UvicornWorker', | ||
'-b', '0.0.0.0:8080', | ||
'-w', str(model_server_workers), | ||
'app:app']) | ||
|
||
signal.signal(signal.SIGTERM, lambda a, b: sigterm_handler(gunicorn.pid)) | ||
|
||
# Wait for the Gunicorn process to exit | ||
gunicorn.wait() | ||
|
||
print('Inference server exiting') | ||
|
||
# The main routine just invokes the start function. | ||
|
||
if __name__ == '__main__': | ||
start_server() |