Skip to content

Commit

Permalink
fix async write bug
Browse files Browse the repository at this point in the history
  • Loading branch information
shaovie committed Aug 12, 2023
1 parent e3f8205 commit 1765868
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 54 deletions.
4 changes: 2 additions & 2 deletions async_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func newAsyncWrite(ep *evPoll) (*asyncWrite, error) {
}
func (aw *asyncWrite) push(awi asyncWriteItem) {
aw.mtx.Lock()
aw.writeq.Push(awi)
aw.writeq.PushBack(awi)
aw.mtx.Unlock()

if !aw.notified.CompareAndSwap(0, 1) {
Expand All @@ -80,7 +80,7 @@ func (aw *asyncWrite) OnRead() bool {
}

for i := 0; i < 256; i++ { // Don't process too many at once
item, ok := aw.readq.Pop()
item, ok := aw.readq.PopFront()
if !ok {
break
}
Expand Down
82 changes: 54 additions & 28 deletions example/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ var (
asynBufPool sync.Pool
)

const asynBufSize int = 4096

const (
FrameNull = -1
FrameContinue = 0
Expand Down Expand Up @@ -124,13 +126,14 @@ const maxControlFramePayloadSize = 125
const maxFreamHeaderSize = 14

type wsFrame struct {
isfin bool
flate bool
masked bool
hlen int8 // header len
opcode int
payload int64
maskKey [4]byte
complete bool // frame complete
isfin bool
flate bool
masked bool
hlen int8 // header len
opcode int
payload int64
maskKey [4]byte
}

type continueWsFrame struct {
Expand Down Expand Up @@ -238,9 +241,6 @@ func (c *Conn) OnOpen(fd int) bool {
return true
}
func (c *Conn) OnRead() bool {
if c.closed == true {
return false
}
buf, n, _ := c.Read()
if n > 0 {
if c.upgraded == false {
Expand Down Expand Up @@ -396,8 +396,7 @@ func (c *Conn) onUpgrade(buf []byte) bool {
var resp = make([]byte, 0, 256)
c.compressEnabled = false // 还不支持
if c.compressEnabled {
resp = append(resp, (unsafe.Slice(unsafe.StringData(switchHeaderWithFlateS),
len(switchHeaderWithFlateS)))...)
resp = append(resp, (unsafe.Slice(unsafe.StringData(switchHeaderWithFlateS), len(switchHeaderWithFlateS)))...)
} else {
resp = append(resp, (unsafe.Slice(unsafe.StringData(switchHeaderS), len(switchHeaderS)))...)
}
Expand All @@ -418,11 +417,20 @@ func (c *Conn) onUpgrade(buf []byte) bool {

writen, _ := c.Write(resp)
if writen < len(resp) {
bf := asynBufPool.Get().([]byte)
n := copy(bf, buf[writen:])
var bf []byte
var flag, n int
if len(resp)-writen <= asynBufSize {
bf = asynBufPool.Get().([]byte)
n = copy(bf, resp[writen:])
} else {
bf = make([]byte, len(resp)-writen)
n = copy(bf, resp[writen:])
flag = 2
}
c.AsyncWrite(c, goev.AsyncWriteBuf{
Len: n,
Buf: bf,
Flag: flag,
Len: n,
Buf: bf,
})
}
c.upgraded = true
Expand Down Expand Up @@ -460,9 +468,9 @@ func (c *Conn) onFrame(buf []byte) bool {
payloadBuf = buf[bufOffset+hlen : bufOffset+hlen+payloadLen]
}

// get a complete frame
bufOffset += hlen + payloadLen
bufLen -= hlen + payloadLen
// get a complete frame

if wsf.isfin {
if len(payloadBuf) > 0 {
Expand All @@ -475,13 +483,13 @@ func (c *Conn) onFrame(buf []byte) bool {
} else if wsf.opcode == FramePong {
c.OnPong(payloadBuf)
} else if wsf.opcode == FrameClose {
ce := CloseInfo{Code: CloseNoCloseRcvd}
ce := CloseInfo{Code: CloseNoCloseRcvd, Info: ""}
if len(payloadBuf) > 1 {
ce.Code = CloseCode(binary.BigEndian.Uint16(payloadBuf))
ce.Info = string(payloadBuf[2:])
}
c.OnCloseFrame(ce)
return false // close
break
}
} else {
if wsf.opcode == FrameContinue {
Expand Down Expand Up @@ -603,7 +611,24 @@ func (c *Conn) writeControlFrame(opcode int, payload []byte) {
buf[0] = byte(opcode) | 1<<7
buf[1] = byte(payloadLen)
copy(buf[hlen:], payload)
c.Write(buf[0 : payloadLen+hlen])
writen, _ := c.Write(buf[:payloadLen+hlen])
if writen < payloadLen+hlen {
var bf []byte
var flag, n int
if payloadLen+hlen-writen <= asynBufSize {
bf = asynBufPool.Get().([]byte)
n = copy(bf, buf[writen:payloadLen+hlen])
} else {
bf = make([]byte, payloadLen+hlen-writen)
n = copy(bf, buf[writen:payloadLen+hlen])
flag = 2
}
c.AsyncWrite(c, goev.AsyncWriteBuf{
Flag: flag,
Len: n,
Buf: bf,
})
}
}
func (c *Conn) writeMessageFrame(opcode int, data []byte, flate, fin bool) {
if c.closed {
Expand Down Expand Up @@ -639,17 +664,17 @@ func (c *Conn) writeMessageFrame(opcode int, data []byte, flate, fin bool) {
// no mask in server side
copy(buff[hlen:], data)
wlen := hlen + payloadLen
writen, err := c.Write(buff[:wlen])
if err == nil && writen < wlen {
var flag, n int
writen, _ := c.Write(buff[:wlen])
if writen < wlen {
var bf []byte
if wlen-writen < 1024 {
var flag, n int
if wlen-writen <= asynBufSize {
bf = asynBufPool.Get().([]byte)
n = copy(bf, buf[writen:])
n = copy(bf, buff[writen:wlen])
} else {
bf = make([]byte, wlen-writen)
n = copy(bf, buff[writen:])
flat = 2
n = copy(bf, buff[writen:wlen])
flag = 2
}
c.AsyncWrite(c, goev.AsyncWriteBuf{
Flag: flag,
Expand All @@ -672,6 +697,7 @@ func (c *Conn) OnPong(data []byte) {
}
func (c *Conn) OnCloseFrame(ci CloseInfo) {
if c.closed == false {
fmt.Println("on close frame")
bf := (*(*[2]byte)(unsafe.Pointer(&ci.Code)))[:]
binary.BigEndian.PutUint16(bf, uint16(CloseNormalClosure))
c.writeControlFrame(FrameClose, bf[0:2])
Expand All @@ -684,7 +710,7 @@ func main() {
runtime.GOMAXPROCS(procNum)

asynBufPool.New = func() any {
return make([]byte, 1024)
return make([]byte, asynBufSize)
}

var err error
Expand Down
5 changes: 4 additions & 1 deletion io_handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ func (h *IOHandle) WriteBuff() []byte {
// n = [n, len(bf]
func (h *IOHandle) Write(bf []byte) (n int, err error) {
if h._fd > 0 { // NOTE fd must > 0
if h._asyncWriteBufQ != nil && !h._asyncWriteBufQ.IsEmpty() {
return
}
for {
n, err = syscall.Write(h._fd, bf)
if n < 0 {
Expand Down Expand Up @@ -159,7 +162,7 @@ func (h *IOHandle) Destroy(eh EvHandler) {

if h._asyncWriteBufQ != nil && !h._asyncWriteBufQ.IsEmpty() {
for {
abf, ok := h._asyncWriteBufQ.Pop()
abf, ok := h._asyncWriteBufQ.PopFront()
if !ok {
break
}
Expand Down
26 changes: 19 additions & 7 deletions io_handle_async.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (h *IOHandle) asyncOrderedWrite(eh EvHandler, abf AsyncWriteBuf) {
return
}
if h._asyncWriteBufQ != nil && !h._asyncWriteBufQ.IsEmpty() {
h._asyncWriteBufQ.Push(abf)
h._asyncWriteBufQ.PushBack(abf)
return
}

Expand All @@ -56,7 +56,7 @@ func (h *IOHandle) asyncOrderedWrite(eh EvHandler, abf AsyncWriteBuf) {
if h._asyncWriteBufQ == nil {
h._asyncWriteBufQ = NewRingBuffer[AsyncWriteBuf](2)
}
h._asyncWriteBufQ.Push(abf)
h._asyncWriteBufQ.PushBack(abf)

if h._asyncWriteWaiting == false {
h._asyncWriteWaiting = true
Expand All @@ -73,21 +73,33 @@ func (h *IOHandle) asyncOrderedWrite(eh EvHandler, abf AsyncWriteBuf) {
// x.AsyncOrderedFlush(x)
// }
func (h *IOHandle) AsyncOrderedFlush(eh EvHandler) {
if h._asyncWriteBufQ == nil || h._asyncWriteBufQ.IsEmpty() {
return
}
if h._fd < 1 {
return
}
n := h._asyncWriteBufQ.Len()
// It is necessary to use n to limit the number of sending attempts.
// If there is a possibility of sending failure, the data should be saved again in _asyncWriteBufQ
for i := 0; i < n; i++ {
abf, ok := h._asyncWriteBufQ.Pop()
abf, ok := h._asyncWriteBufQ.PopFront()
if !ok {
break
}
eh.asyncOrderedWrite(eh, abf)
if abf.Len < 1 || abf.Writen >= abf.Len { // so you can pass Buf=nil
eh.OnAsyncWriteBufDone(abf.Buf, abf.Flag)
continue
}
n, _ := syscall.Write(h._fd, abf.Buf[abf.Writen:abf.Len])
if n > 0 {
if n == (abf.Len - abf.Writen) {
h._asyncLastPartialWriteTime = 0
eh.OnAsyncWriteBufDone(abf.Buf, abf.Flag) // send completely
continue
}
abf.Writen += n // Partially write, shift n
}
h._asyncLastPartialWriteTime = time.Now().UnixMilli()
h._asyncWriteBufQ.PushFront(abf)
break
}
if h._asyncWriteBufQ.IsEmpty() {
h._ep.subtract(h._fd, EvOut)
Expand Down
17 changes: 13 additions & 4 deletions ringbuffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ func (rb *RingBuffer[T]) Len() int {
return rb.len
}

// Push an item
func (rb *RingBuffer[T]) Push(data T) {
// PushBack an item
func (rb *RingBuffer[T]) PushBack(data T) {
if rb.len == rb.size {
rb.grow()
}
Expand All @@ -50,8 +50,8 @@ func (rb *RingBuffer[T]) Push(data T) {
rb.len++
}

// Pop an item
func (rb *RingBuffer[T]) Pop() (data T, ok bool) {
// PopFront an item
func (rb *RingBuffer[T]) PopFront() (data T, ok bool) {
if rb.len == 0 {
return
}
Expand All @@ -61,6 +61,15 @@ func (rb *RingBuffer[T]) Pop() (data T, ok bool) {
ok = true
return
}
// PushFront an item
func (rb *RingBuffer[T]) PushFront(data T) {
if rb.len == rb.size {
rb.grow()
}
rb.head = (rb.size + rb.head - 1) % rb.size // prev
rb.buffer[rb.head] = data
rb.len++
}

func (rb *RingBuffer[T]) grow() {
newCapacity := rb.size * 2
Expand Down
28 changes: 16 additions & 12 deletions ringbuffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,30 @@ import (
)

func TestRingBuffer(t *testing.T) {
rb := NewRingBuffer[int](50)
rb := NewRingBuffer[int](10)

for i := 0; i < 100; i++ {
rb.Push(i)
for i := 0; i < 10; i++ {
rb.PushBack(i)
}
for i := 0; i < 50; i++ {
rb.Pop()
t.Logf("is full:%v\n", rb.IsFull())
v, _ := rb.PopFront()
rb.PushFront(v)
t.Logf("is full:%v\n", rb.IsFull())

for i := 0; i < 5; i++ {
rb.PopFront()
}
t.Logf("is empty:%v\n", rb.IsEmpty())
for i := 0; i < 50; i++ {
rb.Push(i)
for i := 0; i < 5; i++ {
rb.PushBack(i)
}
t.Logf("is full:%v\n", rb.IsFull())

fmt.Println("len:", rb.Len())
fmt.Println(rb.Pop()) // Output: 0
fmt.Println(rb.Pop()) // Output: 1
fmt.Println(rb.Pop()) // Output: 2
fmt.Println(rb.PopFront()) // Output: 0
fmt.Println(rb.PopFront()) // Output: 1

for !rb.IsEmpty() {
data, _ := rb.Pop()
data, _ := rb.PopFront()
fmt.Println(data)
}
fmt.Println("len:", rb.Len())
Expand Down

0 comments on commit 1765868

Please sign in to comment.