Skip to content

Commit

Permalink
Clean the examples and the API
Browse files Browse the repository at this point in the history
  • Loading branch information
teilomillet committed Nov 21, 2024
1 parent d221946 commit 29c9be3
Show file tree
Hide file tree
Showing 15 changed files with 435 additions and 122 deletions.
40 changes: 20 additions & 20 deletions embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ type EmbeddedChunk = rag.EmbeddedChunk
// flexible configuration API.
//
// Common options include:
// - WithEmbedderProvider: Choose the embedding service provider
// - WithEmbedderModel: Select the specific embedding model
// - WithEmbedderAPIKey: Configure authentication
// - SetEmbedderProvider: Choose the embedding service provider
// - SetEmbedderModel: Select the specific embedding model
// - SetEmbedderAPIKey: Configure authentication
// - SetOption: Set custom provider-specific options
type EmbedderOption = rag.EmbedderOption

// WithEmbedderProvider sets the provider for the Embedder.
// SetEmbedderProvider sets the provider for the Embedder.
// Supported providers include:
// - "openai": OpenAI's text-embedding-ada-002 and other models
// - "cohere": Cohere's embedding models
Expand All @@ -54,14 +54,14 @@ type EmbedderOption = rag.EmbedderOption
// Example:
//
// embedder, err := NewEmbedder(
// WithEmbedderProvider("openai"),
// WithEmbedderModel("text-embedding-ada-002"),
// SetEmbedderProvider("openai"),
// SetEmbedderModel("text-embedding-ada-002"),
// )
func WithEmbedderProvider(provider string) EmbedderOption {
func SetEmbedderProvider(provider string) EmbedderOption {
return rag.SetProvider(provider)
}

// WithEmbedderModel sets the specific model to use for embedding.
// SetEmbedderModel sets the specific model to use for embedding.
// Available models depend on the chosen provider:
// - OpenAI: "text-embedding-ada-002" (recommended)
// - Cohere: "embed-multilingual-v2.0"
Expand All @@ -70,14 +70,14 @@ func WithEmbedderProvider(provider string) EmbedderOption {
// Example:
//
// embedder, err := NewEmbedder(
// WithEmbedderProvider("openai"),
// WithEmbedderModel("text-embedding-ada-002"),
// SetEmbedderProvider("openai"),
// SetEmbedderModel("text-embedding-ada-002"),
// )
func WithEmbedderModel(model string) EmbedderOption {
func SetEmbedderModel(model string) EmbedderOption {
return rag.SetModel(model)
}

// WithEmbedderAPIKey sets the authentication key for the embedding service.
// SetEmbedderAPIKey sets the authentication key for the embedding service.
// This is required for most cloud-based embedding providers.
//
// Security Note: Store API keys securely and never commit them to version control.
Expand All @@ -86,10 +86,10 @@ func WithEmbedderModel(model string) EmbedderOption {
// Example:
//
// embedder, err := NewEmbedder(
// WithEmbedderProvider("openai"),
// WithEmbedderAPIKey(os.Getenv("OPENAI_API_KEY")),
// SetEmbedderProvider("openai"),
// SetEmbedderAPIKey(os.Getenv("OPENAI_API_KEY")),
// )
func WithEmbedderAPIKey(apiKey string) EmbedderOption {
func SetEmbedderAPIKey(apiKey string) EmbedderOption {
return rag.SetAPIKey(apiKey)
}

Expand All @@ -100,7 +100,7 @@ func WithEmbedderAPIKey(apiKey string) EmbedderOption {
// Example:
//
// embedder, err := NewEmbedder(
// WithEmbedderProvider("openai"),
// SetEmbedderProvider("openai"),
// SetOption("timeout", 30*time.Second),
// SetOption("max_retries", 3),
// )
Expand All @@ -125,9 +125,9 @@ type Embedder = providers.Embedder
// Example:
//
// embedder, err := NewEmbedder(
// WithEmbedderProvider("openai"),
// WithEmbedderModel("text-embedding-ada-002"),
// WithEmbedderAPIKey(os.Getenv("OPENAI_API_KEY")),
// SetEmbedderProvider("openai"),
// SetEmbedderModel("text-embedding-ada-002"),
// SetEmbedderAPIKey(os.Getenv("OPENAI_API_KEY")),
// )
// if err != nil {
// log.Fatal(err)
Expand All @@ -148,7 +148,7 @@ type EmbeddingService struct {
//
// Example:
//
// embedder, _ := NewEmbedder(WithEmbedderProvider("openai"))
// embedder, _ := NewEmbedder(SetEmbedderProvider("openai"))
// service := NewEmbeddingService(embedder)
func NewEmbeddingService(embedder Embedder) *EmbeddingService {
return &EmbeddingService{
Expand Down
3 changes: 1 addition & 2 deletions examples/concurrent_loader_example.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func main() {

// Create a new ConcurrentPDFLoader with custom options
loader := raggo.NewConcurrentPDFLoader(
raggo.SetTimeout(1*time.Minute),
raggo.SetLoaderTimeout(1*time.Minute),
raggo.SetTempDir(os.TempDir()),
)

Expand Down Expand Up @@ -51,4 +51,3 @@ func main() {
fmt.Printf("%d. %s\n", i+1, file)
}
}

8 changes: 4 additions & 4 deletions examples/embedding_example.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ func main() {

// Create a new Embedder
embedder, err := raggo.NewEmbedder(
raggo.SetProvider("openai"),
raggo.SetAPIKey(os.Getenv("OPENAI_API_KEY")), // Make sure to set this environment variable
raggo.SetModel("text-embedding-3-small"),
raggo.SetEmbedderProvider("openai"),
raggo.SetEmbedderAPIKey(os.Getenv("OPENAI_API_KEY")), // Make sure to set this environment variable
raggo.SetEmbedderModel("text-embedding-3-small"),
)
if err != nil {
log.Fatalf("Failed to create embedder: %v", err)
Expand Down Expand Up @@ -76,7 +76,7 @@ func main() {
}
fmt.Printf("Embedded Chunk %d:\n", i+1)
fmt.Printf(" Text: %s\n", truncateString(chunk.Text, 50))
fmt.Printf(" Embedding Vector Length: %d\n", len(chunk.Embedding))
fmt.Printf(" Embedding Vector Length: %d\n", len(chunk.Embeddings["default"]))
fmt.Printf(" Metadata: %v\n", chunk.Metadata)
fmt.Println()
}
Expand Down
6 changes: 3 additions & 3 deletions examples/full_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ func main() {
}

embedder, err := raggo.NewEmbedder(
raggo.WithEmbedderProvider("openai"),
raggo.WithEmbedderAPIKey(os.Getenv("OPENAI_API_KEY")),
raggo.WithEmbedderModel("text-embedding-3-small"),
raggo.SetEmbedderProvider("openai"),
raggo.SetEmbedderAPIKey(os.Getenv("OPENAI_API_KEY")),
raggo.SetEmbedderModel("text-embedding-3-small"),
)
if err != nil {
log.Fatalf("Failed to create embedder: %v", err)
Expand Down
3 changes: 1 addition & 2 deletions examples/loader_example.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func main() {

// Create a new Loader with custom options
loader := raggo.NewLoader(
raggo.SetTimeout(1*time.Minute),
raggo.SetLoaderTimeout(1*time.Minute),
raggo.SetTempDir(os.TempDir()),
)

Expand Down Expand Up @@ -99,4 +99,3 @@ func dirExample(loader raggo.Loader) {
}
}
}

37 changes: 13 additions & 24 deletions examples/parser_example.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"fmt"
"log"
"os"
Expand All @@ -14,21 +15,22 @@ func main() {
raggo.SetLogLevel(raggo.LogLevelInfo)

parser := raggo.NewParser()
loader := raggo.NewLoader()

fmt.Println("Running examples with INFO level logging:")
runExamples(parser)
runExamples(parser, loader)

}

func runExamples(parser raggo.Parser) {
func runExamples(parser raggo.Parser, loader raggo.Loader) {
// Example 1: Parse PDF file
pdfExample(parser)

// Example 2: Parse text file
textExample(parser)

// Example 3: Parse directory
dirExample(parser)
dirExample(loader)
}

func pdfExample(parser raggo.Parser) {
Expand Down Expand Up @@ -71,34 +73,21 @@ func textExample(parser raggo.Parser) {
fmt.Printf("Text file parsed. Content length: %d\n", len(doc.Content))
}

func dirExample(parser raggo.Parser) {
fmt.Println("Example 3: Parsing directory")
func dirExample(loader raggo.Loader) {
fmt.Println("Example 3: Loading directory")
wd, err := os.Getwd()
if err != nil {
log.Fatalf("Failed to get working directory: %v", err)
}
testDataDir := filepath.Join(wd, "testdata")

fileCount := 0
err = filepath.Walk(testDataDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
_, err := parser.Parse(path)
if err != nil {
fmt.Printf("Error parsing %s: %v\n", path, err)
} else {
fileCount++
}
}
return nil
})

paths, err := loader.LoadDir(context.Background(), testDataDir)
if err != nil {
log.Printf("Error walking directory: %v\n", err)
log.Printf("Error loading directory: %v\n", err)
} else {
fmt.Printf("Parsed %d files in directory\n", fileCount)
fmt.Printf("Loaded %d files from directory\n", len(paths))
for i, path := range paths {
fmt.Printf("%d: %s\n", i+1, path)
}
}
}

112 changes: 101 additions & 11 deletions examples/process_embedding_benchmark.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ func main() {
}

embedder, err := raggo.NewEmbedder(
raggo.WithEmbedderProvider("openai"),
raggo.WithEmbedderAPIKey(os.Getenv("OPENAI_API_KEY")),
raggo.WithEmbedderModel("text-embedding-3-small"),
raggo.SetEmbedderProvider("openai"),
raggo.SetEmbedderAPIKey(os.Getenv("OPENAI_API_KEY")),
raggo.SetEmbedderModel("text-embedding-3-small"),
)
if err != nil {
log.Fatalf("Failed to create embedder: %v", err)
Expand All @@ -53,10 +53,78 @@ func main() {
log.Fatalf("Failed to create LLM: %v", err)
}

benchmarkPDFProcessing(parser, chunker, embedder, llm, targetDir)
// Create VectorDB instance
vectorDB, err := raggo.NewVectorDB(
raggo.WithType("milvus"),
raggo.WithAddress("localhost:19530"),
raggo.WithTimeout(30*time.Second),
)
if err != nil {
log.Fatalf("Failed to create vector database: %v", err)
}
defer vectorDB.Close()

// Connect to the database
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

if err := vectorDB.Connect(ctx); err != nil {
log.Fatalf("Failed to connect to vector database: %v", err)
}

collectionName := "benchmark_docs"

// Create collection with schema if it doesn't exist
exists, err := vectorDB.HasCollection(ctx, collectionName)
if err != nil {
log.Fatalf("Failed to check collection existence: %v", err)
}

if exists {
err = vectorDB.DropCollection(ctx, collectionName)
if err != nil {
log.Fatalf("Failed to drop existing collection: %v", err)
}
}

schema := raggo.Schema{
Name: collectionName,
Fields: []raggo.Field{
{Name: "ID", DataType: "int64", PrimaryKey: true, AutoID: false},
{Name: "Embedding", DataType: "float_vector", Dimension: 1536},
{Name: "Text", DataType: "varchar", MaxLength: 65535},
{Name: "Metadata", DataType: "json", MaxLength: 65535},
},
}

err = vectorDB.CreateCollection(ctx, collectionName, schema)
if err != nil {
log.Fatalf("Failed to create collection: %v", err)
}

// Create index for vector search
err = vectorDB.CreateIndex(ctx, collectionName, "Embedding", raggo.Index{
Type: "HNSW",
Metric: "L2",
Parameters: map[string]interface{}{
"M": 16,
"efConstruction": 256,
},
})
if err != nil {
log.Fatalf("Failed to create index: %v", err)
}

// Load the collection
err = vectorDB.LoadCollection(ctx, collectionName)
if err != nil {
log.Fatalf("Failed to load collection: %v", err)
}

benchmarkPDFProcessing(parser, chunker, embedder, llm, vectorDB, targetDir, collectionName)
}

func benchmarkPDFProcessing(parser raggo.Parser, chunker raggo.Chunker, embedder raggo.Embedder, llm gollm.LLM, targetDir string) {
func benchmarkPDFProcessing(parser raggo.Parser, chunker raggo.Chunker, embedder raggo.Embedder, llm gollm.LLM, vectorDB *raggo.VectorDB, targetDir, collectionName string) {
files, err := filepath.Glob(filepath.Join(targetDir, "*.pdf"))
if err != nil {
log.Fatalf("Failed to list PDF files: %v", err)
Expand All @@ -82,7 +150,7 @@ func benchmarkPDFProcessing(parser raggo.Parser, chunker raggo.Chunker, embedder
wg.Add(1)
go func(filePath string) {
defer wg.Done()
tokens, embeds, summaries, err := processAndEmbedPDF(parser, chunker, embedder, llm, filePath)
tokens, embeds, summaries, err := processAndEmbedPDF(parser, chunker, embedder, llm, vectorDB, filePath, collectionName)
mu.Lock()
defer mu.Unlock()
if err != nil {
Expand Down Expand Up @@ -116,7 +184,7 @@ func benchmarkPDFProcessing(parser raggo.Parser, chunker raggo.Chunker, embedder
fmt.Printf("Average summaries per second: %.2f\n", float64(summaryCount)/duration.Seconds())
}

func processAndEmbedPDF(parser raggo.Parser, chunker raggo.Chunker, embedder raggo.Embedder, llm gollm.LLM, filePath string) (int, int, int, error) {
func processAndEmbedPDF(parser raggo.Parser, chunker raggo.Chunker, embedder raggo.Embedder, llm gollm.LLM, vectorDB *raggo.VectorDB, filePath, collectionName string) (int, int, int, error) {
log.Printf("Processing file: %s", filePath)

doc, err := parser.Parse(filePath)
Expand All @@ -138,19 +206,41 @@ func processAndEmbedPDF(parser raggo.Parser, chunker raggo.Chunker, embedder rag

log.Printf("Generated summary for %s, Summary length: %d", filePath, len(summary))

chunks := chunker.Chunk(summary)

chunks := chunker.Chunk(doc.Content)
totalTokens := 0
embedCount := 0

for _, chunk := range chunks {
// Create records for batch insertion
var records []raggo.Record
for i, chunk := range chunks {
totalTokens += len(chunk.Text) // Simple approximation

_, err := embedder.Embed(context.Context(context.Background()), chunk.Text)
embedding, err := embedder.Embed(context.Background(), chunk.Text)
if err != nil {
return totalTokens, embedCount, 1, fmt.Errorf("error embedding chunk: %w", err)
}
embedCount++

// Create record with metadata
records = append(records, raggo.Record{
Fields: map[string]interface{}{
"ID": int64(i),
"Embedding": embedding,
"Text": chunk.Text,
"Metadata": map[string]interface{}{
"source": filePath,
"chunk": i,
"summary": summary,
"timestamp": time.Now().Unix(),
},
},
})
}

// Batch insert records
err = vectorDB.Insert(context.Background(), collectionName, records)
if err != nil {
return totalTokens, embedCount, 1, fmt.Errorf("error inserting into vector database: %w", err)
}

log.Printf("Successfully processed %s: %d tokens, %d embeddings", filePath, totalTokens, embedCount)
Expand Down
Loading

0 comments on commit 29c9be3

Please sign in to comment.