From 8f68d1802889b0805c8825d0f0ac63808fdcea0e Mon Sep 17 00:00:00 2001 From: Eliott Bouhana Date: Thu, 14 Nov 2024 18:04:21 +0100 Subject: [PATCH] internal/appsec: refactor blocking actions to call Block() Signed-off-by: Eliott Bouhana --- internal/appsec/emitter/httpsec/http.go | 58 +++----- .../appsec/emitter/waf/actions/actions.go | 2 +- .../emitter/waf/actions/actions_test.go | 6 +- internal/appsec/emitter/waf/actions/block.go | 139 ++++++++++++------ .../emitter/waf/actions/http_redirect.go | 24 ++- internal/appsec/waf_test.go | 56 ++++++- 6 files changed, 186 insertions(+), 99 deletions(-) diff --git a/internal/appsec/emitter/httpsec/http.go b/internal/appsec/emitter/httpsec/http.go index 07f6dd9ba8..6adcfbf8fa 100644 --- a/internal/appsec/emitter/httpsec/http.go +++ b/internal/appsec/emitter/httpsec/http.go @@ -56,20 +56,31 @@ type ( func (HandlerOperationArgs) IsArgOf(*HandlerOperation) {} func (HandlerOperationRes) IsResultOf(*HandlerOperation) {} -func StartOperation(ctx context.Context, args HandlerOperationArgs) (*HandlerOperation, *atomic.Pointer[actions.BlockHTTP], context.Context) { - wafOp, ctx := waf.StartContextOperation(ctx) +func StartOperation(w http.ResponseWriter, r *http.Request, pathParams map[string]string, opts *Config) (*HandlerOperation, *atomic.Bool, context.Context) { + wafOp, ctx := waf.StartContextOperation(r.Context()) op := &HandlerOperation{ Operation: dyngo.NewOperation(wafOp), ContextOperation: wafOp, } - // We need to use an atomic pointer to store the action because the action may be created asynchronously in the future - var action atomic.Pointer[actions.BlockHTTP] + var blocked atomic.Bool dyngo.OnData(op, func(a *actions.BlockHTTP) { - action.Store(a) + a.Handler.ServeHTTP(w, r) + for _, f := range opts.OnBlock { + f() + } }) - return op, &action, dyngo.StartAndRegisterOperation(ctx, op, args) + return op, &blocked, dyngo.StartAndRegisterOperation(ctx, op, HandlerOperationArgs{ + Method: r.Method, + RequestURI: r.RequestURI, + Host: r.Host, + RemoteAddr: r.RemoteAddr, + Headers: r.Header, + Cookies: makeCookies(r.Cookies()), + QueryParams: r.URL.Query(), + PathParams: pathParams, + }) } // Finish the HTTP handler operation and its children operations and write everything to the service entry span. @@ -125,18 +136,7 @@ func BeforeHandle( opts.ResponseHeaderCopier = defaultWrapHandlerConfig.ResponseHeaderCopier } - op, blockAtomic, ctx := StartOperation(r.Context(), HandlerOperationArgs{ - Method: r.Method, - RequestURI: r.RequestURI, - Host: r.Host, - RemoteAddr: r.RemoteAddr, - Headers: r.Header, - Cookies: makeCookies(r.Cookies()), - QueryParams: r.URL.Query(), - PathParams: pathParams, - }) - tr := r.WithContext(ctx) - var blocked atomic.Bool + op, blocked, ctx := StartOperation(w, r, pathParams, opts) afterHandle := func() { var statusCode int @@ -147,28 +147,9 @@ func BeforeHandle( Headers: opts.ResponseHeaderCopier(w), StatusCode: statusCode, }, span) - - if blockPtr := blockAtomic.Swap(nil); blockPtr != nil { - blockPtr.Handler.ServeHTTP(w, tr) - blocked.Store(true) - } - - // Execute the onBlock functions to make sure blocking works properly - // in case we are instrumenting the Gin framework - if blocked.Load() { - for _, f := range opts.OnBlock { - f() - } - } - } - - if blockPtr := blockAtomic.Swap(nil); blockPtr != nil { - // handler is replaced - blockPtr.Handler.ServeHTTP(w, tr) - blocked.Store(true) } - return w, tr, afterHandle, blocked.Load() + return w, r.WithContext(ctx), afterHandle, blocked.Load() } // WrapHandler wraps the given HTTP handler with the abstract HTTP operation defined by HandlerOperationArgs and @@ -177,7 +158,6 @@ func BeforeHandle( // It is a specific patch meant for Gin, for which we must abort the // context since it uses a queue of handlers and it's the only way to make // sure other queued handlers don't get executed. -// TODO: this patch must be removed/improved when we rework our actions/operations system func WrapHandler(handler http.Handler, span ddtrace.Span, pathParams map[string]string, opts *Config) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tw, tr, afterHandle, handled := BeforeHandle(w, r, span, pathParams, opts) diff --git a/internal/appsec/emitter/waf/actions/actions.go b/internal/appsec/emitter/waf/actions/actions.go index 4eabcfaff6..1f9cd30bd5 100644 --- a/internal/appsec/emitter/waf/actions/actions.go +++ b/internal/appsec/emitter/waf/actions/actions.go @@ -25,7 +25,7 @@ var actionHandlers = map[string]actionHandler{} func registerActionHandler(aType string, handler actionHandler) { if _, ok := actionHandlers[aType]; ok { - log.Warn("appsec: action type `%s` already registered", aType) + log.Debug("appsec: action type `%s` already registered", aType) return } actionHandlers[aType] = handler diff --git a/internal/appsec/emitter/waf/actions/actions_test.go b/internal/appsec/emitter/waf/actions/actions_test.go index 1e77c8e2d1..434c091057 100644 --- a/internal/appsec/emitter/waf/actions/actions_test.go +++ b/internal/appsec/emitter/waf/actions/actions_test.go @@ -17,9 +17,9 @@ import ( func TestNewHTTPBlockRequestAction(t *testing.T) { mux := http.NewServeMux() srv := httptest.NewServer(mux) - mux.HandleFunc("/json", newHTTPBlockRequestAction(403, "json").ServeHTTP) - mux.HandleFunc("/html", newHTTPBlockRequestAction(403, "html").ServeHTTP) - mux.HandleFunc("/auto", newHTTPBlockRequestAction(403, "auto").ServeHTTP) + mux.HandleFunc("/json", newHTTPBlockRequestAction(403, BlockingTemplateJSON).ServeHTTP) + mux.HandleFunc("/html", newHTTPBlockRequestAction(403, BlockingTemplateHTML).ServeHTTP) + mux.HandleFunc("/auto", newHTTPBlockRequestAction(403, BlockingTemplateAuto).ServeHTTP) defer srv.Close() t.Run("json", func(t *testing.T) { diff --git a/internal/appsec/emitter/waf/actions/block.go b/internal/appsec/emitter/waf/actions/block.go index ae802b60bd..901141d159 100644 --- a/internal/appsec/emitter/waf/actions/block.go +++ b/internal/appsec/emitter/waf/actions/block.go @@ -47,15 +47,23 @@ func init() { registerActionHandler("block_request", NewBlockAction) } +const ( + BlockingTemplateJSON blockingTemplateType = "json" + BlockingTemplateHTML blockingTemplateType = "html" + BlockingTemplateAuto blockingTemplateType = "auto" +) + type ( + blockingTemplateType string + // blockActionParams are the dynamic parameters to be provided to a "block_request" // action type upon invocation blockActionParams struct { // GRPCStatusCode is the gRPC status code to be returned. Since 0 is the OK status, the value is nullable to // be able to distinguish between unset and defaulting to Abort (10), or set to OK (0). - GRPCStatusCode *int `mapstructure:"grpc_status_code,omitempty"` - StatusCode int `mapstructure:"status_code"` - Type string `mapstructure:"type,omitempty"` + GRPCStatusCode *int `mapstructure:"grpc_status_code,omitempty"` + StatusCode int `mapstructure:"status_code"` + Type blockingTemplateType `mapstructure:"type,omitempty"` } // GRPCWrapper is an opaque prototype abstraction for a gRPC handler (to avoid importing grpc) // that returns a status code and an error @@ -70,6 +78,12 @@ type ( BlockHTTP struct { http.Handler } + + HTTPBlockHandlerConfig struct { + Template []byte + ContentType string + StatusCode int + } ) func (a *BlockGRPC) EmitData(op dyngo.Operation) { @@ -83,32 +97,28 @@ func (a *BlockHTTP) EmitData(op dyngo.Operation) { } func newGRPCBlockRequestAction(status int) *BlockGRPC { - return &BlockGRPC{GRPCWrapper: newGRPCBlockHandler(status)} -} - -func newGRPCBlockHandler(status int) GRPCWrapper { - return func() (uint32, error) { + return &BlockGRPC{GRPCWrapper: func() (uint32, error) { return uint32(status), &events.BlockingSecurityEvent{} - } + }} } func blockParamsFromMap(params map[string]any) (blockActionParams, error) { grpcCode := 10 - p := blockActionParams{ - Type: "auto", + parsedParams := blockActionParams{ + Type: BlockingTemplateAuto, StatusCode: 403, GRPCStatusCode: &grpcCode, } - if err := mapstructure.WeakDecode(params, &p); err != nil { - return p, err + if err := mapstructure.WeakDecode(params, &parsedParams); err != nil { + return parsedParams, err } - if p.GRPCStatusCode == nil { - p.GRPCStatusCode = &grpcCode + if parsedParams.GRPCStatusCode == nil { + parsedParams.GRPCStatusCode = &grpcCode } - return p, nil + return parsedParams, nil } // NewBlockAction creates an action for the "block_request" action type @@ -124,38 +134,79 @@ func NewBlockAction(params map[string]any) []Action { } } -func newHTTPBlockRequestAction(status int, template string) *BlockHTTP { - return &BlockHTTP{Handler: newBlockHandler(status, template)} +func newHTTPBlockRequestAction(statusCode int, template blockingTemplateType) *BlockHTTP { + return &BlockHTTP{Handler: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + template := template + if template == BlockingTemplateAuto { + template = blockingTemplateTypeFromHeaders(request.Header) + } + + blocker, found := UnwrapBlocker(writer) + if !found { + return + } + + defer blocker() + + // Remove all headers to avoid leaking information + for key := range writer.Header() { + delete(writer.Header(), key) + } + + writer.Header().Set("Content-Type", template.ContentType()) + writer.WriteHeader(statusCode) + writer.Write(template.Template()) + })} +} + +func blockingTemplateTypeFromHeaders(headers http.Header) blockingTemplateType { + hdr := headers.Get("Accept") + htmlIdx := strings.Index(hdr, "text/html") + jsonIdx := strings.Index(hdr, "application/json") + // Switch to html handler if text/html comes before application/json in the Accept header + if htmlIdx != -1 && (jsonIdx == -1 || htmlIdx < jsonIdx) { + return BlockingTemplateHTML + } + + return BlockingTemplateJSON +} + +func (typ blockingTemplateType) Template() []byte { + if typ == BlockingTemplateHTML { + return blockedTemplateHTML + } + + return blockedTemplateJSON } -// newBlockHandler creates, initializes and returns a new BlockRequestAction -func newBlockHandler(status int, template string) http.Handler { - htmlHandler := newBlockRequestHandler(status, "text/html", blockedTemplateHTML) - jsonHandler := newBlockRequestHandler(status, "application/json", blockedTemplateJSON) - switch template { - case "json": - return jsonHandler - case "html": - return htmlHandler - default: - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - h := jsonHandler - hdr := r.Header.Get("Accept") - htmlIdx := strings.Index(hdr, "text/html") - jsonIdx := strings.Index(hdr, "application/json") - // Switch to html handler if text/html comes before application/json in the Accept header - if htmlIdx != -1 && (jsonIdx == -1 || htmlIdx < jsonIdx) { - h = htmlHandler - } - h.ServeHTTP(w, r) - }) +func (typ blockingTemplateType) ContentType() string { + if typ == BlockingTemplateHTML { + return "text/html" } + + return "application/json" } -func newBlockRequestHandler(status int, ct string, payload []byte) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", ct) - w.WriteHeader(status) - w.Write(payload) +// UnwrapBlocker unwraps the right struct methods from contrib/internal/httptrace.responseWriter +// and returns if the Block() function and if if was found. +func UnwrapBlocker(writer http.ResponseWriter) (func(), bool) { + // this is a copy of the contrib/internal/httptrace.responseWriter interface + wrapped, ok := writer.(interface { + Block() + Status() int }) + if !ok { + // Somehow the http handler ended up not being instrumented, we cannot block the request + log.Debug("appsec: could not block request, response writer is not wrapped") + return nil, false + } + + if wrapped.Status() != 0 { + // The request has already been started to be handled by the user code or was blocked before + // if we try to write a response now, worse case scenario the request will be malformed + // treat it as if we can't block it + return nil, false + } + + return wrapped.Block, true } diff --git a/internal/appsec/emitter/waf/actions/http_redirect.go b/internal/appsec/emitter/waf/actions/http_redirect.go index 3cdca4c818..04122649a2 100644 --- a/internal/appsec/emitter/waf/actions/http_redirect.go +++ b/internal/appsec/emitter/waf/actions/http_redirect.go @@ -25,9 +25,9 @@ func init() { } func redirectParamsFromMap(params map[string]any) (redirectActionParams, error) { - var p redirectActionParams - err := mapstructure.WeakDecode(params, &p) - return p, err + var parsedParams redirectActionParams + err := mapstructure.WeakDecode(params, &parsedParams) + return parsedParams, err } func newRedirectRequestAction(status int, loc string) *BlockHTTP { @@ -38,9 +38,23 @@ func newRedirectRequestAction(status int, loc string) *BlockHTTP { // If location is not set we fall back on a default block action if loc == "" { - return &BlockHTTP{Handler: newBlockHandler(http.StatusForbidden, string(blockedTemplateJSON))} + return newHTTPBlockRequestAction(http.StatusForbidden, BlockingTemplateAuto) } - return &BlockHTTP{Handler: http.RedirectHandler(loc, status)} + + redirectHandler := http.RedirectHandler(loc, status) + return &BlockHTTP{Handler: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + blocker, found := UnwrapBlocker(writer) + if !found { + return + } + + defer blocker() + + for key := range request.Header { + delete(request.Header, key) + } + redirectHandler.ServeHTTP(writer, request) + })} } // NewRedirectAction creates an action for the "redirect_request" action type diff --git a/internal/appsec/waf_test.go b/internal/appsec/waf_test.go index 7e169ffb7a..8b32f2dc83 100644 --- a/internal/appsec/waf_test.go +++ b/internal/appsec/waf_test.go @@ -335,15 +335,22 @@ func TestBlocking(t *testing.T) { w.Write([]byte("Hello World!\n")) }) mux.HandleFunc("/user", func(w http.ResponseWriter, r *http.Request) { - if err := pAppsec.SetUser(r.Context(), r.Header.Get("test-usr")); err != nil { + if r.Header.Get("write-before-block") != "" { + w.WriteHeader(204) + } + + if err := pAppsec.SetUser(r.Context(), r.Header.Get("test-usr")); err != nil && r.Header.Get("write-after-block") == "" { return } w.Write([]byte("Hello World!\n")) }) mux.HandleFunc("/body", func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("write-before-block") != "" { + w.WriteHeader(204) + } buf := new(strings.Builder) io.Copy(buf, r.Body) - if err := pAppsec.MonitorParsedHTTPBody(r.Context(), buf.String()); err != nil { + if err := pAppsec.MonitorParsedHTTPBody(r.Context(), buf.String()); err != nil && r.Header.Get("write-after-block") == "" { return } w.Write([]byte("Hello World!\n")) @@ -395,6 +402,20 @@ func TestBlocking(t *testing.T) { status: 403, ruleMatch: userBlockingRule, }, + { + name: "user/no-write-after-block", + headers: map[string]string{"test-usr": "blocked-user-1", "write-after-block": "true"}, + endpoint: "/user", + status: 403, + ruleMatch: userBlockingRule, + }, + { + name: "user/cannot-block-because-write-before-block", + headers: map[string]string{"test-usr": "blocked-user-1", "write-before-block": "true"}, + endpoint: "/user", + status: 204, + ruleMatch: userBlockingRule, + }, // This test checks that IP blocking happens BEFORE user blocking, since user blocking needs the request handler // to be invoked while IP blocking doesn't { @@ -417,6 +438,22 @@ func TestBlocking(t *testing.T) { reqBody: "$globals", ruleMatch: bodyBlockingRule, }, + { + name: "body/no-write-after-block", + headers: map[string]string{"write-after-block": "true"}, + endpoint: "/body", + status: 403, + reqBody: "$globals", + ruleMatch: bodyBlockingRule, + }, + { + name: "body/cannot-block-because-write-before-block", + headers: map[string]string{"write-before-block": "true"}, + endpoint: "/body", + status: 204, + reqBody: "$globals", + ruleMatch: bodyBlockingRule, + }, } { t.Run(tc.name, func(t *testing.T) { mt := mocktracer.Start() @@ -430,12 +467,17 @@ func TestBlocking(t *testing.T) { require.NoError(t, err) defer res.Body.Close() require.Equal(t, tc.status, res.StatusCode) - b, err := io.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) require.NoError(t, err) - if tc.status == 200 { - require.Equal(t, "Hello World!\n", string(b)) - } else { - require.NotEqual(t, "Hello World!\n", string(b)) + switch tc.status { + case 200: + require.Equal(t, "Hello World!\n", string(body)) + case 204: + require.Empty(t, string(body)) + case 403: + require.Contains(t, string(body), "Security provided by Datadog") + default: + panic("unexpected status code") } if tc.ruleMatch != "" { spans := mt.FinishedSpans()