From d6015151ca9272b31da539e3a6503722dd9dfde4 Mon Sep 17 00:00:00 2001 From: sod-lol Date: Fri, 18 Nov 2022 05:51:34 -0500 Subject: [PATCH 01/12] test: add more tests for pkg/common/utils (#372) --- pkg/common/utils/chunk_test.go | 33 ++++--------- pkg/common/utils/ioutil_test.go | 40 +++++++++------ pkg/common/utils/netaddr_test.go | 30 ++++++++++++ pkg/common/utils/network_test.go | 43 ++++++++++++++++ pkg/common/utils/path_test.go | 46 +++++------------ pkg/common/utils/utils_test.go | 84 +++++++++++++------------------- 6 files changed, 155 insertions(+), 121 deletions(-) create mode 100644 pkg/common/utils/netaddr_test.go create mode 100644 pkg/common/utils/network_test.go diff --git a/pkg/common/utils/chunk_test.go b/pkg/common/utils/chunk_test.go index 120027ef6..ffeafa163 100644 --- a/pkg/common/utils/chunk_test.go +++ b/pkg/common/utils/chunk_test.go @@ -19,6 +19,7 @@ package utils import ( "testing" + "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" ) @@ -29,12 +30,8 @@ func TestChunkParseChunkSizeGetCorrect(t *testing.T) { chunkSizeBody := hex + "\r\n" zr := mock.NewZeroCopyReader(chunkSizeBody) chunkSize, err := ParseChunkSize(zr) - if err != nil { - t.Fatalf("Unexpected error for ParseChunkSize: %s", err) - } - if dec != chunkSize { - t.Fatalf("Unexpected chunkSize: %d. Expecting %d", chunkSize, dec) - } + assert.DeepEqual(t, nil, err) + assert.DeepEqual(t, chunkSize, dec) } } @@ -46,12 +43,8 @@ func TestChunkParseChunkSizeCorrectWhiteSpace(t *testing.T) { chunkSizeBody := "0" + whiteSpace + "\r\n" zr := mock.NewZeroCopyReader(chunkSizeBody) chunkSize, err := ParseChunkSize(zr) - if err != nil { - t.Fatalf("Unexpected error for ParseChunkSize: %s", err) - } - if chunkSize != 0 { - t.Fatalf("Unexpected chunk size: %d. Expecting 0", chunkSize) - } + assert.DeepEqual(t, nil, err) + assert.DeepEqual(t, 0, chunkSize) } } @@ -60,28 +53,20 @@ func TestChunkParseChunkSizeNonCRLF(t *testing.T) { chunkSizeBody := "0" + "\n\r" zr := mock.NewZeroCopyReader(chunkSizeBody) chunkSize, err := ParseChunkSize(zr) - if err == nil { - t.Fatalf("Expecting an error for chunkSize, but get nil") - } - if chunkSize != -1 { - t.Fatalf("Unexpected chunk size: %d. Expecting -1", chunkSize) - } + assert.DeepEqual(t, true, err != nil) + assert.DeepEqual(t, -1, chunkSize) } func TestChunkReadTrueCRLF(t *testing.T) { CRLF := "\r\n" zr := mock.NewZeroCopyReader(CRLF) err := SkipCRLF(zr) - if err != nil { - t.Fatalf("Unexpected error for SkipCRLF: %s. Expecting nil", err) - } + assert.DeepEqual(t, nil, err) } func TestChunkReadFalseCRLF(t *testing.T) { CRLF := "\n\r" zr := mock.NewZeroCopyReader(CRLF) err := SkipCRLF(zr) - if err == nil { - t.Fatalf("Expecting error, but get nil") - } + assert.DeepEqual(t, errBrokenChunk, err) } diff --git a/pkg/common/utils/ioutil_test.go b/pkg/common/utils/ioutil_test.go index f78556cd6..c0d020c7e 100644 --- a/pkg/common/utils/ioutil_test.go +++ b/pkg/common/utils/ioutil_test.go @@ -20,36 +20,48 @@ import ( "bytes" "testing" + "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/network" ) func TestIoutilCopyBuffer(t *testing.T) { var writeBuffer bytes.Buffer - src := bytes.NewBufferString("hertz is very good!!!") + str := string("hertz is very good!!!") + src := bytes.NewBufferString(str) dst := network.NewWriter(&writeBuffer) var buf []byte // src.Len() will change, when use src.read(p []byte) srcLen := int64(src.Len()) written, err := CopyBuffer(dst, src, buf) - if written != srcLen { - t.Fatalf("Unexpected written: %d. Expecting: %d", written, srcLen) - } - if err != nil { - t.Fatalf("Unexpected error: %s", err) - } + assert.DeepEqual(t, written, srcLen) + assert.DeepEqual(t, err, nil) + assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) +} + +func TestIoutilCopyBufferWithNilBuffer(t *testing.T) { + var writeBuffer bytes.Buffer + str := string("hertz is very good!!!") + src := bytes.NewBufferString(str) + dst := network.NewWriter(&writeBuffer) + // src.Len() will change, when use src.read(p []byte) + srcLen := int64(src.Len()) + written, err := CopyBuffer(dst, src, nil) + + assert.DeepEqual(t, written, srcLen) + assert.DeepEqual(t, err, nil) + assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) } func TestIoutilCopyZeroAlloc(t *testing.T) { var writeBuffer bytes.Buffer - src := bytes.NewBufferString("hertz is very good!!!") + str := string("hertz is very good!!!") + src := bytes.NewBufferString(str) dst := network.NewWriter(&writeBuffer) srcLen := int64(src.Len()) written, err := CopyZeroAlloc(dst, src) - if written != srcLen { - t.Fatalf("Unexpected written: %d. Expecting: %d", written, srcLen) - } - if err != nil { - t.Fatalf("Unexpected error: %s", err) - } + + assert.DeepEqual(t, written, srcLen) + assert.DeepEqual(t, err, nil) + assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) } diff --git a/pkg/common/utils/netaddr_test.go b/pkg/common/utils/netaddr_test.go new file mode 100644 index 000000000..0e9970912 --- /dev/null +++ b/pkg/common/utils/netaddr_test.go @@ -0,0 +1,30 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" +) + +func TestNetAddr(t *testing.T) { + networkAddr := NewNetAddr("127.0.0.1", "192.168.1.1") + + assert.DeepEqual(t, networkAddr.Network(), "127.0.0.1") + assert.DeepEqual(t, networkAddr.String(), "192.168.1.1") +} diff --git a/pkg/common/utils/network_test.go b/pkg/common/utils/network_test.go new file mode 100644 index 000000000..232e35a88 --- /dev/null +++ b/pkg/common/utils/network_test.go @@ -0,0 +1,43 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" +) + +func TestTLSRecordHeaderLooksLikeHTTP(t *testing.T) { + HeaderValueAndExpectedResult := [][]interface{}{ + {[5]byte{'G', 'E', 'T', ' ', '/'}, true}, + {[5]byte{'H', 'E', 'A', 'D', ' '}, true}, + {[5]byte{'P', 'O', 'S', 'T', ' '}, true}, + {[5]byte{'P', 'U', 'T', ' ', '/'}, true}, + {[5]byte{'O', 'P', 'T', 'I', 'O'}, true}, + {[5]byte{'G', 'E', 'T', '/', ' '}, false}, + {[5]byte{' ', 'H', 'E', 'A', 'D'}, false}, + {[5]byte{' ', 'P', 'O', 'S', 'T'}, false}, + {[5]byte{'P', 'U', 'T', '/', ' '}, false}, + {[5]byte{'H', 'E', 'R', 'T', 'Z'}, false}, + } + + for _, testCase := range HeaderValueAndExpectedResult { + value, expectedResult := testCase[0].([5]byte), testCase[1].(bool) + assert.DeepEqual(t, expectedResult, TLSRecordHeaderLooksLikeHTTP(value)) + } +} diff --git a/pkg/common/utils/path_test.go b/pkg/common/utils/path_test.go index 475e4395b..6ebb8430c 100644 --- a/pkg/common/utils/path_test.go +++ b/pkg/common/utils/path_test.go @@ -42,76 +42,56 @@ package utils import ( "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestPathCleanPath(t *testing.T) { normalPath := "/Foo/Bar/go/src/github.com/cloudwego/hertz/pkg/common/utils/path_test.go" expectedNormalPath := "/Foo/Bar/go/src/github.com/cloudwego/hertz/pkg/common/utils/path_test.go" cleanNormalPath := CleanPath(normalPath) - if cleanNormalPath != expectedNormalPath { - t.Fatalf("Unexpected path: %s. Excepting path: %s", cleanNormalPath, expectedNormalPath) - } + assert.DeepEqual(t, expectedNormalPath, cleanNormalPath) singleDotPath := "/Foo/Bar/./././go/src" expectedSingleDotPath := "/Foo/Bar/go/src" cleanSingleDotPath := CleanPath(singleDotPath) - if cleanSingleDotPath != expectedSingleDotPath { - t.Fatalf("Unexpected path: %s. Excepting path: %s", cleanSingleDotPath, expectedSingleDotPath) - } + assert.DeepEqual(t, expectedSingleDotPath, cleanSingleDotPath) doubleDotPath := "../../.." expectedDoubleDotPath := "/" cleanDoublePotPath := CleanPath(doubleDotPath) - if cleanDoublePotPath != expectedDoubleDotPath { - t.Fatalf("Unexpected path: %s. Excepting path: %s", cleanDoublePotPath, expectedDoubleDotPath) - } + assert.DeepEqual(t, expectedDoubleDotPath, cleanDoublePotPath) // MultiDot can be treated as a file name multiDotPath := "/../...." expectedMultiDotPath := "/...." cleanMultiDotPath := CleanPath(multiDotPath) - if cleanMultiDotPath != expectedMultiDotPath { - t.Fatalf("Unexpected path: %s. Excepting path: %s", cleanMultiDotPath, expectedMultiDotPath) - } + assert.DeepEqual(t, expectedMultiDotPath, cleanMultiDotPath) nullPath := "" expectedNullPath := "/" cleanNullPath := CleanPath(nullPath) - if cleanNullPath != expectedNullPath { - t.Fatalf("Unexpected path: %s. Excepting path: %s", cleanDoublePotPath, expectedDoubleDotPath) - } + assert.DeepEqual(t, expectedNullPath, cleanNullPath) relativePath := "/Foo/Bar/../go/src/../../github.com/cloudwego/hertz" expectedRelativePath := "/Foo/github.com/cloudwego/hertz" cleanRelativePath := CleanPath(relativePath) - if cleanRelativePath != expectedRelativePath { - t.Fatalf("Unexpected path: %s. Excepting path: %s", cleanRelativePath, expectedRelativePath) - } + assert.DeepEqual(t, expectedRelativePath, cleanRelativePath) multiSlashPath := "///////Foo//Bar////go//src/github.com/cloudwego/hertz//.." expectedMultiSlashPath := "/Foo/Bar/go/src/github.com/cloudwego" cleanMultiSlashPath := CleanPath(multiSlashPath) - if cleanMultiSlashPath != expectedMultiSlashPath { - t.Fatalf("Unexpected path: %s. Excepting path: %s", cleanMultiSlashPath, expectedMultiSlashPath) - } + assert.DeepEqual(t, expectedMultiSlashPath, cleanMultiSlashPath) } // The Function AddMissingPort can only add the missed port, don't consider the other error case. func TestPathAddMissingPort(t *testing.T) { ipList := []string{"127.0.0.1", "111.111.1.1", "[0:0:0:0:0:ffff:192.1.56.10]", "[0:0:0:0:0:ffff:c0a8:101]", "www.foobar.com"} for _, ip := range ipList { - if AddMissingPort(ip, true) != ip+":443" { - t.Fatalf("Unexpected address: %s. Expecting address: %s", AddMissingPort(ip, true), ip+":443") - } - if AddMissingPort(ip, false) != ip+":80" { - t.Fatalf("Unexpected address: %s. Expecting address: %s", AddMissingPort(ip, false), ip+":80") - } + assert.DeepEqual(t, ip+":443", AddMissingPort(ip, true)) + assert.DeepEqual(t, ip+":80", AddMissingPort(ip, false)) customizedPort := ":8080" - if AddMissingPort(ip+customizedPort, true) != ip+customizedPort { - t.Fatalf("Unexpected address: %s. Expecting address: %s", AddMissingPort(ip+customizedPort, false), ip+customizedPort) - } - if AddMissingPort(ip+customizedPort, false) != ip+customizedPort { - t.Fatalf("Unexpected address: %s. Expecting address: %s", AddMissingPort(ip+customizedPort, true), ip+customizedPort) - } + assert.DeepEqual(t, ip+customizedPort, AddMissingPort(ip+customizedPort, true)) + assert.DeepEqual(t, ip+customizedPort, AddMissingPort(ip+customizedPort, false)) } } diff --git a/pkg/common/utils/utils_test.go b/pkg/common/utils/utils_test.go index a621e33dd..231ad930f 100644 --- a/pkg/common/utils/utils_test.go +++ b/pkg/common/utils/utils_test.go @@ -43,6 +43,8 @@ package utils import ( "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" ) // test assert func @@ -55,15 +57,9 @@ func TestUtilsIsTrueString(t *testing.T) { upperTrueStr := "trUe" otherStr := "hertz" - if !IsTrueString(normalTrueStr) { - t.Fatalf("Unexpected false for %s.", normalTrueStr) - } - if !IsTrueString(upperTrueStr) { - t.Fatalf("Unexpected false for %s.", upperTrueStr) - } - if IsTrueString(otherStr) { - t.Fatalf("Unexpected true for %s.", otherStr) - } + assert.DeepEqual(t, true, IsTrueString(normalTrueStr)) + assert.DeepEqual(t, true, IsTrueString(upperTrueStr)) + assert.DeepEqual(t, false, IsTrueString(otherStr)) } // used for TestUtilsNameOfFunction @@ -77,33 +73,22 @@ func TestUtilsNameOfFunction(t *testing.T) { nameOfTestName := NameOfFunction(testName) nameOfIsTrueString := NameOfFunction(IsTrueString) - if nameOfTestName != pathOfTestName { - t.Fatalf("Unexpected name: %s for testName", nameOfTestName) - } - - if nameOfIsTrueString != pathOfIsTrueString { - t.Fatalf("Unexpected name: %s for IsTrueString", nameOfIsTrueString) - } + assert.DeepEqual(t, pathOfTestName, nameOfTestName) + assert.DeepEqual(t, pathOfIsTrueString, nameOfIsTrueString) } func TestUtilsCaseInsensitiveCompare(t *testing.T) { lowerStr := []byte("content-length") upperStr := []byte("Content-Length") - if !CaseInsensitiveCompare(lowerStr, upperStr) { - t.Fatalf("Unexpected false for %s and %s", string(lowerStr), string(upperStr)) - } + assert.DeepEqual(t, true, CaseInsensitiveCompare(lowerStr, upperStr)) lessStr := []byte("content-type") moreStr := []byte("content-length") - if CaseInsensitiveCompare(lessStr, moreStr) { - t.Fatalf("Unexpected true for %s and %s", string(lessStr), string(moreStr)) - } + assert.DeepEqual(t, false, CaseInsensitiveCompare(lessStr, moreStr)) firstStr := []byte("content-type") secondStr := []byte("contant-type") - if CaseInsensitiveCompare(firstStr, secondStr) { - t.Fatalf("Unexpected true for %s and %s", string(firstStr), string(secondStr)) - } + assert.DeepEqual(t, false, CaseInsensitiveCompare(firstStr, secondStr)) } // NormalizeHeaderKey can upper the first letter and lower the other letter in @@ -113,18 +98,16 @@ func TestUtilsNormalizeHeaderKey(t *testing.T) { contentTypeStr := []byte("Content-Type") lowerContentTypeStr := []byte("content-type") mixedContentTypeStr := []byte("conTENt-tYpE") + mixedContertTypeStrWithoutNormalizing := []byte("Content-type") NormalizeHeaderKey(contentTypeStr, false) NormalizeHeaderKey(lowerContentTypeStr, false) NormalizeHeaderKey(mixedContentTypeStr, false) - if string(contentTypeStr) != "Content-Type" { - t.Fatalf("Unexpected normalizedHeader: %s", string(contentTypeStr)) - } - if string(lowerContentTypeStr) != "Content-Type" { - t.Fatalf("Unexpected normalizedHeader: %s", string(lowerContentTypeStr)) - } - if string(mixedContentTypeStr) != "Content-Type" { - t.Fatalf("Unexpected normalizedHeader: %s", string(mixedContentTypeStr)) - } + NormalizeHeaderKey(lowerContentTypeStr, true) + + assert.DeepEqual(t, "Content-Type", string(contentTypeStr)) + assert.DeepEqual(t, "Content-Type", string(lowerContentTypeStr)) + assert.DeepEqual(t, "Content-Type", string(mixedContentTypeStr)) + assert.DeepEqual(t, "Content-type", string(mixedContertTypeStrWithoutNormalizing)) } // Cutting up the header Type. @@ -133,22 +116,23 @@ func TestUtilsNormalizeHeaderKey(t *testing.T) { func TestUtilsNextLine(t *testing.T) { multiHeaderStr := []byte("Content-Type: application/x-www-form-urlencoded\r\nDate: Fri, 6 Aug 2021 11:00:31 GMT") contentTypeStr, dateStr, hErr := NextLine(multiHeaderStr) - if hErr != nil { - t.Fatalf("Unexpected error: %s", hErr) - } - if string(contentTypeStr) != "Content-Type: application/x-www-form-urlencoded" { - t.Fatalf("Unexpected %s", string(contentTypeStr)) - } - if string(dateStr) != "Date: Fri, 6 Aug 2021 11:00:31 GMT" { - t.Fatalf("Unexpected %s", string(contentTypeStr)) - } + assert.DeepEqual(t, nil, hErr) + assert.DeepEqual(t, "Content-Type: application/x-www-form-urlencoded", string(contentTypeStr)) + assert.DeepEqual(t, "Date: Fri, 6 Aug 2021 11:00:31 GMT", string(dateStr)) + + multiHeaderStrWithoutReturn := []byte("Content-Type: application/x-www-form-urlencoded\nDate: Fri, 6 Aug 2021 11:00:31 GMT") + contentTypeStr, dateStr, hErr = NextLine(multiHeaderStrWithoutReturn) + assert.DeepEqual(t, nil, hErr) + assert.DeepEqual(t, "Content-Type: application/x-www-form-urlencoded", string(contentTypeStr)) + assert.DeepEqual(t, "Date: Fri, 6 Aug 2021 11:00:31 GMT", string(dateStr)) + + singleHeaderStrWithFirstNewLine := []byte("\nContent-Type: application/x-www-form-urlencoded") + firstStr, secondStr, sErr := NextLine(singleHeaderStrWithFirstNewLine) + assert.DeepEqual(t, nil, sErr) + assert.DeepEqual(t, string(""), string(firstStr)) + assert.DeepEqual(t, "Content-Type: application/x-www-form-urlencoded", string(secondStr)) singleHeaderStr := []byte("Content-Type: application/x-www-form-urlencoded") - firstStr, secondStr, sErr := NextLine(singleHeaderStr) - if sErr == nil { - t.Fatalf("Unexpected nil. Expecting an error: ErrNeedMore") - } - if firstStr != nil || secondStr != nil { - t.Fatalf("Unexpected string. Expecting: nil") - } + _, _, sErr = NextLine(singleHeaderStr) + assert.DeepEqual(t, errNeedMore, sErr) } From ec2ec4991db9f05df5195fa0a59a4ea8bcef17f8 Mon Sep 17 00:00:00 2001 From: raymonder jin Date: Tue, 22 Nov 2022 17:55:04 +0800 Subject: [PATCH 02/12] test: add more tests for pkg/app/middlewares/server/basic_auth (#405) --- .../server/basic_auth/basic_auth_test.go | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/pkg/app/middlewares/server/basic_auth/basic_auth_test.go b/pkg/app/middlewares/server/basic_auth/basic_auth_test.go index 013e3197a..5b6d6aad7 100644 --- a/pkg/app/middlewares/server/basic_auth/basic_auth_test.go +++ b/pkg/app/middlewares/server/basic_auth/basic_auth_test.go @@ -41,8 +41,12 @@ package basic_auth import ( + "context" + "encoding/base64" "testing" + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/test/assert" ) @@ -63,3 +67,32 @@ func TestPairs(t *testing.T) { assert.False(t, ok3) assert.False(t, ok4) } + +func TestBasicAuth(t *testing.T) { + userName1 := "user1" + password1 := "value1" + userName2 := "user2" + password2 := "value2" + + c1 := app.RequestContext{} + encodeStr := "Basic " + base64.StdEncoding.EncodeToString(bytesconv.S2b(userName1+":"+password1)) + c1.Request.Header.Add("Authorization", encodeStr) + + t1 := Accounts{userName1: password1} + handler := BasicAuth(t1) + handler(context.TODO(), &c1) + + user, ok := c1.Get("user") + assert.DeepEqual(t, userName1, user) + assert.True(t, ok) + + c2 := app.RequestContext{} + encodeStr = "Basic " + base64.StdEncoding.EncodeToString(bytesconv.S2b(userName2+":"+password2)) + c2.Request.Header.Add("Authorization", encodeStr) + + handler(context.TODO(), &c2) + + user, ok = c2.Get("user") + assert.Nil(t, user) + assert.False(t, ok) +} From d89206f0ee478f7b67113076325ce71e890a1e5b Mon Sep 17 00:00:00 2001 From: gityh2021 <85598202+gityh2021@users.noreply.github.com> Date: Thu, 24 Nov 2022 19:00:24 +0800 Subject: [PATCH 03/12] test: add new test cases for pkg/protocol/response.go (#374) --- pkg/protocol/response_test.go | 55 +++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/pkg/protocol/response_test.go b/pkg/protocol/response_test.go index 94cc6c15d..7195eec9e 100644 --- a/pkg/protocol/response_test.go +++ b/pkg/protocol/response_test.go @@ -48,6 +48,8 @@ import ( "reflect" "testing" + "github.com/cloudwego/hertz/pkg/common/bytebufferpool" + "github.com/cloudwego/hertz/pkg/common/compress" "github.com/cloudwego/hertz/pkg/common/test/assert" ) @@ -174,3 +176,56 @@ func testResponseCopyTo(t *testing.T, src *Response) { t.Fatalf("ResponseCopyTo fail, src: \n%+v\ndst: \n%+v\n", src, &dst) //nolint:govet } } + +func TestResponseMustSkipBody(t *testing.T) { + resp := Response{} + resp.SetStatusCode(200) + resp.SetBodyString("test") + assert.False(t, resp.MustSkipBody()) + // no content 204 means that skip body is necessary + resp.SetStatusCode(204) + resp.ResetBody() + assert.True(t, resp.MustSkipBody()) +} + +func TestResponseBodyGunzip(t *testing.T) { + t.Parallel() + dst1 := []byte("") + src1 := []byte("hello") + res1 := compress.AppendGzipBytes(dst1, src1) + resp := Response{} + resp.SetBody(res1) + zipData, err := resp.BodyGunzip() + assert.Nil(t, err) + assert.DeepEqual(t, zipData, src1) +} + +func TestResponseSwapResponseBody(t *testing.T) { + t.Parallel() + resp1 := Response{} + str1 := "resp1" + byteBuffer1 := &bytebufferpool.ByteBuffer{} + byteBuffer1.Set([]byte(str1)) + resp1.ConstructBodyStream(byteBuffer1, bytes.NewBufferString(str1)) + assert.True(t, resp1.HasBodyBytes()) + resp2 := Response{} + str2 := "resp2" + byteBuffer2 := &bytebufferpool.ByteBuffer{} + byteBuffer2.Set([]byte(str2)) + resp2.ConstructBodyStream(byteBuffer2, bytes.NewBufferString(str2)) + SwapResponseBody(&resp1, &resp2) + assert.DeepEqual(t, resp1.body.B, []byte(str2)) + assert.DeepEqual(t, resp1.BodyStream(), bytes.NewBufferString(str2)) + assert.DeepEqual(t, resp2.body.B, []byte(str1)) + assert.DeepEqual(t, resp2.BodyStream(), bytes.NewBufferString(str1)) +} + +func TestResponseAcquireResponse(t *testing.T) { + t.Parallel() + resp1 := AcquireResponse() + assert.NotNil(t, resp1) + resp1.SetBody([]byte("test")) + resp1.SetStatusCode(200) + ReleaseResponse(resp1) + assert.Nil(t, resp1.body) +} From d39a8186de8aaaabc10d1bd9eac4fedc7c30cd29 Mon Sep 17 00:00:00 2001 From: raymonder jin Date: Fri, 25 Nov 2022 12:11:27 +0800 Subject: [PATCH 04/12] test: add tests for pkg/app/client/discovery/discovery.go (#417) --- pkg/app/client/discovery/discovery_test.go | 73 ++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 pkg/app/client/discovery/discovery_test.go diff --git a/pkg/app/client/discovery/discovery_test.go b/pkg/app/client/discovery/discovery_test.go new file mode 100644 index 000000000..50242c6cc --- /dev/null +++ b/pkg/app/client/discovery/discovery_test.go @@ -0,0 +1,73 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package discovery + +import ( + "context" + "testing" + + "github.com/cloudwego/hertz/pkg/app/server/registry" + "github.com/cloudwego/hertz/pkg/common/test/assert" +) + +func TestInstance(t *testing.T) { + network := "192.168.1.1" + address := "/hello" + weight := 1 + instance := NewInstance(network, address, weight, nil) + + assert.DeepEqual(t, network, instance.Address().Network()) + assert.DeepEqual(t, address, instance.Address().String()) + assert.DeepEqual(t, weight, instance.Weight()) + val, ok := instance.Tag("name") + assert.DeepEqual(t, "", val) + assert.False(t, ok) + + instance2 := NewInstance("", "", 0, nil) + assert.DeepEqual(t, registry.DefaultWeight, instance2.Weight()) +} + +func TestSynthesizedResolver(t *testing.T) { + targetFunc := func(ctx context.Context, target *TargetInfo) string { + return "hello" + } + resolveFunc := func(ctx context.Context, key string) (Result, error) { + return Result{CacheKey: "name"}, nil + } + nameFunc := func() string { + return "raymonder" + } + resolver := SynthesizedResolver{ + TargetFunc: targetFunc, + ResolveFunc: resolveFunc, + NameFunc: nameFunc, + } + + assert.DeepEqual(t, "hello", resolver.Target(context.Background(), &TargetInfo{})) + res, err := resolver.Resolve(context.Background(), "") + assert.DeepEqual(t, "name", res.CacheKey) + assert.Nil(t, err) + assert.DeepEqual(t, "raymonder", resolver.Name()) + + resolver2 := SynthesizedResolver{ + TargetFunc: nil, + ResolveFunc: nil, + NameFunc: nil, + } + assert.DeepEqual(t, "", resolver2.Target(context.Background(), &TargetInfo{})) + assert.DeepEqual(t, "", resolver2.Name()) +} From f4b9e2b17dd345c086c0bf0af6c744df5fb6f8f5 Mon Sep 17 00:00:00 2001 From: Xuran <37136584+Duslia@users.noreply.github.com> Date: Wed, 30 Nov 2022 00:32:37 +0800 Subject: [PATCH 05/12] feat: Adapt the logic after moving the http2 logic to hertz-contrib (#418) --- pkg/route/engine.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/pkg/route/engine.go b/pkg/route/engine.go index 94d6d22ea..0096eb5e9 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -353,10 +353,6 @@ func (engine *Engine) Run() (err error) { } func (engine *Engine) Init() error { - if !h2Enable(engine.options) { - engine.protocolSuite.Delete(suite.HTTP2) - } - // add built-in http1 server by default if !engine.HasServer(suite.HTTP1) { engine.AddProtocol(suite.HTTP1, factory.NewServerFactory(newHttp1OptionFromEngine(engine))) @@ -371,7 +367,6 @@ func (engine *Engine) Init() error { if engine.alpnEnable() { engine.options.TLS.NextProtos = append(engine.options.TLS.NextProtos, suite.HTTP1) - engine.options.TLS.NextProtos = append(engine.options.TLS.NextProtos, suite.HTTP2) } if !atomic.CompareAndSwapUint32(&engine.status, 0, statusInitialized) { @@ -584,10 +579,6 @@ func initTrace(engine *Engine) stats.Level { return traceLevel } -func h2Enable(opt *config.Options) bool { - return opt.H2C || (opt.TLS != nil && opt.ALPN) -} - func debugPrintRoute(httpMethod, absolutePath string, handlers app.HandlersChain) { nuHandlers := len(handlers) handlerName := app.GetHandlerName(handlers.Last()) From 5afd02cdbeb770aac55cd2fe62092b95402a7f91 Mon Sep 17 00:00:00 2001 From: LanLanceYuan <92938836+L2ncE@users.noreply.github.com> Date: Thu, 1 Dec 2022 11:27:29 +0800 Subject: [PATCH 06/12] docs(README): add csrf and loadbalance description (#438) --- README.md | 52 +++++++++++++++++++++++++++------------------------- README_cn.md | 52 +++++++++++++++++++++++++++------------------------- 2 files changed, 54 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 554319cf4..ab5c28c51 100644 --- a/README.md +++ b/README.md @@ -58,32 +58,34 @@ Hertz [həːts] is a high-usability, high-performance and high-extensibility Gol - [Example](https://github.com/cloudwego/hertz-examples): Use examples of Hertz. ## Extensions -| Extensions | Description | -|----------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [Websocket](https://github.com/hertz-contrib/websocket) | Enable Hertz to support the Websocket protocol. | -| [Pprof](https://github.com/hertz-contrib/pprof) | Extension for Hertz integration with Pprof. | -| [Sessions](https://github.com/hertz-contrib/sessions) | Session middleware with multi-state store support. | -| [Obs-opentelemetry](https://github.com/hertz-contrib/obs-opentelemetry) | Hertz's Opentelemetry extension that supports Metric, Logger, Tracing and works out of the box. | +| Extensions | Description | +|----------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [Websocket](https://github.com/hertz-contrib/websocket) | Enable Hertz to support the Websocket protocol. | +| [Pprof](https://github.com/hertz-contrib/pprof) | Extension for Hertz integration with Pprof. | +| [Sessions](https://github.com/hertz-contrib/sessions) | Session middleware with multi-state store support. | +| [Obs-opentelemetry](https://github.com/hertz-contrib/obs-opentelemetry) | Hertz's Opentelemetry extension that supports Metric, Logger, Tracing and works out of the box. | | [Registry](https://github.com/hertz-contrib/registry) | Provides service registry and discovery functions. So far, the supported service discovery extensions are nacos, consul, etcd, eureka, polaris, servicecomb, zookeeper, redis. | -| [Keyauth](https://github.com/hertz-contrib/keyauth) | Provides token-based authentication. | -| [Secure](https://github.com/hertz-contrib/secure) | Secure middleware with multiple configuration items. | -| [Sentry](https://github.com/hertz-contrib/hertzsentry) | Sentry extension provides some unified interfaces to help users perform real-time error monitoring. | -| [Requestid](https://github.com/hertz-contrib/requestid) | Add request id in response. | -| [Limiter](https://github.com/hertz-contrib/limiter) | Provides a current limiter based on the bbr algorithm. | -| [Jwt](https://github.com/hertz-contrib/jwt) | Jwt extension. | -| [Autotls](https://github.com/hertz-contrib/autotls) | Make Hertz support Let's Encrypt. | -| [Monitor-prometheus](https://github.com/hertz-contrib/monitor-prometheus) | Provides service monitoring based on Prometheus. | -| [I18n](https://github.com/hertz-contrib/i18n) | Helps translate Hertz programs into multi programming languages. | -| [Reverseproxy](https://github.com/hertz-contrib/reverseproxy) | Implement a reverse proxy. | -| [Opensergo](https://github.com/hertz-contrib/opensergo) | The Opensergo extension. | -| [Gzip](https://github.com/hertz-contrib/gzip) | A Gzip extension with multiple options. | -| [Cors](https://github.com/hertz-contrib/cors) | Provides cross-domain resource sharing support. | -| [Swagger](https://github.com/hertz-contrib/swagger) | Automatically generate RESTful API documentation with Swagger 2.0. | -| [Tracer](https://github.com/hertz-contrib/tracer) | Link tracing based on Opentracing. | -| [Recovery](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/recovery) | Recovery middleware for Hertz. | -| [Basicauth](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/basic_auth) | Basicauth middleware can provide HTTP basic authentication. | -| [Lark](https://github.com/hertz-contrib/lark-hertz) | Use hertz handle Lark/Feishu card message and event callback. | -| [Logger](https://github.com/hertz-contrib/logger) | Logger extension for Hertz, which provides support for zap, logrus, zerologs logging frameworks. | +| [Keyauth](https://github.com/hertz-contrib/keyauth) | Provides token-based authentication. | +| [Secure](https://github.com/hertz-contrib/secure) | Secure middleware with multiple configuration items. | +| [Sentry](https://github.com/hertz-contrib/hertzsentry) | Sentry extension provides some unified interfaces to help users perform real-time error monitoring. | +| [Requestid](https://github.com/hertz-contrib/requestid) | Add request id in response. | +| [Limiter](https://github.com/hertz-contrib/limiter) | Provides a current limiter based on the bbr algorithm. | +| [Jwt](https://github.com/hertz-contrib/jwt) | Jwt extension. | +| [Autotls](https://github.com/hertz-contrib/autotls) | Make Hertz support Let's Encrypt. | +| [Monitor-prometheus](https://github.com/hertz-contrib/monitor-prometheus) | Provides service monitoring based on Prometheus. | +| [I18n](https://github.com/hertz-contrib/i18n) | Helps translate Hertz programs into multi programming languages. | +| [Reverseproxy](https://github.com/hertz-contrib/reverseproxy) | Implement a reverse proxy. | +| [Opensergo](https://github.com/hertz-contrib/opensergo) | The Opensergo extension. | +| [Gzip](https://github.com/hertz-contrib/gzip) | A Gzip extension with multiple options. | +| [Cors](https://github.com/hertz-contrib/cors) | Provides cross-domain resource sharing support. | +| [Swagger](https://github.com/hertz-contrib/swagger) | Automatically generate RESTful API documentation with Swagger 2.0. | +| [Tracer](https://github.com/hertz-contrib/tracer) | Link tracing based on Opentracing. | +| [Recovery](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/recovery) | Recovery middleware for Hertz. | +| [Basicauth](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/basic_auth) | Basicauth middleware can provide HTTP basic authentication. | +| [Lark](https://github.com/hertz-contrib/lark-hertz) | Use hertz handle Lark/Feishu card message and event callback. | +| [Logger](https://github.com/hertz-contrib/logger) | Logger extension for Hertz, which provides support for zap, logrus, zerologs logging frameworks. | +| [Csrf](https://github.com/hertz-contrib/csrf) | Csrf middleware is used to prevent cross-site request forgery attacks. | +| [Loadbalance](https://github.com/hertz-contrib/loadbalance) | Provides load balancing algorithms for Hertz. | ## Blogs - [ByteDance Practice on Go Network Library](https://www.cloudwego.io/blog/2021/10/09/bytedance-practices-on-go-network-library/) diff --git a/README_cn.md b/README_cn.md index 942b74341..be60589be 100644 --- a/README_cn.md +++ b/README_cn.md @@ -58,32 +58,34 @@ Hertz[həːts] 是一个 Golang 微服务 HTTP 框架,在设计之初参考了 - [Example](https://github.com/cloudwego/hertz-examples): Hertz 使用例子 ## 相关拓展 -| 拓展 | 描述 | -|----------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------| -| [Websocket](https://github.com/hertz-contrib/websocket) | 使 Hertz 支持 Websocket 协议。 | -| [Pprof](https://github.com/hertz-contrib/pprof) | Hertz 集成 Pprof 的扩展。 | -| [Sessions](https://github.com/hertz-contrib/sessions) | 具有多状态存储支持的 Session 中间件。 | -| [Obs-opentelemetry](https://github.com/hertz-contrib/obs-opentelemetry) | Hertz 的 Opentelemetry 扩展,支持 Metric、Logger、Tracing并且达到开箱即用。 | +| 拓展 | 描述 | +|----------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------| +| [Websocket](https://github.com/hertz-contrib/websocket) | 使 Hertz 支持 Websocket 协议。 | +| [Pprof](https://github.com/hertz-contrib/pprof) | Hertz 集成 Pprof 的扩展。 | +| [Sessions](https://github.com/hertz-contrib/sessions) | 具有多状态存储支持的 Session 中间件。 | +| [Obs-opentelemetry](https://github.com/hertz-contrib/obs-opentelemetry) | Hertz 的 Opentelemetry 扩展,支持 Metric、Logger、Tracing并且达到开箱即用。 | | [Registry](https://github.com/hertz-contrib/registry) | 提供服务注册与发现功能。到现在为止,支持的服务发现拓展有 nacos, consul, etcd, eureka, polaris, servicecomb, zookeeper, redis。 | -| [Keyauth](https://github.com/hertz-contrib/keyauth) | 提供基于 token 的身份验证。 | -| [Secure](https://github.com/hertz-contrib/secure) | 具有多配置项的 Secure 中间件。 | -| [Sentry](https://github.com/hertz-contrib/hertzsentry) | Sentry 拓展提供了一些统一的接口来帮助用户进行实时的错误监控。 | -| [Requestid](https://github.com/hertz-contrib/requestid) | 在 response 中添加 request id。 | -| [Limiter](https://github.com/hertz-contrib/limiter) | 提供了基于 bbr 算法的限流器。 | -| [Jwt](https://github.com/hertz-contrib/jwt) | Jwt 拓展。 | -| [Autotls](https://github.com/hertz-contrib/autotls) | 为 Hertz 支持 Let's Encrypt 。 | -| [Monitor-prometheus](https://github.com/hertz-contrib/monitor-prometheus) | 提供基于 Prometheus 服务监控功能。 | -| [I18n](https://github.com/hertz-contrib/i18n) | 可帮助将 Hertz 程序翻译成多种语言。 | -| [Reverseproxy](https://github.com/hertz-contrib/reverseproxy) | 实现反向代理。 | -| [Opensergo](https://github.com/hertz-contrib/opensergo) | Opensergo 扩展。 | -| [Gzip](https://github.com/hertz-contrib/gzip) | 含多个可选项的 Gzip 拓展。 | -| [Cors](https://github.com/hertz-contrib/cors) | 提供跨域资源共享支持。 | -| [Swagger](https://github.com/hertz-contrib/swagger) | 使用 Swagger 2.0 自动生成 RESTful API 文档。 | -| [Tracer](https://github.com/hertz-contrib/tracer) | 基于 Opentracing 的链路追踪。 | -| [Recovery](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/recovery) | Hertz 的异常恢复中间件。 | -| [Basicauth](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/basic_auth) | Basicauth 中间件能够提供 HTTP 基本身份验证。 | -| [Lark](https://github.com/hertz-contrib/lark-hertz) | 在 Hertz 中处理 Lark/飞书的卡片消息和事件的回调。 | -| [Logger](https://github.com/hertz-contrib/logger) | Hertz 的日志拓展,提供了对 zap、logrus、zerologs 日志框架的支持。 | +| [Keyauth](https://github.com/hertz-contrib/keyauth) | 提供基于 token 的身份验证。 | +| [Secure](https://github.com/hertz-contrib/secure) | 具有多配置项的 Secure 中间件。 | +| [Sentry](https://github.com/hertz-contrib/hertzsentry) | Sentry 拓展提供了一些统一的接口来帮助用户进行实时的错误监控。 | +| [Requestid](https://github.com/hertz-contrib/requestid) | 在 response 中添加 request id。 | +| [Limiter](https://github.com/hertz-contrib/limiter) | 提供了基于 bbr 算法的限流器。 | +| [Jwt](https://github.com/hertz-contrib/jwt) | Jwt 拓展。 | +| [Autotls](https://github.com/hertz-contrib/autotls) | 为 Hertz 支持 Let's Encrypt 。 | +| [Monitor-prometheus](https://github.com/hertz-contrib/monitor-prometheus) | 提供基于 Prometheus 服务监控功能。 | +| [I18n](https://github.com/hertz-contrib/i18n) | 可帮助将 Hertz 程序翻译成多种语言。 | +| [Reverseproxy](https://github.com/hertz-contrib/reverseproxy) | 实现反向代理。 | +| [Opensergo](https://github.com/hertz-contrib/opensergo) | Opensergo 扩展。 | +| [Gzip](https://github.com/hertz-contrib/gzip) | 含多个可选项的 Gzip 拓展。 | +| [Cors](https://github.com/hertz-contrib/cors) | 提供跨域资源共享支持。 | +| [Swagger](https://github.com/hertz-contrib/swagger) | 使用 Swagger 2.0 自动生成 RESTful API 文档。 | +| [Tracer](https://github.com/hertz-contrib/tracer) | 基于 Opentracing 的链路追踪。 | +| [Recovery](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/recovery) | Hertz 的异常恢复中间件。 | +| [Basicauth](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/basic_auth) | Basicauth 中间件能够提供 HTTP 基本身份验证。 | +| [Lark](https://github.com/hertz-contrib/lark-hertz) | 在 Hertz 中处理 Lark/飞书的卡片消息和事件的回调。 | +| [Logger](https://github.com/hertz-contrib/logger) | Hertz 的日志拓展,提供了对 zap、logrus、zerologs 日志框架的支持。 | +| [Csrf](https://github.com/hertz-contrib/csrf) | Csrf 中间件用于防止跨站点请求伪造攻击。 | +| [Loadbalance](https://github.com/hertz-contrib/loadbalance) | 提供适用于 Hertz 的负载均衡算法。 | ## 相关文章 - [字节跳动在 Go 网络库上的实践](https://www.cloudwego.io/blog/2021/10/09/bytedance-practices-on-go-network-library/) From a12c7a0110a1cc1277a25002ee8801a0593703f5 Mon Sep 17 00:00:00 2001 From: Xuran <37136584+Duslia@users.noreply.github.com> Date: Thu, 1 Dec 2022 12:40:29 +0800 Subject: [PATCH 07/12] feat: add `resp.SetBodyStreamNoReset` method (#440) --- pkg/protocol/response.go | 7 +++++++ pkg/protocol/response_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/pkg/protocol/response.go b/pkg/protocol/response.go index a170ffdc6..f46395b09 100644 --- a/pkg/protocol/response.go +++ b/pkg/protocol/response.go @@ -193,6 +193,13 @@ func (resp *Response) SetBodyStream(bodyStream io.Reader, bodySize int) { resp.Header.SetContentLength(bodySize) } +// SetBodyStreamNoReset is almost the same as SetBodyStream, +// but it doesn't reset the bodyStream before. +func (resp *Response) SetBodyStreamNoReset(bodyStream io.Reader, bodySize int) { + resp.bodyStream = bodyStream + resp.Header.SetContentLength(bodySize) +} + // BodyE returns response body. func (resp *Response) BodyE() ([]byte, error) { if resp.bodyStream != nil { diff --git a/pkg/protocol/response_test.go b/pkg/protocol/response_test.go index 7195eec9e..4f2f56e89 100644 --- a/pkg/protocol/response_test.go +++ b/pkg/protocol/response_test.go @@ -229,3 +229,30 @@ func TestResponseAcquireResponse(t *testing.T) { ReleaseResponse(resp1) assert.Nil(t, resp1.body) } + +type closeBuffer struct { + *bytes.Buffer +} + +func (b *closeBuffer) Close() error { + b.Reset() + return nil +} + +func TestSetBodyStreamNoReset(t *testing.T) { + t.Parallel() + resp := Response{} + bsA := &closeBuffer{bytes.NewBufferString("A")} + bsB := &closeBuffer{bytes.NewBufferString("B")} + bsC := &closeBuffer{bytes.NewBufferString("C")} + + resp.SetBodyStream(bsA, 1) + resp.SetBodyStreamNoReset(bsB, 1) + // resp.Body() has closed bsB + assert.DeepEqual(t, string(resp.Body()), "B") + assert.DeepEqual(t, bsA.String(), "A") + + resp.bodyStream = bsA + resp.SetBodyStream(bsC, 1) + assert.DeepEqual(t, bsA.String(), "") +} From d7097925a232aeaf7a04c277152c8a42dcdd0c47 Mon Sep 17 00:00:00 2001 From: raymonder jin Date: Mon, 5 Dec 2022 15:27:30 +0800 Subject: [PATCH 08/12] test: add tests for pkg/app/client/retry/retry.go (#419) --- pkg/app/client/retry/retry_test.go | 87 ++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 pkg/app/client/retry/retry_test.go diff --git a/pkg/app/client/retry/retry_test.go b/pkg/app/client/retry/retry_test.go new file mode 100644 index 000000000..c77c4f088 --- /dev/null +++ b/pkg/app/client/retry/retry_test.go @@ -0,0 +1,87 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retry + +import ( + "math" + "testing" + "time" + + "github.com/cloudwego/hertz/pkg/common/test/assert" +) + +func TestApply(t *testing.T) { + delayPolicyFunc := func(attempts uint, err error, retryConfig *Config) time.Duration { + return time.Second + } + options := []Option{} + options = append(options, WithMaxAttemptTimes(100), WithInitDelay(time.Second), + WithMaxDelay(time.Second), WithDelayPolicy(delayPolicyFunc), WithMaxJitter(time.Second)) + + config := Config{} + config.Apply(options) + + assert.DeepEqual(t, uint(100), config.MaxAttemptTimes) + assert.DeepEqual(t, time.Second, config.Delay) + assert.DeepEqual(t, time.Second, config.MaxDelay) + assert.DeepEqual(t, time.Second, Delay(0, nil, &config)) + assert.DeepEqual(t, time.Second, config.MaxJitter) +} + +func TestPolicy(t *testing.T) { + dur := DefaultDelayPolicy(0, nil, nil) + assert.DeepEqual(t, 0*time.Millisecond, dur) + + config := Config{ + Delay: time.Second, + } + dur = FixedDelayPolicy(0, nil, &config) + assert.DeepEqual(t, time.Second, dur) + + dur = RandomDelayPolicy(0, nil, &config) + assert.DeepEqual(t, 0*time.Millisecond, dur) + config.MaxJitter = time.Second * 1 + dur = RandomDelayPolicy(0, nil, &config) + assert.NotEqual(t, time.Second*1, dur) + + dur = BackOffDelayPolicy(0, nil, &config) + assert.DeepEqual(t, time.Second*1, dur) + config.Delay = time.Duration(-1) + dur = BackOffDelayPolicy(0, nil, &config) + assert.DeepEqual(t, time.Second*0, dur) + config.Delay = time.Duration(1) + dur = BackOffDelayPolicy(63, nil, &config) + durExp := config.Delay << 62 + assert.DeepEqual(t, durExp, dur) + + dur = Delay(0, nil, &config) + assert.DeepEqual(t, 0*time.Millisecond, dur) + delayPolicyFunc := func(attempts uint, err error, retryConfig *Config) time.Duration { + return time.Second + } + config.DelayPolicy = delayPolicyFunc + config.MaxDelay = time.Second / 2 + dur = Delay(0, nil, &config) + assert.DeepEqual(t, config.MaxDelay, dur) + + delayPolicyFunc2 := func(attempts uint, err error, retryConfig *Config) time.Duration { + return time.Duration(math.MaxInt64) + } + delayFunc := CombineDelay(delayPolicyFunc2, delayPolicyFunc) + dur = delayFunc(0, nil, &config) + assert.DeepEqual(t, time.Duration(math.MaxInt64), dur) +} From 9b9ec92704e19626225f6fc10d69a62f1762dbda Mon Sep 17 00:00:00 2001 From: Xuran <37136584+Duslia@users.noreply.github.com> Date: Mon, 5 Dec 2022 15:32:13 +0800 Subject: [PATCH 09/12] fix: `doRequestFollowRedirectsBuffer` cannot get body in HTTP2 scenario (#421) --- pkg/app/client/client_test.go | 34 ++++++++++++++++++++++++++++++++++ pkg/protocol/client/client.go | 4 +++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go index 8c86aff30..e325a0e14 100644 --- a/pkg/app/client/client_test.go +++ b/pkg/app/client/client_test.go @@ -68,6 +68,7 @@ import ( "github.com/cloudwego/hertz/pkg/app/client/retry" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" + "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/dialer" "github.com/cloudwego/hertz/pkg/network/netpoll" @@ -193,6 +194,39 @@ func TestClientGetWithBody(t *testing.T) { } } +func TestClientPostBodyStream(t *testing.T) { + t.Parallel() + + opt := config.NewOptions([]config.Option{}) + opt.Addr = "unix-test-10102" + opt.Network = "unix" + engine := route.NewEngine(opt) + engine.POST("/", func(c context.Context, ctx *app.RequestContext) { + body := ctx.Request.Body() + ctx.Write(body) //nolint:errcheck + }) + go engine.Run() + defer func() { + engine.Close() + }() + time.Sleep(time.Millisecond * 500) + + cStream, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil)), WithResponseBodyStream(true)) + args := &protocol.Args{} + // There is some data in databuf and others is in bodystream, so we need + // to let the data exceed the max bodysize of bodystream + v := "" + for i := 0; i < 10240; i++ { + v += "b" + } + args.Add("a", v) + _, body, err := cStream.Post(context.Background(), nil, "http://example.com", args) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, "a="+v, string(body)) +} + func TestClientURLAuth(t *testing.T) { t.Parallel() diff --git a/pkg/protocol/client/client.go b/pkg/protocol/client/client.go index 33d7c3906..e770547b7 100644 --- a/pkg/protocol/client/client.go +++ b/pkg/protocol/client/client.go @@ -220,7 +220,9 @@ func doRequestFollowRedirectsBuffer(ctx context.Context, req *protocol.Request, statusCode, _, err = DoRequestFollowRedirects(ctx, req, resp, url, defaultMaxRedirectsCount, c) - body = bodyBuf.B + // In HTTP2 scenario, client use stream mode to create a request and its body is in body stream. + // In HTTP1, only client recv body exceed max body size and client is in stream mode can trig it. + body = resp.Body() bodyBuf.B = oldBody protocol.ReleaseResponse(resp) From a470f4abaec9b073aa9688a628992f3be472e911 Mon Sep 17 00:00:00 2001 From: Wenju Gao Date: Wed, 7 Dec 2022 11:17:59 +0800 Subject: [PATCH 10/12] optimize(http1): return 413 status code if request body is too large (#430) --- pkg/app/server/hertz_test.go | 8 ++++++-- pkg/protocol/http1/server.go | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index 80ddc0c9b..f15db14e6 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -23,6 +23,7 @@ import ( "fmt" "html/template" "io" + "io/ioutil" "net" "net/http" "os" @@ -379,8 +380,11 @@ func TestNotEnoughBodySize(t *testing.T) { r.ParseForm() r.Form.Add("xxxxxx", "xxx") body := strings.NewReader(r.Form.Encode()) - resp, _ := http.Post("http://127.0.0.1:8889/test", "application/x-www-form-urlencoded", body) - assert.DeepEqual(t, 400, resp.StatusCode) + resp, err := http.Post("http://127.0.0.1:8889/test", "application/x-www-form-urlencoded", body) + assert.Nil(t, err) + assert.DeepEqual(t, 413, resp.StatusCode) + bodyBytes, _ := ioutil.ReadAll(resp.Body) + assert.DeepEqual(t, "Request Entity Too Large", string(bodyBytes)) } func TestEnoughBodySize(t *testing.T) { diff --git a/pkg/protocol/http1/server.go b/pkg/protocol/http1/server.go index 16a65b54a..6ff8d7dd6 100644 --- a/pkg/protocol/http1/server.go +++ b/pkg/protocol/http1/server.go @@ -374,6 +374,8 @@ func writeResponse(ctx *app.RequestContext, w network.Writer) error { func defaultErrorHandler(ctx *app.RequestContext, err error) { if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() { ctx.AbortWithMsg("Request timeout", consts.StatusRequestTimeout) + } else if errors.Is(err, errs.ErrBodyTooLarge) { + ctx.AbortWithMsg("Request Entity Too Large", consts.StatusRequestEntityTooLarge) } else { ctx.AbortWithMsg("Error when parsing request", consts.StatusBadRequest) } From 0329357bf1d810d4e6b7cddb223b81eca9e92c01 Mon Sep 17 00:00:00 2001 From: Wenju Gao Date: Wed, 7 Dec 2022 11:42:00 +0800 Subject: [PATCH 11/12] optimize: ignore sighup signal if binary is run by nohup (#441) --- pkg/app/server/hertz.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pkg/app/server/hertz.go b/pkg/app/server/hertz.go index 651b1ede4..72a19c831 100644 --- a/pkg/app/server/hertz.go +++ b/pkg/app/server/hertz.go @@ -95,8 +95,13 @@ func (h *Hertz) SetCustomSignalWaiter(f func(err chan error) error) { // SIGTERM triggers immediately close. // SIGHUP|SIGINT triggers graceful shutdown. func waitSignal(errCh chan error) error { + signalToNotify := []os.Signal{syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM} + if signal.Ignored(syscall.SIGHUP) { + signalToNotify = []os.Signal{syscall.SIGINT, syscall.SIGTERM} + } + signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM) + signal.Notify(signals, signalToNotify...) select { case sig := <-signals: @@ -105,6 +110,7 @@ func waitSignal(errCh chan error) error { // force exit return errors.New(sig.String()) // nolint case syscall.SIGHUP, syscall.SIGINT: + hlog.SystemLogger().Infof("Received signal: %s\n", sig) // graceful shutdown return nil } From a2dce784a10b1e830a20018f5dfddb4aec76e648 Mon Sep 17 00:00:00 2001 From: alice <90381261+alice-yyds@users.noreply.github.com> Date: Wed, 7 Dec 2022 11:48:27 +0800 Subject: [PATCH 12/12] chore: update version v0.4.2 --- version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.go b/version.go index 006b989dc..2f877ce6b 100644 --- a/version.go +++ b/version.go @@ -19,5 +19,5 @@ package hertz // Name and Version info of this framework, used for statistics and debug const ( Name = "Hertz" - Version = "v0.4.1" + Version = "v0.4.2" )