diff --git a/pkg/engine/directives/selection_transform.go b/pkg/engine/directives/selection_transform.go index 24b73c6..b7d1bef 100644 --- a/pkg/engine/directives/selection_transform.go +++ b/pkg/engine/directives/selection_transform.go @@ -13,13 +13,31 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/vektah/gqlparser/v2/ast" "github.com/wundergraph/wundergraph/pkg/wgpb" + "golang.org/x/exp/slices" "strings" ) const ( - transformName = "transform" - transformArgName = "get" - transformInvalidPathFormat = `invalid path [%s] @transform` + transformName = "transform" + transformArgGetName = "get" + transformArgMathName = "math" + transformArgMathType = "TransformMath" + transformInvalidPathFormat = `invalid path [%s] @transform` + transformArgExpectTypeFormat = `argument [%s] just allow apply on [%s] result` + transformArgMathExpectNumberFormat = `invalid [%s] array, expected number array for math [%s]` +) + +var ( + transformNumberMathArray = []wgpb.PostResolveTransformationMath{ + wgpb.PostResolveTransformationMath_MAX, + wgpb.PostResolveTransformationMath_MIN, + wgpb.PostResolveTransformationMath_AVG, + wgpb.PostResolveTransformationMath_SUM, + } + transformNumberSchemaTypes = []string{ + openapi3.TypeInteger, + openapi3.TypeNumber, + } ) type transform struct{} @@ -31,20 +49,34 @@ func (s *transform) Directive() *ast.DirectiveDefinition { Locations: []ast.DirectiveLocation{ast.LocationField}, Arguments: ast.ArgumentDefinitionList{{ Description: i18n.TransformArgGetDesc.String(), - Name: transformArgName, - Type: ast.NonNullNamedType(consts.ScalarString, nil), + Name: transformArgGetName, + Type: ast.NamedType(consts.ScalarString, nil), + }, { + Description: i18n.TransformArgGetDesc.String(), + Name: transformArgMathName, + Type: ast.NamedType(transformArgMathType, nil), }}, } } func (s *transform) Definitions() ast.DefinitionList { - return nil + var nameEnumValues ast.EnumValueList + for k := range wgpb.PostResolveTransformationMath_value { + nameEnumValues = append(nameEnumValues, &ast.EnumValueDefinition{Name: k}) + } + + return ast.DefinitionList{{ + Kind: ast.Enum, + Name: transformArgMathType, + EnumValues: nameEnumValues, + }} } func (s *transform) Resolve(resolver *SelectionResolver) (err error) { - value, ok := resolver.Arguments[transformArgName] - if !ok { - err = fmt.Errorf(argumentRequiredFormat, transformArgName) + getValue, getOk := resolver.Arguments[transformArgGetName] + mathValue, mathOk := resolver.Arguments[transformArgMathName] + if !getOk && !mathOk { + err = fmt.Errorf(argumentRequiredFormat, utils.JoinString(" or ", transformArgGetName, transformArgMathName)) return } @@ -52,12 +84,14 @@ func (s *transform) Resolve(resolver *SelectionResolver) (err error) { postTransformation.From = append(postTransformation.From, resolver.Path...) postTransformation.To = append(postTransformation.To, resolver.Path...) var transformPaths []string - for _, item := range strings.Split(value, utils.StringDot) { - if item == utils.ArrayPath { - continue - } + if getOk { + for _, item := range strings.Split(getValue, utils.StringDot) { + if item == utils.ArrayPath { + continue + } - transformPaths = append(transformPaths, item) + transformPaths = append(transformPaths, item) + } } transformLength := len(transformPaths) schema := resolver.Schema @@ -102,14 +136,37 @@ loop: return } + var resolveMath *wgpb.PostResolveTransformationMath endWithArrayPath := resolver.Path[len(resolver.Path)-1] == utils.ArrayPath + if mathOk { + if !arrayVisited && !endWithArrayPath { + err = fmt.Errorf(transformArgExpectTypeFormat, transformArgMathName, openapi3.TypeArray) + return + } + mathType, ok := wgpb.PostResolveTransformationMath_value[mathValue] + if !ok { + err = fmt.Errorf(argumentValueNotSupportedFormat, mathValue, transformArgMathName) + return + } + if schema.Value.Type == openapi3.TypeArray { + schema.Value = schema.Value.Items.Value + } + _resolveMath := wgpb.PostResolveTransformationMath(mathType) + if slices.Contains(transformNumberMathArray, _resolveMath) && + !slices.Contains(transformNumberSchemaTypes, schema.Value.Type) { + err = fmt.Errorf(transformArgMathExpectNumberFormat, schema.Value.Type, mathValue) + return + } + resolveMath = &_resolveMath + } + switch schema.Value.Type { case openapi3.TypeArray: if endWithArrayPath { postTransformation.From = append(postTransformation.From, utils.ArrayPath) } default: - if arrayVisited { + if arrayVisited && !mathOk { schema.Value = &openapi3.Schema{Items: &openapi3.SchemaRef{Value: schema.Value}, Type: openapi3.TypeArray} if !endWithArrayPath { arrayIndex := utils.LastIndex(postTransformation.From, utils.ArrayPath) @@ -130,6 +187,7 @@ loop: Kind: wgpb.PostResolveTransformationKind_GET_POST_RESOLVE_TRANSFORMATION, Depth: int32(len(postTransformation.From)), Get: postTransformation, + Math: resolveMath, }) return diff --git a/wundergraphGitSubmodule b/wundergraphGitSubmodule index 0941dc2..3299b86 160000 --- a/wundergraphGitSubmodule +++ b/wundergraphGitSubmodule @@ -1 +1 @@ -Subproject commit 0941dc2216dfad39f7972987d28b35db355d879f +Subproject commit 3299b863f6c44c85861375759cebedaa097a2878