Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
KeplerC committed Feb 3, 2025
1 parent f5288ac commit b7deee1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 25 deletions.
36 changes: 18 additions & 18 deletions examples/vector_database/scripts/build_vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

import argparse
import base64
from concurrent.futures import as_completed
from concurrent.futures import ProcessPoolExecutor
import glob
import logging
import multiprocessing
import os
import pickle
import shutil
import tempfile
import multiprocessing
from concurrent.futures import ProcessPoolExecutor, as_completed

import chromadb
import numpy as np
Expand All @@ -36,7 +37,7 @@ def process_parquet_file(args):
try:
results = []
df = pd.read_parquet(parquet_file)

# Process in batches
for i in range(0, len(df), batch_size):
batch_df = df.iloc[i:i + batch_size]
Expand All @@ -45,7 +46,7 @@ def process_parquet_file(args):
unpacked_data = [pickle.loads(row) for row in batch_df['output']]
images_base64, embeddings = zip(*unpacked_data)
results.append((ids, embeddings, images_base64))

return results
except Exception as e:
logger.error(f"Error processing file {parquet_file}: {str(e)}")
Expand Down Expand Up @@ -77,7 +78,7 @@ def main():
type=str,
default='',
help='Prefix path within mounted bucket to search for parquet files')

args = parser.parse_args()

# Create a temporary directory for building the database
Expand All @@ -103,30 +104,29 @@ def main():
logger.info(f"Found {len(parquet_files)} parquet files")

# Process files in parallel
max_workers = max(1, multiprocessing.cpu_count() - 1) # Leave one CPU free
max_workers = max(1,
multiprocessing.cpu_count() - 1) # Leave one CPU free
logger.info(f"Processing files using {max_workers} workers")

with ProcessPoolExecutor(max_workers=max_workers) as executor:
# Submit all files for processing
future_to_file = {
executor.submit(process_parquet_file, (file, args.batch_size)): file
for file in parquet_files
executor.submit(process_parquet_file, (file, args.batch_size)):
file for file in parquet_files
}

# Process results as they complete
for future in tqdm(as_completed(future_to_file),
total=len(parquet_files),
desc="Processing files"):
for future in tqdm(as_completed(future_to_file),
total=len(parquet_files),
desc="Processing files"):
file = future_to_file[future]
try:
results = future.result()
if results:
for ids, embeddings, images_paths in results:
collection.add(
ids=list(ids),
embeddings=list(embeddings),
documents=list(images_paths)
)
collection.add(ids=list(ids),
embeddings=list(embeddings),
documents=list(images_paths))
except Exception as e:
logger.error(f"Error processing file {file}: {str(e)}")
continue
Expand Down
11 changes: 6 additions & 5 deletions examples/vector_database/scripts/compute_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from tqdm import tqdm


class BatchProcessor():
"""Process ImageNet images with CLIP.
Expand Down Expand Up @@ -114,11 +115,11 @@ def save_image(self, idx: int, image: Image.Image) -> str:
subdir = str(idx // 100000).zfill(4)
save_dir = self.images_path / subdir
save_dir.mkdir(parents=True, exist_ok=True)

# Save image with index as filename
image_path = save_dir / f"{idx}.jpg"
image.save(image_path, format="JPEG", quality=95)

# Return relative path from images root
return str(Path(subdir) / f"{idx}.jpg")

Expand Down Expand Up @@ -188,7 +189,7 @@ async def find_existing_progress(self) -> Tuple[int, int]:
logging.warning(f"Error processing file {file}: {e}")

return max_idx, max_partition + 1

def save_results_to_parquet(self, results: list):
"""Save results to a parquet file with atomic write."""
if not results:
Expand All @@ -209,7 +210,7 @@ def save_results_to_parquet(self, results: list):
f"Saved partition {self.partition_counter} to {final_path} with {len(df)} rows"
)
self.partition_counter += 1

async def run(self):
"""
Run the batch processing pipeline with recovery support.
Expand Down Expand Up @@ -277,7 +278,7 @@ async def main():
type=str,
default='ViT-bigG-14',
help='CLIP model name')

parser.add_argument('--images-path',
type=str,
default='/images',
Expand Down
6 changes: 4 additions & 2 deletions examples/vector_database/scripts/serve_vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from typing import List, Optional

import chromadb
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse, FileResponse
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi.responses import FileResponse
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import numpy as np
import open_clip
Expand Down

0 comments on commit b7deee1

Please sign in to comment.