diff --git a/pkg/bbgo/bootstrap.go b/pkg/bbgo/bootstrap.go index bb297e2db4..cda97eaa4c 100644 --- a/pkg/bbgo/bootstrap.go +++ b/pkg/bbgo/bootstrap.go @@ -34,7 +34,7 @@ func BootstrapEnvironmentLightweight(ctx context.Context, environ *Environment, } func BootstrapEnvironment(ctx context.Context, environ *Environment, userConfig *Config) error { - if err := environ.ConfigureDatabase(ctx); err != nil { + if err := environ.ConfigureDatabase(ctx, userConfig); err != nil { return err } @@ -66,5 +66,5 @@ func BootstrapEnvironment(ctx context.Context, environ *Environment, userConfig } func BootstrapBacktestEnvironment(ctx context.Context, environ *Environment) error { - return environ.ConfigureDatabase(ctx) + return environ.ConfigureDatabase(ctx, nil) } diff --git a/pkg/bbgo/config.go b/pkg/bbgo/config.go index 1121717b90..2fce203da6 100644 --- a/pkg/bbgo/config.go +++ b/pkg/bbgo/config.go @@ -326,6 +326,13 @@ type ServiceConfig struct { GoogleSpreadSheetService *GoogleSpreadSheetServiceConfig `json:"googleSpreadSheet" yaml:"googleSpreadSheet"` } +type DatabaseConfig struct { + Driver string `json:"driver"` + DSN string `json:"dsn"` + + ExtraMigrationPackages []string `json:"extraMigrationPackages"` +} + type EnvironmentConfig struct { DisableDefaultKLineSubscription bool `json:"disableDefaultKLineSubscription"` DisableHistoryKLinePreload bool `json:"disableHistoryKLinePreload"` @@ -358,6 +365,8 @@ type Config struct { Service *ServiceConfig `json:"services,omitempty" yaml:"services,omitempty"` + DatabaseConfig *DatabaseConfig `json:"database,omitempty" yaml:"database,omitempty"` + Environment *EnvironmentConfig `json:"environment,omitempty" yaml:"environment,omitempty"` Sessions map[string]*ExchangeSession `json:"sessions,omitempty" yaml:"sessions,omitempty"` diff --git a/pkg/bbgo/environment.go b/pkg/bbgo/environment.go index e5c7189e2e..c1f8fc22d2 100644 --- a/pkg/bbgo/environment.go +++ b/pkg/bbgo/environment.go @@ -162,34 +162,48 @@ func (environ *Environment) SelectSessions(names ...string) map[string]*Exchange return sessions } -func (environ *Environment) ConfigureDatabase(ctx context.Context) error { +func (environ *Environment) ConfigureDatabase(ctx context.Context, config *Config) error { // configureDB configures the database service based on the environment variable - if driver, ok := os.LookupEnv("DB_DRIVER"); ok { + var dbDriver string + var dbDSN string + var extraPkgNames []string - if dsn, ok := os.LookupEnv("DB_DSN"); ok { - return environ.ConfigureDatabaseDriver(ctx, driver, dsn) - } - - } else if dsn, ok := os.LookupEnv("SQLITE3_DSN"); ok { - - return environ.ConfigureDatabaseDriver(ctx, "sqlite3", dsn) + if config != nil && config.DatabaseConfig != nil { + dbDriver = config.DatabaseConfig.Driver + dbDSN = config.DatabaseConfig.DSN + extraPkgNames = config.DatabaseConfig.ExtraMigrationPackages + } - } else if dsn, ok := os.LookupEnv("MYSQL_URL"); ok { + if val, ok := os.LookupEnv("DB_DRIVER"); ok { + dbDriver = val + } - return environ.ConfigureDatabaseDriver(ctx, "mysql", dsn) + if val, ok := os.LookupEnv("DB_DSN"); ok { + dbDSN = val + } else if val, ok := os.LookupEnv("SQLITE3_DSN"); ok && (dbDriver == "" || dbDriver == "sqlite3") { + dbDSN = val + dbDriver = "sqlite3" + } else if val, ok := os.LookupEnv("MYSQL_URL"); ok && (dbDriver == "" || dbDriver == "mysql") { + dbDSN = val + dbDriver = "mysql" + } + if dbDriver == "" { + return fmt.Errorf("either env DB_DRIVER or config.Driver is not set") } - return nil + return environ.ConfigureDatabaseDriver(ctx, dbDriver, dbDSN, extraPkgNames...) } -func (environ *Environment) ConfigureDatabaseDriver(ctx context.Context, driver string, dsn string) error { +func (environ *Environment) ConfigureDatabaseDriver(ctx context.Context, driver string, dsn string, extraPkgNames ...string) error { environ.DatabaseService = service.NewDatabaseService(driver, dsn) err := environ.DatabaseService.Connect() if err != nil { return err } + environ.DatabaseService.AddMigrationPackages(extraPkgNames...) + if err := environ.DatabaseService.Upgrade(ctx); err != nil { return err } diff --git a/pkg/cmd/account.go b/pkg/cmd/account.go index 18d5ff6f2a..ee6d27352d 100644 --- a/pkg/cmd/account.go +++ b/pkg/cmd/account.go @@ -38,7 +38,7 @@ var accountCmd = &cobra.Command{ } environ := bbgo.NewEnvironment() - if err := environ.ConfigureDatabase(ctx); err != nil { + if err := environ.ConfigureDatabase(ctx, userConfig); err != nil { return err } diff --git a/pkg/cmd/market.go b/pkg/cmd/market.go index 1ff22a25ca..05f0b5bdee 100644 --- a/pkg/cmd/market.go +++ b/pkg/cmd/market.go @@ -44,7 +44,7 @@ var marketCmd = &cobra.Command{ } environ := bbgo.NewEnvironment() - if err := environ.ConfigureDatabase(ctx); err != nil { + if err := environ.ConfigureDatabase(ctx, userConfig); err != nil { return err } diff --git a/pkg/cmd/pnl.go b/pkg/cmd/pnl.go index 8d3fd42317..d14fa8899c 100644 --- a/pkg/cmd/pnl.go +++ b/pkg/cmd/pnl.go @@ -88,7 +88,7 @@ var PnLCmd = &cobra.Command{ environ := bbgo.NewEnvironment() - if err := environ.ConfigureDatabase(ctx); err != nil { + if err := environ.ConfigureDatabase(ctx, userConfig); err != nil { return err } diff --git a/pkg/cmd/sync.go b/pkg/cmd/sync.go index a898c3fe4a..28f86b8b0f 100644 --- a/pkg/cmd/sync.go +++ b/pkg/cmd/sync.go @@ -49,7 +49,7 @@ var SyncCmd = &cobra.Command{ } environ := bbgo.NewEnvironment() - if err := environ.ConfigureDatabase(ctx); err != nil { + if err := environ.ConfigureDatabase(ctx, userConfig); err != nil { return err } diff --git a/pkg/service/database.go b/pkg/service/database.go index 5f02225b43..73433e1408 100644 --- a/pkg/service/database.go +++ b/pkg/service/database.go @@ -19,6 +19,8 @@ type DatabaseService struct { Driver string DSN string DB *sqlx.DB + + migrationPackages []string } func NewDatabaseService(driver, dsn string) *DatabaseService { @@ -35,7 +37,6 @@ func NewDatabaseService(driver, dsn string) *DatabaseService { Driver: driver, DSN: dsn, } - } func (s *DatabaseService) Connect() error { @@ -50,6 +51,10 @@ func (s *DatabaseService) Insert(record interface{}) error { return err } +func (s *DatabaseService) AddMigrationPackages(pkgNames ...string) { + s.migrationPackages = append(s.migrationPackages, pkgNames...) +} + func (s *DatabaseService) Close() error { return s.DB.Close() } @@ -77,23 +82,12 @@ func (s *DatabaseService) Upgrade(ctx context.Context) error { return err } - migrations = migrations.FilterPackage([]string{"main"}).SortAndConnect() if len(migrations) == 0 { return nil } - _, lastAppliedMigration, err := rh.FindLastAppliedMigration(ctx, migrations) - if err != nil { - return err - } - - if lastAppliedMigration != nil { - return rockhopper.Up(ctx, rh, lastAppliedMigration.Next, 0) - } - - // TODO: use align in the next major version - // return rockhopper.Align(ctx, rh, 20231123125402, migrations) - return rockhopper.Up(ctx, rh, migrations.Head(), 0) + pkgNames := append([]string{rockhopper.DefaultPackageName}, s.migrationPackages...) + return rockhopper.Upgrade(ctx, rh, migrations.FilterPackage(pkgNames)) } func ReformatMysqlDSN(dsn string) (string, error) {