Skip to content

Commit

Permalink
Merge pull request #70 from yookoala/fix/idpool
Browse files Browse the repository at this point in the history
Simplify idpool to fix potential memory leak issue
  • Loading branch information
yookoala authored Feb 19, 2021
2 parents 847a66b + 4ea507d commit 33ba70c
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 106 deletions.
2 changes: 0 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ language: go
dist: focal

go:
- 1.7.x
- 1.8.x
- 1.9.x
- 1.10.x
- 1.11.x
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func main() {
// route all requests to a single php file
http.Handle("/", gofast.NewHandler(
gofast.NewFileEndpoint("/var/www/html/index.php")(gofast.BasicSession),
gofast.SimpleClientFactory(connFactory, 0),
gofast.SimpleClientFactory(connFactory),
))

// serve at 8080 port
Expand Down Expand Up @@ -153,7 +153,7 @@ func main() {
// route all requests to relevant PHP file
http.Handle("/", gofast.NewHandler(
gofast.NewPHPFS("/var/www/html")(gofast.BasicSession),
gofast.SimpleClientFactory(connFactory, 0),
gofast.SimpleClientFactory(connFactory),
))

// serve at 8080 port
Expand Down Expand Up @@ -229,7 +229,7 @@ func main() {
// route all requests to a single php file
http.Handle("/", gofast.NewHandler(
gofast.NewFileEndpoint("/var/www/html/index.php")(sess),
gofast.SimpleClientFactory(connFactory, 0),
gofast.SimpleClientFactory(connFactory),
))

// serve at 8080 port
Expand Down Expand Up @@ -308,7 +308,7 @@ func myApp() http.Handler {
func main() {
address := os.Getenv("FASTCGI_ADDR")
connFactory := gofast.SimpleConnFactory("tcp", address)
clientFactory := gofast.SimpleClientFactory(connFactory, 0)
clientFactory := gofast.SimpleClientFactory(connFactory)

// authorization with php
authSess := gofast.Chain(
Expand Down Expand Up @@ -361,7 +361,7 @@ import (
func main() {
address := os.Getenv("FASTCGI_ADDR")
connFactory := gofast.SimpleConnFactory("tcp", address)
clientFactory := gofast.SimpleClientFactory(connFactory, 0)
clientFactory := gofast.SimpleClientFactory(connFactory)

// Note: The local file system "/var/www/html/" only need to be
// local to web server. No need for the FastCGI application to access
Expand Down Expand Up @@ -420,7 +420,7 @@ func main() {
// handle all scripts in document root
// extra pooling layer
pool := gofast.NewClientPool(
gofast.SimpleClientFactory(connFactory, 0),
gofast.SimpleClientFactory(connFactory),
10, // buffer size for pre-created client-connection
30*time.Second, // life span of a client before expire
)
Expand Down
81 changes: 36 additions & 45 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"io"
"net"
"net/http"
"runtime"
"strconv"
"strings"
"sync"
Expand All @@ -31,6 +32,8 @@ const (
RoleResponder Role = iota + 1
RoleAuthorizer
RoleFilter

MaxRequestID = ^uint16(0)
)

// NewRequest returns a standard FastCGI request
Expand Down Expand Up @@ -64,57 +67,53 @@ type Request struct {
}

type idPool struct {
IDs chan uint16
IDs uint16

Used *sync.Map
Lock *sync.Mutex
}

// AllocID implements Client.AllocID
func (p *idPool) Alloc() uint16 {
return <-p.IDs
p.Lock.Lock()
next:
idx := p.IDs
if idx == MaxRequestID {
// reset
p.IDs = 0
}
p.IDs++

if _, inuse := p.Used.Load(idx); inuse {
// Allow other go-routine to take priority
// to prevent spinlock here
runtime.Gosched()
goto next
}

p.Used.Store(idx, struct{}{})
p.Lock.Unlock()

return idx
}

// ReleaseID implements Client.ReleaseID
func (p *idPool) Release(id uint16) {
go func() {
// release the ID back to channel for reuse
// use goroutine to prev0, ent blocking ReleaseID
p.IDs <- id
}()
p.Used.Delete(id)
}

func newIDs(limit uint32) (p idPool) {

// sanatize limit
if limit == 0 || limit > 65535 {
// Note: limit is the size of the pool
// Since 0 cannot be requestId, the effective
// pool is from 1 to 65535, hence size is 65535.
limit = 65535
func newIDs() *idPool {
return &idPool{
Used: new(sync.Map),
Lock: new(sync.Mutex),
IDs: uint16(1),
}

// pool requestID for the client
//
// requestID: Identifies the FastCGI request to which the record belongs.
// The Web server re-uses FastCGI request IDs; the application
// keeps track of the current state of each request ID on a given
// transport connection.
//
// Ref: https://fast-cgi.github.io/spec#33-records
ids := make(chan uint16)
go func(maxID uint16) {
for i := uint16(1); i < maxID; i++ {
ids <- i
}
ids <- uint16(maxID)
}(uint16(limit))

p.IDs = ids
return
}

// client is the default implementation of Client
type client struct {
conn *conn
ids idPool
ids *idPool
}

// writeRequest writes params and stdin to the FastCGI application
Expand Down Expand Up @@ -387,15 +386,7 @@ type ClientFactory func() (Client, error)

// SimpleClientFactory returns a ClientFactory implementation
// with the given ConnFactory.
//
// limit is the maximum number of request that the
// applcation support. 0 means the maximum number
// available for 16bit request id (because 0 is not
// a valid reqeust id, 65535).
//
// Default 0.
//
func SimpleClientFactory(connFactory ConnFactory, limit uint32) ClientFactory {
func SimpleClientFactory(connFactory ConnFactory) ClientFactory {
return func() (c Client, err error) {
// connect to given network address
conn, err := connFactory()
Expand All @@ -406,7 +397,7 @@ func SimpleClientFactory(connFactory ConnFactory, limit uint32) ClientFactory {
// create client
c = &client{
conn: newConn(conn),
ids: newIDs(limit),
ids: newIDs(),
}
return
}
Expand Down
104 changes: 63 additions & 41 deletions client_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,77 +6,99 @@ import (
"time"
)

// requestId is supposed to be unique among all active requests in a connection. So a requestId
// should not be reused until the previous request of the same id is inactive (releasing the id).
func TestIDPool_Alloc(t *testing.T) {
t.Logf("default limit: %d", 65535)
ids := newIDs(0)
for i := uint32(1); i <= 65535; i++ {
ids := newIDs()
idToReserve := uint16(rand.Int31n(int32(MaxRequestID)))

// Loop over all ids to make sure it is sequencely returning
// 1 to 65535.
//
// Note: Use uint as loop counter so it can loop past 65535
// to end the loop (also keep the code readable)
for i := uint(1); i <= uint(MaxRequestID); i++ {
if want, have := uint16(i), ids.Alloc(); want != have {
t.Errorf("expected %d, got %d", want, have)
t.Fatalf("expected %v, got %v", want, have)
}
if i != uint(idToReserve) {
ids.Release(uint16(i))
}
}

// test if new id can be allocated
// when all ids are already allocated
newAlloc := make(chan uint16)
go func(ids idPool, newAlloc chan<- uint16) {
newAlloc <- ids.Alloc()
}(ids, newAlloc)

select {
case reqID := <-newAlloc:
t.Errorf("unexpected new allocation: %d", reqID)
case <-time.After(time.Millisecond * 100):
t.Log("blocks as expected")
// Loop over all requestids 5 times
for i := 0; i < 5; i++ {
for j := uint(1); j <= uint(MaxRequestID-1); j++ {
id := ids.Alloc()
if id == 0 {
t.Fatal("ids.Alloc() is never allowed to return 0")
} else if id == idToReserve {
t.Fatalf("The requestId %v was not reserved as expect", id)
} else if j < uint(idToReserve) {
if want, have := uint(id), j; want != have {
t.Fatalf("expected %v, got %v", want, have)
}
} else if j >= uint(idToReserve) {
if want, have := uint(id), j+1; want != have {
t.Fatalf("expected %v, got %v", want, have)
}
}
ids.Release(id) // always release the allocated id
}
}

// now, release a random ID
released := uint16(rand.Int31n(65535))
go func(ids idPool, released uint16) {
ids.Release(released)
}(ids, released)
// release the reserved id
ids.Release(idToReserve)

select {
case reqID := <-newAlloc:
if want, have := released, reqID; want != have {
t.Errorf("expected %d, got %d", want, have)
// make sure all ids are available again
for i := uint(1); i <= uint(MaxRequestID); i++ {
if want, have := uint16(i), ids.Alloc(); want != have {
t.Fatalf("expected %v, got %v", want, have)
}
case <-time.After(time.Millisecond * 100):
t.Errorf("unexpected blocking")
}
}

func TestIDPool_Alloc_withLimit(t *testing.T) {
// If all IDs are used up, pool is supposed to block on alloc after exhaustion.
func TestIDPool_block(t *testing.T) {

limit := uint32(rand.Int31n(100) + 10)
t.Logf("random limit: %d", limit)
ids := newIDs()

ids := newIDs(limit)
for i := uint32(1); i <= limit; i++ {
if want, have := uint16(i), ids.Alloc(); want != have {
t.Errorf("expected %d, got %d", want, have)
// Test allocating all ids once.
for i := uint(1); i <= uint(MaxRequestID); i++ {
id := ids.Alloc()
if want, have := i, uint(id); want != have {
t.Errorf("expected to allocate %v, got %v", want, have)
t.FailNow()
}
}

// test if new id can be allocated
// when all ids are already allocated
newAlloc := make(chan uint16)
go func(ids idPool, newAlloc chan<- uint16) {
waitAlloc := func(ids *idPool, newAlloc chan<- uint16) {
newAlloc <- ids.Alloc()
}(ids, newAlloc)
}
go waitAlloc(ids, newAlloc)
go waitAlloc(ids, newAlloc)
go waitAlloc(ids, newAlloc)
go waitAlloc(ids, newAlloc)
go waitAlloc(ids, newAlloc)

// wait some time to see if we can allocate id again
select {
case reqID := <-newAlloc:
t.Errorf("unexpected new allocation: %d", reqID)
t.Fatalf("unexpected new allocation: %d", reqID)
case <-time.After(time.Millisecond * 100):
t.Log("blocks as expected")
}

// now, release a random ID
released := uint16(rand.Int31n(int32(limit)))
go func(ids idPool, released uint16) {
released := uint16(rand.Int31n(int32(MaxRequestID)))
go func(ids *idPool, released uint16) {
// release an id
ids.Release(released)
t.Logf("id released: %v", released)
}(ids, released)

// wait some time to see if we can allocate id again
select {
case reqID := <-newAlloc:
if want, have := released, reqID; want != have {
Expand Down
2 changes: 0 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ func testHandlerForCancel(t *testing.T, p *appServer, w http.ResponseWriter, r *

c, err := gofast.SimpleClientFactory(
gofast.SimpleConnFactory(p.Network(), p.Address()),
0,
)()
if err != nil {
http.Error(w, "failed to connect to FastCGI application", http.StatusBadGateway)
Expand Down Expand Up @@ -203,7 +202,6 @@ func TestClient_StdErr(t *testing.T) {
doRequest := func(w http.ResponseWriter, r *http.Request) (errStr string) {
c, err := gofast.SimpleClientFactory(
gofast.SimpleConnFactory(p.Network(), p.Address()),
0,
)()
if err != nil {
errStr = "web server: unable to connect to FastCGI application: " + err.Error()
Expand Down
4 changes: 2 additions & 2 deletions example/nodejs/nodejs.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
func NewHandler(entrypoint, network, address string) http.Handler {
connFactory := gofast.SimpleConnFactory(network, address)
pool := gofast.NewClientPool(
gofast.SimpleClientFactory(connFactory, 0),
gofast.SimpleClientFactory(connFactory),
10,
60*time.Second,
)
Expand Down Expand Up @@ -59,7 +59,7 @@ func NewMuxHandler(
// common client pool for both filter and responder handler
connFactory := gofast.SimpleConnFactory(network, address)
pool := gofast.NewClientPool(
gofast.SimpleClientFactory(connFactory, 0),
gofast.SimpleClientFactory(connFactory),
10,
60*time.Second,
)
Expand Down
4 changes: 2 additions & 2 deletions example/php/php.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func NewSimpleHandler(docroot, network, address string) http.Handler {
connFactory := gofast.SimpleConnFactory(network, address)
h := gofast.NewHandler(
gofast.NewPHPFS(docroot)(gofast.BasicSession),
gofast.SimpleClientFactory(connFactory, 0),
gofast.SimpleClientFactory(connFactory),
)
return h
}
Expand All @@ -41,7 +41,7 @@ func NewFileEndpointHandler(filepath, network, address string) http.Handler {
connFactory := gofast.SimpleConnFactory(network, address)
h := gofast.NewHandler(
gofast.NewFileEndpoint(filepath)(gofast.BasicSession),
gofast.SimpleClientFactory(connFactory, 0),
gofast.SimpleClientFactory(connFactory),
)
return h
}
2 changes: 1 addition & 1 deletion example/python3/python3.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
func NewHandler(entrypoint, network, address string) http.Handler {
connFactory := gofast.SimpleConnFactory(network, address)
pool := gofast.NewClientPool(
gofast.SimpleClientFactory(connFactory, 0),
gofast.SimpleClientFactory(connFactory),
10,
60*time.Second,
)
Expand Down
Loading

0 comments on commit 33ba70c

Please sign in to comment.