From 4f7a2f641e0d851c7723a94fbe692beb63e10aab Mon Sep 17 00:00:00 2001 From: mark4z Date: Wed, 22 Nov 2023 23:36:58 +0800 Subject: [PATCH] [feature] cors support options request --- pixiu/pkg/common/constant/http.go | 10 +++++---- pixiu/pkg/common/router/router.go | 2 +- pixiu/pkg/filter/cors/cors.go | 36 ++++++++++++++++++++----------- 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/pixiu/pkg/common/constant/http.go b/pixiu/pkg/common/constant/http.go index 69fa073fb..1a17040c3 100644 --- a/pixiu/pkg/common/constant/http.go +++ b/pixiu/pkg/common/constant/http.go @@ -21,6 +21,7 @@ const ( HeaderKeyContextType = "Content-Type" HeaderKeyAccessControlAllowOrigin = "Access-Control-Allow-Origin" + HeaderKeyAccessControlAllowHeaders = "Access-Control-Allow-Headers" HeaderKeyAccessControlExposeHeaders = "Access-Control-Expose-Headers" HeaderKeyAccessControlAllowMethods = "Access-Control-Allow-Methods" HeaderKeyAccessControlMaxAge = "Access-Control-Max-Age" @@ -48,10 +49,11 @@ const ( ) const ( - Get = "GET" - Put = "PUT" - Post = "POST" - Delete = "DELETE" + Get = "GET" + Put = "PUT" + Post = "POST" + Delete = "DELETE" + Options = "OPTIONS" ) const ( diff --git a/pixiu/pkg/common/router/router.go b/pixiu/pkg/common/router/router.go index c807d3bd7..d8cbc37d5 100644 --- a/pixiu/pkg/common/router/router.go +++ b/pixiu/pkg/common/router/router.go @@ -137,7 +137,7 @@ func (rm *RouterCoordinator) OnAddRouter(r *model.Router) { rm.rw.Lock() defer rm.rw.Unlock() if r.Match.Methods == nil { - r.Match.Methods = []string{constant.Get, constant.Put, constant.Delete, constant.Post} + r.Match.Methods = []string{constant.Get, constant.Put, constant.Delete, constant.Post, constant.Options} } isPrefix := r.Match.Prefix != "" for _, method := range r.Match.Methods { diff --git a/pixiu/pkg/filter/cors/cors.go b/pixiu/pkg/filter/cors/cors.go index 4269c7565..b76b87159 100644 --- a/pixiu/pkg/filter/cors/cors.go +++ b/pixiu/pkg/filter/cors/cors.go @@ -25,6 +25,7 @@ import ( "github.com/apache/dubbo-go-pixiu/pixiu/pkg/common/constant" "github.com/apache/dubbo-go-pixiu/pixiu/pkg/common/extension/filter" "github.com/apache/dubbo-go-pixiu/pixiu/pkg/context/http" + "github.com/apache/dubbo-go-pixiu/pkg/http/headers" ) const ( @@ -79,41 +80,50 @@ func (factory *FilterFactory) PrepareFilterChain(ctx *http.HttpContext, chain fi } func (f *Filter) Decode(ctx *http.HttpContext) filter.FilterStatus { - f.handleCors(ctx) - return filter.Continue -} - -func (f *Filter) handleCors(ctx *http.HttpContext) { + writer := ctx.Writer c := f.cfg if c == nil { - return + return filter.Continue + } + if ctx.GetHeader(headers.Origin) == "" { + // not a cors request + return filter.Continue } domains := c.AllowOrigin if len(domains) != 0 { for _, domain := range domains { - if ctx.Request.Host == domain || ctx.Request.URL.Host == domain || - ctx.GetHeader("Host") == domain || ctx.GetHeader("host") == domain { - ctx.SourceResp.(*stdHttp.Response).Header.Add(constant.HeaderKeyAccessControlAllowOrigin, domain) + if domain == "*" || ctx.GetHeader("Origin") == domain { + writer.Header().Add(constant.HeaderKeyAccessControlAllowOrigin, domain) + continue } } } if c.AllowHeaders != "" { - ctx.SourceResp.(*stdHttp.Response).Header.Add(constant.HeaderKeyAccessControlExposeHeaders, c.AllowHeaders) + writer.Header().Add(constant.HeaderKeyAccessControlAllowHeaders, c.AllowHeaders) + } + + if c.ExposeHeaders != "" { + writer.Header().Add(constant.HeaderKeyAccessControlExposeHeaders, c.ExposeHeaders) } if c.AllowMethods != "" { - ctx.SourceResp.(*stdHttp.Response).Header.Add(constant.HeaderKeyAccessControlAllowMethods, c.AllowMethods) + writer.Header().Add(constant.HeaderKeyAccessControlAllowMethods, c.AllowMethods) } if c.MaxAge != "" { - ctx.SourceResp.(*stdHttp.Response).Header.Add(constant.HeaderKeyAccessControlMaxAge, c.MaxAge) + writer.Header().Add(constant.HeaderKeyAccessControlMaxAge, c.MaxAge) } if c.AllowCredentials { - ctx.SourceResp.(*stdHttp.Response).Header.Add(constant.HeaderKeyAccessControlAllowCredentials, "true") + writer.Header().Add(constant.HeaderKeyAccessControlAllowCredentials, "true") } + if ctx.Request.Method == stdHttp.MethodOptions { + ctx.SendLocalReply(stdHttp.StatusOK, nil) + return filter.Stop + } + return filter.Continue } func (factory *FilterFactory) Apply() error {