From 6d9342940d9491675e04f89cec3dbc5c814bf0cb Mon Sep 17 00:00:00 2001 From: c9s Date: Mon, 2 Dec 2024 16:33:03 +0800 Subject: [PATCH] xalign: implement new detector --- pkg/strategy/xalign/detector/deviation.go | 113 ++++++++++++++++++ .../xalign/detector/deviation_test.go | 104 ++++++++++++++++ pkg/strategy/xalign/strategy.go | 3 + 3 files changed, 220 insertions(+) create mode 100644 pkg/strategy/xalign/detector/deviation.go create mode 100644 pkg/strategy/xalign/detector/deviation_test.go diff --git a/pkg/strategy/xalign/detector/deviation.go b/pkg/strategy/xalign/detector/deviation.go new file mode 100644 index 0000000000..732c237ca8 --- /dev/null +++ b/pkg/strategy/xalign/detector/deviation.go @@ -0,0 +1,113 @@ +package detector + +import ( + "math" + "sync" + "time" +) + +type Record[T any] struct { + Value T + Time time.Time +} + +type DeviationDetector[T any] struct { + mu sync.Mutex + records map[string][]Record[T] // Stores records for different keys + expectedValue T // Expected value for comparison + tolerance float64 // Tolerance percentage (e.g., 0.01 for 1%) + duration time.Duration // Time limit for sustained deviation + toFloat64 func(T) float64 // Function to convert T to float64 +} + +// NewDeviationDetector creates a new instance of DeviationDetector +func NewDeviationDetector[T any]( + expectedValue T, tolerance float64, duration time.Duration, toFloat64 func(T) float64, +) *DeviationDetector[T] { + // If no conversion function is provided and T is float64, use the default converter + if toFloat64 == nil { + if _, ok := any(expectedValue).(float64); ok { + toFloat64 = func(value T) float64 { + return any(value).(float64) + } + } else { + panic("No conversion function provided for non-float64 type") + } + } + + return &DeviationDetector[T]{ + records: make(map[string][]Record[T]), + expectedValue: expectedValue, + tolerance: tolerance, + duration: duration, + toFloat64: toFloat64, + } +} + +// AddRecord adds a new record and checks deviation status +func (d *DeviationDetector[T]) AddRecord(key string, value T, at time.Time) (bool, time.Duration) { + d.mu.Lock() + defer d.mu.Unlock() + + // Calculate deviation percentage + expected := d.toFloat64(d.expectedValue) + current := d.toFloat64(value) + deviationPercentage := math.Abs((current - expected) / expected) + + // Reset records if deviation is within tolerance + if deviationPercentage <= d.tolerance { + delete(d.records, key) + return false, 0 + } + + // If deviation exceeds tolerance, track the record + records, exists := d.records[key] + if !exists { + // No prior deviation, start tracking + d.records[key] = []Record[T]{{Value: value, Time: at}} + return false, 0 + } + + // If deviation already being tracked, append the new record + d.records[key] = append(records, Record[T]{Value: value, Time: at}) + + // Calculate the duration of sustained deviation + firstRecord := records[0] + sustainedDuration := at.Sub(firstRecord.Time) + return sustainedDuration >= d.duration, sustainedDuration +} + +// GetRecords retrieves all records associated with the specified key +func (d *DeviationDetector[T]) GetRecords(key string) []Record[T] { + d.mu.Lock() + defer d.mu.Unlock() + + if records, exists := d.records[key]; exists { + return records + } + return nil +} + +// ClearRecords removes all records associated with the specified key +func (d *DeviationDetector[T]) ClearRecords(key string) { + d.mu.Lock() + defer d.mu.Unlock() + + delete(d.records, key) +} + +// PruneOldRecords removes records that are older than the specified duration +func (d *DeviationDetector[T]) PruneOldRecords(now time.Time) { + d.mu.Lock() + defer d.mu.Unlock() + + for key, records := range d.records { + prunedRecords := make([]Record[T], 0) + for _, record := range records { + if now.Sub(record.Time) <= d.duration { + prunedRecords = append(prunedRecords, record) + } + } + d.records[key] = prunedRecords + } +} diff --git a/pkg/strategy/xalign/detector/deviation_test.go b/pkg/strategy/xalign/detector/deviation_test.go new file mode 100644 index 0000000000..9e25802aed --- /dev/null +++ b/pkg/strategy/xalign/detector/deviation_test.go @@ -0,0 +1,104 @@ +package detector + +import ( + "testing" + "time" + + "github.com/c9s/bbgo/pkg/fixedpoint" + "github.com/c9s/bbgo/pkg/types" +) + +func TestDeviationWithTolerancePercentage(t *testing.T) { + // Initialize DeviationDetector with float64 values + detector := NewDeviationDetector( + 100.0, // Expected value + 0.01, // Tolerance percentage (1%) + time.Minute*5, // Duration for sustained deviation + nil, // Use default conversion for float64 + ) + + // Define timestamps for testing + t1 := time.Date(2023, 1, 1, 10, 0, 0, 0, time.UTC) + t2 := t1.Add(4 * time.Minute) + t3 := t1.Add(6 * time.Minute) + + // Add a record within tolerance (1%) + reset, sustainedDuration := detector.AddRecord("BTC", 101.0, t1) + if len(detector.GetRecords("BTC")) != 0 || reset { + t.Errorf("Expected records to reset when value is within tolerance") + } + + // Add a record outside tolerance + reset, sustainedDuration = detector.AddRecord("BTC", 110.0, t1) + if reset || sustainedDuration != 0 { + t.Errorf("Expected no sustained deviation initially") + } + + // Add another record within duration + reset, sustainedDuration = detector.AddRecord("BTC", 112.0, t2) + if reset || sustainedDuration != 4*time.Minute { + t.Errorf("Expected sustained deviation to be less than threshold") + } + + // Add another record exceeding duration + reset, sustainedDuration = detector.AddRecord("BTC", 112.0, t3) + if !reset || sustainedDuration != 6*time.Minute { + t.Errorf("Expected sustained deviation to exceed threshold") + } +} + +func TestDefaultToFloat(t *testing.T) { + // Test default toFloat64 for float64 type + detector := NewDeviationDetector( + 100.0, // Expected value + 0.01, // Tolerance percentage (1%) + time.Minute*5, // Duration for sustained deviation + nil, // Use default conversion for float64 + ) + + // Define timestamps for testing + t1 := time.Now() + + // Add a record within tolerance + reset, _ := detector.AddRecord("BTC", 100.5, t1) + if reset { + t.Errorf("Expected no sustained deviation for value within tolerance") + } + + // Add a record outside tolerance + reset, _ = detector.AddRecord("BTC", 105.0, t1.Add(2*time.Minute)) + if reset { + t.Errorf("Expected no sustained deviation initially") + } +} + +func TestBalanceDeviationDetector(t *testing.T) { + detector := NewDeviationDetector( + types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(10.0)}, // Expected value + 0.01, // Tolerance (1%) + time.Minute*5, // Duration for sustained deviation + func(b types.Balance) float64 { + return b.Net().Float64() + }, + ) + + now := time.Now() + + // Add a balance record within tolerance + reset, _ := detector.AddRecord("BTC", types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(10.05)}, now) + if reset { + t.Errorf("Expected no sustained deviation for value within tolerance") + } + + // Add a balance record outside tolerance + reset, _ = detector.AddRecord("BTC", types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(9.5)}, now.Add(2*time.Minute)) + if reset { + t.Errorf("Expected no sustained deviation initially") + } + + // Add another record exceeding duration + reset, sustainedDuration := detector.AddRecord("BTC", types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(9.0)}, now.Add(6*time.Minute)) + if !reset || sustainedDuration != 6*time.Minute { + t.Errorf("Expected sustained deviation to exceed threshold, got %v", sustainedDuration) + } +} diff --git a/pkg/strategy/xalign/strategy.go b/pkg/strategy/xalign/strategy.go index 92d797f69b..744221662d 100644 --- a/pkg/strategy/xalign/strategy.go +++ b/pkg/strategy/xalign/strategy.go @@ -16,6 +16,7 @@ import ( "github.com/c9s/bbgo/pkg/core" "github.com/c9s/bbgo/pkg/fixedpoint" "github.com/c9s/bbgo/pkg/pricesolver" + "github.com/c9s/bbgo/pkg/strategy/xalign/detector" "github.com/c9s/bbgo/pkg/types" ) @@ -106,6 +107,8 @@ type Strategy struct { faultBalanceRecords map[string][]TimeBalance + detector detector.Record[fixedpoint.Value] + priceResolver *pricesolver.SimplePriceSolver sessions map[string]*bbgo.ExchangeSession