Skip to content

Commit

Permalink
xalign: implement new detector
Browse files Browse the repository at this point in the history
  • Loading branch information
c9s committed Dec 2, 2024
1 parent e8dde75 commit 6d93429
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 0 deletions.
113 changes: 113 additions & 0 deletions pkg/strategy/xalign/detector/deviation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package detector

import (
"math"
"sync"
"time"
)

type Record[T any] struct {
Value T
Time time.Time
}

type DeviationDetector[T any] struct {
mu sync.Mutex
records map[string][]Record[T] // Stores records for different keys
expectedValue T // Expected value for comparison
tolerance float64 // Tolerance percentage (e.g., 0.01 for 1%)
duration time.Duration // Time limit for sustained deviation
toFloat64 func(T) float64 // Function to convert T to float64
}

// NewDeviationDetector creates a new instance of DeviationDetector
func NewDeviationDetector[T any](
expectedValue T, tolerance float64, duration time.Duration, toFloat64 func(T) float64,
) *DeviationDetector[T] {
// If no conversion function is provided and T is float64, use the default converter
if toFloat64 == nil {
if _, ok := any(expectedValue).(float64); ok {
toFloat64 = func(value T) float64 {
return any(value).(float64)
}
} else {
panic("No conversion function provided for non-float64 type")
}
}

return &DeviationDetector[T]{
records: make(map[string][]Record[T]),
expectedValue: expectedValue,
tolerance: tolerance,
duration: duration,
toFloat64: toFloat64,
}
}

// AddRecord adds a new record and checks deviation status
func (d *DeviationDetector[T]) AddRecord(key string, value T, at time.Time) (bool, time.Duration) {
d.mu.Lock()
defer d.mu.Unlock()

// Calculate deviation percentage
expected := d.toFloat64(d.expectedValue)
current := d.toFloat64(value)
deviationPercentage := math.Abs((current - expected) / expected)

// Reset records if deviation is within tolerance
if deviationPercentage <= d.tolerance {
delete(d.records, key)
return false, 0
}

// If deviation exceeds tolerance, track the record
records, exists := d.records[key]
if !exists {
// No prior deviation, start tracking
d.records[key] = []Record[T]{{Value: value, Time: at}}
return false, 0
}

// If deviation already being tracked, append the new record
d.records[key] = append(records, Record[T]{Value: value, Time: at})

// Calculate the duration of sustained deviation
firstRecord := records[0]
sustainedDuration := at.Sub(firstRecord.Time)
return sustainedDuration >= d.duration, sustainedDuration
}

// GetRecords retrieves all records associated with the specified key
func (d *DeviationDetector[T]) GetRecords(key string) []Record[T] {
d.mu.Lock()
defer d.mu.Unlock()

if records, exists := d.records[key]; exists {
return records
}
return nil
}

// ClearRecords removes all records associated with the specified key
func (d *DeviationDetector[T]) ClearRecords(key string) {
d.mu.Lock()
defer d.mu.Unlock()

delete(d.records, key)
}

// PruneOldRecords removes records that are older than the specified duration
func (d *DeviationDetector[T]) PruneOldRecords(now time.Time) {
d.mu.Lock()
defer d.mu.Unlock()

for key, records := range d.records {
prunedRecords := make([]Record[T], 0)
for _, record := range records {
if now.Sub(record.Time) <= d.duration {
prunedRecords = append(prunedRecords, record)
}
}
d.records[key] = prunedRecords
}
}
104 changes: 104 additions & 0 deletions pkg/strategy/xalign/detector/deviation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package detector

import (
"testing"
"time"

"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/types"
)

func TestDeviationWithTolerancePercentage(t *testing.T) {
// Initialize DeviationDetector with float64 values
detector := NewDeviationDetector(
100.0, // Expected value
0.01, // Tolerance percentage (1%)
time.Minute*5, // Duration for sustained deviation
nil, // Use default conversion for float64
)

// Define timestamps for testing
t1 := time.Date(2023, 1, 1, 10, 0, 0, 0, time.UTC)
t2 := t1.Add(4 * time.Minute)
t3 := t1.Add(6 * time.Minute)

// Add a record within tolerance (1%)
reset, sustainedDuration := detector.AddRecord("BTC", 101.0, t1)
if len(detector.GetRecords("BTC")) != 0 || reset {
t.Errorf("Expected records to reset when value is within tolerance")
}

// Add a record outside tolerance
reset, sustainedDuration = detector.AddRecord("BTC", 110.0, t1)
if reset || sustainedDuration != 0 {
t.Errorf("Expected no sustained deviation initially")
}

// Add another record within duration
reset, sustainedDuration = detector.AddRecord("BTC", 112.0, t2)
if reset || sustainedDuration != 4*time.Minute {
t.Errorf("Expected sustained deviation to be less than threshold")
}

// Add another record exceeding duration
reset, sustainedDuration = detector.AddRecord("BTC", 112.0, t3)
if !reset || sustainedDuration != 6*time.Minute {
t.Errorf("Expected sustained deviation to exceed threshold")
}
}

func TestDefaultToFloat(t *testing.T) {
// Test default toFloat64 for float64 type
detector := NewDeviationDetector(
100.0, // Expected value
0.01, // Tolerance percentage (1%)
time.Minute*5, // Duration for sustained deviation
nil, // Use default conversion for float64
)

// Define timestamps for testing
t1 := time.Now()

// Add a record within tolerance
reset, _ := detector.AddRecord("BTC", 100.5, t1)
if reset {
t.Errorf("Expected no sustained deviation for value within tolerance")
}

// Add a record outside tolerance
reset, _ = detector.AddRecord("BTC", 105.0, t1.Add(2*time.Minute))
if reset {
t.Errorf("Expected no sustained deviation initially")
}
}

func TestBalanceDeviationDetector(t *testing.T) {
detector := NewDeviationDetector(
types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(10.0)}, // Expected value
0.01, // Tolerance (1%)
time.Minute*5, // Duration for sustained deviation
func(b types.Balance) float64 {
return b.Net().Float64()
},
)

now := time.Now()

// Add a balance record within tolerance
reset, _ := detector.AddRecord("BTC", types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(10.05)}, now)
if reset {
t.Errorf("Expected no sustained deviation for value within tolerance")
}

// Add a balance record outside tolerance
reset, _ = detector.AddRecord("BTC", types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(9.5)}, now.Add(2*time.Minute))
if reset {
t.Errorf("Expected no sustained deviation initially")
}

// Add another record exceeding duration
reset, sustainedDuration := detector.AddRecord("BTC", types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(9.0)}, now.Add(6*time.Minute))
if !reset || sustainedDuration != 6*time.Minute {
t.Errorf("Expected sustained deviation to exceed threshold, got %v", sustainedDuration)
}
}
3 changes: 3 additions & 0 deletions pkg/strategy/xalign/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/c9s/bbgo/pkg/core"
"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/pricesolver"
"github.com/c9s/bbgo/pkg/strategy/xalign/detector"
"github.com/c9s/bbgo/pkg/types"
)

Expand Down Expand Up @@ -106,6 +107,8 @@ type Strategy struct {

faultBalanceRecords map[string][]TimeBalance

detector detector.Record[fixedpoint.Value]

priceResolver *pricesolver.SimplePriceSolver

sessions map[string]*bbgo.ExchangeSession
Expand Down

0 comments on commit 6d93429

Please sign in to comment.