Skip to content

Commit

Permalink
Add support for QWATCH.
Browse files Browse the repository at this point in the history
  • Loading branch information
JyotinderSingh committed Jul 17, 2024
1 parent 0ca5f8b commit 2a3e67c
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 18 deletions.
129 changes: 115 additions & 14 deletions qwatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ import (
"github.com/dicedb/go-dice/internal/proto"
)

type KV struct {
Key string
Value interface{}
}

type QMessage struct {
Query string
Payload string
Updates []KV
}

func (m *QMessage) String() string {
return fmt.Sprintf("QMessage<%s: %s>", m.Query, m.Payload)
return fmt.Sprintf("QMessage(%v)", m.Updates)
}

// QWatch implements QWATCH commands. QMessage receiving is NOT safe
Expand Down Expand Up @@ -50,9 +54,8 @@ func (q *QWatch) init() {

func (q *QWatch) String() string {
var sb strings.Builder
for query := range q.queries {
sb.WriteString(query)
sb.WriteString("; ")
for query, args := range q.queries {
sb.WriteString(fmt.Sprintf("%s(%v); ", query, args))
}
return fmt.Sprintf("QWatch(%s)", sb.String())
}
Expand Down Expand Up @@ -164,7 +167,7 @@ func (q *QWatch) WatchQuery(ctx context.Context, query string, args ...interface
q.mu.Lock()
defer q.mu.Unlock()

err := q.watchQuery(ctx, "qwatch", query, args...)
err := q.watchQuery(ctx, "QWATCH", query, args...)
if q.queries == nil {
q.queries = make(map[string][]interface{})
}
Expand All @@ -182,23 +185,36 @@ func (q *QWatch) watchQuery(ctx context.Context, redisCmd string, query string,
q.releaseConn(ctx, cn, err, false)
return err
}

func (q *QWatch) newQMessage(reply interface{}) (interface{}, error) {
switch reply := reply.(type) {
case string:
return &Pong{
Payload: reply,
}, nil
case []interface{}:
switch kind := reply[0].(string); kind {
if len(reply) == 0 {
return nil, fmt.Errorf("redis: empty qwatch message")
}

kind, ok := reply[0].(string)
if !ok {
// If the first element is not a string, assume it's a qwatch message
return q.processQWatchMessage(reply)
}

switch kind {
case "qwatch":
return &QMessage{
Query: reply[1].(string),
Payload: reply[2].(string),
}, nil
return q.processQWatchMessage(reply[1:])
case "pong":
if len(reply) < 2 {
return nil, fmt.Errorf("redis: invalid pong message format")
}
payload, ok := reply[1].(string)
if !ok {
return nil, fmt.Errorf("redis: invalid pong payload type")
}
return &Pong{
Payload: reply[1].(string),
Payload: payload,
}, nil
default:
return nil, fmt.Errorf("redis: unsupported qwatch message: %q", kind)
Expand All @@ -208,6 +224,32 @@ func (q *QWatch) newQMessage(reply interface{}) (interface{}, error) {
}
}

func (q *QWatch) processQWatchMessage(payload interface{}) (*QMessage, error) {
updates := make([]KV, 0)

switch data := payload.(type) {
case []interface{}:
for _, update := range data {
kv, ok := update.([]interface{})
if !ok || len(kv) != 2 {
return nil, fmt.Errorf("redis: invalid key-value pair in qwatch message")
}
key, ok := kv[0].(string)
if !ok {
return nil, fmt.Errorf("redis: invalid key type in qwatch message")
}
value := kv[1]
updates = append(updates, KV{Key: key, Value: value})
}
default:
return nil, fmt.Errorf("redis: unsupported qwatch message payload: %T", payload)
}

return &QMessage{
Updates: updates,
}, nil
}

// ReceiveTimeout acts like Receive but returns an error if message
// is not received in time. This is low-level API and in most cases
// Channel should be used instead.
Expand Down Expand Up @@ -301,6 +343,36 @@ type qChannel struct {

type QChannelOption func(c *qChannel)

// WithQChannelSize specifies the Go chan size that is used to buffer incoming messages.
//
// The default is 100 messages.
func WithQChannelSize(size int) QChannelOption {
return func(c *qChannel) {
c.chanSize = size
}
}

// WithQChannelHealthCheckInterval specifies the health check interval.
// PubSub will ping Redis Server if it does not receive any messages within the interval.
// To disable health check, use zero interval.
//
// The default is 3 seconds.
func WithQChannelHealthCheckInterval(d time.Duration) QChannelOption {
return func(c *qChannel) {
c.checkInterval = d
}
}

// WithQChannelSendTimeout specifies the channel send timeout after which
// the message is dropped.
//
// The default is 60 seconds.
func WithQChannelSendTimeout(d time.Duration) QChannelOption {
return func(c *qChannel) {
c.chanSendTimeout = d
}
}

func newWatchChannel(qwatch *QWatch, opts ...QChannelOption) *qChannel {
c := &qChannel{
qwatch: qwatch,
Expand Down Expand Up @@ -417,3 +489,32 @@ func (c *qChannel) initMsgChan() {
}
}()
}

// UnwatchQuery unsubscribes the client from the specified query.
// It returns an error if unsubscription fails.
func (q *QWatch) UnwatchQuery(ctx context.Context, query string) error {
q.mu.Lock()
defer q.mu.Unlock()

err := q.unwatchQuery(ctx, "QUNWATCH", query)
if err == nil {
delete(q.queries, query)
}
return err
}

func (q *QWatch) unwatchQuery(ctx context.Context, redisCmd string, query string) error {
cn, err := q.conn(ctx, query)
if err != nil {
return err
}

err = q._unwatchQuery(ctx, cn, redisCmd, query)
q.releaseConn(ctx, cn, err, false)
return err
}

func (q *QWatch) _unwatchQuery(ctx context.Context, cn *pool.Conn, redisCmd string, query string) error {
cmd := NewSliceCmd(ctx, redisCmd, query)
return q.writeCmd(ctx, cn, cmd)
}
6 changes: 2 additions & 4 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -783,10 +783,8 @@ func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub {
return pubsub
}

func (c *Client) QWatch(ctx context.Context, query string, args ...interface{}) *QWatch {
qwatch := c.qwatch()
_ = qwatch.WatchQuery(ctx, query, args...)
return qwatch
func (c *Client) QWatch(ctx context.Context) *QWatch {
return c.qwatch()
}

// PSubscribe subscribes the client to the given patterns.
Expand Down

0 comments on commit 2a3e67c

Please sign in to comment.