From be8f2de7ab95b4e5ba5a8f9e65913d4b394b826f Mon Sep 17 00:00:00 2001 From: psr Date: Sat, 28 Sep 2024 19:11:05 +0530 Subject: [PATCH] Added Websocket support (#690) --- config/config.go | 4 + go.mod | 1 + go.sum | 2 + .../commands/websocket/main_test.go | 46 +++ .../commands/websocket/set_test.go | 262 ++++++++++++++++++ integration_tests/commands/websocket/setup.go | 131 +++++++++ internal/errors/errors.go | 3 +- internal/eval/execute.go | 11 +- internal/ops/store_op.go | 15 +- internal/server/utils/constants.go | 1 + internal/server/utils/redisCmdAdapter.go | 28 +- internal/server/utils/redisCmdAdapter_test.go | 71 +++++ internal/server/websocketServer.go | 192 +++++++++++++ internal/shard/shard_thread.go | 2 +- main.go | 21 ++ 15 files changed, 779 insertions(+), 11 deletions(-) create mode 100644 integration_tests/commands/websocket/main_test.go create mode 100644 integration_tests/commands/websocket/set_test.go create mode 100644 integration_tests/commands/websocket/setup.go create mode 100644 internal/server/websocketServer.go diff --git a/config/config.go b/config/config.go index a3467a7a1..81aa63370 100644 --- a/config/config.go +++ b/config/config.go @@ -32,6 +32,10 @@ var ( EnableMultiThreading = false EnableHTTP = true HTTPPort = 8082 + + EnableWebsocket = true + WebsocketPort = 8379 + // if RequirePass is set to an empty string, no authentication is required RequirePass = utils.EmptyStr diff --git a/go.mod b/go.mod index fc604490c..bf4955e47 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( github.com/cockroachdb/swiss v0.0.0-20240612210725-f4de07ae6964 github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 + github.com/gorilla/websocket v1.5.3 github.com/google/go-cmp v0.6.0 github.com/ohler55/ojg v1.24.0 github.com/pelletier/go-toml/v2 v2.2.3 diff --git a/go.sum b/go.sum index fc89d51ae..ae7d60a9e 100644 --- a/go.sum +++ b/go.sum @@ -41,6 +41,8 @@ github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= diff --git a/integration_tests/commands/websocket/main_test.go b/integration_tests/commands/websocket/main_test.go new file mode 100644 index 000000000..83380f479 --- /dev/null +++ b/integration_tests/commands/websocket/main_test.go @@ -0,0 +1,46 @@ +package websocket + +import ( + "context" + "log/slog" + "os" + "sync" + "testing" + "time" + + "github.com/dicedb/dice/internal/logger" +) + +func TestMain(m *testing.M) { + logger := logger.New(logger.Opts{WithTimestamp: false}) + slog.SetDefault(logger) + var wg sync.WaitGroup + + // Run the test server + // This is a synchronous method, because internally it + // checks for available port and then forks a goroutine + // to start the server + opts := TestServerOptions{ + Port: 8380, + Logger: logger, + } + RunWebsocketServer(context.Background(), &wg, opts) + + // Wait for the server to start + time.Sleep(2 * time.Second) + + executor := NewWebsocketCommandExecutor() + + // Run the test suite + exitCode := m.Run() + + // abort + conn := executor.ConnectToServer() + executor.FireCommand(conn, WebsocketCommand{ + Message: "abort", + }) + executor.DisconnectServer(conn) + + wg.Wait() + os.Exit(exitCode) +} diff --git a/integration_tests/commands/websocket/set_test.go b/integration_tests/commands/websocket/set_test.go new file mode 100644 index 000000000..f897ea8f3 --- /dev/null +++ b/integration_tests/commands/websocket/set_test.go @@ -0,0 +1,262 @@ +package websocket + +import ( + "fmt" + "strconv" + "strings" + "testing" + "time" + + testifyAssert "github.com/stretchr/testify/assert" + + "gotest.tools/v3/assert" +) + +type TestCase struct { + name string + commands []WebsocketCommand + expected []interface{} +} + +func TestSet(t *testing.T) { + exec := NewWebsocketCommandExecutor() + + testCases := []TestCase{ + { + name: "Set and Get Simple Value", + commands: []WebsocketCommand{ + {Message: "set k v"}, + {Message: "get k"}, + }, + expected: []interface{}{"OK", "v"}, + }, + { + name: "Set and Get Integer Value", + commands: []WebsocketCommand{ + {Message: "set k 123456789"}, + {Message: "get k"}, + }, + expected: []interface{}{"OK", float64(1.23456789e+08)}, + }, + { + name: "Overwrite Existing Key", + commands: []WebsocketCommand{ + {Message: "set k v1"}, + {Message: "set k 5"}, + {Message: "get k"}, + }, + expected: []interface{}{"OK", "OK", float64(5)}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + conn := exec.ConnectToServer() + // delete existing key + _, err := exec.FireCommand(conn, WebsocketCommand{ + Message: "del k", + }) + testifyAssert.NoError(t, err) + + for i, cmd := range tc.commands { + result, err := exec.FireCommand(conn, cmd) + testifyAssert.NoError(t, err) + assert.DeepEqual(t, tc.expected[i], result) + } + }) + } +} + +func TestSetWithOptions(t *testing.T) { + exec := NewWebsocketCommandExecutor() + expiryTime := strconv.FormatInt(time.Now().Add(1*time.Minute).UnixMilli(), 10) + + testCases := []TestCase{ + { + name: "Set with EX option", + commands: []WebsocketCommand{ + {Message: "set k v ex 3"}, + {Message: "get k"}, + {Message: "sleep 3"}, + {Message: "get k"}, + }, + expected: []interface{}{"OK", "v", "OK", "(nil)"}, + }, + { + name: "Set with PX option", + commands: []WebsocketCommand{ + {Message: "set k v px 2000"}, + {Message: "get k"}, + {Message: "sleep 3"}, + {Message: "get k"}, + }, + expected: []interface{}{"OK", "v", "OK", "(nil)"}, + }, + { + name: "Set with EX and PX option", + commands: []WebsocketCommand{ + {Message: "set k v ex 2 px 2000"}, + }, + expected: []interface{}{"ERR syntax error"}, + }, + { + name: "XX on non-existing key", + commands: []WebsocketCommand{ + {Message: "del k"}, + {Message: "set k v xx true"}, + {Message: "get k"}, + }, + expected: []interface{}{float64(0), "(nil)", "(nil)"}, + }, + { + name: "NX on non-existing key", + commands: []WebsocketCommand{ + {Message: "del k"}, + {Message: "set k v nx"}, + {Message: "get k"}, + }, + expected: []interface{}{float64(0), "OK", "v"}, + }, + { + name: "NX on existing key", + commands: []WebsocketCommand{ + {Message: "del k"}, + {Message: "set k v nx"}, + {Message: "get k"}, + {Message: "set k v nx"}, + }, + expected: []interface{}{float64(0), "OK", "v", "(nil)"}, + }, + { + name: "PXAT option", + commands: []WebsocketCommand{ + {Message: fmt.Sprintf("set k v pxat %v", expiryTime)}, + {Message: "get k"}, + }, + expected: []interface{}{"OK", "v"}, + }, + { + name: "PXAT option with delete", + commands: []WebsocketCommand{ + {Message: fmt.Sprintf("set k1 v1 pxat %v", expiryTime)}, + {Message: "get k1"}, + {Message: "sleep 4"}, + {Message: "del k1"}, + }, + expected: []interface{}{"OK", "v1", "OK", float64(1)}, + }, + { + name: "PXAT option with invalid unix time ms", + commands: []WebsocketCommand{ + {Message: "set k2 v2 pxat 123123"}, + {Message: "get k2"}, + }, + expected: []interface{}{"OK", "(nil)"}, + }, + { + name: "XX on existing key", + commands: []WebsocketCommand{ + {Message: "set k v2"}, + {Message: "set k v2 xx"}, + {Message: "get k"}, + }, + expected: []interface{}{"OK", "OK", "v2"}, + }, + { + name: "Multiple XX operations", + commands: []WebsocketCommand{ + {Message: "set k v1"}, + {Message: "set k v2 xx"}, + {Message: "set k v3 xx"}, + {Message: "get k"}, + }, + expected: []interface{}{"OK", "OK", "OK", "v3"}, + }, + { + name: "EX option", + commands: []WebsocketCommand{ + {Message: "set k v ex 1"}, + {Message: "get k"}, + {Message: "sleep 2"}, + {Message: "get k"}, + }, + expected: []interface{}{"OK", "v", "OK", "(nil)"}, + }, + { + name: "XX option", + commands: []WebsocketCommand{ + {Message: "set k v xx ex 1"}, + {Message: "get k"}, + {Message: "sleep 2"}, + {Message: "get k"}, + {Message: "set k v xx ex 1"}, + {Message: "get k"}, + }, + expected: []interface{}{"(nil)", "(nil)", "OK", "(nil)", "(nil)", "(nil)"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + conn := exec.ConnectToServer() + exec.FireCommand(conn, WebsocketCommand{Message: "del k"}) + exec.FireCommand(conn, WebsocketCommand{Message: "del k1"}) + exec.FireCommand(conn, WebsocketCommand{Message: "del k2"}) + for i, cmd := range tc.commands { + result, err := exec.FireCommand(conn, cmd) + assert.NilError(t, err) + assert.Equal(t, tc.expected[i], result) + } + }) + } +} + +func TestSetWithExat(t *testing.T) { + exec := NewWebsocketCommandExecutor() + Etime := strconv.FormatInt(time.Now().Unix()+5, 10) + BadTime := "123123" + + testCases := []TestCase{ + { + name: "SET with EXAT", + commands: []WebsocketCommand{ + {Message: "del k"}, + {Message: fmt.Sprintf("set k v exat %v", Etime)}, + {Message: "get k"}, + {Message: "ttl k"}, + }, + expected: []interface{}{float64(0), "OK", "v", float64(4)}, + }, + { + name: "SET with invalid EXAT expires key immediately", + commands: []WebsocketCommand{ + {Message: "del k"}, + {Message: fmt.Sprintf("set k v exat %v", BadTime)}, + {Message: "get k"}, + {Message: "ttl k"}, + }, + expected: []interface{}{float64(0), "OK", "(nil)", float64(-2)}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + conn := exec.ConnectToServer() + // Ensure key is deleted before the test + exec.FireCommand(conn, WebsocketCommand{ + Message: "del k", + }) + + for i, cmd := range tc.commands { + result, err := exec.FireCommand(conn, cmd) + assert.NilError(t, err) + command := strings.Split(cmd.Message, "") + if command[0] == "ttl" { + assert.Assert(t, result.(float64) <= tc.expected[i].(float64)) + } else { + assert.DeepEqual(t, tc.expected[i], result) + } + } + }) + } +} diff --git a/integration_tests/commands/websocket/setup.go b/integration_tests/commands/websocket/setup.go new file mode 100644 index 000000000..8285a008a --- /dev/null +++ b/integration_tests/commands/websocket/setup.go @@ -0,0 +1,131 @@ +package websocket + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "log/slog" + "net/http" + "sync" + "time" + + "github.com/dicedb/dice/config" + derrors "github.com/dicedb/dice/internal/errors" + "github.com/dicedb/dice/internal/server" + "github.com/dicedb/dice/internal/shard" + dstore "github.com/dicedb/dice/internal/store" + "github.com/gorilla/websocket" +) + +const url = "ws://localhost:8380" + +type TestServerOptions struct { + Port int + Logger *slog.Logger +} + +type CommandExecutor interface { + FireCommand(cmd string) interface{} + Name() string +} + +type WebsocketCommandExecutor struct { + baseURL string + websocketClient *http.Client + upgrader websocket.Upgrader +} + +func NewWebsocketCommandExecutor() *WebsocketCommandExecutor { + return &WebsocketCommandExecutor{ + baseURL: url, + websocketClient: &http.Client{ + Timeout: time.Second * 100, + }, + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + }, + } +} + +type WebsocketCommand struct { + Message string +} + +func (e *WebsocketCommandExecutor) ConnectToServer() *websocket.Conn { + // connect with Websocket Server + conn, resp, err := websocket.DefaultDialer.Dial(url, nil) + if err != nil { + return nil + } + if resp != nil { + resp.Body.Close() + } + return conn +} + +func (e *WebsocketCommandExecutor) FireCommand(conn *websocket.Conn, cmd WebsocketCommand) (interface{}, error) { + command := []byte(cmd.Message) + + // send request + err := conn.WriteMessage(websocket.TextMessage, command) + if err != nil { + return nil, fmt.Errorf("error sending websocket request: %v", err) + } + + // read the response + _, resp, err := conn.ReadMessage() + if err != nil { + return nil, fmt.Errorf("error reading websocket response: %v", err) + } + + // marshal to json + var respJSON interface{} + if err = json.Unmarshal(resp, &respJSON); err != nil { + return nil, fmt.Errorf("error unmarshalling response: %v", err) + } + + return respJSON, nil +} + +func (e *WebsocketCommandExecutor) DisconnectServer(conn *websocket.Conn) { + conn.Close() +} + +func (e *WebsocketCommandExecutor) Name() string { + return "Websocket" +} + +func RunWebsocketServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOptions) { + config.DiceConfig.Network.IOBufferLength = 16 + config.DiceConfig.Server.WriteAOFOnCleanup = false + + // Initialize the WebsocketServer + globalErrChannel := make(chan error) + watchChan := make(chan dstore.WatchEvent, config.DiceConfig.Server.KeysLimit) + shardManager := shard.NewShardManager(1, watchChan, globalErrChannel, opt.Logger) + config.WebsocketPort = opt.Port + testServer := server.NewWebSocketServer(shardManager, watchChan, opt.Logger) + + shardManagerCtx, cancelShardManager := context.WithCancel(ctx) + wg.Add(1) + go func() { + defer wg.Done() + shardManager.Run(shardManagerCtx) + }() + + // Start the server in a goroutine + wg.Add(1) + go func() { + defer wg.Done() + srverr := testServer.Run(ctx) + if srverr != nil { + cancelShardManager() + if errors.Is(srverr, derrors.ErrAborted) { + return + } + log.Printf("Websocket test server encountered an error: %v", srverr) + } + }() +} diff --git a/internal/errors/errors.go b/internal/errors/errors.go index ac8db89c6..0e5764355 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -32,7 +32,8 @@ const ( ) var ( - ErrAborted = errors.New("server received ABORT command") + ErrAborted = errors.New("server received ABORT command") + ErrEmptyCommand = errors.New("empty command") ) type DiceError struct { diff --git a/internal/eval/execute.go b/internal/eval/execute.go index fcc0b5c27..98c1a0474 100644 --- a/internal/eval/execute.go +++ b/internal/eval/execute.go @@ -11,12 +11,21 @@ import ( dstore "github.com/dicedb/dice/internal/store" ) -func ExecuteCommand(c *cmd.RedisCmd, client *comm.Client, store *dstore.Store, httpOp bool) EvalResponse { +func ExecuteCommand(c *cmd.RedisCmd, client *comm.Client, store *dstore.Store, httpOp, websocketOp bool) EvalResponse { diceCmd, ok := DiceCmds[c.Cmd] if !ok { return EvalResponse{Result: diceerrors.NewErrWithFormattedMessage("unknown command '%s', with args beginning with: %s", c.Cmd, strings.Join(c.Args, " ")), Error: nil} } + // Till the time we refactor to handle QWATCH differently for websocket + if websocketOp { + if diceCmd.IsMigrated { + return diceCmd.NewEval(c.Args, store) + } + + return EvalResponse{Result: diceCmd.Eval(c.Args, store), Error: nil} + } + // Temporary logic till we move all commands to new eval logic. // MigratedDiceCmds map contains refactored eval commands // For any command we will first check in the exisiting map diff --git a/internal/ops/store_op.go b/internal/ops/store_op.go index 56278c7ce..a11d7af4a 100644 --- a/internal/ops/store_op.go +++ b/internal/ops/store_op.go @@ -7,13 +7,14 @@ import ( ) type StoreOp struct { - SeqID uint8 // SeqID is the sequence id of the operation within a single request (optional, may be used for ordering) - RequestID uint32 // RequestID identifies the request that this StoreOp belongs to - Cmd *cmd.RedisCmd // Cmd is the atomic Store command (e.g., GET, SET) - ShardID uint8 // ShardID of the shard on which the Store command will be executed - WorkerID string // WorkerID is the ID of the worker that sent this Store operation - Client *comm.Client // Client that sent this Store operation. TODO: This can potentially replace the WorkerID in the future - HTTPOp bool // HTTPOp is true if this Store operation is a HTTP operation + SeqID uint8 // SeqID is the sequence id of the operation within a single request (optional, may be used for ordering) + RequestID uint32 // RequestID identifies the request that this StoreOp belongs to + Cmd *cmd.RedisCmd // Cmd is the atomic Store command (e.g., GET, SET) + ShardID uint8 // ShardID of the shard on which the Store command will be executed + WorkerID string // WorkerID is the ID of the worker that sent this Store operation + Client *comm.Client // Client that sent this Store operation. TODO: This can potentially replace the WorkerID in the future + HTTPOp bool // HTTPOp is true if this Store operation is a HTTP operation + WebsocketOp bool // WebsocketOp is true if this Store operaton is a Websocket operation } // StoreResponse represents the response of a Store operation. diff --git a/internal/server/utils/constants.go b/internal/server/utils/constants.go index a3687501b..1254f830c 100644 --- a/internal/server/utils/constants.go +++ b/internal/server/utils/constants.go @@ -12,4 +12,5 @@ const ( NullType string = "null" UnknownType string = "unknown" NumberZeroValue int = 0 + JSONIngest string = "JSON.INGEST" ) diff --git a/internal/server/utils/redisCmdAdapter.go b/internal/server/utils/redisCmdAdapter.go index 3b1db61b1..7bf4b05be 100644 --- a/internal/server/utils/redisCmdAdapter.go +++ b/internal/server/utils/redisCmdAdapter.go @@ -9,9 +9,11 @@ import ( "strings" "github.com/dicedb/dice/internal/cmd" + diceerrors "github.com/dicedb/dice/internal/errors" ) const ( + Command = "command" Key = "key" Keys = "keys" KeyPrefix = "key_prefix" @@ -43,7 +45,7 @@ func ParseHTTPRequest(r *http.Request) (*cmd.RedisCmd, error) { queryParams := r.URL.Query() keyPrefix := queryParams.Get(KeyPrefix) - if keyPrefix != "" && command == "JSON.INGEST" { + if keyPrefix != "" && command == JSONIngest { args = append(args, keyPrefix) } // Step 1: Handle JSON body if present @@ -152,3 +154,27 @@ func ParseHTTPRequest(r *http.Request) (*cmd.RedisCmd, error) { Args: args, }, nil } + +func ParseWebsocketMessage(msg []byte) (*cmd.RedisCmd, error) { + cmdStr := string(msg) + cmdStr = strings.TrimSpace(cmdStr) + + if cmdStr == "" { + return nil, diceerrors.ErrEmptyCommand + } + + cmdArr := strings.Split(cmdStr, " ") + command := strings.ToUpper(cmdArr[0]) + cmdArr = cmdArr[1:] // args + + // if key prefix is empty for JSON.INGEST command + // add "" to cmdArr + if command == JSONIngest && len(cmdArr) == 2 { + cmdArr = append([]string{""}, cmdArr...) + } + + return &cmd.RedisCmd{ + Cmd: command, + Args: cmdArr, + }, nil +} diff --git a/internal/server/utils/redisCmdAdapter_test.go b/internal/server/utils/redisCmdAdapter_test.go index ead0e5f8a..1b85763a4 100644 --- a/internal/server/utils/redisCmdAdapter_test.go +++ b/internal/server/utils/redisCmdAdapter_test.go @@ -186,3 +186,74 @@ func TestParseHTTPRequest(t *testing.T) { }) } } + +func TestParseWebsocketMessage(t *testing.T) { + commands := []struct { + name string + message string + expectedCmd string + expectedArgs []string + }{ + { + name: "Test SET command with nx flag", + message: "set k1 v1 nx", + expectedCmd: "SET", + expectedArgs: []string{"k1", "v1", "nx"}, + }, + { + name: "Test GET command", + message: "get k1", + expectedCmd: "GET", + expectedArgs: []string{"k1"}, + }, + { + name: "Test JSON.SET command", + message: `json.set k1 . {"field":"value"}`, + expectedCmd: "JSON.SET", + expectedArgs: []string{"k1", ".", `{"field":"value"}`}, + }, + { + name: "Test JSON.GET command", + message: "json.get k1", + expectedCmd: "JSON.GET", + expectedArgs: []string{"k1"}, + }, + { + name: "Test HSET command with JSON body", + message: "hset hashkey f1 v1", + expectedCmd: "HSET", + expectedArgs: []string{"hashkey", "f1", "v1"}, + }, + { + name: "Test JSON.INGEST command with key prefix", + message: `json.ingest gmtr_ $..field {"field":"value"}`, + expectedCmd: "JSON.INGEST", + expectedArgs: []string{"gmtr_", "$..field", `{"field":"value"}`}, + }, + { + name: "Test JSON.INGEST command without key prefix", + message: `json.ingest $..field {"field":"value"}`, + expectedCmd: "JSON.INGEST", + expectedArgs: []string{"", "$..field", `{"field":"value"}`}, + }, + } + + for _, tc := range commands { + t.Run(tc.name, func(t *testing.T) { + // parse websocket message + redisCmd, err := ParseWebsocketMessage([]byte(tc.message)) + assert.NoError(t, err) + + expectedCmd := &cmd.RedisCmd{ + Cmd: tc.expectedCmd, + Args: tc.expectedArgs, + } + + // Check command match + assert.Equal(t, expectedCmd.Cmd, redisCmd.Cmd) + + // Check arguments match, regardless of order + assert.ElementsMatch(t, expectedCmd.Args, redisCmd.Args, "The parsed arguments should match the expected arguments, ignoring order") + }) + } +} diff --git a/internal/server/websocketServer.go b/internal/server/websocketServer.go new file mode 100644 index 000000000..9e04eeaee --- /dev/null +++ b/internal/server/websocketServer.go @@ -0,0 +1,192 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "sync" + "time" + + "github.com/dicedb/dice/config" + "github.com/dicedb/dice/internal/clientio" + diceerrors "github.com/dicedb/dice/internal/errors" + "github.com/dicedb/dice/internal/ops" + "github.com/dicedb/dice/internal/querywatcher" + "github.com/dicedb/dice/internal/server/utils" + "github.com/dicedb/dice/internal/shard" + dstore "github.com/dicedb/dice/internal/store" + "github.com/gorilla/websocket" +) + +const Qwatch = "QWATCH" +const Qunwatch = "QUNWATCH" +const Subscribe = "SUBSCRIBE" + +var unimplementedCommandsWebsocket map[string]bool = map[string]bool{ + Qwatch: true, + Qunwatch: true, + Subscribe: true, +} + +type WebsocketServer struct { + querywatcher *querywatcher.QueryManager + shardManager *shard.ShardManager + ioChan chan *ops.StoreResponse + watchChan chan dstore.WatchEvent + websocketServer *http.Server + upgrader websocket.Upgrader + logger *slog.Logger + shutdownChan chan struct{} +} + +func NewWebSocketServer(shardManager *shard.ShardManager, watchChan chan dstore.WatchEvent, logger *slog.Logger) *WebsocketServer { + mux := http.NewServeMux() + srv := &http.Server{ + Addr: fmt.Sprintf(":%d", config.WebsocketPort), + Handler: mux, + ReadHeaderTimeout: 5 * time.Second, + } + + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + + websocketServer := &WebsocketServer{ + shardManager: shardManager, + querywatcher: querywatcher.NewQueryManager(logger), + ioChan: make(chan *ops.StoreResponse, 1000), + watchChan: watchChan, + websocketServer: srv, + upgrader: upgrader, + logger: logger, + shutdownChan: make(chan struct{}), + } + + mux.HandleFunc("/", websocketServer.WebsocketHandler) + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("OK")) + if err != nil { + return + } + }) + return websocketServer +} + +func (s *WebsocketServer) Run(ctx context.Context) error { + var wg sync.WaitGroup + var err error + + websocketCtx, cancelWebsocket := context.WithCancel(ctx) + defer cancelWebsocket() + + s.shardManager.RegisterWorker("wsServer", s.ioChan) + + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + case <-s.shutdownChan: + err = diceerrors.ErrAborted + s.logger.Debug("Shutting down Websocket Server") + } + + shutdownErr := s.websocketServer.Shutdown(websocketCtx) + if shutdownErr != nil { + s.logger.Error("Websocket Server shutdown failed:", slog.Any("error", err)) + return + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + s.logger.Info("Websocket Server running", slog.String("port", s.websocketServer.Addr[1:])) + err = s.websocketServer.ListenAndServe() + }() + + wg.Wait() + return err +} + +func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Request) { + // upgrade http connection to websocket + conn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + // closing handshake + defer func() { + _ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "close 1000 (normal)")) + conn.Close() + }() + + for { + // read incoming message + _, msg, err := conn.ReadMessage() + if err != nil { + writeResponse(conn, []byte("error: command reading failed")) + continue + } + + // parse message to dice command + redisCmd, err := utils.ParseWebsocketMessage(msg) + if errors.Is(err, diceerrors.ErrEmptyCommand) { + continue + } else if err != nil { + writeResponse(conn, []byte("error: parsing failed")) + continue + } + + if redisCmd.Cmd == Abort { + close(s.shutdownChan) + break + } + + if unimplementedCommandsWebsocket[redisCmd.Cmd] { + writeResponse(conn, []byte("Command is not implemented with Websocket")) + continue + } + + // send request to Shard Manager + s.shardManager.GetShard(0).ReqChan <- &ops.StoreOp{ + Cmd: redisCmd, + WorkerID: "wsServer", + ShardID: 0, + WebsocketOp: true, + } + + // Wait for response + resp := <-s.ioChan + var rp *clientio.RESPParser + if resp.EvalResponse.Error != nil { + rp = clientio.NewRESPParser(bytes.NewBuffer([]byte(resp.EvalResponse.Error.Error()))) + } else { + rp = clientio.NewRESPParser(bytes.NewBuffer(resp.EvalResponse.Result.([]byte))) + } + + val, err := rp.DecodeOne() + if err != nil { + writeResponse(conn, []byte("error: decoding response")) + continue + } + + respBytes, err := json.Marshal(val) + if err != nil { + writeResponse(conn, []byte("error: marshaling json response")) + continue + } + + // Write response + writeResponse(conn, respBytes) + } +} + +func writeResponse(conn *websocket.Conn, text []byte) { + _ = conn.WriteMessage(websocket.TextMessage, text) +} diff --git a/internal/shard/shard_thread.go b/internal/shard/shard_thread.go index 088688ea4..aea14b792 100644 --- a/internal/shard/shard_thread.go +++ b/internal/shard/shard_thread.go @@ -89,7 +89,7 @@ func (shard *ShardThread) unregisterWorker(workerID string) { // processRequest processes a Store operation for the shard. func (shard *ShardThread) processRequest(op *ops.StoreOp) { - resp := eval.ExecuteCommand(op.Cmd, op.Client, shard.store, op.HTTPOp) + resp := eval.ExecuteCommand(op.Cmd, op.Client, shard.store, op.HTTPOp, op.WebsocketOp) shard.workerMutex.RLock() workerChan, ok := shard.workerMap[op.WorkerID] diff --git a/main.go b/main.go index 4e018c5d6..e462d3796 100644 --- a/main.go +++ b/main.go @@ -27,6 +27,7 @@ func init() { flag.BoolVar(&config.EnableHTTP, "enable-http", true, "run server in HTTP mode as well") flag.BoolVar(&config.EnableMultiThreading, "enable-multithreading", false, "run server in multithreading mode") flag.IntVar(&config.HTTPPort, "http-port", 8082, "HTTP port for the dice server") + flag.IntVar(&config.WebsocketPort, "websocket-port", 8379, "Websocket port for the dice server") flag.StringVar(&config.RequirePass, "requirepass", config.RequirePass, "enable authentication for the default user") flag.StringVar(&config.CustomConfigFilePath, "o", config.CustomConfigFilePath, "dir path to create the config file") flag.StringVar(&config.ConfigFileLocation, "c", config.ConfigFileLocation, "file path of the config file") @@ -182,6 +183,26 @@ func main() { }() } + websocketServer := server.NewWebSocketServer(shardManager, watchChan, logr) + serverWg.Add(1) + go func() { + defer serverWg.Done() + // Run the Websocket server + err := websocketServer.Run(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + logr.Debug("Websocket Server was canceled") + } else if errors.Is(err, diceerrors.ErrAborted) { + logr.Debug("Websocket received abort command") + } else { + logr.Error("Websocket Server error", "error", err) + } + serverErrCh <- err + } else { + logr.Debug("Websocket Server stopped without error") + } + }() + go func() { serverWg.Wait() close(serverErrCh) // Close the channel when both servers are done