diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..34eb00e --- /dev/null +++ b/.dockerignore @@ -0,0 +1 @@ +resources/dev/wallarm_api.db \ No newline at end of file diff --git a/.github/workflows/binaries.yml b/.github/workflows/binaries.yml index c4915fd..ebf6312 100644 --- a/.github/workflows/binaries.yml +++ b/.github/workflows/binaries.yml @@ -51,7 +51,7 @@ jobs: needs: - draft-release env: - X_GO_DISTRIBUTION: "https://go.dev/dl/go1.19.5.linux-amd64.tar.gz" + X_GO_DISTRIBUTION: "https://go.dev/dl/go1.20.7.linux-amd64.tar.gz" strategy: matrix: include: @@ -78,6 +78,7 @@ jobs: apt-get update -y && \ apt-get install --no-install-recommends -y \ + build-essential \ binutils \ ca-certificates \ curl \ @@ -159,7 +160,7 @@ jobs: needs: - draft-release env: - X_GO_VERSION: "1.19.5-r0" + X_GO_VERSION: "1.20.7-r0" strategy: matrix: include: @@ -177,7 +178,7 @@ jobs: - uses: addnab/docker-run-action@v3 with: - image: alpine:3.17 + image: alpine:3.18 options: > --volume ${{ github.workspace }}:/build --workdir /build @@ -261,18 +262,16 @@ jobs: runs-on: ubuntu-latest needs: - draft-release - env: - X_GO_VERSION: "1.19.5-r0" strategy: matrix: include: - arch: armv6 distro: bullseye - go_distribution: https://go.dev/dl/go1.19.5.linux-armv6l.tar.gz + go_distribution: https://go.dev/dl/go1.20.7.linux-armv6l.tar.gz artifact: armv6-libc - arch: aarch64 distro: bullseye - go_distribution: https://go.dev/dl/go1.19.5.linux-arm64.tar.gz + go_distribution: https://go.dev/dl/go1.20.7.linux-arm64.tar.gz artifact: arm64-libc - arch: armv6 distro: alpine_latest @@ -328,10 +327,11 @@ jobs: git \ gzip \ make \ - go=${{ env.X_GO_VERSION }} + musl-dev \ + go ;; esac - + go version run: |- export PATH=${PATH}:/usr/local/go/bin && \ diff --git a/.gitignore b/.gitignore index b0efd9b..cbc4afa 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ vendor/ .DS_Store .idea/ +dev/ diff --git a/Dockerfile b/Dockerfile index 8da9312..6fda827 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.19-alpine AS build +FROM golang:1.20-alpine3.18 AS build ARG APIFIREWALL_VERSION ENV APIFIREWALL_VERSION=${APIFIREWALL_VERSION} @@ -10,7 +10,7 @@ RUN apk add --no-cache \ musl-dev WORKDIR /build -COPY . . +COPY .. . RUN go mod download -x && \ go build \ @@ -22,17 +22,17 @@ RUN go mod download -x && \ # Smoke test RUN ./api-firewall -v -FROM alpine:3.17 AS composer +FROM alpine:3.18 AS composer WORKDIR /output COPY --from=build /build/api-firewall ./usr/local/bin/ -COPY ./docker-entrypoint.sh ./usr/local/bin/docker-entrypoint.sh +COPY docker-entrypoint.sh ./usr/local/bin/docker-entrypoint.sh RUN chmod 755 ./usr/local/bin/* && \ chown root:root ./usr/local/bin/* -FROM alpine:3.17 +FROM alpine:3.18 RUN adduser -u 1000 -H -h /opt -D -s /bin/sh api-firewall diff --git a/Makefile b/Makefile index aca57d0..39c2a76 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -VERSION := 0.6.11 +VERSION := 0.6.12 .DEFAULT_GOAL := build @@ -13,11 +13,12 @@ tidy: go mod vendor test: - go test ./... -count=1 + go test ./... -count=1 -race -cover genmocks: mockgen -source ./internal/platform/proxy/chainpool.go -destination ./internal/platform/proxy/httppool_mock.go -package proxy - mockgen -source ./internal/platform/shadowAPI/shadowAPI.go -destination ./internal/platform/shadowAPI/shadowAPI_mock.go -package shadowAPI + mockgen -source ./internal/platform/database/database.go -destination ./internal/platform/database/database_mock.go -package database + mockgen -source ./cmd/api-firewall/internal/updater/updater.go -destination ./cmd/api-firewall/internal/updater/updater_mock.go -package updater update: go get -u ./... diff --git a/cmd/api-firewall/internal/handlers/api/health.go b/cmd/api-firewall/internal/handlers/api/health.go new file mode 100644 index 0000000..cceec98 --- /dev/null +++ b/cmd/api-firewall/internal/handlers/api/health.go @@ -0,0 +1,68 @@ +package api + +import ( + "os" + + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" + "github.com/wallarm/api-firewall/internal/platform/database" + "github.com/wallarm/api-firewall/internal/platform/web" +) + +type Health struct { + Build string + Logger *logrus.Logger + OpenAPIDB database.DBOpenAPILoader +} + +// Readiness checks if the Fasthttp connection pool is ready to handle new requests. +func (h *Health) Readiness(ctx *fasthttp.RequestCtx) error { + + status := "ok" + statusCode := fasthttp.StatusOK + + if len(h.OpenAPIDB.SchemaIDs()) == 0 { + status = "not ready" + statusCode = fasthttp.StatusInternalServerError + } + + data := struct { + Status string `json:"status"` + }{ + Status: status, + } + + return web.Respond(ctx, data, statusCode) +} + +// Liveness returns simple status info if the service is alive. If the +// app is deployed to a Kubernetes cluster, it will also return pod, node, and +// namespace details via the Downward API. The Kubernetes environment variables +// need to be set within your Pod/Deployment manifest. +func (h *Health) Liveness(ctx *fasthttp.RequestCtx) error { + host, err := os.Hostname() + if err != nil { + host = "unavailable" + } + + data := struct { + Status string `json:"status,omitempty"` + Build string `json:"build,omitempty"` + Host string `json:"host,omitempty"` + Pod string `json:"pod,omitempty"` + PodIP string `json:"podIP,omitempty"` + Node string `json:"node,omitempty"` + Namespace string `json:"namespace,omitempty"` + }{ + Status: "up", + Build: h.Build, + Host: host, + Pod: os.Getenv("KUBERNETES_PODNAME"), + PodIP: os.Getenv("KUBERNETES_NAMESPACE_POD_IP"), + Node: os.Getenv("KUBERNETES_NODENAME"), + Namespace: os.Getenv("KUBERNETES_NAMESPACE"), + } + + statusCode := fasthttp.StatusOK + return web.Respond(ctx, data, statusCode) +} diff --git a/cmd/api-firewall/internal/handlers/api/openapi.go b/cmd/api-firewall/internal/handlers/api/openapi.go new file mode 100644 index 0000000..c01694b --- /dev/null +++ b/cmd/api-firewall/internal/handlers/api/openapi.go @@ -0,0 +1,512 @@ +package api + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "sync" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/savsgio/gotils/strconv" + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttpadaptor" + "github.com/valyala/fastjson" + "github.com/wallarm/api-firewall/internal/config" + "github.com/wallarm/api-firewall/internal/platform/router" + "github.com/wallarm/api-firewall/internal/platform/validator" + "github.com/wallarm/api-firewall/internal/platform/web" +) + +var ( + ErrAuthHeaderMissed = errors.New("missing Authorization header") + ErrAPITokenMissed = errors.New("missing API keys for authorization") +) + +var apiModeSecurityRequirementsOptions = &openapi3filter.Options{ + MultiError: true, + AuthenticationFunc: func(ctx context.Context, input *openapi3filter.AuthenticationInput) error { + switch input.SecurityScheme.Type { + case "http": + switch input.SecurityScheme.Scheme { + case "basic": + bHeader := input.RequestValidationInput.Request.Header.Get("Authorization") + if bHeader == "" || !strings.HasPrefix(strings.ToLower(bHeader), "basic ") { + return fmt.Errorf("%w: basic authentication is required", ErrAuthHeaderMissed) + } + case "bearer": + bHeader := input.RequestValidationInput.Request.Header.Get("Authorization") + if bHeader == "" || !strings.HasPrefix(strings.ToLower(bHeader), "bearer ") { + return fmt.Errorf("%w: bearer authentication is required", ErrAuthHeaderMissed) + } + } + case "apiKey": + switch input.SecurityScheme.In { + case "header": + if input.RequestValidationInput.Request.Header.Get(input.SecurityScheme.Name) == "" { + return fmt.Errorf("%w: missing %s header", ErrAPITokenMissed, input.SecurityScheme.Name) + } + case "query": + if input.RequestValidationInput.Request.URL.Query().Get(input.SecurityScheme.Name) == "" { + return fmt.Errorf("%w: missing %s query parameter", ErrAPITokenMissed, input.SecurityScheme.Name) + } + case "cookie": + _, err := input.RequestValidationInput.Request.Cookie(input.SecurityScheme.Name) + if err != nil { + return fmt.Errorf("%w: missing %s cookie", ErrAPITokenMissed, input.SecurityScheme.Name) + } + } + } + return nil + }, +} + +type APIMode struct { + CustomRoute *router.CustomRoute + OpenAPIRouter *router.Router + Log *logrus.Logger + Cfg *config.APIFWConfigurationAPIMode + ParserPool *fastjson.ParserPool +} + +const ( + ErrCodeMethodAndPathNotFound = "method_and_path_not_found" + ErrCodeRequiredBodyMissed = "required_body_missed" + ErrCodeRequiredBodyParseError = "required_body_parse_error" + ErrCodeRequiredBodyParameterMissed = "required_body_parameter_missed" + ErrCodeRequiredBodyParameterInvalidValue = "required_body_parameter_invalid_value" + ErrCodeRequiredPathParameterMissed = "required_path_parameter_missed" + ErrCodeRequiredPathParameterInvalidValue = "required_path_parameter_invalid_value" + ErrCodeRequiredQueryParameterMissed = "required_query_parameter_missed" + ErrCodeRequiredQueryParameterInvalidValue = "required_query_parameter_invalid_value" + ErrCodeRequiredCookieParameterMissed = "required_cookie_parameter_missed" + ErrCodeRequiredCookieParameterInvalidValue = "required_cookie_parameter_invalid_value" + ErrCodeRequiredHeaderMissed = "required_header_missed" + ErrCodeRequiredHeaderInvalidValue = "required_header_invalid_value" + + ErrCodeSecRequirementsFailed = "required_security_requirements_failed" + + ErrCodeUnknownParameterFound = "unknown_parameter_found" + + ErrCodeUnknownValidationError = "unknown_validation_error" +) + +var ( + ErrMethodAndPathNotFound = errors.New("method and path are not found") + + ErrRequiredBodyIsMissing = errors.New("required body is missing") + ErrMissedRequiredParameters = errors.New("required parameters missed") +) + +type ValidationError struct { + Message string `json:"message"` + Code string `json:"code"` + SchemaVersion string `json:"schema_version,omitempty"` + Fields []string `json:"related_fields,omitempty"` +} + +type Response struct { + Errors []*ValidationError `json:"errors"` +} + +func getErrorResponse(validationError error) ([]*ValidationError, error) { + var responseErrors []*ValidationError + + switch err := validationError.(type) { + + case *openapi3filter.RequestError: + if err.Parameter != nil { + + // required parameter is missed + if errors.Is(err, validator.ErrInvalidRequired) || errors.Is(err, validator.ErrInvalidEmptyValue) { + response := ValidationError{} + switch err.Parameter.In { + case "path": + response.Code = ErrCodeRequiredPathParameterMissed + case "query": + response.Code = ErrCodeRequiredQueryParameterMissed + case "cookie": + response.Code = ErrCodeRequiredCookieParameterMissed + case "header": + response.Code = ErrCodeRequiredHeaderMissed + } + response.Message = err.Error() + response.Fields = []string{err.Parameter.Name} + responseErrors = append(responseErrors, &response) + } + + // invalid parameter value + if strings.HasSuffix(err.Error(), "invalid syntax") { + response := ValidationError{} + switch err.Parameter.In { + case "path": + response.Code = ErrCodeRequiredPathParameterInvalidValue + case "query": + response.Code = ErrCodeRequiredQueryParameterInvalidValue + case "cookie": + response.Code = ErrCodeRequiredCookieParameterInvalidValue + case "header": + response.Code = ErrCodeRequiredHeaderInvalidValue + } + response.Message = err.Error() + response.Fields = []string{err.Parameter.Name} + responseErrors = append(responseErrors, &response) + } + + // validation of the required parameter error + switch multiErrors := err.Err.(type) { + case openapi3.MultiError: + for _, multiErr := range multiErrors { + schemaError, ok := multiErr.(*openapi3.SchemaError) + if ok { + response := ValidationError{} + switch schemaError.SchemaField { + case "required": + switch err.Parameter.In { + case "query": + response.Code = ErrCodeRequiredQueryParameterMissed + case "cookie": + response.Code = ErrCodeRequiredCookieParameterMissed + case "header": + response.Code = ErrCodeRequiredHeaderMissed + } + response.Fields = schemaError.JSONPointer() + response.Message = ErrMissedRequiredParameters.Error() + responseErrors = append(responseErrors, &response) + default: + switch err.Parameter.In { + case "query": + response.Code = ErrCodeRequiredQueryParameterInvalidValue + case "cookie": + response.Code = ErrCodeRequiredCookieParameterInvalidValue + case "header": + response.Code = ErrCodeRequiredHeaderInvalidValue + } + response.Fields = []string{err.Parameter.Name} + response.Message = schemaError.Error() + responseErrors = append(responseErrors, &response) + } + } + } + default: + schemaError, ok := multiErrors.(*openapi3.SchemaError) + if ok { + response := ValidationError{} + switch schemaError.SchemaField { + case "required": + switch err.Parameter.In { + case "query": + response.Code = ErrCodeRequiredQueryParameterMissed + case "cookie": + response.Code = ErrCodeRequiredCookieParameterMissed + case "header": + response.Code = ErrCodeRequiredHeaderMissed + } + response.Fields = schemaError.JSONPointer() + response.Message = ErrMissedRequiredParameters.Error() + responseErrors = append(responseErrors, &response) + default: + switch err.Parameter.In { + case "query": + response.Code = ErrCodeRequiredQueryParameterInvalidValue + case "cookie": + response.Code = ErrCodeRequiredCookieParameterInvalidValue + case "header": + response.Code = ErrCodeRequiredHeaderInvalidValue + } + response.Fields = []string{err.Parameter.Name} + response.Message = schemaError.Error() + responseErrors = append(responseErrors, &response) + } + } + } + + } + + // validation of the required body error + switch multiErrors := err.Err.(type) { + case openapi3.MultiError: + for _, multiErr := range multiErrors { + schemaError, ok := multiErr.(*openapi3.SchemaError) + if ok { + response := ValidationError{} + switch schemaError.SchemaField { + case "required": + response.Code = ErrCodeRequiredBodyParameterMissed + response.Fields = schemaError.JSONPointer() + response.Message = schemaError.Error() + responseErrors = append(responseErrors, &response) + default: + response.Code = ErrCodeRequiredBodyParameterInvalidValue + response.Fields = schemaError.JSONPointer() + response.Message = schemaError.Error() + responseErrors = append(responseErrors, &response) + } + } + } + default: + schemaError, ok := multiErrors.(*openapi3.SchemaError) + if ok { + response := ValidationError{} + switch schemaError.SchemaField { + case "required": + response.Code = ErrCodeRequiredBodyParameterMissed + response.Fields = schemaError.JSONPointer() + response.Message = schemaError.Error() + responseErrors = append(responseErrors, &response) + default: + response.Code = ErrCodeRequiredBodyParameterInvalidValue + response.Fields = schemaError.JSONPointer() + response.Message = schemaError.Error() + responseErrors = append(responseErrors, &response) + } + } + } + + // handle request body errors + if err.RequestBody != nil { + + // body required but missed + if err.RequestBody.Required { + if err.Err != nil && err.Err.Error() == validator.ErrInvalidRequired.Error() { + response := ValidationError{} + response.Code = ErrCodeRequiredBodyMissed + response.Message = ErrRequiredBodyIsMissing.Error() + responseErrors = append(responseErrors, &response) + } + } + + // body parser not found + if strings.HasPrefix(err.Error(), "request body has an error: failed to decode request body: unsupported content type") { + return nil, err + } + + // body parse errors + _, isParseErr := err.Err.(*validator.ParseError) + if isParseErr || strings.HasPrefix(err.Error(), "request body has an error: header Content-Type has unexpected value") { + response := ValidationError{} + response.Code = ErrCodeRequiredBodyParseError + response.Message = err.Error() + responseErrors = append(responseErrors, &response) + } + } + + case *openapi3filter.SecurityRequirementsError: + + response := ValidationError{} + + secErrors := "" + for _, secError := range err.Errors { + secErrors += secError.Error() + "," + } + + response.Code = ErrCodeSecRequirementsFailed + response.Message = secErrors + responseErrors = append(responseErrors, &response) + } + + // set the error as unknown + if len(responseErrors) == 0 { + response := ValidationError{} + response.Code = ErrCodeUnknownValidationError + response.Message = validationError.Error() + responseErrors = append(responseErrors, &response) + } + + return responseErrors, nil +} + +// APIModeHandler validates request and respond with 200, 403 (with error) or 500 status code +func (s *APIMode) APIModeHandler(ctx *fasthttp.RequestCtx) error { + + // route not found + if s.CustomRoute == nil { + s.Log.WithFields(logrus.Fields{ + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Debug("method or path were not found") + return web.Respond(ctx, Response{Errors: []*ValidationError{{Message: ErrMethodAndPathNotFound.Error(), Code: ErrCodeMethodAndPathNotFound}}}, fasthttp.StatusForbidden) + } + + // get path parameters + var pathParams map[string]string + + if s.CustomRoute.ParametersNumberInPath > 0 { + pathParams = make(map[string]string) + + ctx.VisitUserValues(func(key []byte, value interface{}) { + keyStr := strconv.B2S(key) + if keyStr != web.WallarmSchemaID { + pathParams[keyStr] = value.(string) + } + }) + } + + // Convert fasthttp request to net/http request + req := http.Request{} + if err := fasthttpadaptor.ConvertRequest(ctx, &req, false); err != nil { + s.Log.WithFields(logrus.Fields{ + "error": err, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("error while converting http request") + return web.RespondError(ctx, fasthttp.StatusBadRequest, "") + } + + // decode request body + requestContentEncoding := string(ctx.Request.Header.ContentEncoding()) + if requestContentEncoding != "" { + var err error + if req.Body, err = web.GetDecompressedRequestBody(&ctx.Request, requestContentEncoding); err != nil { + s.Log.WithFields(logrus.Fields{ + "error": err, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("request body decompression error") + return err + } + } + + // Validate request + requestValidationInput := &openapi3filter.RequestValidationInput{ + Request: &req, + PathParams: pathParams, + Route: s.CustomRoute.Route, + Options: apiModeSecurityRequirementsOptions, + } + + var wg sync.WaitGroup + + var valReqErrors error + + wg.Add(1) + go func() { + defer wg.Done() + + // Get fastjson parser + jsonParser := s.ParserPool.Get() + defer s.ParserPool.Put(jsonParser) + + valReqErrors = validator.ValidateRequest(ctx, requestValidationInput, jsonParser) + }() + + var valUPReqErrors error + var upResults []validator.RequestUnknownParameterError + + // validate unknown parameters + if s.Cfg.UnknownParametersDetection { + wg.Add(1) + go func() { + defer wg.Done() + + // Get fastjson parser + jsonParser := s.ParserPool.Get() + defer s.ParserPool.Put(jsonParser) + + upResults, valUPReqErrors = validator.ValidateUnknownRequestParameters(ctx, requestValidationInput.Route, req.Header, jsonParser) + }() + } + + wg.Wait() + + var respErrors []*ValidationError + + if valReqErrors != nil { + + switch valErr := valReqErrors.(type) { + + case openapi3.MultiError: + + for _, currentErr := range valErr { + // parse validation error and build the response + parsedValErrs, unknownErr := getErrorResponse(currentErr) + if unknownErr != nil { + return web.RespondError(ctx, fasthttp.StatusInternalServerError, "") + } + + if len(parsedValErrs) > 0 { + for i := range parsedValErrs { + parsedValErrs[i].SchemaVersion = s.OpenAPIRouter.SchemaVersion + } + respErrors = append(respErrors, parsedValErrs...) + } + } + + s.Log.WithFields(logrus.Fields{ + "error": valReqErrors, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("request validation error") + default: + // parse validation error and build the response + parsedValErrs, unknownErr := getErrorResponse(valErr) + if unknownErr != nil { + return web.RespondError(ctx, fasthttp.StatusInternalServerError, "") + } + if parsedValErrs != nil { + s.Log.WithFields(logrus.Fields{ + "error": valErr, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Warning("request validation error") + + // set schema version for each validation + if len(parsedValErrs) > 0 { + for i := range parsedValErrs { + parsedValErrs[i].SchemaVersion = s.OpenAPIRouter.SchemaVersion + } + } + respErrors = append(respErrors, parsedValErrs...) + } + } + + if len(respErrors) == 0 { + s.Log.WithFields(logrus.Fields{ + "error": valReqErrors, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("request validation error") + + // validation function returned unknown error + return web.RespondError(ctx, fasthttp.StatusInternalServerError, "") + } + } + + // validate unknown parameters + if s.Cfg.UnknownParametersDetection { + + if valUPReqErrors != nil { + s.Log.WithFields(logrus.Fields{ + "error": valUPReqErrors, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("searching for undefined parameters") + + // if it is not a parsing error then return 500 + // if it is a parsing error then it already handled by the request validator + if _, ok := valUPReqErrors.(*validator.ParseError); !ok { + return web.RespondError(ctx, fasthttp.StatusInternalServerError, "") + } + } + + if len(upResults) > 0 { + for _, upResult := range upResults { + s.Log.WithFields(logrus.Fields{ + "error": upResult.Err, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("searching for undefined parameters") + + response := ValidationError{} + response.SchemaVersion = s.OpenAPIRouter.SchemaVersion + response.Message = upResult.Err.Error() + response.Code = ErrCodeUnknownParameterFound + response.Fields = upResult.Parameters + respErrors = append(respErrors, &response) + } + } + } + + // respond 403 with errors + if len(respErrors) > 0 { + return web.Respond(ctx, Response{Errors: respErrors}, fasthttp.StatusForbidden) + } + + // request successfully validated + return web.RespondError(ctx, fasthttp.StatusOK, "") +} diff --git a/cmd/api-firewall/internal/handlers/api/routes.go b/cmd/api-firewall/internal/handlers/api/routes.go new file mode 100644 index 0000000..d6ef1d8 --- /dev/null +++ b/cmd/api-firewall/internal/handlers/api/routes.go @@ -0,0 +1,79 @@ +package api + +import ( + "net/url" + "os" + "path" + "sync" + + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" + "github.com/valyala/fastjson" + "github.com/wallarm/api-firewall/internal/config" + "github.com/wallarm/api-firewall/internal/mid" + "github.com/wallarm/api-firewall/internal/platform/database" + "github.com/wallarm/api-firewall/internal/platform/router" + "github.com/wallarm/api-firewall/internal/platform/web" +) + +func Handlers(lock *sync.RWMutex, cfg *config.APIFWConfigurationAPIMode, shutdown chan os.Signal, logger *logrus.Logger, storedSpecs database.DBOpenAPILoader) fasthttp.RequestHandler { + + // define FastJSON parsers pool + var parserPool fastjson.ParserPool + schemaIDs := storedSpecs.SchemaIDs() + + // Construct the web.App which holds all routes as well as common Middleware. + apps := web.NewApps(lock, cfg.PassOptionsRequests, storedSpecs, shutdown, logger, mid.Logger(logger), mid.MIMETypeIdentifier(logger), mid.Errors(logger), mid.Panics(logger)) + + for _, schemaID := range schemaIDs { + + serverURLStr := "/" + spec := storedSpecs.Specification(schemaID) + servers := spec.Servers + if servers != nil { + var err error + if serverURLStr, err = servers.BasePath(); err != nil { + logger.Errorf("getting server URL from OpenAPI specification: %v", err) + } + } + + serverURL, err := url.Parse(serverURLStr) + if err != nil { + logger.Errorf("parsing server URL from OpenAPI specification: %v", err) + } + + // get new router + newSwagRouter, err := router.NewRouterDBLoader(schemaID, storedSpecs) + if err != nil { + logger.WithFields(logrus.Fields{"error": err}).Error("new router creation failed") + } + + for i := 0; i < len(newSwagRouter.Routes); i++ { + + s := APIMode{ + CustomRoute: &newSwagRouter.Routes[i], + Log: logger, + Cfg: cfg, + ParserPool: &parserPool, + OpenAPIRouter: newSwagRouter, + } + updRoutePath := path.Join(serverURL.Path, newSwagRouter.Routes[i].Path) + + s.Log.Debugf("handler: Schema ID %d: OpenAPI version %s: Loaded path %s - %s", schemaID, storedSpecs.SpecificationVersion(schemaID), newSwagRouter.Routes[i].Method, updRoutePath) + + apps.Handle(schemaID, newSwagRouter.Routes[i].Method, updRoutePath, s.APIModeHandler) + } + + //set handler for default behavior (404, 405) + s := APIMode{ + CustomRoute: nil, + Log: logger, + Cfg: cfg, + ParserPool: &parserPool, + OpenAPIRouter: newSwagRouter, + } + apps.SetDefaultBehavior(schemaID, s.APIModeHandler) + } + + return apps.APIModeHandler +} diff --git a/cmd/api-firewall/internal/handlers/check.go b/cmd/api-firewall/internal/handlers/proxy/health.go similarity index 99% rename from cmd/api-firewall/internal/handlers/check.go rename to cmd/api-firewall/internal/handlers/proxy/health.go index 82efda5..63342b4 100644 --- a/cmd/api-firewall/internal/handlers/check.go +++ b/cmd/api-firewall/internal/handlers/proxy/health.go @@ -1,4 +1,4 @@ -package handlers +package proxy import ( "os" diff --git a/cmd/api-firewall/internal/handlers/openapi.go b/cmd/api-firewall/internal/handlers/proxy/openapi.go similarity index 61% rename from cmd/api-firewall/internal/handlers/openapi.go rename to cmd/api-firewall/internal/handlers/proxy/openapi.go index e0e277a..d180ae4 100644 --- a/cmd/api-firewall/internal/handlers/openapi.go +++ b/cmd/api-firewall/internal/handlers/proxy/openapi.go @@ -1,4 +1,4 @@ -package handlers +package proxy import ( "context" @@ -9,7 +9,6 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/openapi3filter" - "github.com/getkin/kin-openapi/routers" "github.com/savsgio/gotils/strconv" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" @@ -18,20 +17,18 @@ import ( "github.com/wallarm/api-firewall/internal/config" "github.com/wallarm/api-firewall/internal/platform/oauth2" "github.com/wallarm/api-firewall/internal/platform/proxy" - "github.com/wallarm/api-firewall/internal/platform/shadowAPI" + "github.com/wallarm/api-firewall/internal/platform/router" "github.com/wallarm/api-firewall/internal/platform/validator" "github.com/wallarm/api-firewall/internal/platform/web" ) type openapiWaf struct { - route *routers.Route - proxyPool proxy.Pool - logger *logrus.Logger - cfg *config.APIFWConfiguration - pathParamLength int - parserPool *fastjson.ParserPool - oauthValidator oauth2.OAuth2 - shadowAPI shadowAPI.Checker + customRoute *router.CustomRoute + proxyPool proxy.Pool + logger *logrus.Logger + cfg *config.APIFWConfiguration + parserPool *fastjson.ParserPool + oauthValidator oauth2.OAuth2 } // EXPERIMENTAL feature @@ -39,13 +36,10 @@ type openapiWaf struct { func getValidationHeader(ctx *fasthttp.RequestCtx, err error) *string { var reason = "unknown" - switch err.(type) { - + switch err := err.(type) { case *openapi3filter.ResponseError: - responseError, ok := err.(*openapi3filter.ResponseError) - - if ok && responseError.Reason != "" { - reason = responseError.Reason + if err.Reason != "" { + reason = err.Reason } id := fmt.Sprintf("response-%d-%s", ctx.Response.StatusCode(), strings.Split(string(ctx.Response.Header.ContentType()), ";")[0]) @@ -54,46 +48,42 @@ func getValidationHeader(ctx *fasthttp.RequestCtx, err error) *string { case *openapi3filter.RequestError: - requestError, ok := err.(*openapi3filter.RequestError) - if !ok { - return nil - } - - if requestError.Reason != "" { - reason = requestError.Reason + if err.Reason != "" { + reason = err.Reason } - if requestError.Parameter != nil { + if err.Parameter != nil { paramName := "request-parameter" - if requestError.Reason == "" { - schemaError, ok := requestError.Err.(*openapi3.SchemaError) + if err.Reason == "" { + schemaError, ok := err.Err.(*openapi3.SchemaError) if ok && schemaError.Reason != "" { reason = schemaError.Reason } - paramName = requestError.Parameter.Name + paramName = err.Parameter.Name } value := fmt.Sprintf("request-parameter:%s:%s", reason, paramName) return &value } - if requestError.RequestBody != nil { + if err.RequestBody != nil { id := fmt.Sprintf("request-body-%s", strings.Split(string(ctx.Request.Header.ContentType()), ";")[0]) value := fmt.Sprintf("%s:%s:request-body", id, reason) return &value } case *openapi3filter.SecurityRequirementsError: + secRequirements := err.SecurityRequirements secSchemeName := "" - for _, scheme := range err.(*openapi3filter.SecurityRequirementsError).SecurityRequirements { + for _, scheme := range secRequirements { for key := range scheme { secSchemeName += key + "," } } secErrors := "" - for _, secError := range err.(*openapi3filter.SecurityRequirementsError).Errors { + for _, secError := range err.Errors { secErrors += secError.Error() + "," } @@ -105,20 +95,35 @@ func getValidationHeader(ctx *fasthttp.RequestCtx, err error) *string { } // Proxy request -func performProxy(ctx *fasthttp.RequestCtx, logger *logrus.Logger, client proxy.HTTPClient) error { +func performProxy(ctx *fasthttp.RequestCtx, proxyPool proxy.Pool) error { + + client, err := proxyPool.Get() + if err != nil { + return err + } + defer proxyPool.Put(client) + if err := client.Do(&ctx.Request, &ctx.Response); err != nil { - logger.WithFields(logrus.Fields{ - "error": err, - "request_id": fmt.Sprintf("#%016X", ctx.ID()), - }).Error("error while proxying request") + // request proxy has been failed + ctx.SetUserValue(web.RequestProxyFailed, true) + switch err { case fasthttp.ErrDialTimeout: - return web.RespondError(ctx, fasthttp.StatusGatewayTimeout, nil) + if err := web.RespondError(ctx, fasthttp.StatusGatewayTimeout, ""); err != nil { + return err + } case fasthttp.ErrNoFreeConns: - return web.RespondError(ctx, fasthttp.StatusServiceUnavailable, nil) + if err := web.RespondError(ctx, fasthttp.StatusServiceUnavailable, ""); err != nil { + return err + } default: - return web.RespondError(ctx, fasthttp.StatusBadGateway, nil) + if err := web.RespondError(ctx, fasthttp.StatusBadGateway, ""); err != nil { + return err + } } + + // The error has been handled so we can stop propagating it + return err } return nil @@ -126,49 +131,43 @@ func performProxy(ctx *fasthttp.RequestCtx, logger *logrus.Logger, client proxy. func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { - client, err := s.proxyPool.Get() - if err != nil { - s.logger.WithFields(logrus.Fields{ - "error": err, - "request_id": fmt.Sprintf("#%016X", ctx.ID()), - }).Error("error while proxying request") - return web.RespondError(ctx, fasthttp.StatusServiceUnavailable, nil) - } - defer s.proxyPool.Put(client) - - // Proxy request if APIFW is disabled + // Proxy request if APIFW is disabled OR pass requests with OPTIONS method is enabled and request method is OPTIONS if s.cfg.RequestValidation == web.ValidationDisable && s.cfg.ResponseValidation == web.ValidationDisable { - return performProxy(ctx, s.logger, client) + if err := performProxy(ctx, s.proxyPool); err != nil { + s.logger.WithFields(logrus.Fields{ + "error": err, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("error while proxying request") + } + return nil } // If Validation is BLOCK for request and response then respond by CustomBlockStatusCode - if s.route == nil { + if s.customRoute == nil { + // route for the request not found + ctx.SetUserValue(web.RequestProxyNoRoute, true) + if s.cfg.RequestValidation == web.ValidationBlock || s.cfg.ResponseValidation == web.ValidationBlock { if s.cfg.AddValidationStatusHeader { - vh := "request: route not found" - return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, &vh) + vh := "request: customRoute not found" + return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, vh) } - return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, nil) + return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, "") } - // Check shadow api if path or method are not found and validation mode is LOG_ONLY - if s.cfg.RequestValidation == web.ValidationLog || s.cfg.ResponseValidation == web.ValidationLog { - // Check Shadow API endpoints - err := performProxy(ctx, s.logger, client) - if sErr := s.shadowAPI.Check(ctx); sErr != nil { - s.logger.WithFields(logrus.Fields{ - "error": err, - "request_id": fmt.Sprintf("#%016X", ctx.ID()), - }).Error("Shadow API check error") - } - return err + if err := performProxy(ctx, s.proxyPool); err != nil { + s.logger.WithFields(logrus.Fields{ + "error": err, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("error while proxying request") } + return nil } var pathParams map[string]string - if s.pathParamLength > 0 { - pathParams = make(map[string]string, s.pathParamLength) + if s.customRoute.ParametersNumberInPath > 0 { + pathParams = make(map[string]string) ctx.VisitUserValues(func(key []byte, value interface{}) { keyStr := strconv.B2S(key) @@ -183,12 +182,13 @@ func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { "error": err, "request_id": fmt.Sprintf("#%016X", ctx.ID()), }).Error("error while converting http request") - return web.RespondError(ctx, fasthttp.StatusBadRequest, nil) + return web.RespondError(ctx, fasthttp.StatusBadRequest, "") } // decode request body requestContentEncoding := string(ctx.Request.Header.ContentEncoding()) if requestContentEncoding != "" { + var err error req.Body, err = web.GetDecompressedRequestBody(&ctx.Request, requestContentEncoding) if err != nil { s.logger.WithFields(logrus.Fields{ @@ -203,7 +203,7 @@ func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { requestValidationInput := &openapi3filter.RequestValidationInput{ Request: &req, PathParams: pathParams, - Route: s.route, + Route: s.customRoute.Route, Options: &openapi3filter.Options{ AuthenticationFunc: func(ctx context.Context, input *openapi3filter.AuthenticationInput) error { switch input.SecurityScheme.Type { @@ -257,21 +257,64 @@ func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { switch s.cfg.RequestValidation { case web.ValidationBlock: if err := validator.ValidateRequest(ctx, requestValidationInput, jsonParser); err != nil { - s.logger.WithFields(logrus.Fields{ - "error": err, - "request_id": fmt.Sprintf("#%016X", ctx.ID()), - }).Error("request validation error") - if s.cfg.AddValidationStatusHeader { - if vh := getValidationHeader(ctx, err); vh != nil { + + isRequestBlocked := true + if requestErr, ok := err.(*openapi3filter.RequestError); ok { + + // body parser not found + if strings.HasPrefix(requestErr.Error(), "request body has an error: failed to decode request body: unsupported content type") { s.logger.WithFields(logrus.Fields{ "error": err, "request_id": fmt.Sprintf("#%016X", ctx.ID()), - }).Errorf("add header %s: %s", web.ValidationStatus, *vh) - ctx.Request.Header.Add(web.ValidationStatus, *vh) - return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, vh) + }).Error("request body parsing error: request passed") + isRequestBlocked = false } } - return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, nil) + + if isRequestBlocked { + // request has been blocked + ctx.SetUserValue(web.RequestBlocked, true) + + s.logger.WithFields(logrus.Fields{ + "error": err, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("request validation error: request blocked") + + if s.cfg.AddValidationStatusHeader { + if vh := getValidationHeader(ctx, err); vh != nil { + s.logger.WithFields(logrus.Fields{ + "error": err, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Errorf("add header %s: %s", web.ValidationStatus, *vh) + ctx.Request.Header.Add(web.ValidationStatus, *vh) + return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, *vh) + } + } + + return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, "") + } + } + + if s.cfg.ShadowAPI.UnknownParametersDetection { + upResults, valUPReqErrors := validator.ValidateUnknownRequestParameters(ctx, requestValidationInput.Route, req.Header, jsonParser) + // log only error and pass request if unknown params module can't parse it + if valUPReqErrors != nil { + s.logger.WithFields(logrus.Fields{ + "error": valUPReqErrors, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Warning("Shadow API: searching for undefined parameters") + } + + if len(upResults) > 0 { + s.logger.WithFields(logrus.Fields{ + "errors": upResults, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("Shadow API: undefined parameters found") + + // request has been blocked + ctx.SetUserValue(web.RequestBlocked, true) + return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, "") + } } case web.ValidationLog: if err := validator.ValidateRequest(ctx, requestValidationInput, jsonParser); err != nil { @@ -280,10 +323,32 @@ func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { "request_id": fmt.Sprintf("#%016X", ctx.ID()), }).Error("request validation error") } + + if s.cfg.ShadowAPI.UnknownParametersDetection { + upResults, valUPReqErrors := validator.ValidateUnknownRequestParameters(ctx, requestValidationInput.Route, req.Header, jsonParser) + // log only error and pass request if unknown params module can't parse it + if valUPReqErrors != nil { + s.logger.WithFields(logrus.Fields{ + "error": valUPReqErrors, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Warning("Shadow API: searching for undefined parameters") + } + + if len(upResults) > 0 { + s.logger.WithFields(logrus.Fields{ + "errors": upResults, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("Shadow API: undefined parameters found") + } + } } - if err := performProxy(ctx, s.logger, client); err != nil { - return err + if err := performProxy(ctx, s.proxyPool); err != nil { + s.logger.WithFields(logrus.Fields{ + "error": err, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("error while proxying request") + return nil } // Prepare http response headers @@ -323,6 +388,9 @@ func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { switch s.cfg.ResponseValidation { case web.ValidationBlock: if err := validator.ValidateResponse(ctx, responseValidationInput, jsonParser); err != nil { + // response has been blocked + ctx.SetUserValue(web.ResponseBlocked, true) + s.logger.WithFields(logrus.Fields{ "error": err, "request_id": fmt.Sprintf("#%016X", ctx.ID()), @@ -334,13 +402,21 @@ func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { "request_id": fmt.Sprintf("#%016X", ctx.ID()), }).Errorf("add header %s: %s", web.ValidationStatus, *vh) ctx.Response.Header.Add(web.ValidationStatus, *vh) - return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, vh) + return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, *vh) } } - return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, nil) + return web.RespondError(ctx, s.cfg.CustomBlockStatusCode, "") } case web.ValidationLog: if err := validator.ValidateResponse(ctx, responseValidationInput, jsonParser); err != nil { + if respErr, ok := err.(*openapi3filter.ResponseError); ok { + // body parser not found + if respErr.Reason == "status is not supported" { + // received response status was not found in the OpenAPI spec + ctx.SetUserValue(web.ResponseStatusNotFound, true) + } + return nil + } s.logger.WithFields(logrus.Fields{ "error": err, "request_id": fmt.Sprintf("#%016X", ctx.ID()), diff --git a/cmd/api-firewall/internal/handlers/routes.go b/cmd/api-firewall/internal/handlers/proxy/routes.go similarity index 55% rename from cmd/api-firewall/internal/handlers/routes.go rename to cmd/api-firewall/internal/handlers/proxy/routes.go index 0a52d15..64fb5c1 100644 --- a/cmd/api-firewall/internal/handlers/routes.go +++ b/cmd/api-firewall/internal/handlers/proxy/routes.go @@ -1,4 +1,4 @@ -package handlers +package proxy import ( "crypto/rsa" @@ -7,7 +7,6 @@ import ( "path" "strings" - "github.com/getkin/kin-openapi/openapi3" "github.com/golang-jwt/jwt" "github.com/karlseguin/ccache/v2" "github.com/sirupsen/logrus" @@ -19,11 +18,10 @@ import ( woauth2 "github.com/wallarm/api-firewall/internal/platform/oauth2" "github.com/wallarm/api-firewall/internal/platform/proxy" "github.com/wallarm/api-firewall/internal/platform/router" - "github.com/wallarm/api-firewall/internal/platform/shadowAPI" "github.com/wallarm/api-firewall/internal/platform/web" ) -func OpenapiProxy(cfg *config.APIFWConfiguration, serverUrl *url.URL, shutdown chan os.Signal, logger *logrus.Logger, proxy proxy.Pool, swagRouter *router.Router, deniedTokens *denylist.DeniedTokens, shadowAPI shadowAPI.Checker) fasthttp.RequestHandler { +func Handlers(cfg *config.APIFWConfiguration, serverURL *url.URL, shutdown chan os.Signal, logger *logrus.Logger, proxy proxy.Pool, swagRouter *router.Router, deniedTokens *denylist.DeniedTokens) fasthttp.RequestHandler { // define FastJSON parsers pool var parserPool fastjson.ParserPool @@ -66,53 +64,31 @@ func OpenapiProxy(cfg *config.APIFWConfiguration, serverUrl *url.URL, shutdown c } // Construct the web.App which holds all routes as well as common Middleware. - app := web.NewApp(shutdown, cfg, logger, mid.Logger(logger), mid.Errors(logger), mid.Panics(logger), mid.Proxy(cfg, serverUrl), mid.Denylist(cfg, deniedTokens, logger)) - - for _, route := range swagRouter.Routes { - pathParamLength := 0 - if getOp := route.Route.PathItem.GetOperation(route.Method); getOp != nil { - for _, param := range getOp.Parameters { - if param.Value.In == openapi3.ParameterInPath { - pathParamLength += 1 - } - } - } - - // check common parameters - if getOp := route.Route.PathItem.Parameters; getOp != nil { - for _, param := range getOp { - if param.Value.In == openapi3.ParameterInPath { - pathParamLength += 1 - } - } - } + app := web.NewApp(shutdown, cfg, logger, mid.Logger(logger), mid.Errors(logger), mid.Panics(logger), mid.Proxy(cfg, serverURL), mid.Denylist(cfg, deniedTokens, logger), mid.ShadowAPIMonitor(logger, &cfg.ShadowAPI)) + for i := 0; i < len(swagRouter.Routes); i++ { s := openapiWaf{ - route: route.Route, - proxyPool: proxy, - pathParamLength: pathParamLength, - logger: logger, - cfg: cfg, - parserPool: &parserPool, - oauthValidator: oauthValidator, - shadowAPI: shadowAPI, + customRoute: &swagRouter.Routes[i], + proxyPool: proxy, + logger: logger, + cfg: cfg, + parserPool: &parserPool, + oauthValidator: oauthValidator, } - updRoutePath := path.Join(serverUrl.Path, route.Path) + updRoutePath := path.Join(serverURL.Path, swagRouter.Routes[i].Path) - s.logger.Debugf("handler: Loaded path : %s - %s", route.Method, updRoutePath) + s.logger.Debugf("handler: Loaded path %s - %s", swagRouter.Routes[i].Method, updRoutePath) - app.Handle(route.Method, updRoutePath, s.openapiWafHandler) + app.Handle(swagRouter.Routes[i].Method, updRoutePath, s.openapiWafHandler) } // set handler for default behavior (404, 405) s := openapiWaf{ - route: nil, - proxyPool: proxy, - pathParamLength: 0, - logger: logger, - cfg: cfg, - parserPool: &parserPool, - shadowAPI: shadowAPI, + customRoute: nil, + proxyPool: proxy, + logger: logger, + cfg: cfg, + parserPool: &parserPool, } app.SetDefaultBehavior(s.openapiWafHandler) diff --git a/cmd/api-firewall/internal/updater/updater.go b/cmd/api-firewall/internal/updater/updater.go new file mode 100644 index 0000000..eaf6295 --- /dev/null +++ b/cmd/api-firewall/internal/updater/updater.go @@ -0,0 +1,106 @@ +package updater + +import ( + "fmt" + "os" + "reflect" + "sync" + "time" + + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" + handlersAPI "github.com/wallarm/api-firewall/cmd/api-firewall/internal/handlers/api" + "github.com/wallarm/api-firewall/internal/config" + "github.com/wallarm/api-firewall/internal/platform/database" +) + +type Updater interface { + Start() error + Shutdown() error + Update() error +} + +type Specification struct { + logger *logrus.Logger + sqlLiteStorage database.DBOpenAPILoader + stop chan struct{} + updateTime time.Duration + cfg *config.APIFWConfigurationAPIMode + api *fasthttp.Server + shutdown chan os.Signal + health *handlersAPI.Health + lock *sync.RWMutex +} + +// NewController function defines configuration updater controller +func NewController(lock *sync.RWMutex, logger *logrus.Logger, sqlLiteStorage database.DBOpenAPILoader, cfg *config.APIFWConfigurationAPIMode, api *fasthttp.Server, shutdown chan os.Signal, health *handlersAPI.Health) Updater { + return &Specification{ + logger: logger, + sqlLiteStorage: sqlLiteStorage, + stop: make(chan struct{}), + updateTime: cfg.SpecificationUpdatePeriod, + cfg: cfg, + api: api, + shutdown: shutdown, + health: health, + lock: lock, + } +} + +func getSchemaVersions(dbSpecs database.DBOpenAPILoader) map[int]string { + result := make(map[int]string) + schemaIDs := dbSpecs.SchemaIDs() + for _, schemaID := range schemaIDs { + result[schemaID] = dbSpecs.SpecificationVersion(schemaID) + } + return result +} + +// Start function starts update process every ConfigurationUpdatePeriod +func (s *Specification) Start() error { + + go func() { + updateTicker := time.NewTicker(s.updateTime) + for { + select { + case <-updateTicker.C: + beforeUpdateSpecs := getSchemaVersions(s.sqlLiteStorage) + if err := s.Update(); err != nil { + s.logger.WithFields(logrus.Fields{"error": err}).Error("updating OpenAPI specification") + continue + } + afterUpdateSpecs := getSchemaVersions(s.sqlLiteStorage) + if !reflect.DeepEqual(beforeUpdateSpecs, afterUpdateSpecs) { + s.logger.Debugf("OpenAPI specifications has been updated. Loaded OpenAPI specification versions: %v", afterUpdateSpecs) + s.lock.Lock() + s.api.Handler = handlersAPI.Handlers(s.lock, s.cfg, s.shutdown, s.logger, s.sqlLiteStorage) + s.health.OpenAPIDB = s.sqlLiteStorage + s.lock.Unlock() + continue + } + s.logger.Debugf("regular update checker: new OpenAPI specifications not found") + } + } + }() + + <-s.stop + return nil +} + +// Shutdown function stops update process +func (s *Specification) Shutdown() error { + defer s.logger.Infof("specification updater: stopped") + s.stop <- struct{}{} + return nil +} + +// Update function performs a specification update +func (s *Specification) Update() error { + + // Update specification + if err := s.sqlLiteStorage.Load(s.cfg.PathToSpecDB); err != nil { + return fmt.Errorf("error while spicification update: %w", err) + } + + return nil +} diff --git a/cmd/api-firewall/internal/updater/updater_mock.go b/cmd/api-firewall/internal/updater/updater_mock.go new file mode 100644 index 0000000..0e6dfac --- /dev/null +++ b/cmd/api-firewall/internal/updater/updater_mock.go @@ -0,0 +1,76 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./cmd/api-firewall/internal/updater/updater.go + +// Package updater is a generated GoMock package. +package updater + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockUpdater is a mock of Updater interface. +type MockUpdater struct { + ctrl *gomock.Controller + recorder *MockUpdaterMockRecorder +} + +// MockUpdaterMockRecorder is the mock recorder for MockUpdater. +type MockUpdaterMockRecorder struct { + mock *MockUpdater +} + +// NewMockUpdater creates a new mock instance. +func NewMockUpdater(ctrl *gomock.Controller) *MockUpdater { + mock := &MockUpdater{ctrl: ctrl} + mock.recorder = &MockUpdaterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUpdater) EXPECT() *MockUpdaterMockRecorder { + return m.recorder +} + +// Shutdown mocks base method. +func (m *MockUpdater) Shutdown() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Shutdown") + ret0, _ := ret[0].(error) + return ret0 +} + +// Shutdown indicates an expected call of Shutdown. +func (mr *MockUpdaterMockRecorder) Shutdown() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockUpdater)(nil).Shutdown)) +} + +// Start mocks base method. +func (m *MockUpdater) Start() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Start") + ret0, _ := ret[0].(error) + return ret0 +} + +// Start indicates an expected call of Start. +func (mr *MockUpdaterMockRecorder) Start() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockUpdater)(nil).Start)) +} + +// Update mocks base method. +func (m *MockUpdater) Update() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Update") + ret0, _ := ret[0].(error) + return ret0 +} + +// Update indicates an expected call of Update. +func (mr *MockUpdaterMockRecorder) Update() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockUpdater)(nil).Update)) +} diff --git a/cmd/api-firewall/internal/updater/updater_test.go b/cmd/api-firewall/internal/updater/updater_test.go new file mode 100644 index 0000000..346fa3b --- /dev/null +++ b/cmd/api-firewall/internal/updater/updater_test.go @@ -0,0 +1,142 @@ +package updater + +import ( + "fmt" + "os" + "os/signal" + "sync" + "syscall" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" + handlersAPI "github.com/wallarm/api-firewall/cmd/api-firewall/internal/handlers/api" + "github.com/wallarm/api-firewall/internal/config" + "github.com/wallarm/api-firewall/internal/platform/database" + "github.com/wallarm/api-firewall/internal/platform/web" +) + +const ( + DefaultSchemaID = 1 +) + +var cfg = config.APIFWConfigurationAPIMode{ + APIFWMode: config.APIFWMode{Mode: web.APIMode}, + SpecificationUpdatePeriod: 2 * time.Second, + PathToSpecDB: "./wallarm_api_after_update.db", + UnknownParametersDetection: true, + PassOptionsRequests: false, +} + +func TestUpdaterBasic(t *testing.T) { + + logger := logrus.New() + logger.SetLevel(logrus.ErrorLevel) + + var lock sync.RWMutex + + // load spec from the database + specStorage, err := database.NewOpenAPIDB(logger, "./wallarm_api_before_update.db") + if err != nil { + t.Fatal(err) + } + + shutdown := make(chan os.Signal, 1) + signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM) + + api := fasthttp.Server{} + api.Handler = handlersAPI.Handlers(&lock, &cfg, shutdown, logger, specStorage) + health := handlersAPI.Health{} + + // invalid route in the old spec + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/new") + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + lock.RLock() + api.Handler(&reqCtx) + lock.RUnlock() + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + // valid route in the old spec + req = fasthttp.AcquireRequest() + req.SetRequestURI("/") + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + lock.RLock() + api.Handler(&reqCtx) + lock.RUnlock() + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + // start updater + updSpecErrors := make(chan error, 1) + updater := NewController(&lock, logger, specStorage, &cfg, &api, shutdown, &health) + go func() { + t.Logf("starting specification regular update process every %.0f seconds", cfg.SpecificationUpdatePeriod.Seconds()) + updSpecErrors <- updater.Start() + }() + + time.Sleep(3 * time.Second) + + if err := updater.Shutdown(); err != nil { + t.Fatal(err) + } + + // valid route in the new spec + req = fasthttp.AcquireRequest() + req.SetRequestURI("/test/new") + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + lock.RLock() + api.Handler(&reqCtx) + lock.RUnlock() + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + // invalid route in the new spec + req = fasthttp.AcquireRequest() + req.SetRequestURI("/") + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + lock.RLock() + api.Handler(&reqCtx) + lock.RUnlock() + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + +} diff --git a/cmd/api-firewall/internal/updater/wallarm_api_after_update.db b/cmd/api-firewall/internal/updater/wallarm_api_after_update.db new file mode 100644 index 0000000..427f82e Binary files /dev/null and b/cmd/api-firewall/internal/updater/wallarm_api_after_update.db differ diff --git a/cmd/api-firewall/internal/updater/wallarm_api_before_update.db b/cmd/api-firewall/internal/updater/wallarm_api_before_update.db new file mode 100644 index 0000000..3f698a8 Binary files /dev/null and b/cmd/api-firewall/internal/updater/wallarm_api_before_update.db differ diff --git a/cmd/api-firewall/main.go b/cmd/api-firewall/main.go index 7e8b2f7..dc084d6 100644 --- a/cmd/api-firewall/main.go +++ b/cmd/api-firewall/main.go @@ -9,6 +9,7 @@ import ( "os/signal" "path" "strings" + "sync" "syscall" "github.com/ardanlabs/conf" @@ -17,19 +18,23 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" - "github.com/wallarm/api-firewall/cmd/api-firewall/internal/handlers" + handlersAPI "github.com/wallarm/api-firewall/cmd/api-firewall/internal/handlers/api" + handlersProxy "github.com/wallarm/api-firewall/cmd/api-firewall/internal/handlers/proxy" + "github.com/wallarm/api-firewall/cmd/api-firewall/internal/updater" "github.com/wallarm/api-firewall/internal/config" + "github.com/wallarm/api-firewall/internal/platform/database" "github.com/wallarm/api-firewall/internal/platform/denylist" "github.com/wallarm/api-firewall/internal/platform/proxy" "github.com/wallarm/api-firewall/internal/platform/router" - "github.com/wallarm/api-firewall/internal/platform/shadowAPI" + "github.com/wallarm/api-firewall/internal/platform/web" ) var build = "develop" const ( - namespace = "apifw" - logPrefix = "main" + namespace = "apifw" + logPrefix = "main" + projectName = "Wallarm API-Firewall" ) func main() { @@ -38,24 +43,256 @@ func main() { logger.SetLevel(logrus.DebugLevel) logger.SetFormatter(&logrus.TextFormatter{ - DisableQuote: true, - FullTimestamp: true, + DisableQuote: true, + FullTimestamp: true, + DisableLevelTruncation: true, }) - if err := run(logger); err != nil { - logger.Infof("%s: error: %s", logPrefix, err) - os.Exit(1) + // if MODE var has invalid value then proxy mode will be used + var currentMode config.APIFWMode + if err := conf.Parse(os.Args[1:], namespace, ¤tMode); err != nil { + if err := runProxyMode(logger); err != nil { + logger.Infof("%s: error: %s", logPrefix, err) + os.Exit(1) + } + return } + + // if MODE var has valid or default value then an appropriate mode will be used + switch strings.ToLower(currentMode.Mode) { + case web.APIMode: + if err := runAPIMode(logger); err != nil { + logger.Infof("%s: error: %s", logPrefix, err) + os.Exit(1) + } + default: + if err := runProxyMode(logger); err != nil { + logger.Infof("%s: error: %s", logPrefix, err) + os.Exit(1) + } + } + } -func run(logger *logrus.Logger) error { +func runAPIMode(logger *logrus.Logger) error { + + // ========================================================================= + // Configuration + + var cfg config.APIFWConfigurationAPIMode + cfg.Version.SVN = build + cfg.Version.Desc = projectName + + if err := conf.Parse(os.Args[1:], namespace, &cfg); err != nil { + switch err { + case conf.ErrHelpWanted: + usage, err := conf.Usage(namespace, &cfg) + if err != nil { + return errors.Wrap(err, "generating config usage") + } + fmt.Println(usage) + return nil + case conf.ErrVersionWanted: + version, err := conf.VersionString(namespace, &cfg) + if err != nil { + return errors.Wrap(err, "generating config version") + } + fmt.Println(version) + return nil + } + return errors.Wrap(err, "parsing config") + } + + // ========================================================================= + // Init Logger + + if strings.ToLower(cfg.LogFormat) == "json" { + logger.SetFormatter(&logrus.JSONFormatter{}) + } + + switch strings.ToLower(cfg.LogLevel) { + case "trace": + logger.SetLevel(logrus.TraceLevel) + case "debug": + logger.SetLevel(logrus.DebugLevel) + case "error": + logger.SetLevel(logrus.ErrorLevel) + case "warning": + logger.SetLevel(logrus.WarnLevel) + case "info": + logger.SetLevel(logrus.InfoLevel) + default: + return errors.New("invalid log level") + } + + // Print the build version for our logs. Also expose it under /debug/vars. + expvar.NewString("build").Set(build) + + logger.Infof("%s : Started : Application initializing : version %q", logPrefix, build) + defer logger.Infof("%s: Completed", logPrefix) + + out, err := conf.String(&cfg) + if err != nil { + return errors.Wrap(err, "generating config for output") + } + logger.Infof("%s: Configuration Loaded :\n%v\n", logPrefix, out) + + // Make a channel to listen for an interrupt or terminate signal from the OS. + // Use a buffered channel because the signal package requires it. + shutdown := make(chan os.Signal, 1) + signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM) + + // DB Usage Lock + var dbLock sync.RWMutex + + // Make a channel to listen for errors coming from the listener. Use a + // buffered channel so the goroutine can exit if we don't collect this error. + serverErrors := make(chan error, 1) + + // load spec from the database + specStorage, err := database.NewOpenAPIDB(logger, cfg.PathToSpecDB) + if err != nil { + logger.Fatalf("%s: trying to load API Spec value from SQLLite Database : %v\n", logPrefix, err.Error()) + return err + } + + // ========================================================================= + // Init Handlers + + requestHandlers := handlersAPI.Handlers(&dbLock, &cfg, shutdown, logger, specStorage) + + // ========================================================================= + // Start Health API Service + + healthData := handlersAPI.Health{ + Build: build, + Logger: logger, + OpenAPIDB: specStorage, + } + + // health service handler + healthHandler := func(ctx *fasthttp.RequestCtx) { + switch string(ctx.Path()) { + case "/v1/liveness": + if err := healthData.Liveness(ctx); err != nil { + healthData.Logger.Errorf("%s: liveness: %s", logPrefix, err.Error()) + } + case "/v1/readiness": + if err := healthData.Readiness(ctx); err != nil { + healthData.Logger.Errorf("%s: readiness: %s", logPrefix, err.Error()) + } + default: + ctx.Error("Unsupported path", fasthttp.StatusNotFound) + } + } + + healthApi := fasthttp.Server{ + Handler: healthHandler, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + Logger: logger, + NoDefaultServerHeader: true, + } + + // Start the service listening for requests. + go func() { + logger.Infof("%s: Health API listening on %s", logPrefix, cfg.HealthAPIHost) + serverErrors <- healthApi.ListenAndServe(cfg.HealthAPIHost) + }() + + // ========================================================================= + // Start API Service + + logger.Infof("%s: Initializing API support", logPrefix) + + apiHost, err := url.ParseRequestURI(cfg.APIHost) + if err != nil { + return errors.Wrap(err, "parsing API Host URL") + } + + var isTLS bool + + switch apiHost.Scheme { + case "http": + isTLS = false + case "https": + isTLS = true + } + + api := fasthttp.Server{ + Handler: requestHandlers, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + Logger: logger, + NoDefaultServerHeader: true, + } + + // ========================================================================= + // Init Regular Update Controller + + updSpecErrors := make(chan error, 1) + + updOpenAPISpec := updater.NewController(&dbLock, logger, specStorage, &cfg, &api, shutdown, &healthData) + + // disable updater if SpecificationUpdatePeriod == 0 + if cfg.SpecificationUpdatePeriod.Seconds() > 0 { + go func() { + logger.Infof("%s: starting specification regular update process every %.0f seconds", logPrefix, cfg.SpecificationUpdatePeriod.Seconds()) + updSpecErrors <- updOpenAPISpec.Start() + }() + } + + // Start the service listening for requests. + go func() { + logger.Infof("%s: API listening on %s", logPrefix, cfg.APIHost) + switch isTLS { + case false: + serverErrors <- api.ListenAndServe(apiHost.Host) + case true: + serverErrors <- api.ListenAndServeTLS(apiHost.Host, path.Join(cfg.TLS.CertsPath, cfg.TLS.CertFile), + path.Join(cfg.TLS.CertsPath, cfg.TLS.CertKey)) + } + }() + + // ========================================================================= + // Shutdown + + // Blocking main and waiting for shutdown. + select { + case err := <-serverErrors: + return errors.Wrap(err, "server error") + + case err := <-updSpecErrors: + return errors.Wrap(err, "regular updater error") + + case sig := <-shutdown: + logger.Infof("%s: %v: Start shutdown", logPrefix, sig) + + if cfg.SpecificationUpdatePeriod.Seconds() > 0 { + if err := updOpenAPISpec.Shutdown(); err != nil { + return errors.Wrap(err, "could not stop configuration updater gracefully") + } + } + + // Asking listener to shutdown and shed load. + if err := api.Shutdown(); err != nil { + return errors.Wrap(err, "could not stop server gracefully") + } + logger.Infof("%s: %v: Completed shutdown", logPrefix, sig) + } + + return nil + +} + +func runProxyMode(logger *logrus.Logger) error { // ========================================================================= // Configuration var cfg config.APIFWConfiguration cfg.Version.SVN = build - cfg.Version.Desc = "Wallarm API-Firewall" + cfg.Version.Desc = projectName if err := conf.Parse(os.Args[1:], namespace, &cfg); err != nil { switch err { @@ -77,7 +314,7 @@ func run(logger *logrus.Logger) error { return errors.Wrap(err, "parsing config") } - // validate + // validate env parameter values validate := validator.New() if err := validate.RegisterValidation("HttpStatusCodes", config.ValidateStatusList); err != nil { @@ -115,6 +352,8 @@ func run(logger *logrus.Logger) error { } switch strings.ToLower(cfg.LogLevel) { + case "trace": + logger.SetLevel(logrus.TraceLevel) case "debug": logger.SetLevel(logrus.DebugLevel) case "error": @@ -142,6 +381,17 @@ func run(logger *logrus.Logger) error { } logger.Infof("%s: Configuration Loaded :\n%v\n", logPrefix, out) + var requestHandlers fasthttp.RequestHandler + + // Make a channel to listen for an interrupt or terminate signal from the OS. + // Use a buffered channel because the signal package requires it. + shutdown := make(chan os.Signal, 1) + signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM) + + // Make a channel to listen for errors coming from the listener. Use a + // buffered channel so the goroutine can exit if we don't collect this error. + serverErrors := make(chan error, 1) + // ========================================================================= // Init Swagger @@ -198,11 +448,6 @@ func run(logger *logrus.Logger) error { return errors.Wrap(err, "proxy pool init") } - // ========================================================================= - // Init ShadowAPI checker - - shadowAPI := shadowAPI.New(&cfg.ShadowAPI, logger) - // ========================================================================= // Init Cache @@ -221,57 +466,14 @@ func run(logger *logrus.Logger) error { } // ========================================================================= - // Start API Service - - logger.Infof("%s: Initializing API support", logPrefix) - - apiHost, err := url.ParseRequestURI(cfg.APIHost) - if err != nil { - return errors.Wrap(err, "parsing API Host URL") - } - - var isTLS bool - - switch apiHost.Scheme { - case "http": - isTLS = false - case "https": - isTLS = true - } - - // Make a channel to listen for an interrupt or terminate signal from the OS. - // Use a buffered channel because the signal package requires it. - shutdown := make(chan os.Signal, 1) - signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM) + // Init Handlers - api := fasthttp.Server{ - Handler: handlers.OpenapiProxy(&cfg, serverUrl, shutdown, logger, pool, swagRouter, deniedTokens, shadowAPI), - ReadTimeout: cfg.ReadTimeout, - WriteTimeout: cfg.WriteTimeout, - Logger: logger, - NoDefaultServerHeader: true, - } - - // Make a channel to listen for errors coming from the listener. Use a - // buffered channel so the goroutine can exit if we don't collect this error. - serverErrors := make(chan error, 1) - - // Start the service listening for requests. - go func() { - logger.Infof("%s: API listening on %s", logPrefix, cfg.APIHost) - switch isTLS { - case false: - serverErrors <- api.ListenAndServe(apiHost.Host) - case true: - serverErrors <- api.ListenAndServeTLS(apiHost.Host, path.Join(cfg.TLS.CertsPath, cfg.TLS.CertFile), - path.Join(cfg.TLS.CertsPath, cfg.TLS.CertKey)) - } - }() + requestHandlers = handlersProxy.Handlers(&cfg, serverUrl, shutdown, logger, pool, swagRouter, deniedTokens) // ========================================================================= // Start Health API Service - healthData := handlers.Health{ + healthData := handlersProxy.Health{ Build: build, Logger: logger, Pool: pool, @@ -307,6 +509,45 @@ func run(logger *logrus.Logger) error { serverErrors <- healthApi.ListenAndServe(cfg.HealthAPIHost) }() + // ========================================================================= + // Start API Service + + logger.Infof("%s: Initializing API support", logPrefix) + + apiHost, err := url.ParseRequestURI(cfg.APIHost) + if err != nil { + return errors.Wrap(err, "parsing API Host URL") + } + + var isTLS bool + + switch apiHost.Scheme { + case "http": + isTLS = false + case "https": + isTLS = true + } + + api := fasthttp.Server{ + Handler: requestHandlers, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + Logger: logger, + NoDefaultServerHeader: true, + } + + // Start the service listening for requests. + go func() { + logger.Infof("%s: API listening on %s", logPrefix, cfg.APIHost) + switch isTLS { + case false: + serverErrors <- api.ListenAndServe(apiHost.Host) + case true: + serverErrors <- api.ListenAndServeTLS(apiHost.Host, path.Join(cfg.TLS.CertsPath, cfg.TLS.CertFile), + path.Join(cfg.TLS.CertsPath, cfg.TLS.CertKey)) + } + }() + // ========================================================================= // Shutdown diff --git a/cmd/api-firewall/tests/main_api_mode_test.go b/cmd/api-firewall/tests/main_api_mode_test.go new file mode 100644 index 0000000..3dfd6af --- /dev/null +++ b/cmd/api-firewall/tests/main_api_mode_test.go @@ -0,0 +1,2094 @@ +package tests + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/rand" + "mime/multipart" + "net/url" + "os" + "os/signal" + "strings" + "sync" + "syscall" + "testing" + "time" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" + handlersAPI "github.com/wallarm/api-firewall/cmd/api-firewall/internal/handlers/api" + "github.com/wallarm/api-firewall/cmd/api-firewall/internal/updater" + "github.com/wallarm/api-firewall/internal/config" + "github.com/wallarm/api-firewall/internal/platform/database" + "github.com/wallarm/api-firewall/internal/platform/web" +) + +const apiModeOpenAPISpecAPIModeTest = ` +openapi: 3.0.1 +info: + title: Service + version: 1.1.0 +servers: + - url: / +paths: + /absolute-redirect/{n}: + get: + tags: + - Redirects + summary: Absolutely 302 Redirects n times. + parameters: + - name: 'n' + in: path + required: true + schema: {} + responses: + '302': + description: A redirection. + content: {} + /redirect-to: + put: + summary: 302/3XX Redirects to the given URL. + requestBody: + content: + multipart/form-data: + schema: + required: + - url + properties: + url: + type: string + status_code: {} + required: true + responses: + '302': + description: A redirection. + content: {} + /test/security/basic: + get: + responses: + '200': + description: Static page + content: {} + '403': + description: operation forbidden + content: {} + security: + - basicAuth: [] + /test/security/bearer: + get: + responses: + '200': + description: Static page + content: {} + '403': + description: operation forbidden + content: {} + security: + - bearerAuth: [] + /test/security/cookie: + get: + responses: + '200': + description: Static page + content: {} + '403': + description: operation forbidden + content: {} + security: + - cookieAuth: [] + /test/signup: + post: + requestBody: + required: true + content: + application/x-www-form-urlencoded: + schema: + type: object + required: + - email + - firstname + - lastname + properties: + email: + type: string + format: email + pattern: '^[0-9a-zA-Z]+@[0-9a-zA-Z\.]+$' + example: example@mail.com + firstname: + type: string + example: test + lastname: + type: string + example: test + job: + type: string + example: test + url: + type: string + example: http://test.com + application/json: + schema: + type: object + required: + - email + - firstname + - lastname + properties: + email: + type: string + format: email + pattern: '^[0-9a-zA-Z]+@[0-9a-zA-Z\.]+$' + example: example@mail.com + firstname: + type: string + example: test + lastname: + type: string + example: test + job: + type: string + example: test + url: + type: string + example: http://test.com + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: object + required: + - status + properties: + status: + type: string + example: "success" + error: + type: string + '403': + description: operation forbidden + content: {} + /test/multipart: + post: + requestBody: + content: + multipart/form-data: + schema: + type: object + required: + - url + properties: + url: + type: string + id: + type: integer + required: true + responses: + '302': + description: "A redirection." + content: {} + '/test/query': + get: + parameters: + - name: id + in: query + required: true + schema: + type: string + format: uuid + pattern: '^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$' + responses: + '200': + description: Static page + content: {} + '403': + description: operation forbidden + content: {} + '/test/plain': + post: + requestBody: + content: + text/plain: + schema: + type: string + required: true + responses: + '200': + description: Static page + content: {} + '403': + description: operation forbidden + content: {} + '/test/unknownCT': + post: + requestBody: + content: + application/unknownCT: + schema: + type: string + required: true + responses: + '200': + description: Static page + content: {} + '403': + description: operation forbidden + content: {} + /test/headers/request: + get: + summary: Get Request to test Request Headers validation + parameters: + - in: header + name: X-Request-Test + schema: + type: string + format: uuid + pattern: '^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$' + required: true + responses: + 200: + description: Ok + content: { } + /test/cookies/request: + get: + summary: Get Request to test Request Cookies presence + parameters: + - in: cookie + name: cookie_test + schema: + type: string + format: uuid + pattern: '^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$' + required: true + responses: + 200: + description: Ok + content: { } + /test/body/request: + post: + summary: Post Request to test Request Body presence + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - status + properties: + status: + type: string + format: uuid + pattern: '^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$' + error: + type: string + responses: + 200: + description: Ok + content: { } +components: + securitySchemes: + basicAuth: + type: http + scheme: basic + bearerAuth: + type: http + scheme: bearer + bearerFormat: JWT + cookieAuth: + type: apiKey + in: cookie + name: MyAuthHeader + petstore_auth: + type: oauth2 + flows: + implicit: + authorizationUrl: /login + scopes: + read: read + write: write +` + +const ( + testDeleteMethod = "DELETE" + testUnknownPath = "/unknown/path/test" + + testRequestCookie = "cookie_test" + testSecCookieName = "MyAuthHeader" + + DefaultSchemaID = 0 + DefaultSpecVersion = "1.1.0" + UpdatedSpecVersion = "1.1.1" +) + +const apiModeOpenAPISpecAPIModeTestUpdated = ` +openapi: 3.0.1 +info: + title: Service + version: 1.1.1 +servers: + - url: / +paths: + /test/new: + get: + tags: + - Redirects + summary: Absolutely 302 Redirects n times. + responses: + '200': + description: A redirection. + content: {} +` + +var cfg = config.APIFWConfigurationAPIMode{ + APIFWMode: config.APIFWMode{Mode: web.APIMode}, + SpecificationUpdatePeriod: 2 * time.Second, + UnknownParametersDetection: true, + PassOptionsRequests: false, +} + +type APIModeServiceTests struct { + serverUrl *url.URL + shutdown chan os.Signal + logger *logrus.Logger + dbSpec *database.MockDBOpenAPILoader + lock *sync.RWMutex +} + +func TestAPIModeBasic(t *testing.T) { + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + dbSpec := database.NewMockDBOpenAPILoader(mockCtrl) + + logger := logrus.New() + logger.SetLevel(logrus.ErrorLevel) + + var lock sync.RWMutex + + serverUrl, err := url.ParseRequestURI("http://127.0.0.1:80") + if err != nil { + t.Fatalf("parsing API Host URL: %s", err.Error()) + } + + swagger, err := openapi3.NewLoader().LoadFromData([]byte(apiModeOpenAPISpecAPIModeTest)) + if err != nil { + t.Fatalf("loading swagwaf file: %s", err.Error()) + } + + dbSpec.EXPECT().SchemaIDs().Return([]int{DefaultSchemaID}).AnyTimes() + dbSpec.EXPECT().Specification(DefaultSchemaID).Return(swagger).AnyTimes() + dbSpec.EXPECT().SpecificationVersion(DefaultSchemaID).Return(DefaultSpecVersion).AnyTimes() + dbSpec.EXPECT().IsLoaded(DefaultSchemaID).Return(true).AnyTimes() + + shutdown := make(chan os.Signal, 1) + signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM) + + apifwTests := APIModeServiceTests{ + serverUrl: serverUrl, + shutdown: shutdown, + logger: logger, + dbSpec: dbSpec, + lock: &lock, + } + + // basic test + t.Run("testAPIModeSuccess", apifwTests.testAPIModeSuccess) + t.Run("testAPIModeMissedMultipleReqParams", apifwTests.testAPIModeMissedMultipleReqParams) + t.Run("testAPIModeNoXWallarmSchemaIDHeader", apifwTests.testAPIModeNoXWallarmSchemaIDHeader) + + t.Run("testAPIModeMethodAndPathNotFound", apifwTests.testAPIModeMethodAndPathNotFound) + + t.Run("testAPIModeRequiredQueryParameterMissed", apifwTests.testAPIModeRequiredQueryParameterMissed) + t.Run("testAPIModeRequiredHeaderParameterMissed", apifwTests.testAPIModeRequiredHeaderParameterMissed) + t.Run("testAPIModeRequiredCookieParameterMissed", apifwTests.testAPIModeRequiredCookieParameterMissed) + t.Run("testAPIModeRequiredBodyMissed", apifwTests.testAPIModeRequiredBodyMissed) + t.Run("testAPIModeRequiredBodyParameterMissed", apifwTests.testAPIModeRequiredBodyParameterMissed) + + t.Run("testAPIModeRequiredQueryParameterInvalidValue", apifwTests.testAPIModeRequiredQueryParameterInvalidValue) + t.Run("testAPIModeRequiredHeaderParameterInvalidValue", apifwTests.testAPIModeRequiredHeaderParameterInvalidValue) + t.Run("testAPIModeRequiredCookieParameterInvalidValue", apifwTests.testAPIModeRequiredCookieParameterInvalidValue) + t.Run("testAPIModeRequiredBodyParameterInvalidValue", apifwTests.testAPIModeRequiredBodyParameterInvalidValue) + + t.Run("testAPIModeBasicAuthFailed", apifwTests.testAPIModeBasicAuthFailed) + t.Run("testAPIModeBearerTokenFailed", apifwTests.testAPIModeBearerTokenFailed) + t.Run("testAPIModeAPITokenCookieFailed", apifwTests.testAPIModeAPITokenCookieFailed) + + t.Run("testAPIModeSuccessEmptyPathParameter", apifwTests.testAPIModeSuccessEmptyPathParameter) + t.Run("testAPIModeSuccessMultipartStringParameter", apifwTests.testAPIModeSuccessMultipartStringParameter) + + t.Run("testAPIModeJSONParseError", apifwTests.testAPIModeJSONParseError) + t.Run("testAPIModeInvalidCTParseError", apifwTests.testAPIModeInvalidCTParseError) + t.Run("testAPIModeCTNotInSpec", apifwTests.testAPIModeCTNotInSpec) + t.Run("testAPIModeEmptyBody", apifwTests.testAPIModeEmptyBody) + + t.Run("testAPIModeUnknownParameterBodyJSON", apifwTests.testAPIModeUnknownParameterBodyJSON) + t.Run("testAPIModeUnknownParameterBodyPost", apifwTests.testAPIModeUnknownParameterBodyPost) + t.Run("testAPIModeUnknownParameterQuery", apifwTests.testAPIModeUnknownParameterQuery) + t.Run("testAPIModeUnknownParameterTextPlainCT", apifwTests.testAPIModeUnknownParameterTextPlainCT) + t.Run("testAPIModeUnknownParameterInvalidCT", apifwTests.testAPIModeUnknownParameterInvalidCT) + + t.Run("testAPIModePassOptionsRequest", apifwTests.testAPIModePassOptionsRequest) + + t.Run("testAPIModeMultipartOptionalParams", apifwTests.testAPIModeMultipartOptionalParams) +} + +func createForm(form map[string]string) (string, io.Reader, error) { + body := new(bytes.Buffer) + mp := multipart.NewWriter(body) + defer mp.Close() + for key, val := range form { + if strings.HasPrefix(val, "@") { + val = val[1:] + file, err := os.Open(val) + if err != nil { + return "", nil, err + } + defer file.Close() + part, err := mp.CreateFormFile(key, val) + if err != nil { + return "", nil, err + } + io.Copy(part, file) + } else { + mp.WriteField(key, val) + } + } + return mp.FormDataContentType(), body, nil +} + +func (s *APIModeServiceTests) testAPIModeSuccess(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + p, err := json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request with invalid email + reqInvalidEmail, err := json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req.SetBodyStream(bytes.NewReader(reqInvalidEmail), -1) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + +} + +func (s *APIModeServiceTests) testAPIModeMissedMultipleReqParams(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + p, err := json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request with invalid email + reqInvalidEmail, err := json.Marshal(map[string]interface{}{ + "email": "test@wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req.SetBodyStream(bytes.NewReader(reqInvalidEmail), -1) + + missedParams := map[string]interface{}{ + "firstname": struct{}{}, + "lastname": struct{}{}, + } + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if len(apifwResponse.Errors) != 2 { + t.Errorf("wrong number of errors. Expected: 2. Got: %d", len(apifwResponse.Errors)) + } + + for _, apifwErr := range apifwResponse.Errors { + + if apifwErr.Code != handlersAPI.ErrCodeRequiredBodyParameterMissed { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeRequiredBodyParameterMissed, apifwErr.Code) + } + + if len(apifwErr.Fields) != 1 { + t.Errorf("wrong number of related fields. Expected: 1. Got: %d", len(apifwErr.Fields)) + } + + if _, ok := missedParams[apifwErr.Fields[0]]; !ok { + t.Errorf("Invalid missed field. Expected: firstname or lastname but got %s", + apifwErr.Fields[0]) + } + + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + +} + +func (s *APIModeServiceTests) testAPIModeSuccessEmptyPathParameter(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI(fmt.Sprintf("/absolute-redirect/%d", rand.Uint32())) + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + req.SetRequestURI("/absolute-redirect/testString") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + +} + +func (s *APIModeServiceTests) testAPIModeSuccessMultipartStringParameter(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/redirect-to") + req.Header.SetMethod("PUT") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + form := map[string]string{"url": "test"} + ct, body, err := createForm(form) + if err != nil { + t.Fatal(err) + } + + bodyData, err := io.ReadAll(body) + if err != nil { + t.Fatal(err) + } + + req.Header.SetContentType(ct) + req.SetBody(bodyData) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + req = fasthttp.AcquireRequest() + req.SetRequestURI("/redirect-to") + req.Header.SetMethod("PUT") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + form = map[string]string{"wrongKey": "test"} + ct, body, err = createForm(form) + if err != nil { + t.Fatal(err) + } + + bodyData, err = io.ReadAll(body) + if err != nil { + t.Fatal(err) + } + + req.Header.SetContentType(ct) + req.SetBody(bodyData) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + +} + +func (s *APIModeServiceTests) testAPIModeJSONParseError(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader([]byte("{\"test\"=\"wrongSyntax\"}")), -1) + req.Header.SetContentType("application/json") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeRequiredBodyParseError { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeRequiredBodyParseError, apifwResponse.Errors[0].Code) + } +} + +func (s *APIModeServiceTests) testAPIModeInvalidCTParseError(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + p, err := json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("invalid/mimetype") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeRequiredBodyParseError { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeRequiredBodyParseError, apifwResponse.Errors[0].Code) + } + +} + +func (s *APIModeServiceTests) testAPIModeCTNotInSpec(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + p, err := json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/multipart") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeRequiredBodyParseError { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeRequiredBodyParseError, apifwResponse.Errors[0].Code) + } + +} + +func (s *APIModeServiceTests) testAPIModeEmptyBody(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + //req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeRequiredBodyMissed { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeRequiredBodyMissed, apifwResponse.Errors[0].Code) + } + +} + +func (s *APIModeServiceTests) testAPIModeNoXWallarmSchemaIDHeader(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + p, err := json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 500 { + t.Errorf("Incorrect response status code. Expected: 500 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +func (s *APIModeServiceTests) testAPIModeMethodAndPathNotFound(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + p, err := json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod(testDeleteMethod) + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeMethodAndPathNotFound { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeMethodAndPathNotFound, apifwResponse.Errors[0].Code) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + // check path + req.Header.SetMethod("POST") + req.Header.SetRequestURI(testUnknownPath) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse = handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeMethodAndPathNotFound { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeMethodAndPathNotFound, apifwResponse.Errors[0].Code) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + +} + +func (s *APIModeServiceTests) testAPIModeRequiredQueryParameterMissed(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/query?id=" + uuid.New().String()) + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + req.SetRequestURI("/test/query?wrong_q_parameter=test") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeRequiredQueryParameterMissed { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeRequiredQueryParameterMissed, apifwResponse.Errors[0].Code) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +func (s *APIModeServiceTests) testAPIModeRequiredHeaderParameterMissed(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + xReqTestValue := uuid.New() + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/headers/request") + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + req.Header.Add(testRequestHeader, xReqTestValue.String()) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + req.Header.Del(testRequestHeader) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeRequiredHeaderMissed { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeRequiredHeaderMissed, apifwResponse.Errors[0].Code) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +func (s *APIModeServiceTests) testAPIModeRequiredCookieParameterMissed(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/cookies/request") + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + req.Header.SetCookie(testRequestCookie, uuid.New().String()) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + req.Header.DelAllCookies() + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeRequiredCookieParameterMissed { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeRequiredCookieParameterMissed, apifwResponse.Errors[0].Code) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +func (s *APIModeServiceTests) testAPIModeRequiredBodyMissed(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + p, err := json.Marshal(map[string]interface{}{ + "status": uuid.New().String(), + "error": "test", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/body/request") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + req = fasthttp.AcquireRequest() + req.SetRequestURI("/test/body/request") + req.Header.SetMethod("POST") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeRequiredBodyMissed { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeRequiredBodyMissed, apifwResponse.Errors[0].Code) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +func (s *APIModeServiceTests) testAPIModeRequiredBodyParameterMissed(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + p, err := json.Marshal(map[string]interface{}{ + "status": uuid.New().String(), + "error": "test", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/body/request") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + // body without required parameter + p, err = json.Marshal(map[string]interface{}{ + "error": "test", + }) + + if err != nil { + t.Fatal(err) + } + + req = fasthttp.AcquireRequest() + req.SetRequestURI("/test/body/request") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeRequiredBodyParameterMissed { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeRequiredBodyParameterMissed, apifwResponse.Errors[0].Code) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +// Invalid parameters errors +func (s *APIModeServiceTests) testAPIModeRequiredQueryParameterInvalidValue(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/query?id=" + uuid.New().String()) + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + req.SetRequestURI("/test/query?id=invalid_value_test") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeRequiredQueryParameterInvalidValue { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeRequiredQueryParameterInvalidValue, apifwResponse.Errors[0].Code) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +func (s *APIModeServiceTests) testAPIModeRequiredHeaderParameterInvalidValue(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + xReqTestValue := uuid.New() + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/headers/request") + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + req.Header.Add(testRequestHeader, xReqTestValue.String()) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + req.Header.Del(testRequestHeader) + req.Header.Add(testRequestHeader, "invalid_value") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeRequiredHeaderInvalidValue { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeRequiredHeaderInvalidValue, apifwResponse.Errors[0].Code) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +func (s *APIModeServiceTests) testAPIModeRequiredCookieParameterInvalidValue(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/cookies/request") + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + req.Header.SetCookie(testRequestCookie, uuid.New().String()) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + req.Header.SetCookie(testRequestCookie, "invalid_test_value") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeRequiredCookieParameterInvalidValue { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeRequiredCookieParameterInvalidValue, apifwResponse.Errors[0].Code) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +func (s *APIModeServiceTests) testAPIModeRequiredBodyParameterInvalidValue(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + p, err := json.Marshal(map[string]interface{}{ + "status": uuid.New().String(), + "error": "test", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/body/request") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + // body without required parameter + p, err = json.Marshal(map[string]interface{}{ + "status": "invalid_test_value", + "error": "test", + }) + + if err != nil { + t.Fatal(err) + } + + req = fasthttp.AcquireRequest() + req.SetRequestURI("/test/body/request") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeRequiredBodyParameterInvalidValue { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeRequiredBodyParameterInvalidValue, apifwResponse.Errors[0].Code) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +// security requirements +func (s *APIModeServiceTests) testAPIModeBasicAuthFailed(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/security/basic") + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + req.Header.Add("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("user1:password1"))) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + req.Header.Del("Authorization") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeSecRequirementsFailed { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeSecRequirementsFailed, apifwResponse.Errors[0].Code) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +func (s *APIModeServiceTests) testAPIModeBearerTokenFailed(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/security/bearer") + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + req.Header.Add("Authorization", "Bearer "+uuid.New().String()) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + req.Header.Del("Authorization") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeSecRequirementsFailed { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeSecRequirementsFailed, apifwResponse.Errors[0].Code) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +func (s *APIModeServiceTests) testAPIModeAPITokenCookieFailed(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/security/cookie") + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + req.Header.SetCookie(testSecCookieName, uuid.New().String()) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + req.Header.DelAllCookies() + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeSecRequirementsFailed { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeSecRequirementsFailed, apifwResponse.Errors[0].Code) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +// unknown parameters +func (s *APIModeServiceTests) testAPIModeUnknownParameterBodyJSON(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + p, err := json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "unknownParam": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeUnknownParameterFound { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeUnknownParameterFound, apifwResponse.Errors[0].Code) + } + + p, err = json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req = fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +func (s *APIModeServiceTests) testAPIModeUnknownParameterBodyPost(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.PostArgs().Add("firstname", "test") + req.PostArgs().Add("lastname", "test") + req.PostArgs().Add("job", "test") + req.PostArgs().Add("unknownParam", "test") + req.PostArgs().Add("email", "test@example.com") + req.PostArgs().Add("url", "test") + req.SetBodyString(req.PostArgs().String()) + req.Header.SetContentType("application/x-www-form-urlencoded") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeUnknownParameterFound { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeUnknownParameterFound, apifwResponse.Errors[0].Code) + } + + req = fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.PostArgs().Add("firstname", "test") + req.PostArgs().Add("lastname", "test") + req.PostArgs().Add("job", "test") + req.PostArgs().Add("email", "test@example.com") + req.PostArgs().Add("url", "test") + req.SetBodyString(req.PostArgs().String()) + req.Header.SetContentType("application/x-www-form-urlencoded") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +func (s *APIModeServiceTests) testAPIModeUnknownParameterQuery(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/query?uparam=test&id=" + uuid.New().String()) + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + apifwResponse := handlersAPI.Response{} + if err := json.Unmarshal(reqCtx.Response.Body(), &apifwResponse); err != nil { + t.Errorf("Error while JSON response parsing: %v", err) + } + + if apifwResponse.Errors[0].Code != handlersAPI.ErrCodeUnknownParameterFound { + t.Errorf("Incorrect error code. Expected: %s and got %s", + handlersAPI.ErrCodeUnknownParameterFound, apifwResponse.Errors[0].Code) + } + + req = fasthttp.AcquireRequest() + req.SetRequestURI("/test/query?id=" + uuid.New().String()) + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) +} + +func (s *APIModeServiceTests) testAPIModeUnknownParameterTextPlainCT(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/plain") + req.Header.SetMethod("POST") + req.SetBodyString("testString") + req.Header.SetContentType("text/plain") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } +} + +func (s *APIModeServiceTests) testAPIModeUnknownParameterInvalidCT(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/unknownCT") + req.Header.SetMethod("POST") + req.SetBodyString("testString") + req.Header.SetContentType("application/unknownCT") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + if reqCtx.Response.StatusCode() != 500 { + t.Errorf("Incorrect response status code. Expected: 500 and got %d", + reqCtx.Response.StatusCode()) + } +} + +func (s *APIModeServiceTests) testAPIModePassOptionsRequest(t *testing.T) { + + cfg.PassOptionsRequests = true + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("OPTIONS") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + +} + +func (s *APIModeServiceTests) testAPIModeMultipartOptionalParams(t *testing.T) { + + handler := handlersAPI.Handlers(s.lock, &cfg, s.shutdown, s.logger, s.dbSpec) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/multipart") + req.Header.SetMethod("POST") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + form := map[string]string{"url": "test", "id": "10"} + ct, body, err := createForm(form) + if err != nil { + t.Fatal(err) + } + + bodyData, err := io.ReadAll(body) + if err != nil { + t.Fatal(err) + } + + req.Header.SetContentType(ct) + req.SetBody(bodyData) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + req = fasthttp.AcquireRequest() + req.SetRequestURI("/test/multipart") + req.Header.SetMethod("POST") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + form = map[string]string{"url": "test", "id": "test"} + ct, body, err = createForm(form) + if err != nil { + t.Fatal(err) + } + + bodyData, err = io.ReadAll(body) + if err != nil { + t.Fatal(err) + } + + req.Header.SetContentType(ct) + req.SetBody(bodyData) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + t.Logf("Name of the test: %s; status code: %d; response body: %s", t.Name(), reqCtx.Response.StatusCode(), string(reqCtx.Response.Body())) + +} + +func TestAPIModeMockedUpdater(t *testing.T) { + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + dbSpecBeforeUpdate := database.NewMockDBOpenAPILoader(mockCtrl) + + dbSpec := database.NewMockDBOpenAPILoader(mockCtrl) + + logger := logrus.New() + logger.SetLevel(logrus.ErrorLevel) + + schemaIDsBefore := dbSpec.EXPECT().SchemaIDs().Return([]int{DefaultSchemaID}) + specVersionBefore := dbSpec.EXPECT().SpecificationVersion(DefaultSchemaID).Return(DefaultSpecVersion) + loadUpdater := dbSpec.EXPECT().Load(gomock.Any()).Return(nil) + schemaIDsAfter := dbSpec.EXPECT().SchemaIDs().Return([]int{DefaultSchemaID}) + specVersionAfter := dbSpec.EXPECT().SpecificationVersion(DefaultSchemaID).Return(UpdatedSpecVersion) + + // updater calls + gomock.InOrder(schemaIDsBefore, specVersionBefore, loadUpdater, schemaIDsAfter, specVersionAfter) + + swagger, err := openapi3.NewLoader().LoadFromData([]byte(apiModeOpenAPISpecAPIModeTestUpdated)) + if err != nil { + t.Fatalf("loading swagwaf file: %s", err.Error()) + } + specRoutes := dbSpec.EXPECT().Specification(DefaultSchemaID).Return(swagger) + + schemaIDsRoutes := dbSpec.EXPECT().SchemaIDs().Return([]int{DefaultSchemaID}) + schemaIDsApps := dbSpec.EXPECT().SchemaIDs().Return([]int{DefaultSchemaID}) + specRouter := dbSpec.EXPECT().Specification(DefaultSchemaID).Return(swagger) + specVersionRouter := dbSpec.EXPECT().SpecificationVersion(DefaultSchemaID).Return(UpdatedSpecVersion) + specVersionLogMsg := dbSpec.EXPECT().SpecificationVersion(DefaultSchemaID).Return(UpdatedSpecVersion) + + // router calls + gomock.InOrder(schemaIDsRoutes, schemaIDsApps, specRoutes, specRouter, specVersionRouter, specVersionLogMsg) + + shutdown := make(chan os.Signal, 1) + signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM) + + health := handlersAPI.Health{} + + var lock sync.RWMutex + + dbSpecBeforeUpdate.EXPECT().Specification(DefaultSchemaID).Return(swagger).AnyTimes() + dbSpecBeforeUpdate.EXPECT().SchemaIDs().Return([]int{DefaultSchemaID}).AnyTimes() + dbSpecBeforeUpdate.EXPECT().SpecificationVersion(DefaultSchemaID).Return(DefaultSpecVersion).AnyTimes() + dbSpecBeforeUpdate.EXPECT().IsLoaded(DefaultSchemaID).Return(true).AnyTimes() + + handler := handlersAPI.Handlers(&lock, &cfg, shutdown, logger, dbSpecBeforeUpdate) + api := fasthttp.Server{Handler: handler} + + updSpecErrors := make(chan error, 1) + updater := updater.NewController(&lock, logger, dbSpec, &cfg, &api, shutdown, &health) + go func() { + t.Logf("starting specification regular update process every %.0f seconds", cfg.SpecificationUpdatePeriod.Seconds()) + updSpecErrors <- updater.Start() + }() + + time.Sleep(3 * time.Second) + + if err := updater.Shutdown(); err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/new") + req.Header.SetMethod("GET") + req.Header.Add(web.XWallarmSchemaIDHeader, fmt.Sprintf("%d", DefaultSchemaID)) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + // checker in the request handler + dbSpec.EXPECT().IsLoaded(DefaultSchemaID).Return(true).AnyTimes() + + lock.RLock() + api.Handler(&reqCtx) + lock.RUnlock() + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + +} diff --git a/cmd/api-firewall/tests/main_json_test.go b/cmd/api-firewall/tests/main_json_test.go index 4cecdc9..825c3bb 100644 --- a/cmd/api-firewall/tests/main_json_test.go +++ b/cmd/api-firewall/tests/main_json_test.go @@ -13,11 +13,10 @@ import ( "github.com/golang/mock/gomock" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" - "github.com/wallarm/api-firewall/cmd/api-firewall/internal/handlers" + proxyHandler "github.com/wallarm/api-firewall/cmd/api-firewall/internal/handlers/proxy" "github.com/wallarm/api-firewall/internal/config" "github.com/wallarm/api-firewall/internal/platform/proxy" "github.com/wallarm/api-firewall/internal/platform/router" - "github.com/wallarm/api-firewall/internal/platform/shadowAPI" ) const openAPIJSONSpecTest = ` @@ -100,7 +99,6 @@ func TestJSONBasic(t *testing.T) { pool := proxy.NewMockPool(mockCtrl) client := proxy.NewMockHTTPClient(mockCtrl) - checker := shadowAPI.NewMockChecker(mockCtrl) swagger, err := openapi3.NewLoader().LoadFromData([]byte(openAPIJSONSpecTest)) if err != nil { @@ -122,7 +120,6 @@ func TestJSONBasic(t *testing.T) { proxy: pool, client: client, swagRouter: swagRouter, - shadowAPI: checker, } // basic test @@ -134,7 +131,7 @@ func TestJSONBasic(t *testing.T) { func (s *ServiceTests) testBasicObjJSONFieldValidation(t *testing.T) { - handler := handlers.OpenapiProxy(&apifwCfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxyHandler.Handlers(&apifwCfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) // basic object check p, err := json.Marshal(map[string]interface{}{ @@ -181,7 +178,7 @@ func (s *ServiceTests) testBasicObjJSONFieldValidation(t *testing.T) { func (s *ServiceTests) testBasicArrJSONFieldValidation(t *testing.T) { - handler := handlers.OpenapiProxy(&apifwCfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxyHandler.Handlers(&apifwCfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) p, err := json.Marshal([]map[string]interface{}{{ "valueNum": 10.1, @@ -229,7 +226,7 @@ func (s *ServiceTests) testBasicArrJSONFieldValidation(t *testing.T) { func (s *ServiceTests) testNegativeJSONFieldValidation(t *testing.T) { - handler := handlers.OpenapiProxy(&apifwCfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxyHandler.Handlers(&apifwCfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) req := fasthttp.AcquireRequest() req.SetRequestURI("/test") @@ -373,9 +370,6 @@ func (s *ServiceTests) testNegativeJSONFieldValidation(t *testing.T) { Request: *req, } - s.proxy.EXPECT().Get().Return(s.client, nil) - s.proxy.EXPECT().Put(s.client) - handler(&reqCtx) if reqCtx.Response.StatusCode() != 403 { diff --git a/cmd/api-firewall/tests/main_test.go b/cmd/api-firewall/tests/main_test.go index 85d9f13..2eaa257 100644 --- a/cmd/api-firewall/tests/main_test.go +++ b/cmd/api-firewall/tests/main_test.go @@ -22,12 +22,11 @@ import ( "github.com/google/uuid" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" - "github.com/wallarm/api-firewall/cmd/api-firewall/internal/handlers" + proxy2 "github.com/wallarm/api-firewall/cmd/api-firewall/internal/handlers/proxy" "github.com/wallarm/api-firewall/internal/config" "github.com/wallarm/api-firewall/internal/platform/denylist" "github.com/wallarm/api-firewall/internal/platform/proxy" "github.com/wallarm/api-firewall/internal/platform/router" - "github.com/wallarm/api-firewall/internal/platform/shadowAPI" ) const openAPISpecTest = ` @@ -38,6 +37,53 @@ info: servers: - url: / paths: + /cookie_params: + get: + tags: + - Cookie parameters + summary: The endpoint with cookie parameters only + parameters: + - name: cookie_mandatory + in: cookie + description: mandatory cookie parameter + required: true + schema: + type: string + - name: cookie_optional + in: cookie + description: optional cookie parameter + required: false + schema: + type: integer + enum: [0, 10, 100] + responses: + '200': + description: Set cookies. + content: {} + /cookie_params_min_max: + get: + tags: + - Cookie parameters + summary: The endpoint with cookie parameters only + parameters: + - name: cookie_mandatory + in: cookie + description: mandatory cookie parameter + required: true + schema: + type: string + - name: cookie_optional_min_max + in: cookie + description: optional cookie parameter + required: false + schema: + type: integer + minimum: 1000 + maximum: 2000 + responses: + '200': + description: Set cookies. + content: {} /users/{id}/{test}: parameters: - in: path @@ -69,6 +115,34 @@ paths: post: requestBody: content: + application/unsupported-type: + schema: + {} + application/x-www-form-urlencoded: + schema: + type: object + required: + - email + - firstname + - lastname + properties: + email: + type: string + format: email + pattern: '^[0-9a-zA-Z]+@[0-9a-zA-Z\.]+$' + example: example@mail.com + firstname: + type: string + example: test + lastname: + type: string + example: test + url: + type: string + example: test + job: + type: string + example: test application/json: schema: type: object @@ -80,6 +154,7 @@ paths: email: type: string format: email + pattern: '^[0-9a-zA-Z]+@[0-9a-zA-Z\.]+$' example: example@mail.com firstname: type: string @@ -87,6 +162,12 @@ paths: lastname: type: string example: test + url: + type: string + example: test + job: + type: string + example: test responses: '200': description: successful operation @@ -128,6 +209,13 @@ paths: '403': description: operation forbidden content: {} + /get/test: + get: + summary: Get Test Info + responses: + 200: + description: Ok + content: { } /user: get: summary: Get User Info @@ -221,7 +309,6 @@ type ServiceTests struct { proxy *proxy.MockPool client *proxy.MockHTTPClient swagRouter *router.Router - shadowAPI *shadowAPI.MockChecker } func compressFlate(data []byte) ([]byte, error) { @@ -293,7 +380,6 @@ func TestBasic(t *testing.T) { pool := proxy.NewMockPool(mockCtrl) client := proxy.NewMockHTTPClient(mockCtrl) - checker := shadowAPI.NewMockChecker(mockCtrl) swagger, err := openapi3.NewLoader().LoadFromData([]byte(openAPISpecTest)) if err != nil { @@ -315,7 +401,6 @@ func TestBasic(t *testing.T) { proxy: pool, client: client, swagRouter: swagRouter, - shadowAPI: checker, } // basic test @@ -345,6 +430,15 @@ func TestBasic(t *testing.T) { t.Run("reqBodyCompression", apifwTests.testRequestBodyCompression) t.Run("respBodyCompression", apifwTests.testResponseBodyCompression) + + t.Run("requestOptionalCookies", apifwTests.requestOptionalCookies) + t.Run("requestOptionalMinMaxCookies", apifwTests.requestOptionalMinMaxCookies) + + // unknown parameters in requests + t.Run("unknownParamQuery", apifwTests.unknownParamQuery) + t.Run("unknownParamPostBody", apifwTests.unknownParamPostBody) + t.Run("unknownParamJSONParam", apifwTests.unknownParamJSONParam) + t.Run("unknownParamInvalidMimeType", apifwTests.unknownParamUnsupportedMimeType) } func (s *ServiceTests) testBlockMode(t *testing.T) { @@ -359,7 +453,7 @@ func (s *ServiceTests) testBlockMode(t *testing.T) { }, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) p, err := json.Marshal(map[string]interface{}{ "firstname": "test", @@ -414,9 +508,6 @@ func (s *ServiceTests) testBlockMode(t *testing.T) { Request: *req, } - s.proxy.EXPECT().Get().Return(s.client, nil) - s.proxy.EXPECT().Put(s.client) - handler(&reqCtx) if reqCtx.Response.StatusCode() != 403 { @@ -452,7 +543,7 @@ func (s *ServiceTests) testDenylist(t *testing.T) { t.Fatal(err) } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, deniedTokens, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, deniedTokens) p, err := json.Marshal(map[string]interface{}{ "firstname": "test", @@ -534,7 +625,7 @@ func (s *ServiceTests) testShadowAPI(t *testing.T) { t.Fatal(err) } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, deniedTokens, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, deniedTokens) p, err := json.Marshal(map[string]interface{}{ "firstname": "test", @@ -565,7 +656,6 @@ func (s *ServiceTests) testShadowAPI(t *testing.T) { s.proxy.EXPECT().Get().Return(s.client, nil) s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) - s.shadowAPI.EXPECT().Check(gomock.Any()).Times(1) s.proxy.EXPECT().Put(s.client).Return(nil) handler(&reqCtx) @@ -588,7 +678,7 @@ func (s *ServiceTests) testLogOnlyMode(t *testing.T) { }, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) p, err := json.Marshal(map[string]interface{}{ "firstname": "test", @@ -618,7 +708,6 @@ func (s *ServiceTests) testLogOnlyMode(t *testing.T) { } s.proxy.EXPECT().Get().Return(s.client, nil) - s.shadowAPI.EXPECT().Check(gomock.Any()).Times(0) s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) s.proxy.EXPECT().Put(s.client).Return(nil) @@ -643,7 +732,7 @@ func (s *ServiceTests) testDisableMode(t *testing.T) { }, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) p, err := json.Marshal(map[string]interface{}{ "email": "wallarm.com", @@ -694,7 +783,7 @@ func (s *ServiceTests) testBlockLogOnlyMode(t *testing.T) { }, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) p, err := json.Marshal(map[string]interface{}{ "firstname": "test", @@ -746,7 +835,7 @@ func (s *ServiceTests) testLogOnlyBlockMode(t *testing.T) { }, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) p, err := json.Marshal(map[string]interface{}{ "firstname": "test", @@ -799,7 +888,7 @@ func (s *ServiceTests) testCommonParameters(t *testing.T) { }, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) req := fasthttp.AcquireRequest() req.SetRequestURI("/users/1/1") @@ -933,7 +1022,7 @@ func (s *ServiceTests) testOauthIntrospectionReadSuccess(t *testing.T) { Server: serverConf, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) resp := fasthttp.AcquireResponse() resp.SetStatusCode(fasthttp.StatusOK) @@ -1018,7 +1107,7 @@ func (s *ServiceTests) testOauthIntrospectionReadUnsuccessful(t *testing.T) { Server: serverConf, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) resp := fasthttp.AcquireResponse() resp.SetStatusCode(fasthttp.StatusOK) @@ -1027,9 +1116,6 @@ func (s *ServiceTests) testOauthIntrospectionReadUnsuccessful(t *testing.T) { Request: *req, } - s.proxy.EXPECT().Get().Return(s.client, nil) - s.proxy.EXPECT().Put(s.client).Return(nil) - handler(&reqCtx) if reqCtx.Response.StatusCode() != 403 { @@ -1085,7 +1171,7 @@ func (s *ServiceTests) testOauthIntrospectionInvalidResponse(t *testing.T) { Server: serverConf, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) resp := fasthttp.AcquireResponse() resp.SetStatusCode(fasthttp.StatusOK) @@ -1094,10 +1180,6 @@ func (s *ServiceTests) testOauthIntrospectionInvalidResponse(t *testing.T) { Request: *req, } - s.proxy.EXPECT().Get().Return(s.client, nil) - //s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) - s.proxy.EXPECT().Put(s.client).Return(nil) - handler(&reqCtx) if reqCtx.Response.StatusCode() != 403 { @@ -1153,7 +1235,7 @@ func (s *ServiceTests) testOauthIntrospectionReadWriteSuccess(t *testing.T) { Server: serverConf, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) resp := fasthttp.AcquireResponse() resp.SetStatusCode(fasthttp.StatusOK) @@ -1222,7 +1304,7 @@ func (s *ServiceTests) testOauthIntrospectionContentTypeRequest(t *testing.T) { Server: serverConf, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) resp := fasthttp.AcquireResponse() resp.SetStatusCode(fasthttp.StatusOK) @@ -1283,7 +1365,7 @@ func (s *ServiceTests) testOauthJWTRS256(t *testing.T) { Server: serverConf, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) resp := fasthttp.AcquireResponse() resp.SetStatusCode(fasthttp.StatusOK) @@ -1310,9 +1392,6 @@ func (s *ServiceTests) testOauthJWTRS256(t *testing.T) { Request: *req, } - s.proxy.EXPECT().Get().Return(s.client, nil) - s.proxy.EXPECT().Put(s.client).Return(nil) - handler(&reqCtx) if reqCtx.Response.StatusCode() != 403 { @@ -1361,7 +1440,7 @@ func (s *ServiceTests) testOauthJWTHS256(t *testing.T) { Server: serverConf, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) resp := fasthttp.AcquireResponse() resp.SetStatusCode(fasthttp.StatusOK) @@ -1388,9 +1467,6 @@ func (s *ServiceTests) testOauthJWTHS256(t *testing.T) { Request: *req, } - s.proxy.EXPECT().Get().Return(s.client, nil) - s.proxy.EXPECT().Put(s.client).Return(nil) - handler(&reqCtx) if reqCtx.Response.StatusCode() != 403 { @@ -1412,7 +1488,7 @@ func (s *ServiceTests) testRequestHeaders(t *testing.T) { }, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) xReqTestValue := uuid.New() @@ -1448,9 +1524,6 @@ func (s *ServiceTests) testRequestHeaders(t *testing.T) { Request: *req, } - s.proxy.EXPECT().Get().Return(s.client, nil) - s.proxy.EXPECT().Put(s.client) - handler(&reqCtx) if reqCtx.Response.StatusCode() != 403 { @@ -1472,7 +1545,7 @@ func (s *ServiceTests) testResponseHeaders(t *testing.T) { }, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) xRespTestValue := uuid.New() @@ -1533,7 +1606,7 @@ func (s *ServiceTests) testRequestBodyCompression(t *testing.T) { }, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) req := fasthttp.AcquireRequest() req.SetRequestURI("/test/signup") @@ -1617,9 +1690,6 @@ func (s *ServiceTests) testRequestBodyCompression(t *testing.T) { Request: *req, } - s.proxy.EXPECT().Get().Return(s.client, nil) - s.proxy.EXPECT().Put(s.client) - handler(&reqCtx) if reqCtx.Response.StatusCode() != 403 { @@ -1643,7 +1713,7 @@ func (s *ServiceTests) testResponseBodyCompression(t *testing.T) { }, } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) p, err := json.Marshal(map[string]interface{}{ "firstname": "test", @@ -1719,3 +1789,444 @@ func (s *ServiceTests) testResponseBodyCompression(t *testing.T) { } } + +func (s *ServiceTests) requestOptionalCookies(t *testing.T) { + + var cfg = config.APIFWConfiguration{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + }, + } + + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/cookie_params") + req.Header.SetMethod("GET") + req.Header.SetCookie("cookie_mandatory", "test") + req.Header.SetCookie("cookie_optional", "10") + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.SetContentType("application/json") + resp.SetBody([]byte("{\"status\":\"success\"}")) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request without an optional cookie + req.Header.DelCookie("cookie_optional") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request with an optional cookie but optional cookie has invalid value + req.Header.SetCookie("cookie_optional", "wrongValue") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request without an optional cookie + req.Header.DelCookie("cookie_mandatory") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + +} + +func (s *ServiceTests) requestOptionalMinMaxCookies(t *testing.T) { + + var cfg = config.APIFWConfiguration{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + }, + } + + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/cookie_params_min_max") + req.Header.SetMethod("GET") + req.Header.SetCookie("cookie_mandatory", "test") + req.Header.SetCookie("cookie_optional_min_max", "1001") + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.SetContentType("application/json") + resp.SetBody([]byte("{\"status\":\"success\"}")) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request without an optional cookie + req.Header.DelCookie("cookie_optional_min_max") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request with an optional cookie but optional cookie has invalid value + req.Header.SetCookie("cookie_optional_min_max", "999") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request with an optional cookie but optional cookie has invalid value + req.Header.SetCookie("cookie_optional_min_max", "wrongValue") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request without an optional cookie + req.Header.DelCookie("cookie_mandatory") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + +} + +func (s *ServiceTests) unknownParamQuery(t *testing.T) { + + var cfg = config.APIFWConfiguration{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + UnknownParametersDetection: true, + }, + } + + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/get/test") + req.Header.SetMethod("GET") + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + req = fasthttp.AcquireRequest() + req.SetRequestURI("/get/test?test=123") + req.Header.SetMethod("GET") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + +} + +func (s *ServiceTests) unknownParamPostBody(t *testing.T) { + + var cfg = config.APIFWConfiguration{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + UnknownParametersDetection: true, + }, + } + + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.SetBodyString("firstname=test&lastname=testjob=test&email=test@wallarm.com&url=http://wallarm.com") + req.Header.SetContentType("application/x-www-form-urlencoded") + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.SetContentType("application/json") + resp.SetBody([]byte("{\"status\":\"success\"}")) + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + req.SetBodyString("firstname=test&lastname=testjob=test&email=test@wallarm.com&url=http://wallarm.com&test=hello") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + +} + +func (s *ServiceTests) unknownParamJSONParam(t *testing.T) { + + var cfg = config.APIFWConfiguration{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + UnknownParametersDetection: true, + }, + } + + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) + + p, err := json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.SetContentType("application/json") + resp.SetBody([]byte("{\"status\":\"success\"}")) + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + p, err = json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + "test": "hello", + }) + + if err != nil { + t.Fatal(err) + } + + req.SetBodyStream(bytes.NewReader(p), -1) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + +} + +func (s *ServiceTests) unknownParamUnsupportedMimeType(t *testing.T) { + + var cfg = config.APIFWConfiguration{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + UnknownParametersDetection: true, + }, + } + + handler := proxy2.Handlers(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/x-www-form-urlencoded") + req.SetBodyString("firstname=test&lastname=testjob=test&email=test@wallarm.com&url=http://wallarm.com") + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.SetContentType("application/json") + resp.SetBody([]byte("{\"status\":\"success\"}")) + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + req.Header.SetContentType("application/unsupported-type") + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + +} diff --git a/demo/docker-compose/docker-compose-api-mode.yml b/demo/docker-compose/docker-compose-api-mode.yml new file mode 100644 index 0000000..0dbf49a --- /dev/null +++ b/demo/docker-compose/docker-compose-api-mode.yml @@ -0,0 +1,22 @@ +version: '3.8' +services: + api-firewall: + container_name: api-firewall + image: wallarm/api-firewall:v0.6.12 + restart: on-failure + environment: + APIFW_MODE: "api" + APIFW_SPECIFICATION_UPDATE_PERIOD: "1m" + APIFW_API_MODE_UNKNOWN_PARAMETERS_DETECTION: "true" + APIFW_PASS_OPTIONS: "false" + APIFW_URL: "http://0.0.0.0:8080" + APIFW_HEALTH_HOST: "0.0.0.0:9667" + APIFW_READ_TIMEOUT: "5s" + APIFW_WRITE_TIMEOUT: "5s" + APIFW_LOG_LEVEL: "info" + volumes: + - ./volumes/wallarm_api.db:/var/lib/wallarm-api/1/wallarm_api.db:ro + ports: + - "8080:8080" + - "9667:9667" + stop_grace_period: 1s \ No newline at end of file diff --git a/demo/docker-compose/docker-compose.yml b/demo/docker-compose/docker-compose.yml index a5d81fc..02bb32e 100644 --- a/demo/docker-compose/docker-compose.yml +++ b/demo/docker-compose/docker-compose.yml @@ -2,7 +2,7 @@ version: '3.8' services: api-firewall: container_name: api-firewall - image: wallarm/api-firewall:v0.6.11 + image: wallarm/api-firewall:v0.6.12 restart: on-failure environment: APIFW_URL: "http://0.0.0.0:8080" diff --git a/demo/docker-compose/volumes/wallarm_api.db b/demo/docker-compose/volumes/wallarm_api.db new file mode 100644 index 0000000..35e2bb9 Binary files /dev/null and b/demo/docker-compose/volumes/wallarm_api.db differ diff --git a/go.mod b/go.mod index 23435fb..5fae9a8 100644 --- a/go.mod +++ b/go.mod @@ -1,49 +1,55 @@ module github.com/wallarm/api-firewall -go 1.19 +go 1.20 require ( + github.com/andybalholm/brotli v1.0.5 github.com/ardanlabs/conf v1.5.0 + github.com/clbanning/mxj/v2 v2.7.0 github.com/dgraph-io/ristretto v0.1.1 - github.com/fasthttp/router v1.4.15 - github.com/getkin/kin-openapi v0.112.0 + github.com/fasthttp/router v1.4.20 + github.com/gabriel-vasile/mimetype v1.4.2 + github.com/getkin/kin-openapi v0.118.0 github.com/go-playground/validator v9.31.0+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/mock v1.6.0 github.com/google/uuid v1.3.0 github.com/karlseguin/ccache/v2 v2.0.8 + github.com/klauspost/compress v1.16.7 + github.com/mattn/go-sqlite3 v1.14.17 github.com/pkg/errors v0.9.1 - github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d - github.com/sirupsen/logrus v1.9.0 - github.com/stretchr/testify v1.8.1 - github.com/valyala/fasthttp v1.44.0 + github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee + github.com/sirupsen/logrus v1.9.3 + github.com/stretchr/testify v1.8.4 + github.com/valyala/fasthttp v1.48.0 github.com/valyala/fastjson v1.6.4 - golang.org/x/exp v0.0.0-20230127193734-31bee513bff7 + golang.org/x/exp v0.0.0-20230807204917-050eac23e9de ) require ( - github.com/andybalholm/brotli v1.0.4 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/go-openapi/jsonpointer v0.19.6 // indirect - github.com/go-openapi/swag v0.22.3 // indirect + github.com/go-openapi/jsonpointer v0.20.0 // indirect + github.com/go-openapi/swag v0.22.4 // indirect github.com/go-playground/locales v0.14.1 // indirect - github.com/golang/glog v1.0.0 // indirect + github.com/golang/glog v1.1.1 // indirect github.com/gorilla/mux v1.8.0 // indirect github.com/invopop/yaml v0.2.0 // indirect github.com/josharian/intern v1.0.0 // indirect - github.com/klauspost/compress v1.15.15 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect + github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.4.0 // indirect + golang.org/x/net v0.14.0 // indirect + golang.org/x/sys v0.11.0 // indirect ) require ( - github.com/go-playground/universal-translator v0.18.0 // indirect - github.com/leodido/go-urn v1.2.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/leodido/go-urn v1.2.4 // indirect gopkg.in/go-playground/assert.v1 v1.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index a2dd52a..80d028e 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,12 @@ -github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= -github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= +github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/ardanlabs/conf v1.5.0 h1:5TwP6Wu9Xi07eLFEpiCUF3oQXh9UzHMDVnD3u/I5d5c= github.com/ardanlabs/conf v1.5.0/go.mod h1:ILsMo9dMqYzCxDjDXTiwMI0IgxOJd0MOiucbQY2wlJw= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/clbanning/mxj/v2 v2.7.0 h1:WA/La7UGCanFe5NpHF0Q3DNtnCsVoxbPKuyBNHWRyME= +github.com/clbanning/mxj/v2 v2.7.0/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn/Qo+ve2s= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -16,28 +17,31 @@ github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUn github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/fasthttp/router v1.4.15 h1:ERaILezYX6ks1I+Z2v5qY4vqiQKnujauo9nV6M+HIOg= -github.com/fasthttp/router v1.4.15/go.mod h1:NFNlTCilbRVkeLc+E5JDkcxUdkpiJGKDL8Zy7Ey2JTI= -github.com/getkin/kin-openapi v0.112.0 h1:lnLXx3bAG53EJVI4E/w0N8i1Y/vUZUEsnrXkgnfn7/Y= -github.com/getkin/kin-openapi v0.112.0/go.mod h1:QtwUNt0PAAgIIBEvFWYfB7dfngxtAaqCX1zYHMZDeK8= +github.com/fasthttp/router v1.4.20 h1:yPeNxz5WxZGojzolKqiP15DTXnxZce9Drv577GBrDgU= +github.com/fasthttp/router v1.4.20/go.mod h1:um867yNQKtERxBm+C+yzgWxjspTiQoA8z86Ec3fK/tc= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/getkin/kin-openapi v0.118.0 h1:z43njxPmJ7TaPpMSCQb7PN0dEYno4tyBPQcrFdHoLuM= +github.com/getkin/kin-openapi v0.118.0/go.mod h1:l5e9PaFUo9fyLJCPGQeXI2ML8c3P8BHOEV2VaAVf/pc= github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= -github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE= -github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs= +github.com/go-openapi/jsonpointer v0.20.0 h1:ESKJdU9ASRfaPNOPRx12IUyA1vn3R9GiE3KYD14BXdQ= +github.com/go-openapi/jsonpointer v0.20.0/go.mod h1:6PGzBjjIIumbLYysB73Klnms1mwnU4G3YHOECG3CedA= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= -github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g= -github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= -github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= +github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogBU= +github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= -github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator v9.31.0+incompatible h1:UA72EPEogEnq76ehGdEDp4Mit+3FDh548oRqwVgNsHA= github.com/go-playground/validator v9.31.0+incompatible/go.mod h1:yrEkQXlcI+PugkyDjY2bRrL/UBU4f3rvrgkN3V8JEig= +github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= +github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/glog v1.0.0 h1:nfP3RFugxnNRyKgeWd4oI1nYvXpxrx8ck8ZrcizshdQ= -github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= +github.com/golang/glog v1.1.1 h1:jxpi2eWoU84wbX9iIEyAeeoac3FLuifZpY9tcNUD9kw= +github.com/golang/glog v1.1.1/go.mod h1:zR+okUeTbrL6EL3xHUDxZuEtGv04p5shwip1+mL/rLQ= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -55,84 +59,85 @@ github.com/karlseguin/ccache/v2 v2.0.8 h1:lT38cE//uyf6KcFok0rlgXtGFBWxkI6h/qg4tb github.com/karlseguin/ccache/v2 v2.0.8/go.mod h1:2BDThcfQMf/c0jnZowt16eW405XIqZPavt+HoYEtcxQ= github.com/karlseguin/expect v1.0.2-0.20190806010014-778a5f0c6003 h1:vJ0Snvo+SLMY72r5J4sEfkuE7AFbixEP2qRbEcum/wA= github.com/karlseguin/expect v1.0.2-0.20190806010014-778a5f0c6003/go.mod h1:zNBxMY8P21owkeogJELCLeHIt+voOSduHYTFUbwRAV8= -github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= -github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw= -github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= +github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= +github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= -github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/perimeterx/marshmallow v1.1.4/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= +github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= +github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d h1:Q+gqLBOPkFGHyCJxXMRqtUgUbTjI8/Ze8vu8GGyNFwo= -github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d/go.mod h1:Gy+0tqhJvgGlqnTF8CVGP0AaGRjwBtXs/a5PA0Y3+A4= -github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= -github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk= +github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= +github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= +github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= +github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.44.0 h1:R+gLUhldIsfg1HokMuQjdQ5bh9nuXHPIfvkYUu9eR5Q= -github.com/valyala/fasthttp v1.44.0/go.mod h1:f6VbjjoI3z1NDOZOv17o6RvtRSWxC77seBFc2uWtgiY= +github.com/valyala/fasthttp v1.48.0 h1:oJWvHb9BIZToTQS3MuQ2R3bJZiNSa2KiNdeI8A+79Tc= +github.com/valyala/fasthttp v1.48.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA= github.com/valyala/fastjson v1.6.4 h1:uAUNq9Z6ymTgGhcm0UynUAB6tlbakBrz6CQFax3BXVQ= github.com/valyala/fastjson v1.6.4/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= -github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= github.com/wsxiaoys/terminal v0.0.0-20160513160801-0940f3fc43a0 h1:3UeQBvD0TFrlVjOeLOBz+CPAI8dnbqNSVwUwRrkp7vQ= github.com/wsxiaoys/terminal v0.0.0-20160513160801-0940f3fc43a0/go.mod h1:IXCdmsXIht47RaVFLEdVnh1t+pgYtTAhQGj73kz+2DM= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/exp v0.0.0-20230127193734-31bee513bff7 h1:pXR8mGh4q8ooBT7HXruL4Xa2IxoL8XZ6lOgXY/0Ryg8= -golang.org/x/exp v0.0.0-20230127193734-31bee513bff7/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20230807204917-050eac23e9de h1:l5Za6utMv/HsBWWqzt4S8X17j+kt1uVETUX5UFhn2rE= +golang.org/x/exp v0.0.0-20230807204917-050eac23e9de/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220906165146-f3363e06e74c/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= -golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= @@ -142,7 +147,6 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM= gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/helm/api-firewall/Chart.yaml b/helm/api-firewall/Chart.yaml index de28fa9..54d8a94 100644 --- a/helm/api-firewall/Chart.yaml +++ b/helm/api-firewall/Chart.yaml @@ -1,7 +1,7 @@ apiVersion: v1 name: api-firewall -version: 0.6.11 -appVersion: 0.6.11 +version: 0.6.12 +appVersion: 0.6.12 description: Wallarm OpenAPI-based API Firewall home: https://github.com/wallarm/api-firewall icon: https://static.wallarm.com/wallarm-logo.svg diff --git a/helm/api-firewall/templates/deployment.yaml b/helm/api-firewall/templates/deployment.yaml index 3e4fc3b..6042ed6 100644 --- a/helm/api-firewall/templates/deployment.yaml +++ b/helm/api-firewall/templates/deployment.yaml @@ -84,6 +84,8 @@ spec: securityContext: {{ toYaml .Values.apiFirewall.securityContext | trimSuffix "\n" | nindent 10 }} {{ end -}} env: + - name: APIFW_MODE + value: {{ .Values.apiFirewall.config.mode | quote }} - name: APIFW_URL value: http://{{ .Values.apiFirewall.config.listenAddress }}:{{ .Values.apiFirewall.config.listenPort }} {{ if .Values.manifest.enabled -}} @@ -104,6 +106,12 @@ spec: value: {{ .Values.apiFirewall.config.validationMode.request | upper }} - name: APIFW_RESPONSE_VALIDATION value: {{ .Values.apiFirewall.config.validationMode.response | upper }} + - name: APIFW_PASS_OPTIONS + value: {{ .Values.apiFirewall.config.passOptions | quote }} + - name: APIFW_SHADOW_API_EXCLUDE_LIST + value: {{ .Values.apiFirewall.config.shadowAPI.excludeList | quote }} + - name: APIFW_SHADOW_API_UNKNOWN_PARAMETERS_DETECTION + value: {{ .Values.apiFirewall.config.shadowAPI.unknownParametersDetection | quote }} {{- if .Values.apiFirewall.extraEnvs -}} {{ toYaml .Values.apiFirewall.extraEnvs | trimSuffix "\n" | nindent 10 }} {{- end }} diff --git a/helm/api-firewall/values.yaml b/helm/api-firewall/values.yaml index 82c4fac..00d773e 100644 --- a/helm/api-firewall/values.yaml +++ b/helm/api-firewall/values.yaml @@ -70,6 +70,7 @@ apiFirewall: ## Main settings of API Firewall config: + mode: proxy listenAddress: 0.0.0.0 listenPort: 8080 maxConnsPerHost: 512 @@ -80,6 +81,10 @@ apiFirewall: validationMode: request: block response: block + shadowAPI: + excludeList: "404" + unknownParametersDetection: true + passOptions: false ## Number of deployment replicas for the API Firewall container ## https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.21/#deploymentspec-v1-apps diff --git a/internal/config/config.go b/internal/config/config.go index 704de6b..8029856 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -59,25 +59,51 @@ type Oauth struct { } type ShadowAPI struct { - ExcludeList []int `conf:"default:404,env:SHADOW_API_EXCLUDE_LIST" validate:"HttpStatusCodes"` + ExcludeList []int `conf:"default:404,env:SHADOW_API_EXCLUDE_LIST" validate:"HttpStatusCodes"` + UnknownParametersDetection bool `conf:"default:true,env:SHADOW_API_UNKNOWN_PARAMETERS_DETECTION"` } type APIFWConfiguration struct { conf.Version - TLS TLS - Server Server - - APIHost string `conf:"default:http://0.0.0.0:8282,env:URL" validate:"required,url"` - HealthAPIHost string `conf:"default:0.0.0.0:9667,env:HEALTH_HOST" validate:"required"` - ReadTimeout time.Duration `conf:"default:5s"` - WriteTimeout time.Duration `conf:"default:5s"` - LogLevel string `conf:"default:DEBUG" validate:"required,oneof=DEBUG INFO ERROR WARNING"` - LogFormat string `conf:"default:TEXT" validate:"required,oneof=TEXT JSON"` - RequestValidation string `conf:"required" validate:"required,oneof=DISABLE BLOCK LOG_ONLY"` - ResponseValidation string `conf:"required" validate:"required,oneof=DISABLE BLOCK LOG_ONLY"` - CustomBlockStatusCode int `conf:"default:403" validate:"HttpStatusCodes"` - AddValidationStatusHeader bool `conf:"default:false"` - APISpecs string `conf:"default:swagger.json,env:API_SPECS"` - ShadowAPI ShadowAPI - Denylist Denylist + APIFWMode + TLS TLS + ShadowAPI ShadowAPI + Denylist Denylist + Server Server + + APIHost string `conf:"default:http://0.0.0.0:8282,env:URL" validate:"required,url"` + HealthAPIHost string `conf:"default:0.0.0.0:9667,env:HEALTH_HOST" validate:"required"` + ReadTimeout time.Duration `conf:"default:5s"` + WriteTimeout time.Duration `conf:"default:5s"` + LogLevel string `conf:"default:INFO" validate:"oneof=TRACE DEBUG INFO ERROR WARNING"` + LogFormat string `conf:"default:TEXT" validate:"oneof=TEXT JSON"` + + RequestValidation string `conf:"required" validate:"required,oneof=DISABLE BLOCK LOG_ONLY"` + ResponseValidation string `conf:"required" validate:"required,oneof=DISABLE BLOCK LOG_ONLY"` + CustomBlockStatusCode int `conf:"default:403" validate:"HttpStatusCodes"` + AddValidationStatusHeader bool `conf:"default:false"` + APISpecs string `conf:"default:swagger.json,env:API_SPECS"` + PassOptionsRequests bool `conf:"default:false,env:PASS_OPTIONS"` +} + +type APIFWConfigurationAPIMode struct { + conf.Version + APIFWMode + TLS TLS + + SpecificationUpdatePeriod time.Duration `conf:"default:1m,env:API_MODE_SPECIFICATION_UPDATE_PERIOD"` + PathToSpecDB string `conf:"env:API_MODE_DEBUG_PATH_DB"` + UnknownParametersDetection bool `conf:"default:true,env:API_MODE_UNKNOWN_PARAMETERS_DETECTION"` + + APIHost string `conf:"default:http://0.0.0.0:8282,env:URL" validate:"required,url"` + HealthAPIHost string `conf:"default:0.0.0.0:9667,env:HEALTH_HOST" validate:"required"` + ReadTimeout time.Duration `conf:"default:5s"` + WriteTimeout time.Duration `conf:"default:5s"` + LogLevel string `conf:"default:INFO" validate:"oneof=TRACE DEBUG INFO ERROR WARNING"` + LogFormat string `conf:"default:TEXT" validate:"oneof=TEXT JSON"` + PassOptionsRequests bool `conf:"default:false,env:PASS_OPTIONS"` +} + +type APIFWMode struct { + Mode string `conf:"default:PROXY" validate:"oneof=PROXY API"` } diff --git a/internal/mid/denylist.go b/internal/mid/denylist.go index ca94691..e30d462 100644 --- a/internal/mid/denylist.go +++ b/internal/mid/denylist.go @@ -24,7 +24,7 @@ func Denylist(cfg *config.APIFWConfiguration, deniedTokens *denylist.DeniedToken if cfg.Denylist.Tokens.CookieName != "" { token := string(ctx.Request.Header.Cookie(cfg.Denylist.Tokens.CookieName)) if _, found := deniedTokens.Cache.Get(token); found { - return web.RespondError(ctx, cfg.CustomBlockStatusCode, nil) + return web.RespondError(ctx, cfg.CustomBlockStatusCode, "") } } if cfg.Denylist.Tokens.HeaderName != "" { @@ -33,7 +33,7 @@ func Denylist(cfg *config.APIFWConfiguration, deniedTokens *denylist.DeniedToken token = strings.TrimPrefix(token, "Bearer ") } if _, found := deniedTokens.Cache.Get(token); found { - return web.RespondError(ctx, cfg.CustomBlockStatusCode, nil) + return web.RespondError(ctx, cfg.CustomBlockStatusCode, "") } } } diff --git a/internal/mid/errors.go b/internal/mid/errors.go index de0c4e9..85058eb 100644 --- a/internal/mid/errors.go +++ b/internal/mid/errors.go @@ -28,7 +28,7 @@ func Errors(logger *logrus.Logger) web.Middleware { }).Error("common error") // Respond to the error. - if err := web.RespondError(ctx, fasthttp.StatusBadGateway, nil); err != nil { + if err := web.RespondError(ctx, fasthttp.StatusInternalServerError, ""); err != nil { return err } diff --git a/internal/mid/logger.go b/internal/mid/logger.go index faef67b..423fb08 100644 --- a/internal/mid/logger.go +++ b/internal/mid/logger.go @@ -22,15 +22,34 @@ func Logger(logger *logrus.Logger) web.Middleware { err := before(ctx) + // check method and path + if isProxyNoRouteValue := ctx.Value(web.RequestProxyNoRoute); isProxyNoRouteValue != nil { + if isProxyNoRouteValue.(bool) { + logger.WithFields(logrus.Fields{ + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + "status_code": ctx.Response.StatusCode(), + "response_length": fmt.Sprintf("%d", ctx.Response.Header.ContentLength()), + "method": string(ctx.Request.Header.Method()), + "path": string(ctx.Path()), + "uri": string(ctx.Request.URI().RequestURI()), + "client_address": ctx.RemoteAddr(), + }).Error("method or path not found in the OpenAPI specification") + } + } + logger.WithFields(logrus.Fields{ "request_id": fmt.Sprintf("#%016X", ctx.ID()), "status_code": ctx.Response.StatusCode(), - "method": fmt.Sprintf("%s", ctx.Request.Header.Method()), - "path": fmt.Sprintf("%s", ctx.Path()), + "method": string(ctx.Request.Header.Method()), + "path": string(ctx.Path()), + "uri": string(ctx.Request.URI().RequestURI()), "client_address": ctx.RemoteAddr(), "processing_time": time.Since(start), }).Debug("new request") + // log all information about the request + web.LogRequestResponseAtTraceLevel(ctx, logger) + // Return the error, so it can be handled further up the chain. return err } diff --git a/internal/mid/mimetype.go b/internal/mid/mimetype.go new file mode 100644 index 0000000..80c06c4 --- /dev/null +++ b/internal/mid/mimetype.go @@ -0,0 +1,59 @@ +package mid + +import ( + "fmt" + + "github.com/gabriel-vasile/mimetype" + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" + "github.com/wallarm/api-firewall/internal/platform/web" +) + +// MIMETypeIdentifier identifies the MIME type of the content in case of CT header is missing +func MIMETypeIdentifier(logger *logrus.Logger) web.Middleware { + + // This is the actual middleware function to be executed. + m := func(before web.Handler) web.Handler { + + // Create the handler that will be attached in the middleware chain. + h := func(ctx *fasthttp.RequestCtx) error { + + // get current Wallarm schema ID + if len(ctx.Request.Header.ContentType()) == 0 && len(ctx.Request.Body()) > 0 { + // decode request body + requestContentEncoding := string(ctx.Request.Header.ContentEncoding()) + if requestContentEncoding != "" { + body, err := web.GetDecompressedRequestBody(&ctx.Request, requestContentEncoding) + if err != nil { + logger.WithFields(logrus.Fields{ + "error": err, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("request body decompression error") + return web.RespondError(ctx, fasthttp.StatusInternalServerError, "") + } + mtype, err := mimetype.DetectReader(body) + if err != nil { + logger.WithFields(logrus.Fields{ + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("schema version mismatch") + return web.RespondError(ctx, fasthttp.StatusInternalServerError, "") + } + + // set the identified mime type + ctx.Request.Header.SetContentType(mtype.String()) + } + + // set the identified mime type + ctx.Request.Header.SetContentType(mimetype.Detect(ctx.Request.Body()).String()) + } + + err := before(ctx) + + return err + } + + return h + } + + return m +} diff --git a/internal/mid/shadowAPI.go b/internal/mid/shadowAPI.go new file mode 100644 index 0000000..8f2a73f --- /dev/null +++ b/internal/mid/shadowAPI.go @@ -0,0 +1,73 @@ +package mid + +import ( + "fmt" + + "golang.org/x/exp/slices" + + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" + "github.com/wallarm/api-firewall/internal/config" + "github.com/wallarm/api-firewall/internal/platform/web" +) + +// ShadowAPIMonitor check each request for the params, methods or paths that are not specified +// in the OpenAPI specification and log each violation +func ShadowAPIMonitor(logger *logrus.Logger, config *config.ShadowAPI) web.Middleware { + + // This is the actual middleware function to be executed. + m := func(before web.Handler) web.Handler { + + // Create the handler that will be attached in the middleware chain. + h := func(ctx *fasthttp.RequestCtx) error { + + err := before(ctx) + + if isProxyFailedValue := ctx.Value(web.RequestProxyFailed); isProxyFailedValue != nil { + if isProxyFailedValue.(bool) { + return err + } + } + + // skip check if request has been blocked + if isBlockedValue := ctx.Value(web.RequestBlocked); isBlockedValue != nil { + if isBlockedValue.(bool) { + return err + } + } + + currentMethod := string(ctx.Request.Header.Method()) + currentPath := string(ctx.Path()) + + // get the response status code presence in the OpenAPI status + isProxyStatusCodeNotFound := false + statusCodeNotFoundValue := ctx.Value(web.ResponseStatusNotFound) + if statusCodeNotFoundValue != nil { + isProxyStatusCodeNotFound = statusCodeNotFoundValue.(bool) + } + + // check response status code + statusCode := ctx.Response.StatusCode() + idx := slices.IndexFunc(config.ExcludeList, func(c int) bool { return c == statusCode }) + // if response status code not found in the OpenAPI spec AND the code not in the exclude list + if isProxyStatusCodeNotFound && idx < 0 { + logger.WithFields(logrus.Fields{ + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + "status_code": ctx.Response.StatusCode(), + "response_length": fmt.Sprintf("%d", ctx.Response.Header.ContentLength()), + "method": currentMethod, + "path": currentPath, + "client_address": ctx.RemoteAddr(), + "violation": "shadow_api", + }).Error("Shadow API detected: response status code not found in the OpenAPI specification") + } + + // Return the error, so it can be handled further up the chain. + return err + } + + return h + } + + return m +} diff --git a/internal/platform/database/database.go b/internal/platform/database/database.go new file mode 100644 index 0000000..3af0b14 --- /dev/null +++ b/internal/platform/database/database.go @@ -0,0 +1,194 @@ +package database + +import ( + "bytes" + "database/sql" + "fmt" + "os" + "sync" + "time" + + "github.com/getkin/kin-openapi/openapi3" + _ "github.com/mattn/go-sqlite3" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +const currentSQLSchemaVersion = 1 + +type OpenAPISpecStorage struct { + Specs []SpecificationEntry +} + +type SpecificationEntry struct { + SchemaID int `db:"schema_id"` + SchemaVersion string `db:"schema_version"` + SchemaFormat string `db:"schema_format"` + SchemaContent string `db:"schema_content"` +} + +type DBOpenAPILoader interface { + Load(dbStoragePath string) error + SpecificationRaw(schemaID int) []byte + SpecificationVersion(schemaID int) string + Specification(schemaID int) *openapi3.T + IsLoaded(schemaID int) bool + SchemaIDs() []int +} + +type SQLLite struct { + Log *logrus.Logger + RawSpecs map[int]*SpecificationEntry + LastUpdate time.Time + OpenAPISpec map[int]*openapi3.T + lock *sync.RWMutex +} + +func getSpecBytes(spec string) []byte { + return bytes.NewBufferString(spec).Bytes() +} + +func NewOpenAPIDB(log *logrus.Logger, dbStoragePath string) (DBOpenAPILoader, error) { + + sqlObj := SQLLite{ + Log: log, + lock: &sync.RWMutex{}, + } + + if err := sqlObj.Load(dbStoragePath); err != nil { + return nil, err + } + + log.Debugf("OpenAPI specifications with the following IDs has been loaded: %v", sqlObj.SchemaIDs()) + + return &sqlObj, nil +} + +func (s *SQLLite) Load(dbStoragePath string) error { + + entries := make(map[int]*SpecificationEntry) + specs := make(map[int]*openapi3.T) + + currentDBPath := dbStoragePath + if currentDBPath == "" { + currentDBPath = fmt.Sprintf("/var/lib/wallarm-api/%d/wallarm_api.db", currentSQLSchemaVersion) + } + + // check if file exists + if _, err := os.Stat(currentDBPath); errors.Is(err, os.ErrNotExist) { + return err + } + + db, err := sql.Open("sqlite3", currentDBPath) + if err != nil { + return err + } + defer db.Close() + + rows, err := db.Query("select schema_id,schema_version,schema_format,schema_content from openapi_schemas") + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + entry := SpecificationEntry{} + err = rows.Scan(&entry.SchemaID, &entry.SchemaVersion, &entry.SchemaFormat, &entry.SchemaContent) + if err != nil { + return err + } + entries[entry.SchemaID] = &entry + } + + if err = rows.Err(); err != nil { + return err + } + + s.RawSpecs = entries + s.LastUpdate = time.Now().UTC() + + for schemaID, spec := range s.RawSpecs { + + // parse specification + loader := openapi3.NewLoader() + parsedSpec, err := loader.LoadFromData(getSpecBytes(spec.SchemaContent)) + if err != nil { + s.Log.Errorf("error: parsing of the OpenAPI specification %s (schema ID %d): %v", spec.SchemaVersion, schemaID, err) + delete(s.RawSpecs, schemaID) + continue + } + + if err := parsedSpec.Validate(loader.Context); err != nil { + s.Log.Errorf("error: validation of the OpenAPI specification %s (schema ID %d): %v", spec.SchemaVersion, schemaID, err) + delete(s.RawSpecs, schemaID) + continue + } + + specs[spec.SchemaID] = parsedSpec + } + + if len(specs) == 0 { + return errors.New("no OpenAPI specs has been loaded") + } + + s.lock.Lock() + defer s.lock.Unlock() + + s.RawSpecs = entries + s.OpenAPISpec = specs + + return nil +} + +func (s *SQLLite) Specification(schemaID int) *openapi3.T { + s.lock.RLock() + defer s.lock.RUnlock() + + spec, ok := s.OpenAPISpec[schemaID] + if !ok { + return nil + } + return spec +} + +func (s *SQLLite) SpecificationRaw(schemaID int) []byte { + s.lock.RLock() + defer s.lock.RUnlock() + + rawSpec, ok := s.RawSpecs[schemaID] + if !ok { + return nil + } + return getSpecBytes(rawSpec.SchemaContent) +} + +func (s *SQLLite) SpecificationVersion(schemaID int) string { + s.lock.RLock() + defer s.lock.RUnlock() + + rawSpec, ok := s.RawSpecs[schemaID] + if !ok { + return "" + } + return rawSpec.SchemaVersion +} + +func (s *SQLLite) IsLoaded(schemaID int) bool { + s.lock.RLock() + defer s.lock.RUnlock() + + _, ok := s.OpenAPISpec[schemaID] + return ok +} + +func (s *SQLLite) SchemaIDs() []int { + s.lock.RLock() + defer s.lock.RUnlock() + + var schemaIDs []int + for _, spec := range s.RawSpecs { + schemaIDs = append(schemaIDs, spec.SchemaID) + } + + return schemaIDs +} diff --git a/internal/platform/database/database_mock.go b/internal/platform/database/database_mock.go new file mode 100644 index 0000000..0250eb6 --- /dev/null +++ b/internal/platform/database/database_mock.go @@ -0,0 +1,119 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./internal/platform/database/database.go + +// Package database is a generated GoMock package. +package database + +import ( + reflect "reflect" + + openapi3 "github.com/getkin/kin-openapi/openapi3" + gomock "github.com/golang/mock/gomock" +) + +// MockDBOpenAPILoader is a mock of DBOpenAPILoader interface. +type MockDBOpenAPILoader struct { + ctrl *gomock.Controller + recorder *MockDBOpenAPILoaderMockRecorder +} + +// MockDBOpenAPILoaderMockRecorder is the mock recorder for MockDBOpenAPILoader. +type MockDBOpenAPILoaderMockRecorder struct { + mock *MockDBOpenAPILoader +} + +// NewMockDBOpenAPILoader creates a new mock instance. +func NewMockDBOpenAPILoader(ctrl *gomock.Controller) *MockDBOpenAPILoader { + mock := &MockDBOpenAPILoader{ctrl: ctrl} + mock.recorder = &MockDBOpenAPILoaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDBOpenAPILoader) EXPECT() *MockDBOpenAPILoaderMockRecorder { + return m.recorder +} + +// IsLoaded mocks base method. +func (m *MockDBOpenAPILoader) IsLoaded(schemaID int) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsLoaded", schemaID) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsLoaded indicates an expected call of IsLoaded. +func (mr *MockDBOpenAPILoaderMockRecorder) IsLoaded(schemaID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsLoaded", reflect.TypeOf((*MockDBOpenAPILoader)(nil).IsLoaded), schemaID) +} + +// Load mocks base method. +func (m *MockDBOpenAPILoader) Load(dbStoragePath string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Load", dbStoragePath) + ret0, _ := ret[0].(error) + return ret0 +} + +// Load indicates an expected call of Load. +func (mr *MockDBOpenAPILoaderMockRecorder) Load(dbStoragePath interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockDBOpenAPILoader)(nil).Load), dbStoragePath) +} + +// SchemaIDs mocks base method. +func (m *MockDBOpenAPILoader) SchemaIDs() []int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SchemaIDs") + ret0, _ := ret[0].([]int) + return ret0 +} + +// SchemaIDs indicates an expected call of SchemaIDs. +func (mr *MockDBOpenAPILoaderMockRecorder) SchemaIDs() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaIDs", reflect.TypeOf((*MockDBOpenAPILoader)(nil).SchemaIDs)) +} + +// Specification mocks base method. +func (m *MockDBOpenAPILoader) Specification(schemaID int) *openapi3.T { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Specification", schemaID) + ret0, _ := ret[0].(*openapi3.T) + return ret0 +} + +// Specification indicates an expected call of Specification. +func (mr *MockDBOpenAPILoaderMockRecorder) Specification(schemaID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Specification", reflect.TypeOf((*MockDBOpenAPILoader)(nil).Specification), schemaID) +} + +// SpecificationRaw mocks base method. +func (m *MockDBOpenAPILoader) SpecificationRaw(schemaID int) []byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SpecificationRaw", schemaID) + ret0, _ := ret[0].([]byte) + return ret0 +} + +// SpecificationRaw indicates an expected call of SpecificationRaw. +func (mr *MockDBOpenAPILoaderMockRecorder) SpecificationRaw(schemaID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SpecificationRaw", reflect.TypeOf((*MockDBOpenAPILoader)(nil).SpecificationRaw), schemaID) +} + +// SpecificationVersion mocks base method. +func (m *MockDBOpenAPILoader) SpecificationVersion(schemaID int) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SpecificationVersion", schemaID) + ret0, _ := ret[0].(string) + return ret0 +} + +// SpecificationVersion indicates an expected call of SpecificationVersion. +func (mr *MockDBOpenAPILoaderMockRecorder) SpecificationVersion(schemaID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SpecificationVersion", reflect.TypeOf((*MockDBOpenAPILoader)(nil).SpecificationVersion), schemaID) +} diff --git a/internal/platform/database/database_test.go b/internal/platform/database/database_test.go new file mode 100644 index 0000000..f8f93de --- /dev/null +++ b/internal/platform/database/database_test.go @@ -0,0 +1,120 @@ +package database + +import ( + "bytes" + "sort" + "testing" + + "github.com/sirupsen/logrus" +) + +const ( + testSchemaID1 = 1 + testSpecVersion1 = "1" + testSchemaID2 = 4 + testSpecVersion2 = "2" +) + +const ( + testOpenAPIScheme1 = `openapi: 3.0.1 +info: + title: Minimal integer field example + version: 0.0.1 +paths: + /ok: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + type: object + required: + - status + properties: + status: + type: string + example: "success" + error: + type: string` + testOpenAPIScheme2 = `{ + "openapi": "3.0.1", + "info": { + "title": "Minimal integer field example", + "version": "0.0.1" + }, + "paths": { + "/wrong": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "type": "object", + "required": [ + "status" + ], + "properties": { + "status": { + "type": "string", + "example": "example" + }, + "error": { + "type": "string" + } + } + } + } + } + } + } + } + } + } +}` +) + +func TestBasicDBSpecsLoading(t *testing.T) { + + logger := logrus.New() + logger.SetLevel(logrus.ErrorLevel) + + dbSpec, err := NewOpenAPIDB(logger, "../../../resources/test/database/wallarm_api.db") + if err != nil { + t.Fatal(err) + } + + // test first OpenAPI spec + openAPISpec := bytes.Trim(dbSpec.SpecificationRaw(testSchemaID1), "\xef\xbb\xbf") + if !bytes.Equal(openAPISpec, bytes.NewBufferString(testOpenAPIScheme1).Bytes()) { + t.Error("loaded and the original specifications are not equal") + } + + loadedSchemaIDs := dbSpec.SchemaIDs() + sort.Ints(loadedSchemaIDs) + + if len(loadedSchemaIDs) != 2 || loadedSchemaIDs[0] != testSchemaID1 { + t.Error("loaded and the original schema IDs are not equal") + } + + if testSpecVersion1 != dbSpec.SpecificationVersion(testSchemaID1) { + t.Error("loaded and the original specifications versions are not equal") + } + + // test second OpenAPI spec + openAPISpec = bytes.Trim(dbSpec.SpecificationRaw(testSchemaID2), "\xef\xbb\xbf") + if !bytes.Equal(openAPISpec, bytes.NewBufferString(testOpenAPIScheme2).Bytes()) { + t.Error("loaded and the original specifications are not equal") + } + + if len(loadedSchemaIDs) != 2 || loadedSchemaIDs[1] != testSchemaID2 { + t.Error("loaded and the original schema IDs are not equal") + } + + if testSpecVersion2 != dbSpec.SpecificationVersion(testSchemaID2) { + t.Error("loaded and the original specifications versions are not equal") + } +} diff --git a/internal/platform/router/router.go b/internal/platform/router/router.go index 7c08452..21150b3 100644 --- a/internal/platform/router/router.go +++ b/internal/platform/router/router.go @@ -7,17 +7,20 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/routers" + "github.com/wallarm/api-firewall/internal/platform/database" ) // Router helps link http.Request.s and an OpenAPIv3 spec type Router struct { - Routes []Route + Routes []CustomRoute + SchemaVersion string } -type Route struct { - Route *routers.Route - Path string - Method string +type CustomRoute struct { + Route *routers.Route + Path string + Method string + ParametersNumberInPath int } // NewRouter creates a new router. @@ -40,12 +43,48 @@ func NewRouter(doc *openapi3.T) (*Router, error) { Method: method, Operation: operation, } - router.Routes = append(router.Routes, Route{ - Route: &route, - Path: path, - Method: method, + + // count number of parameters in the path + pathParamLength := 0 + if getOp := pathItem.GetOperation(route.Method); getOp != nil { + for _, param := range getOp.Parameters { + if param.Value.In == openapi3.ParameterInPath { + pathParamLength += 1 + } + } + } + + // check common parameters + if getOp := pathItem.Parameters; getOp != nil { + for _, param := range getOp { + if param.Value.In == openapi3.ParameterInPath { + pathParamLength += 1 + } + } + } + + router.Routes = append(router.Routes, CustomRoute{ + Route: &route, + Path: path, + Method: method, + ParametersNumberInPath: pathParamLength, }) } } + return &router, nil } + +// NewRouterDBLoader creates a new router based on DB OpenAPI loader. +func NewRouterDBLoader(schemaID int, openAPISpec database.DBOpenAPILoader) (*Router, error) { + doc := openAPISpec.Specification(schemaID) + + router, err := NewRouter(doc) + if err != nil { + return nil, err + } + + router.SchemaVersion = openAPISpec.SpecificationVersion(schemaID) + + return router, nil +} diff --git a/internal/platform/shadowAPI/shadowAPI.go b/internal/platform/shadowAPI/shadowAPI.go deleted file mode 100644 index 54b1ba0..0000000 --- a/internal/platform/shadowAPI/shadowAPI.go +++ /dev/null @@ -1,43 +0,0 @@ -package shadowAPI - -import ( - "fmt" - - "github.com/sirupsen/logrus" - "github.com/valyala/fasthttp" - "github.com/wallarm/api-firewall/internal/config" - - "golang.org/x/exp/slices" -) - -type Checker interface { - Check(ctx *fasthttp.RequestCtx) error -} - -type ShadowAPI struct { - Config *config.ShadowAPI - Logger *logrus.Logger -} - -func New(config *config.ShadowAPI, logger *logrus.Logger) Checker { - return &ShadowAPI{ - Config: config, - Logger: logger, - } -} - -func (s *ShadowAPI) Check(ctx *fasthttp.RequestCtx) error { - statusCode := ctx.Response.StatusCode() - idx := slices.IndexFunc(s.Config.ExcludeList, func(c int) bool { return c == statusCode }) - if idx < 0 { - s.Logger.WithFields(logrus.Fields{ - "request_id": fmt.Sprintf("#%016X", ctx.ID()), - "status_code": ctx.Response.StatusCode(), - "response_length": fmt.Sprintf("%d", ctx.Response.Header.ContentLength()), - "method": fmt.Sprintf("%s", ctx.Request.Header.Method()), - "path": fmt.Sprintf("%s", ctx.Path()), - "client_address": ctx.RemoteAddr(), - }).Error("Shadow API detected") - } - return nil -} diff --git a/internal/platform/shadowAPI/shadowAPI_mock.go b/internal/platform/shadowAPI/shadowAPI_mock.go deleted file mode 100644 index ff1583a..0000000 --- a/internal/platform/shadowAPI/shadowAPI_mock.go +++ /dev/null @@ -1,49 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: ./internal/platform/shadowAPI/shadowAPI.go - -// Package shadowAPI is a generated GoMock package. -package shadowAPI - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - fasthttp "github.com/valyala/fasthttp" -) - -// MockChecker is a mock of Checker interface. -type MockChecker struct { - ctrl *gomock.Controller - recorder *MockCheckerMockRecorder -} - -// MockCheckerMockRecorder is the mock recorder for MockChecker. -type MockCheckerMockRecorder struct { - mock *MockChecker -} - -// NewMockChecker creates a new mock instance. -func NewMockChecker(ctrl *gomock.Controller) *MockChecker { - mock := &MockChecker{ctrl: ctrl} - mock.recorder = &MockCheckerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockChecker) EXPECT() *MockCheckerMockRecorder { - return m.recorder -} - -// Check mocks base method. -func (m *MockChecker) Check(ctx *fasthttp.RequestCtx) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Check", ctx) - ret0, _ := ret[0].(error) - return ret0 -} - -// Check indicates an expected call of Check. -func (mr *MockCheckerMockRecorder) Check(ctx interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Check", reflect.TypeOf((*MockChecker)(nil).Check), ctx) -} diff --git a/internal/platform/validator/req_resp_decoder.go b/internal/platform/validator/req_resp_decoder.go index 6cec077..6670b95 100644 --- a/internal/platform/validator/req_resp_decoder.go +++ b/internal/platform/validator/req_resp_decoder.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "mime" "mime/multipart" "net/http" @@ -17,11 +16,12 @@ import ( "strconv" "strings" - "github.com/valyala/fastjson" "gopkg.in/yaml.v3" + "github.com/clbanning/mxj/v2" "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/openapi3filter" + "github.com/valyala/fastjson" ) // ParseErrorKind describes a kind of ParseError. @@ -895,6 +895,14 @@ func parsePrimitive(raw string, schema *openapi3.SchemaRef) (interface{}, error) } switch schema.Value.Type { case "integer": + if len(schema.Value.Enum) > 0 { + // parse int as float because of the comparison with float enum values + v, err := strconv.ParseFloat(raw, 64) + if err != nil { + return nil, &ParseError{Kind: KindInvalidFormat, Value: raw, Reason: "an invalid " + schema.Value.Type, Cause: err.(*strconv.NumError).Err} + } + return v, nil + } if schema.Value.Format == "int32" { v, err := strconv.ParseInt(raw, 0, 32) if err != nil { @@ -986,8 +994,14 @@ func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, if contentType == "" { if _, ok := body.(*multipart.Part); ok { contentType = "text/plain" + value, err := multipartPartBodyDecoder(body, header, schema, encFn, jsonParser) + if err != nil { + return "", nil, err + } + return parseMediaType(contentType), value, nil } } + mediaType := parseMediaType(contentType) decoder, ok := bodyDecoders[mediaType] if !ok { @@ -1005,6 +1019,7 @@ func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, func init() { RegisterBodyDecoder("application/json", jsonBodyDecoder) + RegisterBodyDecoder("application/xml", xmlBodyDecoder) RegisterBodyDecoder("application/json-patch+json", jsonBodyDecoder) RegisterBodyDecoder("application/octet-stream", FileBodyDecoder) RegisterBodyDecoder("application/problem+json", jsonBodyDecoder) @@ -1018,13 +1033,39 @@ func init() { } func plainBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (interface{}, error) { - data, err := ioutil.ReadAll(body) + data, err := io.ReadAll(body) if err != nil { return nil, &ParseError{Kind: KindInvalidFormat, Cause: err} } return string(data), nil } +func multipartPartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (interface{}, error) { + data, err := io.ReadAll(body) + if err != nil { + return nil, &ParseError{Kind: KindInvalidFormat, Cause: err} + } + + dataStr := string(data) + + switch schema.Value.Type { + case "integer", "number": + floatValue, err := strconv.ParseFloat(dataStr, 64) + if err != nil { + return nil, &ParseError{Kind: KindInvalidFormat, Cause: err} + } + return floatValue, nil + case "boolean": + boolValue, err := strconv.ParseBool(dataStr) + if err != nil { + return nil, &ParseError{Kind: KindInvalidFormat, Cause: err} + } + return boolValue, nil + } + + return dataStr, nil +} + func jsonBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (interface{}, error) { var data []byte var err error @@ -1042,6 +1083,23 @@ func jsonBodyDecoder(body io.Reader, header http.Header, schema *openapi3.Schema return convertToMap(parsedDoc), nil } +func xmlBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (interface{}, error) { + var data []byte + var err error + + data, err = io.ReadAll(body) + if err != nil { + return nil, &ParseError{Kind: KindInvalidFormat, Cause: err} + } + + mv, err := mxj.NewMapXml(data) + if err != nil { + return nil, &ParseError{Kind: KindInvalidFormat, Cause: err} + } + + return mv.Old(), nil +} + func yamlBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (interface{}, error) { var value interface{} if err := yaml.NewDecoder(body).Decode(&value); err != nil { @@ -1051,12 +1109,7 @@ func yamlBodyDecoder(body io.Reader, header http.Header, schema *openapi3.Schema } func urlencodedBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (interface{}, error) { - // Validate schema of request body. - // By the OpenAPI 3 specification request body's schema must have type "object". - // Properties of the schema describes individual parts of request body. - if schema.Value.Type != "object" { - return nil, errors.New("unsupported schema of request body") - } + for propName, propSchema := range schema.Value.Properties { switch propSchema.Value.Type { case "object": @@ -1070,7 +1123,7 @@ func urlencodedBodyDecoder(body io.Reader, header http.Header, schema *openapi3. } // Parse form. - b, err := ioutil.ReadAll(body) + b, err := io.ReadAll(body) if err != nil { return nil, err } @@ -1092,19 +1145,19 @@ func urlencodedBodyDecoder(body io.Reader, header http.Header, schema *openapi3. } sm := enc.SerializationMethod() - if value, _, err = decodeValue(dec, name, sm, prop, false); err != nil { + found := false + if value, found, err = decodeValue(dec, name, sm, prop, false); err != nil { return nil, err } - obj[name] = value + if found { + obj[name] = value + } } return obj, nil } func multipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (interface{}, error) { - if schema.Value.Type != "object" { - return nil, errors.New("unsupported schema of request body") - } // Parse form. values := make(map[string][]interface{}) @@ -1131,33 +1184,43 @@ func multipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.S enc = encFn(name) } subEncFn := func(string) *openapi3.Encoding { return enc } - // If the property's schema has type "array" it is means that the form contains a few parts with the same name. - // Every such part has a type that is defined by an items schema in the property's schema. + var valueSchema *openapi3.SchemaRef - var exists bool - valueSchema, exists = schema.Value.Properties[name] - if !exists { - anyProperties := schema.Value.AdditionalPropertiesAllowed - if anyProperties != nil { - switch *anyProperties { - case true: - //additionalProperties: true - continue - default: - //additionalProperties: false - return nil, &ParseError{Kind: KindOther, Cause: fmt.Errorf("part %s: undefined", name)} + if len(schema.Value.AllOf) > 0 { + var exists bool + for _, sr := range schema.Value.AllOf { + if valueSchema, exists = sr.Value.Properties[name]; exists { + break } } - if schema.Value.AdditionalProperties == nil { - return nil, &ParseError{Kind: KindOther, Cause: fmt.Errorf("part %s: undefined", name)} - } - valueSchema, exists = schema.Value.AdditionalProperties.Value.Properties[name] if !exists { return nil, &ParseError{Kind: KindOther, Cause: fmt.Errorf("part %s: undefined", name)} } - } - if valueSchema.Value.Type == "array" { - valueSchema = valueSchema.Value.Items + } else { + // If the property's schema has type "array" it is means that the form contains a few parts with the same name. + // Every such part has a type that is defined by an items schema in the property's schema. + var exists bool + if valueSchema, exists = schema.Value.Properties[name]; !exists { + if anyProperties := schema.Value.AdditionalProperties.Has; anyProperties != nil { + switch *anyProperties { + case true: + //additionalProperties: true + continue + default: + //additionalProperties: false + return nil, &ParseError{Kind: KindOther, Cause: fmt.Errorf("part %s: undefined", name)} + } + } + if schema.Value.AdditionalProperties.Schema == nil { + return nil, &ParseError{Kind: KindOther, Cause: fmt.Errorf("part %s: undefined", name)} + } + if valueSchema, exists = schema.Value.AdditionalProperties.Schema.Value.Properties[name]; !exists { + return nil, &ParseError{Kind: KindOther, Cause: fmt.Errorf("part %s: undefined", name)} + } + } + if valueSchema.Value.Type == "array" { + valueSchema = valueSchema.Value.Items + } } var value interface{} @@ -1171,14 +1234,28 @@ func multipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.S } allTheProperties := make(map[string]*openapi3.SchemaRef) - for k, v := range schema.Value.Properties { - allTheProperties[k] = v - } - if schema.Value.AdditionalProperties != nil { - for k, v := range schema.Value.AdditionalProperties.Value.Properties { + if len(schema.Value.AllOf) > 0 { + for _, sr := range schema.Value.AllOf { + for k, v := range sr.Value.Properties { + allTheProperties[k] = v + } + if addProps := sr.Value.AdditionalProperties.Schema; addProps != nil { + for k, v := range addProps.Value.Properties { + allTheProperties[k] = v + } + } + } + } else { + for k, v := range schema.Value.Properties { allTheProperties[k] = v } + if addProps := schema.Value.AdditionalProperties.Schema; addProps != nil { + for k, v := range addProps.Value.Properties { + allTheProperties[k] = v + } + } } + // Make an object value from form values. obj := make(map[string]interface{}) for name, prop := range allTheProperties { @@ -1198,7 +1275,7 @@ func multipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.S // FileBodyDecoder is a body decoder that decodes a file body to a string. func FileBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (interface{}, error) { - data, err := ioutil.ReadAll(body) + data, err := io.ReadAll(body) if err != nil { return nil, err } diff --git a/internal/platform/validator/req_resp_decoder_test.go b/internal/platform/validator/req_resp_decoder_test.go index c7ea5e4..efae67a 100644 --- a/internal/platform/validator/req_resp_decoder_test.go +++ b/internal/platform/validator/req_resp_decoder_test.go @@ -5,10 +5,7 @@ import ( "context" "encoding/json" "fmt" - "github.com/getkin/kin-openapi/openapi3filter" - "github.com/valyala/fastjson" "io" - "io/ioutil" "mime/multipart" "net/http" "net/textproto" @@ -18,8 +15,10 @@ import ( "testing" "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" legacyrouter "github.com/getkin/kin-openapi/routers/legacy" "github.com/stretchr/testify/require" + "github.com/valyala/fastjson" ) func TestDecodeParameter(t *testing.T) { @@ -1147,7 +1146,7 @@ func TestDecodeBody(t *testing.T) { }{ { name: prefixUnsupportedCT, - mime: "application/xml", + mime: "application/none-xml", wantErr: &ParseError{Kind: KindUnsupportedFormat}, }, { @@ -1341,7 +1340,7 @@ func TestRegisterAndUnregisterBodyDecoder(t *testing.T) { var decoder BodyDecoder decoder = func(body io.Reader, h http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (decoded interface{}, err error) { var data []byte - if data, err = ioutil.ReadAll(body); err != nil { + if data, err = io.ReadAll(body); err != nil { return } return strings.Split(string(data), ","), nil diff --git a/internal/platform/validator/unknown_parameters_request.go b/internal/platform/validator/unknown_parameters_request.go new file mode 100644 index 0000000..c705082 --- /dev/null +++ b/internal/platform/validator/unknown_parameters_request.go @@ -0,0 +1,169 @@ +package validator + +import ( + "bytes" + "encoding/csv" + "github.com/getkin/kin-openapi/routers" + "io" + "net/http" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/pkg/errors" + "github.com/valyala/fasthttp" + "github.com/valyala/fastjson" +) + +// ErrUnknownQueryParameter is returned when a query parameter not defined in the OpenAPI specification. +var ErrUnknownQueryParameter = errors.New("query parameter not defined in the OpenAPI specification") + +// ErrUnknownBodyParameter is returned when a body parameter not defined in the OpenAPI specification. +var ErrUnknownBodyParameter = errors.New("body parameter not defined in the OpenAPI specification") + +// ErrUnknownContentType is returned when the API FW can't parse the request body +var ErrUnknownContentType = errors.New("unknown content type of the request body") + +// ErrDecodingFailed is returned when the API FW got error or unexpected value from the decoder +var ErrDecodingFailed = errors.New("the decoder returned the error") + +// RequestUnknownParameterError is returned by ValidateRequest when request does not match OpenAPI spec +type RequestUnknownParameterError struct { + Input *openapi3filter.RequestValidationInput + Parameters []string + RequestBody *openapi3.RequestBody + Err error +} + +// ValidateUnknownRequestParameters is used to get a list of request parameters that are not specified in the OpenAPI specification +func ValidateUnknownRequestParameters(ctx *fasthttp.RequestCtx, route *routers.Route, header http.Header, jsonParser *fastjson.Parser) (foundUnknownParams []RequestUnknownParameterError, valError error) { + + operation := route.Operation + operationParameters := operation.Parameters + pathItemParameters := route.PathItem.Parameters + + // prepare a map with the list of params that defined in the OpenAPI specification + specParams := make(map[string]*openapi3.Parameter) + for _, parameterRef := range pathItemParameters { + parameter := parameterRef.Value + specParams[parameter.Name+parameter.In] = parameter + } + + // add optional parameters to the map with parameters + for _, parameterRef := range operationParameters { + parameter := parameterRef.Value + specParams[parameter.Name+parameter.In] = parameter + } + + unknownQueryParams := RequestUnknownParameterError{} + // compare list of all query params and list of params defined in the specification + ctx.Request.URI().QueryArgs().VisitAll(func(key, value []byte) { + if _, ok := specParams[string(key)+openapi3.ParameterInQuery]; !ok { + unknownQueryParams.Err = ErrUnknownQueryParameter + unknownQueryParams.Parameters = append(unknownQueryParams.Parameters, string(key)) + } + }) + + if unknownQueryParams.Err != nil { + foundUnknownParams = append(foundUnknownParams, unknownQueryParams) + } + + if operation.RequestBody == nil { + return + } + + // validate body params + requestBody := operation.RequestBody.Value + + content := requestBody.Content + if len(content) == 0 { + // A request's body does not have declared content, so skip validation. + return + } + + if len(ctx.Request.Body()) == 0 { + return foundUnknownParams, nil + } + + // check post params + inputMIME := string(ctx.Request.Header.ContentType()) + contentType := requestBody.Content.Get(inputMIME) + if contentType == nil { + return foundUnknownParams, nil + } + + encFn := func(name string) *openapi3.Encoding { return contentType.Encoding[name] } + mediaType, value, err := decodeBody(io.NopCloser(bytes.NewReader(ctx.Request.Body())), header, contentType.Schema, encFn, jsonParser) + if err != nil { + return foundUnknownParams, err + } + + unknownBodyParams := RequestUnknownParameterError{} + + switch mediaType { + case "text/plain": + return nil, nil + case "text/csv": + r := csv.NewReader(io.NopCloser(bytes.NewReader(ctx.Request.Body()))) + + record, err := r.Read() + if err != nil { + return foundUnknownParams, err + } + + for _, rName := range record { + found := false + for propName := range contentType.Schema.Value.Properties { + if rName == propName { + found = true + break + } + } + if !found { + unknownBodyParams.Err = ErrUnknownBodyParameter + unknownBodyParams.Parameters = append(unknownBodyParams.Parameters, rName) + } + } + case "application/x-www-form-urlencoded": + // required params in paramList + paramList, ok := value.(map[string]interface{}) + if !ok { + return foundUnknownParams, ErrDecodingFailed + } + if ok { + ctx.Request.PostArgs().VisitAll(func(key, value []byte) { + if _, ok := paramList[string(key)]; !ok { + unknownBodyParams.Err = ErrUnknownBodyParameter + unknownBodyParams.Parameters = append(unknownBodyParams.Parameters, string(key)) + } + }) + } + case "application/json", "application/xml", "multipart/form-data": + paramList, ok := value.(map[string]interface{}) + if !ok { + return foundUnknownParams, ErrDecodingFailed + } + if ok { + for paramName, paramValue := range paramList { + found := false + for propName := range contentType.Schema.Value.Properties { + if paramName == propName && paramValue != nil { + found = true + break + } + } + if !found { + unknownBodyParams.Err = ErrUnknownBodyParameter + unknownBodyParams.Parameters = append(unknownBodyParams.Parameters, paramName) + } + } + } + default: + return foundUnknownParams, ErrDecodingFailed + } + + if unknownBodyParams.Err != nil { + foundUnknownParams = append(foundUnknownParams, unknownBodyParams) + } + + return +} diff --git a/internal/platform/validator/unknown_parameters_request_test.go b/internal/platform/validator/unknown_parameters_request_test.go new file mode 100644 index 0000000..6b84be5 --- /dev/null +++ b/internal/platform/validator/unknown_parameters_request_test.go @@ -0,0 +1,287 @@ +package validator + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "reflect" + "strings" + "testing" + + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttpadaptor" + "github.com/valyala/fastjson" +) + +func TestUnknownParametersRequest(t *testing.T) { + const spec = ` +openapi: 3.0.0 +info: + title: 'Validator' + version: 0.0.1 +paths: + /category: + post: + parameters: + - name: category + in: query + schema: + type: string + required: true + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - subCategory + properties: + subCategory: + type: string + category: + type: string + default: Sweets + application/x-www-form-urlencoded: + schema: + type: object + required: + - subCategory + properties: + subCategory: + type: string + category: + type: string + default: Sweets + responses: + '201': + description: Created + /unknown: + post: + requestBody: + required: true + content: + application/json: + schema: {} + application/x-www-form-urlencoded: + schema: {} + responses: + '201': + description: Created +` + + router := setupTestRouter(t, spec) + + type testRequestBody struct { + SubCategory string `json:"subCategory"` + Category string `json:"category,omitempty"` + UnknownParameter string `json:"unknown,omitempty"` + } + type args struct { + requestBody *testRequestBody + ct string + url string + } + tests := []struct { + name string + args args + expectedErr error + expectedResp []*RequestUnknownParameterError + }{ + { + name: "Valid request with all fields set", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate", Category: "Food"}, + url: "/category?category=cookies", + ct: "application/json", + }, + expectedErr: nil, + expectedResp: nil, + }, + { + name: "Valid request without certain fields", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate"}, + url: "/category?category=cookies", + ct: "application/json", + }, + expectedErr: nil, + expectedResp: nil, + }, + { + name: "Invalid operation params", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate"}, + url: "/category?invalidCategory=badCookie", + ct: "application/json", + }, + expectedErr: nil, + expectedResp: []*RequestUnknownParameterError{ + { + Parameters: []string{"invalidCategory"}, + Err: ErrUnknownQueryParameter, + }, + }, + }, + { + name: "Invalid request body", + args: args{ + requestBody: nil, + url: "/category?category=cookies", + ct: "application/json", + }, + expectedErr: nil, + expectedResp: nil, + }, + { + name: "Unknown query param", + args: args{ + requestBody: nil, + url: "/category?category=cookies&unknown=test", + }, + expectedErr: nil, + expectedResp: []*RequestUnknownParameterError{ + { + Parameters: []string{"unknown"}, + Err: ErrUnknownQueryParameter, + }, + }, + }, + { + name: "Unknown JSON param", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate", Category: "Food", UnknownParameter: "test"}, + url: "/category?category=cookies", + ct: "application/json", + }, + expectedErr: nil, + expectedResp: []*RequestUnknownParameterError{ + { + Parameters: []string{"unknown"}, + Err: ErrUnknownBodyParameter, + }, + }, + }, + { + name: "Unknown POST param", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate", Category: "Food", UnknownParameter: "test"}, + url: "/category?category=cookies", + ct: "application/x-www-form-urlencoded", + }, + expectedErr: nil, + expectedResp: []*RequestUnknownParameterError{ + { + Parameters: []string{"unknown"}, + Err: ErrUnknownBodyParameter, + }, + }, + }, + { + name: "Valid POST params", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate", Category: "Food"}, + url: "/category?category=cookies", + ct: "application/x-www-form-urlencoded", + }, + expectedErr: nil, + expectedResp: nil, + }, + { + name: "Valid POST unknown params", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate", Category: "Food"}, + url: "/unknown", + ct: "application/x-www-form-urlencoded", + }, + expectedErr: nil, + expectedResp: []*RequestUnknownParameterError{ + { + Parameters: []string{"subCategory", "category"}, + Err: ErrUnknownBodyParameter, + }, + }, + }, + { + name: "Valid JSON unknown params", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate"}, + url: "/unknown", + ct: "application/json", + }, + expectedErr: nil, + expectedResp: []*RequestUnknownParameterError{ + { + Parameters: []string{"subCategory"}, + Err: ErrUnknownBodyParameter, + }, + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := fasthttp.AcquireRequest() + req.SetRequestURI(tc.args.url) + req.Header.SetMethod("POST") + req.Header.SetContentType(tc.args.ct) + + var requestBody io.Reader + if tc.args.requestBody != nil { + switch tc.args.ct { + case "application/x-www-form-urlencoded": + if tc.args.requestBody.UnknownParameter != "" { + req.PostArgs().Add("unknown", tc.args.requestBody.UnknownParameter) + } + if tc.args.requestBody.SubCategory != "" { + req.PostArgs().Add("subCategory", tc.args.requestBody.SubCategory) + } + if tc.args.requestBody.Category != "" { + req.PostArgs().Add("category", tc.args.requestBody.Category) + } + requestBody = strings.NewReader(req.PostArgs().String()) + case "application/json": + testingBody, err := json.Marshal(tc.args.requestBody) + require.NoError(t, err) + requestBody = bytes.NewReader(testingBody) + } + } + + req.SetBodyStream(requestBody, -1) + + ctx := fasthttp.RequestCtx{ + Request: *req, + } + + reqHttp := http.Request{} + + err := fasthttpadaptor.ConvertRequest(&ctx, &reqHttp, false) + require.NoError(t, err) + + route, pathParams, err := router.FindRoute(&reqHttp) + require.NoError(t, err) + + validationInput := &openapi3filter.RequestValidationInput{ + Request: &reqHttp, + PathParams: pathParams, + Route: route, + } + upRes, err := ValidateUnknownRequestParameters(&ctx, validationInput.Route, validationInput.Request.Header, &fastjson.Parser{}) + assert.IsType(t, tc.expectedErr, err, "ValidateUnknownRequestParameters(): error = %v, expectedError %v", err, tc.expectedErr) + if tc.expectedErr != nil { + return + } + if tc.expectedResp != nil || len(tc.expectedResp) > 0 { + assert.Equal(t, len(tc.expectedResp), len(upRes), "expect the number of unknown parameters: %t, got %t", len(tc.expectedResp), len(upRes)) + + if isEq := reflect.DeepEqual(tc.expectedResp, upRes); !isEq { + assert.Errorf(t, errors.New("got unexpected unknown parameters"), "expect unknown parameters: %v, got %v", tc.expectedResp, upRes) + } + } + }) + } +} diff --git a/internal/platform/validator/validate_request.go b/internal/platform/validator/validate_request.go index 15ae15f..9e460c1 100644 --- a/internal/platform/validator/validate_request.go +++ b/internal/platform/validator/validate_request.go @@ -35,7 +35,7 @@ func ValidateRequest(ctx context.Context, input *openapi3filter.RequestValidatio options := input.Options if options == nil { - options = openapi3filter.DefaultOptions + options = &openapi3filter.Options{} } route := input.Route operation := route.Operation @@ -121,7 +121,7 @@ func ValidateParameter(ctx context.Context, input *openapi3filter.RequestValidat options := input.Options if options == nil { - options = openapi3filter.DefaultOptions + options = &openapi3filter.Options{} } var value interface{} @@ -169,6 +169,12 @@ func ValidateParameter(ctx context.Context, input *openapi3filter.RequestValidat } if isNilValue(value) { + + // if parameter has empty schema + if schema.IsEmpty() && found { + return nil + } + if !parameter.AllowEmptyValue && found { return &openapi3filter.RequestError{Input: input, Parameter: parameter, Reason: ErrInvalidEmptyValue.Error(), Err: ErrInvalidEmptyValue} } @@ -204,7 +210,7 @@ func ValidateRequestBody(ctx context.Context, input *openapi3filter.RequestValid options := input.Options if options == nil { - options = openapi3filter.DefaultOptions + options = &openapi3filter.Options{} } if req.Body != http.NoBody && req.Body != nil { @@ -290,7 +296,7 @@ func ValidateRequestBody(ctx context.Context, input *openapi3filter.RequestValid return &openapi3filter.RequestError{ Input: input, RequestBody: requestBody, - Reason: fmt.Sprintf("doesn't match schema%s", schemaId), + Reason: fmt.Sprintf("doesn't match schema %s", schemaId), Err: err, } } @@ -358,7 +364,7 @@ func validateSecurityRequirement(ctx context.Context, input *openapi3filter.Requ // Get authentication function options := input.Options if options == nil { - options = openapi3filter.DefaultOptions + options = &openapi3filter.Options{} } f := options.AuthenticationFunc if f == nil { diff --git a/internal/platform/validator/validate_response.go b/internal/platform/validator/validate_response.go index a8e796b..06174b7 100644 --- a/internal/platform/validator/validate_response.go +++ b/internal/platform/validator/validate_response.go @@ -40,7 +40,7 @@ func ValidateResponse(ctx context.Context, input *openapi3filter.ResponseValidat route := input.RequestValidationInput.Route options := input.Options if options == nil { - options = openapi3filter.DefaultOptions + options = &openapi3filter.Options{} } // Find input for the current status diff --git a/internal/platform/web/apps.go b/internal/platform/web/apps.go new file mode 100644 index 0000000..bdeae72 --- /dev/null +++ b/internal/platform/web/apps.go @@ -0,0 +1,169 @@ +package web + +import ( + "errors" + "fmt" + "os" + strconv2 "strconv" + "sync" + "syscall" + + "github.com/fasthttp/router" + "github.com/savsgio/gotils/strconv" + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" + "github.com/wallarm/api-firewall/internal/platform/database" +) + +// Apps is the entrypoint into our application and what configures our context +// object for each of our http handlers. Feel free to add any configuration +// data/logic on this App struct +type Apps struct { + Routers map[int]*router.Router + Log *logrus.Logger + passOPTIONS bool + shutdown chan os.Signal + mw []Middleware + storedSpecs database.DBOpenAPILoader + lock *sync.RWMutex +} + +func (a *Apps) SetDefaultBehavior(schemaID int, handler Handler, mw ...Middleware) { + // First wrap handler specific middleware around this handler. + handler = wrapMiddleware(mw, handler) + + // Add the application's general middleware to the handler chain. + handler = wrapMiddleware(a.mw, handler) + + customHandler := func(ctx *fasthttp.RequestCtx) { + + if err := handler(ctx); err != nil { + a.SignalShutdown() + return + } + + } + + //Set NOT FOUND behavior + a.Routers[schemaID].NotFound = customHandler + + // Set Method Not Allowed behavior + a.Routers[schemaID].MethodNotAllowed = customHandler +} + +// NewApps creates an Apps value that handle a set of routes for the set of application. +func NewApps(lock *sync.RWMutex, passOPTIONS bool, storedSpecs database.DBOpenAPILoader, shutdown chan os.Signal, logger *logrus.Logger, mw ...Middleware) *Apps { + + schemaIDs := storedSpecs.SchemaIDs() + + // init routers + routers := make(map[int]*router.Router) + for _, schemaID := range schemaIDs { + routers[schemaID] = router.New() + routers[schemaID].HandleOPTIONS = passOPTIONS + } + + app := Apps{ + Routers: routers, + shutdown: shutdown, + mw: mw, + Log: logger, + storedSpecs: storedSpecs, + lock: lock, + passOPTIONS: passOPTIONS, + } + + return &app +} + +// Handle is our mechanism for mounting Handlers for a given HTTP verb and path +// pair, this makes for really easy, convenient routing. +func (a *Apps) Handle(schemaID int, method string, path string, handler Handler, mw ...Middleware) { + + // First wrap handler specific middleware around this handler. + handler = wrapMiddleware(mw, handler) + + // Add the application's general middleware to the handler chain. + handler = wrapMiddleware(a.mw, handler) + + // The function to execute for each request. + h := func(ctx *fasthttp.RequestCtx) { + + if err := handler(ctx); err != nil { + a.SignalShutdown() + return + } + } + + // Add this handler for the specified verb and route. + a.Routers[schemaID].Handle(method, path, h) +} + +func getWallarmSchemaID(ctx *fasthttp.RequestCtx, storedSpecs database.DBOpenAPILoader) (int, error) { + + // get Wallarm Schema ID + xWallarmSchemaID := string(ctx.Request.Header.Peek(XWallarmSchemaIDHeader)) + if xWallarmSchemaID == "" { + return 0, errors.New("required X-WALLARM-SCHEMA-ID header is missing") + } + + // get schema version + schemaID, err := strconv2.Atoi(xWallarmSchemaID) + if err != nil { + return 0, fmt.Errorf("error parsing value: %v", err) + } + + // check if schema ID is loaded + if !storedSpecs.IsLoaded(schemaID) { + return 0, fmt.Errorf("provided via X-WALLARM-SCHEMA-ID header schema ID %d not found", schemaID) + } + + return schemaID, nil +} + +// APIModeHandler routes request to the appropriate handler according to the OpenAPI specification schema ID +func (a *Apps) APIModeHandler(ctx *fasthttp.RequestCtx) { + + schemaID, err := getWallarmSchemaID(ctx, a.storedSpecs) + if err != nil { + defer LogRequestResponseAtTraceLevel(ctx, a.Log) + + a.Log.WithFields(logrus.Fields{ + "error": err, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("error while getting schema ID") + + if err := RespondError(ctx, fasthttp.StatusInternalServerError, ""); err != nil { + a.Log.WithFields(logrus.Fields{ + "error": err, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("error while sending response") + } + + return + } + + // add internal header to the context + ctx.SetUserValue(WallarmSchemaID, schemaID) + + // delete internal header + ctx.Request.Header.Del(XWallarmSchemaIDHeader) + + a.lock.RLock() + defer a.lock.RUnlock() + + a.Routers[schemaID].Handler(ctx) + + // if pass request with OPTIONS method is enabled then log request + if ctx.Response.StatusCode() == fasthttp.StatusOK && a.passOPTIONS && strconv.B2S(ctx.Method()) == fasthttp.MethodOptions { + a.Log.WithFields(logrus.Fields{ + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Debug("pass request with OPTIONS method") + } +} + +// SignalShutdown is used to gracefully shutdown the app when an integrity +// issue is identified. +func (a *Apps) SignalShutdown() { + a.shutdown <- syscall.SIGTERM +} diff --git a/internal/platform/web/response.go b/internal/platform/web/response.go index 7f3c4aa..a4fd710 100644 --- a/internal/platform/web/response.go +++ b/internal/platform/web/response.go @@ -18,37 +18,6 @@ var ( supportedEncodings = []string{"gzip", "deflate", "br"} ) -//// GetDecompressedBody function returns the Reader of the decompressed body -//func GetDecompressedBody(ctx *fasthttp.RequestCtx) (io.ReadCloser, error) { -// -// bodyBytes := ctx.Response.Body() -// compression := ctx.Response.Header.ContentEncoding() -// -// if compression != nil { -// for _, sc := range [][]byte{gzip, deflate, br} { -// if bytes.Equal(sc, compression) { -// var body []byte -// var err error -// if body, err = ctx.Response.BodyUncompressed(); err != nil { -// if errors.Is(zlib.ErrHeader, err) && bytes.Equal(compression, deflate) { -// // deflate rfc 1951 implementation -// return flate.NewReader(bytes.NewReader(bodyBytes)), nil -// } -// // got error while body decompression -// return nil, err -// } -// // body has been successfully uncompressed -// return io.NopCloser(bytes.NewReader(body)), nil -// } -// } -// // body compression schema not supported -// return nil, fasthttp.ErrContentEncodingUnsupported -// } -// -// // body without compression -// return io.NopCloser(bytes.NewReader(bodyBytes)), nil -//} - // GetDecompressedResponseBody function returns the Reader of the decompressed response body func GetDecompressedResponseBody(resp *fasthttp.Response, contentEncoding string) (io.ReadCloser, error) { @@ -133,19 +102,27 @@ func Respond(ctx *fasthttp.RequestCtx, data interface{}, statusCode int) error { return nil } -// RespondError sends an error reponse back to the client. -func RespondError(ctx *fasthttp.RequestCtx, statusCode int, statusHeader *string) error { +// RespondError sends an error response back to the client. +func RespondError(ctx *fasthttp.RequestCtx, statusCode int, statusHeader string) error { ctx.Error("", statusCode) // Add validation status header - if statusHeader != nil { - ctx.Response.Header.Add(ValidationStatus, *statusHeader) + if statusHeader != "" { + ctx.Response.Header.Add(ValidationStatus, statusHeader) } return nil } +// RespondOk sends an empty response with 200 status OK back to the client. +func RespondOk(ctx *fasthttp.RequestCtx) error { + + ctx.Error("", fasthttp.StatusOK) + + return nil +} + // Redirect302 redirects client with code 302 func Redirect302(ctx *fasthttp.RequestCtx, redirectUrl string) error { diff --git a/internal/platform/web/trace.go b/internal/platform/web/trace.go new file mode 100644 index 0000000..1d80e02 --- /dev/null +++ b/internal/platform/web/trace.go @@ -0,0 +1,39 @@ +package web + +import ( + "fmt" + + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" +) + +func LogRequestResponseAtTraceLevel(ctx *fasthttp.RequestCtx, logger *logrus.Logger) { + if logger.Level == logrus.TraceLevel { + requestHeaders := "" + ctx.Request.Header.VisitAll(func(key, value []byte) { + requestHeaders += string(key) + ":" + string(value) + "\n" + }) + + logger.WithFields(logrus.Fields{ + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + "method": string(ctx.Request.Header.Method()), + "uri": string(ctx.Request.URI().RequestURI()), + "headers": requestHeaders, + "body": string(ctx.Request.Body()), + "client_address": ctx.RemoteAddr(), + }).Trace("new request") + + responseHeaders := "" + ctx.Response.Header.VisitAll(func(key, value []byte) { + responseHeaders += string(key) + ":" + string(value) + "\n" + }) + + logger.WithFields(logrus.Fields{ + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + "status_code": ctx.Response.StatusCode(), + "headers": responseHeaders, + "body": string(ctx.Response.Body()), + "client_address": ctx.RemoteAddr(), + }).Trace("response from the API-Firewall") + } +} diff --git a/internal/platform/web/web.go b/internal/platform/web/web.go index 88da191..856c090 100644 --- a/internal/platform/web/web.go +++ b/internal/platform/web/web.go @@ -1,7 +1,9 @@ package web import ( + "bytes" "fmt" + "github.com/savsgio/gotils/strconv" "os" "syscall" @@ -14,9 +16,21 @@ import ( const ( ValidationStatus = "APIFW-Validation-Status" + XWallarmSchemaIDHeader = "X-WALLARM-SCHEMA-ID" + WallarmSchemaID = "WallarmSchemaID" + ValidationDisable = "DISABLE" ValidationBlock = "BLOCK" ValidationLog = "LOG_ONLY" + + RequestProxyNoRoute = "proxy_no_route" + RequestProxyFailed = "proxy_failed" + RequestBlocked = "request_blocked" + ResponseBlocked = "response_blocked" + ResponseStatusNotFound = "response_status_not_found" + + APIMode = "api" + ProxyMode = "proxy" ) // A Handler is a type that handles an http request within our own little mini @@ -43,16 +57,18 @@ func (a *App) SetDefaultBehavior(handler Handler, mw ...Middleware) { customHandler := func(ctx *fasthttp.RequestCtx) { - // Block request if it's not found in the route - if a.cfg.RequestValidation == ValidationBlock || a.cfg.ResponseValidation == ValidationBlock { - a.Log.WithFields(logrus.Fields{ - "request_id": fmt.Sprintf("#%016X", ctx.ID()), - "method": fmt.Sprintf("%s", ctx.Request.Header.Method()), - "path": fmt.Sprintf("%s", ctx.Path()), - "client_address": ctx.RemoteAddr(), - }).Info("request blocked") - ctx.Error("", a.cfg.CustomBlockStatusCode) - return + // Block request if it's not found in the route. Not for API mode. + if a.cfg.Mode == ProxyMode { + if a.cfg.RequestValidation == ValidationBlock || a.cfg.ResponseValidation == ValidationBlock { + a.Log.WithFields(logrus.Fields{ + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + "method": bytes.NewBuffer(ctx.Request.Header.Method()).String(), + "path": string(ctx.Path()), + "client_address": ctx.RemoteAddr(), + }).Info("request blocked") + ctx.Error("", a.cfg.CustomBlockStatusCode) + return + } } if err := handler(ctx); err != nil { @@ -79,6 +95,8 @@ func NewApp(shutdown chan os.Signal, cfg *config.APIFWConfiguration, logger *log cfg: cfg, } + app.Router.HandleOPTIONS = cfg.PassOptionsRequests + return &app } @@ -99,6 +117,13 @@ func (a *App) Handle(method string, path string, handler Handler, mw ...Middlewa a.SignalShutdown() return } + + // if pass request with OPTIONS method is enabled then log request + if ctx.Response.StatusCode() == fasthttp.StatusOK && a.cfg.PassOptionsRequests && strconv.B2S(ctx.Method()) == fasthttp.MethodOptions { + a.Log.WithFields(logrus.Fields{ + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Debug("pass request with OPTIONS method") + } } // Add this handler for the specified verb and route. diff --git a/resources/dev/httpbin.json b/resources/dev/httpbin.json new file mode 100644 index 0000000..bbe252d --- /dev/null +++ b/resources/dev/httpbin.json @@ -0,0 +1,1704 @@ +{ + "openapi": "3.0.1", + "info": { + "title": "httpbin.org", + "description": "A simple HTTP Request & Response Service.

