diff --git a/data/leaves.txt b/data/leaves.txt new file mode 100644 index 0000000..843cbaa --- /dev/null +++ b/data/leaves.txt @@ -0,0 +1 @@ +Leaves are green because chlorophyll absorbs red and blue light. \ No newline at end of file diff --git a/data/sky.txt b/data/sky.txt new file mode 100644 index 0000000..00303ef --- /dev/null +++ b/data/sky.txt @@ -0,0 +1 @@ +The sky is blue because of Rayleigh scattering. \ No newline at end of file diff --git a/examples/chromem/README.md b/examples/chromem/README.md new file mode 100644 index 0000000..645563f --- /dev/null +++ b/examples/chromem/README.md @@ -0,0 +1,53 @@ +# Chromem Example + +This example demonstrates how to use the Chromem vector database with Raggo's SimpleRAG interface. + +## Prerequisites + +1. Go 1.16 or later +2. OpenAI API key (set as environment variable `OPENAI_API_KEY`) + +## Running the Example + +1. Set your OpenAI API key: +```bash +export OPENAI_API_KEY='your-api-key' +``` + +2. Run the example: +```bash +go run main.go +``` + +## What it Does + +1. Creates a new SimpleRAG instance with Chromem as the vector database +2. Creates sample documents about natural phenomena +3. Adds the documents to the database +4. Performs a semantic search using the query "Why is the sky blue?" +5. Prints the response based on the relevant documents found + +## Expected Output + +``` +Question: Why is the sky blue? + +Answer: The sky appears blue because of a phenomenon called Rayleigh scattering. When sunlight travels through Earth's atmosphere, it collides with gas molecules. These molecules scatter blue wavelengths of light more strongly than red wavelengths, which is why we see the sky as blue. +``` + +## Configuration + +The example uses the following configuration: +- Vector Database: Chromem (persistent mode) +- Collection Name: knowledge-base +- Embedding Model: text-embedding-3-small +- Chunk Size: 200 characters +- Chunk Overlap: 50 characters +- Top K Results: 1 +- Minimum Score: 0.1 + +## Notes + +- The database is stored in `./data/chromem.db` +- Sample documents are created in the `./data` directory +- The example uses persistent storage mode for Chromem diff --git a/examples/chromem/main.go b/examples/chromem/main.go new file mode 100644 index 0000000..207242c --- /dev/null +++ b/examples/chromem/main.go @@ -0,0 +1,77 @@ +package main + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "github.com/teilomillet/raggo" +) + +func main() { + // Enable debug logging + raggo.SetLogLevel(raggo.LogLevelDebug) + + // Create a temporary directory for our documents + tmpDir := "./data" + err := os.MkdirAll(tmpDir, 0755) + if err != nil { + fmt.Printf("Error creating temp directory: %v\n", err) + os.Exit(1) + } + + // Create sample documents + docs := map[string]string{ + "sky.txt": "The sky is blue because of Rayleigh scattering.", + "leaves.txt": "Leaves are green because chlorophyll absorbs red and blue light.", + } + + for filename, content := range docs { + err := os.WriteFile(filepath.Join(tmpDir, filename), []byte(content), 0644) + if err != nil { + fmt.Printf("Error writing file %s: %v\n", filename, err) + os.Exit(1) + } + } + + // Initialize RAG with Chromem + config := raggo.SimpleRAGConfig{ + Collection: "knowledge-base", + DBType: "chromem", + DBAddress: "./data/chromem.db", + Model: "text-embedding-3-small", // OpenAI embedding model + APIKey: os.Getenv("OPENAI_API_KEY"), + Dimension: 1536, // text-embedding-3-small dimension + // TopK is determined dynamically by the number of documents + } + + raggo.Debug("Creating SimpleRAG with config", "config", config) + + rag, err := raggo.NewSimpleRAG(config) + if err != nil { + fmt.Printf("Error creating SimpleRAG: %v\n", err) + os.Exit(1) + } + defer rag.Close() + + ctx := context.Background() + + // Add documents from the directory + raggo.Debug("Adding documents from directory", "dir", tmpDir) + err = rag.AddDocuments(ctx, tmpDir) + if err != nil { + fmt.Printf("Error adding documents: %v\n", err) + os.Exit(1) + } + + // Search for documents + raggo.Debug("Searching for documents", "query", "Why is the sky blue?") + response, err := rag.Search(ctx, "Why is the sky blue?") + if err != nil { + fmt.Printf("Error searching: %v\n", err) + os.Exit(1) + } + + fmt.Printf("Response: %s\n", response) +} diff --git a/go.mod b/go.mod index f963232..14f9806 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/leodido/go-urn v1.4.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/milvus-io/milvus-proto/go-api/v2 v2.4.6 // indirect + github.com/philippgille/chromem-go v0.7.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect diff --git a/go.sum b/go.sum index 99c2f5f..93aca44 100644 --- a/go.sum +++ b/go.sum @@ -84,6 +84,8 @@ github.com/milvus-io/milvus-proto/go-api/v2 v2.4.6/go.mod h1:1OIl0v5PQeNxIJhCvY+ github.com/milvus-io/milvus-sdk-go/v2 v2.4.1 h1:KhqjmaJE4mSxj1a88XtkGaqgH4duGiHs1sjnvSXkwE0= github.com/milvus-io/milvus-sdk-go/v2 v2.4.1/go.mod h1:7SJxshlnVhNLksS73tLPtHYY9DiX7lyL43Rv41HCPCw= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/philippgille/chromem-go v0.7.0 h1:4jfvfyKymjKNfGxBUhHUcj1kp7B17NL/I1P+vGh1RvY= +github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/rag/chromem.go b/rag/chromem.go new file mode 100644 index 0000000..597cd5d --- /dev/null +++ b/rag/chromem.go @@ -0,0 +1,434 @@ +// File: chromem.go + +package rag + +import ( + "context" + "fmt" + "log" + "os" + "path/filepath" + "sync" + + "github.com/philippgille/chromem-go" +) + +type ChromemDB struct { + db *chromem.DB + collections map[string]*chromem.Collection + mu sync.RWMutex + columnNames []string + dimension int +} + +func newChromemDB(cfg *Config) (*ChromemDB, error) { + log.Printf("Creating new ChromemDB with config: %+v", cfg) + + // Get dimension from config parameters + dimension, ok := cfg.Parameters["dimension"].(int) + if !ok { + log.Printf("No dimension found in config parameters, using default 1536") + dimension = 1536 + } + log.Printf("Using dimension: %d", dimension) + + // Create DB + var db *chromem.DB + var err error + if cfg.Address != "" { + // Ensure directory exists + dir := filepath.Dir(cfg.Address) + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, fmt.Errorf("failed to create directory for ChromemDB: %w", err) + } + + log.Printf("Creating persistent ChromemDB at %s", cfg.Address) + db, err = chromem.NewPersistentDB(cfg.Address, false) // Don't truncate existing DB + if err != nil { + log.Printf("Failed to create persistent ChromemDB: %v", err) + return nil, fmt.Errorf("failed to create persistent ChromemDB: %w", err) + } + + // Verify database file exists + if _, err := os.Stat(cfg.Address); os.IsNotExist(err) { + log.Printf("Warning: ChromemDB file %s does not exist after creation", cfg.Address) + return nil, fmt.Errorf("ChromemDB file %s does not exist after creation", cfg.Address) + } + } else { + log.Printf("Creating in-memory ChromemDB") + db = chromem.NewDB() + } + + if db == nil { + log.Printf("ChromemDB is nil after creation") + return nil, fmt.Errorf("ChromemDB is nil after creation") + } + + // Test database by creating and removing a test collection + testCol := "test_collection" + log.Printf("Testing database by creating test collection %s", testCol) + + // Create test collection + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + return nil, fmt.Errorf("OPENAI_API_KEY environment variable not set") + } + embeddingFunc := chromem.NewEmbeddingFuncOpenAI(apiKey, "text-embedding-3-small") + + col, err := db.CreateCollection(testCol, map[string]string{}, embeddingFunc) + if err != nil { + log.Printf("Failed to create test collection: %v", err) + return nil, fmt.Errorf("failed to create test collection: %w", err) + } + + if col == nil { + log.Printf("Test collection is nil after creation") + return nil, fmt.Errorf("test collection is nil after creation") + } + + // Get collection to verify it exists + if col = db.GetCollection(testCol, embeddingFunc); col == nil { + log.Printf("Test collection not found after creation") + return nil, fmt.Errorf("test collection not found after creation") + } + + // Drop test collection by creating a new one with truncate=true + col, err = db.CreateCollection(testCol, map[string]string{}, embeddingFunc) + if err != nil { + log.Printf("Failed to drop test collection: %v", err) + return nil, fmt.Errorf("failed to drop test collection: %w", err) + } + + log.Printf("Successfully created and tested ChromemDB") + + return &ChromemDB{ + db: db, + collections: make(map[string]*chromem.Collection), + dimension: dimension, + }, nil +} + +func (c *ChromemDB) Connect(ctx context.Context) error { + log.Printf("Connecting to ChromemDB") + // No explicit connect needed for chromem + log.Printf("ChromemDB connected (no-op)") + return nil +} + +func (c *ChromemDB) Close() error { + // No explicit close in chromem + return nil +} + +func (c *ChromemDB) HasCollection(ctx context.Context, name string) (bool, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + log.Printf("Checking if collection %s exists", name) + + // First check our local map + if _, exists := c.collections[name]; exists { + log.Printf("Collection %s found in local map", name) + return true, nil + } + + // Get OpenAI API key from environment + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + return false, fmt.Errorf("OPENAI_API_KEY environment variable not set") + } + + // Create embedding function using OpenAI's text-embedding-3-small + embeddingFunc := chromem.NewEmbeddingFuncOpenAI(apiKey, "text-embedding-3-small") + + // Try to get the collection + col := c.db.GetCollection(name, embeddingFunc) + exists := col != nil + + if exists { + log.Printf("Collection %s found in database", name) + // Cache the collection in our map + c.collections[name] = col + } else { + log.Printf("Collection %s not found in database", name) + } + + return exists, nil +} + +func (c *ChromemDB) DropCollection(ctx context.Context, name string) error { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.collections, name) + return nil +} + +func (c *ChromemDB) CreateCollection(ctx context.Context, name string, schema Schema) error { + c.mu.Lock() + defer c.mu.Unlock() + + log.Printf("Creating collection: %s (ignoring schema as Chromem doesn't use it)", name) + + // Check if collection already exists in our map + if _, exists := c.collections[name]; exists { + log.Printf("Collection %s already exists in our map", name) + return nil + } + + // Get OpenAI API key from environment + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + return fmt.Errorf("OPENAI_API_KEY environment variable not set") + } + + // Create embedding function using OpenAI's text-embedding-3-small + embeddingFunc := chromem.NewEmbeddingFuncOpenAI(apiKey, "text-embedding-3-small") + + // Create collection in ChromemDB with empty metadata + col, err := c.db.CreateCollection(name, map[string]string{}, embeddingFunc) + if err != nil { + log.Printf("Failed to create collection %s: %v", name, err) + return fmt.Errorf("failed to create collection %s: %w", name, err) + } + + // Store collection in our map + c.collections[name] = col + log.Printf("Successfully created collection %s with dimension %d and embedding function %T", name, c.dimension, embeddingFunc) + + // Verify collection was created + verifyCol := c.db.GetCollection(name, embeddingFunc) + if verifyCol == nil { + log.Printf("Warning: Collection %s was not properly created", name) + return fmt.Errorf("collection %s was not properly created", name) + } + + return nil +} + +func (c *ChromemDB) Insert(ctx context.Context, collectionName string, data []Record) error { + c.mu.Lock() + defer c.mu.Unlock() + + log.Printf("Inserting %d records into collection %s", len(data), collectionName) + + // Get collection from our map + col, exists := c.collections[collectionName] + if !exists { + return fmt.Errorf("collection %s does not exist in collections map", collectionName) + } + + // Convert records to chromem documents + docs := make([]chromem.Document, len(data)) + validCount := 0 + + for i, record := range data { + // Extract content and metadata + content, ok := record.Fields["Text"].(string) + if !ok { + log.Printf("Warning: Record %d has no 'Text' field or it's not a string, skipping", i) + continue + } + + metadata := make(map[string]string) + if metaField, ok := record.Fields["Metadata"]; ok { + if meta, ok := metaField.(map[string]interface{}); ok { + for k, v := range meta { + if str, ok := v.(string); ok { + metadata[k] = str + } + } + } + } + + // Get embedding and convert to []float32 if needed + var embedding []float32 + if embField, ok := record.Fields["Embedding"]; ok { + switch e := embField.(type) { + case []float32: + embedding = e + case Vector: + embedding = toFloat32Slice(e) + case []float64: + embedding = make([]float32, len(e)) + for j, v := range e { + embedding[j] = float32(v) + } + default: + log.Printf("Warning: Record %d has invalid embedding type %T, skipping", i, embField) + continue + } + } else { + log.Printf("Warning: Record %d has no 'Embedding' field, skipping", i) + continue + } + + // Create document + docs[validCount] = chromem.Document{ + ID: fmt.Sprintf("%d", i), + Content: content, + Metadata: metadata, + Embedding: embedding, + } + validCount++ + } + + // Trim docs to valid count + docs = docs[:validCount] + + if validCount == 0 { + log.Printf("Warning: No valid documents to insert into collection %s", collectionName) + return nil + } + + log.Printf("Converted %d/%d records to valid documents for collection %s", validCount, len(data), collectionName) + + // Insert documents in batches to avoid memory issues + batchSize := 100 + for i := 0; i < len(docs); i += batchSize { + end := i + batchSize + if end > len(docs) { + end = len(docs) + } + batch := docs[i:end] + + log.Printf("Inserting batch of %d documents (batch %d/%d) into collection %s", len(batch), (i/batchSize)+1, (len(docs)+batchSize-1)/batchSize, collectionName) + for _, doc := range batch { + err := col.AddDocument(ctx, doc) + if err != nil { + return fmt.Errorf("failed to insert document: %w", err) + } + } + } + + log.Printf("Successfully inserted %d documents into collection %s", validCount, collectionName) + + return nil +} + +func (c *ChromemDB) Flush(ctx context.Context, collectionName string) error { + // No explicit flush in chromem + return nil +} + +func (c *ChromemDB) CreateIndex(ctx context.Context, collectionName, field string, index Index) error { + // No explicit index creation in chromem + return nil +} + +func (c *ChromemDB) LoadCollection(ctx context.Context, name string) error { + c.mu.Lock() + defer c.mu.Unlock() + + log.Printf("Loading collection: %s", name) + + // Get OpenAI API key from environment + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + return fmt.Errorf("OPENAI_API_KEY environment variable not set") + } + + // Create embedding function using OpenAI's text-embedding-3-small + embeddingFunc := chromem.NewEmbeddingFuncOpenAI(apiKey, "text-embedding-3-small") + + // Get collection from ChromemDB + col := c.db.GetCollection(name, embeddingFunc) + if col == nil { + log.Printf("Collection %s not found", name) + return fmt.Errorf("collection %s not found", name) + } + + // Store collection in our map + c.collections[name] = col + log.Printf("Successfully loaded collection %s", name) + + return nil +} + +func (c *ChromemDB) Search(ctx context.Context, collectionName string, vectors map[string]Vector, topK int, metricType string, searchParams map[string]interface{}) ([]SearchResult, error) { + c.mu.RLock() + + // First check if collection exists in our map + col, exists := c.collections[collectionName] + c.mu.RUnlock() + + if !exists { + // Try to load the collection + err := c.LoadCollection(ctx, collectionName) + if err != nil { + return nil, fmt.Errorf("failed to load collection: %w", err) + } + + // Get the collection again + c.mu.RLock() + col = c.collections[collectionName] + c.mu.RUnlock() + } + + // We only support single vector search for now + if len(vectors) != 1 { + return nil, fmt.Errorf("chromem only supports single vector search") + } + + // Get the first vector + var queryVector Vector + for _, v := range vectors { + queryVector = v + break + } + + // Convert query vector to float32 + query := toFloat32Slice(queryVector) + + log.Printf("Searching collection %s with query vector of length %d", collectionName, len(query)) + + // Search documents using empty filters for where and whereDocument + results, err := col.QueryEmbedding(ctx, query, topK, make(map[string]string), make(map[string]string)) + if err != nil { + return nil, fmt.Errorf("failed to search documents: %w", err) + } + + log.Printf("Found %d results (requested topK=%d)", len(results), topK) + + if len(results) == 0 { + log.Printf("Warning: No results found in collection %s. This could indicate that either: (1) the collection is empty, (2) no similar documents were found, or (3) the collection was not properly loaded.", collectionName) + return []SearchResult{}, nil + } + + // Convert results + searchResults := make([]SearchResult, len(results)) + for i, result := range results { + fields := make(map[string]interface{}) + fields["Text"] = result.Content + if len(result.Metadata) > 0 { + fields["Metadata"] = result.Metadata + } + + searchResults[i] = SearchResult{ + ID: int64(i), + Score: float64(result.Similarity), + Fields: fields, + } + log.Printf("Result %d: score=%f, content=%s", i, result.Similarity, result.Content) + } + + return searchResults, nil +} + +func (c *ChromemDB) HybridSearch(ctx context.Context, collectionName string, vectors map[string]Vector, topK int, metricType string, searchParams map[string]interface{}, reranker interface{}) ([]SearchResult, error) { + // Not implemented for chromem + return nil, fmt.Errorf("hybrid search not implemented for chromem") +} + +func (c *ChromemDB) SetColumnNames(names []string) { + c.columnNames = names +} + +// Helper function to convert Vector to []float32 +func toFloat32Slice(v Vector) []float32 { + result := make([]float32, len(v)) + for i, val := range v { + result[i] = float32(val) + } + return result +} diff --git a/rag/milvus.go b/rag/milvus.go index 05fd99c..3df3b3d 100644 --- a/rag/milvus.go +++ b/rag/milvus.go @@ -234,6 +234,8 @@ func (m *MilvusDB) createColumn(fieldName string, fieldValue interface{}) entity return entity.NewColumnInt64(fieldName, []int64{}) case []float64: return entity.NewColumnFloatVector(fieldName, len(v), [][]float32{}) + case []float32: + return entity.NewColumnFloatVector(fieldName, len(v), [][]float32{}) case string: return entity.NewColumnVarChar(fieldName, []string{}) case map[string]interface{}: @@ -254,9 +256,17 @@ func (m *MilvusDB) appendToColumn(col entity.Column, value interface{}) { case *entity.ColumnInt64: c.AppendValue(value.(int64)) case *entity.ColumnFloatVector: - floatVector := make([]float32, len(value.([]float64))) - for i, v := range value.([]float64) { - floatVector[i] = float32(v) + var floatVector []float32 + switch v := value.(type) { + case []float64: + floatVector = make([]float32, len(v)) + for i, val := range v { + floatVector[i] = float32(val) + } + case []float32: + floatVector = v + default: + panic(fmt.Sprintf("unsupported vector type: %T", value)) } c.AppendValue(floatVector) case *entity.ColumnVarChar: diff --git a/rag/providers/openai.go b/rag/providers/openai.go index cc0de9b..bce6d12 100644 --- a/rag/providers/openai.go +++ b/rag/providers/openai.go @@ -109,3 +109,16 @@ func (e *OpenAIEmbedder) Embed(ctx context.Context, text string) ([]float64, err return embeddingResp.Data[0].Embedding, nil } + +func (e *OpenAIEmbedder) GetDimension() (int, error) { + switch e.modelName { + case "text-embedding-3-small": + return 1536, nil + case "text-embedding-3-large": + return 3072, nil + case "text-embedding-ada-002": + return 1536, nil + default: + return 0, fmt.Errorf("unknown model: %s", e.modelName) + } +} diff --git a/rag/providers/register.go b/rag/providers/register.go index d591484..002dcff 100644 --- a/rag/providers/register.go +++ b/rag/providers/register.go @@ -34,5 +34,9 @@ func GetEmbedderFactory(name string) (EmbedderFactory, error) { // Embedder interface defines the contract for embedding implementations type Embedder interface { + // Embed generates embeddings for the given text Embed(ctx context.Context, text string) ([]float64, error) + + // GetDimension returns the dimension of the embeddings for the current model + GetDimension() (int, error) } diff --git a/rag/vector_interface.go b/rag/vector_interface.go index 4a43d62..c7b5b51 100644 --- a/rag/vector_interface.go +++ b/rag/vector_interface.go @@ -66,6 +66,7 @@ type Config struct { Address string MaxPoolSize int Timeout time.Duration + Parameters map[string]interface{} } type Option func(*Config) @@ -96,6 +97,8 @@ func NewVectorDB(cfg *Config) (VectorDB, error) { return newMilvusDB(cfg) case "memory": return newMemoryDB(cfg) + case "chromem": + return newChromemDB(cfg) default: return nil, fmt.Errorf("unsupported database type: %s", cfg.Type) } diff --git a/register.go b/register.go index a1fe0cc..25244b4 100644 --- a/register.go +++ b/register.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "strconv" "time" ) @@ -49,12 +50,8 @@ func defaultConfig() *RegisterConfig { EmbeddingProvider: "openai", EmbeddingModel: "text-embedding-3-small", EmbeddingKey: os.Getenv("OPENAI_API_KEY"), - OnProgress: func(processed, total int) { - Info(fmt.Sprintf("Progress: %d/%d", processed, total)) - }, - OnError: func(err error) { - Error(fmt.Sprintf("Error: %v", err)) - }, + OnProgress: func(processed, total int) { Debug("Progress", "processed", processed, "total", total) }, + OnError: func(err error) { Error("Error during registration", "error", err) }, } } @@ -70,13 +67,17 @@ func Register(ctx context.Context, source string, opts ...RegisterOption) error opt(cfg) } + Debug("Initializing registration", "source", source, "config", cfg) + // Create loader + Debug("Creating loader") loader := NewLoader( SetTempDir(cfg.TempDir), SetTimeout(cfg.Timeout), ) // Create chunker + Debug("Creating chunker") chunker, err := NewChunker( ChunkSize(cfg.ChunkSize), ChunkOverlap(cfg.ChunkOverlap), @@ -86,6 +87,7 @@ func Register(ctx context.Context, source string, opts ...RegisterOption) error } // Create embedder + Debug("Creating embedder") embedder, err := NewEmbedder( SetProvider(cfg.EmbeddingProvider), SetModel(cfg.EmbeddingModel), @@ -95,32 +97,57 @@ func Register(ctx context.Context, source string, opts ...RegisterOption) error return fmt.Errorf("failed to create embedder: %w", err) } + // Get embedding dimension + Debug("Getting embedding dimension") + dimension, err := embedder.GetDimension() + if err != nil { + return fmt.Errorf("failed to get embedding dimension: %w", err) + } + Debug("Embedding dimension", "dimension", dimension) + // Create vector store + Debug("Creating vector store") + // Get dimension from config or use the one from embedder + configDimension := 0 + if dimStr := cfg.VectorDBConfig["dimension"]; dimStr != "" { + if dim, err := strconv.Atoi(dimStr); err == nil { + configDimension = dim + } + } + if configDimension == 0 { + configDimension = dimension + } + Debug("Using dimension", "dimension", configDimension) + vectorDB, err := NewVectorDB( WithType(cfg.VectorDBType), WithAddress(cfg.VectorDBConfig["address"]), WithTimeout(cfg.Timeout), + WithDimension(configDimension), ) if err != nil { return fmt.Errorf("failed to create vector store: %w", err) } defer vectorDB.Close() + Debug("Connecting to vector store") if err := vectorDB.Connect(ctx); err != nil { return fmt.Errorf("failed to connect to vector store: %w", err) } // Create collection if needed if cfg.AutoCreate { + Debug("Checking collection existence", "collection", cfg.CollectionName) exists, _ := vectorDB.HasCollection(ctx, cfg.CollectionName) if !exists { + Debug("Creating collection", "collection", cfg.CollectionName) schema := Schema{ Name: cfg.CollectionName, Fields: []Field{ {Name: "ID", DataType: "int64", PrimaryKey: true, AutoID: true}, - {Name: "Embedding", DataType: "float_vector", Dimension: 1536}, - {Name: "Text", DataType: "varchar", MaxLength: 65535}, // Added MaxLength - {Name: "Metadata", DataType: "varchar", MaxLength: 65535}, // Added MaxLength + {Name: "Embedding", DataType: "float_vector", Dimension: dimension}, + {Name: "Text", DataType: "varchar", MaxLength: 65535}, + {Name: "Metadata", DataType: "varchar", MaxLength: 65535}, }, } @@ -130,6 +157,7 @@ func Register(ctx context.Context, source string, opts ...RegisterOption) error } // Create index for vector field + Debug("Creating index") index := Index{ Type: "HNSW", Metric: "L2", @@ -143,18 +171,23 @@ func Register(ctx context.Context, source string, opts ...RegisterOption) error } // Load collection + Debug("Loading collection") if err := vectorDB.LoadCollection(ctx, cfg.CollectionName); err != nil { return fmt.Errorf("failed to load collection: %w", err) } } } + // Process source + Debug("Processing source", "source", source) var paths []string - var loadErr error // Changed to use a different variable name + var loadErr error if info, err := os.Stat(source); err == nil { if info.IsDir() { + Debug("Loading directory") paths, loadErr = loader.LoadDir(ctx, source) } else { + Debug("Loading file") var path string path, loadErr = loader.LoadFile(ctx, source) if loadErr == nil { @@ -165,6 +198,7 @@ func Register(ctx context.Context, source string, opts ...RegisterOption) error return fmt.Errorf("failed to load source: %w", loadErr) } } else if isURL(source) { + Debug("Loading URL") path, loadErr := loader.LoadURL(ctx, source) if loadErr != nil { return fmt.Errorf("failed to load URL: %w", loadErr) @@ -175,10 +209,14 @@ func Register(ctx context.Context, source string, opts ...RegisterOption) error } // Create embedding service + Debug("Creating embedding service") embeddingService := NewEmbeddingService(embedder) // Process files + Debug("Processing files", "count", len(paths)) for i, path := range paths { + Debug("Processing file", "path", path, "index", i+1, "total", len(paths)) + // Parse content parser := NewParser() doc, err := parser.Parse(path) @@ -188,21 +226,38 @@ func Register(ctx context.Context, source string, opts ...RegisterOption) error } // Create chunks + Debug("Creating chunks") chunks := chunker.Chunk(doc.Content) + Debug("Created chunks", "count", len(chunks)) // Create embeddings + Debug("Creating embeddings") embeddedChunks, err := embeddingService.EmbedChunks(ctx, chunks) if err != nil { cfg.OnError(fmt.Errorf("failed to embed chunks from %s: %w", path, err)) continue } + Debug("Created embeddings", "count", len(embeddedChunks)) // Convert to records + Debug("Converting to records") records := make([]Record, len(embeddedChunks)) for j, chunk := range embeddedChunks { + embedding, ok := chunk.Embeddings["default"] + if !ok || len(embedding) == 0 { + cfg.OnError(fmt.Errorf("missing or empty embedding for chunk %d in %s", j, path)) + continue + } + + // Convert []float64 to []float32 for ChromemDB + embedding32 := make([]float32, len(embedding)) + for i, v := range embedding { + embedding32[i] = float32(v) + } + records[j] = Record{ Fields: map[string]interface{}{ - "Embedding": chunk.Embeddings["default"], + "Embedding": embedding32, "Text": chunk.Text, "Metadata": map[string]interface{}{ "source": path, @@ -215,6 +270,7 @@ func Register(ctx context.Context, source string, opts ...RegisterOption) error } // Insert into vector store + Debug("Inserting records", "count", len(records)) if err := vectorDB.Insert(ctx, cfg.CollectionName, records); err != nil { cfg.OnError(fmt.Errorf("failed to insert records from %s: %w", path, err)) continue @@ -223,6 +279,7 @@ func Register(ctx context.Context, source string, opts ...RegisterOption) error cfg.OnProgress(i+1, len(paths)) } + Debug("Registration complete") return nil } @@ -231,6 +288,13 @@ func Register(ctx context.Context, source string, opts ...RegisterOption) error func WithVectorDB(dbType string, config map[string]string) RegisterOption { return func(c *RegisterConfig) { c.VectorDBType = dbType + if config == nil { + config = make(map[string]string) + } + // Ensure dimension is preserved in VectorDBConfig + if config["dimension"] == "" { + config["dimension"] = "1536" // Default dimension + } c.VectorDBConfig = config } } diff --git a/retriever.go b/retriever.go index 15551cc..4d8a2af 100644 --- a/retriever.go +++ b/retriever.go @@ -27,6 +27,7 @@ type RetrieverConfig struct { // Vector DB settings DBType string DBAddress string + Dimension int // Embedding settings Provider string @@ -59,6 +60,7 @@ func defaultRetrieverConfig() *RetrieverConfig { Columns: []string{"Text", "Metadata"}, DBType: "milvus", DBAddress: "localhost:19530", + Dimension: 128, Provider: "openai", Model: "text-embedding-3-small", APIKey: os.Getenv("OPENAI_API_KEY"), @@ -96,6 +98,7 @@ func (r *Retriever) initialize() error { WithType(r.config.DBType), WithAddress(r.config.DBAddress), WithTimeout(r.config.Timeout), + WithDimension(r.config.Dimension), ) if err != nil { return fmt.Errorf("failed to create vector store: %w", err) @@ -252,6 +255,12 @@ func WithColumns(columns ...string) RetrieverOption { } } +func WithRetrieveDimension(dimension int) RetrieverOption { + return func(c *RetrieverConfig) { + c.Dimension = dimension + } +} + func WithRetrieveCallbacks(onResult func(SearchResult), onError func(error)) RetrieverOption { return func(c *RetrieverConfig) { c.OnResult = onResult diff --git a/simple_rag.go b/simple_rag.go index 0d9cce9..1c3e325 100644 --- a/simple_rag.go +++ b/simple_rag.go @@ -32,6 +32,9 @@ type SimpleRAGConfig struct { TopK int MinScore float64 LLMModel string + DBType string // Type of vector database to use (e.g., "milvus", "chromem") + DBAddress string // Address for the vector database (e.g., "localhost:19530" for Milvus, "./data/chromem.db" for Chromem) + Dimension int // Dimension of the embedding vectors (e.g., 1536 for text-embedding-3-small) } // DefaultConfig returns a default configuration @@ -44,6 +47,9 @@ func DefaultConfig() SimpleRAGConfig { TopK: 5, MinScore: 0.1, LLMModel: "gpt-4o-mini", + DBType: "milvus", + DBAddress: "localhost:19530", + Dimension: 1536, // Default dimension for text-embedding-3-small } } @@ -68,6 +74,18 @@ func NewSimpleRAG(config SimpleRAGConfig) (*SimpleRAG, error) { config.LLMModel = DefaultConfig().LLMModel } + if config.DBType == "" { + config.DBType = DefaultConfig().DBType + } + + if config.DBAddress == "" { + config.DBAddress = DefaultConfig().DBAddress + } + + if config.Dimension == 0 { + config.Dimension = DefaultConfig().Dimension + } + // Initialize LLM llm, err := gollm.NewLLM( gollm.SetProvider("openai"), @@ -80,8 +98,9 @@ func NewSimpleRAG(config SimpleRAGConfig) (*SimpleRAG, error) { // Initialize vector database vectorDB, err := NewVectorDB( - WithType("milvus"), - WithAddress("localhost:19530"), + WithType(config.DBType), + WithAddress(config.DBAddress), + WithDimension(config.Dimension), WithTimeout(5*time.Minute), ) if err != nil { @@ -110,7 +129,7 @@ func NewSimpleRAG(config SimpleRAGConfig) (*SimpleRAG, error) { // Create retriever with configured options retriever, err := NewRetriever( - WithRetrieveDB("milvus", "localhost:19530"), + WithRetrieveDB(config.DBType, config.DBAddress), WithRetrieveCollection(config.Collection), WithTopK(config.TopK), WithMinScore(config.MinScore), @@ -120,6 +139,7 @@ func NewSimpleRAG(config SimpleRAGConfig) (*SimpleRAG, error) { config.Model, config.APIKey, ), + WithRetrieveDimension(config.Dimension), ) if err != nil { return nil, fmt.Errorf("failed to create retriever: %w", err) @@ -164,6 +184,10 @@ func (s *SimpleRAG) AddDocuments(ctx context.Context, source string) error { WithCollection(s.collection, true), WithChunking(DefaultConfig().ChunkSize, DefaultConfig().ChunkOverlap), WithEmbedding("openai", s.model, s.apiKey), + WithVectorDB(s.vectorDB.Type(), map[string]string{ + "address": s.vectorDB.Address(), + "dimension": fmt.Sprintf("%d", s.vectorDB.Dimension()), + }), ) if err != nil { return fmt.Errorf("failed to add document %s: %w", file.Name(), err) @@ -177,13 +201,17 @@ func (s *SimpleRAG) AddDocuments(ctx context.Context, source string) error { WithCollection(s.collection, true), WithChunking(DefaultConfig().ChunkSize, DefaultConfig().ChunkOverlap), WithEmbedding("openai", s.model, s.apiKey), + WithVectorDB(s.vectorDB.Type(), map[string]string{ + "address": s.vectorDB.Address(), + "dimension": fmt.Sprintf("%d", s.vectorDB.Dimension()), + }), ) if err != nil { return fmt.Errorf("failed to add document: %w", err) } } - // Create and load index + // Create and load index only once after all documents are processed err = s.vectorDB.CreateIndex(ctx, s.collection, "Embedding", Index{ Type: "HNSW", Metric: "L2", @@ -214,6 +242,37 @@ func (s *SimpleRAG) Search(ctx context.Context, query string) (string, error) { log.Printf("Performing search with query: %s", query) + // Get the total number of documents in the collection + hasCollection, err := s.vectorDB.HasCollection(ctx, s.collection) + if err != nil { + return "", fmt.Errorf("failed to check collection: %w", err) + } + if !hasCollection { + return "", fmt.Errorf("collection %s does not exist", s.collection) + } + + // Load collection to ensure it's ready for search + err = s.vectorDB.LoadCollection(ctx, s.collection) + if err != nil { + return "", fmt.Errorf("failed to load collection: %w", err) + } + + // Set the retriever's TopK based on the config or dynamically + if s.retriever.config.TopK <= 0 { + // Do a test search with topK=1 to get number of documents + testResults, err := s.vectorDB.Search(ctx, s.collection, map[string]Vector{"test": make(Vector, s.vectorDB.Dimension())}, 1, "L2", nil) + if err != nil { + return "", fmt.Errorf("failed to get collection size: %w", err) + } + // Set TopK to min(20, numDocs) if not specified + s.retriever.config.TopK = 20 + if len(testResults) < 20 { + s.retriever.config.TopK = len(testResults) + } + } + + log.Printf("Using TopK=%d for search", s.retriever.config.TopK) + results, err := s.retriever.Retrieve(ctx, query) if err != nil { return "", fmt.Errorf("failed to search: %w", err) diff --git a/vectordb.go b/vectordb.go index 3b4745a..0030605 100644 --- a/vectordb.go +++ b/vectordb.go @@ -11,7 +11,10 @@ import ( ) type VectorDB struct { - db rag.VectorDB + db rag.VectorDB + dbType string + address string + dimension int } type Config struct { @@ -19,6 +22,7 @@ type Config struct { Address string MaxPoolSize int Timeout time.Duration + Dimension int } type Option func(*Config) @@ -47,6 +51,12 @@ func WithTimeout(timeout time.Duration) Option { } } +func WithDimension(dimension int) Option { + return func(c *Config) { + c.Dimension = dimension + } +} + func NewVectorDB(opts ...Option) (*VectorDB, error) { cfg := &Config{} for _, opt := range opts { @@ -57,11 +67,19 @@ func NewVectorDB(opts ...Option) (*VectorDB, error) { Address: cfg.Address, MaxPoolSize: cfg.MaxPoolSize, Timeout: cfg.Timeout, + Parameters: map[string]interface{}{ + "dimension": cfg.Dimension, + }, }) if err != nil { return nil, err } - return &VectorDB{db: ragDB}, nil + return &VectorDB{ + db: ragDB, + dbType: cfg.Type, + address: cfg.Address, + dimension: cfg.Dimension, + }, nil } func (vdb *VectorDB) Connect(ctx context.Context) error { @@ -138,6 +156,18 @@ func (vdb *VectorDB) SetColumnNames(names []string) { vdb.db.SetColumnNames(names) } +func (vdb *VectorDB) Type() string { + return vdb.dbType +} + +func (vdb *VectorDB) Address() string { + return vdb.address +} + +func (vdb *VectorDB) Dimension() int { + return vdb.dimension +} + // Types to match the internal rag package type Schema = rag.Schema type Field = rag.Field