From 4663d839345edb28a5d61f8f1564236ec3498067 Mon Sep 17 00:00:00 2001 From: MP Droog Date: Mon, 15 Feb 2021 20:15:33 +0100 Subject: [PATCH 1/3] client: simplfy idPool to prevent memory leak * Replaced go-routine impl with lock to simplify the code and solve the leaking. --- client.go | 78 ++++++++++++++---------------- client_internal_test.go | 104 ++++++++++++++++++++++++---------------- 2 files changed, 99 insertions(+), 83 deletions(-) diff --git a/client.go b/client.go index 6e9dc66..b7e9635 100644 --- a/client.go +++ b/client.go @@ -18,6 +18,7 @@ import ( "io" "net" "net/http" + "runtime" "strconv" "strings" "sync" @@ -31,6 +32,8 @@ const ( RoleResponder Role = iota + 1 RoleAuthorizer RoleFilter + + MaxRequestID = ^uint16(0) ) // NewRequest returns a standard FastCGI request @@ -64,57 +67,53 @@ type Request struct { } type idPool struct { - IDs chan uint16 + IDs uint16 + + Used *sync.Map + Lock *sync.Mutex } // AllocID implements Client.AllocID func (p *idPool) Alloc() uint16 { - return <-p.IDs + p.Lock.Lock() +next: + idx := p.IDs + if idx == MaxRequestID { + // reset + p.IDs = 0 + } + p.IDs++ + + if _, inuse := p.Used.Load(idx); inuse { + // Allow other go-routine to take priority + // to prevent spinlock here + runtime.Gosched() + goto next + } + + p.Used.Store(idx, struct{}{}) + p.Lock.Unlock() + + return idx } // ReleaseID implements Client.ReleaseID func (p *idPool) Release(id uint16) { - go func() { - // release the ID back to channel for reuse - // use goroutine to prev0, ent blocking ReleaseID - p.IDs <- id - }() + p.Used.Delete(id) } -func newIDs(limit uint32) (p idPool) { - - // sanatize limit - if limit == 0 || limit > 65535 { - // Note: limit is the size of the pool - // Since 0 cannot be requestId, the effective - // pool is from 1 to 65535, hence size is 65535. - limit = 65535 +func newIDs() *idPool { + return &idPool{ + Used: new(sync.Map), + Lock: new(sync.Mutex), + IDs: uint16(1), } - - // pool requestID for the client - // - // requestID: Identifies the FastCGI request to which the record belongs. - // The Web server re-uses FastCGI request IDs; the application - // keeps track of the current state of each request ID on a given - // transport connection. - // - // Ref: https://fast-cgi.github.io/spec#33-records - ids := make(chan uint16) - go func(maxID uint16) { - for i := uint16(1); i < maxID; i++ { - ids <- i - } - ids <- uint16(maxID) - }(uint16(limit)) - - p.IDs = ids - return } // client is the default implementation of Client type client struct { conn *conn - ids idPool + ids *idPool } // writeRequest writes params and stdin to the FastCGI application @@ -388,12 +387,7 @@ type ClientFactory func() (Client, error) // SimpleClientFactory returns a ClientFactory implementation // with the given ConnFactory. // -// limit is the maximum number of request that the -// applcation support. 0 means the maximum number -// available for 16bit request id (because 0 is not -// a valid reqeust id, 65535). -// -// Default 0. +// limit is UNUSED. // func SimpleClientFactory(connFactory ConnFactory, limit uint32) ClientFactory { return func() (c Client, err error) { @@ -406,7 +400,7 @@ func SimpleClientFactory(connFactory ConnFactory, limit uint32) ClientFactory { // create client c = &client{ conn: newConn(conn), - ids: newIDs(limit), + ids: newIDs(), } return } diff --git a/client_internal_test.go b/client_internal_test.go index 8cc68a2..c9da843 100644 --- a/client_internal_test.go +++ b/client_internal_test.go @@ -6,77 +6,99 @@ import ( "time" ) +// requestId is supposed to be unique among all active requests in a connection. So a requestId +// should not be reused until the previous request of the same id is inactive (releasing the id). func TestIDPool_Alloc(t *testing.T) { - t.Logf("default limit: %d", 65535) - ids := newIDs(0) - for i := uint32(1); i <= 65535; i++ { + ids := newIDs() + idToReserve := uint16(rand.Int31n(int32(MaxRequestID))) + + // Loop over all ids to make sure it is sequencely returning + // 1 to 65535. + // + // Note: Use uint as loop counter so it can loop past 65535 + // to end the loop (also keep the code readable) + for i := uint(1); i <= uint(MaxRequestID); i++ { if want, have := uint16(i), ids.Alloc(); want != have { - t.Errorf("expected %d, got %d", want, have) + t.Fatalf("expected %v, got %v", want, have) + } + if i != uint(idToReserve) { + ids.Release(uint16(i)) } } - // test if new id can be allocated - // when all ids are already allocated - newAlloc := make(chan uint16) - go func(ids idPool, newAlloc chan<- uint16) { - newAlloc <- ids.Alloc() - }(ids, newAlloc) - - select { - case reqID := <-newAlloc: - t.Errorf("unexpected new allocation: %d", reqID) - case <-time.After(time.Millisecond * 100): - t.Log("blocks as expected") + // Loop over all requestids 5 times + for i := 0; i < 5; i++ { + for j := uint(1); j <= uint(MaxRequestID-1); j++ { + id := ids.Alloc() + if id == 0 { + t.Fatal("ids.Alloc() is never allowed to return 0") + } else if id == idToReserve { + t.Fatalf("The requestId %v was not reserved as expect", id) + } else if j < uint(idToReserve) { + if want, have := uint(id), j; want != have { + t.Fatalf("expected %v, got %v", want, have) + } + } else if j >= uint(idToReserve) { + if want, have := uint(id), j+1; want != have { + t.Fatalf("expected %v, got %v", want, have) + } + } + ids.Release(id) // always release the allocated id + } } - // now, release a random ID - released := uint16(rand.Int31n(65535)) - go func(ids idPool, released uint16) { - ids.Release(released) - }(ids, released) + // release the reserved id + ids.Release(idToReserve) - select { - case reqID := <-newAlloc: - if want, have := released, reqID; want != have { - t.Errorf("expected %d, got %d", want, have) + // make sure all ids are available again + for i := uint(1); i <= uint(MaxRequestID); i++ { + if want, have := uint16(i), ids.Alloc(); want != have { + t.Fatalf("expected %v, got %v", want, have) } - case <-time.After(time.Millisecond * 100): - t.Errorf("unexpected blocking") } } -func TestIDPool_Alloc_withLimit(t *testing.T) { +// If all IDs are used up, pool is supposed to block on alloc after exhaustion. +func TestIDPool_block(t *testing.T) { - limit := uint32(rand.Int31n(100) + 10) - t.Logf("random limit: %d", limit) + ids := newIDs() - ids := newIDs(limit) - for i := uint32(1); i <= limit; i++ { - if want, have := uint16(i), ids.Alloc(); want != have { - t.Errorf("expected %d, got %d", want, have) + // Test allocating all ids once. + for i := uint(1); i <= uint(MaxRequestID); i++ { + id := ids.Alloc() + if want, have := i, uint(id); want != have { + t.Errorf("expected to allocate %v, got %v", want, have) + t.FailNow() } } - // test if new id can be allocated - // when all ids are already allocated newAlloc := make(chan uint16) - go func(ids idPool, newAlloc chan<- uint16) { + waitAlloc := func(ids *idPool, newAlloc chan<- uint16) { newAlloc <- ids.Alloc() - }(ids, newAlloc) + } + go waitAlloc(ids, newAlloc) + go waitAlloc(ids, newAlloc) + go waitAlloc(ids, newAlloc) + go waitAlloc(ids, newAlloc) + go waitAlloc(ids, newAlloc) + // wait some time to see if we can allocate id again select { case reqID := <-newAlloc: - t.Errorf("unexpected new allocation: %d", reqID) + t.Fatalf("unexpected new allocation: %d", reqID) case <-time.After(time.Millisecond * 100): t.Log("blocks as expected") } // now, release a random ID - released := uint16(rand.Int31n(int32(limit))) - go func(ids idPool, released uint16) { + released := uint16(rand.Int31n(int32(MaxRequestID))) + go func(ids *idPool, released uint16) { + // release an id ids.Release(released) + t.Logf("id released: %v", released) }(ids, released) + // wait some time to see if we can allocate id again select { case reqID := <-newAlloc: if want, have := released, reqID; want != have { From f4daef5ffaa54eff8093b8a3095141d6a40903ec Mon Sep 17 00:00:00 2001 From: Koala Yeung Date: Sat, 20 Feb 2021 03:15:20 +0800 Subject: [PATCH 2/3] travis: Remove go 1.7 and 1.8 from test * Using sync.Map for idPool. The feature is only available on go 1.9 and after. So go 1.7 and 1.8 are droped out of support. --- .travis.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index db64767..0e6edb2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,8 +3,6 @@ language: go dist: focal go: - - 1.7.x - - 1.8.x - 1.9.x - 1.10.x - 1.11.x From 4ea507d75b2fc2f728e1b50c9e4b65d36356ffc8 Mon Sep 17 00:00:00 2001 From: Koala Yeung Date: Sat, 20 Feb 2021 03:16:50 +0800 Subject: [PATCH 3/3] BREAKING: Changed SimpleClientFactory function signature * Remove the new obsoleted limit argument. No longer support in the new idPool implementation. Please discuss in issue track if you want this back. --- README.md | 12 ++++++------ client.go | 5 +---- client_test.go | 2 -- example/nodejs/nodejs.go | 4 ++-- example/php/php.go | 4 ++-- example/python3/python3.go | 2 +- host_test.go | 1 - pool_test.go | 8 ++++---- 8 files changed, 16 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index cbd40d8..2b226f2 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ func main() { // route all requests to a single php file http.Handle("/", gofast.NewHandler( gofast.NewFileEndpoint("/var/www/html/index.php")(gofast.BasicSession), - gofast.SimpleClientFactory(connFactory, 0), + gofast.SimpleClientFactory(connFactory), )) // serve at 8080 port @@ -153,7 +153,7 @@ func main() { // route all requests to relevant PHP file http.Handle("/", gofast.NewHandler( gofast.NewPHPFS("/var/www/html")(gofast.BasicSession), - gofast.SimpleClientFactory(connFactory, 0), + gofast.SimpleClientFactory(connFactory), )) // serve at 8080 port @@ -229,7 +229,7 @@ func main() { // route all requests to a single php file http.Handle("/", gofast.NewHandler( gofast.NewFileEndpoint("/var/www/html/index.php")(sess), - gofast.SimpleClientFactory(connFactory, 0), + gofast.SimpleClientFactory(connFactory), )) // serve at 8080 port @@ -308,7 +308,7 @@ func myApp() http.Handler { func main() { address := os.Getenv("FASTCGI_ADDR") connFactory := gofast.SimpleConnFactory("tcp", address) - clientFactory := gofast.SimpleClientFactory(connFactory, 0) + clientFactory := gofast.SimpleClientFactory(connFactory) // authorization with php authSess := gofast.Chain( @@ -361,7 +361,7 @@ import ( func main() { address := os.Getenv("FASTCGI_ADDR") connFactory := gofast.SimpleConnFactory("tcp", address) - clientFactory := gofast.SimpleClientFactory(connFactory, 0) + clientFactory := gofast.SimpleClientFactory(connFactory) // Note: The local file system "/var/www/html/" only need to be // local to web server. No need for the FastCGI application to access @@ -420,7 +420,7 @@ func main() { // handle all scripts in document root // extra pooling layer pool := gofast.NewClientPool( - gofast.SimpleClientFactory(connFactory, 0), + gofast.SimpleClientFactory(connFactory), 10, // buffer size for pre-created client-connection 30*time.Second, // life span of a client before expire ) diff --git a/client.go b/client.go index b7e9635..862a138 100644 --- a/client.go +++ b/client.go @@ -386,10 +386,7 @@ type ClientFactory func() (Client, error) // SimpleClientFactory returns a ClientFactory implementation // with the given ConnFactory. -// -// limit is UNUSED. -// -func SimpleClientFactory(connFactory ConnFactory, limit uint32) ClientFactory { +func SimpleClientFactory(connFactory ConnFactory) ClientFactory { return func() (c Client, err error) { // connect to given network address conn, err := connFactory() diff --git a/client_test.go b/client_test.go index c70296a..04e6106 100644 --- a/client_test.go +++ b/client_test.go @@ -109,7 +109,6 @@ func testHandlerForCancel(t *testing.T, p *appServer, w http.ResponseWriter, r * c, err := gofast.SimpleClientFactory( gofast.SimpleConnFactory(p.Network(), p.Address()), - 0, )() if err != nil { http.Error(w, "failed to connect to FastCGI application", http.StatusBadGateway) @@ -203,7 +202,6 @@ func TestClient_StdErr(t *testing.T) { doRequest := func(w http.ResponseWriter, r *http.Request) (errStr string) { c, err := gofast.SimpleClientFactory( gofast.SimpleConnFactory(p.Network(), p.Address()), - 0, )() if err != nil { errStr = "web server: unable to connect to FastCGI application: " + err.Error() diff --git a/example/nodejs/nodejs.go b/example/nodejs/nodejs.go index 5e0f57b..36a38f4 100644 --- a/example/nodejs/nodejs.go +++ b/example/nodejs/nodejs.go @@ -20,7 +20,7 @@ import ( func NewHandler(entrypoint, network, address string) http.Handler { connFactory := gofast.SimpleConnFactory(network, address) pool := gofast.NewClientPool( - gofast.SimpleClientFactory(connFactory, 0), + gofast.SimpleClientFactory(connFactory), 10, 60*time.Second, ) @@ -59,7 +59,7 @@ func NewMuxHandler( // common client pool for both filter and responder handler connFactory := gofast.SimpleConnFactory(network, address) pool := gofast.NewClientPool( - gofast.SimpleClientFactory(connFactory, 0), + gofast.SimpleClientFactory(connFactory), 10, 60*time.Second, ) diff --git a/example/php/php.go b/example/php/php.go index 2648b2f..04e9842 100644 --- a/example/php/php.go +++ b/example/php/php.go @@ -21,7 +21,7 @@ func NewSimpleHandler(docroot, network, address string) http.Handler { connFactory := gofast.SimpleConnFactory(network, address) h := gofast.NewHandler( gofast.NewPHPFS(docroot)(gofast.BasicSession), - gofast.SimpleClientFactory(connFactory, 0), + gofast.SimpleClientFactory(connFactory), ) return h } @@ -41,7 +41,7 @@ func NewFileEndpointHandler(filepath, network, address string) http.Handler { connFactory := gofast.SimpleConnFactory(network, address) h := gofast.NewHandler( gofast.NewFileEndpoint(filepath)(gofast.BasicSession), - gofast.SimpleClientFactory(connFactory, 0), + gofast.SimpleClientFactory(connFactory), ) return h } diff --git a/example/python3/python3.go b/example/python3/python3.go index aca86f5..f714d8e 100644 --- a/example/python3/python3.go +++ b/example/python3/python3.go @@ -20,7 +20,7 @@ import ( func NewHandler(entrypoint, network, address string) http.Handler { connFactory := gofast.SimpleConnFactory(network, address) pool := gofast.NewClientPool( - gofast.SimpleClientFactory(connFactory, 0), + gofast.SimpleClientFactory(connFactory), 10, 60*time.Second, ) diff --git a/host_test.go b/host_test.go index d66c774..d0682bd 100644 --- a/host_test.go +++ b/host_test.go @@ -45,7 +45,6 @@ func TestHandler(t *testing.T) { l.Addr().Network(), l.Addr().String(), ), - 0, ), ) w := httptest.NewRecorder() diff --git a/pool_test.go b/pool_test.go index edc9201..a342fd2 100644 --- a/pool_test.go +++ b/pool_test.go @@ -99,7 +99,7 @@ func TestClientPool_CreateClient_withErr(t *testing.T) { cpHasError := NewClientPool( SimpleClientFactory(func() (net.Conn, error) { return nil, fmt.Errorf("dummy error") - }, 10), + }), 10, 1*time.Millisecond, ) c, err := cpHasError.CreateClient() @@ -117,7 +117,7 @@ func TestClientPool_CreateClient_withErr(t *testing.T) { cpHasError = NewClientPool( SimpleClientFactory(func() (net.Conn, error) { return nil, fmt.Errorf("dummy error") - }, 0), + }), 0, 1*time.Millisecond, ) c, err = cpHasError.CreateClient() @@ -143,7 +143,7 @@ func TestClientPool_CreateClient_Return_0(t *testing.T) { conn := mockConn(false) atomic.AddUint64(&counter, 1) return &conn, nil - }, 0), + }), 0, 1000*time.Millisecond, ) @@ -195,7 +195,7 @@ func TestClientPool_CreateClient_Return_40(t *testing.T) { conn := mockConn(false) atomic.AddUint64(&counter, 1) return &conn, nil - }, 0), + }), 40, 1000*time.Millisecond, )