From f6cbfe135f01bda68df65ff2ab5aa174a4811816 Mon Sep 17 00:00:00 2001 From: John Hosie Date: Thu, 19 Oct 2023 22:54:47 +0100 Subject: [PATCH] namespace scoped web sockets Signed-off-by: John Hosie --- Makefile | 1 + internal/apiserver/server.go | 20 +++ internal/apiserver/server_test.go | 36 ++++ internal/coremsgs/en_error_messages.go | 1 + .../events/websockets/websocket_connection.go | 69 ++++++-- internal/events/websockets/websockets.go | 23 +++ internal/events/websockets/websockets_test.go | 165 +++++++++++++++++- .../websocketsmocks/web_sockets_namespaced.go | 33 ++++ 8 files changed, 327 insertions(+), 21 deletions(-) create mode 100644 mocks/websocketsmocks/web_sockets_namespaced.go diff --git a/Makefile b/Makefile index f675bc5a5..20d3e7f71 100644 --- a/Makefile +++ b/Makefile @@ -79,6 +79,7 @@ $(eval $(call makemock, internal/operations, Manager, operat $(eval $(call makemock, internal/multiparty, Manager, multipartymocks)) $(eval $(call makemock, internal/apiserver, FFISwaggerGen, apiservermocks)) $(eval $(call makemock, internal/apiserver, Server, apiservermocks)) +$(eval $(call makemock, internal/events/websockets, WebSocketsNamespaced, websocketsmocks)) firefly-nocgo: ${GOFILES} CGO_ENABLED=0 $(VGO) build -o ${BINARY_NAME}-nocgo -ldflags "-X main.buildDate=$(DATE) -X main.buildVersion=$(BUILD_VERSION) -X 'github.com/hyperledger/firefly/cmd.BuildVersionOverride=$(BUILD_VERSION)' -X 'github.com/hyperledger/firefly/cmd.BuildDate=$(DATE)' -X 'github.com/hyperledger/firefly/cmd.BuildCommit=$(GIT_REF)'" -tags=prod -tags=prod -v diff --git a/internal/apiserver/server.go b/internal/apiserver/server.go index 5dff12d6b..46a57ba22 100644 --- a/internal/apiserver/server.go +++ b/internal/apiserver/server.go @@ -385,6 +385,9 @@ func (as *apiServer) createMuxRouter(ctx context.Context, mgr namespace.Manager) ws.(*websockets.WebSockets).SetAuthorizer(mgr) r.HandleFunc(`/ws`, ws.(*websockets.WebSockets).ServeHTTP) + // namespace scoped web sockets + r.HandleFunc("/api/v1/namespaces/{ns}/ws", hf.APIWrapper(getNamespacedWebSocketHandler(ws.(*websockets.WebSockets), mgr))) + uiPath := config.GetString(coreconfig.UIPath) if uiPath != "" && config.GetBool(coreconfig.UIEnabled) { r.PathPrefix(`/ui`).Handler(newStaticHandler(uiPath, "index.html", `/ui`)) @@ -394,6 +397,23 @@ func (as *apiServer) createMuxRouter(ctx context.Context, mgr namespace.Manager) return r } +func getNamespacedWebSocketHandler(ws websockets.WebSocketsNamespaced, mgr namespace.Manager) ffapi.HandlerFunction { + return func(res http.ResponseWriter, req *http.Request) (status int, err error) { + + vars := mux.Vars(req) + namespace := vars["ns"] + or, err := mgr.Orchestrator(req.Context(), namespace, false) + if err != nil || or == nil { + return 404, i18n.NewError(req.Context(), coremsgs.Msg404NotFound) + } + + ws.ServeHTTPNamespaced(namespace, res, req) + + return 200, nil + } + +} + func (as *apiServer) notFoundHandler(res http.ResponseWriter, req *http.Request) (status int, err error) { res.Header().Add("Content-Type", "application/json") return 404, i18n.NewError(req.Context(), coremsgs.Msg404NotFound) diff --git a/internal/apiserver/server_test.go b/internal/apiserver/server_test.go index 8a5ce5ae7..8d61e03c0 100644 --- a/internal/apiserver/server_test.go +++ b/internal/apiserver/server_test.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "mime/multipart" @@ -43,6 +44,7 @@ import ( "github.com/hyperledger/firefly/mocks/namespacemocks" "github.com/hyperledger/firefly/mocks/orchestratormocks" "github.com/hyperledger/firefly/mocks/spieventsmocks" + "github.com/hyperledger/firefly/mocks/websocketsmocks" "github.com/hyperledger/firefly/pkg/core" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -513,3 +515,37 @@ func TestGetOrchestratorMissingTag(t *testing.T) { _, err := getOrchestrator(context.Background(), &namespacemocks.Manager{}, "", nil) assert.Regexp(t, "FF10437", err) } + +func TestGetNamespacedWebSocketHandler(t *testing.T) { + mgr, _, _ := newTestServer() + mwsns := &websocketsmocks.WebSocketsNamespaced{} + mwsns.On("ServeHTTPNamespaced", "ns1", mock.Anything, mock.Anything).Return() + + var b bytes.Buffer + req := httptest.NewRequest("GET", "/api/v1/namespaces/ns1/ws", &b) + req = mux.SetURLVars(req, map[string]string{"ns": "ns1"}) + req.Header.Set("Content-Type", "application/json; charset=utf-8") + res := httptest.NewRecorder() + + handler := getNamespacedWebSocketHandler(mwsns, mgr) + status, err := handler(res, req) + assert.NoError(t, err) + assert.Equal(t, 200, status) +} + +func TestGetNamespacedWebSocketHandlerUnknownNamespace(t *testing.T) { + mgr, _, _ := newTestServer() + mwsns := &websocketsmocks.WebSocketsNamespaced{} + + mgr.On("Orchestrator", mock.Anything, "unknown", false).Return(nil, errors.New("unknown namespace")).Maybe() + var b bytes.Buffer + req := httptest.NewRequest("GET", "/api/v1/namespaces/unknown/ws", &b) + req = mux.SetURLVars(req, map[string]string{"ns": "unknown"}) + req.Header.Set("Content-Type", "application/json; charset=utf-8") + res := httptest.NewRecorder() + + handler := getNamespacedWebSocketHandler(mwsns, mgr) + status, err := handler(res, req) + assert.Error(t, err) + assert.Equal(t, 404, status) +} diff --git a/internal/coremsgs/en_error_messages.go b/internal/coremsgs/en_error_messages.go index c7a87ff0d..bcd1067d6 100644 --- a/internal/coremsgs/en_error_messages.go +++ b/internal/coremsgs/en_error_messages.go @@ -300,4 +300,5 @@ var ( MsgTokensRESTErrConflict = ffe("FF10459", "Conflict from tokens service: %s", 409) MsgBatchWithDataNotSupported = ffe("FF10460", "Provided subscription '%s' enables batching and withData which is not supported", 400) MsgBatchDeliveryNotSupported = ffe("FF10461", "Batch delivery not supported by transport '%s'", 400) + MsgWSWrongNamespace = ffe("FF10462", "Websocket request received on a namespace scoped connection but the provided namespace does not match") ) diff --git a/internal/events/websockets/websocket_connection.go b/internal/events/websockets/websocket_connection.go index 70aa9b50f..e72d1eff4 100644 --- a/internal/events/websockets/websocket_connection.go +++ b/internal/events/websockets/websocket_connection.go @@ -38,23 +38,25 @@ type websocketStartedSub struct { } type websocketConnection struct { - ctx context.Context - ws *WebSockets - wsConn *websocket.Conn - cancelCtx func() - connID string - sendMessages chan interface{} - senderDone chan struct{} - receiverDone chan struct{} - autoAck bool - started []*websocketStartedSub - inflight []*core.EventDeliveryResponse - mux sync.Mutex - closed bool - remoteAddr string - userAgent string - header http.Header - auth core.Authorizer + ctx context.Context + ws *WebSockets + wsConn *websocket.Conn + cancelCtx func() + connID string + sendMessages chan interface{} + senderDone chan struct{} + receiverDone chan struct{} + autoAck bool + started []*websocketStartedSub + inflight []*core.EventDeliveryResponse + mux sync.Mutex + closed bool + remoteAddr string + userAgent string + header http.Header + auth core.Authorizer + namespaceScoped bool // if true then any request to listen is asserted to be in the context of namespace + namespace string } func newConnection(pCtx context.Context, ws *WebSockets, wsConn *websocket.Conn, req *http.Request, auth core.Authorizer) *websocketConnection { @@ -80,6 +82,18 @@ func newConnection(pCtx context.Context, ws *WebSockets, wsConn *websocket.Conn, return wc } +func (wc *websocketConnection) assertNamespace(namespace string) (string, error) { + + if wc.namespaceScoped { + if namespace == "" { + namespace = wc.namespace + } else if namespace != wc.namespace { + return "", i18n.NewError(wc.ctx, coremsgs.MsgWSWrongNamespace) + } + } + return namespace, nil +} + // processAutoStart gives a helper to specify query parameters to auto-start your subscription func (wc *websocketConnection) processAutoStart(req *http.Request) { query := req.URL.Query() @@ -88,12 +102,18 @@ func (wc *websocketConnection) processAutoStart(req *http.Request) { _, hasName := query["name"] autoAck, hasAutoack := req.URL.Query()["autoack"] isAutoack := hasAutoack && (len(autoAck) == 0 || autoAck[0] != "false") + namespace, err := wc.assertNamespace(query.Get("namespace")) + if err != nil { + wc.protocolError(err) + return + } + if hasEphemeral || hasName { filter := core.NewSubscriptionFilterFromQuery(query) err := wc.handleStart(&core.WSStart{ AutoAck: &isAutoack, Ephemeral: isEphemeral, - Namespace: query.Get("namespace"), + Namespace: namespace, Name: query.Get("name"), Filter: filter, }) @@ -157,7 +177,10 @@ func (wc *websocketConnection) receiveLoop() { var msg core.WSStart err = json.Unmarshal(msgData, &msg) if err == nil { - err = wc.authorizeMessage(msg.Namespace) + msg.Namespace, err = wc.assertNamespace(msg.Namespace) + if err == nil { + err = wc.authorizeMessage(msg.Namespace) + } if err == nil { err = wc.handleStart(&msg) } @@ -251,6 +274,14 @@ func (wc *websocketConnection) restartForNamespace(ns string, startTime time.Tim } func (wc *websocketConnection) handleStart(start *core.WSStart) (err error) { + // this will very likely already be checked before we get here but + // it doesn't do any harm to do a final assertion just in case it hasn't been done yet + + start.Namespace, err = wc.assertNamespace(start.Namespace) + if err != nil { + return err + } + wc.mux.Lock() if start.AutoAck != nil { if *start.AutoAck != wc.autoAck && len(wc.started) > 0 { diff --git a/internal/events/websockets/websockets.go b/internal/events/websockets/websockets.go index 4388aff16..a12a325dd 100644 --- a/internal/events/websockets/websockets.go +++ b/internal/events/websockets/websockets.go @@ -31,6 +31,10 @@ import ( "github.com/hyperledger/firefly/pkg/events" ) +type WebSocketsNamespaced interface { + ServeHTTPNamespaced(namespace string, res http.ResponseWriter, req *http.Request) +} + type WebSockets struct { ctx context.Context capabilities *events.Capabilities @@ -122,6 +126,25 @@ func (ws *WebSockets) ServeHTTP(res http.ResponseWriter, req *http.Request) { wc.processAutoStart(req) } +func (ws *WebSockets) ServeHTTPNamespaced(namespace string, res http.ResponseWriter, req *http.Request) { + + wsConn, err := ws.upgrader.Upgrade(res, req, nil) + if err != nil { + log.L(ws.ctx).Errorf("WebSocket upgrade failed: %s", err) + return + } + + ws.connMux.Lock() + wc := newConnection(ws.ctx, ws, wsConn, req, ws.auth) + wc.namespaceScoped = true + wc.namespace = namespace + ws.connections[wc.connID] = wc + ws.connMux.Unlock() + + wc.processAutoStart(req) + +} + func (ws *WebSockets) ack(connID string, inflight *core.EventDeliveryResponse) { if cb, ok := ws.callbacks.handlers[inflight.Subscription.Namespace]; ok { cb.DeliveryResponse(connID, inflight) diff --git a/internal/events/websockets/websockets_test.go b/internal/events/websockets/websockets_test.go index 6947d66e7..e2f837988 100644 --- a/internal/events/websockets/websockets_test.go +++ b/internal/events/websockets/websockets_test.go @@ -51,6 +51,18 @@ func (t *testAuthorizer) Authorize(ctx context.Context, authReq *fftypes.AuthReq } func newTestWebsockets(t *testing.T, cbs *eventsmocks.Callbacks, authorizer core.Authorizer, queryParams ...string) (ws *WebSockets, wsc wsclient.WSClient, cancel func()) { + return newTestWebsocketsCommon(t, cbs, authorizer, "", queryParams...) +} + +type testNamespacedHandler struct { + ws *WebSockets + namespace string +} + +func (h *testNamespacedHandler) ServeHTTP(res http.ResponseWriter, req *http.Request) { + h.ws.ServeHTTPNamespaced(h.namespace, res, req) +} +func newTestWebsocketsCommon(t *testing.T, cbs *eventsmocks.Callbacks, authorizer core.Authorizer, namespace string, queryParams ...string) (ws *WebSockets, wsc wsclient.WSClient, cancel func()) { coreconfig.Reset() ws = &WebSockets{} @@ -63,8 +75,16 @@ func newTestWebsockets(t *testing.T, cbs *eventsmocks.Callbacks, authorizer core assert.Equal(t, "websockets", ws.Name()) assert.NotNil(t, ws.Capabilities()) cbs.On("ConnectionClosed", mock.Anything).Return(nil).Maybe() - - svr := httptest.NewServer(ws) + var svr *httptest.Server + if namespace == "" { + svr = httptest.NewServer(ws) + } else { + namespacedHandler := &testNamespacedHandler{ + ws: ws, + namespace: namespace, + } + svr = httptest.NewServer(namespacedHandler) + } clientConfig := config.RootSection("ut.wsclient") wsclient.InitConfig(clientConfig) @@ -820,3 +840,144 @@ func TestEventDeliveryBatchReturnsUnsupported(t *testing.T) { err := ws.BatchDeliveryRequest(ws.ctx, "id", sub, []*core.CombinedEventDataDelivery{}) assert.Regexp(t, "FF10461", err) } + +func TestNamespaceScopedSendWrongNamespaceStartAction(t *testing.T) { + cbs := &eventsmocks.Callbacks{} + _, wsc, cancel := newTestWebsocketsCommon(t, cbs, nil, "ns1") + defer cancel() + cbs.On("ConnectionClosed", mock.Anything).Return(nil) + + err := wsc.Send(context.Background(), []byte(`{"type":"start","namespace":"ns2"}`)) + assert.NoError(t, err) + b := <-wsc.Receive() + var res core.WSError + err = json.Unmarshal(b, &res) + assert.NoError(t, err) + assert.Equal(t, core.WSProtocolErrorEventType, res.Type) + assert.Regexp(t, "FF10462", res.Error) +} + +func TestNamespaceScopedSendWrongNamespaceQueryParameter(t *testing.T) { + cbs := &eventsmocks.Callbacks{} + _, wsc, cancel := newTestWebsocketsCommon(t, cbs, nil, "ns1", "namespace=ns2") + defer cancel() + cbs.On("ConnectionClosed", mock.Anything).Return(nil) + + b := <-wsc.Receive() + var res core.WSError + err := json.Unmarshal(b, &res) + assert.NoError(t, err) + assert.Equal(t, core.WSProtocolErrorEventType, res.Type) + assert.Regexp(t, "FF10462", res.Error) +} + +func TestNamespaceScopedUpgradeFail(t *testing.T) { + cbs := &eventsmocks.Callbacks{} + _, wsc, cancel := newTestWebsocketsCommon(t, cbs, nil, "ns1") + defer cancel() + + u, _ := url.Parse(wsc.URL()) + u.Scheme = "http" + res, err := http.Get(u.String()) + assert.NoError(t, err) + assert.Equal(t, 400, res.StatusCode) + +} + +func TestNamespaceScopedSuccess(t *testing.T) { + cbs := &eventsmocks.Callbacks{} + ws, wsc, cancel := newTestWebsocketsCommon(t, cbs, nil, "ns1") + defer cancel() + var connID string + sub := cbs.On("RegisterConnection", + mock.MatchedBy(func(s string) bool { connID = s; return true }), + mock.MatchedBy(func(subMatch events.SubscriptionMatcher) bool { + return subMatch(core.SubscriptionRef{Namespace: "ns1", Name: "sub1"}) && + !subMatch(core.SubscriptionRef{Namespace: "ns2", Name: "sub1"}) && + !subMatch(core.SubscriptionRef{Namespace: "ns1", Name: "sub2"}) + }), + ).Return(nil) + ack := cbs.On("DeliveryResponse", + mock.MatchedBy(func(s string) bool { return s == connID }), + mock.Anything).Return(nil) + + waitSubscribed := make(chan struct{}) + sub.RunFn = func(a mock.Arguments) { + close(waitSubscribed) + } + + waitAcked := make(chan struct{}) + ack.RunFn = func(a mock.Arguments) { + close(waitAcked) + } + + err := wsc.Send(context.Background(), []byte(`{"type":"start","name":"sub1"}`)) + assert.NoError(t, err) + + <-waitSubscribed + ws.DeliveryRequest(ws.ctx, connID, nil, &core.EventDelivery{ + EnrichedEvent: core.EnrichedEvent{ + Event: core.Event{ID: fftypes.NewUUID()}, + }, + Subscription: core.SubscriptionRef{ + ID: fftypes.NewUUID(), + Namespace: "ns1", + Name: "sub1", + }, + }, nil) + // Put a second in flight + ws.DeliveryRequest(ws.ctx, connID, nil, &core.EventDelivery{ + EnrichedEvent: core.EnrichedEvent{ + Event: core.Event{ID: fftypes.NewUUID()}, + }, + Subscription: core.SubscriptionRef{ + ID: fftypes.NewUUID(), + Namespace: "ns1", + Name: "sub2", + }, + }, nil) + + b := <-wsc.Receive() + var res core.EventDelivery + err = json.Unmarshal(b, &res) + assert.NoError(t, err) + + assert.Equal(t, "ns1", res.Subscription.Namespace) + assert.Equal(t, "sub1", res.Subscription.Name) + err = wsc.Send(context.Background(), []byte(fmt.Sprintf(`{ + "type":"ack", + "id": "%s", + "subscription": { + "namespace": "ns1", + "name": "sub1" + } + }`, res.ID))) + assert.NoError(t, err) + + <-waitAcked + + // Check we left the right one behind + conn := ws.connections[connID] + assert.Equal(t, 1, len(conn.inflight)) + assert.Equal(t, "sub2", conn.inflight[0].Subscription.Name) + + cbs.AssertExpectations(t) +} + +func TestHandleStartWrongNamespace(t *testing.T) { + + // it is not currently possible through exported functions to get to handleStart with the wrong namespace + // but we like to have a final assertion in there as a safety net for accidentaly data leakage across namespaces + // so to prove that safety net, we need to drive the private function handleStart directly. + wc := &websocketConnection{ + ctx: context.Background(), + namespaceScoped: true, + namespace: "ns1", + } + startMessage := &core.WSStart{ + Namespace: "ns2", + } + err := wc.handleStart(startMessage) + assert.Error(t, err) + assert.Regexp(t, "FF10462", err) +} diff --git a/mocks/websocketsmocks/web_sockets_namespaced.go b/mocks/websocketsmocks/web_sockets_namespaced.go new file mode 100644 index 000000000..8d0651a98 --- /dev/null +++ b/mocks/websocketsmocks/web_sockets_namespaced.go @@ -0,0 +1,33 @@ +// Code generated by mockery v2.36.0. DO NOT EDIT. + +package websocketsmocks + +import ( + http "net/http" + + mock "github.com/stretchr/testify/mock" +) + +// WebSocketsNamespaced is an autogenerated mock type for the WebSocketsNamespaced type +type WebSocketsNamespaced struct { + mock.Mock +} + +// ServeHTTPNamespaced provides a mock function with given fields: namespace, res, req +func (_m *WebSocketsNamespaced) ServeHTTPNamespaced(namespace string, res http.ResponseWriter, req *http.Request) { + _m.Called(namespace, res, req) +} + +// NewWebSocketsNamespaced creates a new instance of WebSocketsNamespaced. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewWebSocketsNamespaced(t interface { + mock.TestingT + Cleanup(func()) +}) *WebSocketsNamespaced { + mock := &WebSocketsNamespaced{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}