Skip to content

Commit

Permalink
Modify vector database scripts to save and serve images directly from…
Browse files Browse the repository at this point in the history
… filesystem

- Update compute_vectors.py to save images to mounted bucket instead of base64 encoding
- Modify serve_vectordb.py to serve images from filesystem via new endpoint
- Update YAML configs to mount images directory consistently across scripts
- Refactor image handling to use file paths instead of base64 representations
  • Loading branch information
KeplerC committed Jan 31, 2025
1 parent 4cd1027 commit 9e349f6
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 26 deletions.
7 changes: 6 additions & 1 deletion examples/vector_database/build_vectordb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ workdir: ~/skypilot/examples/vector_database

file_mounts:
/clip_embeddings:
name: sky-embeddings
name: sky-embedding
# this needs to be the same as the source in the compute_vectors.yaml
mode: MOUNT

Expand All @@ -13,6 +13,11 @@ file_mounts:
# this needs to be the same as the source in the serve_vectordb.yaml
mode: MOUNT

/images:
name: sky-images
# this needs to be the same as the source in compute_vectors.yaml
mode: MOUNT

setup: |
pip install chromadb pandas tqdm pyarrow
Expand Down
7 changes: 6 additions & 1 deletion examples/vector_database/compute_vectors.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ num_nodes: 1

file_mounts:
/output:
name: sky-embeddings
name: sky-embedding
# this needs to be the same as the source in the build_vectordb.yaml
mode: MOUNT
/images:
name: sky-images
# this needs to be the same as the source in the build_vectordb.yaml
mode: MOUNT

envs:
Expand Down
6 changes: 3 additions & 3 deletions examples/vector_database/scripts/build_vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def main():
type=str,
default='',
help='Prefix path within mounted bucket to search for parquet files')

args = parser.parse_args()

# Create a temporary directory for building the database
Expand Down Expand Up @@ -121,11 +121,11 @@ def main():
try:
results = future.result()
if results:
for ids, embeddings, images_base64 in results:
for ids, embeddings, images_paths in results:
collection.add(
ids=list(ids),
embeddings=list(embeddings),
documents=list(images_base64)
documents=list(images_paths)
)
except Exception as e:
logger.error(f"Error processing file {file}: {str(e)}")
Expand Down
41 changes: 32 additions & 9 deletions examples/vector_database/scripts/compute_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class BatchProcessor():

