Skip to content

Commit

Permalink
support extra migration packages
Browse files Browse the repository at this point in the history
  • Loading branch information
c9s committed Jan 24, 2024
1 parent ea70544 commit 729afcd
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 33 deletions.
4 changes: 2 additions & 2 deletions pkg/bbgo/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func BootstrapEnvironmentLightweight(ctx context.Context, environ *Environment,
}

func BootstrapEnvironment(ctx context.Context, environ *Environment, userConfig *Config) error {
if err := environ.ConfigureDatabase(ctx); err != nil {
if err := environ.ConfigureDatabase(ctx, userConfig); err != nil {
return err
}

Expand Down Expand Up @@ -66,5 +66,5 @@ func BootstrapEnvironment(ctx context.Context, environ *Environment, userConfig
}

func BootstrapBacktestEnvironment(ctx context.Context, environ *Environment) error {
return environ.ConfigureDatabase(ctx)
return environ.ConfigureDatabase(ctx, nil)
}
9 changes: 9 additions & 0 deletions pkg/bbgo/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,13 @@ type ServiceConfig struct {
GoogleSpreadSheetService *GoogleSpreadSheetServiceConfig `json:"googleSpreadSheet" yaml:"googleSpreadSheet"`
}

type DatabaseConfig struct {
Driver string `json:"driver"`
DSN string `json:"dsn"`

ExtraMigrationPackages []string `json:"extraMigrationPackages"`
}

type EnvironmentConfig struct {
DisableDefaultKLineSubscription bool `json:"disableDefaultKLineSubscription"`
DisableHistoryKLinePreload bool `json:"disableHistoryKLinePreload"`
Expand Down Expand Up @@ -358,6 +365,8 @@ type Config struct {

Service *ServiceConfig `json:"services,omitempty" yaml:"services,omitempty"`

DatabaseConfig *DatabaseConfig `json:"database,omitempty" yaml:"database,omitempty"`

Environment *EnvironmentConfig `json:"environment,omitempty" yaml:"environment,omitempty"`

Sessions map[string]*ExchangeSession `json:"sessions,omitempty" yaml:"sessions,omitempty"`
Expand Down
40 changes: 27 additions & 13 deletions pkg/bbgo/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,34 +162,48 @@ func (environ *Environment) SelectSessions(names ...string) map[string]*Exchange
return sessions
}

func (environ *Environment) ConfigureDatabase(ctx context.Context) error {
func (environ *Environment) ConfigureDatabase(ctx context.Context, config *Config) error {
// configureDB configures the database service based on the environment variable
if driver, ok := os.LookupEnv("DB_DRIVER"); ok {
var dbDriver string
var dbDSN string
var extraPkgNames []string

if dsn, ok := os.LookupEnv("DB_DSN"); ok {
return environ.ConfigureDatabaseDriver(ctx, driver, dsn)
}

} else if dsn, ok := os.LookupEnv("SQLITE3_DSN"); ok {

return environ.ConfigureDatabaseDriver(ctx, "sqlite3", dsn)
if config != nil && config.DatabaseConfig != nil {
dbDriver = config.DatabaseConfig.Driver
dbDSN = config.DatabaseConfig.DSN
extraPkgNames = config.DatabaseConfig.ExtraMigrationPackages
}

} else if dsn, ok := os.LookupEnv("MYSQL_URL"); ok {
if val, ok := os.LookupEnv("DB_DRIVER"); ok {
dbDriver = val
}

return environ.ConfigureDatabaseDriver(ctx, "mysql", dsn)
if val, ok := os.LookupEnv("DB_DSN"); ok {
dbDSN = val
} else if val, ok := os.LookupEnv("SQLITE3_DSN"); ok && (dbDriver == "" || dbDriver == "sqlite3") {
dbDSN = val
dbDriver = "sqlite3"
} else if val, ok := os.LookupEnv("MYSQL_URL"); ok && (dbDriver == "" || dbDriver == "mysql") {
dbDSN = val
dbDriver = "mysql"
}

if dbDriver == "" {
return fmt.Errorf("either env DB_DRIVER or config.Driver is not set")
}

return nil
return environ.ConfigureDatabaseDriver(ctx, dbDriver, dbDSN, extraPkgNames...)
}

func (environ *Environment) ConfigureDatabaseDriver(ctx context.Context, driver string, dsn string) error {
func (environ *Environment) ConfigureDatabaseDriver(ctx context.Context, driver string, dsn string, extraPkgNames ...string) error {
environ.DatabaseService = service.NewDatabaseService(driver, dsn)
err := environ.DatabaseService.Connect()
if err != nil {
return err
}

environ.DatabaseService.AddMigrationPackages(extraPkgNames...)

if err := environ.DatabaseService.Upgrade(ctx); err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ var accountCmd = &cobra.Command{
}

environ := bbgo.NewEnvironment()
if err := environ.ConfigureDatabase(ctx); err != nil {
if err := environ.ConfigureDatabase(ctx, userConfig); err != nil {
return err
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/market.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ var marketCmd = &cobra.Command{
}

environ := bbgo.NewEnvironment()
if err := environ.ConfigureDatabase(ctx); err != nil {
if err := environ.ConfigureDatabase(ctx, userConfig); err != nil {
return err
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/pnl.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ var PnLCmd = &cobra.Command{

environ := bbgo.NewEnvironment()

if err := environ.ConfigureDatabase(ctx); err != nil {
if err := environ.ConfigureDatabase(ctx, userConfig); err != nil {
return err
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ var SyncCmd = &cobra.Command{
}

environ := bbgo.NewEnvironment()
if err := environ.ConfigureDatabase(ctx); err != nil {
if err := environ.ConfigureDatabase(ctx, userConfig); err != nil {
return err
}

Expand Down
22 changes: 8 additions & 14 deletions pkg/service/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ type DatabaseService struct {
Driver string
DSN string
DB *sqlx.DB

migrationPackages []string
}

func NewDatabaseService(driver, dsn string) *DatabaseService {
Expand All @@ -35,7 +37,6 @@ func NewDatabaseService(driver, dsn string) *DatabaseService {
Driver: driver,
DSN: dsn,
}

}

func (s *DatabaseService) Connect() error {
Expand All @@ -50,6 +51,10 @@ func (s *DatabaseService) Insert(record interface{}) error {
return err
}

func (s *DatabaseService) AddMigrationPackages(pkgNames ...string) {
s.migrationPackages = append(s.migrationPackages, pkgNames...)
}

func (s *DatabaseService) Close() error {
return s.DB.Close()
}
Expand Down Expand Up @@ -77,23 +82,12 @@ func (s *DatabaseService) Upgrade(ctx context.Context) error {
return err
}

migrations = migrations.FilterPackage([]string{"main"}).SortAndConnect()
if len(migrations) == 0 {
return nil
}

_, lastAppliedMigration, err := rh.FindLastAppliedMigration(ctx, migrations)
if err != nil {
return err
}

if lastAppliedMigration != nil {
return rockhopper.Up(ctx, rh, lastAppliedMigration.Next, 0)
}

// TODO: use align in the next major version
// return rockhopper.Align(ctx, rh, 20231123125402, migrations)
return rockhopper.Up(ctx, rh, migrations.Head(), 0)
pkgNames := append([]string{rockhopper.DefaultPackageName}, s.migrationPackages...)
return rockhopper.Upgrade(ctx, rh, migrations.FilterPackage(pkgNames))

Check failure on line 90 in pkg/service/database.go

View workflow job for this annotation

GitHub Actions / build (6.2, 1.20)

undefined: rockhopper.Upgrade
}

func ReformatMysqlDSN(dsn string) (string, error) {
Expand Down

0 comments on commit 729afcd

Please sign in to comment.