diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7f46f7596..ba2af0a6c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,15 +19,13 @@ jobs: key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - - name: Lint run: | go vet -stdmethods=false $(go list ./...) go install mvdan.cc/gofumpt@v0.2.0 - test -z "$(gofumpt -l -extra .)" - + test -z "$(gofumpt -l -extra .)" - name: Unit Test run: go test -race -covermode=atomic -coverprofile=coverage.txt ./... - name: Codecov - run: bash <(curl -s https://codecov.io/bash) + run: bash <(curl -s https://codecov.io/bash) \ No newline at end of file diff --git a/README.md b/README.md index 8ee749617..7f9a4056b 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,11 @@ Hertz [həːts] is a high-usability, high-performance and high-extensibility Gol - High performance Hertz uses the self-developed high-performance network library Netpoll by default. In some special scenarios, compared to Go Net, Hertz has certain advantages in QPS and time delay. For performance data, please refer to the Echo data in the figure below. - ![Performance](images/performance.png) + + Comparison of four frameworks: + ![Performance](images/performance-4.png) + Latency comparison of three frameworks: + ![Performance](images/performance-3.png) For detailed performance data, please refer to [hertz-benchmark](https://github.com/cloudwego/hertz-benchmark). - High extensibility @@ -52,6 +56,34 @@ Hertz [həːts] is a high-usability, high-performance and high-extensibility Gol - [Netpoll](https://github.com/cloudwego/netpoll): A high-performance network library. Hertz integrated by default. - [Hertz-Contrib](https://github.com/hertz-contrib): A partial extension library of Hertz, which users can integrate into Hertz through options according to their needs. - [Example](https://github.com/cloudwego/hertz-examples): Use examples of Hertz. +## Extensions + +| Extensions | Description | +|----------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [Websocket](https://github.com/hertz-contrib/websocket) | Enable Hertz to support the Websocket protocol. | +| [Pprof](https://github.com/hertz-contrib/pprof) | Extension for Hertz integration with Pprof. | +| [Sessions](https://github.com/hertz-contrib/sessions) | Session middleware with multi-state store support. | +| [Obs-opentelemetry](https://github.com/hertz-contrib/obs-opentelemetry) | Hertz's Opentelemetry extension that supports Metric, Logger, Tracing and works out of the box. | +| [Registry](https://github.com/hertz-contrib/registry) | Provides service registry and discovery functions. So far, the supported service discovery extensions are nacos, consul, etcd, eureka, polaris, servicecomb, zookeeper. | +| [Keyauth](https://github.com/hertz-contrib/keyauth) | Provides token-based authentication. | +| [Secure](https://github.com/hertz-contrib/secure) | Secure middleware with multiple configuration items. | +| [Sentry](https://github.com/hertz-contrib/hertzsentry) | Sentry extension provides some unified interfaces to help users perform real-time error monitoring. | +| [Requestid](https://github.com/hertz-contrib/requestid) | Add request id in response. | +| [Limiter](https://github.com/hertz-contrib/limiter) | Provides a current limiter based on the bbr algorithm. | +| [Jwt](https://github.com/hertz-contrib/jwt) | Jwt extension. | +| [Autotls](https://github.com/hertz-contrib/autotls) | Make Hertz support Let's Encrypt. | +| [Monitor-prometheus](https://github.com/hertz-contrib/monitor-prometheus) | Provides service monitoring based on Prometheus. | +| [I18n](https://github.com/hertz-contrib/i18n) | Helps translate Hertz programs into multi programming languages. | +| [Reverseproxy](https://github.com/hertz-contrib/reverseproxy) | Implement a reverse proxy. | +| [Opensergo](https://github.com/hertz-contrib/opensergo) | The Opensergo extension. | +| [Gzip](https://github.com/hertz-contrib/gzip) | A Gzip extension with multiple options. | +| [Cors](https://github.com/hertz-contrib/cors) | Provides cross-domain resource sharing support. | +| [Swagger](https://github.com/hertz-contrib/swagger) | Automatically generate RESTful API documentation with Swagger 2.0. | +| [Tracer](https://github.com/hertz-contrib/tracer) | Link tracing based on Opentracing. | +| [Recovery](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/recovery) | Recovery middleware for Hertz. | +| [Basicauth](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/basic_auth) | Basicauth middleware can provide HTTP basic authentication. | +| [Lark](https://github.com/hertz-contrib/lark-hertz) | Use hertz handle Lark/Feishu card message and event callback. | + ## Blogs - [ByteDance Practice on Go Network Library](https://www.cloudwego.io/blog/2021/10/09/bytedance-practices-on-go-network-library/) ## Contributing diff --git a/README_cn.md b/README_cn.md index 9079a4892..ac38d1d4c 100644 --- a/README_cn.md +++ b/README_cn.md @@ -19,7 +19,11 @@ Hertz[həːts] 是一个 Golang 微服务 HTTP 框架,在设计之初参考了 - 高性能 Hertz 默认使用自研的高性能网络库 Netpoll,在一些特殊场景相较于 go net,Hertz 在 QPS、时延上均具有一定优势。关于性能数据,可参考下图 Echo 数据。 - ![Performance](images/performance.png) + + 四个框架的对比: + ![Performance](images/performance-4.png) + 三个框架的时延对比: + ![Performance](images/performance-3.png) 关于详细的性能数据,可参考 [hertz-benchmark](https://github.com/cloudwego/hertz-benchmark)。 - 高扩展性 @@ -52,6 +56,34 @@ Hertz[həːts] 是一个 Golang 微服务 HTTP 框架,在设计之初参考了 - [Netpoll](https://github.com/cloudwego/netpoll): 自研高性能网络库,Hertz 默认集成 - [Hertz-Contrib](https://github.com/hertz-contrib): Hertz 扩展仓库,提供中间件、tracer 等能力 - [Example](https://github.com/cloudwego/hertz-examples): Hertz 使用例子 +## 相关拓展 + +| 拓展 | 描述 | +|----------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------| +| [Websocket](https://github.com/hertz-contrib/websocket) | 使 Hertz 支持 Websocket 协议。 | +| [Pprof](https://github.com/hertz-contrib/pprof) | Hertz 集成 Pprof 的扩展。 | +| [Sessions](https://github.com/hertz-contrib/sessions) | 具有多状态存储支持的 Session 中间件。 | +| [Obs-opentelemetry](https://github.com/hertz-contrib/obs-opentelemetry) | Hertz 的 Opentelemetry 扩展,支持 Metric、Logger、Tracing并且达到开箱即用。 | +| [Registry](https://github.com/hertz-contrib/registry) | 提供服务注册与发现功能。到现在为止,支持的服务发现拓展有 nacos, consul, etcd, eureka, polaris, servicecomb, zookeeper。 | +| [Keyauth](https://github.com/hertz-contrib/keyauth) | 提供基于 token 的身份验证。 | +| [Secure](https://github.com/hertz-contrib/secure) | 具有多配置项的 Secure 中间件。 | +| [Sentry](https://github.com/hertz-contrib/hertzsentry) | Sentry 拓展提供了一些统一的接口来帮助用户进行实时的错误监控。 | +| [Requestid](https://github.com/hertz-contrib/requestid) | 在 response 中添加 request id。 | +| [Limiter](https://github.com/hertz-contrib/limiter) | 提供了基于 bbr 算法的限流器。 | +| [Jwt](https://github.com/hertz-contrib/jwt) | Jwt 拓展。 | +| [Autotls](https://github.com/hertz-contrib/autotls) | 为 Hertz 支持 Let's Encrypt 。 | +| [Monitor-prometheus](https://github.com/hertz-contrib/monitor-prometheus) | 提供基于 Prometheus 服务监控功能。 | +| [I18n](https://github.com/hertz-contrib/i18n) | 可帮助将 Hertz 程序翻译成多种语言。 | +| [Reverseproxy](https://github.com/hertz-contrib/reverseproxy) | 实现反向代理。 | +| [Opensergo](https://github.com/hertz-contrib/opensergo) | Opensergo 扩展。 | +| [Gzip](https://github.com/hertz-contrib/gzip) | 含多个可选项的 Gzip 拓展。 | +| [Cors](https://github.com/hertz-contrib/cors) | 提供跨域资源共享支持。 | +| [Swagger](https://github.com/hertz-contrib/swagger) | 使用 Swagger 2.0 自动生成 RESTful API 文档。 | +| [Tracer](https://github.com/hertz-contrib/tracer) | 基于 Opentracing 的链路追踪。 | +| [Recovery](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/recovery) | Hertz 的异常恢复中间件。 | +| [Basicauth](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/basic_auth) | Basicauth 中间件能够提供 HTTP 基本身份验证。 | +| [Lark](https://github.com/hertz-contrib/lark-hertz) | 在 Hertz 中处理 Lark/飞书的卡片消息和事件的回调。 | + ## 相关文章 - [字节跳动在 Go 网络库上的实践](https://www.cloudwego.io/blog/2021/10/09/bytedance-practices-on-go-network-library/) ## 贡献代码 diff --git a/cmd/hz/go.mod b/cmd/hz/go.mod index d488ce6b6..cbf7e0928 100644 --- a/cmd/hz/go.mod +++ b/cmd/hz/go.mod @@ -6,7 +6,7 @@ require ( github.com/cloudwego/thriftgo v0.1.7 github.com/hashicorp/go-version v1.5.0 github.com/jhump/protoreflect v1.12.0 - github.com/urfave/cli/v2 v2.8.1 + github.com/urfave/cli/v2 v2.20.2 google.golang.org/protobuf v1.28.0 gopkg.in/yaml.v2 v2.4.0 ) diff --git a/cmd/hz/go.sum b/cmd/hz/go.sum index cf959c7d6..35bfc58ba 100644 --- a/cmd/hz/go.sum +++ b/cmd/hz/go.sum @@ -8,8 +8,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cloudwego/thriftgo v0.1.7 h1:mTGRv6Dtwfp0hTPZXuIHwm3vtGOuZVTrWarI0xVzUYg= github.com/cloudwego/thriftgo v0.1.7/go.mod h1:LzeafuLSiHA9JTiWC8TIMIq64iadeObgRUhmVG1OC/w= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/cpuguy83/go-md2man/v2 v2.0.1 h1:r/myEWzV9lfsM1tFLgDyu0atFtJ1fXn261LKYj/3DxU= -github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -55,8 +55,8 @@ github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/urfave/cli/v2 v2.8.1 h1:CGuYNZF9IKZY/rfBe3lJpccSoIY1ytfvmgQT90cNOl4= -github.com/urfave/cli/v2 v2.8.1/go.mod h1:Z41J9TPoffeoqP0Iza0YbAhGvymRdZAd2uPmZ5JxRdY= +github.com/urfave/cli/v2 v2.20.2 h1:dKA0LUjznZpwmmbrc0pOgcLTEilnHeM8Av9Yng77gHM= +github.com/urfave/cli/v2 v2.20.2/go.mod h1:1CNUng3PtjQMtRzJO4FMXBQvkGtuYRxxiR9xMa7jMwI= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -125,5 +125,6 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/go.mod b/go.mod index cf0ca35c6..f5be27668 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,10 @@ go 1.16 require ( github.com/bytedance/go-tagexpr/v2 v2.9.2 github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 - github.com/bytedance/sonic v1.3.5 + github.com/bytedance/sonic v1.5.0 github.com/cloudwego/netpoll v0.2.6 github.com/fsnotify/fsnotify v1.5.4 + github.com/tidwall/gjson v1.13.0 // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.0.0-20220412211240-33da011f77ad google.golang.org/protobuf v1.27.1 diff --git a/go.sum b/go.sum index 1ed067c29..0d1798153 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/bytedance/go-tagexpr/v2 v2.9.2 h1:QySJaAIQgOEDQBLS3x9BxOWrnhqu5sQ+f6H github.com/bytedance/go-tagexpr/v2 v2.9.2/go.mod h1:5qsx05dYOiUXOUgnQ7w3Oz8BYs2qtM/bJokdLb79wRM= github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 h1:PtwsQyQJGxf8iaPptPNaduEIu9BnrNms+pcRdHAxZaM= github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= -github.com/bytedance/sonic v1.3.5 h1:xfBNhsG3QCC+AMCmCHxNQg0StI5IM/B9Jtwjqi5WlI0= -github.com/bytedance/sonic v1.3.5/go.mod h1:V973WhNhGmvHxW6nQmsHEfHaoU9F3zTF+93rH03hcUQ= +github.com/bytedance/sonic v1.5.0 h1:XWdTi8bwPgxIML+eNV1IwNuTROK6EUrQ65ey8yd6fRQ= +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06 h1:1sDoSuDPWzhkdzNVxCxtIaKiAe96ESVPv8coGwc1gZ4= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/cloudwego/netpoll v0.2.6 h1:vzN8cyayoa9RdCOG87tqkYO/j2hA4SMLC+vkcNUq6uI= @@ -13,47 +13,34 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= -github.com/goccy/go-json v0.9.4 h1:L8MLKG2mvVXiQu07qB6hmfqeSYQdOnqPot2GhsIwIaI= -github.com/goccy/go-json v0.9.4/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/henrylee2cn/ameda v1.4.8/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= github.com/henrylee2cn/ameda v1.4.10 h1:JdvI2Ekq7tapdPsuhrc4CaFiqw6QXFvZIULWJgQyCAk= github.com/henrylee2cn/ameda v1.4.10/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8 h1:yE9ULgp02BhYIrO6sdV/FPe0xQM6fNHkVQW2IAymfM0= github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8/go.mod h1:Nhe/DM3671a5udlv2AdV2ni/MZzgfv2qrPL5nIi3EGQ= -github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= -github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= -github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.13.0 h1:3TFY9yxOQShrvmjdM76K+jc66zJeT6D3/VFFYCGQf7M= github.com/tidwall/gjson v1.13.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/tidwall/sjson v1.2.4 h1:cuiLzLnaMeBhRmEv00Lpk3tkYrcxpmbU81tAY4Dw0tc= -github.com/tidwall/sjson v1.2.4/go.mod h1:098SZ494YoMWPmMO6ct4dcFnqxwj9r/gF0Etp19pSNM= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= diff --git a/images/performance-3.png b/images/performance-3.png new file mode 100644 index 000000000..1cbce6414 Binary files /dev/null and b/images/performance-3.png differ diff --git a/images/performance-4.png b/images/performance-4.png new file mode 100644 index 000000000..26bfad16e Binary files /dev/null and b/images/performance-4.png differ diff --git a/images/performance.png b/images/performance.png deleted file mode 100644 index 884702fe1..000000000 Binary files a/images/performance.png and /dev/null differ diff --git a/internal/bytesconv/bytesconv_32_test.go b/internal/bytesconv/bytesconv_32_test.go new file mode 100644 index 000000000..384127ff7 --- /dev/null +++ b/internal/bytesconv/bytesconv_32_test.go @@ -0,0 +1,147 @@ +//go:build !amd64 && !arm64 && !ppc64 +// +build !amd64,!arm64,!ppc64 + +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * The MIT License (MIT) + * + * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors. + */ + +package bytesconv + +import ( + "fmt" + "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" +) + +func TestWriteHexInt(t *testing.T) { + t.Parallel() + + for _, v := range []struct { + s string + n int + }{ + {"0", 0}, + {"1", 1}, + {"123", 0x123}, + {"7fffffff", 0x7fffffff}, + } { + testWriteHexInt(t, v.n, v.s) + } +} + +func TestReadHexInt(t *testing.T) { + t.Parallel() + + for _, v := range []struct { + s string + n int + }{ + //errTooLargeHexNum "too large hex number" + //{"0123456789abcdef", -1}, + {"0", 0}, + {"fF", 0xff}, + {"00abc", 0xabc}, + {"7fffffff", 0x7fffffff}, + {"000", 0}, + {"1234ZZZ", 0x1234}, + } { + testReadHexInt(t, v.s, v.n) + } +} + +func TestParseUint(t *testing.T) { + t.Parallel() + + for _, v := range []struct { + s string + i int + }{ + {"0", 0}, + {"123", 123}, + {"123456789", 123456789}, + {"2147483647", 2147483647}, + } { + n, err := ParseUint(S2b(v.s)) + if err != nil { + t.Errorf("unexpected error: %v. s=%q n=%v", err, v.s, n) + } + assert.DeepEqual(t, n, v.i) + } +} + +func TestParseUintError(t *testing.T) { + t.Parallel() + + for _, v := range []struct { + s string + }{ + {""}, + {"cloudwego123"}, + {"1234.545"}, + {"-2147483648"}, + {"2147483648"}, + {"4294967295"}, + } { + n, err := ParseUint(S2b(v.s)) + if err == nil { + t.Fatalf("Expecting error when parsing %q. obtained %d", v.s, n) + } + if n >= 0 { + t.Fatalf("Unexpected n=%d when parsing %q. Expected negative num", n, v.s) + } + } +} + +func TestAppendUint(t *testing.T) { + t.Parallel() + + for _, s := range []struct { + n int + }{ + {0}, + {123}, + {0x7fffffff}, + } { + expectedS := fmt.Sprintf("%d", s.n) + s := AppendUint(nil, s.n) + assert.DeepEqual(t, expectedS, B2s(s)) + } +} diff --git a/internal/bytesconv/bytesconv_64_test.go b/internal/bytesconv/bytesconv_64_test.go new file mode 100644 index 000000000..6f96f8b13 --- /dev/null +++ b/internal/bytesconv/bytesconv_64_test.go @@ -0,0 +1,149 @@ +//go:build amd64 || arm64 || ppc64 +// +build amd64 arm64 ppc64 + +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * The MIT License (MIT) + * + * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors. + */ + +package bytesconv + +import ( + "fmt" + "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" +) + +func TestWriteHexInt(t *testing.T) { + t.Parallel() + + for _, v := range []struct { + s string + n int + }{ + {"0", 0}, + {"1", 1}, + {"123", 0x123}, + {"7fffffffffffffff", 0x7fffffffffffffff}, + } { + testWriteHexInt(t, v.n, v.s) + } +} + +func TestReadHexInt(t *testing.T) { + t.Parallel() + + for _, v := range []struct { + s string + n int + }{ + //errTooLargeHexNum "too large hex number" + //{"0123456789abcdef", -1}, + {"0", 0}, + {"fF", 0xff}, + {"00abc", 0xabc}, + {"7fffffff", 0x7fffffff}, + {"000", 0}, + {"1234ZZZ", 0x1234}, + {"7ffffffffffffff", 0x7ffffffffffffff}, + } { + testReadHexInt(t, v.s, v.n) + } +} + +func TestParseUint(t *testing.T) { + t.Parallel() + + for _, v := range []struct { + s string + i int + }{ + {"0", 0}, + {"123", 123}, + {"1234567890", 1234567890}, + {"123456789012345678", 123456789012345678}, + {"9223372036854775807", 9223372036854775807}, + } { + n, err := ParseUint(S2b(v.s)) + if err != nil { + t.Errorf("unexpected error: %v. s=%q n=%v", err, v.s, n) + } + assert.DeepEqual(t, n, v.i) + } +} + +func TestParseUintError(t *testing.T) { + t.Parallel() + + for _, v := range []struct { + s string + }{ + {""}, + {"cloudwego123"}, + {"1234.545"}, + {"-9223372036854775808"}, + {"9223372036854775808"}, + {"18446744073709551615"}, + } { + n, err := ParseUint(S2b(v.s)) + if err == nil { + t.Fatalf("Expecting error when parsing %q. obtained %d", v.s, n) + } + if n >= 0 { + t.Fatalf("Unexpected n=%d when parsing %q. Expected negative num", n, v.s) + } + } +} + +func TestAppendUint(t *testing.T) { + t.Parallel() + + for _, s := range []struct { + n int + }{ + {0}, + {123}, + {0x7fffffffffffffff}, + } { + expectedS := fmt.Sprintf("%d", s.n) + s := AppendUint(nil, s.n) + assert.DeepEqual(t, expectedS, B2s(s)) + } +} diff --git a/internal/bytesconv/bytesconv_test.go b/internal/bytesconv/bytesconv_test.go index 57caea626..cf7a9a478 100644 --- a/internal/bytesconv/bytesconv_test.go +++ b/internal/bytesconv/bytesconv_test.go @@ -17,10 +17,14 @@ package bytesconv import ( + "net/url" "testing" "time" + "github.com/cloudwego/hertz/pkg/common/bytebufferpool" "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/hertz/pkg/common/test/mock" + "github.com/cloudwego/hertz/pkg/network" ) func TestAppendDate(t *testing.T) { @@ -46,7 +50,136 @@ func TestAppendDate(t *testing.T) { } { t.Run(c.name, func(t *testing.T) { s := AppendHTTPDate(nil, c.date) - assert.DeepEqual(t, c.dateStr, string(s)) + assert.DeepEqual(t, c.dateStr, B2s(s)) }) } } + +func TestLowercaseBytes(t *testing.T) { + t.Parallel() + + for _, v := range []struct { + b1, b2 []byte + }{ + {[]byte("CLOUDWEGO-HERTZ"), []byte("cloudwego-hertz")}, + {[]byte("CLOUDWEGO"), []byte("cloudwego")}, + {[]byte("HERTZ"), []byte("hertz")}, + } { + LowercaseBytes(v.b1) + assert.DeepEqual(t, v.b2, v.b1) + } +} + +// The test converts byte slice to a string without memory allocation. +func TestB2s(t *testing.T) { + t.Parallel() + + for _, v := range []struct { + s string + b []byte + }{ + {"cloudwego-hertz", []byte("cloudwego-hertz")}, + {"cloudwego", []byte("cloudwego")}, + {"hertz", []byte("hertz")}, + } { + assert.DeepEqual(t, v.s, B2s(v.b)) + } +} + +// The test converts string to a byte slice without memory allocation. +func TestS2b(t *testing.T) { + t.Parallel() + + for _, v := range []struct { + s string + b []byte + }{ + {"cloudwego-hertz", []byte("cloudwego-hertz")}, + {"cloudwego", []byte("cloudwego")}, + {"hertz", []byte("hertz")}, + } { + assert.DeepEqual(t, S2b(v.s), v.b) + } +} + +// common test function for 32bit and 64bit +func testWriteHexInt(t *testing.T, n int, expectedS string) { + w := bytebufferpool.Get() + zw := network.NewWriter(w) + if err := WriteHexInt(zw, n); err != nil { + t.Errorf("unexpected error when writing hex %x: %v", n, err) + } + if err := zw.Flush(); err != nil { + t.Fatalf("unexpected error when flushing hex %x: %v", n, err) + } + s := B2s(w.B) + assert.DeepEqual(t, s, expectedS) +} + +// common test function for 32bit and 64bit +func testReadHexInt(t *testing.T, s string, expectedN int) { + zr := mock.NewZeroCopyReader(s) + n, err := ReadHexInt(zr) + if err != nil { + t.Errorf("unexpected error: %v. s=%q", err, s) + } + assert.DeepEqual(t, n, expectedN) +} + +func TestAppendQuotedPath(t *testing.T) { + t.Parallel() + + // Test all characters + pathSegment := make([]byte, 256) + for i := 0; i < 256; i++ { + pathSegment[i] = byte(i) + } + for _, s := range []struct { + path string + }{ + {"/"}, + {"//"}, + {"/foo/bar"}, + {"*"}, + {"/foo/" + B2s(pathSegment)}, + } { + u := url.URL{Path: s.path} + expectedS := u.EscapedPath() + res := B2s(AppendQuotedPath(nil, S2b(s.path))) + assert.DeepEqual(t, expectedS, res) + } +} + +func TestAppendQuotedArg(t *testing.T) { + t.Parallel() + + // Sync with url.QueryEscape + allcases := make([]byte, 256) + for i := 0; i < 256; i++ { + allcases[i] = byte(i) + } + res := B2s(AppendQuotedArg(nil, allcases)) + expect := url.QueryEscape(B2s(allcases)) + assert.DeepEqual(t, expect, res) +} + +func TestParseHTTPDate(t *testing.T) { + t.Parallel() + + for _, v := range []struct { + t string + }{ + {"Thu, 04 Feb 2010 21:00:57 PST"}, + {"Mon, 02 Jan 2006 15:04:05 MST"}, + } { + t1, err := time.Parse(time.RFC1123, v.t) + if err != nil { + t.Fatalf("unexpected error: %v. t=%q", err, v.t) + } + t2, err := ParseHTTPDate(S2b(t1.Format(time.RFC1123))) + if err != nil { + t.Fatalf("unexpected error: %v. t=%q", err, v.t) + } + assert.DeepEqual(t, t1, t2) + } +} diff --git a/internal/stats/tracer.go b/internal/stats/tracer.go index 1d92cfb4c..e00b09a65 100644 --- a/internal/stats/tracer.go +++ b/internal/stats/tracer.go @@ -67,6 +67,6 @@ func (ctl *Controller) HasTracer() bool { func (ctl *Controller) tryRecover() { if err := recover(); err != nil { - hlog.Warnf("HERTZ: Panic happened during tracer call. This doesn't affect the http call, but may lead to lack of monitor data such as metrics and logs: %s, %s", err, string(debug.Stack())) + hlog.SystemLogger().Warnf("Panic happened during tracer call. This doesn't affect the http call, but may lead to lack of monitor data such as metrics and logs: %s, %s", err, string(debug.Stack())) } } diff --git a/pkg/app/client/client.go b/pkg/app/client/client.go index 92514ae1f..99d32a537 100644 --- a/pkg/app/client/client.go +++ b/pkg/app/client/client.go @@ -243,10 +243,8 @@ type Client struct { // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy protocol.Proxy - // RetryIf controls whether a retry should be attempted after an error. - // - // By default will use isIdempotent function - RetryIf func(request *protocol.Request) bool + // RetryIfFunc sets the retry decision function. If nil, the client.DefaultRetryIf will be applied. + RetryIfFunc client.RetryIfFunc clientFactory suite.ClientFactory @@ -260,6 +258,18 @@ func (c *Client) GetOptions() *config.ClientOptions { return c.options } +func (c *Client) SetRetryIfFunc(retryIf client.RetryIfFunc) { + c.RetryIfFunc = retryIf +} + +// Deprecated: use SetRetryIfFunc instead of SetRetryIf +func (c *Client) SetRetryIf(fn func(request *protocol.Request) bool) { + f := func(req *protocol.Request, resp *protocol.Response, err error) bool { + return fn(req) + } + c.SetRetryIfFunc(f) +} + // SetProxy is used to set client proxy. // // Don't SetProxy twice for a client. @@ -268,11 +278,6 @@ func (c *Client) SetProxy(p protocol.Proxy) { c.Proxy = p } -// SetRetryIf is used to set RetryIf func. -func (c *Client) SetRetryIf(fn func(request *protocol.Request) bool) { - c.RetryIf = fn -} - // Get returns the status code and body of url. // // The contents of dst will be replaced by the body and returned, if the dst @@ -594,14 +599,14 @@ func newHttp1OptionFromClient(c *Client) *http1.ClientOptions { MaxConns: c.options.MaxConnsPerHost, MaxConnDuration: c.options.MaxConnDuration, MaxIdleConnDuration: c.options.MaxIdleConnDuration, - MaxIdempotentCallAttempts: c.options.MaxIdempotentCallAttempts, ReadTimeout: c.options.ReadTimeout, WriteTimeout: c.options.WriteTimeout, MaxResponseBodySize: c.options.MaxResponseBodySize, DisableHeaderNamesNormalizing: c.options.DisableHeaderNamesNormalizing, DisablePathNormalizing: c.options.DisablePathNormalizing, MaxConnWaitTimeout: c.options.MaxConnWaitTimeout, - RetryIf: c.RetryIf, ResponseBodyStream: c.options.ResponseBodyStream, + RetryConfig: c.options.RetryConfig, + RetryIfFunc: c.RetryIfFunc, } } diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go index bd5464ca7..8c86aff30 100644 --- a/pkg/app/client/client_test.go +++ b/pkg/app/client/client_test.go @@ -65,6 +65,7 @@ import ( "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/app/client/retry" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/network" @@ -389,9 +390,9 @@ func TestClientReadTimeout(t *testing.T) { c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ - ReadTimeout: time.Second * 4, - MaxIdempotentCallAttempts: 1, - Dialer: standard.NewDialer(), + ReadTimeout: time.Second * 4, + RetryConfig: &retry.Config{MaxAttemptTimes: 1}, + Dialer: standard.NewDialer(), }, Addr: opt.Addr, } @@ -434,7 +435,7 @@ func TestClientReadTimeout(t *testing.T) { case <-done: // It is abnormal when waiting time exceeds the value of readTimeout times the number of retries. // Give it extra 2 seconds just to be sure. - case <-time.After(c.ReadTimeout*time.Duration(c.MaxIdempotentCallAttempts) + time.Second*2): + case <-time.After(c.ReadTimeout*time.Duration(c.RetryConfig.MaxAttemptTimes) + time.Second*2): t.Fatal("Client.ReadTimeout didn't work") } } @@ -1880,6 +1881,128 @@ func newMockDialerWithCustomFunc(network, address string, timeout time.Duration, } } +func TestClientRetry(t *testing.T) { + t.Parallel() + client, err := NewClient( + WithDialTimeout(2*time.Second), + WithRetryConfig( + retry.WithMaxAttemptTimes(3), + retry.WithInitDelay(100*time.Millisecond), + retry.WithMaxDelay(10*time.Second), + retry.WithDelayPolicy(retry.CombineDelay(retry.FixedDelayPolicy, retry.BackOffDelayPolicy)), + ), + ) + client.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { + return err != nil + }) + if err != nil { + t.Fatal(err) + return + } + startTime := time.Now().UnixNano() + _, resp, err := client.Get(context.Background(), nil, "http://127.0.0.1:1234/ping") + if err != nil { + // first delay 100+200ms , second delay 100+400ms + if time.Duration(time.Now().UnixNano()-startTime) > 800*time.Millisecond && time.Duration(time.Now().UnixNano()-startTime) < 2*time.Second { + t.Logf("Retry triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) + } else if time.Duration(time.Now().UnixNano()-startTime) < 1*time.Second { // Compatible without triggering retry + t.Logf("Retry not triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) + } else { + t.Fatal(err) + } + } + + client2, err := NewClient( + WithDialTimeout(2*time.Second), + WithRetryConfig( + retry.WithMaxAttemptTimes(2), + retry.WithInitDelay(500*time.Millisecond), + retry.WithMaxJitter(1*time.Second), + retry.WithDelayPolicy(retry.CombineDelay(retry.FixedDelayPolicy, retry.BackOffDelayPolicy)), + ), + ) + if err != nil { + t.Fatal(err) + return + } + client2.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { + return err != nil + }) + startTime = time.Now().UnixNano() + _, resp, err = client2.Get(context.Background(), nil, "http://127.0.0.1:1234/ping") + if err != nil { + // delay max{500ms+rand([0,1))s,100ms}. Because if the MaxDelay is not set, we will use the default MaxDelay of 100ms + if time.Duration(time.Now().UnixNano()-startTime) > 100*time.Millisecond && time.Duration(time.Now().UnixNano()-startTime) < 1100*time.Millisecond { + t.Logf("Retry triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) + } else if time.Duration(time.Now().UnixNano()-startTime) < 1*time.Second { // Compatible without triggering retry + t.Logf("Retry not triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) + } else { + t.Fatal(err) + } + } + + client3, err := NewClient( + WithDialTimeout(2*time.Second), + WithRetryConfig( + retry.WithMaxAttemptTimes(2), + retry.WithInitDelay(100*time.Millisecond), + retry.WithMaxDelay(5*time.Second), + retry.WithMaxJitter(1*time.Second), + retry.WithDelayPolicy(retry.CombineDelay(retry.FixedDelayPolicy, retry.BackOffDelayPolicy, retry.RandomDelayPolicy)), + ), + ) + if err != nil { + t.Fatal(err) + return + } + client3.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { + return err != nil + }) + startTime = time.Now().UnixNano() + _, resp, err = client3.Get(context.Background(), nil, "http://127.0.0.1:1234/ping") + if err != nil { + // delay 100ms+200ms+rand([0,1))s + if time.Duration(time.Now().UnixNano()-startTime) > 300*time.Millisecond && time.Duration(time.Now().UnixNano()-startTime) < 2300*time.Millisecond { + t.Logf("Retry triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) + } else if time.Duration(time.Now().UnixNano()-startTime) < 1*time.Second { // Compatible without triggering retry + t.Logf("Retry not triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) + } else { + t.Fatal(err) + } + } + + client4, err := NewClient( + WithDialTimeout(2*time.Second), + WithRetryConfig( + retry.WithMaxAttemptTimes(2), + retry.WithInitDelay(1*time.Second), + retry.WithMaxDelay(10*time.Second), + retry.WithMaxJitter(5*time.Second), + retry.WithDelayPolicy(retry.CombineDelay(retry.FixedDelayPolicy, retry.BackOffDelayPolicy, retry.RandomDelayPolicy)), + ), + ) + if err != nil { + t.Fatal(err) + return + } + /* If the retryIfFunc is not set , idempotent logic is used by default */ + //client4.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { + // return err != nil + //}) + startTime = time.Now().UnixNano() + _, resp, err = client4.Get(context.Background(), nil, "http://127.0.0.1:1234/ping") + if err != nil { + if time.Duration(time.Now().UnixNano()-startTime) > 1*time.Second && time.Duration(time.Now().UnixNano()-startTime) < 9*time.Second { + t.Logf("Retry triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) + } else if time.Duration(time.Now().UnixNano()-startTime) < 1*time.Second { // Compatible without triggering retry + t.Logf("Retry not triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) + } else { + t.Fatal(err) + } + return + } +} + func TestClientDialerName(t *testing.T) { client, _ := NewClient() dName, err := client.GetDialerName() diff --git a/pkg/app/client/loadbalance/lbcache.go b/pkg/app/client/loadbalance/lbcache.go index c249ebe2f..47acdd450 100644 --- a/pkg/app/client/loadbalance/lbcache.go +++ b/pkg/app/client/loadbalance/lbcache.go @@ -111,7 +111,7 @@ func (b *BalancerFactory) refresh() { b.cache.Range(func(key, value interface{}) bool { res, err := b.resolver.Resolve(context.Background(), key.(string)) if err != nil { - hlog.Warnf("Hertz: resolver refresh failed, key=%s error=%s", key, err.Error()) + hlog.SystemLogger().Warnf("resolver refresh failed, key=%s error=%s", key, err.Error()) return true } renameResultCacheKey(&res, b.resolver.Name()) @@ -132,7 +132,7 @@ func (b *BalancerFactory) GetInstance(ctx context.Context, req *protocol.Request atomic.StoreInt32(&cacheRes.expire, 0) ins := b.balancer.Pick(cacheRes.res.Load().(discovery.Result)) if ins == nil { - hlog.Errorf("HERTZ: null instance. serviceName: %s, options: %v", string(req.Host()), req.Options()) + hlog.SystemLogger().Errorf("null instance. serviceName: %s, options: %v", string(req.Host()), req.Options()) return nil, errors.NewPublic("instance not found") } return ins, nil diff --git a/pkg/app/client/loadbalance/weight_random.go b/pkg/app/client/loadbalance/weight_random.go index 1fa0f57b5..381e5769e 100644 --- a/pkg/app/client/loadbalance/weight_random.go +++ b/pkg/app/client/loadbalance/weight_random.go @@ -59,7 +59,7 @@ func (wb *weightedBalancer) calcWeightInfo(e discovery.Result) *weightInfo { w.weightSum += weight cnt++ } else { - hlog.Warnf("HERTZ: Invalid weight=%d on instance address=%s", weight, e.Instances[idx].Address()) + hlog.SystemLogger().Warnf("Invalid weight=%d on instance address=%s", weight, e.Instances[idx].Address()) } } diff --git a/pkg/app/client/option.go b/pkg/app/client/option.go index 80a74ccf8..343ce62b6 100644 --- a/pkg/app/client/option.go +++ b/pkg/app/client/option.go @@ -20,9 +20,11 @@ import ( "crypto/tls" "time" + "github.com/cloudwego/hertz/pkg/app/client/retry" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/standard" + "github.com/cloudwego/hertz/pkg/protocol/consts" ) // WithDialTimeout sets dial timeout. @@ -67,13 +69,6 @@ func WithKeepAlive(b bool) config.ClientOption { }} } -// WithMaxIdempotentCallAttempts sets maximum number of attempts for idempotent calls. -func WithMaxIdempotentCallAttempts(n int) config.ClientOption { - return config.ClientOption{F: func(o *config.ClientOptions) { - o.MaxIdempotentCallAttempts = n - }} -} - // WithClientReadTimeout sets maximum duration for full response reading (including body). func WithClientReadTimeout(t time.Duration) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { @@ -130,3 +125,18 @@ func WithDisablePathNormalizing(isDisablePathNormalizing bool) config.ClientOpti o.DisablePathNormalizing = isDisablePathNormalizing }} } + +func WithRetryConfig(opts ...retry.Option) config.ClientOption { + retryCfg := &retry.Config{ + MaxAttemptTimes: consts.DefaultMaxRetryTimes, + Delay: 1 * time.Millisecond, + MaxDelay: 100 * time.Millisecond, + MaxJitter: 20 * time.Millisecond, + DelayPolicy: retry.CombineDelay(retry.DefaultDelayPolicy), + } + retryCfg.Apply(opts) + + return config.ClientOption{F: func(o *config.ClientOptions) { + o.RetryConfig = retryCfg + }} +} diff --git a/pkg/app/client/option_test.go b/pkg/app/client/option_test.go index ff8214708..149d6a78f 100644 --- a/pkg/app/client/option_test.go +++ b/pkg/app/client/option_test.go @@ -17,9 +17,11 @@ package client import ( + "fmt" "testing" "time" + "github.com/cloudwego/hertz/pkg/app/client/retry" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" ) @@ -31,18 +33,28 @@ func TestClientOptions(t *testing.T) { WithMaxIdleConnDuration(5 * time.Second), WithMaxConnDuration(10 * time.Second), WithMaxConnWaitTimeout(5 * time.Second), - WithMaxIdempotentCallAttempts(10), WithKeepAlive(false), WithClientReadTimeout(1 * time.Second), WithResponseBodyStream(true), + WithRetryConfig( + retry.WithMaxAttemptTimes(2), + retry.WithInitDelay(100*time.Millisecond), + retry.WithMaxDelay(5*time.Second), + retry.WithMaxJitter(1*time.Second), + retry.WithDelayPolicy(retry.CombineDelay(retry.FixedDelayPolicy, retry.BackOffDelayPolicy, retry.RandomDelayPolicy)), + ), }) assert.DeepEqual(t, 100*time.Millisecond, opt.DialTimeout) assert.DeepEqual(t, 128, opt.MaxConnsPerHost) assert.DeepEqual(t, 5*time.Second, opt.MaxIdleConnDuration) assert.DeepEqual(t, 10*time.Second, opt.MaxConnDuration) assert.DeepEqual(t, 5*time.Second, opt.MaxConnWaitTimeout) - assert.DeepEqual(t, 10, opt.MaxIdempotentCallAttempts) assert.DeepEqual(t, false, opt.KeepAlive) assert.DeepEqual(t, 1*time.Second, opt.ReadTimeout) assert.DeepEqual(t, true, opt.ResponseBodyStream) + assert.DeepEqual(t, uint(2), opt.RetryConfig.MaxAttemptTimes) + assert.DeepEqual(t, 100*time.Millisecond, opt.RetryConfig.Delay) + assert.DeepEqual(t, 5*time.Second, opt.RetryConfig.MaxDelay) + assert.DeepEqual(t, 1*time.Second, opt.RetryConfig.MaxJitter) + assert.DeepEqual(t, fmt.Sprint(retry.CombineDelay(retry.FixedDelayPolicy, retry.BackOffDelayPolicy, retry.RandomDelayPolicy)), fmt.Sprint(opt.RetryConfig.DelayPolicy)) } diff --git a/pkg/app/client/retry/option.go b/pkg/app/client/retry/option.go new file mode 100644 index 000000000..bfaa9da8b --- /dev/null +++ b/pkg/app/client/retry/option.go @@ -0,0 +1,59 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retry + +import "time" + +// Option is the only struct that can be used to set Retry Config. +type Option struct { + F func(o *Config) +} + +// WithMaxAttemptTimes set WithMaxAttemptTimes , including the first call. +func WithMaxAttemptTimes(maxAttemptTimes uint) Option { + return Option{F: func(o *Config) { + o.MaxAttemptTimes = maxAttemptTimes + }} +} + +// WithInitDelay set init Delay. +func WithInitDelay(delay time.Duration) Option { + return Option{F: func(o *Config) { + o.Delay = delay + }} +} + +// WithMaxDelay set MaxDelay. +func WithMaxDelay(maxDelay time.Duration) Option { + return Option{F: func(o *Config) { + o.MaxDelay = maxDelay + }} +} + +// WithDelayPolicy set DelayPolicy. +func WithDelayPolicy(delayPolicy DelayPolicyFunc) Option { + return Option{F: func(o *Config) { + o.DelayPolicy = delayPolicy + }} +} + +// WithMaxJitter set MaxJitter. +func WithMaxJitter(maxJitter time.Duration) Option { + return Option{F: func(o *Config) { + o.MaxJitter = maxJitter + }} +} diff --git a/pkg/app/client/retry/retry.go b/pkg/app/client/retry/retry.go new file mode 100644 index 000000000..f9e01bf27 --- /dev/null +++ b/pkg/app/client/retry/retry.go @@ -0,0 +1,115 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retry + +import ( + "math" + "time" + + "github.com/bytedance/gopkg/lang/fastrand" +) + +// Config All configurations related to retry +type Config struct { + // The maximum number of call attempt times, including the initial call + MaxAttemptTimes uint + + // Initial retry delay time + Delay time.Duration + + // Maximum retry delay time. When the retry time increases beyond this time, + // this configuration will limit the upper limit of waiting time + MaxDelay time.Duration + + // The maximum jitter time, which takes effect when the delay policy is configured as RandomDelay + MaxJitter time.Duration + + // Delay strategy, which can combine multiple delay strategies. such as CombineDelay(BackOffDelayPolicy, RandomDelayPolicy) or BackOffDelayPolicy,etc + DelayPolicy DelayPolicyFunc +} + +func (o *Config) Apply(opts []Option) { + for _, op := range opts { + op.F(o) + } +} + +// DelayPolicyFunc signature of delay policy function +// is called to return the delay of retry +type DelayPolicyFunc func(attempts uint, err error, retryConfig *Config) time.Duration + +// DefaultDelayPolicy is a DelayPolicyFunc which keep 0 delay in all iterations +func DefaultDelayPolicy(_ uint, _ error, _ *Config) time.Duration { + return 0 * time.Millisecond +} + +// FixedDelayPolicy is a DelayPolicyFunc which keeps delay the same through all iterations +func FixedDelayPolicy(_ uint, _ error, retryConfig *Config) time.Duration { + return retryConfig.Delay +} + +// RandomDelayPolicy is a DelayPolicyFunc which picks a random delay up to RetryConfig.MaxJitter, if the retryConfig.MaxJitter less than or equal to 0, the final delay is 0 +func RandomDelayPolicy(_ uint, _ error, retryConfig *Config) time.Duration { + if retryConfig.MaxJitter <= 0 { + return 0 * time.Millisecond + } + return time.Duration(fastrand.Int63n(int64(retryConfig.MaxJitter))) +} + +// BackOffDelayPolicy is a DelayPolicyFunc which exponentially increases delay between consecutive retries, if the retryConfig.Delay less than or equal to 0, the final delay is 0 +func BackOffDelayPolicy(attempts uint, _ error, retryConfig *Config) time.Duration { + if retryConfig.Delay <= 0 { + return 0 * time.Millisecond + } + // 1 << 63 would overflow signed int64 (time.Duration), thus 62. + const max uint = 62 + if attempts > max { + attempts = max + } + + return retryConfig.Delay << attempts +} + +// CombineDelay return DelayPolicyFunc, which combines the optional DelayPolicyFunc into a new DelayPolicyFunc +func CombineDelay(delays ...DelayPolicyFunc) DelayPolicyFunc { + const maxInt64 = uint64(math.MaxInt64) + + return func(attempts uint, err error, config *Config) time.Duration { + var total uint64 + for _, delay := range delays { + total += uint64(delay(attempts, err, config)) + if total > maxInt64 { + total = maxInt64 + } + } + + return time.Duration(total) + } +} + +// Delay generate the delay time required for the current retry config, if the retryConfig.DelayPolicy == nil, the final delay is 0 +func Delay(attempts uint, err error, retryConfig *Config) time.Duration { + if retryConfig.DelayPolicy == nil { + return 0 * time.Millisecond + } + + delayTime := retryConfig.DelayPolicy(attempts, err, retryConfig) + if retryConfig.MaxDelay > 0 && delayTime > retryConfig.MaxDelay { + delayTime = retryConfig.MaxDelay + } + return delayTime +} diff --git a/pkg/app/context.go b/pkg/app/context.go index b1464251b..5e15f9bec 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -880,6 +880,12 @@ func (ctx *RequestContext) PureJSON(code int, obj interface{}) { ctx.Render(code, render.PureJSON{Data: obj}) } +// IndentedJSON serializes the given struct as pretty JSON (indented + endlines) into the response body. +// It also sets the Content-Type as "application/json". +func (ctx *RequestContext) IndentedJSON(code int, obj interface{}) { + ctx.Render(code, render.IndentedJSON{Data: obj}) +} + // HTML renders the HTTP template specified by its file name. // // It also updates the HTTP code and sets the Content-Type as "text/html". diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index c83ee6e74..138852ac3 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -59,6 +59,17 @@ func TestPureJson(t *testing.T) { } } +func TestIndentedJSON(t *testing.T) { + ctx := NewContext(0) + ctx.IndentedJSON(consts.StatusOK, utils.H{ + "foo": "bar", + "html": "h1", + }) + if string(ctx.Response.Body()) != "{\n \"foo\": \"bar\",\n \"html\": \"h1\"\n}" { + t.Fatalf("unexpected purejson: %#v, expected: %#v", string(ctx.Response.Body()), "{\n \"foo\": \"bar\",\n \"html\": \"\"\n}") + } +} + func TestContext(t *testing.T) { reqContext := NewContext(0) reqContext.Set("testContextKey", "testValue") diff --git a/pkg/app/fs.go b/pkg/app/fs.go index 24eba0408..644a9e93c 100644 --- a/pkg/app/fs.go +++ b/pkg/app/fs.go @@ -286,7 +286,7 @@ func ServeFile(ctx *RequestContext, path string) { // extend relative path to absolute path var err error if path, err = filepath.Abs(path); err != nil { - hlog.Errorf("HERTZ: Cannot resolve path=%q to absolute file error=%s", path, err) + hlog.SystemLogger().Errorf("Cannot resolve path=%q to absolute file error=%s", path, err) ctx.AbortWithMsg("Internal Server Error", consts.StatusInternalServerError) return } @@ -805,7 +805,7 @@ func (h *fsHandler) handleRequest(c context.Context, ctx *RequestContext) { path = stripTrailingSlashes(path) if n := bytes.IndexByte(path, 0); n >= 0 { - hlog.Errorf("HERTZ: Cannot serve path with nil byte at position=%d, path=%q", n, path) + hlog.SystemLogger().Errorf("Cannot serve path with nil byte at position=%d, path=%q", n, path) ctx.AbortWithMsg("Are you a hacker?", consts.StatusBadRequest) return } @@ -814,7 +814,7 @@ func (h *fsHandler) handleRequest(c context.Context, ctx *RequestContext) { // since ctx.Path must normalize and sanitize the path. if n := bytes.Index(path, bytestr.StrSlashDotDotSlash); n >= 0 { - hlog.Errorf("HERTZ: Cannot serve path with '/../' at position=%d due to security reasons, path=%q", n, path) + hlog.SystemLogger().Errorf("Cannot serve path with '/../' at position=%d due to security reasons, path=%q", n, path) ctx.AbortWithMsg("Internal Server Error", consts.StatusInternalServerError) return } @@ -842,7 +842,7 @@ func (h *fsHandler) handleRequest(c context.Context, ctx *RequestContext) { ff, err = h.openFSFile(filePath, mustCompress) if mustCompress && err == errNoCreatePermission { - hlog.Errorf("HERTZ: Insufficient permissions for saving compressed file for path=%q. Serving uncompressed file. "+ + hlog.SystemLogger().Errorf("Insufficient permissions for saving compressed file for path=%q. Serving uncompressed file. "+ "Allow write access to the directory with this file in order to improve hertz performance", filePath) mustCompress = false ff, err = h.openFSFile(filePath, mustCompress) @@ -850,12 +850,12 @@ func (h *fsHandler) handleRequest(c context.Context, ctx *RequestContext) { if err == errDirIndexRequired { ff, err = h.openIndexFile(ctx, filePath, mustCompress) if err != nil { - hlog.Errorf("HERTZ: Cannot open dir index, path=%q, error=%s", filePath, err) + hlog.SystemLogger().Errorf("Cannot open dir index, path=%q, error=%s", filePath, err) ctx.AbortWithMsg("Directory index is forbidden", consts.StatusForbidden) return } } else if err != nil { - hlog.Errorf("HERTZ: Cannot open file=%q, error=%s", filePath, err) + hlog.SystemLogger().Errorf("Cannot open file=%q, error=%s", filePath, err) if h.pathNotFound == nil { ctx.AbortWithMsg("Cannot open requested path", consts.StatusNotFound) } else { @@ -892,7 +892,7 @@ func (h *fsHandler) handleRequest(c context.Context, ctx *RequestContext) { r, err := ff.NewReader() if err != nil { - hlog.Errorf("HERTZ: Cannot obtain file reader for path=%q, error=%s", path, err) + hlog.SystemLogger().Errorf("Cannot obtain file reader for path=%q, error=%s", path, err) ctx.AbortWithMsg("Internal Server Error", consts.StatusInternalServerError) return } @@ -910,14 +910,14 @@ func (h *fsHandler) handleRequest(c context.Context, ctx *RequestContext) { startPos, endPos, err := ParseByteRange(byteRange, contentLength) if err != nil { r.(io.Closer).Close() - hlog.Errorf("HERTZ: Cannot parse byte range %q for path=%q,error=%s", byteRange, path, err) + hlog.SystemLogger().Errorf("Cannot parse byte range %q for path=%q,error=%s", byteRange, path, err) ctx.AbortWithMsg("Range Not Satisfiable", consts.StatusRequestedRangeNotSatisfiable) return } if err = r.(byteRangeUpdater).UpdateByteRange(startPos, endPos); err != nil { r.(io.Closer).Close() - hlog.Errorf("HERTZ: Cannot seek byte range %q for path=%q, error=%s", byteRange, path, err) + hlog.SystemLogger().Errorf("Cannot seek byte range %q for path=%q, error=%s", byteRange, path, err) ctx.AbortWithMsg("Internal Server Error", consts.StatusInternalServerError) return } @@ -937,7 +937,7 @@ func (h *fsHandler) handleRequest(c context.Context, ctx *RequestContext) { ctx.Response.Header.SetContentLength(contentLength) if rc, ok := r.(io.Closer); ok { if err := rc.Close(); err != nil { - hlog.Errorf("HERTZ: Cannot close file reader: error=%s", err) + hlog.SystemLogger().Errorf("Cannot close file reader: error=%s", err) ctx.AbortWithMsg("Internal Server Error", consts.StatusInternalServerError) return } diff --git a/pkg/app/middlewares/server/recovery/option.go b/pkg/app/middlewares/server/recovery/option.go new file mode 100644 index 000000000..93d9566bf --- /dev/null +++ b/pkg/app/middlewares/server/recovery/option.go @@ -0,0 +1,56 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package recovery + +import ( + "context" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/common/hlog" + "github.com/cloudwego/hertz/pkg/protocol/consts" +) + +type ( + options struct { + recoveryHandler func(c context.Context, ctx *app.RequestContext, err interface{}, stack []byte) + } + + Option func(o *options) +) + +func defaultRecoveryHandler(c context.Context, ctx *app.RequestContext, err interface{}, stack []byte) { + hlog.SystemLogger().CtxErrorf(c, "[Recovery] err=%v\nstack=%s", err, stack) + ctx.AbortWithStatus(consts.StatusInternalServerError) +} + +func newOptions(opts ...Option) *options { + cfg := &options{ + recoveryHandler: defaultRecoveryHandler, + } + + for _, opt := range opts { + opt(cfg) + } + + return cfg +} + +func WithRecoveryHandler(f func(c context.Context, ctx *app.RequestContext, err interface{}, stack []byte)) Option { + return func(o *options) { + o.recoveryHandler = f + } +} diff --git a/pkg/app/middlewares/server/recovery/option_test.go b/pkg/app/middlewares/server/recovery/option_test.go new file mode 100644 index 000000000..4a2140d13 --- /dev/null +++ b/pkg/app/middlewares/server/recovery/option_test.go @@ -0,0 +1,44 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package recovery + +import ( + "context" + "fmt" + "testing" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/common/hlog" + "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/hertz/pkg/common/utils" +) + +func TestDefaultOption(t *testing.T) { + opts := newOptions() + assert.DeepEqual(t, fmt.Sprintf("%p", defaultRecoveryHandler), fmt.Sprintf("%p", opts.recoveryHandler)) +} + +func newRecoveryHandler(c context.Context, ctx *app.RequestContext, err interface{}, stack []byte) { + hlog.SystemLogger().CtxErrorf(c, "[New Recovery] panic recovered:\n%s\n%s\n", + err, stack) + ctx.JSON(501, utils.H{"msg": err.(string)}) +} + +func TestOption(t *testing.T) { + opts := newOptions(WithRecoveryHandler(newRecoveryHandler)) + assert.DeepEqual(t, fmt.Sprintf("%p", newRecoveryHandler), fmt.Sprintf("%p", opts.recoveryHandler)) +} diff --git a/pkg/app/middlewares/server/recovery/recovery.go b/pkg/app/middlewares/server/recovery/recovery.go index 071005640..f05aad217 100644 --- a/pkg/app/middlewares/server/recovery/recovery.go +++ b/pkg/app/middlewares/server/recovery/recovery.go @@ -22,11 +22,8 @@ import ( "fmt" "io/ioutil" "runtime" - "time" "github.com/cloudwego/hertz/pkg/app" - "github.com/cloudwego/hertz/pkg/common/hlog" - "github.com/cloudwego/hertz/pkg/protocol/consts" ) var ( @@ -36,16 +33,18 @@ var ( slash = []byte("/") ) -// Recovery returns a middleware that recovers from any panic and writes a 500 if there was one. -func Recovery() app.HandlerFunc { +// Recovery returns a middleware that recovers from any panic. +// By default, it will print the time, content, and stack information of the error and write a 500. +// Overriding the Config configuration, you can customize the error printing logic. +func Recovery(opts ...Option) app.HandlerFunc { + cfg := newOptions(opts...) + return func(c context.Context, ctx *app.RequestContext) { defer func() { if err := recover(); err != nil { stack := stack(3) - hlog.CtxErrorf(c, "[Recovery] %s panic recovered:\n%s\n%s\n", - timeFormat(time.Now()), err, stack) - ctx.AbortWithStatus(consts.StatusInternalServerError) + cfg.recoveryHandler(c, ctx, err, stack) } }() ctx.Next(c) @@ -112,8 +111,3 @@ func function(pc uintptr) []byte { name = bytes.Replace(name, centerDot, dot, -1) return name } - -func timeFormat(t time.Time) string { - timeString := t.Format("2006/01/02 - 15:04:05") - return timeString -} diff --git a/pkg/app/middlewares/server/recovery/recovery_test.go b/pkg/app/middlewares/server/recovery/recovery_test.go index 79fe4e478..ffaa1a381 100644 --- a/pkg/app/middlewares/server/recovery/recovery_test.go +++ b/pkg/app/middlewares/server/recovery/recovery_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestRecovery(t *testing.T) { @@ -39,3 +40,20 @@ func TestRecovery(t *testing.T) { t.Fatalf("unexpected %v. Expecting %v", ctx.Response.StatusCode(), 500) } } + +func TestWithRecoveryHandler(t *testing.T) { + ctx := app.NewContext(0) + var hc app.HandlersChain + hc = append(hc, func(c context.Context, ctx *app.RequestContext) { + fmt.Println("this is test") + panic("test") + }) + ctx.SetHandlers(hc) + + Recovery(WithRecoveryHandler(newRecoveryHandler))(context.Background(), ctx) + + if ctx.Response.StatusCode() != 501 { + t.Fatalf("unexpected %v. Expecting %v", ctx.Response.StatusCode(), 501) + } + assert.DeepEqual(t, "{\"msg\":\"test\"}", string(ctx.Response.Body())) +} diff --git a/pkg/app/server/binding/binding.go b/pkg/app/server/binding/binding.go index ac5aee52a..1d7298159 100644 --- a/pkg/app/server/binding/binding.go +++ b/pkg/app/server/binding/binding.go @@ -23,13 +23,13 @@ import ( "github.com/bytedance/go-tagexpr/v2/binding" "github.com/bytedance/go-tagexpr/v2/binding/gjson" "github.com/bytedance/go-tagexpr/v2/validator" - "github.com/bytedance/sonic" + hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) func init() { - binding.ResetJSONUnmarshaler(sonic.Unmarshal) + binding.ResetJSONUnmarshaler(hjson.Unmarshal) } var defaultBinder = binding.Default() diff --git a/pkg/app/server/hertz.go b/pkg/app/server/hertz.go index b8edc92e5..651b1ede4 100644 --- a/pkg/app/server/hertz.go +++ b/pkg/app/server/hertz.go @@ -67,20 +67,20 @@ func (h *Hertz) Spin() { } if err := signalWaiter(errCh); err != nil { - hlog.Errorf("HERTZ: Receive close signal: error=%v", err) + hlog.SystemLogger().Errorf("Receive close signal: error=%v", err) if err := h.Engine.Close(); err != nil { - hlog.Errorf("HERTZ: Close error=%v", err) + hlog.SystemLogger().Errorf("Close error=%v", err) } return } - hlog.Infof("HERTZ: Begin graceful shutdown, wait at most num=%d seconds...", h.GetOptions().ExitWaitTimeout/time.Second) + hlog.SystemLogger().Infof("Begin graceful shutdown, wait at most num=%d seconds...", h.GetOptions().ExitWaitTimeout/time.Second) ctx, cancel := context.WithTimeout(context.Background(), h.GetOptions().ExitWaitTimeout) defer cancel() if err := h.Shutdown(ctx); err != nil { - hlog.Errorf("HERTZ: Shutdown error=%v", err) + hlog.SystemLogger().Errorf("Shutdown error=%v", err) } } @@ -124,7 +124,7 @@ func (h *Hertz) initOnRunHooks(errChan chan error) { // delay register 1s time.Sleep(1 * time.Second) if err := opt.Registry.Register(opt.RegistryInfo); err != nil { - hlog.Errorf("HERTZ: Register error=%v", err) + hlog.SystemLogger().Errorf("Register error=%v", err) // pass err to errChan errChan <- err } diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index ccfa97a01..fc913e400 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -288,3 +288,11 @@ func WithAutoReloadRender(b bool, interval time.Duration) config.Option { o.AutoReloadInterval = interval }} } + +// WithDisablePrintRoute sets whether disable debugPrintRoute +// If we don't set it, it will default to false +func WithDisablePrintRoute(b bool) config.Option { + return config.Option{F: func(o *config.Options) { + o.DisablePrintRoute = b + }} +} diff --git a/pkg/app/server/option_test.go b/pkg/app/server/option_test.go index f45eb22c7..c68e2c30d 100644 --- a/pkg/app/server/option_test.go +++ b/pkg/app/server/option_test.go @@ -47,9 +47,12 @@ func TestOptions(t *testing.T) { WithStreamBody(false), WithHostPorts(":8888"), WithMaxRequestBodySize(2), + WithDisablePrintRoute(true), WithNetwork("unix"), WithExitWaitTime(time.Second), WithMaxKeepBodySize(500), + WithGetOnly(true), + WithKeepAlive(false), WithTLS(nil), WithH2C(true), WithReadBufferSize(100), @@ -71,9 +74,12 @@ func TestOptions(t *testing.T) { assert.DeepEqual(t, opt.StreamRequestBody, false) assert.DeepEqual(t, opt.Addr, ":8888") assert.DeepEqual(t, opt.MaxRequestBodySize, 2) + assert.DeepEqual(t, opt.DisablePrintRoute, true) assert.DeepEqual(t, opt.Network, "unix") assert.DeepEqual(t, opt.ExitWaitTimeout, time.Second) assert.DeepEqual(t, opt.MaxKeepBodySize, 500) + assert.DeepEqual(t, opt.GetOnly, true) + assert.DeepEqual(t, opt.DisableKeepalive, true) assert.DeepEqual(t, opt.H2C, true) assert.DeepEqual(t, opt.ReadBufferSize, 100) assert.DeepEqual(t, opt.ALPN, true) @@ -99,6 +105,9 @@ func TestDefaultOptions(t *testing.T) { assert.DeepEqual(t, opt.StreamRequestBody, false) assert.DeepEqual(t, opt.Addr, ":8888") assert.DeepEqual(t, opt.MaxRequestBodySize, 4*1024*1024) + assert.DeepEqual(t, opt.GetOnly, false) + assert.DeepEqual(t, opt.DisableKeepalive, false) + assert.DeepEqual(t, opt.DisablePrintRoute, false) assert.DeepEqual(t, opt.Network, "tcp") assert.DeepEqual(t, opt.ExitWaitTimeout, time.Second*5) assert.DeepEqual(t, opt.MaxKeepBodySize, 4*1024*1024) diff --git a/pkg/app/server/render/html.go b/pkg/app/server/render/html.go index 03cf687a5..f31e33a85 100644 --- a/pkg/app/server/render/html.go +++ b/pkg/app/server/render/html.go @@ -160,13 +160,13 @@ func (h *HTMLDebug) startChecker() { if h.RefreshInterval > 0 { go func() { - hlog.Debugf("HERTZ[HTMLDebug]: HTML template reloader started with interval %v", h.RefreshInterval) + hlog.SystemLogger().Debugf("[HTMLDebug] HTML template reloader started with interval %v", h.RefreshInterval) for { n := time.Now() if n.UTC().Sub(h.updateTimeStamp.UTC()) > h.RefreshInterval { - hlog.Debugf("HERTZ[HTMLDebug]: triggering HTML template reloader") + hlog.SystemLogger().Debugf("[HTMLDebug] triggering HTML template reloader") h.reloadCh <- struct{}{} - hlog.Debugf("HERTZ[HTMLDebug]: HTML template has been reloaded, next reload in %v", h.RefreshInterval) + hlog.SystemLogger().Debugf("[HTMLDebug] HTML template has been reloaded, next reload in %v", h.RefreshInterval) h.updateTimeStamp = time.Now() } } @@ -181,15 +181,15 @@ func (h *HTMLDebug) startChecker() { h.watcher = watcher for _, f := range h.Files { err := watcher.Add(f) - hlog.Debugf("HERTZ[HTMLDebug]: watching file: %s", f) + hlog.SystemLogger().Debugf("[HTMLDebug] watching file: %s", f) if err != nil { - hlog.Errorf("HERTZ[HTMLDebug]: add watching file: %s, error happened: %v", f, err) + hlog.SystemLogger().Errorf("[HTMLDebug] add watching file: %s, error happened: %v", f, err) } } go func() { - hlog.Debugf("HERTZ[HTMLDebug]: HTML template reloader started with file watcher") + hlog.SystemLogger().Debugf("[HTMLDebug] HTML template reloader started with file watcher") for { select { case event, ok := <-watcher.Events: @@ -197,15 +197,15 @@ func (h *HTMLDebug) startChecker() { return } if event.Op&fsnotify.Write == fsnotify.Write { - hlog.Debugf("HERTZ[HTMLDebug]: modified file: %s, html render template will be reloaded at the next rendering", event.Name) + hlog.SystemLogger().Debugf("[HTMLDebug] modified file: %s, html render template will be reloaded at the next rendering", event.Name) h.reloadCh <- struct{}{} - hlog.Debugf("HERTZ[HTMLDebug]: HTML template has been reloaded") + hlog.SystemLogger().Debugf("[HTMLDebug] HTML template has been reloaded") } case err, ok := <-watcher.Errors: if !ok { return } - hlog.Errorf("HERTZ: error happened when watching the rendering files: %v", err) + hlog.SystemLogger().Errorf("error happened when watching the rendering files: %v", err) } } }() diff --git a/pkg/app/server/render/json.go b/pkg/app/server/render/json.go index 69aa2f9a1..577f66bdf 100644 --- a/pkg/app/server/render/json.go +++ b/pkg/app/server/render/json.go @@ -45,17 +45,17 @@ import ( "bytes" "encoding/json" - "github.com/bytedance/sonic" + hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" ) -// customize json.Marshal as you like +// JSONMarshaler customize json.Marshal as you like type JSONMarshaler func(v interface{}) ([]byte, error) var jsonMarshalFunc JSONMarshaler func init() { - ResetJSONMarshal(sonic.Marshal) + ResetJSONMarshal(hjson.Marshal) } func ResetJSONMarshal(fn JSONMarshaler) { @@ -66,7 +66,7 @@ func ResetStdJSONMarshal() { ResetJSONMarshal(json.Marshal) } -// JSON contains the given interface object. +// JSONRender JSON contains the given interface object. type JSONRender struct { Data interface{} } @@ -113,3 +113,29 @@ func (r PureJSON) Render(resp *protocol.Response) (err error) { func (r PureJSON) WriteContentType(resp *protocol.Response) { writeContentType(resp, jsonContentType) } + +// IndentedJSON contains the given interface object. +type IndentedJSON struct { + Data interface{} +} + +// Render (IndentedJSON) marshals the given interface object and writes it with custom ContentType. +func (r IndentedJSON) Render(resp *protocol.Response) (err error) { + writeContentType(resp, jsonContentType) + jsonBytes, err := jsonMarshalFunc(r.Data) + if err != nil { + return err + } + var buf bytes.Buffer + err = json.Indent(&buf, jsonBytes, "", " ") + if err != nil { + return err + } + resp.AppendBody(buf.Bytes()) + return nil +} + +// WriteContentType (JSON) writes JSON ContentType. +func (r IndentedJSON) WriteContentType(resp *protocol.Response) { + writeContentType(resp, jsonContentType) +} diff --git a/pkg/app/server/render/render_test.go b/pkg/app/server/render/render_test.go index b06468b80..bc839e0c0 100644 --- a/pkg/app/server/render/render_test.go +++ b/pkg/app/server/render/render_test.go @@ -45,6 +45,7 @@ import ( "encoding/xml" "testing" + "github.com/bytedance/sonic" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" ) @@ -159,3 +160,30 @@ func TestRenderXML(t *testing.T) { assert.DeepEqual(t, []byte("bar"), resp.Body()) assert.DeepEqual(t, []byte("application/xml; charset=utf-8"), resp.Header.Peek("Content-Type")) } + +func TestRenderIndentedJSON(t *testing.T) { + data := map[string]interface{}{ + "foo": "bar", + "html": "h1", + } + t.Run("TestHeader", func(t *testing.T) { + resp := &protocol.Response{} + (IndentedJSON{data}).WriteContentType(resp) + assert.DeepEqual(t, []byte("application/json; charset=utf-8"), resp.Header.Peek("Content-Type")) + }) + t.Run("TestBody", func(t *testing.T) { + ResetStdJSONMarshal() + resp := &protocol.Response{} + err := (IndentedJSON{data}).Render(resp) + assert.Nil(t, err) + assert.DeepEqual(t, []byte("{\n \"foo\": \"bar\",\n \"html\": \"h1\"\n}"), resp.Body()) + assert.DeepEqual(t, []byte("application/json; charset=utf-8"), resp.Header.Peek("Content-Type")) + ResetJSONMarshal(sonic.Marshal) + }) + t.Run("TestError", func(t *testing.T) { + resp := &protocol.Response{} + ch := make(chan int) + err := (IndentedJSON{ch}).Render(resp) + assert.NotNil(t, err) + }) +} diff --git a/pkg/common/adaptor/request_test.go b/pkg/common/adaptor/request_test.go index f0909b19a..a3eb3f916 100644 --- a/pkg/common/adaptor/request_test.go +++ b/pkg/common/adaptor/request_test.go @@ -26,35 +26,48 @@ import ( "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app/server" + "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/consts" ) func TestCompatResponse_WriteHeader(t *testing.T) { var testHeader http.Header var testBody string - testUrl := "http://127.0.0.1:9000/test" + testUrl1 := "http://127.0.0.1:9000/test1" + testUrl2 := "http://127.0.0.1:9000/test2" testStatusCode := 299 + testCookieValue := "cookie" testHeader = make(map[string][]string) testHeader["Key1"] = []string{"value1"} testHeader["Key2"] = []string{"value2", "value22"} testHeader["Key3"] = []string{"value3", "value33", "value333"} + testHeader[consts.HeaderSetCookie] = []string{testCookieValue} testBody = "test body" h := server.New(server.WithHostPorts("127.0.0.1:9000")) - h.POST("/test", func(c context.Context, ctx *app.RequestContext) { + h.POST("/test1", func(c context.Context, ctx *app.RequestContext) { req, _ := GetCompatRequest(&ctx.Request) resp := GetCompatResponseWriter(&ctx.Response) handlerAndCheck(t, resp, req, testHeader, testBody, testStatusCode) }) + h.POST("/test2", func(c context.Context, ctx *app.RequestContext) { + req, _ := GetCompatRequest(&ctx.Request) + resp := GetCompatResponseWriter(&ctx.Response) + handlerAndCheck(t, resp, req, testHeader, testBody) + }) + go h.Spin() time.Sleep(200 * time.Millisecond) - makeACall(t, http.MethodPost, testUrl, testHeader, testBody, testStatusCode) + makeACall(t, http.MethodPost, testUrl1, testHeader, testBody, testStatusCode, []byte(testCookieValue)) + makeACall(t, http.MethodPost, testUrl2, testHeader, testBody, consts.StatusOK, []byte(testCookieValue)) } -func makeACall(t *testing.T, method, url string, header http.Header, body string, expectStatusCode int) { +func makeACall(t *testing.T, method, url string, header http.Header, body string, expectStatusCode int, expectCookieValue []byte) { client := http.Client{} req, _ := http.NewRequest(method, url, strings.NewReader(body)) req.Header = header @@ -77,16 +90,21 @@ func makeACall(t *testing.T, method, url string, header http.Header, body string if err != nil { t.Fatalf("Read body error: %s", err) } - if string(b) != body { - t.Fatalf("Body not equal: want: %s, got: %s", body, string(b)) - } + assert.DeepEqual(t, body, string(b)) + assert.DeepEqual(t, expectStatusCode, resp.StatusCode) - if resp.StatusCode != expectStatusCode { - t.Fatalf("Status code not equal: want: %d, got: %d", expectStatusCode, resp.StatusCode) - } + // Parse out the cookie to verify it is correct + cookie := protocol.Cookie{} + _ = cookie.Parse(header[consts.HeaderSetCookie][0]) + assert.DeepEqual(t, expectCookieValue, cookie.Value()) } -func handlerAndCheck(t *testing.T, writer http.ResponseWriter, request *http.Request, wantHeader http.Header, wantBody string, statusCode int) { +// handlerAndCheck is designed to handle the program and check the header +// +// "..." is used in the type of statusCode, which is a syntactic sugar in Go. +// In this way, the statusCode can be made an optional parameter, +// and there is no need to pass in some meaningless numbers to judge some special cases. +func handlerAndCheck(t *testing.T, writer http.ResponseWriter, request *http.Request, wantHeader http.Header, wantBody string, statusCode ...int) { reqHeader := request.Header for k, v := range wantHeader { if reqHeader[k] == nil { @@ -101,15 +119,19 @@ func handlerAndCheck(t *testing.T, writer http.ResponseWriter, request *http.Req if err != nil { t.Fatalf("Read body error: %s", err) } - if string(body) != wantBody { - t.Fatalf("Body not equal: want: %s, got: %s", wantBody, string(body)) - } + assert.DeepEqual(t, wantBody, string(body)) respHeader := writer.Header() for k, v := range reqHeader { respHeader[k] = v } - writer.WriteHeader(statusCode) + + // When the incoming status code is nil, the execution of this code is skipped + // and the status code is set to 200 + if statusCode != nil { + writer.WriteHeader(statusCode[0]) + } + _, err = writer.Write([]byte("test")) if err != nil { t.Fatalf("Write body error: %s", err) diff --git a/pkg/common/bytebufferpool/bytebuffer.go b/pkg/common/bytebufferpool/bytebuffer.go index 8783aa47a..656174666 100644 --- a/pkg/common/bytebufferpool/bytebuffer.go +++ b/pkg/common/bytebufferpool/bytebuffer.go @@ -51,7 +51,6 @@ import "io" // // Use Get for obtaining an empty byte buffer. type ByteBuffer struct { - // B is a byte buffer to use in append-like workloads. // See example code for details. B []byte diff --git a/pkg/common/config/client_option.go b/pkg/common/config/client_option.go index ec6f5e68c..8b44d243c 100644 --- a/pkg/common/config/client_option.go +++ b/pkg/common/config/client_option.go @@ -20,6 +20,7 @@ import ( "crypto/tls" "time" + "github.com/cloudwego/hertz/pkg/app/client/retry" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol/consts" ) @@ -35,14 +36,13 @@ type ClientOptions struct { // The max connection nums for each host MaxConnsPerHost int - MaxIdleConnDuration time.Duration - MaxConnDuration time.Duration - MaxConnWaitTimeout time.Duration - MaxIdempotentCallAttempts int - KeepAlive bool - ReadTimeout time.Duration - TLSConfig *tls.Config - ResponseBodyStream bool + MaxIdleConnDuration time.Duration + MaxConnDuration time.Duration + MaxConnWaitTimeout time.Duration + KeepAlive bool + ReadTimeout time.Duration + TLSConfig *tls.Config + ResponseBodyStream bool // Client name. Used in User-Agent request header. // @@ -103,15 +103,17 @@ type ClientOptions struct { // By default path values are normalized, i.e. // extra slashes are removed, special characters are encoded. DisablePathNormalizing bool + + // all configurations related to retry + RetryConfig *retry.Config } func NewClientOptions(opts []ClientOption) *ClientOptions { options := &ClientOptions{ - DialTimeout: consts.DefaultDialTimeout, - MaxConnsPerHost: consts.DefaultMaxConnsPerHost, - MaxIdleConnDuration: consts.DefaultMaxIdleConnDuration, - MaxIdempotentCallAttempts: consts.DefaultMaxIdempotentCallAttempts, - KeepAlive: true, + DialTimeout: consts.DefaultDialTimeout, + MaxConnsPerHost: consts.DefaultMaxConnsPerHost, + MaxIdleConnDuration: consts.DefaultMaxIdleConnDuration, + KeepAlive: true, } options.Apply(opts) diff --git a/pkg/common/config/client_option_test.go b/pkg/common/config/client_option_test.go new file mode 100644 index 000000000..8e9f0151d --- /dev/null +++ b/pkg/common/config/client_option_test.go @@ -0,0 +1,49 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package config + +import ( + "testing" + "time" + + "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/hertz/pkg/protocol/consts" +) + +// TestDefaultClientOptions test client options with default values +func TestDefaultClientOptions(t *testing.T) { + options := NewClientOptions([]ClientOption{}) + + assert.DeepEqual(t, consts.DefaultDialTimeout, options.DialTimeout) + assert.DeepEqual(t, consts.DefaultMaxConnsPerHost, options.MaxConnsPerHost) + assert.DeepEqual(t, consts.DefaultMaxIdleConnDuration, options.MaxIdleConnDuration) + assert.DeepEqual(t, true, options.KeepAlive) +} + +// TestCustomClientOptions test client options with custom values +func TestCustomClientOptions(t *testing.T) { + options := NewClientOptions([]ClientOption{}) + + options.Apply([]ClientOption{ + { + F: func(o *ClientOptions) { + o.DialTimeout = 2 * time.Second + }, + }, + }) + assert.DeepEqual(t, 2*time.Second, options.DialTimeout) +} diff --git a/pkg/common/config/option.go b/pkg/common/config/option.go index 1fc2789ab..cd69716a3 100644 --- a/pkg/common/config/option.go +++ b/pkg/common/config/option.go @@ -56,6 +56,7 @@ type Options struct { DisablePreParseMultipartForm bool StreamRequestBody bool NoDefaultServerHeader bool + DisablePrintRoute bool Network string Addr string ExitWaitTimeout time.Duration @@ -151,6 +152,10 @@ func NewOptions(opts []Option) *Options { // like they are normal requests DisablePreParseMultipartForm: false, + // Routes info printing is not disabled by default + // Disabled when set to True + DisablePrintRoute: false, + // "tcp", "udp", "unix"(unix domain socket) Network: defaultNetwork, diff --git a/pkg/common/config/option_test.go b/pkg/common/config/option_test.go new file mode 100644 index 000000000..71eeae754 --- /dev/null +++ b/pkg/common/config/option_test.go @@ -0,0 +1,65 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package config + +import ( + "testing" + + "github.com/cloudwego/hertz/pkg/app/server/registry" + "github.com/cloudwego/hertz/pkg/common/test/assert" +) + +// TestDefaultOptions test options with default values +func TestDefaultOptions(t *testing.T) { + options := NewOptions([]Option{}) + + assert.DeepEqual(t, defaultKeepAliveTimeout, options.KeepAliveTimeout) + assert.DeepEqual(t, defaultReadTimeout, options.ReadTimeout) + assert.DeepEqual(t, defaultReadTimeout, options.IdleTimeout) + assert.True(t, options.RedirectTrailingSlash) + assert.True(t, options.RedirectTrailingSlash) + assert.False(t, options.HandleMethodNotAllowed) + assert.False(t, options.UseRawPath) + assert.False(t, options.RemoveExtraSlash) + assert.True(t, options.UnescapePathValues) + assert.False(t, options.DisablePreParseMultipartForm) + assert.DeepEqual(t, defaultNetwork, options.Network) + assert.DeepEqual(t, defaultAddr, options.Addr) + assert.DeepEqual(t, defaultMaxRequestBodySize, options.MaxRequestBodySize) + assert.False(t, options.GetOnly) + assert.False(t, options.DisableKeepalive) + assert.False(t, options.NoDefaultServerHeader) + assert.DeepEqual(t, defaultWaitExitTimeout, options.ExitWaitTimeout) + assert.Nil(t, options.TLS) + assert.DeepEqual(t, defaultReadBufferSize, options.ReadBufferSize) + assert.False(t, options.ALPN) + assert.False(t, options.H2C) + assert.DeepEqual(t, []interface{}{}, options.Tracers) + assert.DeepEqual(t, new(interface{}), options.TraceLevel) + assert.DeepEqual(t, registry.NoopRegistry, options.Registry) +} + +// TestApplyCustomOptions test apply options with custom values after init +func TestApplyCustomOptions(t *testing.T) { + options := NewOptions([]Option{}) + options.Apply([]Option{ + {F: func(o *Options) { + o.Network = "unix" + }}, + }) + assert.DeepEqual(t, "unix", options.Network) +} diff --git a/pkg/common/config/request_option_test.go b/pkg/common/config/request_option_test.go index 9658250ca..c044c434b 100644 --- a/pkg/common/config/request_option_test.go +++ b/pkg/common/config/request_option_test.go @@ -22,6 +22,7 @@ import ( "github.com/cloudwego/hertz/pkg/common/test/assert" ) +// TestRequestOptions test request options with custom values func TestRequestOptions(t *testing.T) { opt := NewRequestOptions([]RequestOption{ WithTag("a", "b"), @@ -29,21 +30,38 @@ func TestRequestOptions(t *testing.T) { WithTag("e", "f"), WithSD(true), }) - assert.DeepEqual(t, opt.Tag("a"), "b") - assert.DeepEqual(t, opt.Tag("c"), "d") - assert.DeepEqual(t, opt.Tag("e"), "f") - assert.DeepEqual(t, opt.IsSD(), true) + assert.DeepEqual(t, "b", opt.Tag("a")) + assert.DeepEqual(t, "d", opt.Tag("c")) + assert.DeepEqual(t, "f", opt.Tag("e")) + assert.True(t, opt.IsSD()) } +// TestRequestOptionsWithDefaultOpts test request options with default values func TestRequestOptionsWithDefaultOpts(t *testing.T) { SetPreDefinedOpts(WithTag("pre-defined", "blablabla"), WithTag("a", "default-value"), WithSD(true)) opt := NewRequestOptions([]RequestOption{ WithTag("a", "b"), WithSD(false), }) - assert.DeepEqual(t, opt.Tag("a"), "b") - assert.DeepEqual(t, opt.Tag("pre-defined"), "blablabla") - assert.DeepEqual(t, opt.IsSD(), false) + assert.DeepEqual(t, "b", opt.Tag("a")) + assert.DeepEqual(t, "blablabla", opt.Tag("pre-defined")) + assert.DeepEqual(t, map[string]string{ + "a": "b", + "pre-defined": "blablabla", + }, opt.Tags()) + assert.False(t, opt.IsSD()) SetPreDefinedOpts() assert.Nil(t, preDefinedOpts) } + +// TestRequestOptions_CopyTo test request options copy to another one +func TestRequestOptions_CopyTo(t *testing.T) { + opt := NewRequestOptions([]RequestOption{ + WithTag("a", "b"), + WithSD(false), + }) + var copyOpt RequestOptions + opt.CopyTo(©Opt) + assert.DeepEqual(t, opt.Tags(), copyOpt.Tags()) + assert.DeepEqual(t, opt.IsSD(), copyOpt.IsSD()) +} diff --git a/pkg/common/errors/errors.go b/pkg/common/errors/errors.go index 38e90c9c1..d764c41d5 100644 --- a/pkg/common/errors/errors.go +++ b/pkg/common/errors/errors.go @@ -222,3 +222,15 @@ func NewPublic(err string) *Error { func NewPrivate(err string) *Error { return New(errors.New(err), ErrorTypePrivate, nil) } + +func Newf(t ErrorType, meta interface{}, format string, v ...interface{}) *Error { + return New(fmt.Errorf(format, v...), t, meta) +} + +func NewPublicf(format string, v ...interface{}) *Error { + return New(fmt.Errorf(format, v...), ErrorTypePublic, nil) +} + +func NewPrivatef(format string, v ...interface{}) *Error { + return New(fmt.Errorf(format, v...), ErrorTypePrivate, nil) +} diff --git a/pkg/common/errors/errors_test.go b/pkg/common/errors/errors_test.go index 1f54297df..834b27262 100644 --- a/pkg/common/errors/errors_test.go +++ b/pkg/common/errors/errors_test.go @@ -130,3 +130,12 @@ Error #03: third assert.Nil(t, errs.JSON()) assert.DeepEqual(t, "", errs.String()) } + +func TestErrorFormat(t *testing.T) { + err := Newf(ErrorTypeAny, nil, "caused by %s", "reason") + assert.DeepEqual(t, New(errors.New("caused by reason"), ErrorTypeAny, nil), err) + publicErr := NewPublicf("caused by %s", "reason") + assert.DeepEqual(t, New(errors.New("caused by reason"), ErrorTypePublic, nil), publicErr) + privateErr := NewPrivatef("caused by %s", "reason") + assert.DeepEqual(t, New(errors.New("caused by reason"), ErrorTypePrivate, nil), privateErr) +} diff --git a/pkg/common/hlog/default.go b/pkg/common/hlog/default.go index 3b9548456..6bc624a4a 100644 --- a/pkg/common/hlog/default.go +++ b/pkg/common/hlog/default.go @@ -24,34 +24,6 @@ import ( "os" ) -var logger FullLogger = &defaultLogger{ - stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), -} - -// SetOutput sets the output of default logger. By default, it is stderr. -func SetOutput(w io.Writer) { - logger.SetOutput(w) -} - -// SetLevel sets the level of logs below which logs will not be output. -// The default log level is LevelTrace. -// Note that this method is not concurrent-safe. -func SetLevel(lv Level) { - logger.SetLevel(lv) -} - -// DefaultLogger return the default logger for hertz. -func DefaultLogger() FullLogger { - return logger -} - -// SetLogger sets the default logger. -// Note that this method is not concurrent-safe and must not be called -// after the use of DefaultLogger and global functions in this package. -func SetLogger(v FullLogger) { - logger = v -} - // Fatal calls the default logger's Fatal method and then os.Exit(1). func Fatal(v ...interface{}) { logger.Fatal(v...) @@ -160,6 +132,7 @@ func CtxTracef(ctx context.Context, format string, v ...interface{}) { type defaultLogger struct { stdlog *log.Logger level Level + depth int } func (ll *defaultLogger) SetOutput(w io.Writer) { @@ -180,7 +153,7 @@ func (ll *defaultLogger) logf(lv Level, format *string, v ...interface{}) { } else { msg += fmt.Sprint(v...) } - ll.stdlog.Output(4, msg) + ll.stdlog.Output(ll.depth, msg) if lv == LevelFatal { os.Exit(1) } diff --git a/pkg/common/hlog/hlog.go b/pkg/common/hlog/hlog.go new file mode 100644 index 000000000..11a81615e --- /dev/null +++ b/pkg/common/hlog/hlog.go @@ -0,0 +1,84 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package hlog + +import ( + "io" + "log" + "os" +) + +const ( + systemLogPrefix = "HERTZ: " +) + +var ( + // Provide default logger for users to use + logger FullLogger = &defaultLogger{ + stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), + depth: 4, + } + + // Provide system logger for print system log + sysLogger FullLogger = &systemLogger{ + &defaultLogger{ + stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), + depth: 4, + }, + systemLogPrefix, + } +) + +// SetOutput sets the output of default logger and system logger. By default, it is stderr. +func SetOutput(w io.Writer) { + logger.SetOutput(w) + sysLogger.SetOutput(w) +} + +// SetLevel sets the level of logs below which logs will not be output. +// The default logger and system logger level is LevelTrace. +// Note that this method is not concurrent-safe. +func SetLevel(lv Level) { + logger.SetLevel(lv) + sysLogger.SetLevel(lv) +} + +// DefaultLogger return the default logger for hertz. +func DefaultLogger() FullLogger { + return logger +} + +// SystemLogger return the system logger for hertz to print system log. +// This function is not recommended for users to use. +func SystemLogger() FullLogger { + return sysLogger +} + +// SetSystemLogger sets the system logger. +// Note that this method is not concurrent-safe and must not be called +// This function is not recommended for users to use. +func SetSystemLogger(v FullLogger) { + sysLogger = &systemLogger{v, systemLogPrefix} +} + +// SetLogger sets the default logger and the system logger. +// Note that this method is not concurrent-safe and must not be called +// after the use of DefaultLogger and global functions in this package. +func SetLogger(v FullLogger) { + logger = v + SetSystemLogger(v) +} diff --git a/pkg/common/hlog/system.go b/pkg/common/hlog/system.go new file mode 100644 index 000000000..e9315170d --- /dev/null +++ b/pkg/common/hlog/system.go @@ -0,0 +1,136 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package hlog + +import ( + "context" + "io" + "strings" + "sync" +) + +var builderPool = sync.Pool{New: func() interface{} { + return &strings.Builder{} // nolint:SA6002 +}} + +type systemLogger struct { + logger FullLogger + prefix string +} + +func (ll *systemLogger) SetOutput(w io.Writer) { + ll.logger.SetOutput(w) +} + +func (ll *systemLogger) SetLevel(lv Level) { + ll.logger.SetLevel(lv) +} + +func (ll *systemLogger) Fatal(v ...interface{}) { + ll.logger.Fatal(v...) +} + +func (ll *systemLogger) Error(v ...interface{}) { + ll.logger.Error(v...) +} + +func (ll *systemLogger) Warn(v ...interface{}) { + ll.logger.Warn(v...) +} + +func (ll *systemLogger) Notice(v ...interface{}) { + ll.logger.Notice(v...) +} + +func (ll *systemLogger) Info(v ...interface{}) { + ll.logger.Info(v...) +} + +func (ll *systemLogger) Debug(v ...interface{}) { + ll.logger.Debug(v...) +} + +func (ll *systemLogger) Trace(v ...interface{}) { + ll.logger.Trace(v...) +} + +func (ll *systemLogger) Fatalf(format string, v ...interface{}) { + ll.logger.Fatalf(ll.addPrefix(format), v...) +} + +func (ll *systemLogger) Errorf(format string, v ...interface{}) { + ll.logger.Errorf(ll.addPrefix(format), v...) +} + +func (ll *systemLogger) Warnf(format string, v ...interface{}) { + ll.logger.Warnf(ll.addPrefix(format), v...) +} + +func (ll *systemLogger) Noticef(format string, v ...interface{}) { + ll.logger.Noticef(ll.addPrefix(format), v...) +} + +func (ll *systemLogger) Infof(format string, v ...interface{}) { + ll.logger.Infof(ll.addPrefix(format), v...) +} + +func (ll *systemLogger) Debugf(format string, v ...interface{}) { + ll.logger.Debugf(ll.addPrefix(format), v...) +} + +func (ll *systemLogger) Tracef(format string, v ...interface{}) { + ll.logger.Tracef(ll.addPrefix(format), v...) +} + +func (ll *systemLogger) CtxFatalf(ctx context.Context, format string, v ...interface{}) { + ll.logger.CtxFatalf(ctx, ll.addPrefix(format), v...) +} + +func (ll *systemLogger) CtxErrorf(ctx context.Context, format string, v ...interface{}) { + ll.logger.CtxErrorf(ctx, ll.addPrefix(format), v...) +} + +func (ll *systemLogger) CtxWarnf(ctx context.Context, format string, v ...interface{}) { + ll.logger.CtxWarnf(ctx, ll.addPrefix(format), v...) +} + +func (ll *systemLogger) CtxNoticef(ctx context.Context, format string, v ...interface{}) { + ll.logger.CtxNoticef(ctx, ll.addPrefix(format), v...) +} + +func (ll *systemLogger) CtxInfof(ctx context.Context, format string, v ...interface{}) { + ll.logger.CtxInfof(ctx, ll.addPrefix(format), v...) +} + +func (ll *systemLogger) CtxDebugf(ctx context.Context, format string, v ...interface{}) { + ll.logger.CtxDebugf(ctx, ll.addPrefix(format), v...) +} + +func (ll *systemLogger) CtxTracef(ctx context.Context, format string, v ...interface{}) { + ll.logger.CtxTracef(ctx, ll.addPrefix(format), v...) +} + +func (ll *systemLogger) addPrefix(format string) string { + builder := builderPool.Get().(*strings.Builder) + builder.Grow(len(format) + len(ll.prefix)) + builder.WriteString(ll.prefix) + builder.WriteString(format) + s := builder.String() + builder.Reset() + builderPool.Put(builder) // nolint:SA6002 + return s +} diff --git a/pkg/common/json/sonic.go b/pkg/common/json/sonic.go new file mode 100644 index 000000000..af0174e7e --- /dev/null +++ b/pkg/common/json/sonic.go @@ -0,0 +1,39 @@ +// Copyright 2022 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +//go:build (linux || windows || darwin) && amd64 +// +build linux windows darwin +// +build amd64 + +package json + +import "github.com/bytedance/sonic" + +// Name is the name of the effective json package. +const Name = "sonic" + +var ( + json = sonic.ConfigStd + // Marshal is sonic implementation exported by hertz which is used by rendering. + Marshal = json.Marshal + // Unmarshal is sonic implementation exported by hertz which is used by binding. + Unmarshal = json.Unmarshal + // MarshalIndent is sonic implementation exported by hertz. + MarshalIndent = json.MarshalIndent + // NewDecoder is sonic implementation exported by hertz. + NewDecoder = json.NewDecoder + // NewEncoder is sonic implementation exported by hertz. + NewEncoder = json.NewEncoder +) diff --git a/pkg/common/json/std.go b/pkg/common/json/std.go new file mode 100644 index 000000000..5f0f71aec --- /dev/null +++ b/pkg/common/json/std.go @@ -0,0 +1,37 @@ +// Copyright 2022 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +//go:build stdjson || !(amd64 && (linux || windows || darwin)) +// +build stdjson !amd64 !linux,!windows,!darwin + +package json + +import "encoding/json" + +// Name is the name of the effective json package. +const Name = "encoding/json" + +var ( + // Marshal is standard implementation exported by hertz which is used by rendering. + Marshal = json.Marshal + // Unmarshal is standard implementation exported by hertz which is used by binding. + Unmarshal = json.Unmarshal + // MarshalIndent is standard implementation exported by hertz. + MarshalIndent = json.MarshalIndent + // NewDecoder is standard implementation exported by hertz. + NewDecoder = json.NewDecoder + // NewEncoder is standard implementation exported by hertz. + NewEncoder = json.NewEncoder +) diff --git a/pkg/network/dialer/dialer_test.go b/pkg/network/dialer/dialer_test.go new file mode 100644 index 000000000..eb9896918 --- /dev/null +++ b/pkg/network/dialer/dialer_test.go @@ -0,0 +1,57 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dialer + +import ( + "crypto/tls" + "errors" + "net" + "testing" + "time" + + "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/hertz/pkg/network" +) + +func TestDialer(t *testing.T) { + SetDialer(&mockDialer{}) + dialer := DefaultDialer() + assert.DeepEqual(t, &mockDialer{}, dialer) + + _, err := AddTLS(nil, nil) + assert.NotNil(t, err) + + _, err = DialConnection("", "", 0, nil) + assert.NotNil(t, err) + + _, err = DialTimeout("", "", 0, nil) + assert.NotNil(t, err) +} + +type mockDialer struct{} + +func (m *mockDialer) DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) { + return nil, errors.New("method not implement") +} + +func (m *mockDialer) DialTimeout(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn net.Conn, err error) { + return nil, errors.New("method not implement") +} + +func (m *mockDialer) AddTLS(conn network.Conn, tlsConfig *tls.Config) (network.Conn, error) { + return nil, errors.New("method not implement") +} diff --git a/pkg/network/netpoll/connection.go b/pkg/network/netpoll/connection.go index a407abd67..9281e6455 100644 --- a/pkg/network/netpoll/connection.go +++ b/pkg/network/netpoll/connection.go @@ -22,6 +22,8 @@ package netpoll import ( "errors" "io" + "strings" + "syscall" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/network" @@ -75,8 +77,12 @@ func (c *Conn) Flush() error { } func (c *Conn) HandleSpecificError(err error, rip string) (needIgnore bool) { - if errors.Is(err, netpoll.ErrConnClosed) { - hlog.Warnf("HERTZ: Netpoll error=%s, remoteAddr=%s", err.Error(), rip) + if errors.Is(err, netpoll.ErrConnClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { + // ignore flushing error when connection is closed or reset + if strings.Contains(err.Error(), "when flush") { + return true + } + hlog.SystemLogger().Debugf("Netpoll error=%s, remoteAddr=%s", err.Error(), rip) return true } return false diff --git a/pkg/network/netpoll/connection_test.go b/pkg/network/netpoll/connection_test.go new file mode 100644 index 000000000..79dc80b5d --- /dev/null +++ b/pkg/network/netpoll/connection_test.go @@ -0,0 +1,214 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package netpoll + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/netpoll" +) + +func TestReadBytes(t *testing.T) { + c := &mockConn{[]byte("a"), nil, 0} + conn := newConn(c) + assert.DeepEqual(t, 1, conn.Len()) + + b, _ := conn.Peek(1) + assert.DeepEqual(t, []byte{'a'}, b) + + readByte, _ := conn.ReadByte() + assert.DeepEqual(t, byte('a'), readByte) + + _, err := conn.ReadByte() + assert.DeepEqual(t, errors.New("readByte error: index out of range"), err) + + c = &mockConn{[]byte("bcd"), nil, 0} + conn = newConn(c) + + readBinary, _ := conn.ReadBinary(2) + assert.DeepEqual(t, []byte{'b', 'c'}, readBinary) + + _, err = conn.ReadBinary(2) + assert.DeepEqual(t, errors.New("readBinary error: index out of range"), err) +} + +func TestPeekRelease(t *testing.T) { + c := &mockConn{[]byte("abcdefg"), nil, 0} + conn := newConn(c) + + // release the buf + conn.Release() + _, err := conn.Peek(1) + assert.DeepEqual(t, errors.New("peek error"), err) + + assert.DeepEqual(t, errors.New("skip error"), conn.Skip(2)) +} + +func TestWriteLogin(t *testing.T) { + c := &mockConn{nil, []byte("abcdefg"), 0} + conn := newConn(c) + buf, _ := conn.Malloc(10) + assert.DeepEqual(t, 10, len(buf)) + n, _ := conn.WriteBinary([]byte("abcdefg")) + assert.DeepEqual(t, 7, n) + assert.DeepEqual(t, errors.New("flush error"), conn.Flush()) +} + +func TestHandleSpecificError(t *testing.T) { + conn := &Conn{} + assert.DeepEqual(t, false, conn.HandleSpecificError(nil, "")) + assert.DeepEqual(t, true, conn.HandleSpecificError(netpoll.ErrConnClosed, "")) +} + +type mockConn struct { + readBuf []byte + writeBuf []byte + // index for the first readable byte in readBuf + off int +} + +// mockConn's methods is simplified for unit test +// Peek returns the next n bytes without advancing the reader +func (m *mockConn) Peek(n int) (b []byte, err error) { + if m.off+n-1 < len(m.readBuf) { + return m.readBuf[m.off : m.off+n], nil + } + return nil, errors.New("peek error") +} + +// Skip discards the next n bytes +func (m *mockConn) Skip(n int) error { + if m.off+n < len(m.readBuf) { + m.off += n + return nil + } + return errors.New("skip error") +} + +// Release the memory space occupied by all read slices +func (m *mockConn) Release() error { + m.readBuf = nil + m.off = 0 + return nil +} + +// Len returns the total length of the readable data in the reader +func (m *mockConn) Len() int { + return len(m.readBuf) - m.off +} + +// ReadByte is used to read one byte with advancing the read pointer +func (m *mockConn) ReadByte() (byte, error) { + if m.off < len(m.readBuf) { + m.off++ + return m.readBuf[m.off-1], nil + } + return 0, errors.New("readByte error: index out of range") +} + +// ReadBinary is used to read next n byte with copy, and the read pointer will be advanced +func (m *mockConn) ReadBinary(n int) (b []byte, err error) { + if m.off+n < len(m.readBuf) { + m.off += n + return m.readBuf[m.off-n : m.off], nil + } + return nil, errors.New("readBinary error: index out of range") +} + +// Malloc will provide a n bytes buffer to send data +func (m *mockConn) Malloc(n int) (buf []byte, err error) { + m.writeBuf = make([]byte, n) + return m.writeBuf, nil +} + +// WriteBinary will use the user buffer to flush +func (m *mockConn) WriteBinary(b []byte) (n int, err error) { + return len(b), nil +} + +// Flush will send data to the peer end +func (m *mockConn) Flush() error { + return errors.New("flush error") +} + +func (m *mockConn) HandleSpecificError(err error, rip string) (needIgnore bool) { + panic("implement me") +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + panic("implement me") +} + +func (m *mockConn) Write(b []byte) (n int, err error) { + panic("implement me") +} + +func (m *mockConn) Close() error { + panic("implement me") +} + +func (m *mockConn) LocalAddr() net.Addr { + panic("implement me") +} + +func (m *mockConn) RemoteAddr() net.Addr { + panic("implement me") +} + +func (m *mockConn) SetDeadline(deadline time.Time) error { + panic("implement me") +} + +func (m *mockConn) SetReadDeadline(deadline time.Time) error { + panic("implement me") +} + +func (m *mockConn) SetWriteDeadline(deadline time.Time) error { + panic("implement me") +} + +func (m *mockConn) Reader() netpoll.Reader { + panic("implement me") +} + +func (m *mockConn) Writer() netpoll.Writer { + panic("implement me") +} + +func (m *mockConn) IsActive() bool { + panic("implement me") +} + +func (m *mockConn) SetReadTimeout(timeout time.Duration) error { + panic("implement me") +} + +func (m *mockConn) SetIdleTimeout(timeout time.Duration) error { + panic("implement me") +} + +func (m *mockConn) SetOnRequest(on netpoll.OnRequest) error { + panic("implement me") +} + +func (m *mockConn) AddCloseCallback(callback netpoll.CloseCallback) error { + panic("implement me") +} diff --git a/pkg/network/netpoll/dial.go b/pkg/network/netpoll/dial.go index 9c6d7b3d7..70adbc108 100644 --- a/pkg/network/netpoll/dial.go +++ b/pkg/network/netpoll/dial.go @@ -1,22 +1,21 @@ +// Copyright 2022 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + //go:build !windows // +build !windows -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package netpoll import ( diff --git a/pkg/network/netpoll/transport.go b/pkg/network/netpoll/transport.go index 803e34e1e..844445eb4 100644 --- a/pkg/network/netpoll/transport.go +++ b/pkg/network/netpoll/transport.go @@ -88,7 +88,7 @@ func (t *transporter) ListenAndServe(onReq network.OnData) (err error) { } // Start Server - hlog.Infof("HERTZ: HTTP server listening on address=%s", t.listener.Addr().String()) + hlog.SystemLogger().Infof("HTTP server listening on address=%s", t.listener.Addr().String()) t.RLock() err = t.eventLoop.Serve(t.listener) t.RUnlock() diff --git a/pkg/network/standard/connection.go b/pkg/network/standard/connection.go index 2cf644332..d15f06092 100644 --- a/pkg/network/standard/connection.go +++ b/pkg/network/standard/connection.go @@ -163,6 +163,11 @@ func (c *Conn) Close() error { return c.c.Close() } +// CloseNoResetBuffer closes the connection without reset buffer. +func (c *Conn) CloseNoResetBuffer() error { + return c.c.Close() +} + // LocalAddr returns the local address of the connection. func (c *Conn) LocalAddr() net.Addr { return c.c.LocalAddr() diff --git a/pkg/network/standard/connection_test.go b/pkg/network/standard/connection_test.go index 6909d9fe9..8108f1ca7 100644 --- a/pkg/network/standard/connection_test.go +++ b/pkg/network/standard/connection_test.go @@ -18,6 +18,8 @@ package standard import ( "bytes" + "crypto/tls" + "errors" "io" "net" "strings" @@ -231,8 +233,48 @@ func TestWriteLogic(t *testing.T) { } } +func TestInitializeConn(t *testing.T) { + c := mockConn{ + localAddr: &mockAddr{ + network: "tcp", + address: "192.168.0.10:80", + }, + remoteAddr: &mockAddr{ + network: "tcp", + address: "192.168.0.20:80", + }, + } + conn := newConn(&c, 8192) + // check the assignment + assert.DeepEqual(t, errors.New("conn: write deadline not supported"), conn.SetDeadline(time.Time{})) + assert.DeepEqual(t, errors.New("conn: read deadline not supported"), conn.SetReadDeadline(time.Time{})) + assert.DeepEqual(t, errors.New("conn: write deadline not supported"), conn.SetWriteDeadline(time.Time{})) + assert.DeepEqual(t, errors.New("conn: read deadline not supported"), conn.SetReadTimeout(time.Duration(1)*time.Second)) + assert.DeepEqual(t, errors.New("conn: read deadline not supported"), conn.SetReadTimeout(time.Duration(-1)*time.Second)) + assert.DeepEqual(t, errors.New("conn: method not supported"), conn.Close()) + assert.DeepEqual(t, &mockAddr{network: "tcp", address: "192.168.0.10:80"}, conn.LocalAddr()) + assert.DeepEqual(t, &mockAddr{network: "tcp", address: "192.168.0.20:80"}, conn.RemoteAddr()) +} + +func TestInitializeTLSConn(t *testing.T) { + c := mockConn{} + tlsConn := newTLSConn(&c, 8192).(*TLSConn) + assert.DeepEqual(t, errors.New("conn: method not supported"), tlsConn.Handshake()) + assert.DeepEqual(t, tls.ConnectionState{}, tlsConn.ConnectionState()) +} + type mockConn struct { - buffer bytes.Buffer + buffer bytes.Buffer + localAddr net.Addr + remoteAddr net.Addr +} + +func (m *mockConn) Handshake() error { + return errors.New("conn: method not supported") +} + +func (m *mockConn) ConnectionState() tls.ConnectionState { + return tls.ConnectionState{} } func (m mockConn) Read(b []byte) (n int, err error) { @@ -243,7 +285,6 @@ func (m mockConn) Read(b []byte) (n int, err error) { if len(b) < 1024 { return 100, nil } - if len(b) < 5000 { return 4096, nil } @@ -255,26 +296,42 @@ func (m *mockConn) Write(b []byte) (n int, err error) { return m.buffer.Write(b) } -func (m mockConn) Close() error { - panic("implement me") +func (m *mockConn) Close() error { + return errors.New("conn: method not supported") +} + +func (m *mockConn) LocalAddr() net.Addr { + return m.localAddr +} + +func (m *mockConn) RemoteAddr() net.Addr { + return m.remoteAddr +} + +func (m *mockConn) SetDeadline(deadline time.Time) error { + if err := m.SetWriteDeadline(deadline); err != nil { + return err + } + return m.SetWriteDeadline(deadline) } -func (m mockConn) LocalAddr() net.Addr { - panic("implement me") +func (m *mockConn) SetReadDeadline(deadline time.Time) error { + return errors.New("conn: read deadline not supported") } -func (m mockConn) RemoteAddr() net.Addr { - panic("implement me") +func (m *mockConn) SetWriteDeadline(deadline time.Time) error { + return errors.New("conn: write deadline not supported") } -func (m mockConn) SetDeadline(t time.Time) error { - panic("implement me") +type mockAddr struct { + network string + address string } -func (m mockConn) SetReadDeadline(t time.Time) error { - panic("implement me") +func (m *mockAddr) Network() string { + return m.network } -func (m mockConn) SetWriteDeadline(t time.Time) error { - panic("implement me") +func (m *mockAddr) String() string { + return m.address } diff --git a/pkg/network/standard/transport.go b/pkg/network/standard/transport.go index 27398304a..624f9415b 100644 --- a/pkg/network/standard/transport.go +++ b/pkg/network/standard/transport.go @@ -60,11 +60,12 @@ func (t *transport) serve() (err error) { if err != nil { return err } + hlog.SystemLogger().Infof("HERTZ: HTTP server listening on address=%s", t.ln.Addr().String()) for { conn, err := t.ln.Accept() var c network.Conn if err != nil { - hlog.Errorf("HERTZ: Error=%s", err.Error()) + hlog.SystemLogger().Errorf("Error=%s", err.Error()) return err } if t.tls != nil { diff --git a/pkg/protocol/args_test.go b/pkg/protocol/args_test.go index e63e4c9b3..6a613dc19 100644 --- a/pkg/protocol/args_test.go +++ b/pkg/protocol/args_test.go @@ -43,6 +43,8 @@ package protocol import ( "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestArgsDeleteAll(t *testing.T) { @@ -58,3 +60,72 @@ func TestArgsDeleteAll(t *testing.T) { t.Fatalf("Expected q1 arg to be completely deleted. Current Args: %s", a.String()) } } + +func TestArgsBytesOperation(t *testing.T) { + var a Args + a.Add("q1", "foo") + a.Add("q2", "bar") + setArgBytes(a.args, a.args[0].key, a.args[0].value, false) + assert.DeepEqual(t, []byte("foo"), peekArgBytes(a.args, []byte("q1"))) + setArgBytes(a.args, a.args[1].key, a.args[1].value, true) + assert.DeepEqual(t, []byte(""), peekArgBytes(a.args, []byte("q2"))) +} + +func TestArgsPeekExists(t *testing.T) { + var a Args + a.Add("q1", "foo") + a.Add("", "") + a.Add("?", "=") + v1, b1 := a.PeekExists("q1") + assert.DeepEqual(t, []byte("foo"), []byte(v1)) + assert.True(t, b1) + v2, b2 := a.PeekExists("") + assert.DeepEqual(t, []byte(""), []byte(v2)) + assert.True(t, b2) + v3, b3 := a.PeekExists("q3") + assert.DeepEqual(t, "", v3) + assert.False(t, b3) + v4, b4 := a.PeekExists("?") + assert.DeepEqual(t, "=", v4) + assert.True(t, b4) +} + +func TestSetArg(t *testing.T) { + a := Args{args: setArg(nil, "q1", "foo", true)} + a.Add("", "") + setArgBytes(a.args, []byte("q3"), []byte("bar"), false) + s := a.String() + assert.DeepEqual(t, []byte("q1&="), []byte(s)) +} + +// Test the encoding of special parameters +func TestArgsParseBytes(t *testing.T) { + var ta1 Args + ta1.Add("q1", "foo") + ta1.Add("q1", "bar") + ta1.Add("q2", "123") + ta1.Add("q3", "") + var a1 Args + a1.ParseBytes([]byte("q1=foo&q1=bar&q2=123&q3=")) + assert.DeepEqual(t, &ta1, &a1) + + var ta2 Args + ta2.Add("?", "foo") + ta2.Add("&", "bar") + ta2.Add("&", "?") + ta2.Add("=", "=") + var a2 Args + a2.ParseBytes([]byte("%3F=foo&%26=bar&%26=%3F&%3D=%3D")) + assert.DeepEqual(t, &ta2, &a2) +} + +func TestArgsVisitAll(t *testing.T) { + var a Args + var s []string + a.Add("cloudwego", "hertz") + a.Add("hello", "world") + a.VisitAll(func(key, value []byte) { + s = append(s, string(key), string(value)) + }) + assert.DeepEqual(t, []string{"cloudwego", "hertz", "hello", "world"}, s) +} diff --git a/pkg/protocol/client/client.go b/pkg/protocol/client/client.go index e5e6cf10a..33d7c3906 100644 --- a/pkg/protocol/client/client.go +++ b/pkg/protocol/client/client.go @@ -43,6 +43,7 @@ package client import ( "context" + "io" "sync" "time" @@ -78,6 +79,40 @@ type Doer interface { Do(ctx context.Context, req *protocol.Request, resp *protocol.Response) error } +// DefaultRetryIf Default retry condition, mainly used for idempotent requests. +// If this cannot be satisfied, you can implement your own retry condition. +func DefaultRetryIf(req *protocol.Request, resp *protocol.Response, err error) bool { + // cannot retry if the request body is not rewindable + if req.IsBodyStream() { + return false + } + + if isIdempotent(req, resp, err) { + return true + } + // Retry non-idempotent requests if the server closes + // the connection before sending the response. + // + // This case is possible if the server closes the idle + // keep-alive connection on timeout. + // + // Apache and nginx usually do this. + if err == io.EOF { + return true + } + + return false +} + +func isIdempotent(req *protocol.Request, resp *protocol.Response, err error) bool { + return req.Header.IsGet() || + req.Header.IsHead() || + req.Header.IsPut() || + req.Header.IsDelete() || + req.Header.IsOptions() || + req.Header.IsTrace() +} + // DynamicConfig is config set which will be confirmed when starts a request. type DynamicConfig struct { Addr string @@ -86,9 +121,8 @@ type DynamicConfig struct { } // RetryIfFunc signature of retry if function -// -// Request argument passed to RetryIfFunc, if there are any request errors. -type RetryIfFunc func(request *protocol.Request) bool +// Judge whether to retry by request,response or error , return true is retry +type RetryIfFunc func(req *protocol.Request, resp *protocol.Response, err error) bool type clientURLResponse struct { statusCode int diff --git a/pkg/protocol/consts/default.go b/pkg/protocol/consts/default.go index ae8fbc21e..ae760742f 100644 --- a/pkg/protocol/consts/default.go +++ b/pkg/protocol/consts/default.go @@ -67,4 +67,7 @@ const ( // DefaultMaxIdempotentCallAttempts is the default idempotent calls attempts count. DefaultMaxIdempotentCallAttempts = 1 + + // DefaultMaxRetryTimes is the default call times of retry + DefaultMaxRetryTimes = 1 ) diff --git a/pkg/protocol/cookie.go b/pkg/protocol/cookie.go index c8de4c0d2..7ca22c673 100644 --- a/pkg/protocol/cookie.go +++ b/pkg/protocol/cookie.go @@ -546,7 +546,7 @@ func getCookieKey(dst, src []byte) []byte { func warnIfInvalid(value []byte) bool { for i := range value { if bytesconv.ValidCookieValueTable[value[i]] == 0 { - hlog.Warnf("HERTZ: Invalid byte %q in Cookie.Value, "+ + hlog.SystemLogger().Warnf("Invalid byte %q in Cookie.Value, "+ "it may cause compatibility problems with user agents", value[i]) return false } diff --git a/pkg/protocol/header.go b/pkg/protocol/header.go index bb2e0338b..eb28ade20 100644 --- a/pkg/protocol/header.go +++ b/pkg/protocol/header.go @@ -424,6 +424,11 @@ func (h *RequestHeader) IsPost() bool { return bytes.Equal(h.Method(), bytestr.StrPost) } +// IsDelete returns true if request method is DELETE. +func (h *RequestHeader) IsDelete() bool { + return bytes.Equal(h.Method(), bytestr.StrDelete) +} + // IsConnect returns true if request method is CONNECT. func (h *RequestHeader) IsConnect() bool { return bytes.Equal(h.Method(), bytestr.StrConnect) @@ -456,7 +461,7 @@ func checkWriteHeaderCode(code int) { // For now, we only emit a warning for bad codes. // In the future we might block things over 599 or under 100 if code < 100 || code > 599 { - hlog.Warnf("Invalid StatusCode code %v, status code should not be under 100 or over 599.\n"+ + hlog.SystemLogger().Warnf("Invalid StatusCode code %v, status code should not be under 100 or over 599.\n"+ "For more info: https://www.rfc-editor.org/rfc/rfc9110.html#name-status-codes", code) } } @@ -1063,6 +1068,16 @@ func (h *RequestHeader) IsGet() bool { return bytes.Equal(h.Method(), bytestr.StrGet) } +// IsOptions returns true if request method is Options. +func (h *RequestHeader) IsOptions() bool { + return bytes.Equal(h.Method(), bytestr.StrOptions) +} + +// IsTrace returns true if request method is Trace. +func (h *RequestHeader) IsTrace() bool { + return bytes.Equal(h.Method(), bytestr.StrTrace) +} + // SetHostBytes sets Host header value. func (h *RequestHeader) SetHostBytes(host []byte) { h.host = append(h.host[:0], host...) diff --git a/pkg/protocol/header_test.go b/pkg/protocol/header_test.go index 9a75b3d3e..b534841d1 100644 --- a/pkg/protocol/header_test.go +++ b/pkg/protocol/header_test.go @@ -53,6 +53,53 @@ import ( "github.com/cloudwego/hertz/pkg/protocol/consts" ) +func TestRequestHeaderSetRawHeaders(t *testing.T) { + h := RequestHeader{} + h.SetRawHeaders([]byte("foo")) + assert.DeepEqual(t, h.rawHeaders, []byte("foo")) +} + +func TestResponseHeaderSetHeaderLength(t *testing.T) { + h := ResponseHeader{} + h.SetHeaderLength(15) + assert.DeepEqual(t, h.headerLength, 15) + assert.DeepEqual(t, h.GetHeaderLength(), 15) +} + +func TestSetNoHTTP11(t *testing.T) { + rh := ResponseHeader{} + rh.SetNoHTTP11(true) + assert.True(t, rh.noHTTP11) + + rh.SetNoHTTP11(false) + assert.False(t, rh.noHTTP11) + assert.True(t, rh.IsHTTP11()) + + h := RequestHeader{} + h.SetNoHTTP11(true) + assert.True(t, h.noHTTP11) + + h.SetNoHTTP11(false) + assert.False(t, h.noHTTP11) + assert.True(t, h.IsHTTP11()) +} + +func TestResponseHeaderSetContentType(t *testing.T) { + h := ResponseHeader{} + h.SetContentType("foo") + assert.DeepEqual(t, h.contentType, []byte("foo")) +} + +func TestSetContentLengthBytes(t *testing.T) { + h := RequestHeader{} + h.SetContentLengthBytes([]byte("foo")) + assert.DeepEqual(t, h.contentLengthBytes, []byte("foo")) + + rh := ResponseHeader{} + rh.SetContentLengthBytes([]byte("foo")) + assert.DeepEqual(t, rh.contentLengthBytes, []byte("foo")) +} + func Test_peekRawHeader(t *testing.T) { s := "Expect: 100-continue\r\nUser-Agent: foo\r\nHost: 127.0.0.1\r\nConnection: Keep-Alive\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" assert.DeepEqual(t, []byte("127.0.0.1"), peekRawHeader([]byte(s), []byte("Host"))) @@ -60,10 +107,18 @@ func Test_peekRawHeader(t *testing.T) { func TestResponseHeader_SetContentLength(t *testing.T) { rh := new(ResponseHeader) + rh.SetContentLength(-1) + assert.True(t, strings.Contains(string(rh.Header()), "Transfer-Encoding: chunked")) rh.SetContentLength(-2) assert.True(t, strings.Contains(string(rh.Header()), "Transfer-Encoding: identity")) } +func TestResponseHeader_SetContentRange(t *testing.T) { + rh := new(ResponseHeader) + rh.SetContentRange(1, 5, 10) + assert.DeepEqual(t, rh.bufKV.value, []byte("bytes 1-5/10")) +} + func TestSetCanonical(t *testing.T) { h := ResponseHeader{} h.SetCanonical([]byte(consts.HeaderContentType), []byte("foo")) @@ -83,6 +138,12 @@ func TestSetCanonical(t *testing.T) { assert.DeepEqual(t, true, strings.Contains(string(h.Header()), "bar: foo6")) } +func TestHasAcceptEncodingBytes(t *testing.T) { + h := RequestHeader{} + h.Set(consts.HeaderAcceptEncoding, "gzip") + assert.True(t, h.HasAcceptEncodingBytes([]byte("gzip"))) +} + func TestRequestHeaderGet(t *testing.T) { h := RequestHeader{} rightVal := "yyy" @@ -93,6 +154,14 @@ func TestRequestHeaderGet(t *testing.T) { } } +func TestResponseHeaderGet(t *testing.T) { + h := ResponseHeader{} + rightVal := "yyy" + h.Set("xxx", rightVal) + val := h.Get("xxx") + assert.DeepEqual(t, val, rightVal) +} + func TestRequestHeaderVisitAll(t *testing.T) { h := RequestHeader{} h.Set("xxx", "yyy") @@ -113,6 +182,65 @@ func TestRequestHeaderVisitAll(t *testing.T) { }) } +func TestRequestHeaderDel(t *testing.T) { + t.Parallel() + + var h RequestHeader + h.Set("Foo-Bar", "baz") + h.Set("aaa", "bbb") + h.Set(consts.HeaderConnection, "keep-alive") + h.Set(consts.HeaderContentType, "aaa") + h.Set(consts.HeaderServer, "aaabbb") + h.Set(consts.HeaderContentLength, "1123") + h.SetHost("foobar") + h.SetCookie("foo", "bar") + + h.del([]byte("Foo-Bar")) + h.del([]byte("Connection")) + h.DelBytes([]byte("Content-Type")) + h.del([]byte(consts.HeaderServer)) + h.del([]byte("Content-Length")) + h.del([]byte("Set-Cookie")) + h.del([]byte("Host")) + h.DelCookie("foo") + + hv := h.Peek("aaa") + if string(hv) != "bbb" { + t.Fatalf("unexpected header value: %q. Expecting %q", hv, "bbb") + } + hv = h.Peek("Foo-Bar") + if len(hv) > 0 { + t.Fatalf("non-zero header value: %q", hv) + } + hv = h.Peek(consts.HeaderConnection) + if len(hv) > 0 { + t.Fatalf("non-zero value: %q", hv) + } + hv = h.Peek(consts.HeaderContentType) + if len(hv) > 0 { + t.Fatalf("non-zero value: %q", hv) + } + hv = h.Peek(consts.HeaderServer) + if len(hv) > 0 { + t.Fatalf("non-zero value: %q", hv) + } + hv = h.Peek(consts.HeaderContentLength) + if len(hv) > 0 { + t.Fatalf("non-zero value: %q", hv) + } + hv = h.FullCookie() + if len(hv) > 0 { + t.Fatalf("non-zero value: %q", hv) + } + hv = h.Peek(consts.HeaderCookie) + if len(hv) > 0 { + t.Fatalf("non-zero value: %q", hv) + } + if h.ContentLength() != 0 { + t.Fatalf("unexpected content-length: %d. Expecting 0", h.ContentLength()) + } +} + func TestResponseHeaderDel(t *testing.T) { t.Parallel() @@ -162,7 +290,7 @@ func TestResponseHeaderDel(t *testing.T) { } if h.Cookie(&c) { - t.Fatalf("unexpected cookie obtianed: %v", &c) + t.Fatalf("unexpected cookie obtained: %v", &c) } if h.ContentLength() != 0 { @@ -194,20 +322,48 @@ func TestResponseHeaderDelClientCookie(t *testing.T) { ReleaseCookie(c) } +func TestResponseHeaderResetConnectionClose(t *testing.T) { + h := ResponseHeader{} + h.Set(consts.HeaderConnection, "close") + hv := h.Peek(consts.HeaderConnection) + assert.DeepEqual(t, hv, []byte("close")) + h.SetConnectionClose(true) + h.ResetConnectionClose() + assert.False(t, h.connectionClose) + hv = h.Peek(consts.HeaderConnection) + if len(hv) > 0 { + t.Fatalf("ResetConnectionClose do not work,Connection: %q", hv) + } +} + +func TestRequestHeaderResetConnectionClose(t *testing.T) { + h := RequestHeader{} + h.Set(consts.HeaderConnection, "close") + hv := h.Peek(consts.HeaderConnection) + assert.DeepEqual(t, hv, []byte("close")) + h.connectionClose = true + h.ResetConnectionClose() + assert.False(t, h.connectionClose) + hv = h.Peek(consts.HeaderConnection) + if len(hv) > 0 { + t.Fatalf("ResetConnectionClose do not work,Connection: %q", hv) + } +} + func TestCheckWriteHeaderCode(t *testing.T) { buffer := bytes.NewBuffer(make([]byte, 0, 1024)) hlog.SetOutput(buffer) checkWriteHeaderCode(99) - assert.True(t, strings.Contains(buffer.String(), "[Warn] Invalid StatusCode code")) + assert.True(t, strings.Contains(buffer.String(), "[Warn] HERTZ: Invalid StatusCode code")) buffer.Reset() checkWriteHeaderCode(600) - assert.True(t, strings.Contains(buffer.String(), "[Warn] Invalid StatusCode code")) + assert.True(t, strings.Contains(buffer.String(), "[Warn] HERTZ: Invalid StatusCode code")) buffer.Reset() checkWriteHeaderCode(100) - assert.False(t, strings.Contains(buffer.String(), "[Warn] Invalid StatusCode code")) + assert.False(t, strings.Contains(buffer.String(), "[Warn] HERTZ: Invalid StatusCode code")) buffer.Reset() checkWriteHeaderCode(599) - assert.False(t, strings.Contains(buffer.String(), "[Warn] Invalid StatusCode code")) + assert.False(t, strings.Contains(buffer.String(), "[Warn] HERTZ: Invalid StatusCode code")) } func TestResponseHeaderAdd(t *testing.T) { @@ -311,3 +467,54 @@ func TestRequestHeaderAddContentType(t *testing.T) { t.Errorf("Content-Type occurred %d times", n) } } + +func TestSetMultipartFormBoundary(t *testing.T) { + h := RequestHeader{} + h.SetMultipartFormBoundary("foo") + assert.DeepEqual(t, h.contentType, []byte("multipart/form-data; boundary=foo")) +} + +func TestRequestHeaderSetByteRange(t *testing.T) { + var h RequestHeader + h.SetByteRange(1, 5) + hv := h.Peek(consts.HeaderRange) + assert.DeepEqual(t, hv, []byte("bytes=1-5")) +} + +func TestRequestHeaderSetMethodBytes(t *testing.T) { + var h RequestHeader + h.SetMethodBytes([]byte("foo")) + assert.DeepEqual(t, h.Method(), []byte("foo")) +} + +func TestRequestHeaderSetBytesKV(t *testing.T) { + var h RequestHeader + h.SetBytesKV([]byte("foo"), []byte("foo1")) + hv := h.Peek("foo") + assert.DeepEqual(t, hv, []byte("foo1")) +} + +func TestResponseHeaderSetBytesV(t *testing.T) { + var h ResponseHeader + h.SetBytesV("foo", []byte("foo1")) + hv := h.Peek("foo") + assert.DeepEqual(t, hv, []byte("foo1")) +} + +func TestRequestHeaderInitBufValue(t *testing.T) { + var h RequestHeader + slice := make([]byte, 0, 10) + h.InitBufValue(10) + assert.DeepEqual(t, cap(h.bufKV.value), cap(slice)) + assert.DeepEqual(t, h.GetBufValue(), slice) +} + +func TestRequestHeaderDelAllCookies(t *testing.T) { + var h RequestHeader + h.SetCanonical([]byte(consts.HeaderSetCookie), []byte("foo2")) + h.DelAllCookies() + hv := h.FullCookie() + if len(hv) > 0 { + t.Fatalf("non-zero value: %q", hv) + } +} diff --git a/pkg/protocol/http1/client.go b/pkg/protocol/http1/client.go index 4be7be16a..cb12bc778 100644 --- a/pkg/protocol/http1/client.go +++ b/pkg/protocol/http1/client.go @@ -56,6 +56,7 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/nocopy" + "github.com/cloudwego/hertz/pkg/app/client/retry" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/timer" @@ -325,45 +326,54 @@ func (c *HostClient) DoRedirects(ctx context.Context, req *protocol.Request, res // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *HostClient) Do(ctx context.Context, req *protocol.Request, resp *protocol.Response) error { - var err error - var retry bool + var ( + err error + canIdempotentRetry bool + isDefaultRetryFunc = true + attempts uint = 0 + maxAttempts uint = 1 + isRequestRetryable client.RetryIfFunc = client.DefaultRetryIf + ) + retryCfg := c.ClientOptions.RetryConfig + if retryCfg != nil { + maxAttempts = retryCfg.MaxAttemptTimes + } - maxAttempts := c.MaxIdempotentCallAttempts - isRequestRetryable := isIdempotent - if c.RetryIf != nil { - isRequestRetryable = c.RetryIf + if c.ClientOptions.RetryIfFunc != nil { + isRequestRetryable = c.ClientOptions.RetryIfFunc + // if the user has provided a custom retry function, the canIdempotentRetry has no meaning anymore. + // User will have full control over the retry logic through the custom retry function. + isDefaultRetryFunc = false } - attempts := 0 atomic.AddInt32(&c.pendingRequests, 1) + for { - retry, err = c.do(req, resp) - if err == nil || !retry { + canIdempotentRetry, err = c.do(req, resp) + if err == nil { break } + if isDefaultRetryFunc { + // canIdempotentRetry only makes sense if the user hasn't provided a custom retry function. + if !canIdempotentRetry { + break + } + } + attempts++ if attempts >= maxAttempts { break } - if req.IsBodyStream() { + // Check whether this request should be retried + if !isRequestRetryable(req, resp, err) { break } - if !isRequestRetryable(req) { - // Retry non-idempotent requests if the server closes - // the connection before sending the response. - // - // This case is possible if the server closes the idle - // keep-alive connection on timeout. - // - // Apache and nginx usually do this. - if err != io.EOF { - break - } - } - + wait := retry.Delay(attempts, err, retryCfg) + // Retry after wait time + time.Sleep(wait) } atomic.AddInt32(&c.pendingRequests, -1) @@ -382,10 +392,6 @@ func (c *HostClient) PendingRequests() int { return int(atomic.LoadInt32(&c.pendingRequests)) } -func isIdempotent(req *protocol.Request) bool { - return req.Header.IsGet() || req.Header.IsHead() || req.Header.IsPut() -} - func (c *HostClient) do(req *protocol.Request, resp *protocol.Response) (bool, error) { nilResp := false if resp == nil { @@ -393,13 +399,13 @@ func (c *HostClient) do(req *protocol.Request, resp *protocol.Response) (bool, e resp = protocol.AcquireResponse() } - ok, err := c.doNonNilReqResp(req, resp) + canIdempotentRetry, err := c.doNonNilReqResp(req, resp) if nilResp { protocol.ReleaseResponse(resp) } - return ok, err + return canIdempotentRetry, err } func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Response) (bool, error) { @@ -424,6 +430,7 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo req.URI().DisablePathNormalizing = true } cc, err := c.acquireConn() + // if getting connection error, fast fail if err != nil { return false, err } @@ -444,6 +451,7 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo currentTime := time.Now() if err = conn.SetWriteDeadline(currentTime.Add(c.WriteTimeout)); err != nil { c.closeConn(cc) + // try another connection if retry is enabled return true, err } } @@ -472,6 +480,7 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo if err == nil { err = zw.Flush() } + // error happened when writing request, close the connection, and try another connection if retry is enabled if err != nil { c.closeConn(cc) return true, err @@ -482,6 +491,7 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo // See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details if err = conn.SetReadTimeout(c.ReadTimeout); err != nil { c.closeConn(cc) + // try another connection if retry is enabled return true, err } } @@ -909,7 +919,7 @@ func dialAddr(addr string, dial network.Dialer, dialDualStack bool, tlsConfig *t var conn network.Conn var err error if dial == nil { - hlog.Warnf("HERTZ: HostClient: no dialer specified, trying to use default dialer") + hlog.SystemLogger().Warnf("HostClient: no dialer specified, trying to use default dialer") dial = dialer.DefaultDialer() } dialFunc := dial.DialConnection @@ -1108,11 +1118,6 @@ type ClientOptions struct { // after DefaultMaxIdleConnDuration. MaxIdleConnDuration time.Duration - // Maximum number of attempts for idempotent calls - // - // DefaultMaxIdempotentCallAttempts is used if not set. - MaxIdempotentCallAttempts int - // Maximum duration for full response reading (including body). // // By default response read timeout is unlimited. @@ -1161,11 +1166,11 @@ type ClientOptions struct { // By default will not wait, return errNoFreeConns immediately MaxConnWaitTimeout time.Duration - // RetryIf controls whether a retry should be attempted after an error. - // - // By default will use isIdempotent function - RetryIf client.RetryIfFunc - // ResponseBodyStream enables response body streaming ResponseBodyStream bool + + // All configurations related to retry + RetryConfig *retry.Config + + RetryIfFunc client.RetryIfFunc } diff --git a/pkg/protocol/http1/req/request_test.go b/pkg/protocol/http1/req/request_test.go index 2edee37e6..fc2fe71b1 100644 --- a/pkg/protocol/http1/req/request_test.go +++ b/pkg/protocol/http1/req/request_test.go @@ -1120,7 +1120,8 @@ func testContinueReadBodyStream(t *testing.T, header, body string, maxBodySize, } func verifyRequestHeader(t *testing.T, h *protocol.RequestHeader, expectedContentLength int, - expectedRequestURI, expectedHost, expectedReferer, expectedContentType string) { + expectedRequestURI, expectedHost, expectedReferer, expectedContentType string, +) { if h.ContentLength() != expectedContentLength { t.Fatalf("Unexpected Content-Length %d. Expected %d", h.ContentLength(), expectedContentLength) } diff --git a/pkg/protocol/http1/resp/response_test.go b/pkg/protocol/http1/resp/response_test.go index 69f42f234..072f67081 100644 --- a/pkg/protocol/http1/resp/response_test.go +++ b/pkg/protocol/http1/resp/response_test.go @@ -159,7 +159,8 @@ func testResponseReadError(t *testing.T, resp *protocol.Response, response strin } func testResponseReadSuccess(t *testing.T, resp *protocol.Response, response string, expectedStatusCode, expectedContentLength int, - expectedContentType, expectedBody, expectedTrailer string) { + expectedContentType, expectedBody, expectedTrailer string, +) { zr := mock.NewZeroCopyReader(response) err := Read(resp, zr) if err != nil { @@ -436,7 +437,8 @@ func verifyResponseHeader(t *testing.T, h *protocol.ResponseHeader, expectedStat } func testResponseSuccess(t *testing.T, statusCode int, contentType, serverName, body string, - expectedStatusCode int, expectedContentType, expectedServerName string) { + expectedStatusCode int, expectedContentType, expectedServerName string, +) { var resp protocol.Response resp.SetStatusCode(statusCode) resp.Header.Set("Content-Type", contentType) @@ -479,7 +481,8 @@ func testResponseSuccess(t *testing.T, statusCode int, contentType, serverName, } func testResponseReadWithoutBody(t *testing.T, resp *protocol.Response, s string, skipBody bool, - expectedStatusCode, expectedContentLength int, expectedContentType, expectedTrailer string) { + expectedStatusCode, expectedContentLength int, expectedContentType, expectedTrailer string, +) { zr := mock.NewZeroCopyReader(s) resp.SkipBody = skipBody err := Read(resp, zr) diff --git a/pkg/protocol/request_test.go b/pkg/protocol/request_test.go index 47e415de0..745b51dde 100644 --- a/pkg/protocol/request_test.go +++ b/pkg/protocol/request_test.go @@ -71,6 +71,51 @@ func TestMultiForm(t *testing.T) { fmt.Println(err) } +func TestRequestBodyWriterWrite(t *testing.T) { + w := requestBodyWriter{&Request{}} + w.Write([]byte("test")) + assert.DeepEqual(t, "test", string(w.r.body.B)) +} + +func TestRequestScheme(t *testing.T) { + req := NewRequest("", "ptth://127.0.0.1:8080", nil) + assert.DeepEqual(t, "ptth", string(req.Scheme())) + req = NewRequest("", "127.0.0.1:8080", nil) + assert.DeepEqual(t, "http", string(req.Scheme())) + assert.DeepEqual(t, true, req.IsURIParsed()) +} + +func TestRequestHost(t *testing.T) { + req := &Request{} + req.SetHost("127.0.0.1:8080") + assert.DeepEqual(t, "127.0.0.1:8080", string(req.Host())) +} + +func TestRequestSwapBody(t *testing.T) { + reqA := &Request{} + reqA.SetBodyRaw([]byte("testA")) + reqB := &Request{} + reqB.SetBodyRaw([]byte("testB")) + SwapRequestBody(reqA, reqB) + assert.DeepEqual(t, "testA", string(reqB.bodyRaw)) + assert.DeepEqual(t, "testB", string(reqA.bodyRaw)) + reqA.SetBody([]byte("testA")) + reqB.SetBody([]byte("testB")) + SwapRequestBody(reqA, reqB) + assert.DeepEqual(t, "testA", string(reqB.body.B)) + assert.DeepEqual(t, "", string(reqB.bodyRaw)) + assert.DeepEqual(t, "testB", string(reqA.body.B)) + assert.DeepEqual(t, "", string(reqA.bodyRaw)) + reqA.SetBodyStream(strings.NewReader("testA"), len("testA")) + reqB.SetBodyStream(strings.NewReader("testB"), len("testB")) + SwapRequestBody(reqA, reqB) + body := make([]byte, 5) + reqB.bodyStream.Read(body) + assert.DeepEqual(t, "testA", string(body)) + reqA.bodyStream.Read(body) + assert.DeepEqual(t, "testB", string(body)) +} + func TestRequestKnownSizeStreamMultipartFormWithFile(t *testing.T) { t.Parallel() @@ -94,8 +139,9 @@ tailfoobar` r := NewRequest("POST", "/upload", mr) r.Header.SetContentLength(521) r.Header.SetContentTypeBytes([]byte("multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg")) - + assert.DeepEqual(t, false, r.HasMultipartForm()) f, err := r.MultipartForm() + assert.DeepEqual(t, true, r.HasMultipartForm()) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -147,6 +193,10 @@ tailfoobar` t.Fatalf("unexpected content-type %q. Expecting %q", ct, "application/octet-stream") } } + + firstFile, err := r.FormFile("fileaaa") + assert.DeepEqual(t, "TODO", firstFile.Filename) + assert.Nil(t, err) } func TestRequestUnknownSizeStreamMultipartFormWithFile(t *testing.T) { @@ -300,6 +350,117 @@ tailfoobar` } } +func TestRequestMultipartFormBoundary(t *testing.T) { + r := &Request{} + r.SetMultipartFormBoundary("----boundary----") + assert.DeepEqual(t, "----boundary----", r.MultipartFormBoundary()) +} + +func TestRequestSetQueryString(t *testing.T) { + r := &Request{} + r.SetQueryString("test") + assert.DeepEqual(t, "test", string(r.URI().queryString)) +} + +func TestRequestSetFormData(t *testing.T) { + r := &Request{} + data := map[string]string{"username": "admin"} + r.SetFormData(data) + assert.DeepEqual(t, "username", string(r.postArgs.args[0].key)) + assert.DeepEqual(t, "admin", string(r.postArgs.args[0].value)) + assert.DeepEqual(t, true, r.parsedPostArgs) + assert.DeepEqual(t, "application/x-www-form-urlencoded", string(r.Header.contentType)) + + r = &Request{} + value := map[string][]string{"item": {"apple", "peach"}} + r.SetFormDataFromValues(value) + assert.DeepEqual(t, "item", string(r.postArgs.args[0].key)) + assert.DeepEqual(t, "apple", string(r.postArgs.args[0].value)) + assert.DeepEqual(t, "item", string(r.postArgs.args[1].key)) + assert.DeepEqual(t, "peach", string(r.postArgs.args[1].value)) +} + +func TestRequestSetFile(t *testing.T) { + r := &Request{} + r.SetFile("file", "/usr/bin/test.txt") + assert.DeepEqual(t, &File{"/usr/bin/test.txt", "file", nil}, r.multipartFiles[0]) + + files := map[string]string{"f1": "/usr/bin/test1.txt"} + r.SetFiles(files) + assert.DeepEqual(t, &File{"/usr/bin/test1.txt", "f1", nil}, r.multipartFiles[1]) + + assert.DeepEqual(t, []*File{{"/usr/bin/test.txt", "file", nil}, {"/usr/bin/test1.txt", "f1", nil}}, r.MultipartFiles()) +} + +func TestRequestSetFileReader(t *testing.T) { + r := &Request{} + r.SetFileReader("file", "/usr/bin/test.txt", nil) + assert.DeepEqual(t, &File{"/usr/bin/test.txt", "file", nil}, r.multipartFiles[0]) +} + +func TestRequestSetMultipartFormData(t *testing.T) { + r := &Request{} + data := map[string]string{"item": "apple"} + r.SetMultipartFormData(data) + assert.DeepEqual(t, &MultipartField{"item", "", "", strings.NewReader("apple")}, r.multipartFields[0]) + + r = &Request{} + fields := []*MultipartField{{"item2", "", "", strings.NewReader("apple2")}, {"item3", "", "", strings.NewReader("apple3")}} + r.SetMultipartFields(fields...) + assert.DeepEqual(t, fields, r.MultipartFields()) +} + +func TestRequestSetBasicAuth(t *testing.T) { + r := &Request{} + r.SetBasicAuth("admin", "admin") + assert.DeepEqual(t, "Authorization", string(r.Header.h[0].key)) + assert.DeepEqual(t, "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:admin")), string(r.Header.h[0].value)) +} + +func TestRequestSetAuthToken(t *testing.T) { + r := &Request{} + r.SetAuthToken("token") + assert.DeepEqual(t, "Authorization", string(r.Header.h[0].key)) + assert.DeepEqual(t, "Bearer token", string(r.Header.h[0].value)) + + r = &Request{} + r.SetAuthSchemeToken("http", "token") + assert.DeepEqual(t, "Authorization", string(r.Header.h[0].key)) + assert.DeepEqual(t, "http token", string(r.Header.h[0].value)) +} + +func TestRequestSetHeaders(t *testing.T) { + r := &Request{} + headers := map[string]string{"Key1": "value1"} + r.SetHeaders(headers) + assert.DeepEqual(t, "Key1", string(r.Header.h[0].key)) + assert.DeepEqual(t, "value1", string(r.Header.h[0].value)) +} + +func TestRequestSetCookie(t *testing.T) { + r := &Request{} + r.SetCookie("cookie1", "cookie1") + assert.DeepEqual(t, "cookie1", string(r.Header.cookies[0].key)) + assert.DeepEqual(t, "cookie1", string(r.Header.cookies[0].value)) + + r.SetCookies(map[string]string{"cookie2": "cookie2"}) + assert.DeepEqual(t, "cookie2", string(r.Header.cookies[1].key)) + assert.DeepEqual(t, "cookie2", string(r.Header.cookies[1].value)) +} + +func TestRequestPath(t *testing.T) { + r := NewRequest("POST", "/upload?test", nil) + assert.DeepEqual(t, "/upload", string(r.Path())) + assert.DeepEqual(t, "test", string(r.QueryString())) +} + +func TestRequestConnectionClose(t *testing.T) { + r := NewRequest("POST", "/upload?test", nil) + assert.DeepEqual(t, false, r.ConnectionClose()) + r.SetConnectionClose() + assert.DeepEqual(t, true, r.ConnectionClose()) +} + func TestRequestBodyWriteToPlain(t *testing.T) { t.Parallel() @@ -395,6 +556,32 @@ func TestRequestResetBody(t *testing.T) { assert.Nil(t, req.body) } +func TestRequestConstructBodyStream(t *testing.T) { + r := &Request{} + b := []byte("test") + r.ConstructBodyStream(&bytebufferpool.ByteBuffer{B: b}, strings.NewReader("test")) + assert.DeepEqual(t, "test", string(r.body.B)) + stream := make([]byte, 4) + r.bodyStream.Read(stream) + assert.DeepEqual(t, "test", string(stream)) +} + +func TestRequestPostArgs(t *testing.T) { + t.Parallel() + + s := `username=admin&password=admin` + mr := strings.NewReader(s) + r := &Request{} + r.SetBodyStream(mr, len(s)) + r.Header.contentType = []byte("application/x-www-form-urlencoded") + arg := r.PostArgs() + assert.DeepEqual(t, "username", string(arg.args[0].key)) + assert.DeepEqual(t, "admin", string(arg.args[0].value)) + assert.DeepEqual(t, "password", string(arg.args[1].key)) + assert.DeepEqual(t, "admin", string(arg.args[1].value)) + assert.DeepEqual(t, "username=admin&password=admin", string(r.PostArgString())) +} + func TestRequestMayContinue(t *testing.T) { t.Parallel() @@ -509,6 +696,12 @@ func TestRequestCopyToWithOptions(t *testing.T) { assert.DeepEqual(t, true, reqCopy.options.IsSD()) } +func TestRequestSetMaxKeepBodySize(t *testing.T) { + r := &Request{} + r.SetMaxKeepBodySize(1024) + assert.DeepEqual(t, 1024, r.maxKeepBodySize) +} + func TestRequestGetBodyAfterGetBodyStream(t *testing.T) { req := AcquireRequest() req.SetBodyString("abc") diff --git a/pkg/protocol/suite/server.go b/pkg/protocol/suite/server.go index b489e0d95..8d5886526 100644 --- a/pkg/protocol/suite/server.go +++ b/pkg/protocol/suite/server.go @@ -59,7 +59,7 @@ type ServerMap map[string]protocol.Server func (c *Config) Add(protocol string, factory ServerFactory) { if fac := c.configMap[protocol]; fac != nil { - hlog.Warnf("HERTZ: ServerFactory of protocol: %s will be overridden by customized function", protocol) + hlog.SystemLogger().Warnf("ServerFactory of protocol: %s will be overridden by customized function", protocol) } c.configMap[protocol] = factory } diff --git a/pkg/protocol/uri_test.go b/pkg/protocol/uri_test.go index 9da05a709..958873483 100644 --- a/pkg/protocol/uri_test.go +++ b/pkg/protocol/uri_test.go @@ -42,7 +42,6 @@ package protocol import ( - "fmt" "path/filepath" "reflect" "testing" @@ -53,16 +52,162 @@ import ( func TestURI_Username(t *testing.T) { var req Request req.SetRequestURI("http://user:pass@example.com/foo/bar") - uri := req.URI() - user1 := string(uri.username) - fmt.Printf("1--- uri:%s user:%s\n", uri.RequestURI(), user1) + u := req.URI() + user1 := string(u.Username()) req.Header.SetRequestURIBytes([]byte("/foo/bar")) - uri = req.URI() - user2 := string(uri.username) - fmt.Printf("2--- uri:%s user:%s\n", uri.RequestURI(), user2) - if user1 != user2 { - t.Fatal("user1 != user2") - } + u = req.URI() + user2 := string(u.Username()) + assert.DeepEqual(t, user1, user2) + + expectUser3 := "user3" + expectUser4 := "user4" + + u.SetUsername(expectUser3) + user3 := string(u.Username()) + assert.DeepEqual(t, expectUser3, user3) + u.SetUsername(expectUser4) + user4 := string(u.Username()) + assert.DeepEqual(t, expectUser4, user4) + + u.SetUsernameBytes([]byte(user3)) + assert.DeepEqual(t, expectUser3, user3) + u.SetUsernameBytes([]byte(user4)) + assert.DeepEqual(t, expectUser4, user4) +} + +func TestURI_Password(t *testing.T) { + u := AcquireURI() + defer ReleaseURI(u) + + expectPassword1 := "password1" + expectPassword2 := "password2" + + u.SetPassword(expectPassword1) + password1 := string(u.Password()) + assert.DeepEqual(t, expectPassword1, password1) + u.SetPassword(expectPassword2) + password2 := string(u.Password()) + assert.DeepEqual(t, expectPassword2, password2) + + u.SetPasswordBytes([]byte(password1)) + assert.DeepEqual(t, expectPassword1, password1) + u.SetPasswordBytes([]byte(password2)) + assert.DeepEqual(t, expectPassword2, password2) +} + +func TestURI_Hash(t *testing.T) { + u := AcquireURI() + defer ReleaseURI(u) + + expectHash1 := "hash1" + expectHash2 := "hash2" + + u.SetHash(expectHash1) + hash1 := string(u.Hash()) + assert.DeepEqual(t, expectHash1, hash1) + u.SetHash(expectHash2) + hash2 := string(u.Hash()) + assert.DeepEqual(t, expectHash2, hash2) +} + +func TestURI_QueryString(t *testing.T) { + u := AcquireURI() + defer ReleaseURI(u) + + expectQueryString1 := "key1=value1&key2=value2" + expectQueryString2 := "key3=value3&key4=value4" + + u.SetQueryString(expectQueryString1) + queryString1 := string(u.QueryString()) + assert.DeepEqual(t, expectQueryString1, queryString1) + u.SetQueryString(expectQueryString2) + queryString2 := string(u.QueryString()) + assert.DeepEqual(t, expectQueryString2, queryString2) +} + +func TestURI_Path(t *testing.T) { + u := AcquireURI() + defer ReleaseURI(u) + + expectPath1 := "/" + expectPath2 := "/path1" + expectPath3 := "/path3" + + // When Path is not set, Path defaults to "/" + path1 := string(u.Path()) + assert.DeepEqual(t, expectPath1, path1) + + u.SetPath(expectPath2) + path2 := string(u.Path()) + assert.DeepEqual(t, expectPath2, path2) + u.SetPath(expectPath3) + path3 := string(u.Path()) + assert.DeepEqual(t, expectPath3, path3) + + u.SetPathBytes([]byte(path2)) + assert.DeepEqual(t, expectPath2, path2) + u.SetPathBytes([]byte(path3)) + assert.DeepEqual(t, expectPath3, path3) +} + +func TestURI_Scheme(t *testing.T) { + u := AcquireURI() + defer ReleaseURI(u) + + expectScheme1 := "scheme1" + expectScheme2 := "scheme2" + + u.SetScheme(expectScheme1) + scheme1 := string(u.Scheme()) + assert.DeepEqual(t, expectScheme1, scheme1) + u.SetScheme(expectScheme2) + scheme2 := string(u.Scheme()) + assert.DeepEqual(t, expectScheme2, scheme2) + + u.SetSchemeBytes([]byte(scheme1)) + assert.DeepEqual(t, expectScheme1, scheme1) + u.SetSchemeBytes([]byte(scheme2)) + assert.DeepEqual(t, expectScheme2, scheme2) +} + +func TestURI_Host(t *testing.T) { + u := AcquireURI() + defer ReleaseURI(u) + + expectHost1 := "host1" + expectHost2 := "host2" + + u.SetHost(expectHost1) + host1 := string(u.Host()) + assert.DeepEqual(t, expectHost1, host1) + u.SetHost(expectHost2) + host2 := string(u.Host()) + assert.DeepEqual(t, expectHost2, host2) + + u.SetHostBytes([]byte(host1)) + assert.DeepEqual(t, expectHost1, host1) + u.SetHostBytes([]byte(host2)) + assert.DeepEqual(t, expectHost2, host2) +} + +func TestURI_PathOriginal(t *testing.T) { + var u URI + expectPath := "/path" + u.Parse(nil, []byte(expectPath)) + uri := string(u.PathOriginal()) + assert.DeepEqual(t, expectPath, uri) +} + +func TestArgsKV_Get(t *testing.T) { + var argsKV argsKV + expectKey := "key" + expectValue := "value" + argsKV.key = []byte(expectKey) + argsKV.value = []byte(expectValue) + key := string(argsKV.GetKey()) + value := string(argsKV.GetValue()) + assert.DeepEqual(t, expectKey, key) + assert.DeepEqual(t, expectValue, value) } func TestURICopyToQueryArgs(t *testing.T) { @@ -79,6 +224,7 @@ func TestURICopyToQueryArgs(t *testing.T) { if string(a1.Peek("foo")) != "bar" { t.Fatalf("unexpected query args value %q. Expecting %q", a1.Peek("foo"), "bar") } + assert.DeepEqual(t, "bar", string(a1.Peek("foo"))) } func TestURICopyTo(t *testing.T) { @@ -112,9 +258,7 @@ func testURILastPathSegment(t *testing.T, path, expectedSegment string) { var u URI u.SetPath(path) segment := u.LastPathSegment() - if string(segment) != expectedSegment { - t.Fatalf("unexpected last path segment for path %q: %q. Expecting %q", path, segment, expectedSegment) - } + assert.DeepEqual(t, expectedSegment, string(segment)) } func TestURIPathEscape(t *testing.T) { @@ -157,18 +301,14 @@ func testURIUpdate(t *testing.T, base, update, result string) { u.Parse(nil, []byte(base)) u.Update(update) s := u.String() - if s != result { - t.Fatalf("unexpected result %q. Expecting %q. base=%q, update=%q", s, result, base, update) - } + assert.DeepEqual(t, result, s) } func testURIPathEscape(t *testing.T, path, expectedRequestURI string) { var u URI u.SetPath(path) requestURI := u.RequestURI() - if string(requestURI) != expectedRequestURI { - t.Fatalf("unexpected requestURI %q. Expecting %q. path %q", requestURI, expectedRequestURI, path) - } + assert.DeepEqual(t, expectedRequestURI, string(requestURI)) } func TestDelArgs(t *testing.T) { @@ -210,9 +350,7 @@ func TestURIFullURI(t *testing.T) { u.Parse([]byte("google.com"), []byte("/foo?bar=baz&baraz#qqqq")) uri := u.FullURI() expectedURI := "http://google.com/foo?bar=baz&baraz#qqqq" - if string(uri) != expectedURI { - t.Fatalf("Unexpected URI: %q. Expected %q", uri, expectedURI) - } + assert.DeepEqual(t, expectedURI, string(uri)) } func testURIFullURI(t *testing.T, scheme, host, path, hash string, args *Args, expectedURI string) { @@ -225,9 +363,7 @@ func testURIFullURI(t *testing.T, scheme, host, path, hash string, args *Args, e args.CopyTo(u.QueryArgs()) uri := u.FullURI() - if string(uri) != expectedURI { - t.Fatalf("Unexpected URI: %q. Expected %q", uri, expectedURI) - } + assert.DeepEqual(t, expectedURI, string(uri)) } func TestParsePathWindows(t *testing.T) { @@ -246,3 +382,27 @@ func testParsePathWindows(t *testing.T, path, expectedPath string) { t.Fatalf("Unexpected Path: %q. Expected %q", parsedPath, expectedPath) } } + +func TestParseHostWithStr(t *testing.T) { + expectUsername := "username" + expectPassword := "password" + + testParseHostWithStr(t, "username", "", "") + testParseHostWithStr(t, "username@", expectUsername, "") + testParseHostWithStr(t, "username:password@", expectUsername, expectPassword) + testParseHostWithStr(t, ":password@", "", expectPassword) + testParseHostWithStr(t, ":password", "", "") +} + +func testParseHostWithStr(t *testing.T, host, expectUsername, expectPassword string) { + var u URI + u.Parse([]byte(host), nil) + assert.DeepEqual(t, expectUsername, string(u.Username())) + assert.DeepEqual(t, expectPassword, string(u.Password())) +} + +func TestParseURI(t *testing.T) { + expectURI := "http://google.com/foo?bar=baz&baraz#qqqq" + uri := string(ParseURI(expectURI).FullURI()) + assert.DeepEqual(t, expectURI, uri) +} diff --git a/pkg/route/default.go b/pkg/route/default.go index f0e41539d..d3577156a 100644 --- a/pkg/route/default.go +++ b/pkg/route/default.go @@ -1,22 +1,21 @@ +// Copyright 2022 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + //go:build !windows // +build !windows -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package route import ( diff --git a/pkg/route/engine.go b/pkg/route/engine.go index 52849d434..45ae38f89 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -273,17 +273,17 @@ func (engine *Engine) Shutdown(ctx context.Context) (err error) { // ensure that the hook is executed until wait timeout or finish select { case <-ctx.Done(): - hlog.Infof("HERTZ: Execute OnShutdownHooks timeout: error=%v", ctx.Err()) + hlog.SystemLogger().Infof("Execute OnShutdownHooks timeout: error=%v", ctx.Err()) return case <-ch: - hlog.Info("HERTZ: Execute OnShutdownHooks finish") + hlog.SystemLogger().Info("Execute OnShutdownHooks finish") return } }() if opt := engine.options; opt != nil && opt.Registry != nil { if err = opt.Registry.Deregister(opt.RegistryInfo); err != nil { - hlog.Errorf("HERTZ: Deregister error=%v", err) + hlog.SystemLogger().Errorf("Deregister error=%v", err) return err } } @@ -363,7 +363,7 @@ func (engine *Engine) alpnEnable() bool { } func (engine *Engine) listenAndServe() error { - hlog.Infof("HERTZ: Using network library=%s", GetTransporterName()) + hlog.SystemLogger().Infof("Using network library=%s", GetTransporterName()) return engine.transport.ListenAndServe(engine.onData) } @@ -383,7 +383,7 @@ func (engine *Engine) getNextProto(conn network.Conn) (proto string, err error) if tlsConn, ok := conn.(network.ConnTLSer); ok { if engine.options.ReadTimeout > 0 { if err := conn.SetReadTimeout(engine.options.ReadTimeout); err != nil { - hlog.Errorf("HERTZ: BUG: error in SetReadDeadline=%s: error=%s", engine.options.ReadTimeout, err) + hlog.SystemLogger().Errorf("BUG: error in SetReadDeadline=%s: error=%s", engine.options.ReadTimeout, err) } } err = tlsConn.Handshake() @@ -431,7 +431,7 @@ func errProcess(conn io.Closer, err error) { } } // other errors - hlog.Errorf("HERTZ: Error=%s, remoteAddr=%s", err.Error(), rip) + hlog.SystemLogger().Errorf("Error=%s, remoteAddr=%s", err.Error(), rip) } func getRemoteAddrFromCloser(conn io.Closer) string { @@ -477,7 +477,7 @@ func (engine *Engine) Serve(c context.Context, conn network.Conn) (err error) { if bytes.Equal(buf, bytestr.StrClientPreface) && engine.protocolServers[suite.HTTP2] != nil { return engine.protocolServers[suite.HTTP2].Serve(c, conn) } - hlog.Warnf("HERTZ: HTTP2 server is not loaded, request is going to fallback to HTTP1 server") + hlog.SystemLogger().Warnf("HTTP2 server is not loaded, request is going to fallback to HTTP1 server") } // ALPN path @@ -569,7 +569,7 @@ func debugPrintRoute(httpMethod, absolutePath string, handlers app.HandlersChain if handlerName == "" { handlerName = utils.NameOfFunction(handlers.Last()) } - hlog.Debugf("HERTZ: Method=%-6s absolutePath=%-25s --> handlerName=%s (num=%d handlers)", httpMethod, absolutePath, handlerName, nuHandlers) + hlog.SystemLogger().Debugf("Method=%-6s absolutePath=%-25s --> handlerName=%s (num=%d handlers)", httpMethod, absolutePath, handlerName, nuHandlers) } func (engine *Engine) addRoute(method, path string, handlers app.HandlersChain) { @@ -580,7 +580,10 @@ func (engine *Engine) addRoute(method, path string, handlers app.HandlersChain) utils.Assert(method != "", "HTTP method can not be empty") utils.Assert(len(handlers) > 0, "there must be at least one handler") - debugPrintRoute(method, path, handlers) + if !engine.options.DisablePrintRoute { + debugPrintRoute(method, path, handlers) + } + methodRouter := engine.trees.get(method) if methodRouter == nil { methodRouter = &router{method: method, root: &node{}, hasTsrHandler: make(map[string]bool)} @@ -790,7 +793,7 @@ func (engine *Engine) LoadHTMLGlob(pattern string) { if engine.options.AutoReloadRender { files, err := filepath.Glob(pattern) if err != nil { - hlog.Errorf("LoadHTMLGlob: %v", err) + hlog.SystemLogger().Errorf("LoadHTMLGlob: %v", err) return } engine.SetAutoReloadHTMLTemplate(tmpl, files) diff --git a/version.go b/version.go index 25fc0b2f5..652a5d7be 100644 --- a/version.go +++ b/version.go @@ -19,5 +19,5 @@ package hertz // Name and Version info of this framework, used for statistics and debug const ( Name = "Hertz" - Version = "v0.3.2" + Version = "v0.4.0" )