Skip to content

Commit

Permalink
enh: websocket reconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
huskar-t committed May 9, 2024
1 parent cff3f5e commit 6c57f27
Show file tree
Hide file tree
Showing 12 changed files with 895 additions and 222 deletions.
68 changes: 50 additions & 18 deletions ws/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package client
import (
"bytes"
"encoding/json"
"errors"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -33,7 +34,7 @@ type EnvelopePool struct {
func (ep *EnvelopePool) Get() *Envelope {
epv := ep.p.Get()
if epv == nil {
return &Envelope{Msg: new(bytes.Buffer)}
return &Envelope{Msg: new(bytes.Buffer), ErrorChan: make(chan error, 1)}
}
return epv.(*Envelope)
}
Expand All @@ -44,14 +45,24 @@ func (ep *EnvelopePool) Put(epv *Envelope) {
}

type Envelope struct {
Type int
Msg *bytes.Buffer
Type int
Msg *bytes.Buffer
ErrorChan chan error
}

func (e *Envelope) Reset() {
e.Msg.Reset()
if e.Msg.Cap() > 64*1024 {
e.Msg = new(bytes.Buffer)
} else {
e.Msg.Reset()
}
if len(e.ErrorChan) > 0 {
e.ErrorChan = make(chan error, 1)
}
}

var ClosedError = errors.New("websocket closed")

type Client struct {
conn *websocket.Conn
status uint32
Expand All @@ -63,9 +74,10 @@ type Client struct {
TextMessageHandler func(message []byte)
BinaryMessageHandler func(message []byte)
ErrorHandler func(err error)
SendMessageHandler func(envelope *Envelope)
once sync.Once
errHandlerOnce sync.Once
//SendMessageHandler func(envelope *Envelope)
once sync.Once
errHandlerOnce sync.Once
err error
}

func NewClient(conn *websocket.Conn, sendChanLength uint) *Client {
Expand All @@ -80,9 +92,9 @@ func NewClient(conn *websocket.Conn, sendChanLength uint) *Client {
TextMessageHandler: func(message []byte) {},
BinaryMessageHandler: func(message []byte) {},
ErrorHandler: func(err error) {},
SendMessageHandler: func(envelope *Envelope) {
GlobalEnvelopePool.Put(envelope)
},
//SendMessageHandler: func(envelope *Envelope) {
// GlobalEnvelopePool.Put(envelope)
//},
}
}

Expand Down Expand Up @@ -117,41 +129,61 @@ func (c *Client) WritePump() {
defer func() {
ticker.Stop()
}()

for {
select {
case message, ok := <-c.sendChan:
if !ok {
return
if message == nil {
return
}
message.ErrorChan <- ClosedError
continue
}
c.conn.SetWriteDeadline(time.Now().Add(c.WriteWait))
err := c.conn.WriteMessage(message.Type, message.Msg.Bytes())
if err != nil {
message.ErrorChan <- err
c.handleError(err)
return
c.Close()
for message := range c.sendChan {
if message == nil {
return
}
message.ErrorChan <- ClosedError
}
}
c.SendMessageHandler(message)
message.ErrorChan <- nil
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(c.WriteWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
c.handleError(err)
return
c.Close()
for message := range c.sendChan {
if message == nil {
return
}
message.ErrorChan <- ClosedError
}
}
}
}
}

func (c *Client) Send(envelope *Envelope) {
func (c *Client) Send(envelope *Envelope) error {
if !c.IsRunning() {
return
return ClosedError
}
var err error
defer func() {
// maybe closed
if recover() != nil {

err = ClosedError
return
}
}()
c.sendChan <- envelope
return err
}

func (c *Client) GetEnvelope() *Envelope {
Expand All @@ -168,8 +200,8 @@ func (c *Client) IsRunning() bool {

func (c *Client) Close() {
c.once.Do(func() {
close(c.sendChan)
atomic.StoreUint32(&c.status, StatusStop)
close(c.sendChan)
if c.conn != nil {
c.conn.Close()
}
Expand Down
41 changes: 31 additions & 10 deletions ws/schemaless/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,22 @@ const (
)

type Config struct {
url string
chanLength uint
user string
password string
db string
readTimeout time.Duration
writeTimeout time.Duration
errorHandler func(error)
enableCompression bool
url string
chanLength uint
user string
password string
db string
readTimeout time.Duration
writeTimeout time.Duration
errorHandler func(error)
enableCompression bool
autoReconnect bool
reconnectIntervalMs int
reconnectRetryCount int
}

func NewConfig(url string, chanLength uint, opts ...func(*Config)) *Config {
c := Config{url: url, chanLength: chanLength}
c := Config{url: url, chanLength: chanLength, reconnectRetryCount: 3, reconnectIntervalMs: 2000}
for _, opt := range opts {
opt(&c)
}
Expand Down Expand Up @@ -71,3 +74,21 @@ func SetEnableCompression(enableCompression bool) func(*Config) {
c.enableCompression = enableCompression
}
}

func SetAutoReconnect(reconnect bool) func(*Config) {
return func(c *Config) {
c.autoReconnect = reconnect
}
}

func SetReconnectIntervalMs(reconnectIntervalMs int) func(*Config) {
return func(c *Config) {
c.reconnectIntervalMs = reconnectIntervalMs
}
}

func SetReconnectRetryCount(reconnectRetryCount int) func(*Config) {
return func(c *Config) {
c.reconnectRetryCount = reconnectRetryCount
}
}
Loading

0 comments on commit 6c57f27

Please sign in to comment.