Skip to content

Commit

Permalink
Improve vector database script initialization and model setup
Browse files Browse the repository at this point in the history
- Update build_vectordb.py to handle temporary directory for read/write buckets
- Modify collection creation metadata description
- Refactor compute_vectors.py to initialize model before batch processing
- Move model initialization to ensure it's set up before data loading
  • Loading branch information
KeplerC committed Feb 4, 2025
1 parent d2136a1 commit d88113d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
9 changes: 7 additions & 2 deletions examples/vector_database/scripts/build_vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,22 @@ def main():
args = parser.parse_args()

# Create a temporary directory for building the database
# because the buckets are mounted as read/write only
# we need to copy the data to a temporary directory
# and then copy it to the final location
with tempfile.TemporaryDirectory() as temp_dir:
logger.info(f"Using temporary directory: {temp_dir}")

# Initialize ChromaDB in temporary directory
client = chromadb.PersistentClient(path=temp_dir)

# Create or get collection
# Create or get collection for chromadb
# it attempts to create a collection with the same name
# if it already exists, it will get the collection
try:
collection = client.create_collection(
name=args.collection_name,
metadata={"description": "CLIP embeddings from LAION dataset"})
metadata={"description": "CLIP embeddings from dataset"})
logger.info(f"Created new collection: {args.collection_name}")
except ValueError:
collection = client.get_collection(name=args.collection_name)
Expand Down
6 changes: 4 additions & 2 deletions examples/vector_database/scripts/compute_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ async def get_dataset_iterator(self) -> AsyncIterator[Tuple[int, Any]]:
async def do_data_loading(
self) -> AsyncIterator[Tuple[int, Tuple[torch.Tensor, Any]]]:
"""Load and preprocess ImageNet images."""
if self.model is None:
await self.setup_model()

async for idx, item in self.get_dataset_iterator():
try:
Expand Down Expand Up @@ -215,6 +213,10 @@ async def run(self):
"""
Run the batch processing pipeline with recovery support.
"""
# Initialize the model
if self.model is None:
await self.setup_model()

# Find existing progress
resume_idx, self.partition_counter = await self.find_existing_progress()
self.start_idx = max(self.start_idx, resume_idx + 1)
Expand Down

0 comments on commit d88113d

Please sign in to comment.