diff --git a/assets/js/phoenix/ajax.js b/assets/js/phoenix/ajax.js index 62c771e90d..6d69414df2 100644 --- a/assets/js/phoenix/ajax.js +++ b/assets/js/phoenix/ajax.js @@ -5,13 +5,13 @@ import { export default class Ajax { - static request(method, endPoint, accept, body, timeout, ontimeout, callback){ + static request(method, endPoint, headers, body, timeout, ontimeout, callback){ if(global.XDomainRequest){ let req = new global.XDomainRequest() // IE8, IE9 return this.xdomainRequest(req, method, endPoint, body, timeout, ontimeout, callback) } else { let req = new global.XMLHttpRequest() // IE7+, Firefox, Chrome, Opera, Safari - return this.xhrRequest(req, method, endPoint, accept, body, timeout, ontimeout, callback) + return this.xhrRequest(req, method, endPoint, headers, body, timeout, ontimeout, callback) } } @@ -31,10 +31,12 @@ export default class Ajax { return req } - static xhrRequest(req, method, endPoint, accept, body, timeout, ontimeout, callback){ + static xhrRequest(req, method, endPoint, headers, body, timeout, ontimeout, callback){ req.open(method, endPoint, true) req.timeout = timeout - req.setRequestHeader("Content-Type", accept) + for (let [key, value] of Object.entries(headers)) { + req.setRequestHeader(key, value) + } req.onerror = () => callback && callback(null) req.onreadystatechange = () => { if(req.readyState === XHR_STATES.complete && callback){ diff --git a/assets/js/phoenix/constants.js b/assets/js/phoenix/constants.js index 591ec7a248..231e4f4eea 100644 --- a/assets/js/phoenix/constants.js +++ b/assets/js/phoenix/constants.js @@ -27,3 +27,4 @@ export const TRANSPORTS = { export const XHR_STATES = { complete: 4 } +export const AUTH_TOKEN_PREFIX = "base64url.bearer.phx." \ No newline at end of file diff --git a/assets/js/phoenix/longpoll.js b/assets/js/phoenix/longpoll.js index 42471c7ce7..25f98f5deb 100644 --- a/assets/js/phoenix/longpoll.js +++ b/assets/js/phoenix/longpoll.js @@ -1,6 +1,7 @@ import { SOCKET_STATES, - TRANSPORTS + TRANSPORTS, + AUTH_TOKEN_PREFIX } from "./constants" import Ajax from "./ajax" @@ -15,7 +16,12 @@ let arrayBufferToBase64 = (buffer) => { export default class LongPoll { - constructor(endPoint){ + constructor(endPoint, protocols){ + // we only support subprotocols for authToken + // ["phoenix", "base64url.bearer.phx.BASE64_ENCODED_TOKEN"] + if (protocols.length === 2 && protocols[1].startsWith(AUTH_TOKEN_PREFIX)) { + this.authToken = atob(protocols[1].slice(AUTH_TOKEN_PREFIX.length)) + } this.endPoint = null this.token = null this.skipHeartbeat = true @@ -58,7 +64,11 @@ export default class LongPoll { isActive(){ return this.readyState === SOCKET_STATES.open || this.readyState === SOCKET_STATES.connecting } poll(){ - this.ajax("GET", "application/json", null, () => this.ontimeout(), resp => { + const headers = {"Accept": "application/json"} + if(this.authToken){ + headers["X-Phoenix-AuthToken"] = this.authToken + } + this.ajax("GET", headers, null, () => this.ontimeout(), resp => { if(resp){ var {status, token, messages} = resp this.token = token @@ -160,13 +170,13 @@ export default class LongPoll { } } - ajax(method, contentType, body, onCallerTimeout, callback){ + ajax(method, headers, body, onCallerTimeout, callback){ let req let ontimeout = () => { this.reqs.delete(req) onCallerTimeout() } - req = Ajax.request(method, this.endpointURL(), contentType, body, this.timeout, ontimeout, resp => { + req = Ajax.request(method, this.endpointURL(), headers, body, this.timeout, ontimeout, resp => { this.reqs.delete(req) if(this.isActive()){ callback(resp) } }) diff --git a/assets/js/phoenix/socket.js b/assets/js/phoenix/socket.js index 8da13dd92d..66886cd0c6 100644 --- a/assets/js/phoenix/socket.js +++ b/assets/js/phoenix/socket.js @@ -6,7 +6,8 @@ import { DEFAULT_VSN, SOCKET_STATES, TRANSPORTS, - WS_CLOSE_NORMAL + WS_CLOSE_NORMAL, + AUTH_TOKEN_PREFIX } from "./constants" import { @@ -86,6 +87,8 @@ import Timer from "./timer" * Defaults to 20s (double the server long poll timer). * * @param {(Object|function)} [opts.params] - The optional params to pass when connecting + * @param {string} [opts.authToken] - the optional authentication token to be exposed on the server + * under the `:auth_token` connect_info key. * @param {string} [opts.binaryType] - The binary type to use for binary WebSocket frames. * * Defaults to "arraybuffer" @@ -176,6 +179,7 @@ export default class Socket { this.reconnectTimer = new Timer(() => { this.teardown(() => this.connect()) }, this.reconnectAfterMs) + this.authToken = opts.authToken } /** @@ -345,7 +349,13 @@ export default class Socket { transportConnect(){ this.connectClock++ this.closeWasClean = false - this.conn = new this.transport(this.endPointURL()) + let protocols = ["phoenix"] + // Sec-WebSocket-Protocol based token + // (longpoll uses Authorization header instead) + if (this.authToken) { + protocols.push(`${AUTH_TOKEN_PREFIX}${btoa(this.authToken).replace(/=/g, "")}`) + } + this.conn = new this.transport(this.endPointURL(), protocols) this.conn.binaryType = this.binaryType this.conn.timeout = this.longpollerTimeout this.conn.onopen = () => this.onConnOpen() diff --git a/assets/test/channel_test.js b/assets/test/channel_test.js index a5166a8cef..1ad88d291e 100644 --- a/assets/test/channel_test.js +++ b/assets/test/channel_test.js @@ -7,7 +7,10 @@ const defaultRef = 1 const defaultTimeout = 10000 class WSMock { - constructor(){} + constructor(url, protocols){ + this.url = url + this.protocols = protocols + } close(){} send(){} } @@ -58,6 +61,14 @@ describe("with transport", function (){ expect(joinPush.event).toBe("phx_join") expect(joinPush.timeout).toBe(1234) }) + + it("sets subprotocols when authToken is provided", function (){ + const authToken = "1234" + const socket = new Socket("/socket", {authToken}) + + socket.connect() + expect(socket.conn.protocols).toEqual(["phoenix", "base64url.bearer.phx.MTIzNA"]) + }) }) describe("updating join params", function (){ diff --git a/guides/real_time/channels.md b/guides/real_time/channels.md index f2b12fb236..14510c0578 100644 --- a/guides/real_time/channels.md +++ b/guides/real_time/channels.md @@ -385,7 +385,24 @@ That's all there is to our basic chat app. Fire up multiple browser tabs and you When we connect, we'll often need to authenticate the client. Fortunately, this is a 4-step process with [Phoenix.Token](https://hexdocs.pm/phoenix/Phoenix.Token.html). -### Step 1 - Assign a Token in the Connection +### Step 1 - Enable the `auth_token` functionality in the socket + +Phoenix supports a transport agnostic way to pass an authentication token to the server. To enable this, we need to pass the `:auth_token` option to the socket declaration in our `Endpoint` module and configure the `connect_info` to include the `:auth_token` key. + +```elixir +defmodule HelloWeb.Endpoint do + use Phoenix.Endpoint, otp_app: :hello + + socket "/socket", HelloWeb.UserSocket, + websocket: [connect_info: [:auth_token]], + longpoll: false, + auth_token: true + + ... +end +``` + +### Step 2 - Assign a Token in the Connection Let's say we have an authentication plug in our app called `OurAuth`. When `OurAuth` authenticates a user, it sets a value for the `:current_user` key in `conn.assigns`. Since the `current_user` exists, we can simply assign the user's token in the connection for use in the layout. We can wrap that behavior up in a private function plug, `put_user_token/2`. This could also be put in its own module as well. To make this all work, we just add `OurAuth` and `put_user_token/2` to the browser pipeline. @@ -408,7 +425,7 @@ end Now our `conn.assigns` contains the `current_user` and `user_token`. -### Step 2 - Pass the Token to the JavaScript +### Step 3 - Pass the Token to the JavaScript Next, we need to pass this token to JavaScript. We can do so inside a script tag in `lib/hello_web/components/layouts/root.html.heex` right above the app.js script, as follows: @@ -417,14 +434,14 @@ Next, we need to pass this token to JavaScript. We can do so inside a script tag ``` -### Step 3 - Pass the Token to the Socket Constructor and Verify +### Step 4 - Pass the Token to the Socket Constructor and Verify -We also need to pass the `:params` to the socket constructor and verify the user token in the `connect/3` function. To do so, edit `lib/hello_web/channels/user_socket.ex`, as follows: +We also need to pass the `:auth_token` to the socket constructor and verify the user token in the `connect/3` function. To do so, edit `lib/hello_web/channels/user_socket.ex`, as follows: ```elixir -def connect(%{"token" => token}, socket, _connect_info) do +def connect(_params_, socket, connect_info) do # max_age: 1209600 is equivalent to two weeks in seconds - case Phoenix.Token.verify(socket, "user socket", token, max_age: 1209600) do + case Phoenix.Token.verify(socket, "user socket", connect_info[:auth_token], max_age: 1209600) do {:ok, user_id} -> {:ok, assign(socket, :current_user, user_id)} {:error, reason} -> @@ -436,17 +453,17 @@ end In our JavaScript, we can use the token set previously when constructing the Socket: ```javascript -let socket = new Socket("/socket", {params: {token: window.userToken}}) +let socket = new Socket("/socket", {authToken: window.userToken}) ``` We used `Phoenix.Token.verify/4` to verify the user token provided by the client. `Phoenix.Token.verify/4` returns either `{:ok, user_id}` or `{:error, reason}`. We can pattern match on that return in a `case` statement. With a verified token, we set the user's id as the value to `:current_user` in the socket. Otherwise, we return `:error`. -### Step 4 - Connect to the socket in JavaScript +### Step 5 - Connect to the socket in JavaScript With authentication set up, we can connect to sockets and channels from JavaScript. ```javascript -let socket = new Socket("/socket", {params: {token: window.userToken}}) +let socket = new Socket("/socket", {authToken: window.userToken}) socket.connect() ``` diff --git a/lib/phoenix/endpoint.ex b/lib/phoenix/endpoint.ex index 48ffae62bd..b684d0f9a6 100644 --- a/lib/phoenix/endpoint.ex +++ b/lib/phoenix/endpoint.ex @@ -708,7 +708,8 @@ defmodule Phoenix.Endpoint do :check_origin, :check_csrf, :code_reloader, - :connect_info + :connect_info, + :auth_token ] websocket = @@ -740,6 +741,7 @@ defmodule Phoenix.Endpoint do paths = if websocket do + websocket = put_auth_token(websocket, opts[:auth_token]) config = Phoenix.Socket.Transport.load_config(websocket, Phoenix.Transports.WebSocket) plug_init = {endpoint, socket, config} {conn_ast, match_path} = socket_path(path, config) @@ -750,6 +752,7 @@ defmodule Phoenix.Endpoint do paths = if longpoll do + longpoll = put_auth_token(longpoll, opts[:auth_token]) config = Phoenix.Socket.Transport.load_config(longpoll, Phoenix.Transports.LongPoll) plug_init = {endpoint, socket, config} {conn_ast, match_path} = socket_path(path, config) @@ -761,6 +764,9 @@ defmodule Phoenix.Endpoint do paths end + defp put_auth_token(true, enabled), do: [auth_token: enabled] + defp put_auth_token(opts, enabled), do: Keyword.put(opts, :auth_token, enabled) + defp socket_path(path, config) do end_path_fragment = Keyword.fetch!(config, :path) @@ -844,6 +850,16 @@ defmodule Phoenix.Endpoint do HTTP/HTTPS connection drainer will still run, and apply to all connections. Set it to `false` to disable draining. + * `auth_token` - a boolean that enables the use of the channels client's auth_token option. + The exact token exchange mechanism depends on the transport: + + * the websocket transport, this enables a token to be passed through the `Sec-WebSocket-Protocol` header. + * the longpoll transport, this allows the token to be passed through the `Authorization` header. + + The token is available in the `connect_info` as `:auth_token`. + + Custom transports might implement their own mechanism. + You can also pass the options below on `use Phoenix.Socket`. The values specified here override the value in `use Phoenix.Socket`. diff --git a/lib/phoenix/socket/transport.ex b/lib/phoenix/socket/transport.ex index 3c08e0ec9a..1514723335 100644 --- a/lib/phoenix/socket/transport.ex +++ b/lib/phoenix/socket/transport.ex @@ -259,9 +259,17 @@ defmodule Phoenix.Socket.Transport do def load_config(config) do {connect_info, config} = Keyword.pop(config, :connect_info, []) + connect_info = + if config[:auth_token] do + # auth_token is included by default when enabled + [:auth_token | connect_info] + else + connect_info + end + connect_info = Enum.map(connect_info, fn - key when key in [:peer_data, :trace_context_headers, :uri, :user_agent, :x_headers] -> + key when key in [:peer_data, :trace_context_headers, :uri, :user_agent, :x_headers, :auth_token] -> key {:session, session} -> @@ -485,6 +493,9 @@ defmodule Phoenix.Socket.Transport do {:session, session} -> {:session, connect_session(conn, endpoint, session, opts)} + :auth_token -> + {:auth_token, conn.private[:phoenix_transport_auth_token]} + {key, val} -> {key, val} end @@ -549,7 +560,7 @@ defmodule Phoenix.Socket.Transport do with csrf_token when is_binary(csrf_token) <- conn.params["_csrf_token"], csrf_state when is_binary(csrf_state) <- Plug.CSRFProtection.dump_state_from_session(session[csrf_token_key]) do - Plug.CSRFProtection.valid_state_and_csrf_token?(csrf_state, csrf_token) + Plug.CSRFProtection.valid_state_and_csrf_token?(csrf_state, csrf_token) end end diff --git a/lib/phoenix/transports/long_poll.ex b/lib/phoenix/transports/long_poll.ex index 3a9915bb4e..a30c34f819 100644 --- a/lib/phoenix/transports/long_poll.ex +++ b/lib/phoenix/transports/long_poll.ex @@ -138,6 +138,8 @@ defmodule Phoenix.Transports.LongPoll do keys = Keyword.get(opts, :connect_info, []) + conn = maybe_auth_token_from_header(conn, opts[:auth_token]) + connect_info = Transport.connect_info(conn, endpoint, keys, Keyword.take(opts, @connect_info_opts)) @@ -265,6 +267,18 @@ defmodule Phoenix.Transports.LongPoll do ) end + defp maybe_auth_token_from_header(conn, true) do + case Plug.Conn.get_req_header(conn, "x-phoenix-authtoken") do + [] -> + conn + + [token | _] -> + Plug.Conn.put_private(conn, :phoenix_transport_auth_token, token) + end + end + + defp maybe_auth_token_from_header(conn, _), do: conn + defp status_json(conn) do send_json(conn, %{"status" => conn.status || 200}) end diff --git a/lib/phoenix/transports/websocket.ex b/lib/phoenix/transports/websocket.ex index dfc7bd2508..e9422570ab 100644 --- a/lib/phoenix/transports/websocket.ex +++ b/lib/phoenix/transports/websocket.ex @@ -17,6 +17,8 @@ defmodule Phoenix.Transports.WebSocket do @connect_info_opts [:check_csrf] + @auth_token_prefix "base64url.bearer.phx." + import Plug.Conn alias Phoenix.Socket.{V1, V2, Transport} @@ -35,12 +37,23 @@ defmodule Phoenix.Transports.WebSocket do def init(opts), do: opts def call(%{method: "GET"} = conn, {endpoint, handler, opts}) do + subprotocols = + if opts[:auth_token] do + # when using Sec-WebSocket-Protocol for passing an auth token + # the server must reply with one of the subprotocols in the request; + # therefore we include "phoenix" as allowed subprotocol and include it on the client + ["phoenix" | Keyword.get(opts, :subprotocols, [])] + else + opts[:subprotocols] + end + conn |> fetch_query_params() |> Transport.code_reload(endpoint, opts) |> Transport.transport_log(opts[:transport_log]) |> Transport.check_origin(handler, endpoint, opts) - |> Transport.check_subprotocols(opts[:subprotocols]) + |> maybe_auth_token_from_header(opts[:auth_token]) + |> Transport.check_subprotocols(subprotocols) |> case do %{halted: true} = conn -> conn @@ -82,4 +95,36 @@ defmodule Phoenix.Transports.WebSocket do def call(conn, _), do: send_resp(conn, 400, "") def handle_error(conn, _reason), do: send_resp(conn, 403, "") + + defp maybe_auth_token_from_header(conn, true) do + case get_req_header(conn, "sec-websocket-protocol") do + [] -> + conn + + [subprotocols_header | _] -> + request_subprotocols = + subprotocols_header + |> Plug.Conn.Utils.list() + |> Enum.split_with(&String.starts_with?(&1, @auth_token_prefix)) + + case request_subprotocols do + {[@auth_token_prefix <> encoded_token], actual_subprotocols} -> + token = Base.decode64!(encoded_token, padding: false) + + conn + |> put_private(:phoenix_transport_auth_token, token) + |> set_actual_subprotocols(actual_subprotocols) + + _ -> + conn + end + end + end + + defp maybe_auth_token_from_header(conn, _), do: conn + + defp set_actual_subprotocols(conn, []), do: delete_req_header(conn, "sec-websocket-protocol") + + defp set_actual_subprotocols(conn, subprotocols), + do: put_req_header(conn, "sec-websocket-protocol", Enum.join(subprotocols, ", ")) end