Skip to content

Commit

Permalink
Code: add tracking requests by source IP
Browse files Browse the repository at this point in the history
  • Loading branch information
trader-payne committed Oct 25, 2024
1 parent 20eec1a commit e3c8caa
Showing 1 changed file with 50 additions and 8 deletions.
58 changes: 50 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,15 @@ type LoadBalancer struct {
ProxyCacheMutex sync.RWMutex

// Prometheus metrics
latencyGauge *prometheus.GaugeVec
chainheadGauge *prometheus.GaugeVec
blocksBehindGauge *prometheus.GaugeVec
loadGauge *prometheus.GaugeVec
errorCounter *prometheus.CounterVec
bestEndpointGauge *prometheus.GaugeVec
requestCounter *prometheus.CounterVec
requestDuration *prometheus.HistogramVec
latencyGauge *prometheus.GaugeVec
chainheadGauge *prometheus.GaugeVec
blocksBehindGauge *prometheus.GaugeVec
loadGauge *prometheus.GaugeVec
errorCounter *prometheus.CounterVec
bestEndpointGauge *prometheus.GaugeVec
requestCounter *prometheus.CounterVec
requestDuration *prometheus.HistogramVec
requestByIPCounter *prometheus.CounterVec // New metric

// Custom Prometheus registry
promRegistry *prometheus.Registry
Expand All @@ -109,6 +110,7 @@ func logMessage(globalLevel string, level string, format string, args ...interfa
"ERROR": 0,
"INFO": 1,
"DEBUG": 2,
"TRACE": 3,
}

if levels[level] <= levels[globalLevel] {
Expand Down Expand Up @@ -237,6 +239,15 @@ func (lb *LoadBalancer) initMetrics() {
[]string{"network"},
)

// New metric for requests by client IP
lb.requestByIPCounter = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "loadbalancer_requests_by_ip_total",
Help: "Total number of requests by client IP",
},
[]string{"network", "client_ip"},
)

// Register only the custom metrics with the custom registry
lb.promRegistry.MustRegister(lb.latencyGauge)
lb.promRegistry.MustRegister(lb.chainheadGauge)
Expand All @@ -246,6 +257,7 @@ func (lb *LoadBalancer) initMetrics() {
lb.promRegistry.MustRegister(lb.bestEndpointGauge)
lb.promRegistry.MustRegister(lb.requestCounter)
lb.promRegistry.MustRegister(lb.requestDuration)
lb.promRegistry.MustRegister(lb.requestByIPCounter) // Register new metric
}

func (lb *LoadBalancer) logRateLimited(level string, key string, format string, args ...interface{}) {
Expand Down Expand Up @@ -573,6 +585,27 @@ func (lb *LoadBalancer) getValidEndpoints(nodes []*NodeStatus, endpointChainhead
return validEndpoints
}

// getClientIP extracts the client IP address from the request.
func getClientIP(r *http.Request) string {
// Try to get the IP from the X-Forwarded-For header
xForwardedFor := r.Header.Get("X-Forwarded-For")
if xForwardedFor != "" {
// X-Forwarded-For can be a comma-separated list of IPs
ips := strings.Split(xForwardedFor, ",")
// Return the first IP
ip := strings.TrimSpace(ips[0])
if ip != "" {
return ip
}
}
// Fallback to RemoteAddr
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}

func (lb *LoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
pathSegments := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
Expand Down Expand Up @@ -604,9 +637,18 @@ func (lb *LoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

// Extract the client IP
clientIP := getClientIP(r)

// Log the client IP at TRACE level, obeying the rate limit
lb.logRateLimited("TRACE", "client_ip_"+clientIP+"_"+networkName, "Client IP %s made a request to network %s", clientIP, networkName)

// Increment request counter
lb.requestCounter.WithLabelValues(networkName).Inc()

// Increment request counter by client IP
lb.requestByIPCounter.WithLabelValues(networkName, clientIP).Inc()

// Serve the request using the proxy
proxy.ServeHTTP(w, r)

Expand Down

0 comments on commit e3c8caa

Please sign in to comment.