diff --git a/cmd/hz/app/app.go b/cmd/hz/app/app.go index 116df93cc..739db8823 100644 --- a/cmd/hz/app/app.go +++ b/cmd/hz/app/app.go @@ -191,6 +191,7 @@ func Init() *cli.App { customLayoutData := cli.StringFlag{Name: "customize_layout_data_path", Usage: "Specify the path for layout template render data.", Destination: &globalArgs.CustomizeLayoutData} customPackage := cli.StringFlag{Name: "customize_package", Usage: "Specify the path for package template.", Destination: &globalArgs.CustomizePackage} handlerByMethod := cli.BoolFlag{Name: "handler_by_method", Usage: "Generate a separate handler file for each method.", Destination: &globalArgs.HandlerByMethod} + trimGoPackage := cli.StringFlag{Name: "trim_gopackage", Aliases: []string{"trim_pkg"}, Usage: "Trim the prefix of go_package for protobuf.", Destination: &globalArgs.TrimGoPackage} // app app := cli.NewApp() @@ -225,6 +226,7 @@ func Init() *cli.App { &thriftOptionsFlag, &protoOptionsFlag, &optPkgFlag, + &trimGoPackage, &noRecurseFlag, &forceNewFlag, &enableExtendsFlag, @@ -261,6 +263,7 @@ func Init() *cli.App { &thriftOptionsFlag, &protoOptionsFlag, &optPkgFlag, + &trimGoPackage, &noRecurseFlag, &enableExtendsFlag, &sortRouterFlag, @@ -291,6 +294,7 @@ func Init() *cli.App { &thriftOptionsFlag, &protoOptionsFlag, &noRecurseFlag, + &trimGoPackage, &jsonEnumStrFlag, &unsetOmitemptyFlag, @@ -318,6 +322,7 @@ func Init() *cli.App { &protoOptionsFlag, &noRecurseFlag, &enableExtendsFlag, + &trimGoPackage, &jsonEnumStrFlag, &queryEnumIntFlag, diff --git a/cmd/hz/config/argument.go b/cmd/hz/config/argument.go index 4e6d75ed5..48e381297 100644 --- a/cmd/hz/config/argument.go +++ b/cmd/hz/config/argument.go @@ -41,12 +41,13 @@ type Argument struct { BaseDomain string // request domain ForceClientDir string // client dir (not use namespace as a subpath) - IdlType string // idl type - IdlPaths []string // master idl path - RawOptPkg []string // user-specified package import path - OptPkgMap map[string]string - Includes []string - PkgPrefix string + IdlType string // idl type + IdlPaths []string // master idl path + RawOptPkg []string // user-specified package import path + OptPkgMap map[string]string + Includes []string + PkgPrefix string + TrimGoPackage string // trim go_package for protobuf, avoid to generate multiple directory Gopath string // $GOPATH Gosrc string // $GOPATH/src diff --git a/cmd/hz/generator/handler.go b/cmd/hz/generator/handler.go index eeab0cf7a..6b353c626 100644 --- a/cmd/hz/generator/handler.go +++ b/cmd/hz/generator/handler.go @@ -97,7 +97,7 @@ func (pkgGen *HttpPackageGenerator) genHandler(pkg *HttpPackage, handlerDir, han tmpHandlerPackage := handlerPackage if len(s.ServiceGenDir) != 0 { tmpHandlerDir = s.ServiceGenDir - tmpHandlerPackage = util.SubPackage(pkgGen.ProjPackage, tmpHandlerDir) + tmpHandlerPackage = util.SubPackage(pkgGen.ProjPackage, strings.TrimPrefix(tmpHandlerDir, "/")) } handler = Handler{ FilePath: filepath.Join(tmpHandlerDir, util.ToSnakeCase(s.Name)+".go"), diff --git a/cmd/hz/generator/router.go b/cmd/hz/generator/router.go index 37366ea23..c5e5cbba9 100644 --- a/cmd/hz/generator/router.go +++ b/cmd/hz/generator/router.go @@ -225,6 +225,11 @@ func (routerNode *RouterNode) Insert(name string, method *HttpMethod, handlerTyp method.RefPackageAlias = c.HandlerPackageAlias } else { // generate handler by service c.Handler = handlerType + "." + method.Name + if len(method.RefPackage) != 0 { + c.Handler = method.RefPackageAlias + "." + method.Name + c.HandlerPackageAlias = method.RefPackageAlias + c.HandlerPackage = method.RefPackage + } } c.HttpMethod = getHttpMethod(method.HTTPMethod) } @@ -409,15 +414,15 @@ func (pkgGen *HttpPackageGenerator) genRouter(pkg *HttpPackage, root *RouterNode Router: root, } - if pkgGen.HandlerByMethod { - handlerMap := make(map[string]string, 1) - hook := func(layer int, node *RouterNode) error { - if len(node.HandlerPackage) != 0 { - handlerMap[node.HandlerPackageAlias] = node.HandlerPackage - } - return nil + handlerMap := make(map[string]string) + hook := func(layer int, node *RouterNode) error { + if len(node.HandlerPackage) != 0 { + handlerMap[node.HandlerPackageAlias] = node.HandlerPackage } - root.DFS(0, hook) + return nil + } + root.DFS(0, hook) + if len(handlerMap) != 0 { router.HandlerPackages = handlerMap } diff --git a/cmd/hz/meta/const.go b/cmd/hz/meta/const.go index c6018814b..a696be8e7 100644 --- a/cmd/hz/meta/const.go +++ b/cmd/hz/meta/const.go @@ -19,7 +19,7 @@ package meta import "runtime" // Version hz version -const Version = "v0.9.0" +const Version = "v0.9.1" const DefaultServiceName = "hertz_service" diff --git a/cmd/hz/protobuf/ast.go b/cmd/hz/protobuf/ast.go index 3b4759a2c..3ccfc9b51 100644 --- a/cmd/hz/protobuf/ast.go +++ b/cmd/hz/protobuf/ast.go @@ -353,6 +353,10 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen val := fileAnnos.(string) clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) } + if proto.HasExtension(f.Desc.Options(), api.E_Cookie) { + hasAnnotation = true + // cookie do nothing + } if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") { clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(string(f.Desc.Name())), f.GoName) } diff --git a/cmd/hz/protobuf/plugin.go b/cmd/hz/protobuf/plugin.go index d6f775cef..fbae89929 100644 --- a/cmd/hz/protobuf/plugin.go +++ b/cmd/hz/protobuf/plugin.go @@ -198,7 +198,7 @@ func (plugin *Plugin) Response(resp *pluginpb.CodeGeneratorResponse) error { } func (plugin *Plugin) Handle(req *pluginpb.CodeGeneratorRequest, args *config.Argument) error { - plugin.fixGoPackage(req, plugin.PkgMap) + plugin.fixGoPackage(req, plugin.PkgMap, args.TrimGoPackage) // new plugin opts := protogen.Options{} @@ -291,12 +291,16 @@ func (plugin *Plugin) Handle(req *pluginpb.CodeGeneratorRequest, args *config.Ar } // fixGoPackage will update go_package to store all the model files in ${model_dir} -func (plugin *Plugin) fixGoPackage(req *pluginpb.CodeGeneratorRequest, pkgMap map[string]string) { +func (plugin *Plugin) fixGoPackage(req *pluginpb.CodeGeneratorRequest, pkgMap map[string]string, trimGoPackage string) { gopkg := plugin.Package for _, f := range req.ProtoFile { if strings.HasPrefix(f.GetPackage(), "google.protobuf") { continue } + if len(trimGoPackage) != 0 && strings.HasPrefix(f.GetOptions().GetGoPackage(), trimGoPackage) { + *f.Options.GoPackage = strings.TrimPrefix(*f.Options.GoPackage, trimGoPackage) + } + opt := getGoPackage(f, pkgMap) if !strings.Contains(opt, gopkg) { if strings.HasPrefix(opt, "/") { @@ -325,7 +329,9 @@ func (plugin *Plugin) fixModelPathAndPackage(pkg string) (impt, path string) { impt = util.PathToImport(plugin.ModelDir, "") + impt } path = util.ImportToPath(impt, "") - impt = plugin.Package + "/" + impt + // bugfix: impt may have "/" suffix + //impt = plugin.Package + "/" + impt + impt = filepath.Join(plugin.Package, impt) if util.IsWindows() { impt = util.PathToImport(impt, "") } diff --git a/cmd/hz/protobuf/tag_test.go b/cmd/hz/protobuf/tag_test.go index 2e7a9ca3e..e8e9561ec 100644 --- a/cmd/hz/protobuf/tag_test.go +++ b/cmd/hz/protobuf/tag_test.go @@ -48,15 +48,15 @@ func TestTagGenerate(t *testing.T) { }, { Annotation: "form", - GeneratedTag: "protobuf:\"bytes,4,opt,name=FormTag\" json:\"FormTag,omitempty\" form:\"form\"", + GeneratedTag: "protobuf:\"bytes,4,opt,name=FormTag\" form:\"form\" json:\"FormTag,omitempty\"", }, { Annotation: "cookie", - GeneratedTag: "protobuf:\"bytes,5,opt,name=CookieTag\" json:\"CookieTag,omitempty\" cookie:\"cookie\"", + GeneratedTag: "protobuf:\"bytes,5,opt,name=CookieTag\" cookie:\"cookie\" json:\"CookieTag,omitempty\"", }, { Annotation: "header", - GeneratedTag: "protobuf:\"bytes,6,opt,name=HeaderTag\" json:\"HeaderTag,omitempty\" header:\"header\"", + GeneratedTag: "protobuf:\"bytes,6,opt,name=HeaderTag\" header:\"header\" json:\"HeaderTag,omitempty\"", }, { Annotation: "body", @@ -64,15 +64,15 @@ func TestTagGenerate(t *testing.T) { }, { Annotation: "go.tag", - GeneratedTag: "bytes,8,opt,name=GoTag\" json:\"json\" form:\"form\" goTag:\"tag\" header:\"header\" query:\"query\"", + GeneratedTag: "bytes,8,opt,name=GoTag\" form:\"form\" goTag:\"tag\" header:\"header\" json:\"json\" query:\"query\"", }, { Annotation: "vd", - GeneratedTag: "bytes,9,opt,name=VdTag\" json:\"VdTag,omitempty\" form:\"VdTag\" query:\"VdTag\" vd:\"$!='?'\"", + GeneratedTag: "bytes,9,opt,name=VdTag\" form:\"VdTag\" json:\"VdTag,omitempty\" query:\"VdTag\" vd:\"$!='?'\"", }, { Annotation: "non", - GeneratedTag: "bytes,10,opt,name=DefaultTag\" json:\"DefaultTag,omitempty\" form:\"DefaultTag\" query:\"DefaultTag\"", + GeneratedTag: "bytes,10,opt,name=DefaultTag\" form:\"DefaultTag\" json:\"DefaultTag,omitempty\" query:\"DefaultTag\"", }, { Annotation: "query required", @@ -92,11 +92,11 @@ func TestTagGenerate(t *testing.T) { }, { Annotation: "go.tag required", - GeneratedTag: "protobuf:\"bytes,15,req,name=ReqGoTag\" query:\"ReqGoTag,required\" form:\"ReqGoTag,required\" json:\"json\"", + GeneratedTag: "protobuf:\"bytes,15,req,name=ReqGoTag\" form:\"ReqGoTag,required\" json:\"json\" query:\"ReqGoTag,required\"", }, { Annotation: "go.tag optional", - GeneratedTag: "bytes,16,opt,name=OptGoTag\" query:\"OptGoTag\" form:\"OptGoTag\" json:\"json\"", + GeneratedTag: "bytes,16,opt,name=OptGoTag\" form:\"OptGoTag\" json:\"json\" query:\"OptGoTag\"", }, { Annotation: "go tag cover query", diff --git a/cmd/hz/protobuf/tags.go b/cmd/hz/protobuf/tags.go index 26c6cc7cf..4880f99a0 100644 --- a/cmd/hz/protobuf/tags.go +++ b/cmd/hz/protobuf/tags.go @@ -437,8 +437,7 @@ func injectTagsToStructTags(f protoreflect.FieldDescriptor, out *structTags, nee tags.Remove(t.Key) } } - // protobuf tag as first - sort.Sort(tags[1:]) + sort.Sort(tags) for _, t := range tags { if disableTag { *out = append(*out, [2]string{t.Key, "-"}) diff --git a/cmd/hz/thrift/ast.go b/cmd/hz/thrift/ast.go index d577b95db..83b27fd7f 100644 --- a/cmd/hz/thrift/ast.go +++ b/cmd/hz/thrift/ast.go @@ -321,6 +321,10 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ hasFormAnnotation = true clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", fileName, field.GoName().String()) } + if anno := getAnnotation(field.Annotations, AnnotationCookie); len(anno) > 0 { + hasAnnotation = true + // cookie do nothing + } if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") { clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(field.GetName()), field.GoName().String()) } diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index ec962dafa..1ea7b1e36 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -692,6 +692,31 @@ func TestBind_FileBind(t *testing.T) { assert.DeepEqual(t, fileName, (**s.D).N.Filename) } +func TestBind_FileBindWithNoFile(t *testing.T) { + var s struct { + A *multipart.FileHeader `file_name:"a"` + B *multipart.FileHeader `form:"b"` + C *multipart.FileHeader + } + fileName := "binder_test.go" + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetFile("a", fileName). + SetFile("b", fileName) + // to parse multipart files + req2 := req2.GetHTTP1Request(req.Req) + _ = req2.String() + err := DefaultBinder().Bind(req.Req, &s, nil) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + assert.DeepEqual(t, fileName, s.A.Filename) + assert.DeepEqual(t, fileName, s.B.Filename) + if s.C != nil { + t.Fatalf("expected a nil for s.C") + } +} + func TestBind_FileSliceBind(t *testing.T) { type Nest struct { N *[]*multipart.FileHeader `form:"b"` diff --git a/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go b/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go index ae32dfea5..b11417728 100644 --- a/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go +++ b/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go @@ -20,6 +20,7 @@ import ( "fmt" "reflect" + "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -52,7 +53,8 @@ func (d *fileTypeDecoder) Decode(req *protocol.Request, params param.Params, req } file, err := req.FormFile(fileName) if err != nil { - return fmt.Errorf("can not get file '%s', err: %v", fileName, err) + hlog.SystemLogger().Warnf("can not get file '%s' form request, reason: %v, so skip '%s' field binding", fileName, err, d.fieldName) + return nil } if field.Kind() == reflect.Ptr { t := field.Type() @@ -105,11 +107,13 @@ func (d *fileTypeDecoder) fileSliceDecode(req *protocol.Request, params param.Pa } multipartForm, err := req.MultipartForm() if err != nil { - return fmt.Errorf("can not get multipartForm info, err: %v", err) + hlog.SystemLogger().Warnf("can not get MultipartForm from request, reason: %v, so skip '%s' field binding", fileName, err, d.fieldName) + return nil } files, exist := multipartForm.File[fileName] if !exist { - return fmt.Errorf("the file '%s' is not existed", fileName) + hlog.SystemLogger().Warnf("the file '%s' is not existed in request, so skip '%s' field binding", fileName, d.fieldName) + return nil } if field.Kind() == reflect.Array { diff --git a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go index c2af2c030..4a3ded138 100644 --- a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go @@ -83,7 +83,7 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param. var vv reflect.Value vv, err := stringToValue(t, text, req, params, d.config) if err != nil { - hlog.Infof("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) + hlog.SystemLogger().Infof("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) return nil } field.Set(ReferenceValue(vv, ptrDepth)) @@ -92,7 +92,7 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param. err = hjson.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) if err != nil { - hlog.Infof("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) + hlog.SystemLogger().Infof("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) } return nil diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index 18f184379..73188d1b9 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -188,6 +188,7 @@ func WithMaxRequestBodySize(bs int) config.Option { // WithMaxKeepBodySize sets max size of request/response body to keep when recycled. Unit: byte // // Body buffer which larger than this size will be put back into buffer poll. +// Note: If memory pressure is high, try setting the value to 0. func WithMaxKeepBodySize(bs int) config.Option { return config.Option{F: func(o *config.Options) { o.MaxKeepBodySize = bs diff --git a/pkg/common/config/option.go b/pkg/common/config/option.go index 958d9b3a3..38cbfb338 100644 --- a/pkg/common/config/option.go +++ b/pkg/common/config/option.go @@ -220,9 +220,9 @@ func NewOptions(opts []Option) *Options { // an error will be returned MaxRequestBodySize: defaultMaxRequestBodySize, - // max reserved body buffer size when reset Request & Request - // If the body size exceeds this value, then the buffer won't be put to - // sync.Pool to prevent OOM + // max reserved body buffer size when reset Request & Response + // If the body size exceeds this value, then the buffer will be put to + // sync.Pool instead of hold by Request/Response directly. MaxKeepBodySize: defaultMaxRequestBodySize, // only accept GET request diff --git a/pkg/common/utils/utils_test.go b/pkg/common/utils/utils_test.go index 92873b51d..225d8e42e 100644 --- a/pkg/common/utils/utils_test.go +++ b/pkg/common/utils/utils_test.go @@ -49,7 +49,34 @@ import ( // test assert func func TestUtilsAssert(t *testing.T) { - // nothing to test + assertPanic := func() (panicked bool) { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + Assert(false, "should panic") + return false + } + + // Checking if the assertPanic function results in a panic as expected. + // We expect a true value because it should panic. + assert.DeepEqual(t, true, assertPanic()) + + // Checking if a true assertion does not result in a panic. + // We create a wrapper around Assert to capture if it panics when it should not. + noPanic := func() (panicked bool) { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + Assert(true, "should not panic") + return false + } + + // We expect a false value because it should not panic. + assert.DeepEqual(t, false, noPanic()) } func TestUtilsIsTrueString(t *testing.T) { @@ -142,3 +169,18 @@ func TestFilterContentType(t *testing.T) { contentType = FilterContentType(contentType) assert.DeepEqual(t, "text/plain", contentType) } + +func TestNormalizeHeaderKeyEdgeCases(t *testing.T) { + empty := []byte("") + NormalizeHeaderKey(empty, false) + assert.DeepEqual(t, []byte(""), empty) + NormalizeHeaderKey(empty, true) + assert.DeepEqual(t, []byte(""), empty) +} + +func TestFilterContentTypeEdgeCases(t *testing.T) { + simpleContentType := "text/plain" + assert.DeepEqual(t, "text/plain", FilterContentType(simpleContentType)) + complexContentType := "text/html; charset=utf-8; format=flowed" + assert.DeepEqual(t, "text/html", FilterContentType(complexContentType)) +} diff --git a/pkg/protocol/request.go b/pkg/protocol/request.go index d04731867..8e4b40bf7 100644 --- a/pkg/protocol/request.go +++ b/pkg/protocol/request.go @@ -65,7 +65,7 @@ import ( ) var ( - errMissingFile = errors.NewPublic("http: no such file") + ErrMissingFile = errors.NewPublic("http: no such file") responseBodyPool bytebufferpool.Pool requestBodyPool bytebufferpool.Pool @@ -313,7 +313,7 @@ func (req *Request) FormFile(name string) (*multipart.FileHeader, error) { } fhh := mf.File[name] if fhh == nil { - return nil, errMissingFile + return nil, ErrMissingFile } return fhh[0], nil } diff --git a/pkg/route/engine.go b/pkg/route/engine.go index 30b899dac..4881ee9cf 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -75,6 +75,7 @@ import ( "github.com/cloudwego/hertz/pkg/protocol/http1" "github.com/cloudwego/hertz/pkg/protocol/http1/factory" "github.com/cloudwego/hertz/pkg/protocol/suite" + "github.com/cloudwego/hertz/pkg/route/param" ) const unknownTransporterName = "unknown" @@ -749,6 +750,12 @@ func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) { return } + // if Params is re-assigned in HandlerFunc and the capacity is not enough we need to realloc + maxParams := int(engine.maxParams) + if cap(ctx.Params) < maxParams { + ctx.Params = make(param.Params, 0, maxParams) + } + // Find root of the tree for the given HTTP method t := engine.trees paramsPointer := &ctx.Params diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index b3e0adb30..ea1bc5fd9 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -1029,3 +1029,33 @@ func TestAcquireHijackConn(t *testing.T) { assert.DeepEqual(t, engine, hijackConn.e) assert.DeepEqual(t, conn, hijackConn.Conn) } + +func TestHandleParamsReassignInHandleFunc(t *testing.T) { + e := NewEngine(config.NewOptions(nil)) + routes := []string{ + "/:a/:b/:c", + } + for _, r := range routes { + e.GET(r, func(c context.Context, ctx *app.RequestContext) { + ctx.Params = make([]param.Param, 1) + ctx.String(consts.StatusOK, "") + }) + } + testRoutes := []string{ + "/aaa/bbb/ccc", + "/asd/alskja/alkdjad", + "/asd/alskja/alkdjad", + "/asd/alskja/alkdjad", + "/asd/alskja/alkdjad", + "/alksjdlakjd/ooo/askda", + "/alksjdlakjd/ooo/askda", + "/alksjdlakjd/ooo/askda", + } + ctx := e.ctxPool.Get().(*app.RequestContext) + for _, tr := range testRoutes { + r := protocol.NewRequest(http.MethodGet, tr, nil) + r.CopyTo(&ctx.Request) + e.ServeHTTP(context.Background(), ctx) + ctx.ResetWithoutConn() + } +} diff --git a/version.go b/version.go index fcfce1ec8..3a0529dfc 100644 --- a/version.go +++ b/version.go @@ -19,5 +19,5 @@ package hertz // Name and Version info of this framework, used for statistics and debug const ( Name = "Hertz" - Version = "v0.9.1" + Version = "v0.9.2" )