Skip to content

Commit

Permalink
Merge pull request #1644 from c9s/narumi/fee-budget
Browse files Browse the repository at this point in the history
REFACTOR: Extract and move FeeBudget from xgap
  • Loading branch information
narumiruna authored Jun 20, 2024
2 parents 9bf635d + 9cbf8a0 commit 396ee68
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 89 deletions.
92 changes: 92 additions & 0 deletions pkg/strategy/common/fee_budget.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package common

import (
"sync"
"time"

"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/types"
log "github.com/sirupsen/logrus"
)

type FeeBudget struct {
DailyFeeBudgets map[string]fixedpoint.Value `json:"dailyFeeBudgets,omitempty"`
State *State `persistence:"state"`

mu sync.Mutex
}

func (f *FeeBudget) Initialize() {
if f.State == nil {
f.State = &State{}
f.State.Reset()
}

if f.State.IsOver24Hours() {
log.Warn("[FeeBudget] state is over 24 hours, resetting to zero")
f.State.Reset()
}
}

func (f *FeeBudget) IsBudgetAllowed() bool {
if f.DailyFeeBudgets == nil {
return true
}

if f.State.AccumulatedFees == nil {
return true
}

if f.State.IsOver24Hours() {
f.State.Reset()
return true
}

for asset, budget := range f.DailyFeeBudgets {
if fee, ok := f.State.AccumulatedFees[asset]; ok {
if fee.Compare(budget) >= 0 {
log.Warnf("[FeeBudget] accumulative fee %s exceeded the fee budget %s, skipping...", fee.String(), budget.String())
return false
}
}
}

return true
}

func (f *FeeBudget) HandleTradeUpdate(trade types.Trade) {
log.Infof("[FeeBudget] received trade %s", trade.String())

if f.State.IsOver24Hours() {
f.State.Reset()
}

// safe check
if f.State.AccumulatedFees == nil {
f.mu.Lock()
f.State.AccumulatedFees = make(map[string]fixedpoint.Value)
f.mu.Unlock()
}

f.State.AccumulatedFees[trade.FeeCurrency] = f.State.AccumulatedFees[trade.FeeCurrency].Add(trade.Fee)
log.Infof("[FeeBudget] accumulated fee: %s %s", f.State.AccumulatedFees[trade.FeeCurrency].String(), trade.FeeCurrency)
}

type State struct {
AccumulatedFeeStartedAt time.Time `json:"accumulatedFeeStartedAt,omitempty"`
AccumulatedFees map[string]fixedpoint.Value `json:"accumulatedFees,omitempty"`
}

func (s *State) IsOver24Hours() bool {
return time.Since(s.AccumulatedFeeStartedAt) >= 24*time.Hour
}

func (s *State) Reset() {
t := time.Now()
dateTime := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())

log.Infof("[State] resetting accumulated started time to: %s", dateTime)

s.AccumulatedFeeStartedAt = dateTime
s.AccumulatedFees = make(map[string]fixedpoint.Value)
}
56 changes: 56 additions & 0 deletions pkg/strategy/common/fee_budget_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package common

import (
"testing"
"time"

"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/types"
"github.com/stretchr/testify/assert"
)