def __init__(self,
output_path: str,
images_path: str = "/images",
model_name: str = "ViT-bigG-14",
dataset_name: str = "ILSVRC/imagenet-1k",
pretrained: str = "laion2b_s39b_b160k",
Expand All @@ -37,12 +38,16 @@ def __init__(self,
start_idx: int = 0,
end_idx: Optional[int] = None):
self.output_path = Path(output_path) # Convert to Path object
self.images_path = Path(images_path) # Path to store images
self.batch_size = batch_size
self.checkpoint_size = checkpoint_size
self.start_idx = start_idx
self.end_idx = end_idx
self._current_batch = []

# Create images directory if it doesn't exist
self.images_path.mkdir(parents=True, exist_ok=True)

# CLIP-specific attributes
self.model_name = model_name
self.pretrained = pretrained
Expand Down Expand Up @@ -96,6 +101,20 @@ async def do_data_loading(
logging.debug(
f"Error preprocessing image at index {idx}: {str(e)}")

def save_image(self, idx: int, image: Image.Image) -> str:
"""Save image to the mounted bucket and return its path."""
# Create a subdirectory based on the first few digits of the index to avoid too many files in one directory
subdir = str(idx // 100000).zfill(4)
save_dir = self.images_path / subdir
save_dir.mkdir(parents=True, exist_ok=True)

# Save image with index as filename
image_path = save_dir / f"{idx}.jpg"
image.save(image_path, format="JPEG", quality=95)

# Return relative path from images root
return str(Path(subdir) / f"{idx}.jpg")

async def do_batch_processing(
self, batch: List[Tuple[int, Tuple[torch.Tensor, Any]]]
) -> List[Tuple[int, bytes]]:
Expand All @@ -118,16 +137,14 @@ async def do_batch_processing(
# Convert to numpy arrays
embeddings = features.cpu().numpy()

# Convert original images to base64
images_base64 = {}
# Save images and store their paths
image_paths = {}
for idx, img in zip(indices, original_images):
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
images_base64[idx] = img_str
image_path = self.save_image(idx, img)
image_paths[idx] = image_path

# Return both embeddings and images
return [(idx, pickle.dumps((images_base64[idx], arr)))
# Return both embeddings and image paths
return [(idx, pickle.dumps((image_paths[idx], arr)))
for idx, arr in zip(indices, embeddings)]

async def find_existing_progress(self) -> Tuple[int, int]:
Expand Down Expand Up @@ -253,6 +270,11 @@ async def main():
type=str,
default='ViT-bigG-14',
help='CLIP model name')

parser.add_argument('--images-path',
type=str,
default='/images',
help='Path to store images')

args = parser.parse_args()

Expand All @@ -267,7 +289,8 @@ async def main():
end_idx=args.end_idx,
batch_size=args.batch_size,
checkpoint_size=args.checkpoint_size,
model_name=args.model_name)
model_name=args.model_name,
images_path=args.images_path)

# Run processing
await processor.run()
Expand Down
39 changes: 29 additions & 10 deletions examples/vector_database/scripts/serve_vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import argparse
import base64
import logging
import os
from pathlib import Path
from typing import List, Optional

import chromadb
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi.responses import HTMLResponse
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
import numpy as np
import open_clip
from pydantic import BaseModel
Expand All @@ -28,6 +30,7 @@
tokenizer = None
collection = None
device = None
images_dir = None


class SearchQuery(BaseModel):
Expand All @@ -36,7 +39,7 @@ class SearchQuery(BaseModel):


class SearchResult(BaseModel):
image_base64: str
image_path: str
similarity: float


Expand Down Expand Up @@ -64,16 +67,16 @@ def query_collection(query_embedding: np.ndarray,
n_results=n_results,
include=["metadatas", "distances", "documents"])

# Get images and distances, images in documents
images = results['documents'][0]
# Get image paths and distances
image_paths = results['documents'][0]
distances = results['distances'][0]

# Convert distances to similarities (cosine similarity = 1 - distance/2)
similarities = [1 - (d / 2) for d in distances]

return [
SearchResult(image_base64=img, similarity=similarity)
for img, similarity in zip(images, similarities)
SearchResult(image_path=img_path, similarity=similarity)
for img_path, similarity in zip(image_paths, similarities)
]


Expand All @@ -93,6 +96,15 @@ async def search(query: SearchQuery):
raise HTTPException(status_code=500, detail=str(e))


@app.get("/image/{subpath:path}")
async def get_image(subpath: str):
"""Serve an image from the mounted bucket."""
image_path = os.path.join(images_dir, subpath)
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(image_path, media_type="image/jpeg")


@app.get("/health")
async def health_check():
"""Health check endpoint."""
Expand Down Expand Up @@ -251,7 +263,7 @@ async def get_search_page():
const results = await response.json();
resultsDiv.innerHTML = results.map(result => `
<div class="result">
<img src="data:image/jpeg;base64,${result.image_base64}"
<img src="/image/${result.image_path}"
alt="Search result">
<div class="result-info">
<p class="similarity-score">
Expand Down Expand Up @@ -295,6 +307,10 @@ def main():
type=str,
default='/vectordb/chroma',
help='Directory where ChromaDB is persisted')
parser.add_argument('--images-dir',
type=str,
default='/images',
help='Directory where images are stored')
parser.add_argument('--model-name',
type=str,
default='ViT-bigG-14',
Expand All @@ -303,12 +319,15 @@ def main():
args = parser.parse_args()

# Initialize global variables
global model, tokenizer, collection, device
global model, tokenizer, collection, device, images_dir

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")

# Set images directory
images_dir = args.images_dir

# Load the model
import open_clip
model, _, _ = open_clip.create_model_and_transforms(
Expand Down
8 changes: 6 additions & 2 deletions examples/vector_database/serve_vectordb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ resources:
file_mounts:
/vectordb:
name: sky-vectordb
# this needs to be the same as the source in the build_vectordb.yaml
mode: MOUNT
/images:
name: sky-images
# this needs to be the same as the source in the build_vectordb.yaml
mode: MOUNT

setup: |
Expand All @@ -18,12 +23,11 @@ setup: |
pip install open_clip_torch chromadb pandas
pip install fastapi uvicorn pydantic
run: |
python scripts/serve_vectordb.py \
--collection-name clip_embeddings \
--persist-dir /vectordb/chroma \
--images-dir /images \
--host 0.0.0.0 \
--port 8000
Expand Down

0 comments on commit 9e349f6

Please sign in to comment.