Skip to content

Commit

Permalink
dialect/sql/sqlgraph: expose standard modifier to eager-load N neighb…
Browse files Browse the repository at this point in the history
…ors (ent#3603)
  • Loading branch information
a8m authored Jun 17, 2023
1 parent ee7a50b commit b49d5f5
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 2 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ jobs:
--health-timeout 5s
--health-retries 10
maria:
image: mariadb
image: mariadb:10.4 # Temporary to unblock PRs from failing.
env:
MYSQL_DATABASE: test
MYSQL_ROOT_PASSWORD: pass
Expand Down Expand Up @@ -321,7 +321,7 @@ jobs:
--health-timeout 5s
--health-retries 10
maria:
image: mariadb
image: mariadb:10.4 # Temporary to unblock PRs from failing.
env:
MYSQL_DATABASE: test
MYSQL_ROOT_PASSWORD: pass
Expand Down
63 changes: 63 additions & 0 deletions dialect/sql/sqlgraph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"math"
"sort"
Expand Down Expand Up @@ -520,6 +521,68 @@ func OrderByNeighborTerms(q *sql.Selector, s *Step, opts ...sql.OrderTerm) {
orderTerms(q, join, opts)
}

// NeighborsLimit provides a modifier function that limits the
// number of neighbors (rows) loaded per parent row (node).
type NeighborsLimit struct {
// SrcCTE, LimitCTE and RowNumber hold the identifier names
// to src query, new limited one (using window function) and
// the column for counting rows.
SrcCTE, LimitCTE, RowNumber string
// DefaultOrderField sets the default ordering for
// sub-queries in case no order terms were provided.
DefaultOrderField string
}

// LimitNeighbors returns a modifier that limits the number of neighbors (rows) loaded per parent
// row (node). The "partitionBy" is the foreign-key column (edge) to partition the window function
// by, the "limit" is the maximum number of rows per parent, and the "orderBy" defines the order of
// how neighbors (connected by the edge) are returned.
//
// This function is useful for non-unique edges, such as O2M and M2M, where the same parent can
// have multiple children.
func LimitNeighbors(partitionBy string, limit int, orderBy ...sql.Querier) func(*sql.Selector) {
l := &NeighborsLimit{
SrcCTE: "src_query",
LimitCTE: "limited_query",
RowNumber: "row_number",
DefaultOrderField: "id",
}
return l.Modifier(partitionBy, limit, orderBy...)
}

// Modifier returns a modifier function that limits the number of rows of the eager load query.
func (l *NeighborsLimit) Modifier(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sql.Selector) {
return func(s *sql.Selector) {
var (
d = sql.Dialect(s.Dialect())
rn = sql.RowNumber().PartitionBy(partitionBy)
)
switch {
case len(orderBy) > 0:
rn.OrderExpr(orderBy...)
case l.DefaultOrderField != "":
rn.OrderBy(l.DefaultOrderField)
default:
s.AddError(errors.New("no order terms provided for window function"))
return
}
s.SetDistinct(false)
with := d.With(l.SrcCTE).
As(s.Clone()).
With(l.LimitCTE).
As(
d.Select("*").
AppendSelectExprAs(rn, l.RowNumber).
From(d.Table(l.SrcCTE)),
)
t := d.Table(l.LimitCTE).As(s.TableName())
*s = *d.Select(s.UnqualifiedColumns()...).
From(t).
Where(sql.LTE(t.C(l.RowNumber), limit)).
Prefix(with)
}
}

type (
// FieldSpec holds the information for updating a field
// column in the database.
Expand Down
27 changes: 27 additions & 0 deletions dialect/sql/sqlgraph/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2648,6 +2648,33 @@ func TestIsConstraintError(t *testing.T) {
}
}

func TestLimitNeighbors(t *testing.T) {
t.Run("O2M", func(t *testing.T) {
const fk = "author_id"
// Authors load their posts.
s := sql.Select(fk, "id").From(sql.Table("posts"))
LimitNeighbors(fk, 2)(s)
query, args := s.Query()
require.Equal(t,
"WITH `src_query` AS (SELECT `author_id`, `id` FROM `posts`), `limited_query` AS (SELECT *, (ROW_NUMBER() OVER (PARTITION BY `author_id` ORDER BY `id`)) AS `row_number` FROM `src_query`) SELECT `author_id`, `id` FROM `limited_query` AS `posts` WHERE `posts`.`row_number` <= ?",
query,
)
require.Equal(t, []any{2}, args)
})
t.Run("M2M", func(t *testing.T) {
const fk = "user_id"
edgeT, neighborsT := sql.Table("user_groups"), sql.Table("groups")
s := sql.Select(fk, "id", "name").From(neighborsT).Join(edgeT).On(neighborsT.C("id"), edgeT.C("group_id"))
LimitNeighbors(fk, 1, sql.ExprFunc(func(b *sql.Builder) { b.Ident("updated_at") }))(s)
query, args := s.Query()
require.Equal(t,
"WITH `src_query` AS (SELECT `user_id`, `id`, `name` FROM `groups` JOIN `user_groups` AS `t1` ON `groups`.`id` = `t1`.`group_id`), `limited_query` AS (SELECT *, (ROW_NUMBER() OVER (PARTITION BY `user_id` ORDER BY `updated_at`)) AS `row_number` FROM `src_query`) SELECT `user_id`, `id`, `name` FROM `limited_query` AS `groups` WHERE `groups`.`row_number` <= ?",
query,
)
require.Equal(t, []any{1}, args)
})
}

func escape(query string) string {
rows := strings.Split(query, "\n")
for i := range rows {
Expand Down

0 comments on commit b49d5f5

Please sign in to comment.