Run locally: $ docker run -p 80:80 kennethreitz/httpbin", + "contact": { + "url": "https://kennethreitz.org", + "email": "me@kennethreitz.org" + }, + "version": "0.9.2" + }, + "servers": [ + { + "url": "https://httpbin.org/" + } + ], + "tags": [ + { + "name": "HTTP Methods", + "description": "Testing different HTTP verbs" + }, + { + "name": "Auth", + "description": "Auth methods" + }, + { + "name": "Status codes", + "description": "Generates responses with given status code" + }, + { + "name": "Request inspection", + "description": "Inspect the request data" + }, + { + "name": "Response inspection", + "description": "Inspect the response data like caching and headers" + }, + { + "name": "Response formats", + "description": "Returns responses in different data formats" + }, + { + "name": "Dynamic data", + "description": "Generates random and dynamic data" + }, + { + "name": "Cookies", + "description": "Creates, reads and deletes Cookies" + }, + { + "name": "Images", + "description": "Returns different image formats" + }, + { + "name": "Redirects", + "description": "Returns different redirect responses" + }, + { + "name": "Anything", + "description": "Returns anything that is passed to request" + } + ], + "paths": { + "/absolute-redirect/{n}": { + "get": { + "tags": [ + "Redirects" + ], + "summary": "Absolutely 302 Redirects n times.", + "parameters": [ + { + "name": "n", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "302": { + "description": "A redirection.", + "content": {} + } + } + } + }, + "/anything": { + "get": { + "tags": [ + "Anything" + ], + "summary": "Returns anything passed in request data.", + "responses": { + "200": { + "description": "Anything passed in request", + "content": {} + } + } + }, + "put": { + "tags": [ + "Anything" + ], + "summary": "Returns anything passed in request data.", + "responses": { + "200": { + "description": "Anything passed in request", + "content": {} + } + } + }, + "post": { + "tags": [ + "Anything" + ], + "summary": "Returns anything passed in request data.", + "responses": { + "200": { + "description": "Anything passed in request", + "content": {} + } + } + }, + "delete": { + "tags": [ + "Anything" + ], + "summary": "Returns anything passed in request data.", + "responses": { + "200": { + "description": "Anything passed in request", + "content": {} + } + } + }, + "patch": { + "tags": [ + "Anything" + ], + "summary": "Returns anything passed in request data.", + "responses": { + "200": { + "description": "Anything passed in request", + "content": {} + } + } + } + }, + "/anything/{anything}": { + "get": { + "parameters": [ + { + "name": "anything", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "tags": [ + "Anything" + ], + "summary": "Returns anything passed in request data.", + "responses": { + "200": { + "description": "Anything passed in request", + "content": {} + } + } + }, + "put": { + "parameters": [ + { + "name": "anything", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "tags": [ + "Anything" + ], + "summary": "Returns anything passed in request data.", + "responses": { + "200": { + "description": "Anything passed in request", + "content": {} + } + } + }, + "post": { + "parameters": [ + { + "name": "anything", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "tags": [ + "Anything" + ], + "summary": "Returns anything passed in request data.", + "responses": { + "200": { + "description": "Anything passed in request", + "content": {} + } + } + }, + "delete": { + "parameters": [ + { + "name": "anything", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "tags": [ + "Anything" + ], + "summary": "Returns anything passed in request data.", + "responses": { + "200": { + "description": "Anything passed in request", + "content": {} + } + } + }, + "patch": { + "parameters": [ + { + "name": "anything", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "tags": [ + "Anything" + ], + "summary": "Returns anything passed in request data.", + "responses": { + "200": { + "description": "Anything passed in request", + "content": {} + } + } + } + }, + "/base64/{value}": { + "get": { + "tags": [ + "Dynamic data" + ], + "summary": "Decodes base64url-encoded string.", + "parameters": [ + { + "name": "value", + "in": "path", + "required": true, + "schema": { + "type": "string", + "default": "SFRUUEJJTiBpcyBhd2Vzb21l" + } + } + ], + "responses": { + "200": { + "description": "Decoded base64 content.", + "content": {} + } + } + } + }, + "/basic-auth/{user}/{passwd}": { + "get": { + "tags": [ + "Auth" + ], + "summary": "Prompts the user for authorization using HTTP Basic Auth.", + "parameters": [ + { + "name": "user", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "passwd", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Sucessful authentication.", + "content": {} + }, + "401": { + "description": "Unsuccessful authentication.", + "content": {} + } + } + } + }, + "/bearer": { + "get": { + "tags": [ + "Auth" + ], + "summary": "Prompts the user for authorization using bearer authentication.", + "parameters": [ + { + "name": "Authorization", + "in": "header", + "schema": {} + } + ], + "responses": { + "200": { + "description": "Sucessful authentication.", + "content": {} + }, + "401": { + "description": "Unsuccessful authentication.", + "content": {} + } + } + } + }, + "/brotli": { + "get": { + "tags": [ + "Response formats" + ], + "summary": "Returns Brotli-encoded data.", + "responses": { + "200": { + "description": "Brotli-encoded data.", + "content": {} + } + } + } + }, + "/bytes/{n}": { + "get": { + "tags": [ + "Dynamic data" + ], + "summary": "Returns n random bytes generated with given seed", + "parameters": [ + { + "name": "n", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "200": { + "description": "Bytes.", + "content": {} + } + } + } + }, + "/cache": { + "get": { + "tags": [ + "Response inspection" + ], + "summary": "Returns a 304 if an If-Modified-Since header or If-None-Match is present. Returns the same as a GET otherwise.", + "parameters": [ + { + "name": "If-Modified-Since", + "in": "header", + "schema": {} + }, + { + "name": "If-None-Match", + "in": "header", + "schema": {} + } + ], + "responses": { + "200": { + "description": "Cached response", + "content": {} + }, + "304": { + "description": "Modified", + "content": {} + } + } + } + }, + "/cache/{value}": { + "get": { + "tags": [ + "Response inspection" + ], + "summary": "Sets a Cache-Control header for n seconds.", + "parameters": [ + { + "name": "value", + "in": "path", + "required": true, + "schema": { + "type": "integer" + } + } + ], + "responses": { + "200": { + "description": "Cache control set", + "content": {} + } + } + } + }, + "/cookies": { + "get": { + "tags": [ + "Cookies" + ], + "summary": "Returns cookie data.", + "responses": { + "200": { + "description": "Set cookies.", + "content": {} + } + } + } + }, + "/cookies/delete": { + "get": { + "tags": [ + "Cookies" + ], + "summary": "Deletes cookie(s) as provided by the query string and redirects to cookie list.", + "parameters": [ + { + "name": "freeform", + "in": "query", + "allowEmptyValue": true, + "schema": {} + } + ], + "responses": { + "200": { + "description": "Redirect to cookie list", + "content": {} + } + } + } + }, + "/cookies/set": { + "get": { + "tags": [ + "Cookies" + ], + "summary": "Sets cookie(s) as provided by the query string and redirects to cookie list.", + "parameters": [ + { + "name": "freeform", + "in": "query", + "allowEmptyValue": true, + "schema": {} + } + ], + "responses": { + "200": { + "description": "Redirect to cookie list", + "content": {} + } + } + } + }, + "/cookies/set/{name}/{value}": { + "get": { + "tags": [ + "Cookies" + ], + "summary": "Sets a cookie and redirects to cookie list.", + "parameters": [ + { + "name": "name", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "value", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Set cookies and redirects to cookie list.", + "content": {} + } + } + } + }, + "/deflate": { + "get": { + "tags": [ + "Response formats" + ], + "summary": "Returns Deflate-encoded data.", + "responses": { + "200": { + "description": "Defalte-encoded data.", + "content": {} + } + } + } + }, + "/delay/{delay}": { + "get": { + "tags": [ + "Dynamic data" + ], + "summary": "Returns a delayed response (max of 10 seconds).", + "parameters": [ + { + "name": "delay", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "200": { + "description": "A delayed response.", + "content": {} + } + } + }, + "put": { + "tags": [ + "Dynamic data" + ], + "summary": "Returns a delayed response (max of 10 seconds).", + "parameters": [ + { + "name": "delay", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "200": { + "description": "A delayed response.", + "content": {} + } + } + }, + "post": { + "tags": [ + "Dynamic data" + ], + "summary": "Returns a delayed response (max of 10 seconds).", + "parameters": [ + { + "name": "delay", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "200": { + "description": "A delayed response.", + "content": {} + } + } + }, + "delete": { + "tags": [ + "Dynamic data" + ], + "summary": "Returns a delayed response (max of 10 seconds).", + "parameters": [ + { + "name": "delay", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "200": { + "description": "A delayed response.", + "content": {} + } + } + }, + "patch": { + "tags": [ + "Dynamic data" + ], + "summary": "Returns a delayed response (max of 10 seconds).", + "parameters": [ + { + "name": "delay", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "200": { + "description": "A delayed response.", + "content": {} + } + } + } + }, + "/delete": { + "delete": { + "tags": [ + "HTTP Methods" + ], + "summary": "The request's DELETE parameters.", + "responses": { + "200": { + "description": "The request's DELETE parameters.", + "content": {} + } + } + } + }, + "/deny": { + "get": { + "tags": [ + "Response formats" + ], + "summary": "Returns page denied by robots.txt rules.", + "responses": { + "200": { + "description": "Denied message", + "content": {} + } + } + } + }, + "/digest-auth/{qop}/{user}/{passwd}": { + "get": { + "tags": [ + "Auth" + ], + "summary": "Prompts the user for authorization using Digest Auth.", + "parameters": [ + { + "name": "qop", + "in": "path", + "description": "auth or auth-int", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "user", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "passwd", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Sucessful authentication.", + "content": {} + }, + "401": { + "description": "Unsuccessful authentication.", + "content": {} + } + } + } + }, + "/digest-auth/{qop}/{user}/{passwd}/{algorithm}": { + "get": { + "tags": [ + "Auth" + ], + "summary": "Prompts the user for authorization using Digest Auth + Algorithm.", + "parameters": [ + { + "name": "qop", + "in": "path", + "description": "auth or auth-int", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "user", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "passwd", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "algorithm", + "in": "path", + "description": "MD5, SHA-256, SHA-512", + "required": true, + "schema": { + "type": "string", + "default": "MD5" + } + } + ], + "responses": { + "200": { + "description": "Sucessful authentication.", + "content": {} + }, + "401": { + "description": "Unsuccessful authentication.", + "content": {} + } + } + } + }, + "/digest-auth/{qop}/{user}/{passwd}/{algorithm}/{stale_after}": { + "get": { + "tags": [ + "Auth" + ], + "summary": "Prompts the user for authorization using Digest Auth + Algorithm.", + "description": "allow settings the stale_after argument.\n", + "parameters": [ + { + "name": "qop", + "in": "path", + "description": "auth or auth-int", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "user", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "passwd", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "algorithm", + "in": "path", + "description": "MD5, SHA-256, SHA-512", + "required": true, + "schema": { + "type": "string", + "default": "MD5" + } + }, + { + "name": "stale_after", + "in": "path", + "required": true, + "schema": { + "type": "string", + "default": "never" + } + } + ], + "responses": { + "200": { + "description": "Sucessful authentication.", + "content": {} + }, + "401": { + "description": "Unsuccessful authentication.", + "content": {} + } + } + } + }, + "/drip": { + "get": { + "tags": [ + "Dynamic data" + ], + "summary": "Drips data over a duration after an optional initial delay.", + "parameters": [ + { + "name": "duration", + "in": "query", + "description": "The amount of time (in seconds) over which to drip each byte", + "schema": { + "type": "number", + "default": 2 + } + }, + { + "name": "numbytes", + "in": "query", + "description": "The number of bytes to respond with", + "schema": { + "type": "integer", + "default": 10 + } + }, + { + "name": "code", + "in": "query", + "description": "The response code that will be returned", + "schema": { + "type": "integer", + "default": 200 + } + }, + { + "name": "delay", + "in": "query", + "description": "The amount of time (in seconds) to delay before responding", + "schema": { + "type": "number", + "default": 2 + } + } + ], + "responses": { + "200": { + "description": "A dripped response.", + "content": {} + } + } + } + }, + "/encoding/utf8": { + "get": { + "tags": [ + "Response formats" + ], + "summary": "Returns a UTF-8 encoded body.", + "responses": { + "200": { + "description": "Encoded UTF-8 content.", + "content": {} + } + } + } + }, + "/etag/{etag}": { + "get": { + "tags": [ + "Response inspection" + ], + "summary": "Assumes the resource has the given etag and responds to If-None-Match and If-Match headers appropriately.", + "parameters": [ + { + "name": "etag", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "If-None-Match", + "in": "header", + "schema": {} + }, + { + "name": "If-Match", + "in": "header", + "schema": {} + } + ], + "responses": { + "200": { + "description": "Normal response", + "content": {} + }, + "412": { + "description": "match", + "content": {} + } + } + } + }, + "/get": { + "get": { + "tags": [ + "HTTP Methods" + ], + "summary": "The request's query parameters.", + "responses": { + "200": { + "description": "The request's query parameters.", + "content": {} + } + } + } + }, + "/gzip": { + "get": { + "tags": [ + "Response formats" + ], + "summary": "Returns GZip-encoded data.", + "responses": { + "200": { + "description": "GZip-encoded data.", + "content": {} + } + } + } + }, + "/headers": { + "get": { + "tags": [ + "Request inspection" + ], + "summary": "Return the incoming request's HTTP headers.", + "responses": { + "200": { + "description": "The request's headers.", + "content": {} + } + } + } + }, + "/hidden-basic-auth/{user}/{passwd}": { + "get": { + "tags": [ + "Auth" + ], + "summary": "Prompts the user for authorization using HTTP Basic Auth.", + "parameters": [ + { + "name": "user", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "passwd", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Sucessful authentication.", + "content": {} + }, + "404": { + "description": "Unsuccessful authentication.", + "content": {} + } + } + } + }, + "/html": { + "get": { + "tags": [ + "Response formats" + ], + "summary": "Returns a simple HTML document.", + "responses": { + "200": { + "description": "An HTML page.", + "content": {} + } + } + } + }, + "/image": { + "get": { + "tags": [ + "Images" + ], + "summary": "Returns a simple image of the type suggest by the Accept header.", + "responses": { + "200": { + "description": "An image.", + "content": {} + } + } + } + }, + "/image/jpeg": { + "get": { + "tags": [ + "Images" + ], + "summary": "Returns a simple JPEG image.", + "responses": { + "200": { + "description": "A JPEG image.", + "content": {} + } + } + } + }, + "/image/png": { + "get": { + "tags": [ + "Images" + ], + "summary": "Returns a simple PNG image.", + "responses": { + "200": { + "description": "A PNG image.", + "content": {} + } + } + } + }, + "/image/svg": { + "get": { + "tags": [ + "Images" + ], + "summary": "Returns a simple SVG image.", + "responses": { + "200": { + "description": "An SVG image.", + "content": {} + } + } + } + }, + "/image/webp": { + "get": { + "tags": [ + "Images" + ], + "summary": "Returns a simple WEBP image.", + "responses": { + "200": { + "description": "A WEBP image.", + "content": {} + } + } + } + }, + "/ip": { + "get": { + "tags": [ + "Request inspection" + ], + "summary": "Returns the requester's IP Address.", + "responses": { + "200": { + "description": "The Requester's IP Address.", + "content": {} + } + } + } + }, + "/json": { + "get": { + "tags": [ + "Response formats" + ], + "summary": "Returns a simple JSON document.", + "responses": { + "200": { + "description": "An JSON document.", + "content": {} + } + } + } + }, + "/links/{n}/{offset}": { + "get": { + "tags": [ + "Dynamic data" + ], + "summary": "Generate a page containing n links to other pages which do the same.", + "parameters": [ + { + "name": "n", + "in": "path", + "required": true, + "schema": {} + }, + { + "name": "offset", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "200": { + "description": "HTML links.", + "content": {} + } + } + } + }, + "/patch": { + "patch": { + "tags": [ + "HTTP Methods" + ], + "summary": "The request's PATCH parameters.", + "responses": { + "200": { + "description": "The request's PATCH parameters.", + "content": {} + } + } + } + }, + "/post": { + "post": { + "tags": [ + "HTTP Methods" + ], + "summary": "The request's POST parameters.", + "responses": { + "200": { + "description": "The request's POST parameters.", + "content": {} + } + } + } + }, + "/put": { + "put": { + "tags": [ + "HTTP Methods" + ], + "summary": "The request's PUT parameters.", + "responses": { + "200": { + "description": "The request's PUT parameters.", + "content": {} + } + } + } + }, + "/range/{numbytes}": { + "get": { + "tags": [ + "Dynamic data" + ], + "summary": "Streams n random bytes generated with given seed, at given chunk size per packet.", + "parameters": [ + { + "name": "numbytes", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "200": { + "description": "Bytes.", + "content": {} + } + } + } + }, + "/redirect-to": { + "get": { + "tags": [ + "Redirects" + ], + "summary": "302/3XX Redirects to the given URL.", + "parameters": [ + { + "name": "url", + "in": "query", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "status_code", + "in": "query", + "schema": {} + } + ], + "responses": { + "302": { + "description": "A redirection.", + "content": {} + } + } + }, + "put": { + "tags": [ + "Redirects" + ], + "summary": "302/3XX Redirects to the given URL.", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "required": [ + "url" + ], + "properties": { + "url": { + "type": "string" + }, + "status_code": {} + } + } + } + }, + "required": true + }, + "responses": { + "302": { + "description": "A redirection.", + "content": {} + } + } + }, + "post": { + "tags": [ + "Redirects" + ], + "summary": "302/3XX Redirects to the given URL.", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "required": [ + "url" + ], + "properties": { + "url": { + "type": "string" + }, + "status_code": {} + } + } + } + }, + "required": true + }, + "responses": { + "302": { + "description": "A redirection.", + "content": {} + } + } + }, + "delete": { + "tags": [ + "Redirects" + ], + "summary": "302/3XX Redirects to the given URL.", + "responses": { + "302": { + "description": "A redirection.", + "content": {} + } + } + }, + "patch": { + "tags": [ + "Redirects" + ], + "summary": "302/3XX Redirects to the given URL.", + "responses": { + "302": { + "description": "A redirection.", + "content": {} + } + } + } + }, + "/redirect/{n}": { + "get": { + "tags": [ + "Redirects" + ], + "summary": "302 Redirects n times.", + "parameters": [ + { + "name": "n", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "302": { + "description": "A redirection.", + "content": {} + } + } + } + }, + "/relative-redirect/{n}": { + "get": { + "tags": [ + "Redirects" + ], + "summary": "Relatively 302 Redirects n times.", + "parameters": [ + { + "name": "n", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "302": { + "description": "A redirection.", + "content": {} + } + } + } + }, + "/response-headers": { + "get": { + "tags": [ + "Response inspection" + ], + "summary": "Returns a set of response headers from the query string.", + "parameters": [ + { + "name": "freeform", + "in": "query", + "allowEmptyValue": true, + "schema": {} + } + ], + "responses": { + "200": { + "description": "Response headers", + "content": {} + } + } + }, + "post": { + "tags": [ + "Response inspection" + ], + "summary": "Returns a set of response headers from the query string.", + "parameters": [ + { + "name": "freeform", + "in": "query", + "allowEmptyValue": true, + "schema": {} + } + ], + "responses": { + "200": { + "description": "Response headers", + "content": {} + } + } + } + }, + "/robots.txt": { + "get": { + "tags": [ + "Response formats" + ], + "summary": "Returns some robots.txt rules.", + "responses": { + "200": { + "description": "Robots file", + "content": {} + } + } + } + }, + "/status/{codes}": { + "get": { + "tags": [ + "Status codes" + ], + "summary": "Return status code or random status code if more than one are given", + "parameters": [ + { + "name": "codes", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "100": { + "description": "Informational responses", + "content": {} + }, + "200": { + "description": "Success", + "content": {} + }, + "300": { + "description": "Redirection", + "content": {} + }, + "400": { + "description": "Client Errors", + "content": {} + }, + "500": { + "description": "Server Errors", + "content": {} + } + } + }, + "put": { + "tags": [ + "Status codes" + ], + "summary": "Return status code or random status code if more than one are given", + "parameters": [ + { + "name": "codes", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "100": { + "description": "Informational responses", + "content": {} + }, + "200": { + "description": "Success", + "content": {} + }, + "300": { + "description": "Redirection", + "content": {} + }, + "400": { + "description": "Client Errors", + "content": {} + }, + "500": { + "description": "Server Errors", + "content": {} + } + } + }, + "post": { + "tags": [ + "Status codes" + ], + "summary": "Return status code or random status code if more than one are given", + "parameters": [ + { + "name": "codes", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "100": { + "description": "Informational responses", + "content": {} + }, + "200": { + "description": "Success", + "content": {} + }, + "300": { + "description": "Redirection", + "content": {} + }, + "400": { + "description": "Client Errors", + "content": {} + }, + "500": { + "description": "Server Errors", + "content": {} + } + } + }, + "delete": { + "tags": [ + "Status codes" + ], + "summary": "Return status code or random status code if more than one are given", + "parameters": [ + { + "name": "codes", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "100": { + "description": "Informational responses", + "content": {} + }, + "200": { + "description": "Success", + "content": {} + }, + "300": { + "description": "Redirection", + "content": {} + }, + "400": { + "description": "Client Errors", + "content": {} + }, + "500": { + "description": "Server Errors", + "content": {} + } + } + }, + "patch": { + "tags": [ + "Status codes" + ], + "summary": "Return status code or random status code if more than one are given", + "parameters": [ + { + "name": "codes", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "100": { + "description": "Informational responses", + "content": {} + }, + "200": { + "description": "Success", + "content": {} + }, + "300": { + "description": "Redirection", + "content": {} + }, + "400": { + "description": "Client Errors", + "content": {} + }, + "500": { + "description": "Server Errors", + "content": {} + } + } + } + }, + "/stream-bytes/{n}": { + "get": { + "tags": [ + "Dynamic data" + ], + "summary": "Streams n random bytes generated with given seed, at given chunk size per packet.", + "parameters": [ + { + "name": "n", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "200": { + "description": "Bytes.", + "content": {} + } + } + } + }, + "/stream/{n}": { + "get": { + "tags": [ + "Dynamic data" + ], + "summary": "Stream n JSON responses", + "parameters": [ + { + "name": "n", + "in": "path", + "required": true, + "schema": {} + } + ], + "responses": { + "200": { + "description": "Streamed JSON responses.", + "content": {} + } + } + } + }, + "/user-agent": { + "get": { + "tags": [ + "Request inspection" + ], + "summary": "Return the incoming requests's User-Agent header.", + "responses": { + "200": { + "description": "The request's User-Agent header.", + "content": {} + } + } + } + }, + "/uuid": { + "get": { + "tags": [ + "Dynamic data" + ], + "summary": "Return a UUID4.", + "responses": { + "200": { + "description": "A UUID4.", + "content": {} + } + } + } + }, + "/xml": { + "get": { + "tags": [ + "Response formats" + ], + "summary": "Returns a simple XML document.", + "responses": { + "200": { + "description": "An XML document.", + "content": {} + } + } + } + } + }, + "components": {} +} \ No newline at end of file diff --git a/resources/dev/wallarm_api.db b/resources/dev/wallarm_api.db new file mode 100644 index 0000000..35e2bb9 Binary files /dev/null and b/resources/dev/wallarm_api.db differ diff --git a/resources/test/database/wallarm_api.db b/resources/test/database/wallarm_api.db new file mode 100644 index 0000000..1c87e05 Binary files /dev/null and b/resources/test/database/wallarm_api.db differ