From 9195006c2ba9b37b0500e2224f0378a2620d22d5 Mon Sep 17 00:00:00 2001 From: "Valber M. Silva de Souza" Date: Tue, 7 Feb 2023 01:25:44 +0100 Subject: [PATCH] improving handling of invalid client messages --- src/Main/Exceptions.fs | 4 + src/Main/GraphQLWebsocketMiddleware.fs | 58 ++++--- src/Main/Main.fsproj | 1 + src/Main/MessageMapping.fs | 68 ++++++--- src/Main/Messages.fs | 4 +- src/Main/RawMessageConverter.fs | 15 +- tests/unit-tests/InvalidMessageTests.fs | 194 ++++++++++++++++++++++++ tests/unit-tests/Main.UnitTests.fsproj | 1 + tests/unit-tests/SerializationTests.fs | 19 +-- 9 files changed, 309 insertions(+), 55 deletions(-) create mode 100644 src/Main/Exceptions.fs create mode 100644 tests/unit-tests/InvalidMessageTests.fs diff --git a/src/Main/Exceptions.fs b/src/Main/Exceptions.fs new file mode 100644 index 0000000..042cc3f --- /dev/null +++ b/src/Main/Exceptions.fs @@ -0,0 +1,4 @@ +namespace GraphQLTransportWS + +type InvalidMessageException (explanation : string) = + inherit System.Exception(explanation) \ No newline at end of file diff --git a/src/Main/GraphQLWebsocketMiddleware.fs b/src/Main/GraphQLWebsocketMiddleware.fs index ad10b1a..ff54e0a 100644 --- a/src/Main/GraphQLWebsocketMiddleware.fs +++ b/src/Main/GraphQLWebsocketMiddleware.fs @@ -50,7 +50,9 @@ type GraphQLWebSocketMiddleware<'Root>(next : RequestDelegate, applicationLifeti task { return "dummySerializedServerMessage" } let deserializeClientMessage (jsonOptions: JsonOptions) (msg: string) = - task { return ConnectionInit None } + task { + return JsonSerializer.Deserialize(msg, jsonOptions.SerializerOptions) + } let isSocketOpen (theSocket : WebSocket) = not (theSocket.State = WebSocketState.Aborted) && @@ -61,7 +63,7 @@ type GraphQLWebSocketMiddleware<'Root>(next : RequestDelegate, applicationLifeti not (theSocket.State = WebSocketState.Aborted) && not (theSocket.State = WebSocketState.Closed) - let receiveMessageViaSocket (cancellationToken : CancellationToken) (jsonOptions: JsonOptions) (executor : Executor<'Root>) (replacements : Map) (socket : WebSocket) = + let receiveMessageViaSocket (cancellationToken : CancellationToken) (jsonOptions: JsonOptions) (executor : Executor<'Root>) (replacements : Map) (socket : WebSocket) : Task option> = task { let buffer = Array.zeroCreate 4096 let completeMessage = new List() @@ -84,8 +86,17 @@ type GraphQLWebSocketMiddleware<'Root>(next : RequestDelegate, applicationLifeti if String.IsNullOrWhiteSpace message then return None else - let! deserializedMsg = deserializeClientMessage jsonOptions message - return Some deserializedMsg + try + let! deserializedRawMsg = deserializeClientMessage jsonOptions message + return + deserializedRawMsg + |> MessageMapping.toClientMessage executor + |> Some + with :? JsonException as e -> + printfn "%s" (e.ToString()) + return + (MessageMapping.invalidMsg <| "invalid json in client message" + |> Some) } let sendMessageViaSocket (cancellationToken : CancellationToken) (jsonOptions: JsonOptions) (socket : WebSocket) (message : ServerMessage) = @@ -167,6 +178,15 @@ type GraphQLWebSocketMiddleware<'Root>(next : RequestDelegate, applicationLifeti let logMsgWithIdReceived id msgAsStr = printfn "%s (id: %s)" msgAsStr id + let tryToGracefullyCloseSocket theSocket = + task { + if theSocket |> canCloseSocket + then + do! theSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "NormalClosure", new CancellationToken()) + else + () + } + // <-------------- // <-- Helpers --| // <-------------- @@ -183,15 +203,20 @@ type GraphQLWebSocketMiddleware<'Root>(next : RequestDelegate, applicationLifeti printfn "Warn: websocket socket received empty message! (socket state = %A)" socket.State | Some msg -> match msg with - | ConnectionInit p -> + | Failure failureMsgs -> + "InvalidMessage" |> logMsgReceivedWithOptionalPayload None + match failureMsgs |> List.head with + | InvalidMessage (code, explanation) -> + do! socket.CloseAsync(enum code, explanation, cancellationToken) + | Success (ConnectionInit p, _) -> "ConnectionInit" |> logMsgReceivedWithOptionalPayload p do! ConnectionAck |> safe_Send - | ClientPing p -> + | Success (ClientPing p, _) -> "ClientPing" |> logMsgReceivedWithOptionalPayload p do! ServerPong None |> safe_Send - | ClientPong p -> + | Success (ClientPong p, _) -> "ClientPong" |> logMsgReceivedWithOptionalPayload p - | Subscribe (id, query) -> + | Success (Subscribe (id, query), _) -> "Subscribe" |> logMsgWithIdReceived id let! planExecutionResult = Async.StartAsTask ( @@ -200,22 +225,19 @@ type GraphQLWebSocketMiddleware<'Root>(next : RequestDelegate, applicationLifeti ) do! planExecutionResult |> safe_ApplyPlanExecutionResult id - | ClientComplete id -> + | Success (ClientComplete id, _) -> "ClientComplete" |> logMsgWithIdReceived id id |> GraphQLSubscriptionsManagement.removeSubscription do! Complete id |> safe_Send - | InvalidMessage explanation -> - "InvalidMessage" |> logMsgReceivedWithOptionalPayload None - do! socket.CloseAsync(enum CustomWebSocketStatus.invalidMessage, explanation, cancellationToken) printfn "Leaving graphql-ws connection loop..." - if socket |> canCloseSocket - then - do! socket.CloseAsync(WebSocketCloseStatus.NormalClosure, "NormalClosure", new CancellationToken()) - else - () - with // TODO: MAKE A PROPER GRAPHQL ERROR HANDLING! + do! socket |> tryToGracefullyCloseSocket + with | ex -> printfn "Unexpected exception \"%s\" in GraphQLWebsocketMiddleware (handleMessages). More:\n%s" (ex.GetType().Name) (ex.ToString()) + // at this point, only something really weird must have happened. + // In order to avoid faulty state scenarios and unimagined damages, + // just close the socket without further ado. + do! socket |> tryToGracefullyCloseSocket } let waitForConnectionInit (jsonOptions : JsonOptions) (schemaExecutor : Executor<'Root>) (replacements : Map) (connectionInitTimeoutInMs : int) (socket : WebSocket) : Task> = diff --git a/src/Main/Main.fsproj b/src/Main/Main.fsproj index 148d432..cd79b8a 100644 --- a/src/Main/Main.fsproj +++ b/src/Main/Main.fsproj @@ -12,6 +12,7 @@ + diff --git a/src/Main/MessageMapping.fs b/src/Main/MessageMapping.fs index 936567e..c7ca245 100644 --- a/src/Main/MessageMapping.fs +++ b/src/Main/MessageMapping.fs @@ -2,21 +2,32 @@ namespace GraphQLTransportWS module MessageMapping = open FSharp.Data.GraphQL + open Rop - let requireId (raw : RawMessage) : string = + /// From the spec: "Receiving a message of a type or format which is not specified in this document will result in an immediate socket closure with the event 4400: <error-message>. + /// The <error-message> can be vaguely descriptive on why the received message is invalid." + let invalidMsg (explanation : string) = + InvalidMessage (4400, explanation) + |> fail + + let private requireId (raw : RawMessage) : RopResult = match raw.Id with - | Some s -> s - | None -> failwith "property \"id\" is required but was not there" + | Some s -> succeed s + | None -> invalidMsg <| "property \"id\" is required for this message but was not present." - let requirePayloadToBeAnOptionalString (payload : RawPayload option) : string option = + let private requirePayloadToBeAnOptionalString (payload : RawPayload option) : RopResult = match payload with | Some p -> match p with - | StringPayload strPayload -> Some strPayload - | _ -> failwith "payload was expected to be a string, but it wasn't" - | None -> None + | StringPayload strPayload -> + Some strPayload + |> succeed + | SubscribePayload _ -> + invalidMsg <| "for this message, payload was expected to be an optional string, but it was a \"subscribe\" payload instead." + | None -> + succeed None - let requireSubscribePayload (executor : Executor<'a>) (payload : RawPayload option) : GraphQLQuery = + let private requireSubscribePayload (executor : Executor<'a>) (payload : RawPayload option) : RopResult = match payload with | Some p -> match p with @@ -25,29 +36,44 @@ module MessageMapping = | Some query -> { ExecutionPlan = executor.CreateExecutionPlan(query) Variables = Map.empty } + |> succeed | None -> - failwith "there was no query in subscribe message!" + invalidMsg <| sprintf "there was no query in the client's subscribe message!" | _ -> - failwith "payload was expected to be a subscribe payload object, but it wasn't." + invalidMsg <| "for this message, payload was expected to be a \"subscribe\" payload object, but it wasn't." | None -> - failwith "payload is required for this message, but none was available" + invalidMsg <| "payload is required for this message, but none was present." - let toClientMessage (executor : Executor<'a>) (raw : RawMessage) : ClientMessage = + let toClientMessage (executor : Executor<'a>) (raw : RawMessage) : RopResult = match raw.Type with | None -> - failwithf "property \"type\" was not found in the client message" + invalidMsg <| sprintf "message type was not specified by client." | Some "connection_init" -> - ConnectionInit (raw.Payload |> requirePayloadToBeAnOptionalString) + raw.Payload + |> requirePayloadToBeAnOptionalString + |> mapR ConnectionInit | Some "ping" -> - ClientPing (raw.Payload |> requirePayloadToBeAnOptionalString) + raw.Payload + |> requirePayloadToBeAnOptionalString + |> mapR ClientPing | Some "pong" -> - ClientPong (raw.Payload |> requirePayloadToBeAnOptionalString) + raw.Payload + |> requirePayloadToBeAnOptionalString + |> mapR ClientPong | Some "complete" -> - ClientComplete (raw |> requireId) + raw + |> requireId + |> mapR ClientComplete | Some "subscribe" -> - let id = raw |> requireId - let payload = raw.Payload |> requireSubscribePayload executor - Subscribe (id, payload) + raw + |> requireId + |> bindR + (fun id -> + raw.Payload + |> requireSubscribePayload executor + |> mapR (fun payload -> (id, payload)) + ) + |> mapR Subscribe | Some other -> - failwithf "type \"%s\" is not supported as a client message type" other + invalidMsg <| sprintf "invalid type \"%s\" specified by client." other diff --git a/src/Main/Messages.fs b/src/Main/Messages.fs index 8516d89..4cd8873 100644 --- a/src/Main/Messages.fs +++ b/src/Main/Messages.fs @@ -28,7 +28,9 @@ type ClientMessage = | ClientPong of payload: string option | Subscribe of id: string * query: GraphQLQuery | ClientComplete of id: string - | InvalidMessage of explanation: string + +type ClientMessageProtocolFailure = + | InvalidMessage of code: int * explanation: string type ServerMessage = | ConnectionAck diff --git a/src/Main/RawMessageConverter.fs b/src/Main/RawMessageConverter.fs index 50b29dc..14dadc4 100644 --- a/src/Main/RawMessageConverter.fs +++ b/src/Main/RawMessageConverter.fs @@ -8,6 +8,9 @@ open System.Text.Json.Serialization type RawMessageConverter() = inherit JsonConverter() + let raiseInvalidMsg explanation = + raise <| InvalidMessageException explanation + let getOptionalString (reader : byref) = if reader.TokenType.Equals(JsonTokenType.Null) then None @@ -18,7 +21,7 @@ type RawMessageConverter() = if reader.Read() then getOptionalString(&reader) else - failwithf "was expecting a value for property \"%s\"" propertyName + raiseInvalidMsg <| sprintf "was expecting a value for property \"%s\"" propertyName let readSubscribePayload (reader : byref) : RawSubscribePayload = let mutable operationName : string option = None @@ -36,7 +39,7 @@ type RawMessageConverter() = | "extensions" -> extensions <- readPropertyValueAsAString "extensions" &reader | other -> - failwithf "unexpected property \"%s\" in payload object" other + raiseInvalidMsg <| sprintf "unexpected property \"%s\" in payload object" other { OperationName = operationName Query = query Variables = variables @@ -51,11 +54,11 @@ type RawMessageConverter() = SubscribePayload (readSubscribePayload &reader) |> Some elif reader.TokenType.Equals(JsonTokenType.Null) then - failwith "was expecting a value for property \"payload\"" + raiseInvalidMsg <| "was expecting a value for property \"payload\"" else - failwith "Not implemented yet. Uh-oh, this is a bug." + raiseInvalidMsg <| sprintf "payload is a \"%A\", which is not supported" reader.TokenType else - failwith "was expecting a value for property \"payload\"" + raiseInvalidMsg <| "was expecting a value for property \"payload\"" override __.Read(reader : byref, typeToConvert: Type, options: JsonSerializerOptions) : RawMessage = if not (reader.TokenType.Equals(JsonTokenType.StartObject)) @@ -73,7 +76,7 @@ type RawMessageConverter() = | "payload" -> payload <- readPayload &reader | other -> - failwithf "unknown property \"%s\"" other + raiseInvalidMsg <| sprintf "unknown property \"%s\"" other { Id = id Type = theType Payload = payload } diff --git a/tests/unit-tests/InvalidMessageTests.fs b/tests/unit-tests/InvalidMessageTests.fs new file mode 100644 index 0000000..e35b7b4 --- /dev/null +++ b/tests/unit-tests/InvalidMessageTests.fs @@ -0,0 +1,194 @@ +module InvalidMessageTests + +open GraphQLTransportWS.Rop +open UnitTest +open GraphQLTransportWS +open System +open System.Text.Json +open Xunit +open FSharp.Data.GraphQL.Ast + +let toClientMessage (theInput : string) = + let serializerOptions = new JsonSerializerOptions() + serializerOptions.Converters.Add(new RawMessageConverter()) + JsonSerializer.Deserialize(theInput, serializerOptions) + |> MessageMapping.toClientMessage TestSchema.executor + +let willResultInInvalidMessage expectedExplanation input = + try + let result = + input + |> toClientMessage + match result with + | Failure msgs -> + match msgs |> List.head with + | InvalidMessage (code, explanation) -> + Assert.Equal(4400, code) + Assert.Equal(expectedExplanation, explanation) + | other -> + Assert.Fail(sprintf "unexpected actual value: '%A'" other) + with + | :? InvalidMessageException as ex -> + Assert.Equal(expectedExplanation, ex.Message) + +let willResultInJsonException input = + try + input + |> toClientMessage + |> ignore + Assert.Fail("expected that a JsonException would have already been thrown at this point") + with + | :? JsonException as ex -> + Assert.True(true) + +[] +let ``unknown message type`` () = + """{ + "type": "connection_start" + } + """ + |> willResultInInvalidMessage "invalid type \"connection_start\" specified by client." + +[] +let ``type not specified`` () = + """{ + "payload": "hello, let us connect" + } + """ + |> willResultInInvalidMessage "message type was not specified by client." + +[] +let ``no payload in subscribe message`` () = + """{ + "type": "subscribe", + "id": "b5d4d2ff-d262-4882-a7b9-d6aec5e4faa6" + } + """ + |> willResultInInvalidMessage "payload is required for this message, but none was present." + +[] +let ``null payload json in subscribe message`` () = + """{ + "type": "subscribe", + "id": "b5d4d2ff-d262-4882-a7b9-d6aec5e4faa6", + "payload": null + } + """ + |> willResultInInvalidMessage "was expecting a value for property \"payload\"" + +[] +let ``payload type of number in subscribe message`` () = + """{ + "type": "subscribe", + "id": "b5d4d2ff-d262-4882-a7b9-d6aec5e4faa6", + "payload": 42 + } + """ + |> willResultInInvalidMessage "payload is a \"Number\", which is not supported" + +[] +let ``payload type of number in connection_init message is not supported`` () = + """{ + "type": "connection_init", + "payload": 42 + } + """ + |> willResultInInvalidMessage "payload is a \"Number\", which is not supported" + +[] +let ``no id in subscribe message`` () = + """{ + "type": "subscribe", + "payload": { + "query": "subscription { watchMoon(id: \"1\") { id name isMoon } }" + } + } + """ + |> willResultInInvalidMessage "property \"id\" is required for this message but was not present." + +[] +let ``subscribe payload format wrongly used in connection_init`` () = + """{ + "type": "connection_init", + "payload": { + "query": "subscription { watchMoon(id: \"1\") { id name isMoon } }" + } + } + """ + |> willResultInInvalidMessage "for this message, payload was expected to be an optional string, but it was a \"subscribe\" payload instead." + +[] +let ``string payload wrongly used in subscribe`` () = + """{ + "type": "subscribe", + "id": "b5d4d2ff-d262-4882-a7b9-d6aec5e4faa6", + "payload": "{\"query\": \"subscription { watchMoon(id: \\\"1\\\") { id name isMoon } }\"}" + } + """ + |> willResultInInvalidMessage "for this message, payload was expected to be a \"subscribe\" payload object, but it wasn't." + +[] +let ``ping payload object is totally unknown`` () = + """{ + "type": "ping", + "payload": { "isEmergency": true } + } + """ + |> willResultInInvalidMessage "unexpected property \"isEmergency\" in payload object" + +[] +let ``subscribe payload object is wrongly used in ping``() = + """{ + "type": "ping", + "payload": { + "query": "subscription { watchMoon(id: \"1\") { id name isMoon } }" + } + } + """ + |> willResultInInvalidMessage "for this message, payload was expected to be an optional string, but it was a \"subscribe\" payload instead." + +[] +let ``id is incorrectly a number in a subscribe message`` () = + """{ + "type": "subscribe", + "id": 42, + "payload": { + "query": "subscription { watchMoon(id: \"1\") { id name isMoon } }" + } + } + """ + |> willResultInJsonException + +[] +let ``typo in one of the messages root properties`` () = + """{ + "typo": "subscribe", + "id": "b5d4d2ff-d262-4882-a7b9-d6aec5e4faa6", + "payload": { + "query": "subscription { watchMoon(id: \"1\") { id name isMoon } }" + } + } + """ + |> willResultInInvalidMessage "unknown property \"typo\"" + +[] +let ``complete message without an id`` () = + """{ + "type": "complete" + } + """ + |> willResultInInvalidMessage "property \"id\" is required for this message but was not present." + +[] +let ``complete message with a null id`` () = + """{ + "type": "complete", + "id": null + } + """ + |> willResultInInvalidMessage "property \"id\" is required for this message but was not present." + + + + + diff --git a/tests/unit-tests/Main.UnitTests.fsproj b/tests/unit-tests/Main.UnitTests.fsproj index f228968..577c618 100644 --- a/tests/unit-tests/Main.UnitTests.fsproj +++ b/tests/unit-tests/Main.UnitTests.fsproj @@ -10,6 +10,7 @@ + diff --git a/tests/unit-tests/SerializationTests.fs b/tests/unit-tests/SerializationTests.fs index 644616d..320cb6c 100644 --- a/tests/unit-tests/SerializationTests.fs +++ b/tests/unit-tests/SerializationTests.fs @@ -1,5 +1,6 @@ -module Tests +module SerializationTests +open GraphQLTransportWS.Rop open UnitTest open GraphQLTransportWS open System @@ -20,7 +21,7 @@ let ``Deserializes ConnectionInit correctly`` () = |> MessageMapping.toClientMessage (TestSchema.executor) match result with - | ConnectionInit None -> () // <-- expected + | Success (ConnectionInit None, _) -> () // <-- expected | other -> Assert.Fail(sprintf "unexpected actual value: '%A'" other) @@ -37,7 +38,7 @@ let ``Deserializes ConnectionInit with payload correctly`` () = |> MessageMapping.toClientMessage (TestSchema.executor) match result with - | ConnectionInit (Some "hello") -> () // <-- expected + | Success (ConnectionInit (Some "hello"), _) -> () // <-- expected | other -> Assert.Fail(sprintf "unexpected actual value: '%A'" other) @@ -54,7 +55,7 @@ let ``Deserializes ClientPing correctly`` () = |> MessageMapping.toClientMessage (TestSchema.executor) match result with - | ClientPing None -> () // <-- expected + | Success (ClientPing None, _) -> () // <-- expected | other -> Assert.Fail(sprintf "unexpected actual value '%A'" other) @@ -71,7 +72,7 @@ let ``Deserializes ClientPing with payload correctly`` () = |> MessageMapping.toClientMessage (TestSchema.executor) match result with - | ClientPing (Some "ping!") -> () // <-- expected + | Success (ClientPing (Some "ping!"), _) -> () // <-- expected | other -> Assert.Fail(sprintf "unexpected actual value '%A" other) @@ -88,7 +89,7 @@ let ``Deserializes ClientPong correctly`` () = |> MessageMapping.toClientMessage (TestSchema.executor) match result with - | ClientPong None -> () // <-- expected + | Success (ClientPong None, _) -> () // <-- expected | other -> Assert.Fail(sprintf "unexpected actual value: '%A'" other) @@ -105,7 +106,7 @@ let ``Deserializes ClientPong with payload correctly`` () = |> MessageMapping.toClientMessage (TestSchema.executor) match result with - | ClientPong (Some "pong!") -> () // <-- expected + | Success (ClientPong (Some "pong!"), _) -> () // <-- expected | other -> Assert.Fail(sprintf "unexpected actual value: '%A'" other) @@ -122,7 +123,7 @@ let ``Deserializes ClientComplete correctly``() = |> MessageMapping.toClientMessage (TestSchema.executor) match result with - | ClientComplete id -> + | Success (ClientComplete id, _) -> Assert.Equal("65fca2b5-f149-4a70-a055-5123dea4628f", id) | other -> Assert.Fail(sprintf "unexpected actual value: '%A'" other) @@ -148,7 +149,7 @@ let ``Deserializes client subscription correctly`` () = |> MessageMapping.toClientMessage (TestSchema.executor) match result with - | Subscribe (id, payload) -> + | Success (Subscribe (id, payload), _) -> Assert.Equal("b5d4d2ff-d262-4882-a7b9-d6aec5e4faa6", id) Assert.Equal(1, payload.ExecutionPlan.Operation.SelectionSet.Length) let watchMoonSelection = payload.ExecutionPlan.Operation.SelectionSet |> List.head