From 892c60708b3298cd535f89994ef974ca07656de8 Mon Sep 17 00:00:00 2001 From: narumi Date: Sun, 26 May 2024 22:38:17 +0800 Subject: [PATCH] add fee budget support to random strategy --- pkg/strategy/common/fee_budget.go | 8 +++--- pkg/strategy/common/fee_budget_test.go | 6 ++++- pkg/strategy/random/strategy.go | 35 +++++++++++++++++++++----- 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/pkg/strategy/common/fee_budget.go b/pkg/strategy/common/fee_budget.go index 57fd932563..ca999ea585 100644 --- a/pkg/strategy/common/fee_budget.go +++ b/pkg/strategy/common/fee_budget.go @@ -34,6 +34,11 @@ func (f *FeeBudget) IsBudgetAllowed() bool { return true } + if f.State.IsOver24Hours() { + f.State.Reset() + return true + } + for asset, budget := range f.DailyFeeBudgets { if fee, ok := f.State.AccumulatedFees[asset]; ok { if fee.Compare(budget) >= 0 { @@ -59,14 +64,12 @@ func (f *FeeBudget) HandleTradeUpdate(trade types.Trade) { } f.State.AccumulatedFees[trade.FeeCurrency] = f.State.AccumulatedFees[trade.FeeCurrency].Add(trade.Fee) - f.State.AccumulatedVolume = f.State.AccumulatedVolume.Add(trade.Quantity) log.Infof("[FeeBudget] accumulated fee: %s %s", f.State.AccumulatedFees[trade.FeeCurrency].String(), trade.FeeCurrency) } type State struct { AccumulatedFeeStartedAt time.Time `json:"accumulatedFeeStartedAt,omitempty"` AccumulatedFees map[string]fixedpoint.Value `json:"accumulatedFees,omitempty"` - AccumulatedVolume fixedpoint.Value `json:"accumulatedVolume,omitempty"` } func (s *State) IsOver24Hours() bool { @@ -81,5 +84,4 @@ func (s *State) Reset() { s.AccumulatedFeeStartedAt = dateTime s.AccumulatedFees = make(map[string]fixedpoint.Value) - s.AccumulatedVolume = fixedpoint.Zero } diff --git a/pkg/strategy/common/fee_budget_test.go b/pkg/strategy/common/fee_budget_test.go index 7ceb972b48..d57658b159 100644 --- a/pkg/strategy/common/fee_budget_test.go +++ b/pkg/strategy/common/fee_budget_test.go @@ -2,6 +2,7 @@ package common import ( "testing" + "time" "github.com/c9s/bbgo/pkg/fixedpoint" "github.com/c9s/bbgo/pkg/types" @@ -46,7 +47,10 @@ func TestFeeBudget(t *testing.T) { for _, trade := range c.trades { feeBudget.HandleTradeUpdate(trade) } - assert.Equal(t, c.expected, feeBudget.IsBudgetAllowed()) + + // test reset + feeBudget.State.AccumulatedFeeStartedAt = feeBudget.State.AccumulatedFeeStartedAt.Add(-24 * time.Hour) + assert.True(t, feeBudget.IsBudgetAllowed()) } } diff --git a/pkg/strategy/random/strategy.go b/pkg/strategy/random/strategy.go index 1d95440327..e0a0f7091d 100644 --- a/pkg/strategy/random/strategy.go +++ b/pkg/strategy/random/strategy.go @@ -24,6 +24,7 @@ func init() { type Strategy struct { *common.Strategy + *common.FeeBudget Environment *bbgo.Environment Market types.Market @@ -45,6 +46,10 @@ func (s *Strategy) Initialize() error { if s.Strategy == nil { s.Strategy = &common.Strategy{} } + + if s.FeeBudget == nil { + s.FeeBudget = &common.FeeBudget{} + } return nil } @@ -71,11 +76,25 @@ func (s *Strategy) Subscribe(session *bbgo.ExchangeSession) {} func (s *Strategy) Run(ctx context.Context, _ bbgo.OrderExecutor, session *bbgo.ExchangeSession) error { s.Strategy.Initialize(ctx, s.Environment, session, s.Market, s.ID(), s.InstanceID()) + s.FeeBudget.Initialize() session.UserDataStream.OnStart(func() { - if s.OnStart { - s.placeOrder() + if !s.OnStart { + return + } + + if !s.FeeBudget.IsBudgetAllowed() { + return + } + + s.placeOrder(ctx) + }) + + session.UserDataStream.OnTradeUpdate(func(trade types.Trade) { + if trade.Symbol != s.Symbol { + return } + s.FeeBudget.HandleTradeUpdate(trade) }) // the shutdown handler, you can cancel all orders @@ -86,15 +105,19 @@ func (s *Strategy) Run(ctx context.Context, _ bbgo.OrderExecutor, session *bbgo. }) s.cron = cron.New() - s.cron.AddFunc(s.Schedule, s.placeOrder) + s.cron.AddFunc(s.Schedule, func() { + if !s.FeeBudget.IsBudgetAllowed() { + return + } + + s.placeOrder(ctx) + }) s.cron.Start() return nil } -func (s *Strategy) placeOrder() { - ctx := context.Background() - +func (s *Strategy) placeOrder(ctx context.Context) { baseBalance, ok := s.Session.GetAccount().Balance(s.Market.BaseCurrency) if !ok { log.Errorf("base balance not found")