From 2a3e67c8bda0071b1a562c23e21f1c65c27586ff Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Wed, 17 Jul 2024 11:09:02 +0530 Subject: [PATCH] Add support for QWATCH. --- qwatch.go | 129 ++++++++++++++++++++++++++++++++++++++++++++++++------ redis.go | 6 +-- 2 files changed, 117 insertions(+), 18 deletions(-) diff --git a/qwatch.go b/qwatch.go index 9a458e44f..7ec5078c4 100644 --- a/qwatch.go +++ b/qwatch.go @@ -12,13 +12,17 @@ import ( "github.com/dicedb/go-dice/internal/proto" ) +type KV struct { + Key string + Value interface{} +} + type QMessage struct { - Query string - Payload string + Updates []KV } func (m *QMessage) String() string { - return fmt.Sprintf("QMessage<%s: %s>", m.Query, m.Payload) + return fmt.Sprintf("QMessage(%v)", m.Updates) } // QWatch implements QWATCH commands. QMessage receiving is NOT safe @@ -50,9 +54,8 @@ func (q *QWatch) init() { func (q *QWatch) String() string { var sb strings.Builder - for query := range q.queries { - sb.WriteString(query) - sb.WriteString("; ") + for query, args := range q.queries { + sb.WriteString(fmt.Sprintf("%s(%v); ", query, args)) } return fmt.Sprintf("QWatch(%s)", sb.String()) } @@ -164,7 +167,7 @@ func (q *QWatch) WatchQuery(ctx context.Context, query string, args ...interface q.mu.Lock() defer q.mu.Unlock() - err := q.watchQuery(ctx, "qwatch", query, args...) + err := q.watchQuery(ctx, "QWATCH", query, args...) if q.queries == nil { q.queries = make(map[string][]interface{}) } @@ -182,7 +185,6 @@ func (q *QWatch) watchQuery(ctx context.Context, redisCmd string, query string, q.releaseConn(ctx, cn, err, false) return err } - func (q *QWatch) newQMessage(reply interface{}) (interface{}, error) { switch reply := reply.(type) { case string: @@ -190,15 +192,29 @@ func (q *QWatch) newQMessage(reply interface{}) (interface{}, error) { Payload: reply, }, nil case []interface{}: - switch kind := reply[0].(string); kind { + if len(reply) == 0 { + return nil, fmt.Errorf("redis: empty qwatch message") + } + + kind, ok := reply[0].(string) + if !ok { + // If the first element is not a string, assume it's a qwatch message + return q.processQWatchMessage(reply) + } + + switch kind { case "qwatch": - return &QMessage{ - Query: reply[1].(string), - Payload: reply[2].(string), - }, nil + return q.processQWatchMessage(reply[1:]) case "pong": + if len(reply) < 2 { + return nil, fmt.Errorf("redis: invalid pong message format") + } + payload, ok := reply[1].(string) + if !ok { + return nil, fmt.Errorf("redis: invalid pong payload type") + } return &Pong{ - Payload: reply[1].(string), + Payload: payload, }, nil default: return nil, fmt.Errorf("redis: unsupported qwatch message: %q", kind) @@ -208,6 +224,32 @@ func (q *QWatch) newQMessage(reply interface{}) (interface{}, error) { } } +func (q *QWatch) processQWatchMessage(payload interface{}) (*QMessage, error) { + updates := make([]KV, 0) + + switch data := payload.(type) { + case []interface{}: + for _, update := range data { + kv, ok := update.([]interface{}) + if !ok || len(kv) != 2 { + return nil, fmt.Errorf("redis: invalid key-value pair in qwatch message") + } + key, ok := kv[0].(string) + if !ok { + return nil, fmt.Errorf("redis: invalid key type in qwatch message") + } + value := kv[1] + updates = append(updates, KV{Key: key, Value: value}) + } + default: + return nil, fmt.Errorf("redis: unsupported qwatch message payload: %T", payload) + } + + return &QMessage{ + Updates: updates, + }, nil +} + // ReceiveTimeout acts like Receive but returns an error if message // is not received in time. This is low-level API and in most cases // Channel should be used instead. @@ -301,6 +343,36 @@ type qChannel struct { type QChannelOption func(c *qChannel) +// WithQChannelSize specifies the Go chan size that is used to buffer incoming messages. +// +// The default is 100 messages. +func WithQChannelSize(size int) QChannelOption { + return func(c *qChannel) { + c.chanSize = size + } +} + +// WithQChannelHealthCheckInterval specifies the health check interval. +// PubSub will ping Redis Server if it does not receive any messages within the interval. +// To disable health check, use zero interval. +// +// The default is 3 seconds. +func WithQChannelHealthCheckInterval(d time.Duration) QChannelOption { + return func(c *qChannel) { + c.checkInterval = d + } +} + +// WithQChannelSendTimeout specifies the channel send timeout after which +// the message is dropped. +// +// The default is 60 seconds. +func WithQChannelSendTimeout(d time.Duration) QChannelOption { + return func(c *qChannel) { + c.chanSendTimeout = d + } +} + func newWatchChannel(qwatch *QWatch, opts ...QChannelOption) *qChannel { c := &qChannel{ qwatch: qwatch, @@ -417,3 +489,32 @@ func (c *qChannel) initMsgChan() { } }() } + +// UnwatchQuery unsubscribes the client from the specified query. +// It returns an error if unsubscription fails. +func (q *QWatch) UnwatchQuery(ctx context.Context, query string) error { + q.mu.Lock() + defer q.mu.Unlock() + + err := q.unwatchQuery(ctx, "QUNWATCH", query) + if err == nil { + delete(q.queries, query) + } + return err +} + +func (q *QWatch) unwatchQuery(ctx context.Context, redisCmd string, query string) error { + cn, err := q.conn(ctx, query) + if err != nil { + return err + } + + err = q._unwatchQuery(ctx, cn, redisCmd, query) + q.releaseConn(ctx, cn, err, false) + return err +} + +func (q *QWatch) _unwatchQuery(ctx context.Context, cn *pool.Conn, redisCmd string, query string) error { + cmd := NewSliceCmd(ctx, redisCmd, query) + return q.writeCmd(ctx, cn, cmd) +} diff --git a/redis.go b/redis.go index 1ac94d620..30d08afe5 100644 --- a/redis.go +++ b/redis.go @@ -783,10 +783,8 @@ func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub { return pubsub } -func (c *Client) QWatch(ctx context.Context, query string, args ...interface{}) *QWatch { - qwatch := c.qwatch() - _ = qwatch.WatchQuery(ctx, query, args...) - return qwatch +func (c *Client) QWatch(ctx context.Context) *QWatch { + return c.qwatch() } // PSubscribe subscribes the client to the given patterns.