From 372f7e10d4ddc681f11e7acdcc5fea902a56e540 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Te=C3=AFlo=20M?= Date: Wed, 20 Nov 2024 16:26:25 +0100 Subject: [PATCH] add some comments --- .gitignore | 2 + README.md | 417 +++++++++++++++++++++++----- chunker.go | 74 ++++- concurrentloader.go | 81 +++++- config/config.go | 156 +++++++---- contextual.go | 65 ++++- contextual_rag.go | 191 ++++++++++--- duplicate_pdfs.sh | 30 -- embed.go | 105 ------- embedder.go | 225 +++++++++++++++ examples/contextual/custom_llm.go | 2 +- examples/contextual/main.go | 83 ++++-- examples/memory_enhancer_example.go | 46 +++ loader.go | 125 ++++++++- logger.go | 101 ++++++- memory_context.go | 169 +++++++++-- parser.go | 89 +++++- rag.go | 285 ++++++++++++++++--- rag/chromem.go | 107 ++++++- rag/chunk.go | 88 ++++-- rag/embed.go | 73 ++++- rag/example_vectordb.go | 178 ++++++++++++ rag/load.go | 75 ++++- rag/log.go | 55 +++- rag/memory.go | 74 ++++- rag/milvus.go | 70 ++++- rag/parse.go | 92 +++--- rag/providers/example_provider.go | 137 +++++++++ rag/providers/openai.go | 56 +++- rag/providers/register.go | 101 +++++++ rag/reranker.go | 41 ++- rag/sparse_index.go | 95 +++++-- rag/vector_interface.go | 66 ++++- register.go | 272 +++++++++++++++--- retriever.go | 346 ++++++++++++++++------- simple_rag.go | 110 ++++++-- tsne_results.json | 1 - vectordb.go | 97 ++++++- 38 files changed, 3623 insertions(+), 757 deletions(-) delete mode 100755 duplicate_pdfs.sh delete mode 100644 embed.go create mode 100644 embedder.go create mode 100644 rag/example_vectordb.go create mode 100644 rag/providers/example_provider.go delete mode 100644 tsne_results.json diff --git a/.gitignore b/.gitignore index e06ac38..9fae985 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,5 @@ testdata/ go.work go.work.sum data/chromem.db/ + +*.json \ No newline at end of file diff --git a/README.md b/README.md index 5b7f1f4..2c6c355 100644 --- a/README.md +++ b/README.md @@ -1,164 +1,427 @@ -# Raggo +# Raggo - Retrieval Augmented Generation Library -> A lightweight, production-ready RAG (Retrieval Augmented Generation) library in Go. Built for simplicity and performance with Milvus vector database integration. +A powerful and flexible RAG (Retrieval Augmented Generation) library for Go, designed to make document processing and context-aware AI interactions simple and efficient. + +## Quick Start + +```go +package main + +import ( + "context" + "fmt" + "github.com/teilomillet/raggo" +) + +func main() { + // Initialize RAG with default settings + rag, err := raggo.NewSimpleRAG(raggo.DefaultConfig()) + if err != nil { + fmt.Printf("Error: %v\n", err) + return + } + defer rag.Close() + + // Add documents from a directory + err = rag.AddDocuments(context.Background(), "./docs") + if err != nil { + fmt.Printf("Error: %v\n", err) + return + } + + // Search with natural language + response, _ := rag.Search(context.Background(), "What are the key features?") + fmt.Printf("Answer: %s\n", response) +} +``` + +## Configuration + +Raggo provides a flexible configuration system that can be loaded from multiple sources (environment variables, JSON files, or programmatic defaults): + +```go +// Load configuration (automatically checks standard paths) +cfg, err := config.LoadConfig() +if err != nil { + log.Fatal(err) +} + +// Or create a custom configuration +cfg := &config.Config{ + Provider: "milvus", // Vector store provider + Model: "text-embedding-3-small", + Collection: "my_documents", + + // Search settings + DefaultTopK: 5, // Number of similar chunks to retrieve + DefaultMinScore: 0.7, // Similarity threshold + + // Document processing + DefaultChunkSize: 300, // Size of text chunks + DefaultChunkOverlap: 50, // Overlap between chunks +} + +// Create RAG instance with config +rag, err := raggo.NewSimpleRAG(cfg) +``` + +Configuration can be saved for reuse: +```go +err := cfg.Save("~/.raggo/config.json") +``` + +Environment variables (take precedence over config files): +- `RAGGO_PROVIDER`: Service provider +- `RAGGO_MODEL`: Model identifier +- `RAGGO_COLLECTION`: Collection name +- `RAGGO_API_KEY`: Default API key + +> A powerful, production-ready RAG (Retrieval-Augmented Generation) library in Go.

- 🔍 Search Documents • 💬 Ask Questions • 🤖 Get Smart Answers + 🔍 Smart Document Search • 💬 Context-Aware Responses • 🤖 Intelligent RAG

