-
-
Notifications
You must be signed in to change notification settings - Fork 302
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
220 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
package detector | ||
|
||
import ( | ||
"math" | ||
"sync" | ||
"time" | ||
) | ||
|
||
type Record[T any] struct { | ||
Value T | ||
Time time.Time | ||
} | ||
|
||
type DeviationDetector[T any] struct { | ||
mu sync.Mutex | ||
records map[string][]Record[T] // Stores records for different keys | ||
expectedValue T // Expected value for comparison | ||
tolerance float64 // Tolerance percentage (e.g., 0.01 for 1%) | ||
duration time.Duration // Time limit for sustained deviation | ||
toFloat64 func(T) float64 // Function to convert T to float64 | ||
} | ||
|
||
// NewDeviationDetector creates a new instance of DeviationDetector | ||
func NewDeviationDetector[T any]( | ||
expectedValue T, tolerance float64, duration time.Duration, toFloat64 func(T) float64, | ||
) *DeviationDetector[T] { | ||
// If no conversion function is provided and T is float64, use the default converter | ||
if toFloat64 == nil { | ||
if _, ok := any(expectedValue).(float64); ok { | ||
toFloat64 = func(value T) float64 { | ||
return any(value).(float64) | ||
} | ||
} else { | ||
panic("No conversion function provided for non-float64 type") | ||
} | ||
} | ||
|
||
return &DeviationDetector[T]{ | ||
records: make(map[string][]Record[T]), | ||
expectedValue: expectedValue, | ||
tolerance: tolerance, | ||
duration: duration, | ||
toFloat64: toFloat64, | ||
} | ||
} | ||
|
||
// AddRecord adds a new record and checks deviation status | ||
func (d *DeviationDetector[T]) AddRecord(key string, value T, at time.Time) (bool, time.Duration) { | ||
d.mu.Lock() | ||
defer d.mu.Unlock() | ||
|
||
// Calculate deviation percentage | ||
expected := d.toFloat64(d.expectedValue) | ||
current := d.toFloat64(value) | ||
deviationPercentage := math.Abs((current - expected) / expected) | ||
|
||
// Reset records if deviation is within tolerance | ||
if deviationPercentage <= d.tolerance { | ||
delete(d.records, key) | ||
return false, 0 | ||
} | ||
|
||
// If deviation exceeds tolerance, track the record | ||
records, exists := d.records[key] | ||
if !exists { | ||
// No prior deviation, start tracking | ||
d.records[key] = []Record[T]{{Value: value, Time: at}} | ||
return false, 0 | ||
} | ||
|
||
// If deviation already being tracked, append the new record | ||
d.records[key] = append(records, Record[T]{Value: value, Time: at}) | ||
|
||
// Calculate the duration of sustained deviation | ||
firstRecord := records[0] | ||
sustainedDuration := at.Sub(firstRecord.Time) | ||
return sustainedDuration >= d.duration, sustainedDuration | ||
} | ||
|
||
// GetRecords retrieves all records associated with the specified key | ||
func (d *DeviationDetector[T]) GetRecords(key string) []Record[T] { | ||
d.mu.Lock() | ||
defer d.mu.Unlock() | ||
|
||
if records, exists := d.records[key]; exists { | ||
return records | ||
} | ||
return nil | ||
} | ||
|
||
// ClearRecords removes all records associated with the specified key | ||
func (d *DeviationDetector[T]) ClearRecords(key string) { | ||
d.mu.Lock() | ||
defer d.mu.Unlock() | ||
|
||
delete(d.records, key) | ||
} | ||
|
||
// PruneOldRecords removes records that are older than the specified duration | ||
func (d *DeviationDetector[T]) PruneOldRecords(now time.Time) { | ||
d.mu.Lock() | ||
defer d.mu.Unlock() | ||
|
||
for key, records := range d.records { | ||
prunedRecords := make([]Record[T], 0) | ||
for _, record := range records { | ||
if now.Sub(record.Time) <= d.duration { | ||
prunedRecords = append(prunedRecords, record) | ||
} | ||
} | ||
d.records[key] = prunedRecords | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
package detector | ||
|
||
import ( | ||
"testing" | ||
"time" | ||
|
||
"github.com/c9s/bbgo/pkg/fixedpoint" | ||
"github.com/c9s/bbgo/pkg/types" | ||
) | ||
|
||
func TestDeviationWithTolerancePercentage(t *testing.T) { | ||
// Initialize DeviationDetector with float64 values | ||
detector := NewDeviationDetector( | ||
100.0, // Expected value | ||
0.01, // Tolerance percentage (1%) | ||
time.Minute*5, // Duration for sustained deviation | ||
nil, // Use default conversion for float64 | ||
) | ||
|
||
// Define timestamps for testing | ||
t1 := time.Date(2023, 1, 1, 10, 0, 0, 0, time.UTC) | ||
t2 := t1.Add(4 * time.Minute) | ||
t3 := t1.Add(6 * time.Minute) | ||
|
||
// Add a record within tolerance (1%) | ||
reset, sustainedDuration := detector.AddRecord("BTC", 101.0, t1) | ||
if len(detector.GetRecords("BTC")) != 0 || reset { | ||
t.Errorf("Expected records to reset when value is within tolerance") | ||
} | ||
|
||
// Add a record outside tolerance | ||
reset, sustainedDuration = detector.AddRecord("BTC", 110.0, t1) | ||
if reset || sustainedDuration != 0 { | ||
t.Errorf("Expected no sustained deviation initially") | ||
} | ||
|
||
// Add another record within duration | ||
reset, sustainedDuration = detector.AddRecord("BTC", 112.0, t2) | ||
if reset || sustainedDuration != 4*time.Minute { | ||
t.Errorf("Expected sustained deviation to be less than threshold") | ||
} | ||
|
||
// Add another record exceeding duration | ||
reset, sustainedDuration = detector.AddRecord("BTC", 112.0, t3) | ||
if !reset || sustainedDuration != 6*time.Minute { | ||
t.Errorf("Expected sustained deviation to exceed threshold") | ||
} | ||
} | ||
|
||
func TestDefaultToFloat(t *testing.T) { | ||
// Test default toFloat64 for float64 type | ||
detector := NewDeviationDetector( | ||
100.0, // Expected value | ||
0.01, // Tolerance percentage (1%) | ||
time.Minute*5, // Duration for sustained deviation | ||
nil, // Use default conversion for float64 | ||
) | ||
|
||
// Define timestamps for testing | ||
t1 := time.Now() | ||
|
||
// Add a record within tolerance | ||
reset, _ := detector.AddRecord("BTC", 100.5, t1) | ||
if reset { | ||
t.Errorf("Expected no sustained deviation for value within tolerance") | ||
} | ||
|
||
// Add a record outside tolerance | ||
reset, _ = detector.AddRecord("BTC", 105.0, t1.Add(2*time.Minute)) | ||
if reset { | ||
t.Errorf("Expected no sustained deviation initially") | ||
} | ||
} | ||
|
||
func TestBalanceDeviationDetector(t *testing.T) { | ||
detector := NewDeviationDetector( | ||
types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(10.0)}, // Expected value | ||
0.01, // Tolerance (1%) | ||
time.Minute*5, // Duration for sustained deviation | ||
func(b types.Balance) float64 { | ||
return b.Net().Float64() | ||
}, | ||
) | ||
|
||
now := time.Now() | ||
|
||
// Add a balance record within tolerance | ||
reset, _ := detector.AddRecord("BTC", types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(10.05)}, now) | ||
if reset { | ||
t.Errorf("Expected no sustained deviation for value within tolerance") | ||
} | ||
|
||
// Add a balance record outside tolerance | ||
reset, _ = detector.AddRecord("BTC", types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(9.5)}, now.Add(2*time.Minute)) | ||
if reset { | ||
t.Errorf("Expected no sustained deviation initially") | ||
} | ||
|
||
// Add another record exceeding duration | ||
reset, sustainedDuration := detector.AddRecord("BTC", types.Balance{Currency: "BTC", NetAsset: fixedpoint.NewFromFloat(9.0)}, now.Add(6*time.Minute)) | ||
if !reset || sustainedDuration != 6*time.Minute { | ||
t.Errorf("Expected sustained deviation to exceed threshold, got %v", sustainedDuration) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters