From 29779a80256fab0d79df872d14b3f22deb676ab7 Mon Sep 17 00:00:00 2001 From: Geoffrey Wilson Date: Wed, 29 Nov 2023 12:29:11 -0500 Subject: [PATCH] =?UTF-8?q?New=20migration=20to=20remove=20the=20old=20`ex?= =?UTF-8?q?periments=5Fname=5Fkey`=20constraint=20fro=E2=80=A6=20(#663)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * New migration to remove the old `experiments_name_key` constraint from PG --- pkg/database/migrate.go | 12 +- pkg/database/migrations/v_0007/migrate.go | 41 ++++ pkg/database/migrations/v_0007/model.go | 238 ++++++++++++++++++++++ 3 files changed, 289 insertions(+), 2 deletions(-) create mode 100644 pkg/database/migrations/v_0007/migrate.go create mode 100644 pkg/database/migrations/v_0007/model.go diff --git a/pkg/database/migrate.go b/pkg/database/migrate.go index 063d694a3..de8ec42eb 100644 --- a/pkg/database/migrate.go +++ b/pkg/database/migrate.go @@ -19,6 +19,7 @@ import ( "github.com/G-Research/fasttrackml/pkg/database/migrations/v_0004" "github.com/G-Research/fasttrackml/pkg/database/migrations/v_0005" "github.com/G-Research/fasttrackml/pkg/database/migrations/v_0006" + "github.com/G-Research/fasttrackml/pkg/database/migrations/v_0007" ) var supportedAlembicVersions = []string{ @@ -40,7 +41,7 @@ func CheckAndMigrateDB(migrate bool, db *gorm.DB) error { tx.First(&schemaVersion) } - if !slices.Contains(supportedAlembicVersions, alembicVersion.Version) || schemaVersion.Version != "e0d125c68d9a" { + if !slices.Contains(supportedAlembicVersions, alembicVersion.Version) || schemaVersion.Version != v_0007.Version { if !migrate && alembicVersion.Version != "" { return fmt.Errorf( "unsupported database schema versions alembic %s, FastTrackML %s", @@ -164,6 +165,13 @@ func CheckAndMigrateDB(migrate bool, db *gorm.DB) error { if err := v_0006.Migrate(db); err != nil { return fmt.Errorf("error migrating database to FastTrackML schema %s: %w", v_0006.Version, err) } + fallthrough + + case v_0006.Version: + log.Infof("Migrating database to FastTrackML schema %s", v_0007.Version) + if err := v_0007.Migrate(db); err != nil { + return fmt.Errorf("error migrating database to FastTrackML schema %s: %w", v_0007.Version, err) + } default: return fmt.Errorf("unsupported database FastTrackML schema version %s", schemaVersion.Version) @@ -194,7 +202,7 @@ func CheckAndMigrateDB(migrate bool, db *gorm.DB) error { Version: "97727af70f4d", }) tx.Create(&SchemaVersion{ - Version: v_0006.Version, + Version: v_0007.Version, }) tx.Commit() if tx.Error != nil { diff --git a/pkg/database/migrations/v_0007/migrate.go b/pkg/database/migrations/v_0007/migrate.go new file mode 100644 index 000000000..857cc6603 --- /dev/null +++ b/pkg/database/migrations/v_0007/migrate.go @@ -0,0 +1,41 @@ +package v_0007 + +import ( + "fmt" + + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + + "github.com/G-Research/fasttrackml/pkg/database/migrations" +) + +const Version = "cbc41c0f4fc5" + +func Migrate(db *gorm.DB) error { + // We need to run this migration without foreign key constraints to avoid + // the cascading delete to kick in and delete all the runs. + return migrations.RunWithoutForeignKeyIfNeeded(db, func() error { + return db.Transaction(func(tx *gorm.DB) error { + switch tx.Dialector.Name() { + case sqlite.Dialector{}.Name(): + // SQLite no action needed + case postgres.Dialector{}.Name(): + // Postgres needs to remove this constraint + constraint := "experiments_name_key" + if tx.Migrator().HasConstraint("experiments", constraint) { + if err := tx.Migrator().DropConstraint("experiments", constraint); err != nil { + return err + } + } + default: + return fmt.Errorf("unsupported database dialect %s", tx.Dialector.Name()) + } + + return tx.Model(&SchemaVersion{}). + Where("1 = 1"). + Update("Version", Version). + Error + }) + }) +} diff --git a/pkg/database/migrations/v_0007/model.go b/pkg/database/migrations/v_0007/model.go new file mode 100644 index 000000000..c966199fa --- /dev/null +++ b/pkg/database/migrations/v_0007/model.go @@ -0,0 +1,238 @@ +package v_0007 + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/hex" + "encoding/json" + "time" + + "github.com/google/uuid" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +type Status string + +const ( + StatusRunning Status = "RUNNING" + StatusScheduled Status = "SCHEDULED" + StatusFinished Status = "FINISHED" + StatusFailed Status = "FAILED" + StatusKilled Status = "KILLED" +) + +type LifecycleStage string + +const ( + LifecycleStageActive LifecycleStage = "active" + LifecycleStageDeleted LifecycleStage = "deleted" +) + +type Namespace struct { + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + Apps []App `gorm:"constraint:OnDelete:CASCADE" json:"apps"` + Code string `gorm:"unique;index;not null" json:"code"` + Description string `json:"description"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` + DefaultExperimentID *int32 `gorm:"not null" json:"default_experiment_id"` + Experiments []Experiment `gorm:"constraint:OnDelete:CASCADE" json:"experiments"` +} + +type Experiment struct { + ID *int32 `gorm:"column:experiment_id;not null;primaryKey"` + Name string `gorm:"type:varchar(256);not null;index:,unique,composite:name"` + ArtifactLocation string `gorm:"type:varchar(256)"` + LifecycleStage LifecycleStage `gorm:"type:varchar(32);check:lifecycle_stage IN ('active', 'deleted')"` + CreationTime sql.NullInt64 `gorm:"type:bigint"` + LastUpdateTime sql.NullInt64 `gorm:"type:bigint"` + NamespaceID uint `gorm:"index:,unique,composite:name"` + Namespace Namespace + Tags []ExperimentTag `gorm:"constraint:OnDelete:CASCADE"` + Runs []Run `gorm:"constraint:OnDelete:CASCADE"` +} + +type ExperimentTag struct { + Key string `gorm:"type:varchar(250);not null;primaryKey"` + Value string `gorm:"type:varchar(5000)"` + ExperimentID int32 `gorm:"not null;primaryKey"` +} + +//nolint:lll +type Run struct { + ID string `gorm:"<-:create;column:run_uuid;type:varchar(32);not null;primaryKey"` + Name string `gorm:"type:varchar(250)"` + SourceType string `gorm:"<-:create;type:varchar(20);check:source_type IN ('NOTEBOOK', 'JOB', 'LOCAL', 'UNKNOWN', 'PROJECT')"` + SourceName string `gorm:"<-:create;type:varchar(500)"` + EntryPointName string `gorm:"<-:create;type:varchar(50)"` + UserID string `gorm:"<-:create;type:varchar(256)"` + Status Status `gorm:"type:varchar(9);check:status IN ('SCHEDULED', 'FAILED', 'FINISHED', 'RUNNING', 'KILLED')"` + StartTime sql.NullInt64 `gorm:"<-:create;type:bigint"` + EndTime sql.NullInt64 `gorm:"type:bigint"` + SourceVersion string `gorm:"<-:create;type:varchar(50)"` + LifecycleStage LifecycleStage `gorm:"type:varchar(20);check:lifecycle_stage IN ('active', 'deleted')"` + ArtifactURI string `gorm:"<-:create;type:varchar(200)"` + ExperimentID int32 + Experiment Experiment + DeletedTime sql.NullInt64 `gorm:"type:bigint"` + RowNum RowNum `gorm:"<-:create;index"` + Params []Param `gorm:"constraint:OnDelete:CASCADE"` + Tags []Tag `gorm:"constraint:OnDelete:CASCADE"` + Metrics []Metric `gorm:"constraint:OnDelete:CASCADE"` + LatestMetrics []LatestMetric `gorm:"constraint:OnDelete:CASCADE"` +} + +type RowNum int64 + +func (rn *RowNum) Scan(v interface{}) error { + nullInt := sql.NullInt64{} + if err := nullInt.Scan(v); err != nil { + return err + } + *rn = RowNum(nullInt.Int64) + return nil +} + +func (rn RowNum) GormDataType() string { + return "bigint" +} + +func (rn RowNum) GormValue(ctx context.Context, db *gorm.DB) clause.Expr { + if rn == 0 { + return clause.Expr{ + SQL: "(SELECT COALESCE(MAX(row_num), -1) FROM runs) + 1", + } + } + return clause.Expr{ + SQL: "?", + Vars: []interface{}{int64(rn)}, + } +} + +type Param struct { + Key string `gorm:"type:varchar(250);not null;primaryKey"` + Value string `gorm:"type:varchar(500);not null"` + RunID string `gorm:"column:run_uuid;not null;primaryKey;index"` +} + +type Tag struct { + Key string `gorm:"type:varchar(250);not null;primaryKey"` + Value string `gorm:"type:varchar(5000)"` + RunID string `gorm:"column:run_uuid;not null;primaryKey;index"` +} + +type Metric struct { + Key string `gorm:"type:varchar(250);not null;primaryKey"` + Value float64 `gorm:"type:double precision;not null;primaryKey"` + Timestamp int64 `gorm:"not null;primaryKey"` + RunID string `gorm:"column:run_uuid;not null;primaryKey;index"` + Step int64 `gorm:"default:0;not null;primaryKey"` + IsNan bool `gorm:"default:false;not null;primaryKey"` + Iter int64 `gorm:"index"` +} + +type LatestMetric struct { + Key string `gorm:"type:varchar(250);not null;primaryKey"` + Value float64 `gorm:"type:double precision;not null"` + Timestamp int64 + Step int64 `gorm:"not null"` + IsNan bool `gorm:"not null"` + RunID string `gorm:"column:run_uuid;not null;primaryKey;index"` + LastIter int64 +} + +type AlembicVersion struct { + Version string `gorm:"column:version_num;type:varchar(32);not null;primaryKey"` +} + +func (AlembicVersion) TableName() string { + return "alembic_version" +} + +type SchemaVersion struct { + Version string `gorm:"not null;primaryKey"` +} + +func (SchemaVersion) TableName() string { + return "schema_version" +} + +type Base struct { + ID uuid.UUID `gorm:"type:uuid;primaryKey" json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + IsArchived bool `json:"-"` +} + +func (b *Base) BeforeCreate(tx *gorm.DB) error { + b.ID = uuid.New() + return nil +} + +type Dashboard struct { + Base + Name string `json:"name"` + Description string `json:"description"` + AppID *uuid.UUID `gorm:"type:uuid" json:"app_id"` + App App `json:"-"` +} + +func (d Dashboard) MarshalJSON() ([]byte, error) { + type localDashboard Dashboard + type jsonDashboard struct { + localDashboard + AppType *string `json:"app_type"` + } + jd := jsonDashboard{ + localDashboard: localDashboard(d), + } + if d.App.IsArchived { + jd.AppID = nil + } else { + jd.AppType = &d.App.Type + } + return json.Marshal(jd) +} + +type App struct { + Base + Type string `gorm:"not null" json:"type"` + State AppState `json:"state"` + Namespace Namespace `json:"-"` + NamespaceID uint `gorm:"column:namespace_id" json:"-"` +} + +type AppState map[string]any + +func (s AppState) Value() (driver.Value, error) { + v, err := json.Marshal(s) + if err != nil { + return nil, err + } + return string(v), nil +} + +func (s *AppState) Scan(v interface{}) error { + var nullS sql.NullString + if err := nullS.Scan(v); err != nil { + return err + } + if nullS.Valid { + return json.Unmarshal([]byte(nullS.String), s) + } + return nil +} + +func (s AppState) GormDataType() string { + return "text" +} + +func NewUUID() string { + var r [32]byte + u := uuid.New() + hex.Encode(r[:], u[:]) + return string(r[:]) +}