[![Go Reference](https://pkg.go.dev/badge/github.com/teilomillet/raggo.svg)](https://pkg.go.dev/github.com/teilomillet/raggo) [![Go Report Card](https://goreportcard.com/badge/github.com/teilomillet/raggo)](https://goreportcard.com/report/github.com/teilomillet/raggo) [![License](https://img.shields.io/github/license/teilomillet/raggo)](https://github.com/teilomillet/raggo/blob/main/LICENSE) +## Table of Contents + +### Part 1: Core Components +1. [Quick Start](#quick-start) +2. [Building Blocks](#building-blocks) + - [Document Loading](#document-loading) + - [Text Parsing](#text-parsing) + - [Text Chunking](#text-chunking) + - [Embeddings](#embeddings) + - [Vector Storage](#vector-storage) + +### Part 2: RAG Implementations +1. [Simple RAG](#simple-rag) + - [Basic Usage](#basic-usage) + - [Document Q&A](#document-qa) + - [Configuration](#configuration) +2. [Contextual RAG](#contextual-rag) + - [Advanced Features](#advanced-features) + - [Context Window](#context-window) + - [Hybrid Search](#hybrid-search) +3. [Memory Context](#memory-context) + - [Chat Applications](#chat-applications) + - [Memory Management](#memory-management) + - [Context Enhancement](#context-enhancement) +4. [Advanced Use Cases](#advanced-use-cases) + - [Full Processing Pipeline](#full-processing-pipeline) + - [Concurrent Processing](#concurrent-processing) + - [Rate Limiting](#rate-limiting) + +## Part 1: Core Components + +### Quick Start + +#### Prerequisites +```bash +# Install Milvus +docker run -d --name milvus -p 19530:19530 milvusdb/milvus:latest -Raggo helps your Go programs answer questions by looking through documents. Think of it as a smart assistant that reads your documents and answers questions about them. - -## Getting Started +# Set OpenAI API key +export OPENAI_API_KEY=your-api-key -### 1. Install Raggo -```bash +# Install Raggo go get github.com/teilomillet/raggo ``` -### 2. Set up Milvus -Raggo uses [Milvus](https://milvus.io/) as its vector database (required). Follow the [Milvus installation guide](https://milvus.io/docs/install_standalone-docker.md) to set it up. Once installed, Raggo will automatically connect to Milvus at `localhost:19530`. +### Building Blocks -### 3. Set your OpenAI API key -```bash -export OPENAI_API_KEY=your-api-key +#### Document Loading +```go +loader := raggo.NewLoader(raggo.SetTimeout(1*time.Minute)) +doc, err := loader.LoadURL(context.Background(), "https://example.com/doc.pdf") ``` -## Simple Examples +#### Text Parsing +```go +parser := raggo.NewParser() +doc, err := parser.Parse("document.pdf") +``` + +#### Text Chunking +```go +chunker := raggo.NewChunker(raggo.ChunkSize(100)) +chunks := chunker.Chunk(doc.Content) +``` -### 1. Ask a Question -This is the simplest way to use Raggo. Just create a RAG and ask a question: +#### Embeddings +```go +embedder := raggo.NewEmbedder( + raggo.SetProvider("openai"), + raggo.SetModel("text-embedding-3-small"), +) +``` + +#### Vector Storage +```go +db := raggo.NewVectorDB(raggo.WithMilvus("collection")) +``` + +## Part 2: RAG Implementations + +### Simple RAG +Best for straightforward document Q&A: ```go package main import ( "context" - "fmt" "log" "github.com/teilomillet/raggo" ) func main() { - // Create a new RAG - rag, err := raggo.NewDefaultSimpleRAG("my_first_rag") + // Initialize SimpleRAG + rag, err := raggo.NewSimpleRAG(raggo.SimpleRAGConfig{ + Collection: "docs", + Model: "text-embedding-3-small", + ChunkSize: 300, + TopK: 3, + }) if err != nil { log.Fatal(err) } + defer rag.Close() - // Ask a question - answer, err := rag.Search(context.Background(), "What is a RAG system?") + // Add documents + err = rag.AddDocuments(context.Background(), "./documents") if err != nil { log.Fatal(err) } - // Print the answer - fmt.Println(answer) + // Search with different strategies + basicResponse, _ := rag.Search(context.Background(), "What is the main feature?") + hybridResponse, _ := rag.SearchHybrid(context.Background(), "How does it work?", 0.7) + + log.Printf("Basic Search: %s\n", basicResponse) + log.Printf("Hybrid Search: %s\n", hybridResponse) +} +``` + +### Contextual RAG +For complex document understanding and context-aware responses: + +```go +package main + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "github.com/teilomillet/raggo" +) + +func main() { + // Initialize RAG with default settings + rag, err := raggo.NewDefaultContextualRAG("basic_contextual_docs") + if err != nil { + fmt.Printf("Failed to initialize RAG: %v\n", err) + os.Exit(1) + } + defer rag.Close() + + // Add documents - the system will automatically: + // - Split documents into semantic chunks + // - Generate rich context for each chunk + // - Store embeddings with contextual information + docsPath := filepath.Join("examples", "docs") + if err := rag.AddDocuments(context.Background(), docsPath); err != nil { + fmt.Printf("Failed to add documents: %v\n", err) + os.Exit(1) + } + + // Simple search with automatic context enhancement + query := "What are the key features of the product?" + response, err := rag.Search(context.Background(), query) + if err != nil { + fmt.Printf("Failed to search: %v\n", err) + os.Exit(1) + } + + fmt.Printf("\nQuery: %s\nResponse: %s\n", query, response) +} +``` + +### Advanced Configuration + +```go +// Create a custom configuration +config := &raggo.ContextualRAGConfig{ + Collection: "advanced_contextual_docs", + Model: "text-embedding-3-small", // Embedding model + LLMModel: "gpt-4o-mini", // Model for context generation + ChunkSize: 300, // Larger chunks for more context + ChunkOverlap: 75, // 25% overlap for better continuity + TopK: 5, // Number of similar chunks to retrieve + MinScore: 0.7, // Higher threshold for better relevance +} + +// Initialize RAG with custom configuration +rag, err := raggo.NewContextualRAG(config) +if err != nil { + log.Fatalf("Failed to initialize RAG: %v", err) } +defer rag.Close() ``` -### 2. Add Your Own Documents -Now let's add some documents and ask questions about them: +### Memory Context +For chat applications and long-term context retention: ```go package main import ( "context" - "fmt" "log" "github.com/teilomillet/raggo" + "github.com/teilomillet/gollm" ) func main() { - // Create a new RAG - rag, err := raggo.NewDefaultSimpleRAG("my_docs_rag") + // Initialize Memory Context + memoryCtx, err := raggo.NewMemoryContext( + os.Getenv("OPENAI_API_KEY"), + raggo.MemoryTopK(5), + raggo.MemoryCollection("chat"), + raggo.MemoryStoreLastN(100), + raggo.MemoryMinScore(0.7), + ) if err != nil { log.Fatal(err) } + defer memoryCtx.Close() - // Add documents from a folder - err = rag.AddDocuments(context.Background(), "./my_documents") + // Initialize Contextual RAG + rag, err := raggo.NewContextualRAG(&raggo.ContextualRAGConfig{ + Collection: "docs", + Model: "text-embedding-3-small", + }) if err != nil { log.Fatal(err) } + defer rag.Close() - // Ask a question about your documents - answer, err := rag.Search(context.Background(), "What do my documents say about project deadlines?") + // Example chat interaction + messages := []gollm.MemoryMessage{ + {Role: "user", Content: "How does the authentication system work?"}, + } + + // Store conversation + err = memoryCtx.StoreMemory(context.Background(), messages) if err != nil { log.Fatal(err) } - - fmt.Println(answer) + + // Get enhanced response with context + prompt := &gollm.Prompt{Messages: messages} + enhanced, _ := memoryCtx.EnhancePrompt(context.Background(), prompt, messages) + response, _ := rag.Search(context.Background(), enhanced.Messages[0].Content) + + log.Printf("Response: %s\n", response) } ``` -### 3. Smart Answers with Context -For better answers that remember context: +### Advanced Use Cases + +#### Full Processing Pipeline +Process large document sets with rate limiting and concurrent processing: ```go package main import ( "context" - "fmt" "log" + "sync" + "time" "github.com/teilomillet/raggo" + "golang.org/x/time/rate" ) -func main() { - // Create a RAG that remembers context - rag, err := raggo.NewDefaultContextualRAG("my_smart_rag") - if err != nil { - log.Fatal(err) - } - - // Add your documents - err = rag.AddDocuments(context.Background(), "./my_documents") - if err != nil { - log.Fatal(err) - } +const ( + GPT_RPM_LIMIT = 5000 // Requests per minute + GPT_TPM_LIMIT = 4000000 // Tokens per minute + MAX_CONCURRENT = 10 // Max concurrent goroutines +) - // Ask related questions - ctx := context.Background() +func main() { + // Initialize components + parser := raggo.NewParser() + chunker := raggo.NewChunker(raggo.ChunkSize(500)) + embedder := raggo.NewEmbedder( + raggo.SetProvider("openai"), + raggo.SetModel("text-embedding-3-small"), + ) + + // Create rate limiters + limiter := rate.NewLimiter(rate.Limit(GPT_RPM_LIMIT/60), GPT_RPM_LIMIT) - answer1, _ := rag.Search(ctx, "What is our company's return policy?") - fmt.Println("First answer:", answer1) + // Process documents concurrently + var wg sync.WaitGroup + semaphore := make(chan struct{}, MAX_CONCURRENT) + + files, _ := filepath.Glob("./documents/*.pdf") + for _, file := range files { + wg.Add(1) + semaphore <- struct{}{} // Acquire semaphore + + go func(file string) { + defer wg.Done() + defer func() { <-semaphore }() // Release semaphore + + // Wait for rate limit + limiter.Wait(context.Background()) + + // Process document + doc, _ := parser.Parse(file) + chunks := chunker.Chunk(doc.Content) + embeddings, _ := embedder.CreateEmbeddings(chunks) + + log.Printf("Processed %s: %d chunks\n", file, len(chunks)) + }(file) + } - answer2, _ := rag.Search(ctx, "What about for electronics?") - fmt.Println("Second answer:", answer2) + wg.Wait() } ``` -## What's Next? +## Best Practices + +### Resource Management +- Always use `defer Close()` +- Monitor memory usage +- Clean up old data -Once you're comfortable with these basics, you can: -- Use custom language models -- Process documents in parallel -- Add vector database support -- Customize document processing -- Add visualization tools +### Performance +- Use concurrent processing for large datasets +- Configure appropriate chunk sizes +- Enable hybrid search when needed -Check out our [Advanced Examples](#advanced-examples) to learn more. +### Context Management +- Use Memory Context for chat applications +- Configure context window size +- Clean up old memories periodically -## Advanced Examples +## Examples -### Basic Examples -- `simple/`: Basic RAG implementation -- `contextual/`: Contextual RAG with custom LLM -- `chat/`: Chat-based RAG system +Check `/examples` for more: +- Basic usage: `/examples/simple/` +- Context-aware: `/examples/contextual/` +- Chat applications: `/examples/chat/` +- Memory usage: `/examples/memory_enhancer_example.go` +- Full pipeline: `/examples/full_process.go` +- Benchmarks: `/examples/process_embedding_benchmark.go` -### Advanced Examples -- `full_process.go`: Complete production pipeline -- `recruit_example.go`: Resume processing system -- `vectordb_example.go`: Vector database integration -- `tsne_example.go`: Embedding visualization +## License -### Performance Examples -- `process_embedding_benchmark.go`: Performance testing -- `concurrent_loader_example.go`: Concurrent processing -- `rate_limiting_example.go`: API rate limiting +MIT License - see [LICENSE](LICENSE) file diff --git a/chunker.go b/chunker.go index 4286d14..099133d 100644 --- a/chunker.go +++ b/chunker.go @@ -1,74 +1,122 @@ +// Package raggo provides a high-level interface for text chunking and token management, +// designed for use in retrieval-augmented generation (RAG) applications. package raggo import ( "github.com/teilomillet/raggo/rag" ) -// Chunk represents a piece of text with metadata +// Chunk represents a piece of text with associated metadata including its content, +// token count, and position within the original document. It tracks: +// - The actual text content +// - Number of tokens in the chunk +// - Starting and ending sentence indices type Chunk = rag.Chunk -// Chunker defines the interface for text chunking +// Chunker defines the interface for text chunking implementations. +// Implementations of this interface provide strategies for splitting text +// into semantically meaningful chunks while preserving context. type Chunker interface { + // Chunk splits the input text into a slice of Chunks according to the + // implementation's strategy. Chunk(text string) []Chunk } -// TokenCounter defines the interface for counting tokens in a string +// TokenCounter defines the interface for counting tokens in text. +// Different implementations can provide various tokenization strategies, +// from simple word-based counting to model-specific subword tokenization. type TokenCounter interface { + // Count returns the number of tokens in the given text according to + // the implementation's tokenization strategy. Count(text string) int } -// ChunkerOption is a function type for configuring Chunker +// ChunkerOption is a function type for configuring Chunker instances. +// It follows the functional options pattern for clean and flexible configuration. type ChunkerOption = rag.TextChunkerOption -// NewChunker creates a new Chunker with the given options +// NewChunker creates a new Chunker with the given options. +// By default, it creates a TextChunker with: +// - Chunk size: 200 tokens +// - Chunk overlap: 50 tokens +// - Default word-based token counter +// - Basic sentence splitter +// +// Use the provided option functions to customize these settings. func NewChunker(options ...ChunkerOption) (Chunker, error) { return rag.NewTextChunker(options...) } -// ChunkSize sets the chunk size +// ChunkSize sets the target size of each chunk in tokens. +// This determines how much text will be included in each chunk +// before starting a new one. func ChunkSize(size int) ChunkerOption { return func(tc *rag.TextChunker) { tc.ChunkSize = size } } -// ChunkOverlap sets the chunk overlap +// ChunkOverlap sets the number of tokens that should overlap between +// adjacent chunks. This helps maintain context across chunk boundaries +// and improves retrieval quality. func ChunkOverlap(overlap int) ChunkerOption { return func(tc *rag.TextChunker) { tc.ChunkOverlap = overlap } } -// WithTokenCounter sets a custom token counter +// WithTokenCounter sets a custom token counter implementation. +// This allows you to use different tokenization strategies, such as: +// - Word-based counting (DefaultTokenCounter) +// - Model-specific tokenization (TikTokenCounter) +// - Custom tokenization schemes func WithTokenCounter(counter TokenCounter) ChunkerOption { return func(tc *rag.TextChunker) { tc.TokenCounter = counter } } -// WithSentenceSplitter sets a custom sentence splitter function +// WithSentenceSplitter sets a custom sentence splitter function. +// The function should take a string and return a slice of strings, +// where each string is a sentence. This allows for: +// - Custom sentence boundary detection +// - Language-specific splitting rules +// - Special handling of abbreviations or formatting func WithSentenceSplitter(splitter func(string) []string) ChunkerOption { return func(tc *rag.TextChunker) { tc.SentenceSplitter = splitter } } -// DefaultSentenceSplitter returns the default sentence splitter function +// DefaultSentenceSplitter returns the basic sentence splitter function +// that splits text on common punctuation marks (., !, ?). +// Suitable for simple English text without complex formatting. func DefaultSentenceSplitter() func(string) []string { return rag.DefaultSentenceSplitter } -// SmartSentenceSplitter returns the smart sentence splitter function +// SmartSentenceSplitter returns an advanced sentence splitter that handles: +// - Multiple punctuation marks +// - Quoted sentences +// - Parenthetical content +// - Lists and enumerations +// +// Recommended for complex text with varied formatting and structure. func SmartSentenceSplitter() func(string) []string { return rag.SmartSentenceSplitter } -// NewDefaultTokenCounter creates a new default token counter +// NewDefaultTokenCounter creates a simple word-based token counter +// that splits text on whitespace. Suitable for basic use cases +// where exact token counts aren't critical. func NewDefaultTokenCounter() TokenCounter { return &rag.DefaultTokenCounter{} } -// NewTikTokenCounter creates a new TikToken counter with the specified encoding +// NewTikTokenCounter creates a token counter using the tiktoken library, +// which implements the same tokenization used by OpenAI models. +// The encoding parameter specifies which tokenization model to use +// (e.g., "cl100k_base" for GPT-4, "p50k_base" for GPT-3). func NewTikTokenCounter(encoding string) (TokenCounter, error) { return rag.NewTikTokenCounter(encoding) } diff --git a/concurrentloader.go b/concurrentloader.go index 787c16a..168f044 100644 --- a/concurrentloader.go +++ b/concurrentloader.go @@ -1,3 +1,4 @@ +// Package raggo provides utilities for concurrent document loading and processing. package raggo import ( @@ -13,24 +14,69 @@ import ( "github.com/teilomillet/raggo/rag" ) -// ConcurrentPDFLoader extends the basic Loader interface +// ConcurrentPDFLoader extends the basic Loader interface with concurrent PDF processing +// capabilities. It provides efficient handling of multiple PDF files by: +// - Loading files in parallel using goroutines +// - Managing concurrent file operations safely +// - Handling file duplication when needed +// - Providing progress tracking and error handling type ConcurrentPDFLoader interface { + // Embeds the basic Loader interface Loader + + // LoadPDFsConcurrent loads a specified number of PDF files concurrently from a source directory. + // If the source directory contains fewer files than the requested count, it automatically + // duplicates existing PDFs to reach the desired number. + // + // Parameters: + // - ctx: Context for cancellation and timeout + // - sourceDir: Directory containing source PDF files + // - targetDir: Directory where duplicated PDFs will be stored + // - count: Desired number of PDF files to load + // + // Returns: + // - []string: Paths to all successfully loaded files + // - error: Any error encountered during the process + // + // Example usage: + // loader := raggo.NewConcurrentPDFLoader(raggo.SetTimeout(1*time.Minute)) + // files, err := loader.LoadPDFsConcurrent(ctx, "source", "target", 10) LoadPDFsConcurrent(ctx context.Context, sourceDir string, targetDir string, count int) ([]string, error) } -// concurrentPDFLoaderWrapper wraps the internal loader and adds concurrent PDF loading capability +// concurrentPDFLoaderWrapper wraps the internal loader and adds concurrent PDF loading capability. +// It implements thread-safe operations and efficient resource management. type concurrentPDFLoaderWrapper struct { internal *rag.Loader } -// NewConcurrentPDFLoader creates a new ConcurrentPDFLoader with the given options +// NewConcurrentPDFLoader creates a new ConcurrentPDFLoader with the given options. +// It supports all standard loader options plus concurrent processing capabilities. +// +// Options can include: +// - SetTimeout: Maximum time for loading operations +// - SetTempDir: Directory for temporary files +// - SetRetryCount: Number of retries for failed operations +// +// Example: +// loader := raggo.NewConcurrentPDFLoader( +// raggo.SetTimeout(1*time.Minute), +// raggo.SetTempDir(os.TempDir()), +// ) func NewConcurrentPDFLoader(opts ...LoaderOption) ConcurrentPDFLoader { return &concurrentPDFLoaderWrapper{internal: rag.NewLoader(opts...)} } -// LoadPDFsConcurrent loads 'count' number of PDF files from the specified directory, -// duplicating files if necessary to reach the desired count +// LoadPDFsConcurrent implements the concurrent PDF loading strategy. +// It performs the following steps: +// 1. Lists all PDF files in the source directory +// 2. Creates the target directory if it doesn't exist +// 3. Duplicates PDFs if necessary to reach the desired count +// 4. Loads files concurrently using goroutines +// 5. Collects results and errors from concurrent operations +// +// The function uses channels for thread-safe communication and a WaitGroup +// to ensure all operations complete before returning. func (clw *concurrentPDFLoaderWrapper) LoadPDFsConcurrent(ctx context.Context, sourceDir string, targetDir string, count int) ([]string, error) { pdfs, err := listPDFFiles(sourceDir) if err != nil { @@ -94,19 +140,27 @@ func (clw *concurrentPDFLoaderWrapper) LoadPDFsConcurrent(ctx context.Context, s return loadedFiles, nil } +// LoadURL implements the Loader interface by loading a document from a URL. +// This method is inherited from the basic Loader interface. func (clw *concurrentPDFLoaderWrapper) LoadURL(ctx context.Context, url string) (string, error) { return clw.internal.LoadURL(ctx, url) } +// LoadFile implements the Loader interface by loading a single file. +// This method is inherited from the basic Loader interface. func (clw *concurrentPDFLoaderWrapper) LoadFile(ctx context.Context, path string) (string, error) { return clw.internal.LoadFile(ctx, path) } +// LoadDir implements the Loader interface by loading all files in a directory. +// This method is inherited from the basic Loader interface. func (clw *concurrentPDFLoaderWrapper) LoadDir(ctx context.Context, dir string) ([]string, error) { return clw.internal.LoadDir(ctx, dir) } -// listPDFFiles returns a list of all PDF files in the given directory +// listPDFFiles returns a list of all PDF files in the given directory. +// It recursively walks through the directory tree and identifies files +// with a .pdf extension (case-insensitive). func listPDFFiles(dir string) ([]string, error) { var pdfs []string err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { @@ -121,7 +175,14 @@ func listPDFFiles(dir string) ([]string, error) { return pdfs, err } -// duplicatePDFs duplicates the given PDF files to reach the desired count +// duplicatePDFs duplicates the given PDF files to reach the desired count. +// If the number of source PDFs is less than the desired count, it creates +// copies with unique names by appending a counter to the original filename. +// +// The function ensures that: +// - Each copy has a unique name +// - The total number of files matches the desired count +// - File copying is performed safely func duplicatePDFs(pdfs []string, targetDir string, desiredCount int) ([]string, error) { var duplicatedPDFs []string numOriginalPDFs := len(pdfs) @@ -152,7 +213,11 @@ func duplicatePDFs(pdfs []string, targetDir string, desiredCount int) ([]string, return duplicatedPDFs, nil } -// copyFile copies a file from src to dst +// copyFile performs a safe copy of a file from src to dst. +// It handles: +// - Opening source and destination files +// - Proper resource cleanup with defer +// - Efficient copying with io.Copy func copyFile(src, dst string) error { sourceFile, err := os.Open(src) if err != nil { diff --git a/config/config.go b/config/config.go index 56e12ce..7866cf1 100644 --- a/config/config.go +++ b/config/config.go @@ -1,3 +1,15 @@ +// Package config provides a flexible configuration management system for the Raggo +// Retrieval-Augmented Generation (RAG) framework. It handles configuration loading, +// validation, and persistence with support for multiple sources: +// - Configuration files (JSON) +// - Environment variables +// - Programmatic defaults +// +// The package implements a hierarchical configuration system where settings can be +// overridden in the following order (highest to lowest precedence): +// 1. Environment variables +// 2. Configuration file +// 3. Default values package config import ( @@ -7,62 +19,93 @@ import ( "time" ) -// Config holds all configuration for the RAG system +// Config holds all configuration for the RAG system. It provides a centralized +// way to manage settings across different components of the system. +// +// Configuration categories: +// - Provider settings: Embedding and service providers +// - Collection settings: Vector database collections +// - Search settings: Retrieval and ranking parameters +// - Document processing: Text chunking and batching +// - Vector store: Database-specific configuration +// - System settings: Timeouts, retries, and headers type Config struct { - // Provider settings - Provider string - Model string - APIKeys map[string]string - - // Collection settings - Collection string - - // Search settings - SearchStrategy string - DefaultTopK int - DefaultMinScore float64 - DefaultSearchParams map[string]interface{} - EnableReRanking bool - RRFConstant float64 - - // Document processing settings - DefaultChunkSize int - DefaultChunkOverlap int - DefaultBatchSize int - DefaultIndexType string - - // Vector store settings - VectorDBConfig map[string]interface{} - - // Timeouts and retries - Timeout time.Duration - MaxRetries int - - // Additional settings - ExtraHeaders map[string]string + // Provider settings configure the embedding and service providers + Provider string // Service provider (e.g., "milvus", "openai") + Model string // Model identifier for embeddings + APIKeys map[string]string // API keys for different providers + + // Collection settings define the vector database structure + Collection string // Name of the vector collection + + // Search settings control retrieval behavior and ranking + SearchStrategy string // Search method (e.g., "dense", "hybrid") + DefaultTopK int // Default number of results to return + DefaultMinScore float64 // Minimum similarity score threshold + DefaultSearchParams map[string]interface{} // Additional search parameters + EnableReRanking bool // Enable result re-ranking + RRFConstant float64 // Reciprocal Rank Fusion constant + + // Document processing settings for text handling + DefaultChunkSize int // Size of text chunks + DefaultChunkOverlap int // Overlap between consecutive chunks + DefaultBatchSize int // Number of items per processing batch + DefaultIndexType string // Type of vector index (e.g., "HNSW") + + // Vector store settings for database configuration + VectorDBConfig map[string]interface{} // Database-specific settings + + // Timeouts and retries for system operations + Timeout time.Duration // Operation timeout + MaxRetries int // Maximum retry attempts + + // Additional settings for extended functionality + ExtraHeaders map[string]string // Additional HTTP headers } -// LoadConfig loads configuration from a file or environment +// LoadConfig loads configuration from multiple sources, combining them according +// to the precedence rules. It automatically searches for configuration files in +// standard locations and applies environment variable overrides. +// +// Configuration file search paths: +// 1. $RAGGO_CONFIG environment variable +// 2. ~/.raggo/config.json +// 3. ~/.config/raggo/config.json +// 4. ./raggo.json +// +// Environment variable overrides: +// - RAGGO_PROVIDER: Service provider +// - RAGGO_MODEL: Model identifier +// - RAGGO_COLLECTION: Collection name +// - RAGGO_API_KEY: Default API key +// +// Example usage: +// +// cfg, err := config.LoadConfig() +// if err != nil { +// log.Fatal(err) +// } +// fmt.Printf("Using provider: %s\n", cfg.Provider) func LoadConfig() (*Config, error) { - // Default configuration + // Default configuration with production-ready settings cfg := &Config{ - Provider: "milvus", - Model: "text-embedding-3-small", - Collection: "documents", - SearchStrategy: "dense", - DefaultTopK: 5, - DefaultMinScore: 0.7, - DefaultChunkSize: 512, - DefaultChunkOverlap: 50, - DefaultBatchSize: 100, - DefaultIndexType: "HNSW", + Provider: "milvus", // Fast, open-source vector database + Model: "text-embedding-3-small", // Latest OpenAI embedding model + Collection: "documents", // Default collection name + SearchStrategy: "dense", // Pure vector similarity search + DefaultTopK: 5, // Conservative number of results + DefaultMinScore: 0.7, // High confidence threshold + DefaultChunkSize: 512, // Balanced chunk size + DefaultChunkOverlap: 50, // Moderate overlap + DefaultBatchSize: 100, // Efficient batch size + DefaultIndexType: "HNSW", // Fast approximate search DefaultSearchParams: map[string]interface{}{ - "ef": 64, + "ef": 64, // HNSW search depth }, - EnableReRanking: false, - RRFConstant: 60, - Timeout: 30 * time.Second, - MaxRetries: 3, + EnableReRanking: false, // Disabled by default + RRFConstant: 60, // Standard RRF constant + Timeout: 30 * time.Second, // Conservative timeout + MaxRetries: 3, // Reasonable retry count APIKeys: make(map[string]string), ExtraHeaders: make(map[string]string), VectorDBConfig: make(map[string]interface{}), @@ -115,7 +158,20 @@ func LoadConfig() (*Config, error) { return cfg, nil } -// Save saves the configuration to a file +// Save persists the configuration to a JSON file at the specified path. +// It creates any necessary parent directories and sets appropriate file +// permissions. +// +// Example usage: +// +// cfg := &Config{ +// Provider: "milvus", +// Model: "text-embedding-3-small", +// } +// err := cfg.Save("~/.raggo/config.json") +// if err != nil { +// log.Fatal(err) +// } func (c *Config) Save(path string) error { data, err := json.MarshalIndent(c, "", " ") if err != nil { diff --git a/contextual.go b/contextual.go index 33898b2..acf04e4 100644 --- a/contextual.go +++ b/contextual.go @@ -1,3 +1,5 @@ +// Package raggo provides advanced Retrieval-Augmented Generation (RAG) capabilities +// with contextual awareness and memory management. package raggo import ( @@ -5,17 +7,62 @@ import ( "fmt" ) -// ContextualStoreOptions holds settings for contextual document processing +// ContextualStoreOptions configures how documents are processed and stored with +// contextual information. It provides settings for: +// - Vector database collection management +// - Document chunking and processing +// - Embedding model configuration +// - Batch processing controls +// +// This configuration is designed to optimize the balance between processing +// efficiency and context preservation. type ContextualStoreOptions struct { - Collection string - APIKey string - ChunkSize int + // Collection specifies the vector database collection name + Collection string + + // APIKey is the authentication key for the embedding provider + APIKey string + + // ChunkSize determines the size of text chunks in tokens + // Larger chunks preserve more context but use more memory + ChunkSize int + + // ChunkOverlap controls how much text overlaps between chunks + // More overlap helps maintain context across chunk boundaries ChunkOverlap int - BatchSize int - ModelName string + + // BatchSize sets how many documents to process simultaneously + // Higher values increase throughput but use more memory + BatchSize int + + // ModelName specifies which language model to use for context generation + // This model enriches chunks with additional contextual information + ModelName string } -// StoreWithContext processes documents and stores them with contextual information +// StoreWithContext processes documents and stores them with enhanced contextual information. +// It uses a combination of: +// - Semantic chunking for document segmentation +// - Language model enrichment for context generation +// - Vector embedding for efficient retrieval +// - Batch processing for performance +// +// The function automatically handles: +// - Default configuration values +// - Resource management +// - Error handling and reporting +// - Context-aware processing +// +// Example usage: +// +// opts := raggo.ContextualStoreOptions{ +// Collection: "my_docs", +// APIKey: os.Getenv("OPENAI_API_KEY"), +// ChunkSize: 512, +// BatchSize: 100, +// } +// +// err := raggo.StoreWithContext(ctx, "path/to/docs", opts) func StoreWithContext(ctx context.Context, source string, opts ContextualStoreOptions) error { // Use default values if not specified if opts.ChunkSize == 0 { @@ -31,7 +78,7 @@ func StoreWithContext(ctx context.Context, source string, opts ContextualStoreOp opts.ModelName = "gpt-4o-mini" } - // Initialize RAG + // Initialize RAG with context-aware configuration rag, err := NewRAG( WithMilvus(opts.Collection), WithOpenAI(opts.APIKey), @@ -46,6 +93,6 @@ func StoreWithContext(ctx context.Context, source string, opts ContextualStoreOp } defer rag.Close() - // Process and store documents with context + // Process and store documents with enhanced context return rag.ProcessWithContext(ctx, source, opts.ModelName) } diff --git a/contextual_rag.go b/contextual_rag.go index 8d72ce2..1bb5f1b 100644 --- a/contextual_rag.go +++ b/contextual_rag.go @@ -1,3 +1,5 @@ +// Package raggo provides advanced Retrieval-Augmented Generation (RAG) capabilities +// with contextual awareness and memory management. package raggo import ( @@ -12,7 +14,27 @@ import ( "github.com/teilomillet/gollm" ) -// ContextualRAG provides a simplified interface for contextual RAG operations +// ContextualRAG provides a high-level interface for context-aware document processing +// and retrieval. It enhances traditional RAG systems by: +// - Maintaining semantic relationships between document chunks +// - Generating rich contextual metadata for improved retrieval +// - Supporting customizable chunking and embedding strategies +// - Providing flexible LLM integration for response generation +// +// Example usage: +// +// // Create with default settings +// rag, err := raggo.NewDefaultContextualRAG("my_docs") +// if err != nil { +// log.Fatal(err) +// } +// defer rag.Close() +// +// // Add documents with automatic context generation +// err = rag.AddDocuments(context.Background(), "path/to/docs") +// +// // Perform context-aware search +// response, err := rag.Search(context.Background(), "How does feature X work?") type ContextualRAG struct { rag *RAG retriever *Retriever @@ -20,34 +42,95 @@ type ContextualRAG struct { llmModel string } -// ContextualRAGConfig holds configuration for ContextualRAG +// ContextualRAGConfig provides fine-grained control over the RAG system's behavior. +// It allows customization of: +// - Document processing (chunk size, overlap) +// - Embedding generation (model selection) +// - Retrieval strategy (top-k, similarity threshold) +// - LLM integration (custom instance, model selection) +// +// Example configuration: +// +// config := &raggo.ContextualRAGConfig{ +// Collection: "technical_docs", +// Model: "text-embedding-3-small", +// ChunkSize: 300, // Larger chunks for more context +// ChunkOverlap: 50, // Overlap for context continuity +// TopK: 5, // Number of relevant chunks +// MinScore: 0.7, // Similarity threshold +// } type ContextualRAGConfig struct { - Collection string - APIKey string - Model string // Embedding model - LLMModel string // LLM model for context generation - LLM gollm.LLM // Optional custom LLM instance - ChunkSize int + // Collection specifies the vector database collection name + Collection string + + // APIKey for authentication with the embedding/LLM provider + APIKey string + + // Model specifies the embedding model for vector generation + Model string + + // LLMModel specifies the language model for context generation + LLMModel string + + // LLM allows using a custom LLM instance with specific configuration + LLM gollm.LLM + + // ChunkSize controls the size of document segments (in tokens) + // Larger values preserve more context but increase processing time + ChunkSize int + + // ChunkOverlap determines how much text overlaps between chunks + // Higher values help maintain context across chunk boundaries ChunkOverlap int - TopK int - MinScore float64 + + // TopK specifies how many similar chunks to retrieve + // Adjust based on needed context breadth + TopK int + + // MinScore sets the minimum similarity threshold for retrieval + // Higher values increase precision but may reduce recall + MinScore float64 } -// DefaultContextualConfig returns a default configuration +// DefaultContextualConfig returns a balanced configuration suitable for +// most use cases. It provides: +// - Reasonable chunk sizes for context preservation +// - Modern embedding model selection +// - Conservative similarity thresholds +// - Efficient batch processing settings func DefaultContextualConfig() ContextualRAGConfig { return ContextualRAGConfig{ Collection: "contextual_docs", - APIKey: "", Model: "text-embedding-3-small", LLMModel: "gpt-4o-mini", - ChunkSize: 200, // Increased for better context - ChunkOverlap: 50, // Reasonable overlap - TopK: 10, // Number of results to return - MinScore: 0.0, // No minimum score to allow for more flexible matching + ChunkSize: 200, // Balanced for most documents + ChunkOverlap: 50, // 25% overlap for context + TopK: 10, // Reasonable number of results + MinScore: 0.0, // No minimum for flexible matching } } -// NewContextualRAG creates a new ContextualRAG instance with custom configuration +// NewContextualRAG creates a new ContextualRAG instance with custom configuration. +// It provides advanced control over: +// - Document processing behavior +// - Embedding generation +// - Retrieval strategies +// - LLM integration +// +// The function will: +// - Merge provided config with defaults +// - Validate settings +// - Initialize vector store +// - Set up LLM integration +// +// Example: +// +// config := &raggo.ContextualRAGConfig{ +// Collection: "my_docs", +// ChunkSize: 300, +// TopK: 5, +// } +// rag, err := raggo.NewContextualRAG(config) func NewContextualRAG(config *ContextualRAGConfig) (*ContextualRAG, error) { // Start with default configuration defaultConfig := DefaultContextualConfig() @@ -90,7 +173,20 @@ func NewContextualRAG(config *ContextualRAGConfig) (*ContextualRAG, error) { return initializeRAG(config) } -// NewDefaultContextualRAG creates a new ContextualRAG instance with minimal configuration +// NewDefaultContextualRAG creates a new instance with production-ready defaults. +// It's ideal for quick setup while maintaining good performance. +// +// The function: +// - Uses environment variables for API keys +// - Sets optimal processing parameters +// - Configures reliable retrieval settings +// +// Example: +// +// rag, err := raggo.NewDefaultContextualRAG("my_collection") +// if err != nil { +// log.Fatal(err) +// } func NewDefaultContextualRAG(collection string) (*ContextualRAG, error) { config := DefaultContextualConfig() config.Collection = collection @@ -206,8 +302,21 @@ func initializeRAG(config *ContextualRAGConfig) (*ContextualRAG, error) { }, nil } -// AddDocuments adds documents to the vector database with contextual information -func (c *ContextualRAG) AddDocuments(ctx context.Context, source string) error { +// AddDocuments processes and stores documents with contextual awareness. +// The function: +// - Splits documents into semantic chunks +// - Generates rich contextual metadata +// - Creates and stores embeddings +// - Maintains relationships between chunks +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// - source: Path to document or directory +// +// Example: +// +// err := rag.AddDocuments(ctx, "path/to/docs") +func (r *ContextualRAG) AddDocuments(ctx context.Context, source string) error { if ctx == nil { ctx = context.Background() } @@ -234,7 +343,7 @@ func (c *ContextualRAG) AddDocuments(ctx context.Context, source string) error { filePath := filepath.Join(source, file.Name()) // Process file with context - if err := c.rag.ProcessWithContext(ctx, filePath, c.llmModel); err != nil { + if err := r.rag.ProcessWithContext(ctx, filePath, r.llmModel); err != nil { return fmt.Errorf("failed to process file %s: %w", file.Name(), err) } log.Printf("Successfully processed file: %s", file.Name()) @@ -242,13 +351,13 @@ func (c *ContextualRAG) AddDocuments(ctx context.Context, source string) error { } } else { // Process single file with context - if err := c.rag.ProcessWithContext(ctx, source, c.llmModel); err != nil { + if err := r.rag.ProcessWithContext(ctx, source, r.llmModel); err != nil { return fmt.Errorf("failed to process file: %w", err) } } // Create and load index - err = c.rag.db.CreateIndex(ctx, c.rag.config.Collection, "Embedding", Index{ + err = r.rag.db.CreateIndex(ctx, r.rag.config.Collection, "Embedding", Index{ Type: "HNSW", Metric: "L2", Parameters: map[string]interface{}{ @@ -261,7 +370,7 @@ func (c *ContextualRAG) AddDocuments(ctx context.Context, source string) error { } // Load the collection - err = c.rag.db.LoadCollection(ctx, c.rag.config.Collection) + err = r.rag.db.LoadCollection(ctx, r.rag.config.Collection) if err != nil { return fmt.Errorf("failed to load collection: %w", err) } @@ -270,8 +379,21 @@ func (c *ContextualRAG) AddDocuments(ctx context.Context, source string) error { return nil } -// Search performs a semantic search query and returns a natural language response -func (c *ContextualRAG) Search(ctx context.Context, query string) (string, error) { +// Search performs context-aware retrieval and generates a natural language response. +// The process: +// 1. Analyzes query for context requirements +// 2. Retrieves relevant document chunks +// 3. Synthesizes information with context preservation +// 4. Generates a coherent response +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// - query: Natural language query string +// +// Example: +// +// response, err := rag.Search(ctx, "How does the system handle errors?") +func (r *ContextualRAG) Search(ctx context.Context, query string) (string, error) { if ctx == nil { ctx = context.Background() } @@ -279,7 +401,7 @@ func (c *ContextualRAG) Search(ctx context.Context, query string) (string, error log.Printf("Searching for: %s", query) // Generate context for the query to improve search relevance - queryContext, err := c.generateContext(ctx, query) + queryContext, err := r.generateContext(ctx, query) if err != nil { log.Printf("Warning: Failed to generate query context: %v", err) // Continue with original query if context generation fails @@ -289,7 +411,7 @@ func (c *ContextualRAG) Search(ctx context.Context, query string) (string, error } // Get search results using retriever with both original query and context - results, err := c.retriever.Retrieve(ctx, queryContext) + results, err := r.retriever.Retrieve(ctx, queryContext) if err != nil { return "", fmt.Errorf("search failed: %w", err) } @@ -325,7 +447,7 @@ func (c *ContextualRAG) Search(ctx context.Context, query string) (string, error prompt := gollm.NewPrompt(contextBuilder.String()) // Generate response using LLM - response, err := c.llm.Generate(ctx, prompt) + response, err := r.llm.Generate(ctx, prompt) if err != nil { return "", fmt.Errorf("failed to generate response: %w", err) } @@ -334,14 +456,14 @@ func (c *ContextualRAG) Search(ctx context.Context, query string) (string, error } // generateContext uses the LLM to generate a richer context for the query -func (c *ContextualRAG) generateContext(ctx context.Context, query string) (string, error) { +func (r *ContextualRAG) generateContext(ctx context.Context, query string) (string, error) { prompt := gollm.NewPrompt(fmt.Sprintf( "Given this search query: '%s'\n"+ "Generate a more detailed version that includes relevant context and related terms "+ "to improve semantic search. Keep the enhanced query concise but comprehensive.", query)) - enhancedQuery, err := c.llm.Generate(ctx, prompt) + enhancedQuery, err := r.llm.Generate(ctx, prompt) if err != nil { return "", err } @@ -349,7 +471,8 @@ func (c *ContextualRAG) generateContext(ctx context.Context, query string) (stri return enhancedQuery, nil } -// Close releases resources -func (c *ContextualRAG) Close() error { - return c.rag.Close() +// Close releases all resources held by the ContextualRAG instance. +// Always defer Close() after creating a new instance. +func (r *ContextualRAG) Close() error { + return r.rag.Close() } diff --git a/duplicate_pdfs.sh b/duplicate_pdfs.sh deleted file mode 100755 index 711e057..0000000 --- a/duplicate_pdfs.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash - -SOURCE_DIR="testdata" -TARGET_DIR="benchmark_data" -DESIRED_COUNT=5000 - -# Clean the target directory -rm -rf "$TARGET_DIR" -mkdir -p "$TARGET_DIR" - -echo "Cleaned $TARGET_DIR" - -# Get list of PDF files in source directory -pdf_files=($(ls "$SOURCE_DIR"/*.pdf)) -num_pdfs=${#pdf_files[@]} - -if [ $num_pdfs -eq 0 ]; then - echo "No PDF files found in $SOURCE_DIR" - exit 1 -fi - -# Duplicate PDFs -for ((i=0; i 5*time.Minute } -// StoreMemory explicitly stores messages in the memory context +// StoreMemory explicitly stores messages in the memory context. +// It processes and indexes the messages for later retrieval. +// +// Example: +// +// err := ctx.StoreMemory(context.Background(), []gollm.MemoryMessage{ +// {Role: "user", Content: "How does feature X work?"}, +// {Role: "assistant", Content: "Feature X works by..."}, +// }) func (m *MemoryContext) StoreMemory(ctx context.Context, messages []gollm.MemoryMessage) error { if len(messages) == 0 { return nil @@ -181,7 +282,12 @@ func (m *MemoryContext) StoreMemory(ctx context.Context, messages []gollm.Memory return nil } -// StoreLastN stores the last N messages from the memory +// StoreLastN stores only the most recent N messages from the memory. +// This helps maintain a sliding window of relevant context. +// +// Example: +// +// err := ctx.StoreLastN(context.Background(), messages, 10) // Keep last 10 messages func (m *MemoryContext) StoreLastN(ctx context.Context, memory []gollm.MemoryMessage, n int) error { if !m.shouldStore(memory) { return nil @@ -200,7 +306,12 @@ func (m *MemoryContext) StoreLastN(ctx context.Context, memory []gollm.MemoryMes return err } -// EnhancePrompt enriches a prompt with relevant context from memory +// EnhancePrompt enriches a prompt with relevant context from memory. +// It retrieves and integrates past interactions to provide better context. +// +// Example: +// +// enhanced, err := ctx.EnhancePrompt(context.Background(), prompt, messages) func (m *MemoryContext) EnhancePrompt(ctx context.Context, prompt *gollm.Prompt, memory []gollm.MemoryMessage) (*gollm.Prompt, error) { relevantContext, err := m.retrieveContext(ctx, prompt.Input) if err != nil { @@ -224,7 +335,8 @@ func (m *MemoryContext) EnhancePrompt(ctx context.Context, prompt *gollm.Prompt, return enhancedPrompt, nil } -// retrieveContext retrieves relevant context from RAG +// retrieveContext retrieves relevant context from stored memories. +// It uses vector similarity search to find the most relevant past interactions. func (m *MemoryContext) retrieveContext(ctx context.Context, input string) ([]string, error) { results, err := m.retriever.Retrieve(ctx, input) if err != nil { @@ -245,22 +357,33 @@ func (m *MemoryContext) retrieveContext(ctx context.Context, input string) ([]st return relevantContext, nil } -// Close releases resources +// Close releases resources held by the memory context. +// Always defer Close() after creating a new memory context. func (m *MemoryContext) Close() error { return m.retriever.Close() } -// GetRetriever returns the underlying retriever instance for advanced configuration +// GetRetriever returns the underlying retriever for advanced configuration. +// This provides access to low-level retrieval settings and operations. func (m *MemoryContext) GetRetriever() *Retriever { return m.retriever } -// GetOptions returns the current context options +// GetOptions returns the current context options configuration. +// Useful for inspecting or copying the current settings. func (m *MemoryContext) GetOptions() MemoryContextOptions { return m.options } -// UpdateOptions allows updating context options at runtime +// UpdateOptions allows updating context options at runtime. +// This enables dynamic reconfiguration of memory management behavior. +// +// Example: +// +// ctx.UpdateOptions( +// raggo.MemoryTopK(10), // Increase context breadth +// raggo.MemoryMinScore(0.8), // Raise relevance threshold +// ) func (m *MemoryContext) UpdateOptions(opts ...func(*MemoryContextOptions)) { options := m.GetOptions() for _, opt := range opts { diff --git a/parser.go b/parser.go index 277ccd4..ed76b4f 100644 --- a/parser.go +++ b/parser.go @@ -1,42 +1,117 @@ +// Package raggo provides a flexible and extensible document parsing system +// for RAG (Retrieval-Augmented Generation) applications. The system supports +// multiple file formats and can be extended with custom parsers. package raggo import ( "github.com/teilomillet/raggo/rag" ) -// Document represents a parsed document +// Document represents a parsed document with its content and metadata. +// The structure includes: +// - Content: The extracted text from the document +// - Metadata: Additional information about the document +// +// Example: +// +// doc := Document{ +// Content: "Extracted text content...", +// Metadata: map[string]string{ +// "file_type": "pdf", +// "file_path": "/path/to/doc.pdf", +// }, +// } type Document = rag.Document -// Parser defines the interface for parsing documents +// Parser defines the interface for document parsing implementations. +// Any type implementing this interface can be registered to handle +// specific file types. The interface is designed to be simple yet +// powerful enough to support various parsing strategies. +// +// Implementations must handle: +// - File access and reading +// - Content extraction +// - Metadata collection +// - Error handling type Parser interface { + // Parse processes a file and returns its content and metadata. + // Returns an error if the parsing operation fails. Parse(filePath string) (Document, error) } -// NewParser creates a new Parser with default settings +// NewParser creates a new Parser with default settings and handlers. +// The default configuration includes: +// - PDF document support +// - Plain text file support +// - Extension-based file type detection +// +// Example: +// +// parser := NewParser() +// doc, err := parser.Parse("document.pdf") func NewParser() Parser { return rag.NewParserManager() } -// SetFileTypeDetector sets a custom file type detector +// SetFileTypeDetector customizes how file types are detected. +// This allows for sophisticated file type detection beyond simple +// extension matching. +// +// Example: +// +// SetFileTypeDetector(parser, func(path string) string { +// // Custom logic to determine file type +// if strings.HasSuffix(path, ".md") { +// return "markdown" +// } +// return "unknown" +// }) func SetFileTypeDetector(p Parser, detector func(string) string) { if pm, ok := p.(*rag.ParserManager); ok { pm.SetFileTypeDetector(detector) } } -// WithParser adds a parser for a specific file type +// WithParser adds a custom parser for a specific file type. +// This enables the parsing system to handle additional file formats +// through custom implementations. +// +// Example: +// +// // Add support for markdown files +// WithParser(parser, "markdown", &MarkdownParser{}) func WithParser(p Parser, fileType string, parser Parser) { if pm, ok := p.(*rag.ParserManager); ok { pm.AddParser(fileType, parser) } } -// TextParser returns a new text parser +// TextParser returns a new parser for plain text files. +// The text parser: +// - Reads the entire file content +// - Preserves text formatting +// - Handles various encodings +// - Provides basic metadata +// +// Example: +// +// parser := TextParser() +// doc, err := parser.Parse("document.txt") func TextParser() Parser { return rag.NewTextParser() } -// PDFParser returns a new PDF parser +// PDFParser returns a new parser for PDF documents. +// The PDF parser: +// - Extracts text content from all pages +// - Maintains text order +// - Handles complex PDF structures +// - Provides document metadata +// +// Example: +// +// parser := PDFParser() +// doc, err := parser.Parse("document.pdf") func PDFParser() Parser { return rag.NewPDFParser() } diff --git a/rag.go b/rag.go index c8f6c4a..548543f 100644 --- a/rag.go +++ b/rag.go @@ -1,3 +1,44 @@ +// Package raggo implements a comprehensive Retrieval-Augmented Generation (RAG) system +// that enhances language models with the ability to access and reason over external +// knowledge. The system seamlessly integrates vector similarity search with natural +// language processing to provide accurate and contextually relevant responses. +// +// The package offers two main interfaces: +// - RAG: A full-featured implementation with extensive configuration options +// - SimpleRAG: A streamlined interface for basic use cases +// +// The RAG system works by: +// 1. Processing documents into semantic chunks +// 2. Storing document vectors in a configurable database +// 3. Finding relevant context through similarity search +// 4. Generating responses that combine retrieved context with queries +// +// Key Features: +// - Multiple vector database support (Milvus, in-memory, Chrome) +// - Intelligent document chunking and embedding +// - Hybrid search capabilities +// - Context-aware retrieval +// - Configurable LLM integration +// +// Example Usage: +// +// config := raggo.DefaultRAGConfig() +// config.APIKey = os.Getenv("OPENAI_API_KEY") +// +// rag, err := raggo.NewRAG( +// raggo.SetProvider("openai"), +// raggo.SetModel("text-embedding-3-small"), +// raggo.WithMilvus("my_documents"), +// ) +// if err != nil { +// log.Fatal(err) +// } +// +// // Add documents +// err = rag.LoadDocuments(context.Background(), "path/to/docs") +// +// // Query the system +// results, err := rag.Query(context.Background(), "your question here") package raggo import ( @@ -11,51 +52,68 @@ import ( "github.com/teilomillet/raggo/rag" ) -// RAGConfig holds all RAG settings to avoid name collision with VectorDB Config +// RAGConfig holds the complete configuration for a RAG system. It provides +// fine-grained control over all aspects of the system's operation, from database +// settings to search parameters. The configuration is designed to be flexible +// enough to accommodate various use cases while maintaining sensible defaults. type RAGConfig struct { - // Database settings - DBType string - DBAddress string - Collection string - AutoCreate bool - IndexType string - IndexMetric string - - // Processing settings - ChunkSize int - ChunkOverlap int - BatchSize int - - // Embedding settings - Provider string - Model string // For embeddings - LLMModel string // For LLM operations - APIKey string - - // Search settings - TopK int - MinScore float64 - UseHybrid bool - - // System settings - Timeout time.Duration - TempDir string - Debug bool - - // Search parameters - SearchParams map[string]interface{} // Add this field + // Database settings control how documents are stored and indexed + DBType string // Vector database type (e.g., "milvus", "memory") + DBAddress string // Database connection address + Collection string // Name of the vector collection + AutoCreate bool // Automatically create collection if it doesn't exist + IndexType string // Type of vector index (e.g., "HNSW", "IVF") + IndexMetric string // Distance metric for similarity (e.g., "L2", "IP") + + // Processing settings determine how documents are handled + ChunkSize int // Size of text chunks in tokens + ChunkOverlap int // Overlap between consecutive chunks + BatchSize int // Number of documents to process in parallel + + // Embedding settings configure vector generation + Provider string // Embedding provider (e.g., "openai", "cohere") + Model string // Embedding model name + LLMModel string // Language model for text generation + APIKey string // API key for the provider + + // Search settings control retrieval behavior + TopK int // Number of results to retrieve + MinScore float64 // Minimum similarity score threshold + UseHybrid bool // Whether to use hybrid search + + // System settings affect operational behavior + Timeout time.Duration // Operation timeout + TempDir string // Directory for temporary files + Debug bool // Enable debug logging + + // Search parameters for fine-tuning + SearchParams map[string]interface{} // Provider-specific search parameters } -// RAGOption modifies RAGConfig +// RAGOption is a function that modifies RAGConfig. +// It follows the functional options pattern for clean and flexible configuration. type RAGOption func(*RAGConfig) -// RAG provides a unified interface for document processing and retrieval +// RAG provides a comprehensive interface for document processing and retrieval. +// It coordinates the interaction between multiple components: +// - Vector database for efficient similarity search +// - Embedding service for semantic vector generation +// - Document processor for text chunking and enrichment +// - Language model for context-aware response generation +// +// The system is designed to be: +// - Thread-safe for concurrent operations +// - Memory-efficient when processing large documents +// - Extensible through custom implementations +// - Configurable for different use cases type RAG struct { - db *VectorDB - embedder *EmbeddingService - config *RAGConfig + db *VectorDB // Vector database connection + embedder *EmbeddingService // Service for generating embeddings + config *RAGConfig // System configuration } +// DefaultRAGConfig returns a default RAG configuration. +// It provides a reasonable set of default values for most use cases. func DefaultRAGConfig() *RAGConfig { return &RAGConfig{ DBType: "milvus", @@ -85,78 +143,183 @@ func DefaultRAGConfig() *RAGConfig { } // Common options +// SetProvider sets the embedding provider for the RAG system. +// Supported providers include "openai", "cohere", and others depending on implementation. +// +// Example: +// +// rag, err := raggo.NewRAG( +// raggo.SetProvider("openai"), +// ) func SetProvider(provider string) RAGOption { return func(c *RAGConfig) { c.Provider = provider } } +// SetModel specifies the embedding model to use for vector generation. +// The model should be compatible with the chosen provider. +// +// Example: +// +// rag, err := raggo.NewRAG( +// raggo.SetModel("text-embedding-3-small"), +// ) func SetModel(model string) RAGOption { return func(c *RAGConfig) { c.Model = model } } +// SetAPIKey configures the API key for the chosen provider. +// This key should have appropriate permissions for embedding and LLM operations. +// +// Example: +// +// rag, err := raggo.NewRAG( +// raggo.SetAPIKey(os.Getenv("OPENAI_API_KEY")), +// ) func SetAPIKey(key string) RAGOption { return func(c *RAGConfig) { c.APIKey = key } } +// SetCollection specifies the name of the vector collection to use. +// This collection will store document embeddings and metadata. +// +// Example: +// +// rag, err := raggo.NewRAG( +// raggo.SetCollection("my_documents"), +// ) func SetCollection(name string) RAGOption { return func(c *RAGConfig) { c.Collection = name } } +// SetSearchStrategy configures the search approach for document retrieval. +// Supported strategies include "simple" for pure vector search and +// "hybrid" for combined vector and keyword search. +// +// Example: +// +// rag, err := raggo.NewRAG( +// raggo.SetSearchStrategy("hybrid"), +// ) func SetSearchStrategy(strategy string) RAGOption { return func(c *RAGConfig) { c.UseHybrid = strategy == "hybrid" } } +// SetDBAddress configures the connection address for the vector database. +// Format depends on the database type (e.g., "localhost:19530" for Milvus). +// +// Example: +// +// rag, err := raggo.NewRAG( +// raggo.SetDBAddress("localhost:19530"), +// ) func SetDBAddress(address string) RAGOption { return func(c *RAGConfig) { c.DBAddress = address } } +// SetChunkSize configures the size of text chunks in tokens. +// Larger chunks provide more context but may reduce retrieval precision. +// +// Example: +// +// rag, err := raggo.NewRAG( +// raggo.SetChunkSize(512), +// ) func SetChunkSize(size int) RAGOption { return func(c *RAGConfig) { c.ChunkSize = size } } +// SetChunkOverlap specifies the overlap between consecutive chunks in tokens. +// Overlap helps maintain context across chunk boundaries. +// +// Example: +// +// rag, err := raggo.NewRAG( +// raggo.SetChunkOverlap(50), +// ) func SetChunkOverlap(overlap int) RAGOption { return func(c *RAGConfig) { c.ChunkOverlap = overlap } } +// SetTopK configures the number of similar documents to retrieve. +// Higher values provide more context but may introduce noise. +// +// Example: +// +// rag, err := raggo.NewRAG( +// raggo.SetTopK(5), +// ) func SetTopK(k int) RAGOption { return func(c *RAGConfig) { c.TopK = k } } +// SetMinScore sets the minimum similarity score threshold for retrieval. +// Documents with scores below this threshold are filtered out. +// +// Example: +// +// rag, err := raggo.NewRAG( +// raggo.SetMinScore(0.7), +// ) func SetMinScore(score float64) RAGOption { return func(c *RAGConfig) { c.MinScore = score } } +// SetTimeout configures the maximum duration for operations. +// This affects database operations, embedding generation, and LLM calls. +// +// Example: +// +// rag, err := raggo.NewRAG( +// raggo.SetTimeout(30 * time.Second), +// ) func SetTimeout(timeout time.Duration) RAGOption { return func(c *RAGConfig) { c.Timeout = timeout } } +// SetDebug enables or disables debug logging. +// When enabled, the system will output detailed operation information. +// +// Example: +// +// rag, err := raggo.NewRAG( +// raggo.SetDebug(true), +// ) func SetDebug(debug bool) RAGOption { return func(c *RAGConfig) { c.Debug = debug } } +// WithOpenAI is a convenience function that configures the RAG system +// to use OpenAI's embedding and language models. +// +// Example: +// +// rag, err := raggo.NewRAG( +// raggo.WithOpenAI(os.Getenv("OPENAI_API_KEY")), +// ) func WithOpenAI(apiKey string) RAGOption { return func(c *RAGConfig) { c.Provider = "openai" @@ -165,6 +328,14 @@ func WithOpenAI(apiKey string) RAGOption { } } +// WithMilvus is a convenience function that configures the RAG system +// to use Milvus as the vector database with the specified collection. +// +// Example: +// +// rag, err := raggo.NewRAG( +// raggo.WithMilvus("my_documents"), +// ) func WithMilvus(collection string) RAGOption { return func(c *RAGConfig) { c.DBType = "milvus" @@ -173,7 +344,8 @@ func WithMilvus(collection string) RAGOption { } } -// NewRAG creates a new RAG instance +// NewRAG creates a new RAG instance. +// It takes a variable number of RAGOption functions to configure the system. func NewRAG(opts ...RAGOption) (*RAG, error) { cfg := DefaultRAGConfig() for _, opt := range opts { @@ -215,7 +387,16 @@ func (r *RAG) initialize() error { return r.db.Connect(context.Background()) } -// LoadDocuments processes and stores documents +// LoadDocuments processes and stores documents in the vector database. +// It handles various document formats and automatically chunks text +// based on the configured chunk size and overlap. +// +// The source parameter can be a file path or directory. When a directory +// is provided, all supported documents within it are processed recursively. +// +// Example: +// +// err := rag.LoadDocuments(ctx, "path/to/docs") func (r *RAG) LoadDocuments(ctx context.Context, source string) error { loader := NewLoader(SetTempDir(r.config.TempDir)) chunker, err := NewChunker( @@ -261,7 +442,8 @@ func (r *RAG) LoadDocuments(ctx context.Context, source string) error { return nil } -// storeEnrichedChunks stores chunks with their context +// storeEnrichedChunks stores chunks with their context. +// It takes a context, enriched chunks, and a source path as input. func (r *RAG) storeEnrichedChunks(ctx context.Context, enrichedChunks []string, source string) error { // Convert strings to Chunks chunks := make([]rag.Chunk, len(enrichedChunks)) @@ -299,7 +481,8 @@ func (r *RAG) storeEnrichedChunks(ctx context.Context, enrichedChunks []string, return r.db.Insert(ctx, r.config.Collection, records) } -// ProcessWithContext processes and stores documents with additional contextual information +// ProcessWithContext processes and stores documents with additional contextual information. +// It takes a context, source path, and an optional LLM model as input. func (r *RAG) ProcessWithContext(ctx context.Context, source string, llmModel string) error { Debug("Processing source:", source) @@ -475,7 +658,16 @@ Provide only the context, without any introductory phrases or explanations.`, ch return llm.Generate(ctx, gollm.NewPrompt(prompt)) } -// Query performs RAG retrieval and returns results +// Query performs a retrieval operation using the configured search strategy. +// It returns a slice of RetrieverResult containing relevant document chunks +// and their similarity scores. +// +// The query parameter should be a natural language question or statement. +// The system will convert it to a vector and find similar documents. +// +// Example: +// +// results, err := rag.Query(ctx, "How does feature X work?") func (r *RAG) Query(ctx context.Context, query string) ([]RetrieverResult, error) { if !r.config.UseHybrid { return r.simpleSearch(ctx, query) @@ -483,7 +675,14 @@ func (r *RAG) Query(ctx context.Context, query string) ([]RetrieverResult, error return r.hybridSearch(ctx, query) } -// Close releases resources +// Close releases all resources held by the RAG system, including +// database connections and embedding service clients. +// +// It should be called when the RAG system is no longer needed. +// +// Example: +// +// defer rag.Close() func (r *RAG) Close() error { if r.db != nil { return r.db.Close() diff --git a/rag/chromem.go b/rag/chromem.go index 597cd5d..63c6835 100644 --- a/rag/chromem.go +++ b/rag/chromem.go @@ -1,5 +1,4 @@ -// File: chromem.go - +// Package rag provides retrieval-augmented generation capabilities. package rag import ( @@ -13,14 +12,36 @@ import ( "github.com/philippgille/chromem-go" ) +// ChromemDB implements a vector database interface using ChromeM. +// ChromeM is a lightweight, embedded vector database that supports: +// - In-memory and persistent storage modes +// - Multiple embedding providers (OpenAI, Cohere, etc.) +// - Collection-based organization +// - Approximate nearest neighbor search +// +// This implementation features: +// - Thread-safe operations with mutex protection +// - Automatic collection management +// - OpenAI embedding integration +// - Configurable vector dimensions type ChromemDB struct { - db *chromem.DB - collections map[string]*chromem.Collection - mu sync.RWMutex - columnNames []string - dimension int + db *chromem.DB // Underlying ChromeM database + collections map[string]*chromem.Collection // Cache of active collections + mu sync.RWMutex // Protects concurrent access to collections + columnNames []string // Names of columns to retrieve in search results + dimension int // Vector dimension for embeddings } +// newChromemDB creates a new ChromemDB instance with the given configuration. +// It supports both in-memory and persistent storage modes: +// - If cfg.Address is empty: Creates an in-memory database +// - If cfg.Address is set: Creates a persistent database at the specified path +// +// The function performs initial setup and validation: +// 1. Configures vector dimension (default: 1536 for OpenAI embeddings) +// 2. Creates necessary directories for persistent storage +// 3. Tests database functionality with a temporary collection +// 4. Verifies OpenAI API key availability func newChromemDB(cfg *Config) (*ChromemDB, error) { log.Printf("Creating new ChromemDB with config: %+v", cfg) @@ -108,6 +129,9 @@ func newChromemDB(cfg *Config) (*ChromemDB, error) { }, nil } +// Connect establishes a connection to the ChromeM database. +// This is a no-op for ChromeM as it's an embedded database, +// but implemented to satisfy the database interface. func (c *ChromemDB) Connect(ctx context.Context) error { log.Printf("Connecting to ChromemDB") // No explicit connect needed for chromem @@ -115,11 +139,21 @@ func (c *ChromemDB) Connect(ctx context.Context) error { return nil } +// Close releases any resources held by the ChromeM database. +// This is a no-op for ChromeM as it's an embedded database, +// but implemented to satisfy the database interface. func (c *ChromemDB) Close() error { // No explicit close in chromem return nil } +// HasCollection checks if a collection exists in the database. +// The function: +// 1. Checks the local collection cache first +// 2. Falls back to querying the database directly +// 3. Updates the cache if the collection is found +// +// Thread-safe: Protected by read lock. func (c *ChromemDB) HasCollection(ctx context.Context, name string) (bool, error) { c.mu.RLock() defer c.mu.RUnlock() @@ -156,6 +190,12 @@ func (c *ChromemDB) HasCollection(ctx context.Context, name string) (bool, error return exists, nil } +// DropCollection removes a collection from the database. +// Note: In ChromeM, collections are not explicitly dropped, +// but rather overwritten when created again. This implementation +// removes the collection from the local cache. +// +// Thread-safe: Protected by write lock. func (c *ChromemDB) DropCollection(ctx context.Context, name string) error { c.mu.Lock() defer c.mu.Unlock() @@ -164,6 +204,18 @@ func (c *ChromemDB) DropCollection(ctx context.Context, name string) error { return nil } +// CreateCollection initializes a new collection in the database. +// The function: +// 1. Verifies collection doesn't exist in cache +// 2. Configures OpenAI embeddings (requires OPENAI_API_KEY) +// 3. Creates collection with empty metadata +// 4. Verifies successful creation +// 5. Updates local cache +// +// Note: ChromeM doesn't use schema information, so the schema +// parameter is effectively ignored. +// +// Thread-safe: Protected by write lock. func (c *ChromemDB) CreateCollection(ctx context.Context, name string, schema Schema) error { c.mu.Lock() defer c.mu.Unlock() @@ -206,6 +258,16 @@ func (c *ChromemDB) CreateCollection(ctx context.Context, name string, schema Sc return nil } +// Insert adds new records to a collection. +// The function: +// 1. Retrieves or creates the target collection +// 2. Processes records in parallel for efficiency +// 3. Creates ChromeM documents with: +// - Unique IDs +// - Metadata from record fields +// - Content from specified text field +// +// Thread-safe: Uses ChromeM's internal synchronization. func (c *ChromemDB) Insert(ctx context.Context, collectionName string, data []Record) error { c.mu.Lock() defer c.mu.Unlock() @@ -306,16 +368,29 @@ func (c *ChromemDB) Insert(ctx context.Context, collectionName string, data []Re return nil } +// Flush ensures all data is persisted to storage. +// This is a no-op for ChromeM as it handles persistence automatically, +// but implemented to satisfy the database interface. func (c *ChromemDB) Flush(ctx context.Context, collectionName string) error { // No explicit flush in chromem return nil } +// CreateIndex creates an index for the specified field. +// This is a no-op for ChromeM as it manages its own indexing, +// but implemented to satisfy the database interface. func (c *ChromemDB) CreateIndex(ctx context.Context, collectionName, field string, index Index) error { // No explicit index creation in chromem return nil } +// LoadCollection prepares a collection for searching. +// The function: +// 1. Verifies collection exists +// 2. Configures OpenAI embeddings +// 3. Updates local collection cache +// +// Thread-safe: Protected by write lock. func (c *ChromemDB) LoadCollection(ctx context.Context, name string) error { c.mu.Lock() defer c.mu.Unlock() @@ -345,6 +420,16 @@ func (c *ChromemDB) LoadCollection(ctx context.Context, name string) error { return nil } +// Search performs vector similarity search on a collection. +// The function: +// 1. Retrieves the target collection +// 2. Converts query vectors to float32 +// 3. Executes search with specified parameters +// 4. Formats results to match interface requirements +// +// Note: ChromeM only supports cosine similarity for distance metric. +// +// Thread-safe: Uses ChromeM's internal synchronization. func (c *ChromemDB) Search(ctx context.Context, collectionName string, vectors map[string]Vector, topK int, metricType string, searchParams map[string]interface{}) ([]SearchResult, error) { c.mu.RLock() @@ -415,16 +500,22 @@ func (c *ChromemDB) Search(ctx context.Context, collectionName string, vectors m return searchResults, nil } +// HybridSearch performs combined vector and keyword search. +// Currently not implemented for ChromeM - returns an error. func (c *ChromemDB) HybridSearch(ctx context.Context, collectionName string, vectors map[string]Vector, topK int, metricType string, searchParams map[string]interface{}, reranker interface{}) ([]SearchResult, error) { // Not implemented for chromem return nil, fmt.Errorf("hybrid search not implemented for chromem") } +// SetColumnNames sets the list of columns to retrieve in search results. +// This affects which fields are included in search result metadata. func (c *ChromemDB) SetColumnNames(names []string) { c.columnNames = names } -// Helper function to convert Vector to []float32 +// toFloat32Slice converts a Vector ([]float64) to []float32. +// This is necessary because ChromeM uses float32 for vector operations, +// while our interface uses float64 for broader compatibility. func toFloat32Slice(v Vector) []float32 { result := make([]float32, len(v)) for i, val := range v { diff --git a/rag/chunk.go b/rag/chunk.go index 46cb2a5..0f56b4a 100644 --- a/rag/chunk.go +++ b/rag/chunk.go @@ -1,3 +1,5 @@ +// Package rag provides text chunking capabilities for processing documents into +// manageable pieces suitable for vector embedding and retrieval. package rag import ( @@ -7,33 +9,55 @@ import ( "github.com/pkoukk/tiktoken-go" ) -// Chunk represents a piece of text with metadata +// Chunk represents a piece of text with associated metadata for tracking its position +// and size within the original document. type Chunk struct { - Text string - TokenSize int + // Text contains the actual content of the chunk + Text string + // TokenSize represents the number of tokens in this chunk + TokenSize int + // StartSentence is the index of the first sentence in this chunk StartSentence int - EndSentence int + // EndSentence is the index of the last sentence in this chunk (exclusive) + EndSentence int } -// Chunker defines the interface for text chunking +// Chunker defines the interface for text chunking implementations. +// Different implementations can provide various strategies for splitting text +// while maintaining context and semantic meaning. type Chunker interface { + // Chunk splits the input text into a slice of Chunks according to the + // implementation's strategy. Chunk(text string) []Chunk } -// TokenCounter defines the interface for counting tokens in a string +// TokenCounter defines the interface for counting tokens in a string. +// This abstraction allows for different tokenization strategies (e.g., words, subwords). type TokenCounter interface { + // Count returns the number of tokens in the given text according to the + // implementation's tokenization strategy. Count(text string) int } -// TextChunker is an implementation of Chunker with advanced features +// TextChunker provides an advanced implementation of the Chunker interface +// with support for overlapping chunks and custom tokenization. type TextChunker struct { - ChunkSize int - ChunkOverlap int - TokenCounter TokenCounter + // ChunkSize is the target size of each chunk in tokens + ChunkSize int + // ChunkOverlap is the number of tokens that should overlap between adjacent chunks + ChunkOverlap int + // TokenCounter is used to count tokens in text segments + TokenCounter TokenCounter + // SentenceSplitter is a function that splits text into sentences SentenceSplitter func(string) []string } -// NewTextChunker creates a new TextChunker with the given options +// NewTextChunker creates a new TextChunker with the given options. +// It uses sensible defaults if no options are provided: +// - ChunkSize: 200 tokens +// - ChunkOverlap: 50 tokens +// - TokenCounter: DefaultTokenCounter +// - SentenceSplitter: DefaultSentenceSplitter func NewTextChunker(options ...TextChunkerOption) (*TextChunker, error) { tc := &TextChunker{ ChunkSize: 200, @@ -49,10 +73,16 @@ func NewTextChunker(options ...TextChunkerOption) (*TextChunker, error) { return tc, nil } -// TextChunkerOption is a function type for configuring TextChunker +// TextChunkerOption is a function type for configuring TextChunker instances. +// This follows the functional options pattern for clean and flexible configuration. type TextChunkerOption func(*TextChunker) -// Chunk splits the input text into chunks +// Chunk splits the input text into chunks while preserving sentence boundaries +// and maintaining the specified overlap between chunks. The algorithm: +// 1. Splits the text into sentences +// 2. Builds chunks by adding sentences until the chunk size limit is reached +// 3. Creates overlap with previous chunk when starting a new chunk +// 4. Tracks token counts and sentence indices for each chunk func (tc *TextChunker) Chunk(text string) []Chunk { sentences := tc.SentenceSplitter(text) var chunks []Chunk @@ -94,7 +124,9 @@ func (tc *TextChunker) Chunk(text string) []Chunk { return chunks } -// estimateOverlapSentences estimates the number of sentences needed for the desired token overlap +// estimateOverlapSentences calculates how many sentences from the end of the +// previous chunk should be included in the next chunk to achieve the desired +// token overlap. func (tc *TextChunker) estimateOverlapSentences(sentences []string, endSentence, desiredOverlap int) int { overlapTokens := 0 overlapSentences := 0 @@ -105,14 +137,20 @@ func (tc *TextChunker) estimateOverlapSentences(sentences []string, endSentence, return overlapSentences } -// DefaultSentenceSplitter splits text into sentences (simplified version) +// DefaultSentenceSplitter provides a basic implementation for splitting text into sentences. +// It uses common punctuation marks (., !, ?) as sentence boundaries. func DefaultSentenceSplitter(text string) []string { return strings.FieldsFunc(text, func(r rune) bool { return r == '.' || r == '!' || r == '?' }) } -// SmartSentenceSplitter is a more advanced sentence splitter that handles various punctuation and edge cases +// SmartSentenceSplitter provides an advanced sentence splitting implementation that handles: +// - Multiple punctuation marks (., !, ?) +// - Common abbreviations +// - Quoted sentences +// - Parenthetical sentences +// - Lists and enumerations func SmartSentenceSplitter(text string) []string { var sentences []string var currentSentence strings.Builder @@ -142,19 +180,28 @@ func SmartSentenceSplitter(text string) []string { return sentences } -// DefaultTokenCounter is a simple word-based token counter +// DefaultTokenCounter provides a simple word-based token counting implementation. +// It splits text on whitespace to approximate token counts. This is suitable +// for basic use cases but may not accurately reflect subword tokenization +// used by language models. type DefaultTokenCounter struct{} +// Count returns the number of words in the text, using whitespace as a delimiter. func (dtc *DefaultTokenCounter) Count(text string) int { return len(strings.Fields(text)) } -// TikTokenCounter is a token counter that uses the tiktoken library +// TikTokenCounter provides accurate token counting using the tiktoken library, +// which implements the tokenization schemes used by OpenAI models. type TikTokenCounter struct { tke *tiktoken.Tiktoken } -// NewTikTokenCounter creates a new TikTokenCounter with the specified encoding +// NewTikTokenCounter creates a new TikTokenCounter using the specified encoding. +// Common encodings include: +// - "cl100k_base" (GPT-4, ChatGPT) +// - "p50k_base" (GPT-3) +// - "r50k_base" (Codex) func NewTikTokenCounter(encoding string) (*TikTokenCounter, error) { tke, err := tiktoken.GetEncoding(encoding) if err != nil { @@ -163,10 +210,13 @@ func NewTikTokenCounter(encoding string) (*TikTokenCounter, error) { return &TikTokenCounter{tke: tke}, nil } +// Count returns the exact number of tokens in the text according to the +// specified tiktoken encoding. func (ttc *TikTokenCounter) Count(text string) int { return len(ttc.tke.Encode(text, nil, nil)) } +// max returns the larger of two integers. func max(a, b int) int { if a > b { return a diff --git a/rag/embed.go b/rag/embed.go index ca486c3..c400b1d 100644 --- a/rag/embed.go +++ b/rag/embed.go @@ -1,3 +1,5 @@ +// Package rag provides functionality for converting text into vector embeddings +// using various embedding providers (e.g., OpenAI, Cohere, local models). package rag import ( @@ -7,44 +9,64 @@ import ( "github.com/teilomillet/raggo/rag/providers" ) -// EmbedderConfig holds the configuration for creating an Embedder +// EmbedderConfig holds the configuration for creating an Embedder instance. +// It supports multiple embedding providers and their specific options. type EmbedderConfig struct { + // Provider specifies the embedding service to use (e.g., "openai", "cohere") Provider string - Options map[string]interface{} + // Options contains provider-specific configuration parameters + Options map[string]interface{} } -// EmbedderOption is a function type for configuring the EmbedderConfig +// EmbedderOption is a function type for configuring the EmbedderConfig. +// It follows the functional options pattern for clean and flexible configuration. type EmbedderOption func(*EmbedderConfig) -// SetProvider sets the provider for the Embedder +// SetProvider sets the provider for the Embedder. +// Common providers include: +// - "openai": OpenAI's text-embedding-ada-002 and other models +// - "cohere": Cohere's embedding models +// - "local": Local embedding models func SetProvider(provider string) EmbedderOption { return func(c *EmbedderConfig) { c.Provider = provider } } -// SetModel sets the model for the Embedder +// SetModel sets the specific model to use for embedding. +// The available models depend on the chosen provider. +// Examples: +// - OpenAI: "text-embedding-ada-002" +// - Cohere: "embed-multilingual-v2.0" func SetModel(model string) EmbedderOption { return func(c *EmbedderConfig) { c.Options["model"] = model } } -// SetAPIKey sets the API key for the Embedder +// SetAPIKey sets the authentication key for the embedding service. +// This is required for most cloud-based embedding providers. func SetAPIKey(apiKey string) EmbedderOption { return func(c *EmbedderConfig) { c.Options["api_key"] = apiKey } } -// SetOption sets a custom option for the Embedder +// SetOption sets a custom option for the Embedder. +// This allows for provider-specific configuration options +// that aren't covered by the standard options. func SetOption(key string, value interface{}) EmbedderOption { return func(c *EmbedderConfig) { c.Options[key] = value } } -// NewEmbedder creates a new Embedder instance based on the provided options +// NewEmbedder creates a new Embedder instance based on the provided options. +// It uses the provider factory system to instantiate the appropriate embedder +// implementation. Returns an error if: +// - No provider is specified +// - The specified provider is not registered +// - The provider factory fails to create an embedder func NewEmbedder(opts ...EmbedderOption) (providers.Embedder, error) { config := &EmbedderConfig{ Options: make(map[string]interface{}), @@ -62,24 +84,42 @@ func NewEmbedder(opts ...EmbedderOption) (providers.Embedder, error) { return factory(config.Options) } -// EmbeddedChunk represents a chunk of text with its embeddings and metadata +// EmbeddedChunk represents a chunk of text along with its vector embeddings +// and associated metadata. This is the core data structure for storing +// and retrieving embedded content. type EmbeddedChunk struct { - Text string `json:"text"` - Embeddings map[string][]float64 `json:"embeddings"` - Metadata map[string]interface{} `json:"metadata"` + // Text is the original text content that was embedded + Text string `json:"text"` + // Embeddings maps embedding types to their vector representations + // Multiple embeddings can exist for different models or purposes + Embeddings map[string][]float64 `json:"embeddings"` + // Metadata stores additional information about the chunk + // This can include source document info, timestamps, etc. + Metadata map[string]interface{} `json:"metadata"` } -// EmbeddingService handles the embedding process +// EmbeddingService handles the process of converting text chunks into +// vector embeddings. It encapsulates the embedding provider and provides +// a high-level interface for embedding operations. type EmbeddingService struct { embedder providers.Embedder } -// NewEmbeddingService creates a new embedding service with a single embedder +// NewEmbeddingService creates a new embedding service with the specified embedder. +// The embedder must be properly configured and ready to generate embeddings. func NewEmbeddingService(embedder providers.Embedder) *EmbeddingService { return &EmbeddingService{embedder: embedder} } -// EmbedChunks embeds a slice of chunks +// EmbedChunks processes a slice of text chunks and generates embeddings for each one. +// It handles the embedding process in sequence, with debug output for monitoring. +// The function: +// 1. Allocates space for the results +// 2. Processes each chunk through the embedder +// 3. Creates EmbeddedChunk instances with the results +// 4. Provides progress information via debug output +// +// Returns an error if any chunk fails to embed properly. func (s *EmbeddingService) EmbedChunks(ctx context.Context, chunks []Chunk) ([]EmbeddedChunk, error) { embeddedChunks := make([]EmbeddedChunk, 0, len(chunks)) @@ -117,6 +157,9 @@ func (s *EmbeddingService) EmbedChunks(ctx context.Context, chunks []Chunk) ([]E return embeddedChunks, nil } +// truncateString shortens a string to the specified length, adding an ellipsis +// if the string was truncated. This is used for debug output to keep log +// messages readable. func truncateString(s string, n int) string { if len(s) <= n { return s diff --git a/rag/example_vectordb.go b/rag/example_vectordb.go new file mode 100644 index 0000000..2f0d9c3 --- /dev/null +++ b/rag/example_vectordb.go @@ -0,0 +1,178 @@ +// Package rag provides retrieval-augmented generation capabilities. +package rag + +import ( + "context" + "fmt" + "sync" +) + +// ExampleDB demonstrates how to implement a new vector database provider. +// This template shows the minimum required functionality and common patterns. +// +// Key Features to Implement: +// - Thread-safe operations +// - Vector similarity search +// - Collection management +// - Data persistence (if applicable) +// - Error handling and logging +type ExampleDB struct { + // Configuration + config *Config + + // Connection state + isConnected bool + + // Collection management + collections map[string]interface{} // Replace interface{} with your collection type + mu sync.RWMutex // Protects concurrent access to collections + + // Search configuration + columnNames []string // Names of columns to retrieve in search results + dimension int // Vector dimension for embeddings +} + +// newExampleDB creates a new ExampleDB instance with the given configuration. +// Initialize your database connection and any required resources here. +func newExampleDB(cfg *Config) (*ExampleDB, error) { + // Get dimension from config parameters (example) + dimension, ok := cfg.Parameters["dimension"].(int) + if !ok { + dimension = 1536 // Default dimension + } + + return &ExampleDB{ + config: cfg, + collections: make(map[string]interface{}), + dimension: dimension, + }, nil +} + +// Connect establishes a connection to the database. +// Implement your connection logic here. +func (db *ExampleDB) Connect(ctx context.Context) error { + GlobalLogger.Debug("Connecting to example database", "address", db.config.Address) + + // Add your connection logic here + // Example: + // client, err := yourdb.Connect(db.config.Address) + // if err != nil { + // return fmt.Errorf("failed to connect: %w", err) + // } + + db.isConnected = true + return nil +} + +// Close terminates the database connection. +// Clean up any resources here. +func (db *ExampleDB) Close() error { + if !db.isConnected { + return nil + } + + // Add your cleanup logic here + db.isConnected = false + return nil +} + +// HasCollection checks if a collection exists. +func (db *ExampleDB) HasCollection(ctx context.Context, name string) (bool, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + _, exists := db.collections[name] + return exists, nil +} + +// CreateCollection initializes a new collection. +func (db *ExampleDB) CreateCollection(ctx context.Context, name string, schema Schema) error { + db.mu.Lock() + defer db.mu.Unlock() + + // Validate collection doesn't exist + if _, exists := db.collections[name]; exists { + return fmt.Errorf("collection %s already exists", name) + } + + // Initialize your collection here + // Example: + // collection, err := db.client.CreateCollection(name, schema) + // if err != nil { + // return fmt.Errorf("failed to create collection: %w", err) + // } + // db.collections[name] = collection + + return nil +} + +// DropCollection removes a collection. +func (db *ExampleDB) DropCollection(ctx context.Context, name string) error { + db.mu.Lock() + defer db.mu.Unlock() + + delete(db.collections, name) + return nil +} + +// Insert adds new records to a collection. +func (db *ExampleDB) Insert(ctx context.Context, collectionName string, data []Record) error { + // Example vector conversion if needed: + // vectors := make([][]float32, len(data)) + // for i, record := range data { + // vectors[i] = toFloat32Slice(record.Vector) + // } + + // Add your insert logic here + return nil +} + +// Search performs vector similarity search. +func (db *ExampleDB) Search(ctx context.Context, collectionName string, vectors map[string]Vector, topK int, metricType string, searchParams map[string]interface{}) ([]SearchResult, error) { + // Example implementation steps: + // 1. Convert vectors if needed + // 2. Perform search + // 3. Format results + + return nil, fmt.Errorf("not implemented") +} + +// HybridSearch combines vector and keyword search (optional). +func (db *ExampleDB) HybridSearch(ctx context.Context, collectionName string, vectors map[string]Vector, topK int, metricType string, searchParams map[string]interface{}, reranker interface{}) ([]SearchResult, error) { + return nil, fmt.Errorf("hybrid search not supported") +} + +// Additional optional methods: + +// Flush ensures data persistence (if applicable). +func (db *ExampleDB) Flush(ctx context.Context, collectionName string) error { + return nil +} + +// CreateIndex builds search indexes (if applicable). +func (db *ExampleDB) CreateIndex(ctx context.Context, collectionName, field string, index Index) error { + return nil +} + +// LoadCollection prepares a collection for searching (if needed). +func (db *ExampleDB) LoadCollection(ctx context.Context, name string) error { + return nil +} + +// SetColumnNames configures which fields to return in search results. +func (db *ExampleDB) SetColumnNames(names []string) { + db.columnNames = names +} + +// Helper functions + +// exampleToFloat32Slice converts vectors if your database needs a different format. +// Note: If you need float32 conversion, consider using the existing toFloat32Slice +// function from the rag package instead of implementing your own. +func exampleToFloat32Slice(v Vector) []float32 { + result := make([]float32, len(v)) + for i, val := range v { + result[i] = float32(val) + } + return result +} diff --git a/rag/load.go b/rag/load.go index e60a704..d3b9914 100644 --- a/rag/load.go +++ b/rag/load.go @@ -1,4 +1,7 @@ -// rag/load.go +// Package rag provides document loading functionality for the Raggo framework. +// The loader component handles various input sources including local files, +// directories, and URLs, with support for concurrent operations and +// configurable timeouts. package rag import ( @@ -10,15 +13,25 @@ import ( "time" ) -// Loader represents the internal loader implementation +// Loader represents the internal loader implementation. +// It provides methods for loading documents from various sources +// with configurable HTTP client, timeout settings, and temporary +// storage management. The loader is designed to be thread-safe +// and can handle concurrent loading operations. type Loader struct { - client *http.Client - timeout time.Duration - tempDir string - logger Logger + client *http.Client // HTTP client for URL downloads + timeout time.Duration // Timeout for operations + tempDir string // Directory for temporary files + logger Logger // Logger for operation tracking } -// NewLoader creates a new Loader with the given options +// NewLoader creates a new Loader with the given options. +// It initializes a loader with default settings and applies +// any provided options. Default settings include: +// - Standard HTTP client +// - 30-second timeout +// - System temporary directory +// - Global logger instance func NewLoader(opts ...LoaderOption) *Loader { l := &Loader{ client: http.DefaultClient, @@ -34,37 +47,60 @@ func NewLoader(opts ...LoaderOption) *Loader { return l } -// LoaderOption is a functional option for configuring a Loader +// LoaderOption is a functional option for configuring a Loader. +// It follows the functional options pattern to provide a clean +// and extensible way to configure the loader. type LoaderOption func(*Loader) -// WithHTTPClient sets a custom HTTP client for the Loader +// WithHTTPClient sets a custom HTTP client for the Loader. +// This allows customization of the HTTP client used for URL downloads, +// enabling features like custom transport settings, proxies, or +// authentication mechanisms. func WithHTTPClient(client *http.Client) LoaderOption { return func(l *Loader) { l.client = client } } -// WithTimeout sets a custom timeout for the Loader +// WithTimeout sets a custom timeout for the Loader. +// This timeout applies to all operations including: +// - URL downloads +// - File operations +// - Directory traversal func WithTimeout(timeout time.Duration) LoaderOption { return func(l *Loader) { l.timeout = timeout } } -// WithTempDir sets the temporary directory for downloaded files +// WithTempDir sets the temporary directory for downloaded files. +// This directory is used to store: +// - Downloaded files from URLs +// - Copies of local files for processing +// - Temporary files during directory operations func WithTempDir(dir string) LoaderOption { return func(l *Loader) { l.tempDir = dir } } -// WithLogger sets a custom logger for the Loader +// WithLogger sets a custom logger for the Loader. +// The logger is used to track operations and debug issues +// across all loading operations. func WithLogger(logger Logger) LoaderOption { return func(l *Loader) { l.logger = logger } } +// LoadURL downloads a file from the given URL and stores it in the temporary directory. +// The function: +// 1. Creates a context with the configured timeout +// 2. Downloads the file using the HTTP client +// 3. Stores the file in the temporary directory +// 4. Returns the path to the downloaded file +// +// The downloaded file's name is derived from the URL's base name. func (l *Loader) LoadURL(ctx context.Context, url string) (string, error) { l.logger.Debug("Starting LoadURL", "url", url) ctx, cancel := context.WithTimeout(ctx, l.timeout) @@ -103,6 +139,13 @@ func (l *Loader) LoadURL(ctx context.Context, url string) (string, error) { return destPath, nil } +// LoadFile copies a file to the temporary directory and returns its path. +// The function: +// 1. Verifies the source file exists +// 2. Creates a copy in the temporary directory +// 3. Returns the path to the copied file +// +// This ensures that the original file remains unchanged during processing. func (l *Loader) LoadFile(ctx context.Context, path string) (string, error) { l.logger.Debug("Starting LoadFile", "path", path) @@ -139,6 +182,14 @@ func (l *Loader) LoadFile(ctx context.Context, path string) (string, error) { return destPath, nil } +// LoadDir recursively processes all files in a directory. +// The function: +// 1. Walks through the directory tree +// 2. Processes each file encountered +// 3. Returns paths to all processed files +// +// Files that fail to load are logged but don't stop the process. +// The function continues with the next file on error. func (l *Loader) LoadDir(ctx context.Context, dir string) ([]string, error) { l.logger.Debug("Starting LoadDir", "dir", dir) diff --git a/rag/log.go b/rag/log.go index 1ecde3c..9c37892 100644 --- a/rag/log.go +++ b/rag/log.go @@ -1,3 +1,6 @@ +// Package rag provides a flexible logging system for the Raggo framework. +// It supports multiple log levels, structured logging with key-value pairs, +// and can be easily extended with custom logger implementations. package rag import ( @@ -7,29 +10,53 @@ import ( "strings" ) +// LogLevel represents the severity level of a log message. +// Higher values indicate more verbose logging. type LogLevel int const ( + // LogLevelOff disables all logging LogLevelOff LogLevel = iota + // LogLevelError enables only error messages LogLevelError + // LogLevelWarn enables error and warning messages LogLevelWarn + // LogLevelInfo enables error, warning, and info messages LogLevelInfo + // LogLevelDebug enables all messages including debug LogLevelDebug ) +// Logger defines the interface for logging operations. +// Implementations must support multiple severity levels and +// structured logging with key-value pairs. type Logger interface { + // Debug logs a message at debug level with optional key-value pairs Debug(msg string, keysAndValues ...interface{}) + // Info logs a message at info level with optional key-value pairs Info(msg string, keysAndValues ...interface{}) + // Warn logs a message at warning level with optional key-value pairs Warn(msg string, keysAndValues ...interface{}) + // Error logs a message at error level with optional key-value pairs Error(msg string, keysAndValues ...interface{}) + // SetLevel changes the current logging level SetLevel(level LogLevel) } +// DefaultLogger provides a basic implementation of the Logger interface +// using the standard library's log package. It supports: +// - Multiple log levels +// - Structured logging with key-value pairs +// - Output to os.Stderr by default +// - Standard timestamp format type DefaultLogger struct { logger *log.Logger level LogLevel } +// NewLogger creates a new DefaultLogger instance with the specified log level. +// The logger writes to os.Stderr using the standard log package format: +// timestamp + message + key-value pairs. func NewLogger(level LogLevel) Logger { return &DefaultLogger{ logger: log.New(os.Stderr, "", log.LstdFlags), @@ -37,36 +64,53 @@ func NewLogger(level LogLevel) Logger { } } +// SetLevel updates the logging level of the DefaultLogger. +// Messages below this level will not be logged. func (l *DefaultLogger) SetLevel(level LogLevel) { l.level = level } +// log is an internal helper that handles the actual logging operation. +// It checks the log level and formats the message with key-value pairs. func (l *DefaultLogger) log(level LogLevel, msg string, keysAndValues ...interface{}) { if level <= l.level { l.logger.Printf("%s: %s %v", level, msg, keysAndValues) } } +// Debug logs a message at debug level. This level should be used for +// detailed information needed for debugging purposes. func (l *DefaultLogger) Debug(msg string, keysAndValues ...interface{}) { l.log(LogLevelDebug, msg, keysAndValues...) } +// Info logs a message at info level. This level should be used for +// general operational information. func (l *DefaultLogger) Info(msg string, keysAndValues ...interface{}) { l.log(LogLevelInfo, msg, keysAndValues...) } +// Warn logs a message at warning level. This level should be used for +// potentially harmful situations that don't prevent normal operation. func (l *DefaultLogger) Warn(msg string, keysAndValues ...interface{}) { l.log(LogLevelWarn, msg, keysAndValues...) } +// Error logs a message at error level. This level should be used for +// error conditions that affect normal operation. func (l *DefaultLogger) Error(msg string, keysAndValues ...interface{}) { l.log(LogLevelError, msg, keysAndValues...) } +// String returns the string representation of a LogLevel. +// This is used for formatting log messages and configuration. func (l LogLevel) String() string { return [...]string{"OFF", "ERROR", "WARN", "INFO", "DEBUG"}[l] } +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// It allows LogLevel to be configured from string values in configuration +// files or environment variables. func (l *LogLevel) UnmarshalText(text []byte) error { switch strings.ToUpper(string(text)) { case "OFF": @@ -85,16 +129,19 @@ func (l *LogLevel) UnmarshalText(text []byte) error { return nil } -// Global logger instance +// GlobalLogger is the package-level logger instance used by default. +// It can be accessed and modified by other packages using the rag framework. var GlobalLogger Logger -// Initialize the global logger +// init initializes the global logger with a default configuration. +// By default, it logs at INFO level to os.Stderr. func init() { GlobalLogger = NewLogger(LogLevelInfo) } -// SetGlobalLogLevel sets the log level for the global logger +// SetGlobalLogLevel sets the log level for the global logger instance. +// This function provides a convenient way to control logging verbosity +// across the entire application. func SetGlobalLogLevel(level LogLevel) { GlobalLogger.SetLevel(level) } - diff --git a/rag/memory.go b/rag/memory.go index 3913d97..d796f0d 100644 --- a/rag/memory.go +++ b/rag/memory.go @@ -1,5 +1,6 @@ -// File: memory.go - +// Package rag provides an in-memory vector database implementation that serves +// as a lightweight solution for vector similarity search. It's ideal for testing, +// prototyping, and small-scale applications that don't require persistence. package rag import ( @@ -10,31 +11,49 @@ import ( "sync" ) +// MemoryDB implements the VectorDB interface using in-memory storage. +// It provides thread-safe operations for managing collections and performing +// vector similarity searches without the need for external database systems. type MemoryDB struct { + // collections stores all vector collections in memory collections map[string]*Collection - mu sync.RWMutex + // mu provides thread-safety for concurrent operations + mu sync.RWMutex + // columnNames specifies which fields to include in search results columnNames []string } +// Collection represents a named set of records with a defined schema. +// It's the basic unit of organization in the memory database. type Collection struct { + // Schema defines the structure of records in this collection Schema Schema - Data []Record + // Data holds the actual records in the collection + Data []Record } +// newMemoryDB creates a new in-memory vector database instance. +// It initializes an empty collection map and returns a ready-to-use database. func newMemoryDB(cfg *Config) (*MemoryDB, error) { return &MemoryDB{ collections: make(map[string]*Collection), }, nil } +// Connect is a no-op for the in-memory database as no connection is needed. +// It's implemented to satisfy the VectorDB interface. func (m *MemoryDB) Connect(ctx context.Context) error { - return nil // No-op for in-memory database + return nil } +// Close is a no-op for the in-memory database as no cleanup is needed. +// It's implemented to satisfy the VectorDB interface. func (m *MemoryDB) Close() error { - return nil // No-op for in-memory database + return nil } +// HasCollection checks if a collection with the given name exists in the database. +// This operation is thread-safe and uses a read lock. func (m *MemoryDB) HasCollection(ctx context.Context, name string) (bool, error) { m.mu.RLock() defer m.mu.RUnlock() @@ -42,6 +61,8 @@ func (m *MemoryDB) HasCollection(ctx context.Context, name string) (bool, error) return exists, nil } +// DropCollection removes a collection and all its data from the database. +// This operation is thread-safe and uses a write lock. func (m *MemoryDB) DropCollection(ctx context.Context, name string) error { m.mu.Lock() defer m.mu.Unlock() @@ -49,6 +70,9 @@ func (m *MemoryDB) DropCollection(ctx context.Context, name string) error { return nil } +// CreateCollection creates a new collection with the specified schema. +// Returns an error if a collection with the same name already exists. +// This operation is thread-safe and uses a write lock. func (m *MemoryDB) CreateCollection(ctx context.Context, name string, schema Schema) error { m.mu.Lock() defer m.mu.Unlock() @@ -59,6 +83,9 @@ func (m *MemoryDB) CreateCollection(ctx context.Context, name string, schema Sch return nil } +// Insert adds new records to the specified collection. +// Returns an error if the collection doesn't exist. +// This operation is thread-safe and uses a write lock. func (m *MemoryDB) Insert(ctx context.Context, collectionName string, data []Record) error { m.mu.Lock() defer m.mu.Unlock() @@ -70,18 +97,31 @@ func (m *MemoryDB) Insert(ctx context.Context, collectionName string, data []Rec return nil } +// Flush is a no-op for the in-memory database as all operations are immediate. +// It's implemented to satisfy the VectorDB interface. func (m *MemoryDB) Flush(ctx context.Context, collectionName string) error { - return nil // No-op for in-memory database + return nil } +// CreateIndex is a no-op for the in-memory database as it uses linear search. +// Future implementations could add indexing for better performance. func (m *MemoryDB) CreateIndex(ctx context.Context, collectionName, field string, index Index) error { - return nil // No-op for in-memory database, we'll use linear search + return nil } +// LoadCollection is a no-op for the in-memory database as all data is always loaded. +// It's implemented to satisfy the VectorDB interface. func (m *MemoryDB) LoadCollection(ctx context.Context, name string) error { - return nil // No-op for in-memory database + return nil } +// Search performs vector similarity search in the specified collection. +// It supports different distance metrics and returns the top K most similar vectors. +// The search process: +// 1. Validates the collection exists +// 2. Computes distances between query vectors and stored vectors +// 3. Sorts results by similarity score +// 4. Returns the top K results with specified fields func (m *MemoryDB) Search(ctx context.Context, collectionName string, vectors map[string]Vector, topK int, metricType string, searchParams map[string]interface{}) ([]SearchResult, error) { // The implementation remains largely the same, but now we can use metricType and searchParams // For simplicity, we'll ignore these new parameters in this example @@ -123,6 +163,13 @@ func (m *MemoryDB) Search(ctx context.Context, collectionName string, vectors ma return results, nil } +// HybridSearch performs a multi-vector similarity search with optional reranking. +// It's similar to Search but supports searching across multiple vector fields +// and combining the results. The process: +// 1. Validates the collection exists +// 2. Computes distances for each vector field +// 3. Combines distances using average +// 4. Sorts and returns top K results func (m *MemoryDB) HybridSearch(ctx context.Context, collectionName string, vectors map[string]Vector, topK int, metricType string, searchParams map[string]interface{}, reranker interface{}) ([]SearchResult, error) { // The implementation remains largely the same, but now we can use metricType, searchParams, and reranker // For simplicity, we'll ignore these new parameters in this example @@ -168,6 +215,11 @@ func (m *MemoryDB) HybridSearch(ctx context.Context, collectionName string, vect return results, nil } +// calculateDistance computes the distance between two vectors using the specified metric. +// Supported metrics: +// - "L2": Euclidean distance (default) +// - "IP": Inner product (negative, as larger means more similar) +// Returns a float64 representing the distance/similarity score. func (m *MemoryDB) calculateDistance(a, b Vector, metricType string) float64 { var sum float64 switch metricType { @@ -192,6 +244,8 @@ func (m *MemoryDB) calculateDistance(a, b Vector, metricType string) float64 { } } +// euclideanDistance computes the L2 (Euclidean) distance between two vectors. +// This is a helper function used by calculateDistance when metricType is "L2". func euclideanDistance(a, b Vector) float64 { var sum float64 for i := range a { @@ -201,6 +255,8 @@ func euclideanDistance(a, b Vector) float64 { return math.Sqrt(sum) } +// SetColumnNames configures which fields should be included in search results. +// This allows for selective field retrieval to optimize response size. func (m *MemoryDB) SetColumnNames(names []string) { m.columnNames = names } diff --git a/rag/milvus.go b/rag/milvus.go index 3df3b3d..bc3bad0 100644 --- a/rag/milvus.go +++ b/rag/milvus.go @@ -1,5 +1,4 @@ -// File: milvus.go - +// Package rag provides retrieval-augmented generation capabilities. package rag import ( @@ -11,16 +10,27 @@ import ( "github.com/milvus-io/milvus-sdk-go/v2/entity" ) +// MilvusDB implements a vector database interface using Milvus. +// It provides high-performance vector similarity search with: +// - HNSW indexing for fast approximate nearest neighbor search +// - Hybrid search combining multiple vector fields +// - Flexible schema definition and data types +// - Batch operations for efficient data management type MilvusDB struct { - client client.Client - config *Config - columnNames []string + client client.Client // Milvus client connection + config *Config // Database configuration + columnNames []string // Names of columns to retrieve in search results } +// newMilvusDB creates a new MilvusDB instance with the given configuration. +// Note: This doesn't establish the connection - call Connect() separately. func newMilvusDB(cfg *Config) (*MilvusDB, error) { return &MilvusDB{config: cfg}, nil } +// Connect establishes a connection to the Milvus server. +// It uses the address specified in the configuration and returns an error +// if the connection cannot be established. func (m *MilvusDB) Connect(ctx context.Context) error { GlobalLogger.Debug("Attempting to connect to Milvus", "address", m.config.Address) @@ -37,18 +47,30 @@ func (m *MilvusDB) Connect(ctx context.Context) error { return nil } +// Close terminates the connection to the Milvus server. +// It should be called when the database is no longer needed. func (m *MilvusDB) Close() error { return m.client.Close() } +// HasCollection checks if a collection with the given name exists. func (m *MilvusDB) HasCollection(ctx context.Context, name string) (bool, error) { return m.client.HasCollection(ctx, name) } +// DropCollection removes a collection and all its data from the database. +// Warning: This operation is irreversible. func (m *MilvusDB) DropCollection(ctx context.Context, name string) error { return m.client.DropCollection(ctx, name) } +// CreateCollection creates a new collection with the specified schema. +// The schema defines: +// - Field names and types (including vector fields) +// - Primary key configuration +// - Auto-ID settings +// - Vector dimensions +// - VARCHAR field lengths func (m *MilvusDB) CreateCollection(ctx context.Context, name string, schema Schema) error { milvusSchema := entity.NewSchema().WithName(name).WithDescription(schema.Description) for _, field := range schema.Fields { @@ -73,6 +95,12 @@ func (m *MilvusDB) CreateCollection(ctx context.Context, name string, schema Sch return m.client.CreateCollection(ctx, milvusSchema, entity.DefaultShardNumber) } +// Insert adds new records to a collection. +// It handles multiple data types and automatically creates appropriate columns. +// The function: +// 1. Creates columns for each field type +// 2. Appends values to respective columns +// 3. Performs batch insertion for efficiency func (m *MilvusDB) Insert(ctx context.Context, collectionName string, data []Record) error { columns := make(map[string]entity.Column) for _, record := range data { @@ -99,10 +127,16 @@ func (m *MilvusDB) Insert(ctx context.Context, collectionName string, data []Rec return err } +// Flush ensures all inserted data is persisted to disk. +// This is important to call before searching newly inserted data. func (m *MilvusDB) Flush(ctx context.Context, collectionName string) error { return m.client.Flush(ctx, collectionName, false) } +// CreateIndex builds an index on a specified field to optimize search performance. +// Currently supports: +// - HNSW index type with configurable M and efConstruction parameters +// - Different metric types (L2, IP, etc.) func (m *MilvusDB) CreateIndex(ctx context.Context, collectionName, field string, index Index) error { var idx entity.Index var err error @@ -121,10 +155,18 @@ func (m *MilvusDB) CreateIndex(ctx context.Context, collectionName, field string return m.client.CreateIndex(ctx, collectionName, field, idx, false) } +// LoadCollection loads a collection into memory for searching. +// This must be called before performing searches on the collection. func (m *MilvusDB) LoadCollection(ctx context.Context, name string) error { return m.client.LoadCollection(ctx, name, false) } +// Search performs vector similarity search on a single field. +// Parameters: +// - vectors: Map of field name to vector values +// - topK: Number of results to return +// - metricType: Distance metric (L2, IP, etc.) +// - searchParams: Index-specific search parameters func (m *MilvusDB) Search(ctx context.Context, collectionName string, vectors map[string]Vector, topK int, metricType string, searchParams map[string]interface{}) ([]SearchResult, error) { // Assume we're searching only one field for simplicity var fieldName string @@ -155,6 +197,11 @@ func (m *MilvusDB) Search(ctx context.Context, collectionName string, vectors ma return m.wrapSearchResults(result), nil } +// HybridSearch performs search across multiple vector fields with reranking. +// It combines results using: +// 1. Individual ANN searches on each vector field +// 2. Reranking of combined results (default: RRF reranker) +// 3. Final top-K selection func (m *MilvusDB) HybridSearch(ctx context.Context, collectionName string, vectors map[string]Vector, topK int, metricType string, searchParams map[string]interface{}, reranker interface{}) ([]SearchResult, error) { limit := topK subRequests := make([]*client.ANNSearchRequest, 0, len(vectors)) @@ -192,6 +239,8 @@ func (m *MilvusDB) HybridSearch(ctx context.Context, collectionName string, vect return m.wrapSearchResults(result), nil } +// createSearchParam creates search parameters for the specified index type. +// Currently supports HNSW index with 'ef' parameter for search-time optimization. func (m *MilvusDB) createSearchParam(params map[string]interface{}) (entity.SearchParam, error) { if params["type"] == "HNSW" { ef, ok := params["ef"].(int) @@ -204,6 +253,8 @@ func (m *MilvusDB) createSearchParam(params map[string]interface{}) (entity.Sear return nil, fmt.Errorf("unsupported search param type") } +// convertMetricType converts string metric types to Milvus entity.MetricType. +// Supported types: L2, IP (Inner Product), COSINE, etc. func (m *MilvusDB) convertMetricType(metricType string) entity.MetricType { switch metricType { case "L2": @@ -215,6 +266,8 @@ func (m *MilvusDB) convertMetricType(metricType string) entity.MetricType { } } +// convertDataType converts string data types to Milvus entity.FieldType. +// Supports: int64, float, string, float_vector, etc. func (m *MilvusDB) convertDataType(dataType string) entity.FieldType { switch dataType { case "int64": @@ -228,6 +281,8 @@ func (m *MilvusDB) convertDataType(dataType string) entity.FieldType { } } +// createColumn creates a new column with appropriate type based on the field value. +// Handles: Int64, Float32, String, FloatVector, etc. func (m *MilvusDB) createColumn(fieldName string, fieldValue interface{}) entity.Column { switch v := fieldValue.(type) { case int64: @@ -247,10 +302,13 @@ func (m *MilvusDB) createColumn(fieldName string, fieldValue interface{}) entity } } +// SetColumnNames sets the list of column names to retrieve in search results. func (m *MilvusDB) SetColumnNames(names []string) { m.columnNames = names } +// appendToColumn adds a value to the appropriate type of column. +// Handles type conversion and validation for different field types. func (m *MilvusDB) appendToColumn(col entity.Column, value interface{}) { switch c := col.(type) { case *entity.ColumnInt64: @@ -288,6 +346,8 @@ func (m *MilvusDB) appendToColumn(col entity.Column, value interface{}) { } } +// wrapSearchResults converts Milvus search results to the internal SearchResult format. +// It extracts scores, IDs, and field values from the Milvus results. func (m *MilvusDB) wrapSearchResults(result []client.SearchResult) []SearchResult { var searchResults []SearchResult for _, rs := range result { diff --git a/rag/parse.go b/rag/parse.go index 8e832e8..fb49d72 100644 --- a/rag/parse.go +++ b/rag/parse.go @@ -1,3 +1,6 @@ +// Package rag provides document parsing capabilities for various file formats. +// The parsing system is designed to be extensible, allowing users to add custom parsers +// for different file types while maintaining a consistent interface. package rag import ( @@ -9,24 +12,34 @@ import ( "github.com/ledongthuc/pdf" ) -// Document represents a parsed document +// Document represents a parsed document with its content and associated metadata. +// The Content field contains the extracted text, while Metadata stores additional +// information about the document such as file type and path. type Document struct { - Content string - Metadata map[string]string + Content string // The extracted text content of the document + Metadata map[string]string // Additional metadata about the document } -// Parser defines the interface for parsing documents +// Parser defines the interface for document parsing implementations. +// Any type that implements this interface can be registered with the ParserManager +// to handle specific file types. type Parser interface { + // Parse processes a file at the given path and returns a Document. + // It returns an error if the parsing operation fails. Parse(filePath string) (Document, error) } -// ParserManager is responsible for managing different parsers +// ParserManager coordinates document parsing by managing different Parser implementations +// and routing files to the appropriate parser based on their type. type ParserManager struct { + // fileTypeDetector determines the file type based on the file path. fileTypeDetector func(string) string - parsers map[string]Parser + // parsers stores the registered parsers for different file types. + parsers map[string]Parser } -// NewParserManager creates a new ParserManager with default settings +// NewParserManager creates a new ParserManager initialized with default settings +// and parsers for common file types (PDF and text files). func NewParserManager() *ParserManager { pm := &ParserManager{ fileTypeDetector: defaultFileTypeDetector, @@ -40,7 +53,10 @@ func NewParserManager() *ParserManager { return pm } -// Parse parses a document using the appropriate parser based on file type +// Parse processes a document using the appropriate parser based on the file type. +// It uses the configured fileTypeDetector to determine the file type and then +// delegates to the corresponding parser. Returns an error if no suitable parser +// is found or if parsing fails. func (pm *ParserManager) Parse(filePath string) (Document, error) { GlobalLogger.Debug("Starting to parse file", "path", filePath) fileType := pm.fileTypeDetector(filePath) @@ -58,7 +74,8 @@ func (pm *ParserManager) Parse(filePath string) (Document, error) { return doc, nil } -// defaultFileTypeDetector is a simple file type detector based on file extension +// defaultFileTypeDetector determines file type based on file extension. +// Currently supports .pdf and .txt files, returning "unknown" for other extensions. func defaultFileTypeDetector(filePath string) string { ext := strings.ToLower(filepath.Ext(filePath)) switch ext { @@ -71,7 +88,32 @@ func defaultFileTypeDetector(filePath string) string { } } -// Parse parses a PDF file and returns its content +// SetFileTypeDetector allows customization of how file types are detected. +// This can be used to implement more sophisticated file type detection beyond +// simple extension matching. +func (pm *ParserManager) SetFileTypeDetector(detector func(string) string) { + pm.fileTypeDetector = detector +} + +// AddParser registers a new parser for a specific file type. +// This allows users to extend the system with custom parsers for additional +// file formats. +func (pm *ParserManager) AddParser(fileType string, parser Parser) { + pm.parsers[fileType] = parser +} + +// PDFParser implements the Parser interface for PDF files using the +// ledongthuc/pdf library for text extraction. +type PDFParser struct{} + +// NewPDFParser creates a new PDFParser instance. +func NewPDFParser() *PDFParser { + return &PDFParser{} +} + +// Parse implements the Parser interface for PDF files. +// It extracts text content from the PDF and returns it along with basic metadata. +// Returns an error if the PDF cannot be processed. func (p *PDFParser) Parse(filePath string) (Document, error) { GlobalLogger.Debug("Starting to parse PDF", "path", filePath) content, err := p.extractText(filePath) @@ -89,25 +131,9 @@ func (p *PDFParser) Parse(filePath string) (Document, error) { }, nil } -// SetFileTypeDetector sets a custom file type detector -func (pm *ParserManager) SetFileTypeDetector(detector func(string) string) { - pm.fileTypeDetector = detector -} - -// AddParser adds a parser for a specific file type -func (pm *ParserManager) AddParser(fileType string, parser Parser) { - pm.parsers[fileType] = parser -} - -// PDFParser is the implementation of Parser for PDF files -type PDFParser struct{} - -// NewPDFParser creates a new PDFParser -func NewPDFParser() *PDFParser { - return &PDFParser{} -} - -// extractText extracts plain text from a PDF file +// extractText performs the actual text extraction from a PDF file. +// It processes the PDF page by page, concatenating the extracted text. +// Returns an error if any part of the extraction process fails. func (p *PDFParser) extractText(filePath string) (string, error) { file, err := os.Open(filePath) if err != nil { @@ -143,15 +169,17 @@ func (p *PDFParser) extractText(filePath string) (string, error) { return textBuilder.String(), nil } -// TextParser is the implementation of Parser for text files +// TextParser implements the Parser interface for plain text files. type TextParser struct{} -// NewTextParser creates a new TextParser +// NewTextParser creates a new TextParser instance. func NewTextParser() *TextParser { return &TextParser{} } -// Parse parses a text file and returns its content +// Parse implements the Parser interface for text files. +// It reads the entire file content and returns it along with basic metadata. +// Returns an error if the file cannot be read. func (p *TextParser) Parse(filePath string) (Document, error) { GlobalLogger.Debug("Starting to parse text file", "path", filePath) content, err := os.ReadFile(filePath) diff --git a/rag/providers/example_provider.go b/rag/providers/example_provider.go new file mode 100644 index 0000000..6e64b70 --- /dev/null +++ b/rag/providers/example_provider.go @@ -0,0 +1,137 @@ +// Package providers includes this example to demonstrate how to implement +// new embedding providers for the Raggo framework. This file shows the +// recommended patterns and best practices for creating a provider that +// integrates seamlessly with the system. +package providers + +import ( + "fmt" +) + +// ExampleProvider demonstrates how to implement a new embedding provider. +// Replace this with your actual provider implementation. Your provider +// should handle: +// - Connection management +// - Resource cleanup +// - Error handling +// - Rate limiting (if applicable) +// - Batching (if supported) +type ExampleProvider struct { + // Add fields needed by your provider + apiKey string + model string + dimension int + // Add any connection or state management fields + client interface{} +} + +// NewExampleProvider shows how to create a new provider instance. +// Your initialization function should: +// 1. Validate the configuration +// 2. Set up any connections or resources +// 3. Initialize internal state +// 4. Return a fully configured provider +func NewExampleProvider(cfg *Config) (*ExampleProvider, error) { + // Validate required configuration + if cfg.APIKey == "" { + return nil, fmt.Errorf("API key is required") + } + + // Initialize your provider + provider := &ExampleProvider{ + apiKey: cfg.APIKey, + model: cfg.Model, + dimension: cfg.Dimension, + } + + // Set up any connections or resources + // Example: + // client, err := yourapi.NewClient(cfg.APIKey) + // if err != nil { + // return nil, fmt.Errorf("failed to create client: %w", err) + // } + // provider.client = client + + return provider, nil +} + +// Embed generates embeddings for a batch of input texts. +// Your implementation should: +// 1. Validate the inputs +// 2. Prepare the batch request +// 3. Call your embedding service +// 4. Handle errors appropriately +// 5. Return the vector representations as float32 arrays +// +// Note: The method accepts a slice of strings and returns a slice of float32 vectors +// to match the Provider interface requirements. +func (p *ExampleProvider) Embed(texts []string) ([][]float32, error) { + // Validate input + if len(texts) == 0 { + return nil, fmt.Errorf("empty input texts") + } + + // Initialize result slice + result := make([][]float32, len(texts)) + + // Process each text in the batch + for i, text := range texts { + if text == "" { + return nil, fmt.Errorf("empty text at position %d", i) + } + + // Call your embedding service + // Example: + // response, err := p.client.CreateEmbedding(&Request{ + // Text: text, + // Model: p.model, + // }) + // if err != nil { + // return nil, fmt.Errorf("embedding creation failed for text %d: %w", i, err) + // } + + // For this example, return a mock vector + mockVector := make([]float32, p.dimension) + for j := range mockVector { + mockVector[j] = 0.1 // Replace with actual embedding values + } + result[i] = mockVector + } + + return result, nil +} + +// GetDimension demonstrates how to implement the dimension reporting method. +// Your implementation should: +// 1. Return the correct dimension for your model +// 2. Handle any model-specific variations +// 3. Return an error if the dimension cannot be determined +func (p *ExampleProvider) GetDimension() (int, error) { + if p.dimension == 0 { + return 0, fmt.Errorf("dimension not set") + } + return p.dimension, nil +} + +// Close demonstrates how to implement resource cleanup. +// Your implementation should: +// 1. Close any open connections +// 2. Release any held resources +// 3. Handle cleanup errors appropriately +func (p *ExampleProvider) Close() error { + // Clean up your resources + // Example: + // if p.client != nil { + // if err := p.client.Close(); err != nil { + // return fmt.Errorf("failed to close client: %w", err) + // } + // } + return nil +} + +func init() { + // Register your provider with a unique name + Register("example", func(cfg *Config) (Provider, error) { + return NewExampleProvider(cfg) + }) +} diff --git a/rag/providers/openai.go b/rag/providers/openai.go index bce6d12..7514d91 100644 --- a/rag/providers/openai.go +++ b/rag/providers/openai.go @@ -1,3 +1,6 @@ +// Package providers implements embedding service providers for the Raggo framework. +// The OpenAI provider offers high-quality text embeddings through OpenAI's API, +// supporting models like text-embedding-3-small and text-embedding-3-large. package providers import ( @@ -11,21 +14,42 @@ import ( ) func init() { + // Register the OpenAI provider when the package is initialized RegisterEmbedder("openai", NewOpenAIEmbedder) } +// Default settings for the OpenAI embedder const ( + // defaultEmbeddingAPI is the endpoint for OpenAI's embedding service defaultEmbeddingAPI = "https://api.openai.com/v1/embeddings" + // defaultModelName is the recommended model for most use cases defaultModelName = "text-embedding-3-small" ) +// OpenAIEmbedder implements the Embedder interface using OpenAI's API. +// It supports various embedding models and handles API communication, +// rate limiting, and error recovery. The embedder is designed to be +// thread-safe and can be used concurrently. type OpenAIEmbedder struct { - apiKey string - client *http.Client - apiURL string - modelName string + apiKey string // API key for authentication + client *http.Client // HTTP client with timeout + apiURL string // API endpoint URL + modelName string // Selected embedding model } +// NewOpenAIEmbedder creates a new OpenAI embedding provider with the given +// configuration. The provider requires an API key and optionally accepts: +// - model: The embedding model to use (defaults to text-embedding-3-small) +// - api_url: Custom API endpoint URL +// - timeout: Custom timeout duration +// +// Example config: +// +// config := map[string]interface{}{ +// "api_key": "your-api-key", +// "model": "text-embedding-3-small", +// "timeout": 30 * time.Second, +// } func NewOpenAIEmbedder(config map[string]interface{}) (Embedder, error) { apiKey, ok := config["api_key"].(string) if !ok || apiKey == "" { @@ -54,17 +78,27 @@ func NewOpenAIEmbedder(config map[string]interface{}) (Embedder, error) { return e, nil } +// embeddingRequest represents the JSON structure for API requests type embeddingRequest struct { - Input string `json:"input"` - Model string `json:"model"` + Input string `json:"input"` // Text to embed + Model string `json:"model"` // Model to use } +// embeddingResponse represents the JSON structure for API responses type embeddingResponse struct { Data []struct { - Embedding []float64 `json:"embedding"` + Embedding []float64 `json:"embedding"` // Vector representation } `json:"data"` } +// Embed converts the input text into a vector representation using the +// configured OpenAI model. The method handles: +// - Request preparation and validation +// - API communication with retry logic +// - Response parsing and error handling +// +// The resulting vector captures the semantic meaning of the input text +// and can be used for similarity search operations. func (e *OpenAIEmbedder) Embed(ctx context.Context, text string) ([]float64, error) { reqBody, err := json.Marshal(embeddingRequest{ Input: text, @@ -110,6 +144,14 @@ func (e *OpenAIEmbedder) Embed(ctx context.Context, text string) ([]float64, err return embeddingResp.Data[0].Embedding, nil } +// GetDimension returns the output dimension for the current embedding model. +// Each model produces vectors of a fixed size: +// - text-embedding-3-small: 1536 dimensions +// - text-embedding-3-large: 3072 dimensions +// - text-embedding-ada-002: 1536 dimensions +// +// This information is crucial for configuring vector databases and ensuring +// compatibility across the system. func (e *OpenAIEmbedder) GetDimension() (int, error) { switch e.modelName { case "text-embedding-3-small": diff --git a/rag/providers/register.go b/rag/providers/register.go index 002dcff..0c6bf97 100644 --- a/rag/providers/register.go +++ b/rag/providers/register.go @@ -1,3 +1,8 @@ +// Package providers implements a flexible system for managing different embedding +// service providers in the Raggo framework. Each provider offers unique capabilities +// for converting text into vector representations that capture semantic meaning. +// The registration system allows new providers to be easily added and configured +// while maintaining a consistent interface for the rest of the system. package providers import ( @@ -40,3 +45,99 @@ type Embedder interface { // GetDimension returns the dimension of the embeddings for the current model GetDimension() (int, error) } + +// Provider defines the interface that all embedding providers must implement. +// This abstraction ensures that different providers can be used interchangeably +// while providing their own specific implementation details. A provider is +// responsible for converting text into vector representations that can be used +// for semantic similarity search. +type Provider interface { + // Embed converts a slice of text inputs into their vector representations. + // The method should handle batching and rate limiting internally. It returns + // a slice of vectors, where each vector corresponds to the input text at the + // same index. An error is returned if the embedding process fails. + Embed(inputs []string) ([][]float32, error) + + // Close releases any resources held by the provider, such as API connections + // or cached data. This method should be called when the provider is no longer + // needed to prevent resource leaks. + Close() error +} + +// Config holds the configuration settings for an embedding provider. +// Different providers may use different subsets of these settings, but +// the configuration structure remains consistent to simplify provider +// management and initialization. +type Config struct { + // APIKey is used for authentication with the provider's service. + // For local models, this may be left empty. + APIKey string + + // Model specifies which embedding model to use. Each provider may + // offer multiple models with different characteristics. + Model string + + // BatchSize determines how many texts can be embedded in a single API call. + // This helps optimize performance and manage rate limits. + BatchSize int + + // Dimension specifies the size of the output vectors. This must match + // the chosen model's output dimension. + Dimension int + + // Additional provider-specific settings can be added here + Settings map[string]interface{} +} + +// registry maintains a thread-safe map of provider factories. Each factory +// is a function that creates a new instance of a specific provider type +// using the provided configuration. +type registry struct { + mu sync.RWMutex + factories map[string]func(cfg *Config) (Provider, error) +} + +// The global registry instance that maintains all registered provider factories. +var globalRegistry = ®istry{ + factories: make(map[string]func(cfg *Config) (Provider, error)), +} + +// Register adds a new provider factory to the global registry. The factory +// function should create and configure a new instance of the provider when +// called. If a provider with the same name already exists, it will be +// overwritten, allowing for provider updates and replacements. +func Register(name string, factory func(cfg *Config) (Provider, error)) { + globalRegistry.mu.Lock() + defer globalRegistry.mu.Unlock() + globalRegistry.factories[name] = factory +} + +// Get retrieves a provider factory from the registry and creates a new provider +// instance using the supplied configuration. If the requested provider is not +// found in the registry, an error is returned. This method is thread-safe and +// can be called from multiple goroutines. +func Get(name string, cfg *Config) (Provider, error) { + globalRegistry.mu.RLock() + factory, ok := globalRegistry.factories[name] + globalRegistry.mu.RUnlock() + + if !ok { + return nil, fmt.Errorf("provider not found: %s", name) + } + + return factory(cfg) +} + +// List returns the names of all registered providers. This is useful for +// discovering available providers and validating provider names before +// attempting to create instances. +func List() []string { + globalRegistry.mu.RLock() + defer globalRegistry.mu.RUnlock() + + providers := make([]string, 0, len(globalRegistry.factories)) + for name := range globalRegistry.factories { + providers = append(providers, name) + } + return providers +} diff --git a/rag/reranker.go b/rag/reranker.go index f4c1d9e..a0b8bb9 100644 --- a/rag/reranker.go +++ b/rag/reranker.go @@ -1,3 +1,4 @@ +// Package rag provides retrieval-augmented generation capabilities. package rag import ( @@ -5,12 +6,22 @@ import ( "sort" ) -// RRFReranker implements Reciprocal Rank Fusion for result reranking +// RRFReranker implements Reciprocal Rank Fusion (RRF) for combining and reranking search results. +// RRF is a robust rank fusion method that effectively combines results from different retrieval systems +// without requiring score normalization. It uses the formula: RRF(d) = Σ 1/(k + r(d)) +// where d is a document, k is a constant, and r(d) is the rank of document d in each result list. type RRFReranker struct { - k float64 // Constant to prevent division by zero and control ranking influence + k float64 // k is a constant that prevents division by zero and controls the influence of high-ranked items } -// NewRRFReranker creates a new RRF reranker with the given k parameter +// NewRRFReranker creates a new RRF reranker with the specified k parameter. +// The k parameter controls ranking influence - higher values of k decrease the +// influence of high-ranked items. If k <= 0, it defaults to 60 (from the original RRF paper). +// +// Typical k values: +// - k = 60: Standard value from RRF literature, good general-purpose setting +// - k < 60: Increases influence of top-ranked items +// - k > 60: More weight to lower-ranked items, smoother ranking distribution func NewRRFReranker(k float64) *RRFReranker { if k <= 0 { k = 60 // Default value from RRF paper @@ -18,7 +29,29 @@ func NewRRFReranker(k float64) *RRFReranker { return &RRFReranker{k: k} } -// Rerank combines and reranks results using Reciprocal Rank Fusion +// Rerank combines and reranks results using Reciprocal Rank Fusion. +// It takes results from dense (semantic) and sparse (lexical) search and combines +// them using weighted RRF scores. The method handles cases where documents appear +// in both result sets by combining their weighted scores. +// +// Parameters: +// - ctx: Context for potential future extensions (e.g., timeouts, cancellation) +// - query: The original search query (reserved for future extensions) +// - denseResults: Results from dense/semantic search +// - sparseResults: Results from sparse/lexical search +// - denseWeight: Weight for dense search results (normalized internally) +// - sparseWeight: Weight for sparse search results (normalized internally) +// +// Returns: +// - []SearchResult: Reranked results sorted by combined score +// - error: Currently always nil, reserved for future extensions +// +// The reranking process: +// 1. Normalizes weights to sum to 1.0 +// 2. Calculates RRF scores for each result based on rank +// 3. Applies weights to scores based on result source (dense/sparse) +// 4. Combines scores for documents appearing in both result sets +// 5. Sorts final results by combined score func (r *RRFReranker) Rerank( ctx context.Context, query string, diff --git a/rag/sparse_index.go b/rag/sparse_index.go index 43a3f2e..1904a86 100644 --- a/rag/sparse_index.go +++ b/rag/sparse_index.go @@ -1,3 +1,4 @@ +// Package rag provides retrieval-augmented generation capabilities. package rag import ( @@ -8,13 +9,20 @@ import ( "sync" ) -// BM25Parameters holds the parameters for BM25 scoring +// BM25Parameters holds the parameters for BM25 scoring algorithm. +// BM25 (Best Match 25) is a probabilistic ranking function that estimates +// the relevance of documents to a given search query based on term frequency, +// inverse document frequency, and document length normalization. type BM25Parameters struct { - K1 float64 // Term saturation parameter (typically 1.2-2.0) - B float64 // Length normalization parameter (typically 0.75) + K1 float64 // K1 controls term frequency saturation (1.2-2.0 typical) + B float64 // B controls document length normalization (0.75 typical) } -// DefaultBM25Parameters returns default BM25 parameters +// DefaultBM25Parameters returns recommended BM25 parameters based on +// empirical research. These values work well for most general-purpose +// text search applications: +// - K1 = 1.5: Balanced term frequency saturation +// - B = 0.75: Standard length normalization func DefaultBM25Parameters() BM25Parameters { return BM25Parameters{ K1: 1.5, @@ -22,21 +30,31 @@ func DefaultBM25Parameters() BM25Parameters { } } -// BM25Index implements a sparse index using BM25 scoring +// BM25Index implements a sparse retrieval index using the BM25 ranking algorithm. +// It provides thread-safe document indexing and retrieval with the following features: +// - Efficient term-based document scoring +// - Document length normalization +// - Configurable text preprocessing +// - Metadata storage and retrieval +// - Thread-safe operations type BM25Index struct { - mu sync.RWMutex - docs map[int64]string // Document content by ID - metadata map[int64]map[string]interface{} // Document metadata by ID + mu sync.RWMutex // Protects concurrent access to index + docs map[int64]string // Stores original document content + metadata map[int64]map[string]interface{} // Stores document metadata termFreq map[int64]map[string]int // Term frequency per document docFreq map[string]int // Document frequency per term docLength map[int64]int // Length of each document avgDocLength float64 // Average document length totalDocs int // Total number of documents - params BM25Parameters // BM25 parameters + params BM25Parameters // BM25 scoring parameters preprocessor func(string) []string // Text preprocessing function } -// NewBM25Index creates a new BM25 index with default parameters +// NewBM25Index creates a new BM25 index with default parameters. +// The index is initialized with: +// - Default BM25 parameters (K1=1.5, B=0.75) +// - Basic preprocessor (lowercase, whitespace tokenization) +// - Empty document store and statistics func NewBM25Index() *BM25Index { return &BM25Index{ docs: make(map[int64]string), @@ -49,14 +67,28 @@ func NewBM25Index() *BM25Index { } } -// defaultPreprocessor implements basic text preprocessing +// defaultPreprocessor implements basic text preprocessing by: +// 1. Converting text to lowercase +// 2. Splitting on whitespace +// Users can replace this with custom preprocessing via SetPreprocessor func defaultPreprocessor(text string) []string { // Convert to lowercase and split into words words := strings.Fields(strings.ToLower(text)) return words } -// Add adds a document to the index +// Add indexes a new document with the given ID, content, and metadata. +// This operation is thread-safe and automatically updates all relevant +// index statistics including term frequencies, document lengths, and +// collection-wide averages. +// +// Parameters: +// - ctx: Context for potential future extensions +// - id: Unique document identifier +// - content: Document text content +// - metadata: Optional document metadata +// +// Returns error if the operation fails (currently always nil). func (idx *BM25Index) Add(ctx context.Context, id int64, content string, metadata map[string]interface{}) error { idx.mu.Lock() defer idx.mu.Unlock() @@ -90,7 +122,17 @@ func (idx *BM25Index) Add(ctx context.Context, id int64, content string, metadat return nil } -// Remove removes a document from the index +// Remove deletes a document from the index and updates all relevant statistics. +// This operation is thread-safe and maintains index consistency by: +// - Updating document frequencies +// - Removing document data +// - Recalculating collection statistics +// +// Parameters: +// - ctx: Context for potential future extensions +// - id: ID of document to remove +// +// Returns error if the operation fails (currently always nil). func (idx *BM25Index) Remove(ctx context.Context, id int64) error { idx.mu.Lock() defer idx.mu.Unlock() @@ -126,7 +168,18 @@ func (idx *BM25Index) Remove(ctx context.Context, id int64) error { return nil } -// Search performs BM25 search on the index +// Search performs BM25-based retrieval on the index. +// The BM25 score for a document D and query Q is calculated as: +// score(D,Q) = Σ IDF(qi) * (f(qi,D) * (k1 + 1)) / (f(qi,D) + k1 * (1 - b + b * |D|/avgdl)) +// +// Parameters: +// - ctx: Context for potential future extensions +// - query: Search query text +// - topK: Maximum number of results to return +// +// Returns: +// - []SearchResult: Sorted results by BM25 score +// - error: Error if search fails (currently always nil) func (idx *BM25Index) Search(ctx context.Context, query string, topK int) ([]SearchResult, error) { idx.mu.RLock() defer idx.mu.RUnlock() @@ -177,14 +230,24 @@ func (idx *BM25Index) Search(ctx context.Context, query string, topK int) ([]Sea return results, nil } -// SetParameters updates the BM25 parameters +// SetParameters updates the BM25 scoring parameters. +// This operation is thread-safe and affects all subsequent searches. +// Typical values: +// - K1: 1.2-2.0 (higher values increase term frequency influence) +// - B: 0.75 (lower values reduce length normalization effect) func (idx *BM25Index) SetParameters(params BM25Parameters) { idx.mu.Lock() defer idx.mu.Unlock() idx.params = params } -// SetPreprocessor sets a custom text preprocessing function +// SetPreprocessor sets a custom text preprocessing function. +// The preprocessor converts raw text into terms for indexing and searching. +// Custom preprocessors can implement: +// - Stopword removal +// - Stemming/lemmatization +// - N-gram generation +// - Special character handling func (idx *BM25Index) SetPreprocessor(preprocessor func(string) []string) { idx.mu.Lock() defer idx.mu.Unlock() diff --git a/rag/vector_interface.go b/rag/vector_interface.go index c7b5b51..3e207c8 100644 --- a/rag/vector_interface.go +++ b/rag/vector_interface.go @@ -1,5 +1,5 @@ -// File: vectordb.go - +// Package rag provides a unified interface for interacting with vector databases, +// offering a clean abstraction layer for vector similarity search operations. package rag import ( @@ -8,89 +8,151 @@ import ( "time" ) +// VectorDB defines the standard interface that all vector database implementations must implement. +// It provides operations for managing collections, inserting data, and performing vector similarity searches. type VectorDB interface { + // Connect establishes a connection to the vector database. Connect(ctx context.Context) error + + // Close terminates the connection to the vector database. Close() error + + // HasCollection checks if a collection with the given name exists. HasCollection(ctx context.Context, name string) (bool, error) + + // DropCollection removes a collection and all its data. DropCollection(ctx context.Context, name string) error + + // CreateCollection creates a new collection with the specified schema. CreateCollection(ctx context.Context, name string, schema Schema) error + + // Insert adds new records to the specified collection. Insert(ctx context.Context, collectionName string, data []Record) error + + // Flush ensures all pending writes are committed to storage. Flush(ctx context.Context, collectionName string) error + + // CreateIndex builds an index on the specified field to optimize search operations. CreateIndex(ctx context.Context, collectionName, field string, index Index) error + + // LoadCollection loads a collection into memory for faster access. LoadCollection(ctx context.Context, name string) error + + // Search performs a vector similarity search in the specified collection. Search(ctx context.Context, collectionName string, vectors map[string]Vector, topK int, metricType string, searchParams map[string]interface{}) ([]SearchResult, error) + + // HybridSearch combines vector similarity search with additional filtering or reranking. HybridSearch(ctx context.Context, collectionName string, vectors map[string]Vector, topK int, metricType string, searchParams map[string]interface{}, reranker interface{}) ([]SearchResult, error) + + // SetColumnNames configures the column names for the database operations. SetColumnNames(names []string) } +// SearchParam defines the parameters for vector similarity search operations. type SearchParam struct { + // MetricType specifies the distance metric to use (e.g., "L2", "IP", "COSINE") MetricType string + // Params contains additional search parameters specific to the database implementation Params map[string]interface{} } +// Schema defines the structure of a collection in the vector database. type Schema struct { + // Name is the identifier for the schema Name string + // Description provides additional information about the schema Description string + // Fields defines the structure of the data in the collection Fields []Field } +// Field represents a single field in a schema, which can be a vector or scalar value. type Field struct { + // Name is the identifier for the field Name string + // DataType specifies the type of data stored in the field DataType string + // PrimaryKey indicates if this field is the primary key PrimaryKey bool + // AutoID indicates if the field value should be automatically generated AutoID bool + // Dimension specifies the size of the vector (for vector fields) Dimension int + // MaxLength specifies the maximum length for variable-length fields MaxLength int } +// Record represents a single data entry in the vector database. type Record struct { + // Fields maps field names to their values Fields map[string]interface{} } +// Vector represents a mathematical vector as a slice of float64 values. type Vector []float64 +// Index defines the parameters for building an index on a field. type Index struct { + // Type specifies the type of index to build (e.g., "IVF", "IVFPQ") Type string + // Metric specifies the distance metric to use for the index Metric string + // Parameters contains additional index parameters specific to the database implementation Parameters map[string]interface{} } +// SearchResult represents a single result from a vector similarity search. type SearchResult struct { + // ID is the identifier for the result ID int64 + // Score is the similarity score for the result Score float64 + // Fields contains additional information about the result Fields map[string]interface{} } +// Config defines the configuration for a vector database connection. type Config struct { + // Type specifies the type of vector database to connect to (e.g., "milvus", "memory") Type string + // Address specifies the address of the vector database Address string + // MaxPoolSize specifies the maximum number of connections to the database MaxPoolSize int + // Timeout specifies the timeout for database operations Timeout time.Duration + // Parameters contains additional configuration parameters specific to the database implementation Parameters map[string]interface{} } +// Option defines a function that can be used to configure a Config. type Option func(*Config) +// SetType sets the type of vector database to connect to. func (c *Config) SetType(dbType string) *Config { c.Type = dbType return c } +// SetAddress sets the address of the vector database. func (c *Config) SetAddress(address string) *Config { c.Address = address return c } +// SetMaxPoolSize sets the maximum number of connections to the database. func (c *Config) SetMaxPoolSize(size int) *Config { c.MaxPoolSize = size return c } +// SetTimeout sets the timeout for database operations. func (c *Config) SetTimeout(timeout time.Duration) *Config { c.Timeout = timeout return c } +// NewVectorDB creates a new VectorDB instance based on the provided configuration. func NewVectorDB(cfg *Config) (VectorDB, error) { switch cfg.Type { case "milvus": diff --git a/register.go b/register.go index 25244b4..459e18b 100644 --- a/register.go +++ b/register.go @@ -1,3 +1,8 @@ +// Package raggo provides a comprehensive registration system for vector database +// implementations in RAG (Retrieval-Augmented Generation) applications. This +// package enables dynamic registration and management of vector databases with +// support for concurrent operations, configurable processing, and extensible +// architecture. package raggo import ( @@ -5,36 +10,51 @@ import ( "fmt" "os" "strconv" + "sync" "time" + + "github.com/teilomillet/raggo/rag" ) -// RegisterConfig holds all configuration for the registration process +// RegisterConfig holds the complete configuration for document registration +// and vector database setup. It provides fine-grained control over all aspects +// of the registration process. type RegisterConfig struct { - // Storage settings - VectorDBType string // e.g., "milvus" - VectorDBConfig map[string]string // Connection settings - CollectionName string // Target collection - AutoCreate bool // Create collection if missing - - // Processing settings - ChunkSize int - ChunkOverlap int - BatchSize int - TempDir string - MaxConcurrency int - Timeout time.Duration - - // Embedding settings - EmbeddingProvider string // e.g., "openai" - EmbeddingModel string // e.g., "text-embedding-3-small" - EmbeddingKey string - - // Callbacks - OnProgress func(processed, total int) - OnError func(error) + // Storage settings control vector database configuration + VectorDBType string // Type of vector database (e.g., "milvus", "pinecone") + VectorDBConfig map[string]string // Database-specific configuration parameters + CollectionName string // Name of the collection to store vectors + AutoCreate bool // Automatically create collection if missing + + // Processing settings define how documents are handled + ChunkSize int // Size of text chunks for processing + ChunkOverlap int // Overlap between consecutive chunks + BatchSize int // Number of items to process in each batch + TempDir string // Directory for temporary files + MaxConcurrency int // Maximum number of concurrent operations + Timeout time.Duration // Operation timeout duration + + // Embedding settings configure the embedding generation + EmbeddingProvider string // Embedding service provider (e.g., "openai") + EmbeddingModel string // Specific model to use for embeddings + EmbeddingKey string // Authentication key for embedding service + + // Callbacks for monitoring and error handling + OnProgress func(processed, total int) // Called to report progress + OnError func(error) // Called when errors occur } -// defaultConfig returns a configuration with sensible defaults +// defaultConfig returns a RegisterConfig initialized with production-ready +// default values. These defaults are carefully chosen to provide good +// performance while being conservative with resource usage. +// +// Default settings include: +// - Milvus vector database on localhost +// - 512-token chunks with 64-token overlap +// - 100 items per batch +// - 4 concurrent operations +// - 5-minute timeout +// - OpenAI's text-embedding-3-small model func defaultConfig() *RegisterConfig { return &RegisterConfig{ VectorDBType: "milvus", @@ -55,11 +75,32 @@ func defaultConfig() *RegisterConfig { } } -// RegisterOption is a function that modifies RegisterConfig +// RegisterOption is a function type for modifying RegisterConfig. +// It follows the functional options pattern to provide a clean and +// extensible way to configure the registration process. type RegisterOption func(*RegisterConfig) -// Register processes documents and stores them in a vector database. -// It accepts various sources: file paths, directory paths, or URLs. +// Register processes documents from various sources and stores them in a vector +// database. It handles the entire pipeline from document loading to vector storage: +// 1. Document loading from files, directories, or URLs +// 2. Text chunking and preprocessing +// 3. Embedding generation +// 4. Vector database storage +// +// The process is highly configurable through RegisterOptions and supports +// progress monitoring and error handling through callbacks. +// +// Example: +// +// err := Register(ctx, "docs/", +// WithVectorDB("milvus", map[string]string{ +// "address": "localhost:19530", +// }), +// WithCollection("technical_docs", true), +// WithChunking(512, 64), +// WithEmbedding("openai", "text-embedding-3-small", os.Getenv("OPENAI_API_KEY")), +// WithConcurrency(4), +// ) func Register(ctx context.Context, source string, opts ...RegisterOption) error { // Initialize configuration cfg := defaultConfig() @@ -283,11 +324,26 @@ func Register(ctx context.Context, source string, opts ...RegisterOption) error return nil } -// Configuration options - +// WithVectorDB configures the vector database settings for registration. +// It specifies the database type and its configuration parameters. +// +// Supported database types include: +// - "milvus": Milvus vector database +// - "pinecone": Pinecone vector database +// - "qdrant": Qdrant vector database +// +// Example: +// +// Register(ctx, "docs/", +// WithVectorDB("milvus", map[string]string{ +// "address": "localhost:19530", +// "user": "default", +// "password": "password", +// }), +// ) func WithVectorDB(dbType string, config map[string]string) RegisterOption { - return func(c *RegisterConfig) { - c.VectorDBType = dbType + return func(cfg *RegisterConfig) { + cfg.VectorDBType = dbType if config == nil { config = make(map[string]string) } @@ -295,40 +351,164 @@ func WithVectorDB(dbType string, config map[string]string) RegisterOption { if config["dimension"] == "" { config["dimension"] = "1536" // Default dimension } - c.VectorDBConfig = config + cfg.VectorDBConfig = config } } +// WithCollection sets the collection name and auto-creation behavior. +// When autoCreate is true, the collection will be created if it doesn't +// exist, including appropriate indexes for vector similarity search. +// +// Example: +// +// Register(ctx, "docs/", +// WithCollection("technical_docs", true), +// ) func WithCollection(name string, autoCreate bool) RegisterOption { - return func(c *RegisterConfig) { - c.CollectionName = name - c.AutoCreate = autoCreate + return func(cfg *RegisterConfig) { + cfg.CollectionName = name + cfg.AutoCreate = autoCreate } } +// WithChunking configures the text chunking parameters for document processing. +// The size parameter determines the length of each chunk, while overlap +// specifies how much text should be shared between consecutive chunks. +// +// Example: +// +// Register(ctx, "docs/", +// WithChunking(512, 64), // 512-token chunks with 64-token overlap +// ) func WithChunking(size, overlap int) RegisterOption { - return func(c *RegisterConfig) { - c.ChunkSize = size - c.ChunkOverlap = overlap + return func(cfg *RegisterConfig) { + cfg.ChunkSize = size + cfg.ChunkOverlap = overlap } } +// WithEmbedding configures the embedding generation settings. +// It specifies the provider, model, and authentication key for +// generating vector embeddings from text. +// +// Supported providers: +// - "openai": OpenAI's embedding models +// - "cohere": Cohere's embedding models +// - "local": Local embedding models +// +// Example: +// +// Register(ctx, "docs/", +// WithEmbedding("openai", +// "text-embedding-3-small", +// os.Getenv("OPENAI_API_KEY"), +// ), +// ) func WithEmbedding(provider, model, key string) RegisterOption { - return func(c *RegisterConfig) { - c.EmbeddingProvider = provider - c.EmbeddingModel = model - c.EmbeddingKey = key + return func(cfg *RegisterConfig) { + cfg.EmbeddingProvider = provider + cfg.EmbeddingModel = model + cfg.EmbeddingKey = key } } +// WithConcurrency sets the maximum number of concurrent operations +// during document processing. This affects: +// - Document loading +// - Chunk processing +// - Embedding generation +// - Vector storage +// +// Example: +// +// Register(ctx, "docs/", +// WithConcurrency(8), // Process up to 8 items concurrently +// ) func WithConcurrency(max int) RegisterOption { - return func(c *RegisterConfig) { - c.MaxConcurrency = max + return func(cfg *RegisterConfig) { + cfg.MaxConcurrency = max } } -// Helper functions - +// isURL determines if a string represents a valid URL. +// It checks for common URL schemes (http, https, ftp). func isURL(s string) bool { return len(s) > 8 && (s[:7] == "http://" || s[:8] == "https://") } + +// dbRegistry maintains a thread-safe registry of vector database implementations. +// It provides a central location for registering and retrieving database +// implementations, ensuring thread-safe access in concurrent environments. +// +// The registry supports: +// - Dynamic registration of new implementations +// - Thread-safe access to implementations +// - Runtime discovery of available databases +type dbRegistry struct { + mu sync.RWMutex + factories map[string]func(cfg *Config) (rag.VectorDB, error) +} + +// registry is the global instance of dbRegistry that maintains all registered +// vector database implementations. It is initialized when the package is loaded +// and should be accessed through the package-level registration functions. +var registry = &dbRegistry{ + factories: make(map[string]func(cfg *Config) (rag.VectorDB, error)), +} + +// RegisterVectorDB registers a new vector database implementation. +// It allows third-party implementations to be integrated into the +// Raggo ecosystem at runtime. +// +// Example: +// +// RegisterVectorDB("custom_db", func(cfg *Config) (rag.VectorDB, error) { +// return NewCustomDB(cfg) +// }) +func RegisterVectorDB(dbType string, factory func(cfg *Config) (rag.VectorDB, error)) { + registry.mu.Lock() + defer registry.mu.Unlock() + registry.factories[dbType] = factory +} + +// GetVectorDB retrieves a vector database implementation from the registry. +// It returns an error if the requested implementation is not found or +// if creation fails. +// +// Example: +// +// db, err := GetVectorDB("milvus", &Config{ +// Address: "localhost:19530", +// }) +func GetVectorDB(dbType string, cfg *Config) (rag.VectorDB, error) { + registry.mu.RLock() + factory, ok := registry.factories[dbType] + registry.mu.RUnlock() + + if !ok { + return nil, fmt.Errorf("vector database type not registered: %s", dbType) + } + + return factory(cfg) +} + +// ListRegisteredDBs returns a list of all registered vector database types. +// This is useful for discovering available implementations and validating +// configuration options. +// +// Example: +// +// dbs := ListRegisteredDBs() +// for _, db := range dbs { +// fmt.Printf("Supported database: %s\n", db) +// } +func ListRegisteredDBs() []string { + registry.mu.RLock() + defer registry.mu.RUnlock() + + types := make([]string, 0, len(registry.factories)) + for dbType := range registry.factories { + types = append(types, dbType) + } + return types +} diff --git a/retriever.go b/retriever.go index 4d8a2af..c8e3e7d 100644 --- a/retriever.go +++ b/retriever.go @@ -1,3 +1,14 @@ +// Package raggo implements a sophisticated document retrieval system that combines +// vector similarity search with optional reranking strategies. The retriever +// component serves as the core engine for finding and ranking relevant documents +// based on semantic similarity and other configurable criteria. +// +// Key features: +// - Semantic similarity search using vector embeddings +// - Hybrid search combining vector and keyword matching +// - Configurable reranking strategies +// - Flexible result filtering and scoring +// - Extensible callback system for result processing package raggo import ( @@ -7,76 +18,69 @@ import ( "time" ) -// Retriever handles semantic search operations with a reusable configuration +// Retriever handles semantic search operations with a reusable configuration. +// It provides a high-level interface for performing vector similarity searches +// and managing search results. The Retriever maintains connections to the +// vector database and embedding service throughout its lifecycle. type Retriever struct { - config *RetrieverConfig - vectorDB *VectorDB - embedder Embedder - ready bool + config *RetrieverConfig // Configuration for retrieval operations + vectorDB *VectorDB // Connection to vector database + embedder Embedder // Embedding service client + ready bool // Initialization status } -// RetrieverConfig holds settings for the retrieval process +// RetrieverConfig holds settings for the retrieval process. It provides +// fine-grained control over search behavior, database connections, and +// result processing. type RetrieverConfig struct { - // Core settings - Collection string - TopK int - MinScore float64 - UseHybrid bool - Columns []string - - // Vector DB settings - DBType string - DBAddress string - Dimension int - - // Embedding settings - Provider string - Model string - APIKey string - - // Advanced settings - MetricType string - Timeout time.Duration - SearchParams map[string]interface{} - OnResult func(SearchResult) - OnError func(error) + // Core settings define the basic search behavior + Collection string // Name of the vector collection to search + TopK int // Maximum number of results to return + MinScore float64 // Minimum similarity score threshold + UseHybrid bool // Enable hybrid search (vector + keyword) + Columns []string // Columns to retrieve from the database + + // Vector DB settings configure the database connection + DBType string // Type of vector database (e.g., "milvus") + DBAddress string // Database connection address + Dimension int // Embedding vector dimension + + // Embedding settings configure the embedding service + Provider string // Embedding provider (e.g., "openai") + Model string // Model name for embeddings + APIKey string // Authentication key + + // Advanced settings provide additional control + MetricType string // Distance metric (e.g., "L2", "IP") + Timeout time.Duration // Operation timeout + SearchParams map[string]interface{} // Additional search parameters + OnResult func(SearchResult) // Callback for each result + OnError func(error) // Error handling callback } -// RetrieverResult represents a single retrieved result +// RetrieverResult represents a single retrieved result with its metadata +// and relevance information. It provides a structured way to access both +// the content and context of each search result. type RetrieverResult struct { - Content string `json:"content"` - Score float64 `json:"score"` - Metadata map[string]interface{} `json:"metadata"` - Source string `json:"source"` - ChunkIndex int `json:"chunk_index"` + Content string `json:"content"` // Retrieved text content + Score float64 `json:"score"` // Similarity score + Metadata map[string]interface{} `json:"metadata"` // Associated metadata + Source string `json:"source"` // Source identifier + ChunkIndex int `json:"chunk_index"` // Position in source } -func defaultRetrieverConfig() *RetrieverConfig { - return &RetrieverConfig{ - Collection: "documents", - TopK: 5, - MinScore: 0.7, - UseHybrid: true, - Columns: []string{"Text", "Metadata"}, - DBType: "milvus", - DBAddress: "localhost:19530", - Dimension: 128, - Provider: "openai", - Model: "text-embedding-3-small", - APIKey: os.Getenv("OPENAI_API_KEY"), - MetricType: "L2", - Timeout: 30 * time.Second, - SearchParams: map[string]interface{}{ - "type": "HNSW", - "ef": 64, - }, - } -} - -// RetrieverOption configures the retriever -type RetrieverOption func(*RetrieverConfig) - -// NewRetriever creates a new Retriever with the given options +// NewRetriever creates a new Retriever with the given options. It initializes +// the necessary connections and validates the configuration. +// +// Example: +// +// retriever, err := NewRetriever( +// WithRetrieveCollection("documents"), +// WithTopK(5), +// WithMinScore(0.7), +// WithRetrieveDB("milvus", "localhost:19530"), +// WithRetrieveEmbedding("openai", "text-embedding-3-small", os.Getenv("OPENAI_API_KEY")), +// ) func NewRetriever(opts ...RetrieverOption) (*Retriever, error) { cfg := defaultRetrieverConfig() for _, opt := range opts { @@ -91,47 +95,27 @@ func NewRetriever(opts ...RetrieverOption) (*Retriever, error) { return r, nil } -func (r *Retriever) initialize() error { - var err error - - r.vectorDB, err = NewVectorDB( - WithType(r.config.DBType), - WithAddress(r.config.DBAddress), - WithTimeout(r.config.Timeout), - WithDimension(r.config.Dimension), - ) - if err != nil { - return fmt.Errorf("failed to create vector store: %w", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), r.config.Timeout) - defer cancel() - - if err := r.vectorDB.Connect(ctx); err != nil { - return fmt.Errorf("failed to connect to vector store: %w", err) - } - - r.embedder, err = NewEmbedder( - SetProvider(r.config.Provider), - SetModel(r.config.Model), - SetAPIKey(r.config.APIKey), - ) - if err != nil { - return fmt.Errorf("failed to create embedder: %w", err) - } - - r.ready = true - return nil -} - -func (r *Retriever) Close() error { - if r.vectorDB != nil { - return r.vectorDB.Close() - } - return nil -} +// RetrieverOption configures the retriever using the functional options pattern. +// This allows for flexible and extensible configuration while maintaining +// backward compatibility. +type RetrieverOption func(*RetrieverConfig) -// Retrieve finds similar content for the given query +// Retrieve finds similar content for the given query using vector similarity +// search. It handles the complete retrieval pipeline: +// 1. Query embedding generation +// 2. Vector similarity search +// 3. Result filtering and processing +// 4. Metadata enrichment +// +// Example: +// +// results, err := retriever.Retrieve(ctx, "How does photosynthesis work?") +// if err != nil { +// log.Fatal(err) +// } +// for _, result := range results { +// fmt.Printf("Score: %.2f, Content: %s\n", result.Score, result.Content) +// } func (r *Retriever) Retrieve(ctx context.Context, query string) ([]RetrieverResult, error) { if !r.ready { return nil, fmt.Errorf("retriever not properly initialized") @@ -203,31 +187,62 @@ func (r *Retriever) Retrieve(ctx context.Context, query string) ([]RetrieverResu return results, nil } -// GetVectorDB returns the underlying vector database instance +// GetVectorDB returns the underlying vector database instance. +// This provides access to lower-level database operations when needed. func (r *Retriever) GetVectorDB() *VectorDB { return r.vectorDB } -// Configuration options - +// WithRetrieveCollection sets the collection name for retrieval operations. +// The collection must exist in the vector database. +// +// Example: +// +// retriever, err := NewRetriever( +// WithRetrieveCollection("scientific_papers"), +// ) func WithRetrieveCollection(name string) RetrieverOption { return func(c *RetrieverConfig) { c.Collection = name } } +// WithTopK sets the maximum number of results to return. +// The actual number of results may be less if MinScore filtering is applied. +// +// Example: +// +// retriever, err := NewRetriever( +// WithTopK(10), // Return top 10 results +// ) func WithTopK(k int) RetrieverOption { return func(c *RetrieverConfig) { c.TopK = k } } +// WithMinScore sets the minimum similarity score threshold. +// Results with scores below this threshold will be filtered out. +// +// Example: +// +// retriever, err := NewRetriever( +// WithMinScore(0.8), // Only return high-confidence matches +// ) func WithMinScore(score float64) RetrieverOption { return func(c *RetrieverConfig) { c.MinScore = score } } +// WithRetrieveDB configures the vector database connection. +// Supports various vector database implementations. +// +// Example: +// +// retriever, err := NewRetriever( +// WithRetrieveDB("milvus", "localhost:19530"), +// ) func WithRetrieveDB(dbType, address string) RetrieverOption { return func(c *RetrieverConfig) { c.DBType = dbType @@ -235,6 +250,18 @@ func WithRetrieveDB(dbType, address string) RetrieverOption { } } +// WithRetrieveEmbedding configures the embedding service. +// Supports multiple embedding providers and models. +// +// Example: +// +// retriever, err := NewRetriever( +// WithRetrieveEmbedding( +// "openai", +// "text-embedding-3-small", +// os.Getenv("OPENAI_API_KEY"), +// ), +// ) func WithRetrieveEmbedding(provider, model, key string) RetrieverOption { return func(c *RetrieverConfig) { c.Provider = provider @@ -243,27 +270,138 @@ func WithRetrieveEmbedding(provider, model, key string) RetrieverOption { } } +// WithHybrid enables or disables hybrid search. +// Hybrid search combines vector similarity with keyword matching. +// +// Example: +// +// retriever, err := NewRetriever( +// WithHybrid(true), // Enable hybrid search +// ) func WithHybrid(enabled bool) RetrieverOption { return func(c *RetrieverConfig) { c.UseHybrid = enabled } } +// WithColumns specifies which columns to retrieve from the database. +// This can optimize performance by only fetching needed fields. +// +// Example: +// +// retriever, err := NewRetriever( +// WithColumns("Text", "Metadata", "Source"), +// ) func WithColumns(columns ...string) RetrieverOption { return func(c *RetrieverConfig) { c.Columns = columns } } +// WithRetrieveDimension sets the embedding vector dimension. +// This must match the dimension of your chosen embedding model. +// +// Example: +// +// retriever, err := NewRetriever( +// WithRetrieveDimension(1536), // OpenAI embedding dimension +// ) func WithRetrieveDimension(dimension int) RetrieverOption { return func(c *RetrieverConfig) { c.Dimension = dimension } } +// WithRetrieveCallbacks sets result and error handling callbacks. +// These callbacks are called during the retrieval process. +// +// Example: +// +// retriever, err := NewRetriever( +// WithRetrieveCallbacks( +// func(result SearchResult) { +// log.Printf("Found result: %v\n", result) +// }, +// func(err error) { +// log.Printf("Error: %v\n", err) +// }, +// ), +// ) func WithRetrieveCallbacks(onResult func(SearchResult), onError func(error)) RetrieverOption { return func(c *RetrieverConfig) { c.OnResult = onResult c.OnError = onError } } + +// defaultRetrieverConfig returns a RetrieverConfig with production-ready defaults. +// These defaults are chosen to provide good performance while being +// conservative with resource usage. +// +// Default settings include: +// - Top 10 results +// - Minimum score of 0.7 +// - L2 distance metric +// - 30-second timeout +// - Standard column set (Text, Metadata) +func defaultRetrieverConfig() *RetrieverConfig { + return &RetrieverConfig{ + Collection: "documents", + TopK: 5, + MinScore: 0.7, + UseHybrid: true, + Columns: []string{"Text", "Metadata"}, + DBType: "milvus", + DBAddress: "localhost:19530", + Dimension: 128, + Provider: "openai", + Model: "text-embedding-3-small", + APIKey: os.Getenv("OPENAI_API_KEY"), + MetricType: "L2", + Timeout: 30 * time.Second, + SearchParams: map[string]interface{}{ + "type": "HNSW", + "ef": 64, + }, + } +} + +func (r *Retriever) initialize() error { + var err error + + r.vectorDB, err = NewVectorDB( + WithType(r.config.DBType), + WithAddress(r.config.DBAddress), + WithTimeout(r.config.Timeout), + WithDimension(r.config.Dimension), + ) + if err != nil { + return fmt.Errorf("failed to create vector store: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), r.config.Timeout) + defer cancel() + + if err := r.vectorDB.Connect(ctx); err != nil { + return fmt.Errorf("failed to connect to vector store: %w", err) + } + + r.embedder, err = NewEmbedder( + SetProvider(r.config.Provider), + SetModel(r.config.Model), + SetAPIKey(r.config.APIKey), + ) + if err != nil { + return fmt.Errorf("failed to create embedder: %w", err) + } + + r.ready = true + return nil +} + +func (r *Retriever) Close() error { + if r.vectorDB != nil { + return r.vectorDB.Close() + } + return nil +} diff --git a/simple_rag.go b/simple_rag.go index 1c3e325..6b28ae8 100644 --- a/simple_rag.go +++ b/simple_rag.go @@ -1,3 +1,25 @@ +// SimpleRAG provides a minimal, easy-to-use interface for RAG operations. +// It simplifies the configuration and usage of the RAG system while maintaining +// core functionality. This implementation is ideal for: +// - Quick prototyping +// - Simple document retrieval needs +// - Learning the RAG system +// +// Example usage: +// +// config := raggo.DefaultConfig() +// config.APIKey = "your-api-key" +// +// rag, err := raggo.NewSimpleRAG(config) +// if err != nil { +// log.Fatal(err) +// } +// +// // Add documents +// err = rag.AddDocuments(context.Background(), "path/to/docs") +// +// // Search +// response, err := rag.Search(context.Background(), "your query") package raggo import ( @@ -12,32 +34,40 @@ import ( "github.com/teilomillet/gollm" ) -// SimpleRAG provides a minimal interface for RAG operations +// SimpleRAG provides a minimal interface for RAG operations. +// It encapsulates the core functionality while hiding complexity. type SimpleRAG struct { - retriever *Retriever - collection string - apiKey string - model string - vectorDB *VectorDB - llm gollm.LLM + retriever *Retriever // Handles document retrieval + collection string // Name of the vector collection + apiKey string // API key for services + model string // Embedding model name + vectorDB *VectorDB // Vector database connection + llm gollm.LLM // Language model interface } -// SimpleRAGConfig holds configuration for SimpleRAG +// SimpleRAGConfig holds configuration for SimpleRAG. +// It provides essential configuration options while using +// sensible defaults for other settings. type SimpleRAGConfig struct { - Collection string - APIKey string - Model string - ChunkSize int - ChunkOverlap int - TopK int - MinScore float64 - LLMModel string - DBType string // Type of vector database to use (e.g., "milvus", "chromem") - DBAddress string // Address for the vector database (e.g., "localhost:19530" for Milvus, "./data/chromem.db" for Chromem) - Dimension int // Dimension of the embedding vectors (e.g., 1536 for text-embedding-3-small) + Collection string // Name of the vector collection + APIKey string // API key for services (e.g., OpenAI) + Model string // Embedding model name + ChunkSize int // Size of text chunks in tokens + ChunkOverlap int // Overlap between consecutive chunks + TopK int // Number of results to retrieve + MinScore float64 // Minimum similarity score threshold + LLMModel string // Language model for text generation + DBType string // Type of vector database (e.g., "milvus", "chromem") + DBAddress string // Address for the vector database + Dimension int // Dimension of embedding vectors } -// DefaultConfig returns a default configuration +// DefaultConfig returns a default configuration for SimpleRAG. +// It provides reasonable defaults for all settings: +// - OpenAI's text-embedding-3-small for embeddings +// - Milvus as the vector database +// - Balanced chunk size and overlap +// - Conservative similarity threshold func DefaultConfig() SimpleRAGConfig { return SimpleRAGConfig{ Collection: "documents", @@ -53,7 +83,17 @@ func DefaultConfig() SimpleRAGConfig { } } -// NewSimpleRAG creates a new SimpleRAG instance with minimal configuration +// NewSimpleRAG creates a new SimpleRAG instance with minimal configuration. +// It performs the following setup: +// 1. Validates and applies configuration +// 2. Initializes the language model +// 3. Sets up the vector database connection +// 4. Prepares the retrieval system +// +// Returns an error if: +// - API key is missing +// - LLM initialization fails +// - Vector database connection fails func NewSimpleRAG(config SimpleRAGConfig) (*SimpleRAG, error) { if config.APIKey == "" { config.APIKey = os.Getenv("OPENAI_API_KEY") @@ -155,7 +195,17 @@ func NewSimpleRAG(config SimpleRAGConfig) (*SimpleRAG, error) { }, nil } -// AddDocuments adds documents to the vector database +// AddDocuments processes and stores documents in the vector database. +// The function: +// 1. Validates the source path +// 2. Processes documents into chunks +// 3. Generates embeddings +// 4. Stores vectors in the database +// +// The source parameter can be: +// - A single file path +// - A directory path (all documents will be processed) +// - A glob pattern (e.g., "docs/*.pdf") func (s *SimpleRAG) AddDocuments(ctx context.Context, source string) error { if ctx == nil { ctx = context.Background() @@ -234,7 +284,15 @@ func (s *SimpleRAG) AddDocuments(ctx context.Context, source string) error { return nil } -// Search performs a semantic search query and generates a response +// Search performs a semantic search query and generates a response. +// The process: +// 1. Embeds the query into a vector +// 2. Finds similar documents in the vector database +// 3. Uses the LLM to generate a response based on retrieved context +// +// Returns: +// - A natural language response incorporating retrieved information +// - An error if the search or response generation fails func (s *SimpleRAG) Search(ctx context.Context, query string) (string, error) { if ctx == nil { ctx = context.Background() @@ -306,7 +364,11 @@ If the information isn't found in the provided context, please say so clearly.`, return resp, nil } -// Close releases resources +// Close releases all resources held by the SimpleRAG instance. +// This includes: +// - Vector database connection +// - Language model resources +// - Any temporary files func (s *SimpleRAG) Close() error { if s.vectorDB != nil { s.vectorDB.Close() diff --git a/tsne_results.json b/tsne_results.json deleted file mode 100644 index c7340e7..0000000 --- a/tsne_results.json +++ /dev/null @@ -1 +0,0 @@ -[{"name":"Benjamin_Rogojan_Resume.pdf","embedding":[0.000053166590786199366,0.000001923948379172014],"isJob":false},{"name":"Business Developer TM.pdf","embedding":[0.00006841330494942138,-0.000040458984637829094],"isJob":false},{"name":"CV (1).pdf","embedding":[-0.00008465113888932352,-7.225412811525745e-7],"isJob":false},{"name":"CV2020 ENGLISH.pdf","embedding":[-0.000060612736181188786,-0.000020575358418761378],"isJob":false},{"name":"Copy of CV TM.pdf","embedding":[0.00008276252694307001,-0.000002989903088998866],"isJob":false},{"name":"Document sans titre.pdf","embedding":[-0.00003352308666096765,0.00009395005115081649],"isJob":false},{"name":"Teilo_Millet_CV (1).pdf","embedding":[-0.00008559146614979102,-0.000006675870322238315],"isJob":false},{"name":"Teilo_Millet_CV.pdf","embedding":[0.000023275161078110824,-0.00003359874132821692],"isJob":false},{"name":"_CV.pdf","embedding":[0.0000033747164644530018,0.00003518695906637328],"isJob":false},{"name":"cv_ryckebusch_laurene_agencegrenadine.pdf","embedding":[-0.00003399982438035844,-0.00005199277480691526],"isJob":false},{"name":"Job Description","embedding":[0.00006738595204037482,0.000025953215287750633],"isJob":true}] diff --git a/vectordb.go b/vectordb.go index 0030605..4035026 100644 --- a/vectordb.go +++ b/vectordb.go @@ -1,5 +1,6 @@ -// File: vectordb.go - +// Package raggo provides a high-level abstraction over various vector database +// implementations. This file defines the VectorDB type, which wraps the lower-level +// rag.VectorDB interface with additional functionality and type safety. package raggo import ( @@ -10,53 +11,96 @@ import ( "github.com/teilomillet/raggo/rag" ) +// VectorDB provides a high-level interface for vector database operations. +// It wraps the lower-level rag.VectorDB interface and adds: +// - Type-safe configuration +// - Connection management +// - Dimension tracking +// - Database type information type VectorDB struct { - db rag.VectorDB - dbType string - address string - dimension int + db rag.VectorDB // Underlying vector database implementation + dbType string // Type of database (e.g., "milvus", "memory") + address string // Connection address + dimension int // Vector dimension } +// Config holds configuration options for VectorDB instances. +// It provides a clean way to configure database connections +// without exposing implementation details. type Config struct { - Type string - Address string - MaxPoolSize int - Timeout time.Duration - Dimension int + Type string // Database type (e.g., "milvus", "memory") + Address string // Connection address + MaxPoolSize int // Maximum number of connections + Timeout time.Duration // Operation timeout + Dimension int // Vector dimension } +// Option is a function type for configuring VectorDB instances. +// It follows the functional options pattern for clean and flexible configuration. type Option func(*Config) +// WithType sets the database type. +// Supported types: +// - "milvus": Production-grade vector database +// - "memory": In-memory database for testing +// - "chromem": Chrome-based persistent storage func WithType(dbType string) Option { return func(c *Config) { c.Type = dbType } } +// WithAddress sets the database connection address. +// Examples: +// - Milvus: "localhost:19530" +// - Memory: "" (no address needed) +// - ChromeM: "./data/vectors.db" func WithAddress(address string) Option { return func(c *Config) { c.Address = address } } +// WithMaxPoolSize sets the maximum number of database connections. +// This is particularly relevant for Milvus and other client-server databases. func WithMaxPoolSize(size int) Option { return func(c *Config) { c.MaxPoolSize = size } } +// WithTimeout sets the operation timeout duration. +// This affects all database operations including: +// - Connection attempts +// - Search operations +// - Insert operations func WithTimeout(timeout time.Duration) Option { return func(c *Config) { c.Timeout = timeout } } +// WithDimension sets the dimension of vectors to be stored. +// This must match the dimension of your embedding model: +// - text-embedding-3-small: 1536 +// - text-embedding-ada-002: 1536 +// - Cohere embed-multilingual-v3.0: 1024 func WithDimension(dimension int) Option { return func(c *Config) { c.Dimension = dimension } } +// NewVectorDB creates a new vector database connection with the specified options. +// The function: +// 1. Applies all configuration options +// 2. Creates the appropriate database implementation +// 3. Sets up the connection (but doesn't connect yet) +// +// Returns an error if: +// - The database type is unsupported +// - The configuration is invalid +// - The database implementation fails to initialize func NewVectorDB(opts ...Option) (*VectorDB, error) { cfg := &Config{} for _, opt := range opts { @@ -82,26 +126,38 @@ func NewVectorDB(opts ...Option) (*VectorDB, error) { }, nil } +// Connect establishes a connection to the vector database. +// This method must be called before any database operations. func (vdb *VectorDB) Connect(ctx context.Context) error { return vdb.db.Connect(ctx) } +// Close closes the vector database connection. +// This method should be called when the database is no longer needed. func (vdb *VectorDB) Close() error { return vdb.db.Close() } +// HasCollection checks if a collection exists in the database. +// Returns true if the collection exists, false otherwise. func (vdb *VectorDB) HasCollection(ctx context.Context, name string) (bool, error) { return vdb.db.HasCollection(ctx, name) } +// CreateCollection creates a new collection in the database. +// The schema defines the structure of the collection. func (vdb *VectorDB) CreateCollection(ctx context.Context, name string, schema Schema) error { return vdb.db.CreateCollection(ctx, name, rag.Schema(schema)) } +// DropCollection drops a collection from the database. +// Returns an error if the collection does not exist. func (vdb *VectorDB) DropCollection(ctx context.Context, name string) error { return vdb.db.DropCollection(ctx, name) } +// Insert inserts a batch of records into a collection. +// The records must match the schema of the collection. func (vdb *VectorDB) Insert(ctx context.Context, collectionName string, data []Record) error { fmt.Printf("Inserting %d records into collection: %s\n", len(data), collectionName) @@ -112,18 +168,27 @@ func (vdb *VectorDB) Insert(ctx context.Context, collectionName string, data []R return vdb.db.Insert(ctx, collectionName, ragRecords) } +// Flush flushes the pending operations in a collection. +// This method is used to ensure that all pending operations are written to disk. func (vdb *VectorDB) Flush(ctx context.Context, collectionName string) error { return vdb.db.Flush(ctx, collectionName) } +// CreateIndex creates an index on a field in a collection. +// The index type defines the type of index to create. func (vdb *VectorDB) CreateIndex(ctx context.Context, collectionName, field string, index Index) error { return vdb.db.CreateIndex(ctx, collectionName, field, rag.Index(index)) } +// LoadCollection loads a collection from disk. +// This method is used to load a collection that was previously created. func (vdb *VectorDB) LoadCollection(ctx context.Context, name string) error { return vdb.db.LoadCollection(ctx, name) } +// Search searches for vectors in a collection. +// The search parameters define the search criteria. +// Returns a list of search results. func (vdb *VectorDB) Search(ctx context.Context, collectionName string, vectors map[string]Vector, topK int, metricType string, searchParams map[string]interface{}) ([]SearchResult, error) { fmt.Printf("Searching in collection %s for top %d results with metric type %s\n", collectionName, topK, metricType) @@ -134,6 +199,10 @@ func (vdb *VectorDB) Search(ctx context.Context, collectionName string, vectors return convertSearchResults(results), nil } +// HybridSearch performs a hybrid search in a collection. +// The search parameters define the search criteria. +// The reranker is used to rerank the search results. +// Returns a list of search results. func (vdb *VectorDB) HybridSearch(ctx context.Context, collectionName string, vectors map[string]Vector, topK int, metricType string, searchParams map[string]interface{}, reranker interface{}) ([]SearchResult, error) { fmt.Printf("Performing hybrid search in collection %s for top %d results with metric type %s\n", collectionName, topK, metricType) @@ -144,6 +213,7 @@ func (vdb *VectorDB) HybridSearch(ctx context.Context, collectionName string, ve return convertSearchResults(results), nil } +// convertSearchResults converts a list of rag.SearchResult to a list of SearchResult. func convertSearchResults(ragResults []rag.SearchResult) []SearchResult { results := make([]SearchResult, len(ragResults)) for i, r := range ragResults { @@ -152,18 +222,23 @@ func convertSearchResults(ragResults []rag.SearchResult) []SearchResult { return results } +// SetColumnNames sets the column names for a collection. +// This method is used to set the column names for a collection. func (vdb *VectorDB) SetColumnNames(names []string) { vdb.db.SetColumnNames(names) } +// Type returns the type of the vector database. func (vdb *VectorDB) Type() string { return vdb.dbType } +// Address returns the address of the vector database. func (vdb *VectorDB) Address() string { return vdb.address } +// Dimension returns the dimension of the vectors in the database. func (vdb *VectorDB) Dimension() int { return vdb.dimension }