Skip to content

Commit

Permalink
feat: add support resolve math argument for directive @Transform.
Browse files Browse the repository at this point in the history
  • Loading branch information
xufeixiang committed Nov 4, 2024
1 parent 251b052 commit 8f1fdfb
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 16 deletions.
88 changes: 73 additions & 15 deletions pkg/engine/directives/selection_transform.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,31 @@ import (
"github.com/getkin/kin-openapi/openapi3"
"github.com/vektah/gqlparser/v2/ast"
"github.com/wundergraph/wundergraph/pkg/wgpb"
"golang.org/x/exp/slices"
"strings"
)

const (
transformName = "transform"
transformArgName = "get"
transformInvalidPathFormat = `invalid path [%s] @transform`
transformName = "transform"
transformArgGetName = "get"
transformArgMathName = "math"
transformArgMathType = "TransformMath"
transformInvalidPathFormat = `invalid path [%s] @transform`
transformArgExpectTypeFormat = `argument [%s] just allow apply on [%s] result`
transformArgMathExpectNumberFormat = `invalid [%s] array, expected number array for math [%s]`
)

var (
transformNumberMathArray = []wgpb.PostResolveTransformationMath{
wgpb.PostResolveTransformationMath_MAX,
wgpb.PostResolveTransformationMath_MIN,
wgpb.PostResolveTransformationMath_AVG,
wgpb.PostResolveTransformationMath_SUM,
}
transformNumberSchemaTypes = []string{
openapi3.TypeInteger,
openapi3.TypeNumber,
}
)

type transform struct{}
Expand All @@ -31,33 +49,49 @@ func (s *transform) Directive() *ast.DirectiveDefinition {
Locations: []ast.DirectiveLocation{ast.LocationField},
Arguments: ast.ArgumentDefinitionList{{
Description: i18n.TransformArgGetDesc.String(),
Name: transformArgName,
Type: ast.NonNullNamedType(consts.ScalarString, nil),
Name: transformArgGetName,
Type: ast.NamedType(consts.ScalarString, nil),
}, {
Description: i18n.TransformArgGetDesc.String(),
Name: transformArgMathName,
Type: ast.NamedType(transformArgMathType, nil),
}},
}
}

func (s *transform) Definitions() ast.DefinitionList {
return nil
var nameEnumValues ast.EnumValueList
for k := range wgpb.PostResolveTransformationMath_value {
nameEnumValues = append(nameEnumValues, &ast.EnumValueDefinition{Name: k})
}

return ast.DefinitionList{{
Kind: ast.Enum,
Name: transformArgMathType,
EnumValues: nameEnumValues,
}}
}

func (s *transform) Resolve(resolver *SelectionResolver) (err error) {
value, ok := resolver.Arguments[transformArgName]
if !ok {
err = fmt.Errorf(argumentRequiredFormat, transformArgName)
getValue, getOk := resolver.Arguments[transformArgGetName]
mathValue, mathOk := resolver.Arguments[transformArgMathName]
if !getOk && !mathOk {
err = fmt.Errorf(argumentRequiredFormat, utils.JoinString(" or ", transformArgGetName, transformArgMathName))
return
}

postTransformation := &wgpb.PostResolveGetTransformation{}
postTransformation.From = append(postTransformation.From, resolver.Path...)
postTransformation.To = append(postTransformation.To, resolver.Path...)
var transformPaths []string
for _, item := range strings.Split(value, utils.StringDot) {
if item == utils.ArrayPath {
continue
}
if getOk {
for _, item := range strings.Split(getValue, utils.StringDot) {
if item == utils.ArrayPath {
continue
}

transformPaths = append(transformPaths, item)
transformPaths = append(transformPaths, item)
}
}
transformLength := len(transformPaths)
schema := resolver.Schema
Expand Down Expand Up @@ -102,14 +136,37 @@ loop:
return
}

var resolveMath *wgpb.PostResolveTransformationMath
endWithArrayPath := resolver.Path[len(resolver.Path)-1] == utils.ArrayPath
if mathOk {
if !arrayVisited && !endWithArrayPath {
err = fmt.Errorf(transformArgExpectTypeFormat, transformArgMathName, openapi3.TypeArray)
return
}
mathType, ok := wgpb.PostResolveTransformationMath_value[mathValue]
if !ok {
err = fmt.Errorf(argumentValueNotSupportedFormat, mathValue, transformArgMathName)
return
}
if schema.Value.Type == openapi3.TypeArray {
schema.Value = schema.Value.Items.Value
}
_resolveMath := wgpb.PostResolveTransformationMath(mathType)
if slices.Contains(transformNumberMathArray, _resolveMath) &&
!slices.Contains(transformNumberSchemaTypes, schema.Value.Type) {
err = fmt.Errorf(transformArgMathExpectNumberFormat, schema.Value.Type, mathValue)
return
}
resolveMath = &_resolveMath
}

switch schema.Value.Type {
case openapi3.TypeArray:
if endWithArrayPath {
postTransformation.From = append(postTransformation.From, utils.ArrayPath)
}
default:
if arrayVisited {
if arrayVisited && !mathOk {
schema.Value = &openapi3.Schema{Items: &openapi3.SchemaRef{Value: schema.Value}, Type: openapi3.TypeArray}
if !endWithArrayPath {
arrayIndex := utils.LastIndex(postTransformation.From, utils.ArrayPath)
Expand All @@ -130,6 +187,7 @@ loop:
Kind: wgpb.PostResolveTransformationKind_GET_POST_RESOLVE_TRANSFORMATION,
Depth: int32(len(postTransformation.From)),
Get: postTransformation,
Math: resolveMath,
})

return
Expand Down

0 comments on commit 8f1fdfb

Please sign in to comment.