diff --git a/go.mod b/go.mod index 1a818a0..a73a39f 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/stretchr/testify v1.8.4 gorm.io/driver/sqlite v1.5.4 gorm.io/gorm v1.25.5 - goyave.dev/goyave/v5 v5.0.0-preview5 + goyave.dev/goyave/v5 v5.0.0-preview6.0.20231201171501-32722c77ca39 ) require ( @@ -18,7 +18,7 @@ require ( github.com/jinzhu/now v1.1.5 // indirect github.com/mattn/go-sqlite3 v1.14.18 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect + golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect goyave.dev/copier v0.4.3 // indirect ) diff --git a/go.sum b/go.sum index 0cfa179..68354da 100644 --- a/go.sum +++ b/go.sum @@ -16,8 +16,8 @@ github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= -golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w= +golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= @@ -28,5 +28,5 @@ gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= goyave.dev/copier v0.4.3 h1:MxX2wBnhQUbv0mHPXEgw/zS4TZMtTVpzj/aYS3h4amk= goyave.dev/copier v0.4.3/go.mod h1:WJu0Ex81v29f5U0eMWzSNsMTGmuGY6lQ/q5yGlyLDsU= -goyave.dev/goyave/v5 v5.0.0-preview5 h1:mcX7mAzyxJGCMuNrLOAhAN3rLQNOqV4zdBBvGe+NPCs= -goyave.dev/goyave/v5 v5.0.0-preview5/go.mod h1:VgRY1EJfaQrW2UfDiXQYe3tK1YuUDGQW0RJOnlmbgic= +goyave.dev/goyave/v5 v5.0.0-preview6.0.20231201171501-32722c77ca39 h1:BZwps/AOhKBafAI9RUUt4UcEVTym938e7a2oM/xt7VE= +goyave.dev/goyave/v5 v5.0.0-preview6.0.20231201171501-32722c77ca39/go.mod h1:sCz3xtCWJhFQvDqIaRvcqO4rWhDLIEBb678zTxAxgHo= diff --git a/settings.go b/settings.go index a78ad5d..29eaf78 100644 --- a/settings.go +++ b/settings.go @@ -125,7 +125,7 @@ func parseModel(db *gorm.DB, model any) (*schema.Schema, error) { } // Scope using the default FilterSettings. See `FilterSettings.Scope()` for more details. -func Scope[T any](db *gorm.DB, request *Request, dest *[]T) (*database.Paginator[T], *gorm.DB) { +func Scope[T any](db *gorm.DB, request *Request, dest *[]T) (*database.Paginator[T], error) { return (&Settings[T]{}).Scope(db, request, dest) } @@ -135,26 +135,33 @@ func ScopeUnpaginated[T any](db *gorm.DB, request *Request, dest *[]T) *gorm.DB } // Scope apply all filters, sorts and joins defined in the request's data to the given `*gorm.DB` -// and process pagination. Returns the resulting `*database.Paginator` and the `*gorm.DB` result, -// which can be used to check for database errors. +// and process pagination. Returns the resulting `*database.Paginator`. // The given request is expected to be validated using `ApplyValidation`. -func (s *Settings[T]) Scope(db *gorm.DB, request *Request, dest *[]T) (*database.Paginator[T], *gorm.DB) { - db, schema, hasJoins := s.scopeCommon(db, request, dest) +func (s *Settings[T]) Scope(db *gorm.DB, request *Request, dest *[]T) (*database.Paginator[T], error) { page := request.Page.Default(1) pageSize := request.PerPage.Default(DefaultPageSize) - paginator := database.NewPaginator(db, page, pageSize, dest) - paginator.UpdatePageInfo() + var paginator *database.Paginator[T] + err := db.Transaction(func(tx *gorm.DB) error { + tx, schema, hasJoins := s.scopeCommon(tx, request, dest) - paginator.DB = s.scopeSort(paginator.DB, request, schema) - if fieldsDB := s.scopeFields(paginator.DB, request, schema, hasJoins); fieldsDB != nil { - paginator.DB = fieldsDB - } else { - return nil, paginator.DB - } + paginator = database.NewPaginator(tx, page, pageSize, dest) + err := paginator.UpdatePageInfo() + if err != nil { + return errors.New(err) + } + paginator.DB = s.scopeSort(paginator.DB, request, schema) + if fieldsDB := s.scopeFields(paginator.DB, request, schema, hasJoins); fieldsDB != nil { + paginator.DB = fieldsDB + } else { + return errors.New(paginator.DB.Error) + } + + return paginator.Find() + }) - return paginator, paginator.Find() + return paginator, err } // ScopeUnpaginated apply all filters, sorts and joins defined in the request's data to the given `*gorm.DB` diff --git a/settings_test.go b/settings_test.go index 904e8f5..b37cb48 100644 --- a/settings_test.go +++ b/settings_test.go @@ -11,6 +11,7 @@ import ( "gorm.io/gorm/clause" "gorm.io/gorm/schema" "goyave.dev/goyave/v5/database" + "goyave.dev/goyave/v5/util/typeutil" ) @@ -45,7 +46,7 @@ func openDryRunDB(t *testing.T) *gorm.DB { return db } -func prepareTestScope(t *testing.T, settings *Settings[*TestScopeModel]) (*database.Paginator[*TestScopeModel], *gorm.DB) { +func prepareTestScope(t *testing.T, settings *Settings[*TestScopeModel]) (*database.Paginator[*TestScopeModel], error) { request := &Request{ Filter: typeutil.NewUndefined([]*Filter{ {Field: "name", Args: []string{"val1"}, Operator: Operators["$cont"]}, @@ -98,7 +99,7 @@ func prepareTestScopeUnpaginated(t *testing.T, settings *Settings[*TestScopeMode } func TestScope(t *testing.T) { - paginator, db := prepareTestScope(t, &Settings[*TestScopeModel]{ + paginator, err := prepareTestScope(t, &Settings[*TestScopeModel]{ FieldsSearch: []string{"email"}, SearchOperator: &Operator{ Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { @@ -108,6 +109,7 @@ func TestScope(t *testing.T) { }, }) assert.NotNil(t, paginator) + assert.NoError(t, err) expected := map[string]clause.Clause{ "WHERE": { @@ -185,9 +187,9 @@ func TestScope(t *testing.T) { }, }, } - assert.Equal(t, expected, db.Statement.Clauses) - assert.Contains(t, db.Statement.Preloads, "Relation") - assert.Equal(t, []string{"`test_scope_models`.`id`", "`test_scope_models`.`name`", "`test_scope_models`.`email`", "(UPPER(`test_scope_models`.name)) `computed`", "`test_scope_models`.`relation_id`"}, db.Statement.Selects) + assert.Equal(t, expected, paginator.DB.Statement.Clauses) + assert.Contains(t, paginator.DB.Statement.Preloads, "Relation") + assert.Equal(t, []string{"`test_scope_models`.`id`", "`test_scope_models`.`name`", "`test_scope_models`.`email`", "(UPPER(`test_scope_models`.name)) `computed`", "`test_scope_models`.`relation_id`"}, paginator.DB.Statement.Selects) } func TestScopeUnpaginated(t *testing.T) { @@ -278,8 +280,9 @@ func TestScopeUnpaginated(t *testing.T) { } func TestScopeDisableFields(t *testing.T) { - paginator, db := prepareTestScope(t, &Settings[*TestScopeModel]{DisableFields: true, FieldsSearch: []string{"email"}}) + paginator, err := prepareTestScope(t, &Settings[*TestScopeModel]{DisableFields: true, FieldsSearch: []string{"email"}}) assert.NotNil(t, paginator) + assert.NoError(t, err) expected := map[string]clause.Clause{ "WHERE": { @@ -357,8 +360,8 @@ func TestScopeDisableFields(t *testing.T) { }, }, } - assert.Equal(t, expected, db.Statement.Clauses) - assert.ElementsMatch(t, []string{"`test_scope_models`.`id`", "`test_scope_models`.`relation_id`", "`test_scope_models`.`name`", "`test_scope_models`.`email`", "(UPPER(`test_scope_models`.name)) `computed`"}, db.Statement.Selects) + assert.Equal(t, expected, paginator.DB.Statement.Clauses) + assert.ElementsMatch(t, []string{"`test_scope_models`.`id`", "`test_scope_models`.`relation_id`", "`test_scope_models`.`name`", "`test_scope_models`.`email`", "(UPPER(`test_scope_models`.name)) `computed`"}, paginator.DB.Statement.Selects) } func TestScopeUnpaginatedDisableFields(t *testing.T) { @@ -440,8 +443,9 @@ func TestScopeUnpaginatedDisableFields(t *testing.T) { } func TestScopeDisableFilter(t *testing.T) { - paginator, db := prepareTestScope(t, &Settings[*TestScopeModel]{DisableFilter: true, FieldsSearch: []string{"email"}}) + paginator, err := prepareTestScope(t, &Settings[*TestScopeModel]{DisableFilter: true, FieldsSearch: []string{"email"}}) assert.NotNil(t, paginator) + assert.NoError(t, err) expected := map[string]clause.Clause{ "WHERE": { @@ -496,8 +500,8 @@ func TestScopeDisableFilter(t *testing.T) { }, }, } - assert.Equal(t, expected, db.Statement.Clauses) - assert.ElementsMatch(t, []string{"`test_scope_models`.`id`", "`test_scope_models`.`name`", "`test_scope_models`.`email`", "`test_scope_models`.`relation_id`", "(UPPER(`test_scope_models`.name)) `computed`"}, db.Statement.Selects) + assert.Equal(t, expected, paginator.DB.Statement.Clauses) + assert.ElementsMatch(t, []string{"`test_scope_models`.`id`", "`test_scope_models`.`name`", "`test_scope_models`.`email`", "`test_scope_models`.`relation_id`", "(UPPER(`test_scope_models`.name)) `computed`"}, paginator.DB.Statement.Selects) } func TestScopeUnpaginatedDisableFilter(t *testing.T) { @@ -556,8 +560,9 @@ func TestScopeUnpaginatedDisableFilter(t *testing.T) { } func TestScopeDisableSort(t *testing.T) { - paginator, db := prepareTestScope(t, &Settings[*TestScopeModel]{DisableSort: true, FieldsSearch: []string{"email"}}) + paginator, err := prepareTestScope(t, &Settings[*TestScopeModel]{DisableSort: true, FieldsSearch: []string{"email"}}) assert.NotNil(t, paginator) + assert.NoError(t, err) expected := map[string]clause.Clause{ "WHERE": { @@ -622,8 +627,8 @@ func TestScopeDisableSort(t *testing.T) { }, }, } - assert.Equal(t, expected, db.Statement.Clauses) - assert.ElementsMatch(t, []string{"`test_scope_models`.`id`", "`test_scope_models`.`name`", "`test_scope_models`.`email`", "`test_scope_models`.`relation_id`", "(UPPER(`test_scope_models`.name)) `computed`"}, db.Statement.Selects) + assert.Equal(t, expected, paginator.DB.Statement.Clauses) + assert.ElementsMatch(t, []string{"`test_scope_models`.`id`", "`test_scope_models`.`name`", "`test_scope_models`.`email`", "`test_scope_models`.`relation_id`", "(UPPER(`test_scope_models`.name)) `computed`"}, paginator.DB.Statement.Selects) } func TestScopeUnpaginatedDisableSort(t *testing.T) { @@ -692,8 +697,9 @@ func TestScopeUnpaginatedDisableSort(t *testing.T) { } func TestScopeDisableJoin(t *testing.T) { - paginator, db := prepareTestScope(t, &Settings[*TestScopeModel]{DisableJoin: true, FieldsSearch: []string{"email"}}) + paginator, err := prepareTestScope(t, &Settings[*TestScopeModel]{DisableJoin: true, FieldsSearch: []string{"email"}}) assert.NotNil(t, paginator) + assert.NoError(t, err) expected := map[string]clause.Clause{ "WHERE": { @@ -770,9 +776,9 @@ func TestScopeDisableJoin(t *testing.T) { }, }, } - assert.Equal(t, expected, db.Statement.Clauses) - assert.Empty(t, db.Statement.Preloads) - assert.ElementsMatch(t, []string{"`test_scope_models`.`id`", "`test_scope_models`.`name`", "`test_scope_models`.`email`", "(UPPER(`test_scope_models`.name)) `computed`"}, db.Statement.Selects) + assert.Equal(t, expected, paginator.DB.Statement.Clauses) + assert.Empty(t, paginator.DB.Statement.Preloads) + assert.ElementsMatch(t, []string{"`test_scope_models`.`id`", "`test_scope_models`.`name`", "`test_scope_models`.`email`", "(UPPER(`test_scope_models`.name)) `computed`"}, paginator.DB.Statement.Selects) } func TestScopeUnpaginatedDisableJoin(t *testing.T) { @@ -854,8 +860,9 @@ func TestScopeUnpaginatedDisableJoin(t *testing.T) { } func TestScopeDisableSearch(t *testing.T) { - paginator, db := prepareTestScope(t, &Settings[*TestScopeModel]{DisableSearch: true, FieldsSearch: []string{"name"}}) + paginator, err := prepareTestScope(t, &Settings[*TestScopeModel]{DisableSearch: true, FieldsSearch: []string{"name"}}) assert.NotNil(t, paginator) + assert.NoError(t, err) expected := map[string]clause.Clause{ "WHERE": { @@ -925,8 +932,8 @@ func TestScopeDisableSearch(t *testing.T) { }, } - assert.Equal(t, expected, db.Statement.Clauses) - assert.ElementsMatch(t, []string{"`test_scope_models`.`id`", "`test_scope_models`.`name`", "`test_scope_models`.`email`", "`test_scope_models`.`relation_id`", "(UPPER(`test_scope_models`.name)) `computed`"}, db.Statement.Selects) + assert.Equal(t, expected, paginator.DB.Statement.Clauses) + assert.ElementsMatch(t, []string{"`test_scope_models`.`id`", "`test_scope_models`.`name`", "`test_scope_models`.`email`", "`test_scope_models`.`relation_id`", "(UPPER(`test_scope_models`.name)) `computed`"}, paginator.DB.Statement.Selects) } func TestScopeUnpaginatedDisableSearch(t *testing.T) { @@ -1007,9 +1014,9 @@ func TestScopeNoPrimaryKey(t *testing.T) { db := openDryRunDB(t) results := []*TestScopeModelNoPrimaryKey{} - paginator, db := Scope(db, request, &results) - assert.Nil(t, paginator) - assert.Equal(t, "could not find primary key. Add `gorm:\"primaryKey\"` to your model", db.Error.Error()) + paginator, err := Scope(db, request, &results) + assert.Equal(t, "could not find primary key. Add `gorm:\"primaryKey\"` to your model", err.Error()) + assert.Equal(t, err, paginator.DB.Error) } func TestScopeUnpaginatedNoPrimaryKey(t *testing.T) { @@ -1034,9 +1041,10 @@ func TestScopeWithFieldsBlacklist(t *testing.T) { }, } results := []*TestScopeModel{} - paginator, db := settings.Scope(db, request, &results) + paginator, err := settings.Scope(db, request, &results) assert.NotNil(t, paginator) - assert.ElementsMatch(t, []string{"`test_scope_models`.`id`", "`test_scope_models`.`relation_id`", "`test_scope_models`.`email`", "(UPPER(`test_scope_models`.name)) `computed`"}, db.Statement.Selects) + assert.NoError(t, err) + assert.ElementsMatch(t, []string{"`test_scope_models`.`id`", "`test_scope_models`.`relation_id`", "`test_scope_models`.`email`", "(UPPER(`test_scope_models`.name)) `computed`"}, paginator.DB.Statement.Selects) } func TestScopeUnpaginatedWithFieldsBlacklist(t *testing.T) { @@ -1058,7 +1066,7 @@ func TestScopeInvalidModel(t *testing.T) { db := openDryRunDB(t) model := []string{} assert.Panics(t, func() { - Scope(db, request, &model) + _, _ = Scope(db, request, &model) }) } @@ -1425,9 +1433,10 @@ func TestSettingsComputedFieldWithAutoFields(t *testing.T) { db := openDryRunDB(t) results := []*TestScopeModel{} - paginator, db := Scope(db, request, &results) + paginator, err := Scope(db, request, &results) assert.NotNil(t, paginator) + assert.NoError(t, err) expected := map[string]clause.Clause{ "WHERE": { @@ -1489,8 +1498,8 @@ func TestSettingsComputedFieldWithAutoFields(t *testing.T) { }, }, } - assert.Equal(t, expected, db.Statement.Clauses) - assert.ElementsMatch(t, []string{"`test_scope_models`.`name`", "`test_scope_models`.`email`", "(UPPER(`test_scope_models`.name)) `computed`", "`test_scope_models`.`id`", "`test_scope_models`.`relation_id`"}, db.Statement.Selects) + assert.Equal(t, expected, paginator.DB.Statement.Clauses) + assert.ElementsMatch(t, []string{"`test_scope_models`.`name`", "`test_scope_models`.`email`", "(UPPER(`test_scope_models`.name)) `computed`", "`test_scope_models`.`id`", "`test_scope_models`.`relation_id`"}, paginator.DB.Statement.Selects) } func TestSettingsSelectWithExistingJoin(t *testing.T) { @@ -1507,9 +1516,10 @@ func TestSettingsSelectWithExistingJoin(t *testing.T) { db = db.Joins("Relation", db.Session(&gorm.Session{NewDB: true}).Where("Relation.id > ?", 0)) results := []*TestScopeModel{} - paginator, db := Scope(db, request, &results) + paginator, err := Scope(db, request, &results) assert.NotNil(t, paginator) + assert.NoError(t, err) expected := map[string]clause.Clause{ "WHERE": { @@ -1575,8 +1585,8 @@ func TestSettingsSelectWithExistingJoin(t *testing.T) { }, }, } - assert.Equal(t, expected, db.Statement.Clauses) - assert.Empty(t, db.Statement.Joins) + assert.Equal(t, expected, paginator.DB.Statement.Clauses) + assert.Empty(t, paginator.DB.Statement.Joins) } type TestScopeRelationWithComputed struct { @@ -1606,9 +1616,10 @@ func TestSettingsSelectWithExistingJoinAndComputed(t *testing.T) { db = db.Joins("Relation") results := []*TestScopeModelWithComputed{} - paginator, db := Scope(db, request, &results) + paginator, err := Scope(db, request, &results) assert.NotNil(t, paginator) + assert.NoError(t, err) expected := map[string]clause.Clause{ "WHERE": { @@ -1674,8 +1685,8 @@ func TestSettingsSelectWithExistingJoinAndComputed(t *testing.T) { }, }, } - assert.Equal(t, expected, db.Statement.Clauses) - assert.Empty(t, db.Statement.Joins) + assert.Equal(t, expected, paginator.DB.Statement.Clauses) + assert.Empty(t, paginator.DB.Statement.Joins) } func TestSettingsSelectWithExistingJoinAndComputedOmit(t *testing.T) { @@ -1691,9 +1702,10 @@ func TestSettingsSelectWithExistingJoinAndComputedOmit(t *testing.T) { db = db.Joins("Relation", db.Session(&gorm.Session{NewDB: true}).Omit("c")) results := []*TestScopeModelWithComputed{} - paginator, db := Scope(db, request, &results) + paginator, err := Scope(db, request, &results) assert.NotNil(t, paginator) + assert.NoError(t, err) expected := map[string]clause.Clause{ "WHERE": { @@ -1758,8 +1770,8 @@ func TestSettingsSelectWithExistingJoinAndComputedOmit(t *testing.T) { }, }, } - assert.Equal(t, expected, db.Statement.Clauses) - assert.Empty(t, db.Statement.Joins) + assert.Equal(t, expected, paginator.DB.Statement.Clauses) + assert.Empty(t, paginator.DB.Statement.Joins) } func TestSettingsSelectWithExistingJoinAndComputedWithoutFiltering(t *testing.T) { @@ -1773,9 +1785,10 @@ func TestSettingsSelectWithExistingJoinAndComputedWithoutFiltering(t *testing.T) db = db.Joins("Relation", db.Session(&gorm.Session{NewDB: true}).Where("Relation.id > ?", 0)) results := []*TestScopeModelWithComputed{} - paginator, db := Scope(db, request, &results) + paginator, err := Scope(db, request, &results) assert.NotNil(t, paginator) + assert.NoError(t, err) expected := map[string]clause.Clause{ "FROM": { @@ -1830,8 +1843,8 @@ func TestSettingsSelectWithExistingJoinAndComputedWithoutFiltering(t *testing.T) }, }, } - assert.Equal(t, expected, db.Statement.Clauses) - assert.Empty(t, db.Statement.Joins) + assert.Equal(t, expected, paginator.DB.Statement.Clauses) + assert.Empty(t, paginator.DB.Statement.Joins) } func TestSettingsDefaultSort(t *testing.T) { @@ -1854,9 +1867,10 @@ func TestSettingsDefaultSort(t *testing.T) { }, } - paginator, db := settings.Scope(db, request, &results) + paginator, err := settings.Scope(db, request, &results) assert.NotNil(t, paginator) + assert.NoError(t, err) expected := map[string]clause.Clause{ "WHERE": { @@ -1912,7 +1926,7 @@ func TestSettingsDefaultSort(t *testing.T) { }, }, } - assert.Equal(t, expected, db.Statement.Clauses) + assert.Equal(t, expected, paginator.DB.Statement.Clauses) request = &Request{ Filter: typeutil.NewUndefined([]*Filter{ @@ -1927,9 +1941,10 @@ func TestSettingsDefaultSort(t *testing.T) { results = []*TestScopeModel{} - paginator, db = settings.Scope(db, request, &results) + paginator, err = settings.Scope(db, request, &results) assert.NotNil(t, paginator) + assert.NoError(t, err) expected = map[string]clause.Clause{ "WHERE": { @@ -1979,7 +1994,7 @@ func TestSettingsDefaultSort(t *testing.T) { }, }, } - assert.Equal(t, expected, db.Statement.Clauses) + assert.Equal(t, expected, paginator.DB.Statement.Clauses) } func TestNewRequest(t *testing.T) {