From a6792ed3f55f495be2c4cc5745fe4dd0ba54e63e Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Wed, 22 Nov 2023 14:56:19 +0300 Subject: [PATCH] Add limits for the amount of aliases in the GraphQL document and field duplication detection. --- cmd/api-firewall/tests/main_graphql_test.go | 378 ++++++++++++++------ internal/config/config.go | 2 + internal/platform/validator/graphql.go | 85 ++++- 3 files changed, 355 insertions(+), 110 deletions(-) diff --git a/cmd/api-firewall/tests/main_graphql_test.go b/cmd/api-firewall/tests/main_graphql_test.go index 5a8159c..ca512f9 100644 --- a/cmd/api-firewall/tests/main_graphql_test.go +++ b/cmd/api-firewall/tests/main_graphql_test.go @@ -21,6 +21,7 @@ import ( "github.com/wallarm/api-firewall/internal/config" "github.com/wallarm/api-firewall/internal/platform/denylist" "github.com/wallarm/api-firewall/internal/platform/proxy" + "github.com/wallarm/api-firewall/internal/platform/validator" "github.com/wundergraph/graphql-go-tools/pkg/graphql" ) @@ -43,6 +44,7 @@ type Chatroom { type Message { id: ID! text: String! + name: String! createdBy: String! createdAt: Time! } @@ -118,6 +120,9 @@ func TestGraphQLBasic(t *testing.T) { // the sequence of messages in the tests: hello -> invalid gql message (<- response from APIFW) -> valid gql message -> complete -> stop t.Run("basicGraphQLQuerySubscription", apifwTests.testGQLSubscription) t.Run("basicGraphQLQuerySubscriptionLogOnly", apifwTests.testGQLSubscriptionLogOnly) + + t.Run("basicGraphQLMaxAliasesNum", apifwTests.testGQLMaxAliasesNum) + t.Run("basicGraphQLDuplicateFields", apifwTests.testGQLDuplicateFields) } func (s *ServiceGraphQLTests) testGQLSuccess(t *testing.T) { @@ -194,9 +199,9 @@ func (s *ServiceGraphQLTests) testGQLSuccess(t *testing.T) { Request: *req, } - s.proxy.EXPECT().Get().Return(s.client, nil) - s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) - s.proxy.EXPECT().Put(s.client).Return(nil) + s.proxy.EXPECT().Get().Return(s.client, nil).Times(1) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp).Times(1) + s.proxy.EXPECT().Put(s.client).Return(nil).Times(1) handler(&reqCtx) @@ -271,14 +276,8 @@ func (s *ServiceGraphQLTests) testGQLGETSuccess(t *testing.T) { req := fasthttp.AcquireRequest() req.SetRequestURI("/query") req.Header.SetMethod("GET") - //req.SetBodyStream(bytes.NewReader(jsonValue), -1) req.URI().QueryArgs().Add("query", url.QueryEscape(query)) - testqqq, _ := url.QueryUnescape(url.QueryEscape(query)) - t.Log(url.QueryEscape(query)) - t.Log(testqqq) - //req.Header.SetContentType("application/json") - resp := fasthttp.AcquireResponse() resp.SetStatusCode(fasthttp.StatusOK) resp.Header.SetContentType("application/json") @@ -288,9 +287,9 @@ func (s *ServiceGraphQLTests) testGQLGETSuccess(t *testing.T) { Request: *req, } - s.proxy.EXPECT().Get().Return(s.client, nil) - s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) - s.proxy.EXPECT().Put(s.client).Return(nil) + s.proxy.EXPECT().Get().Return(s.client, nil).Times(1) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp).Times(1) + s.proxy.EXPECT().Put(s.client).Return(nil).Times(1) handler(&reqCtx) @@ -343,46 +342,15 @@ func (s *ServiceGraphQLTests) testGQLGETMutationFailed(t *testing.T) { } ` - responseBody := `{ - "data": { - "room": { - "name": "GeneralChat", - "messages": [ - { - "id": "TrsXJcKa", - "text": "Hello, world!", - "createdBy": "TestUser", - "createdAt": "2023-01-01T00:00:00+00:00" - } - ] - } - } -}` - req := fasthttp.AcquireRequest() req.SetRequestURI("/query") req.Header.SetMethod("GET") - //req.SetBodyStream(bytes.NewReader(jsonValue), -1) req.URI().QueryArgs().Add("query", url.QueryEscape(query)) - testqqq, _ := url.QueryUnescape(url.QueryEscape(query)) - t.Log(url.QueryEscape(query)) - t.Log(testqqq) - //req.Header.SetContentType("application/json") - - resp := fasthttp.AcquireResponse() - resp.SetStatusCode(fasthttp.StatusOK) - resp.Header.SetContentType("application/json") - resp.SetBody([]byte(responseBody)) - reqCtx := fasthttp.RequestCtx{ Request: *req, } - //s.proxy.EXPECT().Get().Return(s.client, nil) - //s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) - //s.proxy.EXPECT().Put(s.client).Return(nil) - handler(&reqCtx) if reqCtx.Response.StatusCode() != 200 { @@ -610,22 +578,6 @@ func (s *ServiceGraphQLTests) testGQLInvalidMaxComplexity(t *testing.T) { "query": query, } - responseBody := `{ - "data": { - "room": { - "name": "GeneralChat", - "messages": [ - { - "id": "TrsXJcKa", - "text": "Hello, world!", - "createdBy": "TestUser", - "createdAt": "2023-01-01T00:00:00+00:00" - } - ] - } - } -}` - jsonValue, _ := json.Marshal(requestBody) req := fasthttp.AcquireRequest() @@ -634,11 +586,6 @@ func (s *ServiceGraphQLTests) testGQLInvalidMaxComplexity(t *testing.T) { req.SetBodyStream(bytes.NewReader(jsonValue), -1) req.Header.SetContentType("application/json") - resp := fasthttp.AcquireResponse() - resp.SetStatusCode(fasthttp.StatusOK) - resp.Header.SetContentType("application/json") - resp.SetBody([]byte(responseBody)) - reqCtx := fasthttp.RequestCtx{ Request: *req, } @@ -711,22 +658,6 @@ func (s *ServiceGraphQLTests) testGQLInvalidMaxDepth(t *testing.T) { "query": query, } - responseBody := `{ - "data": { - "room": { - "name": "GeneralChat", - "messages": [ - { - "id": "TrsXJcKa", - "text": "Hello, world!", - "createdBy": "TestUser", - "createdAt": "2023-01-01T00:00:00+00:00" - } - ] - } - } -}` - jsonValue, _ := json.Marshal(requestBody) req := fasthttp.AcquireRequest() @@ -735,11 +666,6 @@ func (s *ServiceGraphQLTests) testGQLInvalidMaxDepth(t *testing.T) { req.SetBodyStream(bytes.NewReader(jsonValue), -1) req.Header.SetContentType("application/json") - resp := fasthttp.AcquireResponse() - resp.SetStatusCode(fasthttp.StatusOK) - resp.Header.SetContentType("application/json") - resp.SetBody([]byte(responseBody)) - reqCtx := fasthttp.RequestCtx{ Request: *req, } @@ -813,22 +739,6 @@ func (s *ServiceGraphQLTests) testGQLInvalidNodeLimit(t *testing.T) { "query": query, } - responseBody := `{ - "data": { - "room": { - "name": "GeneralChat", - "messages": [ - { - "id": "TrsXJcKa", - "text": "Hello, world!", - "createdBy": "TestUser", - "createdAt": "2023-01-01T00:00:00+00:00" - } - ] - } - } -}` - jsonValue, _ := json.Marshal(requestBody) req := fasthttp.AcquireRequest() @@ -837,11 +747,6 @@ func (s *ServiceGraphQLTests) testGQLInvalidNodeLimit(t *testing.T) { req.SetBodyStream(bytes.NewReader(jsonValue), -1) req.Header.SetContentType("application/json") - resp := fasthttp.AcquireResponse() - resp.SetStatusCode(fasthttp.StatusOK) - resp.Header.SetContentType("application/json") - resp.SetBody([]byte(responseBody)) - reqCtx := fasthttp.RequestCtx{ Request: *req, } @@ -1411,3 +1316,264 @@ func (s *ServiceGraphQLTests) testGQLSubscriptionLogOnly(t *testing.T) { time.Sleep(1 * time.Second) } + +func (s *ServiceGraphQLTests) testGQLMaxAliasesNum(t *testing.T) { + + gqlCfg := config.GraphQL{ + MaxQueryComplexity: 0, + MaxQueryDepth: 0, + NodeCountLimit: 0, + MaxAliasesNum: 1, + Playground: false, + Introspection: false, + Schema: "", + RequestValidation: "BLOCK", + } + var cfg = config.GraphQLMode{ + Graphql: gqlCfg, + } + + // parse the GraphQL schema + schema, err := graphql.NewSchemaFromString(testSchema) + if err != nil { + t.Fatalf("Loading GraphQL Schema error: %v", err) + } + + handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil) + + // Construct GraphQL request payload + query := ` + query { + a0:room(name: "GeneralChat") { + name + messages { + id + text + createdBy + createdAt + } + } + a1:room(name: "GeneralChat") { + name + messages { + id + text + createdBy + createdAt + } + } +} + ` + var requestBody = map[string]interface{}{ + "query": query, + } + + jsonValue, _ := json.Marshal(requestBody) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/query") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(jsonValue), -1) + req.Header.SetContentType("application/json") + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + gqlResp := new(Response) + + if err := json.Unmarshal(reqCtx.Response.Body(), &gqlResp); err != nil { + t.Fatal(err) + } + + if len(gqlResp.Errors) != 1 { + t.Errorf("Incorrect amount of errors in the response. Expected: 1 and got %d", + len(gqlResp.Errors)) + } + + expectedErrMsg := "the maximum number of aliases in the GraphQL document has been exceeded. The maximum number of aliases value is 1. The current number of aliases is 2" + + if gqlResp.Errors[0].Message != expectedErrMsg { + t.Errorf("Incorrect error message. Expected: \"%s\" and got \"%s\"", + expectedErrMsg, gqlResp.Errors[0].Message) + } + +} + +func (s *ServiceGraphQLTests) testGQLDuplicateFields(t *testing.T) { + + gqlCfg := config.GraphQL{ + MaxQueryComplexity: 0, + MaxQueryDepth: 0, + NodeCountLimit: 0, + MaxAliasesNum: 0, + FieldDuplication: true, + Playground: false, + Introspection: false, + Schema: "", + RequestValidation: "BLOCK", + } + var cfg = config.GraphQLMode{ + Graphql: gqlCfg, + } + + // parse the GraphQL schema + schema, err := graphql.NewSchemaFromString(testSchema) + if err != nil { + t.Fatalf("Loading GraphQL Schema error: %v", err) + } + + handler := graphqlHandler.Handlers(&cfg, schema, s.serverUrl, s.shutdown, s.logger, s.proxy, s.backendWSClient, nil) + + // Construct GraphQL request payload + query := ` + query { + a0:room(name: "GeneralChat") { + name + messages { + id + text + createdBy + createdAt + } + } + a1:room(name: "GeneralChat") { + name + messages { + id + name + text + createdBy + createdAt + } + } +} + ` + var requestBody = map[string]interface{}{ + "query": query, + } + + jsonValue, _ := json.Marshal(requestBody) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/query") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(jsonValue), -1) + req.Header.SetContentType("application/json") + + responseBody := `{ + "data": { + "room": { + "name": "GeneralChat", + "messages": [ + { + "id": "TrsXJcKa", + "text": "Hello, world!", + "createdBy": "TestUser", + "createdAt": "2023-01-01T00:00:00+00:00" + } + ] + } + } +}` + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.SetContentType("application/json") + resp.SetBody([]byte(responseBody)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil).AnyTimes() + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp).AnyTimes() + s.proxy.EXPECT().Put(s.client).Return(nil).AnyTimes() + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + recvBody := strings.TrimSpace(string(reqCtx.Response.Body())) + + if recvBody != responseBody { + t.Errorf("Incorrect response status code. Expected: %s and got %s", + responseBody, recvBody) + } + + // query with duplication of the name field + query = ` + query { + a0:room(name: "GeneralChat") { + name + messages { + id + text + createdBy + createdAt + } + } + a1:room(name: "GeneralChat") { + name + messages { + id + name + name + text + createdBy + createdAt + } + } +} + ` + + requestBody = map[string]interface{}{ + "query": query, + } + + jsonValue, _ = json.Marshal(requestBody) + + req = fasthttp.AcquireRequest() + req.SetRequestURI("/query") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(jsonValue), -1) + req.Header.SetContentType("application/json") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + gqlResp := new(Response) + + if err := json.Unmarshal(reqCtx.Response.Body(), &gqlResp); err != nil { + t.Fatal(err) + } + + if len(gqlResp.Errors) != 1 { + t.Errorf("Incorrect amount of errors in the response. Expected: 1 and got %d", + len(gqlResp.Errors)) + } + + if gqlResp.Errors[0].Message != validator.ErrFieldDuplicationFound.Error() { + t.Errorf("Incorrect error message. Expected: \"%s\" and got \"%s\"", + validator.ErrFieldDuplicationFound.Error(), gqlResp.Errors[0].Message) + } + +} diff --git a/internal/config/config.go b/internal/config/config.go index 5f0cda6..9693bd3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -131,7 +131,9 @@ type ShadowAPI struct { type GraphQL struct { MaxQueryComplexity int `conf:"required" validate:"required"` MaxQueryDepth int `conf:"required" validate:"required"` + MaxAliasesNum int `conf:"required" validate:"required"` NodeCountLimit int `conf:"required" validate:"required"` + FieldDuplication bool `conf:"default:false"` Playground bool `conf:"default:false"` PlaygroundPath string `conf:"default:/" validate:"path"` Introspection bool `conf:"required" validate:"required"` diff --git a/internal/platform/validator/graphql.go b/internal/platform/validator/graphql.go index 2fbb27c..4418e69 100644 --- a/internal/platform/validator/graphql.go +++ b/internal/platform/validator/graphql.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "net/url" "strings" @@ -13,6 +14,7 @@ import ( "github.com/valyala/fasthttp" "github.com/wallarm/api-firewall/internal/config" "github.com/wallarm/api-firewall/internal/platform/complexity" + "github.com/wundergraph/graphql-go-tools/pkg/ast" "github.com/wundergraph/graphql-go-tools/pkg/astparser" "github.com/wundergraph/graphql-go-tools/pkg/graphql" "github.com/wundergraph/graphql-go-tools/pkg/operationreport" @@ -28,6 +30,7 @@ var ( ErrNotAllowIntrospectionQuery = errors.New("introspection query is not allowed") ErrGraphQLQueryNotFound = errors.New("GraphQL query not found in the request") ErrWrongGraphQLQueryTypeInGETRequest = errors.New("wrong GraphQL query type in GET request") + ErrFieldDuplicationFound = errors.New("duplicate fields were found in the GraphQL document") ) // ValidateGraphQLRequest validates the GraphQL request @@ -46,11 +49,28 @@ func ValidateGraphQLRequest(cfg *config.GraphQL, schema *graphql.Schema, r *grap } - // validate operation name value - if err := validateOperationName(r); err != nil { + // parse the GraphQL document + document, _ := astparser.ParseGraphqlDocumentString(r.Query) + + // validate that there are no duplication fields + if err := validateOperationName(&document, r); err != nil { return &graphql.ValidationResult{Valid: false, Errors: graphql.RequestErrorsFromError(err)}, nil } + if cfg.MaxAliasesNum > 0 { + // validate max aliases in the GraphQL document + if err := validateAliasesNum(&document, cfg.MaxAliasesNum); err != nil { + return &graphql.ValidationResult{Valid: false, Errors: graphql.RequestErrorsFromError(err)}, nil + } + } + + if cfg.FieldDuplication { + // validate operation name value + if err := validateDuplicateFields(&document); err != nil { + return &graphql.ValidationResult{Valid: false, Errors: graphql.RequestErrorsFromError(err)}, nil + } + } + // skip query complexity check if it is not configured if cfg.NodeCountLimit > 0 || cfg.MaxQueryDepth > 0 || cfg.MaxQueryComplexity > 0 { @@ -140,9 +160,66 @@ func ParseGraphQLRequest(ctx *fasthttp.RequestCtx) (*graphql.Request, error) { return nil, ErrGraphQLQueryNotFound } -func validateOperationName(gqlRequest *graphql.Request) error { +// validateAliasesNum validates that the total amount of aliases in the GraphQL document does not exceed the configured max value +func validateAliasesNum(document *ast.Document, MaxAliasesNum int) error { + + numOfAliases := getNumOfAliases(document) + if numOfAliases > MaxAliasesNum { + return fmt.Errorf("the maximum number of aliases in the GraphQL document has been exceeded. The maximum number of aliases value is %d. The current number of aliases is %d", MaxAliasesNum, numOfAliases) + } + + return nil +} + +// getNumOfAliases returns amount of aliases in the GraphQL documents +func getNumOfAliases(document *ast.Document) int { + numOfAliases := 0 + + for _, f := range document.Fields { + if f.Alias.IsDefined { + numOfAliases += 1 + } + } + + return numOfAliases +} + +// validateDuplicateFields checks that there are now duplicates fields in the document +func validateDuplicateFields(document *ast.Document) error { + + for _, ss := range document.SelectionSets { + fieldsRepeatMap := make(map[string]int) + for _, cs := range ss.SelectionRefs { + field := document.Fields[document.Selections[cs].Ref] + fieldName := document.FieldNameString(document.Selections[cs].Ref) + + // skip objects + if field.HasSelections { + continue + } + + currentNum, ok := fieldsRepeatMap[fieldName] + if !ok { + fieldsRepeatMap[fieldName] = 1 + continue + } + + currentNum += 1 + fieldsRepeatMap[fieldName] = currentNum + } + + for f := range fieldsRepeatMap { + if fieldsRepeatMap[f] > 1 { + return ErrFieldDuplicationFound + } + } + } + + return nil +} + +func validateOperationName(operation *ast.Document, gqlRequest *graphql.Request) error { - operation, _ := astparser.ParseGraphqlDocumentString(gqlRequest.Query) numOfOperations := operation.NumOfOperationDefinitions() operationName := strings.TrimSpace(gqlRequest.OperationName) report := &operationreport.Report{}