diff --git a/lib/phoenix/socket/transport.ex b/lib/phoenix/socket/transport.ex index 1514723335..a09c778b89 100644 --- a/lib/phoenix/socket/transport.ex +++ b/lib/phoenix/socket/transport.ex @@ -269,7 +269,7 @@ defmodule Phoenix.Socket.Transport do connect_info = Enum.map(connect_info, fn - key when key in [:peer_data, :trace_context_headers, :uri, :user_agent, :x_headers, :auth_token] -> + key when key in [:peer_data, :trace_context_headers, :uri, :user_agent, :x_headers, :sec_websocket_headers, :auth_token] -> key {:session, session} -> @@ -280,7 +280,7 @@ defmodule Phoenix.Socket.Transport do other -> raise ArgumentError, - ":connect_info keys are expected to be one of :peer_data, :trace_context_headers, :x_headers, :uri, or {:session, config}, " <> + ":connect_info keys are expected to be one of :peer_data, :trace_context_headers, :x_headers, :user_agent, :sec_websocket_headers, :uri, or {:session, config}, " <> "optionally followed by custom keyword pairs, got: #{inspect(other)}" end) @@ -470,6 +470,8 @@ defmodule Phoenix.Socket.Transport do * `:user_agent` - the value of the "user-agent" request header + * `:sec_websocket_headers` - a list of all request headers that have a "sec-websocket-" prefix + The CSRF check can be disabled by setting the `:check_csrf` option to `false`. """ def connect_info(conn, endpoint, keys, opts \\ []) do @@ -482,7 +484,7 @@ defmodule Phoenix.Socket.Transport do {:trace_context_headers, fetch_trace_context_headers(conn)} :x_headers -> - {:x_headers, fetch_x_headers(conn)} + {:x_headers, fetch_headers(conn, "x-")} :uri -> {:uri, fetch_uri(conn)} @@ -490,6 +492,9 @@ defmodule Phoenix.Socket.Transport do :user_agent -> {:user_agent, fetch_user_agent(conn)} + :sec_websocket_headers -> + {:sec_websocket_headers, fetch_headers(conn, "sec-websocket-")} + {:session, session} -> {:session, connect_session(conn, endpoint, session, opts)} @@ -527,9 +532,9 @@ defmodule Phoenix.Socket.Transport do end end - defp fetch_x_headers(conn) do + defp fetch_headers(conn, prefix) do for {header, _} = pair <- conn.req_headers, - String.starts_with?(header, "x-"), + String.starts_with?(header, prefix), do: pair end diff --git a/test/phoenix/integration/websocket_channels_test.exs b/test/phoenix/integration/websocket_channels_test.exs index 2bd0a49c9a..7f3125252f 100644 --- a/test/phoenix/integration/websocket_channels_test.exs +++ b/test/phoenix/integration/websocket_channels_test.exs @@ -121,6 +121,7 @@ defmodule Phoenix.Integration.WebSocketChannelsTest do |> Map.update!(:trace_context_headers, &Map.new/1) |> Map.update!(:uri, &Map.from_struct/1) |> Map.update!(:x_headers, &Map.new/1) + |> Map.update!(:sec_websocket_headers, &Map.new/1) socket = socket @@ -191,8 +192,9 @@ defmodule Phoenix.Integration.WebSocketChannelsTest do :peer_data, :uri, :user_agent, + :sec_websocket_headers, session: @session_config, - signing_salt: "salt" + signing_salt: "salt", ] ] @@ -328,6 +330,35 @@ defmodule Phoenix.Integration.WebSocketChannelsTest do } end + test "transport sec-websocket-* headers are extracted to the socket connect_info" do + extra_headers = [ + {"sec-websocket-protocol", "phoenix, 123"}, + {"sec-websocket-extensions", "permessage-deflate; client_max_window_bits=15"} + ] + + {:ok, sock} = + WebsocketClient.connect( + self(), + "ws://127.0.0.1:#{@port}/ws/connect_info/websocket?vsn=#{@vsn}", + @serializer, + extra_headers + ) + + WebsocketClient.join(sock, lobby(), %{}) + + assert_receive %Message{ + event: "joined", + payload: %{ + "connect_info" => %{ + "sec_websocket_headers" => %{ + "sec-websocket-protocol" => "phoenix, 123", + "sec-websocket-extensions" => "permessage-deflate; client_max_window_bits=15" + } + } + } + } + end + test "transport trace_context_headers are extracted to the socket connect_info" do extra_headers = [ {"traceparent", "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"},