func TestFeeBudget(t *testing.T) {
cases := []struct {
budgets map[string]fixedpoint.Value
trades []types.Trade
expected bool
}{
{
budgets: map[string]fixedpoint.Value{
"MAX": fixedpoint.NewFromFloat(0.5),
},
trades: []types.Trade{
{FeeCurrency: "MAX", Fee: fixedpoint.NewFromFloat(0.1)},
{FeeCurrency: "USDT", Fee: fixedpoint.NewFromFloat(10.0)},
},
expected: true,
},
{
budgets: map[string]fixedpoint.Value{
"MAX": fixedpoint.NewFromFloat(0.5),
},
trades: []types.Trade{
{FeeCurrency: "MAX", Fee: fixedpoint.NewFromFloat(0.1)},
{FeeCurrency: "MAX", Fee: fixedpoint.NewFromFloat(0.5)},
{FeeCurrency: "USDT", Fee: fixedpoint.NewFromFloat(10.0)},
},
expected: false,
},
}

for _, c := range cases {
feeBudget := FeeBudget{
DailyFeeBudgets: c.budgets,
}
feeBudget.Initialize()

for _, trade := range c.trades {
feeBudget.HandleTradeUpdate(trade)
}
assert.Equal(t, c.expected, feeBudget.IsBudgetAllowed())

// test reset
feeBudget.State.AccumulatedFeeStartedAt = feeBudget.State.AccumulatedFeeStartedAt.Add(-24 * time.Hour)
assert.True(t, feeBudget.IsBudgetAllowed())
}
}
35 changes: 29 additions & 6 deletions pkg/strategy/random/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func init() {

type Strategy struct {
*common.Strategy
*common.FeeBudget

Environment *bbgo.Environment
Market types.Market
Expand All @@ -45,6 +46,10 @@ func (s *Strategy) Initialize() error {
if s.Strategy == nil {
s.Strategy = &common.Strategy{}
}

if s.FeeBudget == nil {
s.FeeBudget = &common.FeeBudget{}
}
return nil
}

Expand All @@ -71,11 +76,25 @@ func (s *Strategy) Subscribe(session *bbgo.ExchangeSession) {}

func (s *Strategy) Run(ctx context.Context, _ bbgo.OrderExecutor, session *bbgo.ExchangeSession) error {
s.Strategy.Initialize(ctx, s.Environment, session, s.Market, s.ID(), s.InstanceID())
s.FeeBudget.Initialize()

session.UserDataStream.OnStart(func() {
if s.OnStart {
s.placeOrder()
if !s.OnStart {
return
}

if !s.FeeBudget.IsBudgetAllowed() {
return
}

s.placeOrder(ctx)
})

session.UserDataStream.OnTradeUpdate(func(trade types.Trade) {
if trade.Symbol != s.Symbol {
return
}
s.FeeBudget.HandleTradeUpdate(trade)
})

// the shutdown handler, you can cancel all orders
Expand All @@ -86,15 +105,19 @@ func (s *Strategy) Run(ctx context.Context, _ bbgo.OrderExecutor, session *bbgo.
})

s.cron = cron.New()
s.cron.AddFunc(s.Schedule, s.placeOrder)
s.cron.AddFunc(s.Schedule, func() {
if !s.FeeBudget.IsBudgetAllowed() {
return
}

s.placeOrder(ctx)
})
s.cron.Start()

return nil
}

func (s *Strategy) placeOrder() {
ctx := context.Background()

func (s *Strategy) placeOrder(ctx context.Context) {
baseBalance, ok := s.Session.GetAccount().Balance(s.Market.BaseCurrency)
if !ok {
log.Errorf("base balance not found")
Expand Down
101 changes: 18 additions & 83 deletions pkg/strategy/xgap/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,9 @@ func (s *Strategy) InstanceID() string {
return fmt.Sprintf("%s:%s", ID, s.Symbol)
}

type State struct {
AccumulatedFeeStartedAt time.Time `json:"accumulatedFeeStartedAt,omitempty"`
AccumulatedFees map[string]fixedpoint.Value `json:"accumulatedFees,omitempty"`
AccumulatedVolume fixedpoint.Value `json:"accumulatedVolume,omitempty"`
}

func (s *State) IsOver24Hours() bool {
return time.Since(s.AccumulatedFeeStartedAt) >= 24*time.Hour
}

func (s *State) Reset() {
t := time.Now()
dateTime := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())

log.Infof("resetting accumulated started time to: %s", dateTime)

s.AccumulatedFeeStartedAt = dateTime
s.AccumulatedFees = make(map[string]fixedpoint.Value)
s.AccumulatedVolume = fixedpoint.Zero
}

type Strategy struct {
*common.Strategy
*common.FeeBudget

Environment *bbgo.Environment

Expand All @@ -70,18 +50,15 @@ type Strategy struct {
Quantity fixedpoint.Value `json:"quantity"`
DryRun bool `json:"dryRun"`

DailyFeeBudgets map[string]fixedpoint.Value `json:"dailyFeeBudgets,omitempty"`
DailyMaxVolume fixedpoint.Value `json:"dailyMaxVolume,omitempty"`
DailyTargetVolume fixedpoint.Value `json:"dailyTargetVolume,omitempty"`
UpdateInterval types.Duration `json:"updateInterval"`
SimulateVolume bool `json:"simulateVolume"`
SimulatePrice bool `json:"simulatePrice"`
DailyMaxVolume fixedpoint.Value `json:"dailyMaxVolume,omitempty"`
DailyTargetVolume fixedpoint.Value `json:"dailyTargetVolume,omitempty"`
UpdateInterval types.Duration `json:"updateInterval"`
SimulateVolume bool `json:"simulateVolume"`
SimulatePrice bool `json:"simulatePrice"`

sourceSession, tradingSession *bbgo.ExchangeSession
sourceMarket, tradingMarket types.Market

State *State `persistence:"state"`

mu sync.Mutex
lastSourceKLine, lastTradingKLine types.KLine
sourceBook, tradingBook *types.StreamOrderBook
Expand All @@ -93,6 +70,10 @@ func (s *Strategy) Initialize() error {
if s.Strategy == nil {
s.Strategy = &common.Strategy{}
}

if s.FeeBudget == nil {
s.FeeBudget = &common.FeeBudget{}
}
return nil
}

Expand All @@ -107,48 +88,6 @@ func (s *Strategy) Defaults() error {
return nil
}

func (s *Strategy) isBudgetAllowed() bool {
if s.DailyFeeBudgets == nil {
return true
}

if s.State.AccumulatedFees == nil {
return true
}

for asset, budget := range s.DailyFeeBudgets {
if fee, ok := s.State.AccumulatedFees[asset]; ok {
if fee.Compare(budget) >= 0 {
log.Warnf("accumulative fee %s exceeded the fee budget %s, skipping...", fee.String(), budget.String())
return false
}
}
}

return true
}

func (s *Strategy) handleTradeUpdate(trade types.Trade) {
log.Infof("received trade %s", trade.String())

if trade.Symbol != s.Symbol {
return
}

if s.State.IsOver24Hours() {
s.State.Reset()
}

// safe check
if s.State.AccumulatedFees == nil {
s.State.AccumulatedFees = make(map[string]fixedpoint.Value)
}

s.State.AccumulatedFees[trade.FeeCurrency] = s.State.AccumulatedFees[trade.FeeCurrency].Add(trade.Fee)
s.State.AccumulatedVolume = s.State.AccumulatedVolume.Add(trade.Quantity)
log.Infof("accumulated fee: %s %s", s.State.AccumulatedFees[trade.FeeCurrency].String(), trade.FeeCurrency)
}

func (s *Strategy) CrossSubscribe(sessions map[string]*bbgo.ExchangeSession) {
sourceSession, ok := sessions[s.SourceExchange]
if !ok {
Expand Down Expand Up @@ -191,19 +130,10 @@ func (s *Strategy) CrossRun(ctx context.Context, _ bbgo.OrderExecutionRouter, se
}

s.Strategy.Initialize(ctx, s.Environment, tradingSession, s.tradingMarket, ID, s.InstanceID())
s.FeeBudget.Initialize()

s.stopC = make(chan struct{})

if s.State == nil {
s.State = &State{}
s.State.Reset()
}

if s.State.IsOver24Hours() {
log.Warn("state is over 24 hours, resetting to zero")
s.State.Reset()
}

bbgo.OnShutdown(ctx, func(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
close(s.stopC)
Expand All @@ -230,7 +160,12 @@ func (s *Strategy) CrossRun(ctx context.Context, _ bbgo.OrderExecutionRouter, se
s.tradingBook = types.NewStreamBook(s.Symbol)
s.tradingBook.BindStream(s.tradingSession.MarketDataStream)

s.tradingSession.UserDataStream.OnTradeUpdate(s.handleTradeUpdate)
s.tradingSession.UserDataStream.OnTradeUpdate(func(trade types.Trade) {
if trade.Symbol != s.Symbol {
return
}
s.FeeBudget.HandleTradeUpdate(trade)
})

go func() {
ticker := time.NewTicker(
Expand All @@ -247,7 +182,7 @@ func (s *Strategy) CrossRun(ctx context.Context, _ bbgo.OrderExecutionRouter, se
return

case <-ticker.C:
if !s.isBudgetAllowed() {
if !s.IsBudgetAllowed() {
continue
}

Expand Down

0 comments on commit 396ee68

Please sign in to comment.