Skip to content

Commit

Permalink
xalign: integrate deviation detector into xalign
Browse files Browse the repository at this point in the history
  • Loading branch information
c9s committed Dec 2, 2024
1 parent 6d93429 commit 6292118
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 108 deletions.
62 changes: 20 additions & 42 deletions pkg/strategy/xalign/detector/deviation.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@ type Record[T any] struct {

type DeviationDetector[T any] struct {
mu sync.Mutex
records map[string][]Record[T] // Stores records for different keys
expectedValue T // Expected value for comparison
tolerance float64 // Tolerance percentage (e.g., 0.01 for 1%)
duration time.Duration // Time limit for sustained deviation
toFloat64 func(T) float64 // Function to convert T to float64
expectedValue T // Expected value for comparison
tolerance float64 // Tolerance percentage (e.g., 0.01 for 1%)
duration time.Duration // Time limit for sustained deviation
toFloat64 func(T) float64 // Function to convert T to float64
records []Record[T] // Tracks deviation records
}

// NewDeviationDetector creates a new instance of DeviationDetector
func NewDeviationDetector[T any](
expectedValue T, tolerance float64, duration time.Duration, toFloat64 func(T) float64,
) *DeviationDetector[T] {
// If no conversion function is provided and T is float64, use the default converter
if toFloat64 == nil {
if _, ok := any(expectedValue).(float64); ok {
toFloat64 = func(value T) float64 {
Expand All @@ -36,16 +35,15 @@ func NewDeviationDetector[T any](
}

return &DeviationDetector[T]{
records: make(map[string][]Record[T]),
expectedValue: expectedValue,
tolerance: tolerance,
duration: duration,
toFloat64: toFloat64,
records: nil,
}
}

// AddRecord adds a new record and checks deviation status
func (d *DeviationDetector[T]) AddRecord(key string, value T, at time.Time) (bool, time.Duration) {
func (d *DeviationDetector[T]) AddRecord(value T, at time.Time) (bool, time.Duration) {
d.mu.Lock()
defer d.mu.Unlock()

Expand All @@ -56,58 +54,38 @@ func (d *DeviationDetector[T]) AddRecord(key string, value T, at time.Time) (boo

// Reset records if deviation is within tolerance
if deviationPercentage <= d.tolerance {
delete(d.records, key)
d.records = nil
return false, 0
}

// If deviation exceeds tolerance, track the record
records, exists := d.records[key]
if !exists {
if len(d.records) == 0 {
// No prior deviation, start tracking
d.records[key] = []Record[T]{{Value: value, Time: at}}
d.records = []Record[T]{{Value: value, Time: at}}
return false, 0
}

// If deviation already being tracked, append the new record
d.records[key] = append(records, Record[T]{Value: value, Time: at})
// Append new record
d.records = append(d.records, Record[T]{Value: value, Time: at})

// Calculate the duration of sustained deviation
firstRecord := records[0]
// Calculate the sustained duration
firstRecord := d.records[0]
sustainedDuration := at.Sub(firstRecord.Time)
return sustainedDuration >= d.duration, sustainedDuration
}

// GetRecords retrieves all records associated with the specified key
func (d *DeviationDetector[T]) GetRecords(key string) []Record[T] {
// GetRecords retrieves all deviation records
func (d *DeviationDetector[T]) GetRecords() []Record[T] {
d.mu.Lock()
defer d.mu.Unlock()

if records, exists := d.records[key]; exists {
return records
}
return nil
return append([]Record[T](nil), d.records...) // Return a copy of the records
}

// ClearRecords removes all records associated with the specified key
func (d *DeviationDetector[T]) ClearRecords(key string) {
// ClearRecords clears all deviation records
func (d *DeviationDetector[T]) ClearRecords() {
d.mu.Lock()
defer d.mu.Unlock()

delete(d.records, key)
}

// PruneOldRecords removes records that are older than the specified duration
func (d *DeviationDetector[T]) PruneOldRecords(now time.Time) {
d.mu.Lock()
defer d.mu.Unlock()

for key, records := range d.records {
prunedRecords := make([]Record[T], 0)
for _, record := range records {
if now.Sub(record.Time) <= d.duration {
prunedRecords = append(prunedRecords, record)
}
}
d.records[key] = prunedRecords
}
d.records = nil
}
125 changes: 63 additions & 62 deletions pkg/strategy/xalign/detector/deviation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,74 +8,61 @@ import (
"github.com/c9s/bbgo/pkg/types"
)

func TestDeviationWithTolerancePercentage(t *testing.T) {
// Initialize DeviationDetector with float64 values
func TestBalanceDeviationDetector(t *testing.T) {
// Initialize DeviationDetector for types.Balance
detector := NewDeviationDetector(
100.0, // Expected value
types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(10.0)}, // Expected balance
0.01, // Tolerance percentage (1%)
time.Minute*5, // Duration for sustained deviation
nil, // Use default conversion for float64
time.Minute*4, // Duration for sustained deviation
func(b types.Balance) float64 {
return b.Net().Float64() // Use Net() as the base for deviation detection
},
)

// Define timestamps for testing
t1 := time.Date(2023, 1, 1, 10, 0, 0, 0, time.UTC)
t2 := t1.Add(4 * time.Minute)
t3 := t1.Add(6 * time.Minute)
now := time.Now()

// Add a record within tolerance (1%)
reset, sustainedDuration := detector.AddRecord("BTC", 101.0, t1)
if len(detector.GetRecords("BTC")) != 0 || reset {
t.Errorf("Expected records to reset when value is within tolerance")
// Add a balance record within tolerance
reset, sustainedDuration := detector.AddRecord(
types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(10.05)},
now,
)
if reset {
t.Errorf("Expected no sustained deviation for value within tolerance")
}
if sustainedDuration != 0 {
t.Errorf("Expected sustained duration to be 0 for value within tolerance, got %v", sustainedDuration)
}

// Add a record outside tolerance
reset, sustainedDuration = detector.AddRecord("BTC", 110.0, t1)
if reset || sustainedDuration != 0 {
// Add a balance record outside tolerance
reset, sustainedDuration = detector.AddRecord(
types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(11.0)},
now.Add(2*time.Minute),
)
if reset {
t.Errorf("Expected no sustained deviation initially")
}

// Add another record within duration
reset, sustainedDuration = detector.AddRecord("BTC", 112.0, t2)
if reset || sustainedDuration != 4*time.Minute {
t.Errorf("Expected sustained deviation to be less than threshold")
if sustainedDuration != 0 {
t.Errorf("Expected sustained duration to be 0 initially, got %v", sustainedDuration)
}

// Add another record exceeding duration
reset, sustainedDuration = detector.AddRecord("BTC", 112.0, t3)
if !reset || sustainedDuration != 6*time.Minute {
t.Errorf("Expected sustained deviation to exceed threshold")
}
}

func TestDefaultToFloat(t *testing.T) {
// Test default toFloat64 for float64 type
detector := NewDeviationDetector(
100.0, // Expected value
0.01, // Tolerance percentage (1%)
time.Minute*5, // Duration for sustained deviation
nil, // Use default conversion for float64
reset, sustainedDuration = detector.AddRecord(
types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(11.5)},
now.Add(6*time.Minute),
)

// Define timestamps for testing
t1 := time.Now()

// Add a record within tolerance
reset, _ := detector.AddRecord("BTC", 100.5, t1)
if reset {
t.Errorf("Expected no sustained deviation for value within tolerance")
if !reset {
t.Errorf("Expected reset to be true")
}

// Add a record outside tolerance
reset, _ = detector.AddRecord("BTC", 105.0, t1.Add(2*time.Minute))
if reset {
t.Errorf("Expected no sustained deviation initially")
if sustainedDuration != 4*time.Minute {
t.Errorf("Expected sustained deviation to exceed threshold, got %v", sustainedDuration)
}
}

func TestBalanceDeviationDetector(t *testing.T) {
func TestBalanceRecordTracking(t *testing.T) {
// Initialize DeviationDetector for types.Balance
detector := NewDeviationDetector(
types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(10.0)}, // Expected value
0.01, // Tolerance (1%)
types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(10.0)}, // Expected balance
0.01, // Tolerance percentage (1%)
time.Minute*5, // Duration for sustained deviation
func(b types.Balance) float64 {
return b.Net().Float64()
Expand All @@ -84,21 +71,35 @@ func TestBalanceDeviationDetector(t *testing.T) {

now := time.Now()

// Add a balance record within tolerance
reset, _ := detector.AddRecord("BTC", types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(10.05)}, now)
if reset {
t.Errorf("Expected no sustained deviation for value within tolerance")
// Add a balance record outside tolerance
_, _ = detector.AddRecord(
types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(11.0)},
now,
)

// Check if record is being tracked
records := detector.GetRecords()
if len(records) != 1 {
t.Errorf("Expected 1 record, got %d", len(records))
}

// Add a balance record outside tolerance
reset, _ = detector.AddRecord("BTC", types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(9.5)}, now.Add(2*time.Minute))
if reset {
t.Errorf("Expected no sustained deviation initially")
// Add another record
_, _ = detector.AddRecord(
types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(11.5)},
now.Add(2*time.Minute),
)
records = detector.GetRecords()
if len(records) != 2 {
t.Errorf("Expected 2 records, got %d", len(records))
}

// Add another record exceeding duration
reset, sustainedDuration := detector.AddRecord("BTC", types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(9.0)}, now.Add(6*time.Minute))
if !reset || sustainedDuration != 6*time.Minute {
t.Errorf("Expected sustained deviation to exceed threshold, got %v", sustainedDuration)
// Add a balance record within tolerance to reset
_, _ = detector.AddRecord(
types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(10.05)},
now.Add(4*time.Minute),
)
records = detector.GetRecords()
if len(records) != 0 {
t.Errorf("Expected records to be cleared, got %d", len(records))
}
}
23 changes: 21 additions & 2 deletions pkg/strategy/xalign/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ type Strategy struct {

faultBalanceRecords map[string][]TimeBalance

detector detector.Record[fixedpoint.Value]
deviationDetectors map[string]*detector.DeviationDetector[types.Balance]

priceResolver *pricesolver.SimplePriceSolver

Expand Down Expand Up @@ -141,12 +141,20 @@ func (s *Strategy) CrossSubscribe(sessions map[string]*bbgo.ExchangeSession) {
}

func (s *Strategy) Defaults() error {
s.BalanceToleranceRange = fixedpoint.NewFromFloat(0.01)
if s.BalanceToleranceRange == fixedpoint.Zero {
s.BalanceToleranceRange = fixedpoint.NewFromFloat(0.01)
}

if s.Duration == 0 {
s.Duration = types.Duration(15 * time.Minute)
}

return nil
}

func (s *Strategy) Initialize() error {
s.activeTransferNotificationLimiter = rate.NewLimiter(rate.Every(5*time.Minute), 1)
s.deviationDetectors = make(map[string]*detector.DeviationDetector[types.Balance])
return nil
}

Expand Down Expand Up @@ -444,6 +452,17 @@ func (s *Strategy) CrossRun(ctx context.Context, _ bbgo.OrderExecutionRouter, se

s.orderStore = core.NewOrderStore("")

for currency, expectedValue := range s.ExpectedBalances {
s.deviationDetectors[currency] = detector.NewDeviationDetector(
types.Balance{Currency: currency, NetAsset: expectedValue}, // Expected value
s.BalanceToleranceRange.Float64(), // Tolerance (1%)
s.Duration.Duration(), // Duration for sustained deviation
func(b types.Balance) float64 {
return b.Net().Float64()
},
)
}

markets := types.MarketMap{}
for _, sessionName := range s.PreferredSessions {
session, ok := sessions[sessionName]
Expand Down
7 changes: 5 additions & 2 deletions pkg/types/balance.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ func (b Balance) Total() fixedpoint.Value {

// Net returns the net asset value (total - debt)
func (b Balance) Net() fixedpoint.Value {
total := b.Total()
return total.Sub(b.Debt())
if !b.NetAsset.IsZero() {
return b.NetAsset
}

return b.Total().Sub(b.Debt())
}

func (b Balance) Debt() fixedpoint.Value {
Expand Down

0 comments on commit 6292118

Please sign in to comment.