From e30bc1023f2c0612643eccfea55c47952a217699 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Te=C3=AFlo=20M?= Date: Fri, 20 Dec 2024 16:12:20 +0100 Subject: [PATCH] fix: circuit breaker recovery and error handling - Fixed circuit breaker reset timeout configuration - Added circuit breaker config to Config struct - Updated version to v0.0.11 --- Dockerfile | 2 +- README.md | 2 +- ROADMAP.md | 2 +- config.example.yaml | 31 ++- config/config.go | 15 ++ server/circuitbreaker/circuitbreaker.go | 173 +++++++++++++ server/circuitbreaker/errors.go | 8 + server/provider/manager.go | 308 ++++++++++++++++++++++++ server/provider/provider.go | 182 -------------- tests/circuitbreaker_test.go | 250 +++++++++++++++++++ 10 files changed, 782 insertions(+), 191 deletions(-) create mode 100644 server/circuitbreaker/circuitbreaker.go create mode 100644 server/circuitbreaker/errors.go create mode 100644 server/provider/manager.go delete mode 100644 server/provider/provider.go create mode 100644 tests/circuitbreaker_test.go diff --git a/Dockerfile b/Dockerfile index 3d04b06..010ff11 100644 --- a/Dockerfile +++ b/Dockerfile @@ -35,7 +35,7 @@ WORKDIR /app COPY --from=builder /app/hapax . # Copy default config file -COPY config.yaml ./config.yaml +COPY config.example.yaml ./config.yaml # Use non-root user USER hapax diff --git a/README.md b/README.md index c805283..a9ff305 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ A lightweight HTTP server for Large Language Model (LLM) interactions, built with Go. ## Version -v0.0.9 +v0.0.11 ## Features diff --git a/ROADMAP.md b/ROADMAP.md index ee8abae..10569ac 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -60,7 +60,7 @@ Focus: Enhance reliability, scalability, and deployability for production enviro - Latency tracking - Error monitoring - Resource utilization -- [ ] Docker support +- [x] Docker support - Multi-stage build optimization - Production-ready Dockerfile - Docker Compose configuration diff --git a/config.example.yaml b/config.example.yaml index 94a7638..42c2ea2 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -2,15 +2,34 @@ server: port: 8080 read_timeout: 30s write_timeout: 30s + max_header_bytes: 1048576 + shutdown_timeout: 5s -llm: - provider: "ollama" # Change to your preferred provider (ollama, openai, etc.) - model: "llama2" # Change to your preferred model - endpoint: "" # Set your provider's endpoint if needed +providers: + openai: + type: openai + model: gpt-4 + api_key: ${OPENAI_API_KEY} + anthropic: + type: anthropic + model: claude-2 + api_key: ${ANTHROPIC_API_KEY} + ollama: + type: ollama + model: llama2 + api_key: "" + +# Order of provider preference for failover +provider_preference: + - openai + - anthropic + - ollama logging: - level: "info" - format: "json" + level: info + format: json metrics: enabled: true + prometheus: + enabled: true diff --git a/config/config.go b/config/config.go index 43d0bc2..d99cfca 100644 --- a/config/config.go +++ b/config/config.go @@ -21,6 +21,9 @@ type Config struct { LLM LLMConfig `yaml:"llm"` Logging LoggingConfig `yaml:"logging"` Routes []RouteConfig `yaml:"routes"` + Providers map[string]ProviderConfig `yaml:"providers"` + ProviderPreference []string `yaml:"provider_preference"` // Order of provider preference + CircuitBreaker CircuitBreakerConfig `yaml:"circuit_breaker"` } // ServerConfig holds server-specific configuration for the HTTP server. @@ -166,6 +169,13 @@ type RetryConfig struct { RetryableErrors []string `yaml:"retryable_errors"` } +// ProviderConfig holds configuration for an LLM provider +type ProviderConfig struct { + Type string `yaml:"type"` // Provider type (e.g., openai, anthropic) + Model string `yaml:"model"` // Model name + APIKey string `yaml:"api_key"` // API key for authentication +} + // LoggingConfig holds logging-specific configuration. type LoggingConfig struct { // Level sets logging verbosity: debug, info, warn, error @@ -217,6 +227,11 @@ type HealthCheck struct { Checks map[string]string `yaml:"checks"` } +// CircuitBreakerConfig holds circuit breaker settings +type CircuitBreakerConfig struct { + ResetTimeout time.Duration `yaml:"reset_timeout"` +} + // DefaultConfig returns a configuration with sensible defaults func DefaultConfig() *Config { return &Config{ diff --git a/server/circuitbreaker/circuitbreaker.go b/server/circuitbreaker/circuitbreaker.go new file mode 100644 index 0000000..a12b68f --- /dev/null +++ b/server/circuitbreaker/circuitbreaker.go @@ -0,0 +1,173 @@ +package circuitbreaker + +import ( + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" +) + +// State represents the current state of the circuit breaker +type State int + +const ( + StateClosed State = iota // Circuit is closed (allowing requests) + StateOpen // Circuit is open (blocking requests) + StateHalfOpen // Circuit is half-open (testing if service is healthy) +) + +// Config holds configuration for the circuit breaker +type Config struct { + FailureThreshold int // Number of failures before opening circuit + ResetTimeout time.Duration // Time to wait before attempting reset + HalfOpenRequests int // Number of requests to allow in half-open state +} + +// CircuitBreaker implements the circuit breaker pattern +type CircuitBreaker struct { + name string + config Config + state State + failures int + lastFailure time.Time + halfOpen int + mu sync.RWMutex + logger *zap.Logger + + // Metrics + stateGauge prometheus.Gauge + failuresCount prometheus.Counter + tripsTotal prometheus.Counter +} + +// NewCircuitBreaker creates a new circuit breaker +func NewCircuitBreaker(name string, config Config, logger *zap.Logger, registry *prometheus.Registry) *CircuitBreaker { + cb := &CircuitBreaker{ + name: name, + config: config, + state: StateClosed, + logger: logger, + } + + // Initialize Prometheus metrics + cb.stateGauge = prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "hapax_circuit_breaker_state", + Help: "Current state of the circuit breaker (0=closed, 1=open, 2=half-open)", + ConstLabels: prometheus.Labels{ + "name": name, + }, + }) + + cb.failuresCount = prometheus.NewCounter(prometheus.CounterOpts{ + Name: "hapax_circuit_breaker_failures_total", + Help: "Total number of failures recorded by the circuit breaker", + ConstLabels: prometheus.Labels{ + "name": name, + }, + }) + + cb.tripsTotal = prometheus.NewCounter(prometheus.CounterOpts{ + Name: "hapax_circuit_breaker_trips_total", + Help: "Total number of times the circuit breaker has tripped", + ConstLabels: prometheus.Labels{ + "name": name, + }, + }) + + // Register metrics with Prometheus + registry.MustRegister(cb.stateGauge) + registry.MustRegister(cb.failuresCount) + registry.MustRegister(cb.tripsTotal) + + return cb +} + +// Execute runs the given function if the circuit breaker allows it +func (cb *CircuitBreaker) Execute(f func() error) error { + if !cb.AllowRequest() { + return ErrCircuitOpen + } + + err := f() + cb.RecordResult(err) + return err +} + +// AllowRequest checks if a request should be allowed through +func (cb *CircuitBreaker) AllowRequest() bool { + cb.mu.Lock() + defer cb.mu.Unlock() + + switch cb.state { + case StateClosed: + return true + case StateOpen: + // Check if enough time has passed to try half-open + if time.Since(cb.lastFailure) > cb.config.ResetTimeout { + cb.setState(StateHalfOpen) + cb.halfOpen = 0 + return true + } + return false + case StateHalfOpen: + // Allow one request in half-open state + if cb.halfOpen < cb.config.HalfOpenRequests { + cb.halfOpen++ + return true + } + return false + default: + return false + } +} + +// RecordResult records the result of a request +func (cb *CircuitBreaker) RecordResult(err error) { + cb.mu.Lock() + defer cb.mu.Unlock() + + if err != nil { + cb.failures++ + cb.failuresCount.Inc() + cb.lastFailure = time.Now() + + // Trip breaker if failure threshold reached + if cb.failures >= cb.config.FailureThreshold { + cb.tripBreaker() + } + } else { + // Reset on success + if cb.state == StateHalfOpen { + cb.setState(StateClosed) + cb.failures = 0 + cb.halfOpen = 0 + } else if cb.state == StateClosed { + cb.failures = 0 + } + } +} + +// tripBreaker moves the circuit breaker to the open state +func (cb *CircuitBreaker) tripBreaker() { + cb.setState(StateOpen) + cb.tripsTotal.Inc() + cb.logger.Warn("Circuit breaker tripped", + zap.String("name", cb.name), + zap.Int("failures", cb.failures), + zap.Time("last_failure", cb.lastFailure), + ) +} + +// setState updates the circuit breaker state and metrics +func (cb *CircuitBreaker) setState(state State) { + cb.state = state + cb.stateGauge.Set(float64(state)) +} + +// GetState returns the current state of the circuit breaker +func (cb *CircuitBreaker) GetState() State { + cb.mu.RLock() + defer cb.mu.RUnlock() + return cb.state +} diff --git a/server/circuitbreaker/errors.go b/server/circuitbreaker/errors.go new file mode 100644 index 0000000..396451d --- /dev/null +++ b/server/circuitbreaker/errors.go @@ -0,0 +1,8 @@ +package circuitbreaker + +import "errors" + +var ( + // ErrCircuitOpen is returned when the circuit breaker is open + ErrCircuitOpen = errors.New("circuit breaker is open") +) diff --git a/server/provider/manager.go b/server/provider/manager.go new file mode 100644 index 0000000..99b80e3 --- /dev/null +++ b/server/provider/manager.go @@ -0,0 +1,308 @@ +package provider + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/teilomillet/hapax/config" + "github.com/teilomillet/hapax/server/circuitbreaker" + "github.com/prometheus/client_golang/prometheus" + "github.com/teilomillet/gollm" + "go.uber.org/zap" +) + +// HealthStatus represents the current health state of a provider +type HealthStatus struct { + Healthy bool + LastCheck time.Time + ConsecutiveFails int + Latency time.Duration + ErrorCount int64 + RequestCount int64 +} + +// Manager handles LLM provider management and selection +type Manager struct { + providers map[string]gollm.LLM + breakers map[string]*circuitbreaker.CircuitBreaker + healthStates sync.Map // map[string]HealthStatus + logger *zap.Logger + cfg *config.Config + mu sync.RWMutex + + // Metrics + healthCheckDuration prometheus.Histogram + healthCheckErrors *prometheus.CounterVec + requestLatency *prometheus.HistogramVec +} + +// NewManager creates a new provider manager +func NewManager(cfg *config.Config, logger *zap.Logger, registry *prometheus.Registry) (*Manager, error) { + m := &Manager{ + providers: make(map[string]gollm.LLM), + breakers: make(map[string]*circuitbreaker.CircuitBreaker), + logger: logger, + cfg: cfg, + } + + // Initialize metrics + m.initializeMetrics(registry) + + // Create circuit breaker config + cbConfig := circuitbreaker.Config{ + FailureThreshold: 3, // Trip after 3 failures + ResetTimeout: cfg.CircuitBreaker.ResetTimeout, // Use configured timeout + HalfOpenRequests: 1, // Allow 1 request in half-open state + } + + if cbConfig.ResetTimeout == 0 { + cbConfig.ResetTimeout = time.Minute // Default to 1 minute + } + + // Initialize providers from both new and legacy configs + if err := m.initializeProviders(cbConfig, registry); err != nil { + return nil, err + } + + // Start health checks if enabled + if cfg.LLM.HealthCheck != nil && cfg.LLM.HealthCheck.Enabled { + go m.startHealthChecks(context.Background()) + } + + return m, nil +} + +// initializeMetrics sets up Prometheus metrics +func (m *Manager) initializeMetrics(registry *prometheus.Registry) { + m.healthCheckDuration = prometheus.NewHistogram(prometheus.HistogramOpts{ + Name: "hapax_provider_health_check_duration_seconds", + Help: "Duration of provider health checks", + }) + + m.healthCheckErrors = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "hapax_provider_health_check_errors_total", + Help: "Number of health check errors by provider", + }, []string{"provider"}) + + m.requestLatency = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: "hapax_provider_request_latency_seconds", + Help: "Latency of provider requests", + }, []string{"provider"}) + + registry.MustRegister(m.healthCheckDuration) + registry.MustRegister(m.healthCheckErrors) + registry.MustRegister(m.requestLatency) +} + +// initializeProviders sets up LLM providers based on configuration +func (m *Manager) initializeProviders(cbConfig circuitbreaker.Config, registry *prometheus.Registry) error { + // Initialize from new provider config + for name, providerCfg := range m.cfg.Providers { + llm, err := gollm.NewLLM( + gollm.SetProvider(providerCfg.Type), + gollm.SetModel(providerCfg.Model), + gollm.SetAPIKey(providerCfg.APIKey), + ) + if err != nil { + return fmt.Errorf("failed to initialize provider %s: %w", name, err) + } + + m.providers[name] = llm + m.breakers[name] = circuitbreaker.NewCircuitBreaker( + name, + cbConfig, + m.logger.With(zap.String("provider", name)), + registry, + ) + } + + // Initialize from legacy config if no new providers configured + if len(m.providers) == 0 && m.cfg.LLM.Provider != "" { + primary, err := gollm.NewLLM( + gollm.SetProvider(m.cfg.LLM.Provider), + gollm.SetModel(m.cfg.LLM.Model), + ) + if err != nil { + return fmt.Errorf("failed to initialize legacy provider: %w", err) + } + + name := m.cfg.LLM.Provider + m.providers[name] = primary + m.breakers[name] = circuitbreaker.NewCircuitBreaker( + name, + cbConfig, + m.logger.With(zap.String("provider", name)), + registry, + ) + + // Initialize legacy backup providers + for _, backup := range m.cfg.LLM.BackupProviders { + llm, err := gollm.NewLLM( + gollm.SetProvider(backup.Provider), + gollm.SetModel(backup.Model), + ) + if err != nil { + m.logger.Warn("Failed to initialize backup provider", + zap.String("provider", backup.Provider), + zap.Error(err)) + continue + } + + m.providers[backup.Provider] = llm + m.breakers[backup.Provider] = circuitbreaker.NewCircuitBreaker( + backup.Provider, + cbConfig, + m.logger.With(zap.String("provider", backup.Provider)), + registry, + ) + } + } + + return nil +} + +// startHealthChecks begins monitoring all providers +func (m *Manager) startHealthChecks(ctx context.Context) { + interval := time.Minute + if m.cfg.LLM.HealthCheck != nil { + interval = m.cfg.LLM.HealthCheck.Interval + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + m.checkAllProviders() + } + } +} + +// checkAllProviders performs health checks on all providers +func (m *Manager) checkAllProviders() { + for name, provider := range m.providers { + status := m.checkProviderHealth(name, provider) + m.updateHealthStatus(name, status) + } +} + +// checkProviderHealth performs a health check on a provider +func (m *Manager) checkProviderHealth(name string, llm gollm.LLM) HealthStatus { + start := time.Now() + status := HealthStatus{ + LastCheck: start, + Healthy: true, + } + + // Simple health check prompt + prompt := &gollm.Prompt{ + Messages: []gollm.PromptMessage{ + {Role: "user", Content: "health check"}, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err := llm.Generate(ctx, prompt) + status.Latency = time.Since(start) + m.healthCheckDuration.Observe(status.Latency.Seconds()) + + if err != nil { + status.Healthy = false + status.ErrorCount++ + m.healthCheckErrors.WithLabelValues(name).Inc() + m.logger.Warn("Provider health check failed", + zap.String("provider", name), + zap.Error(err), + zap.Duration("latency", status.Latency), + ) + } + + return status +} + +// updateHealthStatus updates the health status for a provider +func (m *Manager) updateHealthStatus(name string, status HealthStatus) { + m.healthStates.Store(name, status) +} + +// GetProvider returns a healthy provider or error if none available +func (m *Manager) GetProvider() (gollm.LLM, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + // Try providers in order of preference + for _, name := range m.cfg.ProviderPreference { + provider, ok := m.providers[name] + if !ok { + continue + } + + breaker := m.breakers[name] + if breaker.AllowRequest() { + return provider, nil + } + } + + return nil, ErrNoHealthyProvider +} + +// Execute runs an LLM operation with circuit breaker protection +func (m *Manager) Execute(ctx context.Context, op func(gollm.LLM) error) error { + provider, err := m.GetProvider() + if err != nil { + return fmt.Errorf("failed to get provider: %w", err) + } + + name := m.getProviderName(provider) + breaker := m.breakers[name] + + start := time.Now() + err = breaker.Execute(func() error { + return op(provider) + }) + m.requestLatency.WithLabelValues(name).Observe(time.Since(start).Seconds()) + + // If all providers are failing, wrap with ErrNoHealthyProvider + if err != nil { + allFailing := true + for _, b := range m.breakers { + if b.GetState() != circuitbreaker.StateOpen { + allFailing = false + break + } + } + if allFailing { + return fmt.Errorf("%w: %v", ErrNoHealthyProvider, err) + } + } + + return err +} + +// getProviderName returns the name of a provider instance +func (m *Manager) getProviderName(provider gollm.LLM) string { + m.mu.RLock() + defer m.mu.RUnlock() + + for name, p := range m.providers { + if p == provider { + return name + } + } + return "unknown" +} + +// SetProviders replaces the current providers with new ones (for testing) +func (m *Manager) SetProviders(providers map[string]gollm.LLM) { + m.mu.Lock() + defer m.mu.Unlock() + + m.providers = providers +} diff --git a/server/provider/provider.go b/server/provider/provider.go deleted file mode 100644 index bc5a525..0000000 --- a/server/provider/provider.go +++ /dev/null @@ -1,182 +0,0 @@ -// Package provider implements LLM provider management functionality. -package provider - -import ( - "context" - "sync" - "time" - - "github.com/teilomillet/gollm" - "github.com/teilomillet/hapax/config" - "go.uber.org/zap" -) - -// HealthStatus represents the current health state of a provider -type HealthStatus struct { - Healthy bool - LastCheck time.Time - ConsecutiveFails int - Latency time.Duration - ErrorCount int64 - RequestCount int64 -} - -// Manager handles LLM provider management, including: -// - Health monitoring -// - Failover handling -// - Provider configuration -type Manager struct { - providers map[string]gollm.LLM - healthStates sync.Map - logger *zap.Logger - cfg *config.Config - mu sync.RWMutex -} - -// NewManager creates a new provider manager instance -func NewManager(cfg *config.Config, logger *zap.Logger) (*Manager, error) { - m := &Manager{ - providers: make(map[string]gollm.LLM), - logger: logger, - cfg: cfg, - } - - // Skip provider initialization if health check is disabled - if cfg.LLM.HealthCheck == nil || !cfg.LLM.HealthCheck.Enabled { - return m, nil - } - - // Initialize providers from config - if err := m.initializeProviders(); err != nil { - return nil, err - } - - return m, nil -} - -// NewManagerWithProviders creates a manager with pre-configured providers (for testing) -func NewManagerWithProviders(cfg *config.Config, logger *zap.Logger, providers map[string]gollm.LLM) *Manager { - return &Manager{ - providers: providers, - logger: logger, - cfg: cfg, - } -} - -// initializeProviders sets up LLM providers based on configuration -func (m *Manager) initializeProviders() error { - // Initialize primary provider - primary, err := gollm.NewLLM( - gollm.SetProvider(m.cfg.LLM.Provider), - gollm.SetModel(m.cfg.LLM.Model), - ) - if err != nil { - return err - } - m.providers[m.cfg.LLM.Provider] = primary - - // Initialize backup providers if configured - for _, backup := range m.cfg.LLM.BackupProviders { - llm, err := gollm.NewLLM( - gollm.SetProvider(backup.Provider), - gollm.SetModel(backup.Model), - ) - if err != nil { - m.logger.Warn("Failed to initialize backup provider", - zap.String("provider", backup.Provider), - zap.Error(err)) - continue - } - m.providers[backup.Provider] = llm - } - - return nil -} - -// StartHealthChecks begins monitoring all providers -func (m *Manager) StartHealthChecks(ctx context.Context) { - for providerName, llm := range m.providers { - go m.monitorProvider(ctx, providerName, llm) - } -} - -// monitorProvider continuously monitors a single provider's health -func (m *Manager) monitorProvider(ctx context.Context, providerName string, llm gollm.LLM) { - ticker := time.NewTicker(m.cfg.LLM.HealthCheck.Interval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - status := m.checkProviderHealth(providerName, llm) - m.updateHealthStatus(providerName, status) - } - } -} - -// checkProviderHealth performs a health check on a provider -func (m *Manager) checkProviderHealth(providerName string, llm gollm.LLM) HealthStatus { - start := time.Now() - - // Simple health check prompt - prompt := llm.NewPrompt("health check") - llm.SetSystemPrompt("Respond with 'ok' for health check.", gollm.CacheTypeEphemeral) - - ctx, cancel := context.WithTimeout(context.Background(), m.cfg.LLM.HealthCheck.Timeout) - defer cancel() - - _, err := llm.Generate(ctx, prompt) - - status := HealthStatus{ - LastCheck: time.Now(), - Latency: time.Since(start), - } - - if err != nil { - m.logger.Warn("Provider health check failed", - zap.String("provider", providerName), - zap.Error(err)) - status.Healthy = false - status.ConsecutiveFails++ - } else { - status.Healthy = true - status.ConsecutiveFails = 0 - } - - return status -} - -// updateHealthStatus updates the health status for a provider -func (m *Manager) updateHealthStatus(providerName string, status HealthStatus) { - m.healthStates.Store(providerName, status) -} - -// GetHealthyProvider returns a healthy provider, implementing failover if needed -func (m *Manager) GetHealthyProvider() (gollm.LLM, error) { - m.mu.RLock() - defer m.mu.RUnlock() - - // Try primary provider first - if status, ok := m.getProviderStatus(m.cfg.LLM.Provider); ok && status.Healthy { - return m.providers[m.cfg.LLM.Provider], nil - } - - // Try backup providers in order - for _, backup := range m.cfg.LLM.BackupProviders { - if status, ok := m.getProviderStatus(backup.Provider); ok && status.Healthy { - return m.providers[backup.Provider], nil - } - } - - return nil, ErrNoHealthyProvider -} - -// getProviderStatus retrieves the current health status for a provider -func (m *Manager) getProviderStatus(provider string) (HealthStatus, bool) { - if status, ok := m.healthStates.Load(provider); ok { - return status.(HealthStatus), true - } - return HealthStatus{}, false -} diff --git a/tests/circuitbreaker_test.go b/tests/circuitbreaker_test.go new file mode 100644 index 0000000..d645a28 --- /dev/null +++ b/tests/circuitbreaker_test.go @@ -0,0 +1,250 @@ +package tests + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/teilomillet/hapax/config" + "github.com/teilomillet/hapax/server/circuitbreaker" + "github.com/teilomillet/hapax/server/mocks" + "github.com/teilomillet/hapax/server/provider" + "github.com/teilomillet/gollm" + "go.uber.org/zap" +) + +func TestCircuitBreaker(t *testing.T) { + logger, _ := zap.NewDevelopment() + registry := prometheus.NewRegistry() + + config := circuitbreaker.Config{ + FailureThreshold: 2, // Trip after 2 failures + ResetTimeout: time.Second, // Short timeout for testing + HalfOpenRequests: 1, + } + + cb := circuitbreaker.NewCircuitBreaker("test", config, logger, registry) + + t.Run("Initially Closed", func(t *testing.T) { + assert.True(t, cb.AllowRequest()) + assert.Equal(t, circuitbreaker.StateClosed, cb.GetState()) + }) + + t.Run("Opens After Failures", func(t *testing.T) { + // Record two failures + err := cb.Execute(func() error { + return errors.New("error 1") + }) + assert.Error(t, err) + assert.True(t, cb.AllowRequest(), "Should still allow requests after one failure") + + err = cb.Execute(func() error { + return errors.New("error 2") + }) + assert.Error(t, err) + assert.False(t, cb.AllowRequest(), "Should not allow requests after threshold reached") + assert.Equal(t, circuitbreaker.StateOpen, cb.GetState()) + }) + + t.Run("Transitions to Half-Open", func(t *testing.T) { + // Wait for reset timeout + time.Sleep(2 * time.Second) + + assert.True(t, cb.AllowRequest(), "Should allow one request in half-open state") + assert.Equal(t, circuitbreaker.StateHalfOpen, cb.GetState()) + }) + + t.Run("Closes After Success", func(t *testing.T) { + err := cb.Execute(func() error { + return nil + }) + assert.NoError(t, err) + assert.True(t, cb.AllowRequest()) + assert.Equal(t, circuitbreaker.StateClosed, cb.GetState()) + }) +} + +func TestProviderManagerWithCircuitBreaker(t *testing.T) { + logger, _ := zap.NewDevelopment() + registry := prometheus.NewRegistry() + + // Create test configuration + cfg := &config.Config{} + cfg.Providers = map[string]config.ProviderConfig{ + "primary": { + Type: "openai", + Model: "gpt-3.5-turbo", + }, + "backup": { + Type: "openai", + Model: "gpt-3.5-turbo", + }, + } + cfg.ProviderPreference = []string{"primary", "backup"} + cfg.CircuitBreaker = config.CircuitBreakerConfig{ + ResetTimeout: 1 * time.Second, // Short timeout for testing + } + + var primaryCallCount, backupCallCount int + + // Create mock providers + primaryProvider := mocks.NewMockLLMWithConfig("primary", "gpt-3.5-turbo", func(ctx context.Context, prompt *gollm.Prompt) (string, error) { + primaryCallCount++ + return "primary response", nil + }) + + backupProvider := mocks.NewMockLLMWithConfig("backup", "gpt-3.5-turbo", func(ctx context.Context, prompt *gollm.Prompt) (string, error) { + backupCallCount++ + return "backup response", nil + }) + + providers := map[string]gollm.LLM{ + "primary": primaryProvider, + "backup": backupProvider, + } + + manager, err := provider.NewManager(cfg, logger, registry) + require.NoError(t, err) + require.NotNil(t, manager) + + // Replace providers with mocks + manager.SetProviders(providers) + + t.Run("Uses Primary Provider", func(t *testing.T) { + err := manager.Execute(context.Background(), func(llm gollm.LLM) error { + assert.Equal(t, primaryProvider, llm) + prompt := &gollm.Prompt{ + Messages: []gollm.PromptMessage{ + {Role: "user", Content: "test"}, + }, + } + _, err := llm.Generate(context.Background(), prompt) + return err + }) + require.NoError(t, err) + assert.Equal(t, 1, primaryCallCount) + assert.Equal(t, 0, backupCallCount) + }) + + t.Run("Fails Over to Backup", func(t *testing.T) { + // Make primary provider fail + primaryProvider.GenerateFunc = func(ctx context.Context, prompt *gollm.Prompt) (string, error) { + primaryCallCount++ + return "", errors.New("mock error") + } + + // Execute enough requests to trip the circuit breaker + for i := 0; i < 3; i++ { + manager.Execute(context.Background(), func(llm gollm.LLM) error { + prompt := &gollm.Prompt{ + Messages: []gollm.PromptMessage{ + {Role: "user", Content: "test"}, + }, + } + _, err := llm.Generate(context.Background(), prompt) + return err + }) + } + + // Next request should use backup provider + err := manager.Execute(context.Background(), func(llm gollm.LLM) error { + assert.Equal(t, backupProvider, llm) + prompt := &gollm.Prompt{ + Messages: []gollm.PromptMessage{ + {Role: "user", Content: "test"}, + }, + } + _, err := llm.Generate(context.Background(), prompt) + return err + }) + require.NoError(t, err) + assert.True(t, backupCallCount > 0) + }) + + t.Run("Recovers Primary Provider", func(t *testing.T) { + // Wait for circuit breaker timeout + time.Sleep(2 * time.Second) + + // Fix primary provider + primaryProvider.GenerateFunc = func(ctx context.Context, prompt *gollm.Prompt) (string, error) { + primaryCallCount++ + return "primary response", nil + } + + // Should eventually switch back to primary + err := manager.Execute(context.Background(), func(llm gollm.LLM) error { + assert.Equal(t, primaryProvider, llm) + prompt := &gollm.Prompt{ + Messages: []gollm.PromptMessage{ + {Role: "user", Content: "test"}, + }, + } + _, err := llm.Generate(context.Background(), prompt) + return err + }) + require.NoError(t, err) + }) +} + +func TestProviderManagerAllProvidersFailing(t *testing.T) { + logger, _ := zap.NewDevelopment() + registry := prometheus.NewRegistry() + + cfg := &config.Config{} + cfg.Providers = map[string]config.ProviderConfig{ + "primary": {Type: "openai", Model: "gpt-3.5-turbo"}, + "backup": {Type: "openai", Model: "gpt-3.5-turbo"}, + } + cfg.ProviderPreference = []string{"primary", "backup"} + + // Create failing providers + primaryProvider := mocks.NewMockLLMWithConfig("primary", "gpt-3.5-turbo", func(ctx context.Context, prompt *gollm.Prompt) (string, error) { + return "", errors.New("mock error") + }) + + backupProvider := mocks.NewMockLLMWithConfig("backup", "gpt-3.5-turbo", func(ctx context.Context, prompt *gollm.Prompt) (string, error) { + return "", errors.New("mock error") + }) + + providers := map[string]gollm.LLM{ + "primary": primaryProvider, + "backup": backupProvider, + } + + manager, err := provider.NewManager(cfg, logger, registry) + require.NoError(t, err) + require.NotNil(t, manager) + + // Replace providers with mocks + manager.SetProviders(providers) + + // Trip circuit breakers for both providers + for i := 0; i < 5; i++ { + err := manager.Execute(context.Background(), func(llm gollm.LLM) error { + prompt := &gollm.Prompt{ + Messages: []gollm.PromptMessage{ + {Role: "user", Content: "test"}, + }, + } + _, err := llm.Generate(context.Background(), prompt) + return err + }) + assert.Error(t, err) + } + + // Should return ErrNoHealthyProvider + err = manager.Execute(context.Background(), func(llm gollm.LLM) error { + prompt := &gollm.Prompt{ + Messages: []gollm.PromptMessage{ + {Role: "user", Content: "test"}, + }, + } + _, err := llm.Generate(context.Background(), prompt) + return err + }) + assert.ErrorIs(t, err, provider.ErrNoHealthyProvider) +}