Skip to content

Commit

Permalink
Support transport agnostic token passing in channels (#6086)
Browse files Browse the repository at this point in the history
* Support transport agnostic token passing in channels

For WebSocket, the `Sec-WebSocket-Protocol` header is used. For LongPoll,
a custom header is passed instead.

Fixes #5778.
  • Loading branch information
SteffenDE authored Feb 19, 2025
1 parent fbef2f3 commit db8eac8
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 25 deletions.
10 changes: 6 additions & 4 deletions assets/js/phoenix/ajax.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ import {

export default class Ajax {

static request(method, endPoint, accept, body, timeout, ontimeout, callback){
static request(method, endPoint, headers, body, timeout, ontimeout, callback){
if(global.XDomainRequest){
let req = new global.XDomainRequest() // IE8, IE9
return this.xdomainRequest(req, method, endPoint, body, timeout, ontimeout, callback)
} else {
let req = new global.XMLHttpRequest() // IE7+, Firefox, Chrome, Opera, Safari
return this.xhrRequest(req, method, endPoint, accept, body, timeout, ontimeout, callback)
return this.xhrRequest(req, method, endPoint, headers, body, timeout, ontimeout, callback)
}
}

Expand All @@ -31,10 +31,12 @@ export default class Ajax {
return req
}

static xhrRequest(req, method, endPoint, accept, body, timeout, ontimeout, callback){
static xhrRequest(req, method, endPoint, headers, body, timeout, ontimeout, callback){
req.open(method, endPoint, true)
req.timeout = timeout
req.setRequestHeader("Content-Type", accept)
for (let [key, value] of Object.entries(headers)) {
req.setRequestHeader(key, value)
}
req.onerror = () => callback && callback(null)
req.onreadystatechange = () => {
if(req.readyState === XHR_STATES.complete && callback){
Expand Down
1 change: 1 addition & 0 deletions assets/js/phoenix/constants.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ export const TRANSPORTS = {
export const XHR_STATES = {
complete: 4
}
export const AUTH_TOKEN_PREFIX = "base64url.bearer.phx."
20 changes: 15 additions & 5 deletions assets/js/phoenix/longpoll.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import {
SOCKET_STATES,
TRANSPORTS
TRANSPORTS,
AUTH_TOKEN_PREFIX
} from "./constants"

import Ajax from "./ajax"
Expand All @@ -15,7 +16,12 @@ let arrayBufferToBase64 = (buffer) => {

export default class LongPoll {

constructor(endPoint){
constructor(endPoint, protocols){
// we only support subprotocols for authToken
// ["phoenix", "base64url.bearer.phx.BASE64_ENCODED_TOKEN"]
if (protocols.length === 2 && protocols[1].startsWith(AUTH_TOKEN_PREFIX)) {
this.authToken = atob(protocols[1].slice(AUTH_TOKEN_PREFIX.length))
}
this.endPoint = null
this.token = null
this.skipHeartbeat = true
Expand Down Expand Up @@ -58,7 +64,11 @@ export default class LongPoll {
isActive(){ return this.readyState === SOCKET_STATES.open || this.readyState === SOCKET_STATES.connecting }

poll(){
this.ajax("GET", "application/json", null, () => this.ontimeout(), resp => {
const headers = {"Accept": "application/json"}
if(this.authToken){
headers["X-Phoenix-AuthToken"] = this.authToken
}
this.ajax("GET", headers, null, () => this.ontimeout(), resp => {
if(resp){
var {status, token, messages} = resp
this.token = token
Expand Down Expand Up @@ -160,13 +170,13 @@ export default class LongPoll {
}
}

ajax(method, contentType, body, onCallerTimeout, callback){
ajax(method, headers, body, onCallerTimeout, callback){
let req
let ontimeout = () => {
this.reqs.delete(req)
onCallerTimeout()
}
req = Ajax.request(method, this.endpointURL(), contentType, body, this.timeout, ontimeout, resp => {
req = Ajax.request(method, this.endpointURL(), headers, body, this.timeout, ontimeout, resp => {
this.reqs.delete(req)
if(this.isActive()){ callback(resp) }
})
Expand Down
14 changes: 12 additions & 2 deletions assets/js/phoenix/socket.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import {
DEFAULT_VSN,
SOCKET_STATES,
TRANSPORTS,
WS_CLOSE_NORMAL
WS_CLOSE_NORMAL,
AUTH_TOKEN_PREFIX
} from "./constants"

import {
Expand Down Expand Up @@ -86,6 +87,8 @@ import Timer from "./timer"
* Defaults to 20s (double the server long poll timer).
*
* @param {(Object|function)} [opts.params] - The optional params to pass when connecting
* @param {string} [opts.authToken] - the optional authentication token to be exposed on the server
* under the `:auth_token` connect_info key.
* @param {string} [opts.binaryType] - The binary type to use for binary WebSocket frames.
*
* Defaults to "arraybuffer"
Expand Down Expand Up @@ -176,6 +179,7 @@ export default class Socket {
this.reconnectTimer = new Timer(() => {
this.teardown(() => this.connect())
}, this.reconnectAfterMs)
this.authToken = opts.authToken
}

/**
Expand Down Expand Up @@ -345,7 +349,13 @@ export default class Socket {
transportConnect(){
this.connectClock++
this.closeWasClean = false
this.conn = new this.transport(this.endPointURL())
let protocols = ["phoenix"]
// Sec-WebSocket-Protocol based token
// (longpoll uses Authorization header instead)
if (this.authToken) {
protocols.push(`${AUTH_TOKEN_PREFIX}${btoa(this.authToken).replace(/=/g, "")}`)
}
this.conn = new this.transport(this.endPointURL(), protocols)
this.conn.binaryType = this.binaryType
this.conn.timeout = this.longpollerTimeout
this.conn.onopen = () => this.onConnOpen()
Expand Down
13 changes: 12 additions & 1 deletion assets/test/channel_test.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ const defaultRef = 1
const defaultTimeout = 10000

class WSMock {
constructor(){}
constructor(url, protocols){
this.url = url
this.protocols = protocols
}
close(){}
send(){}
}
Expand Down Expand Up @@ -58,6 +61,14 @@ describe("with transport", function (){
expect(joinPush.event).toBe("phx_join")
expect(joinPush.timeout).toBe(1234)
})

it("sets subprotocols when authToken is provided", function (){
const authToken = "1234"
const socket = new Socket("/socket", {authToken})

socket.connect()
expect(socket.conn.protocols).toEqual(["phoenix", "base64url.bearer.phx.MTIzNA"])
})
})

describe("updating join params", function (){
Expand Down
35 changes: 26 additions & 9 deletions guides/real_time/channels.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,24 @@ That's all there is to our basic chat app. Fire up multiple browser tabs and you

When we connect, we'll often need to authenticate the client. Fortunately, this is a 4-step process with [Phoenix.Token](https://hexdocs.pm/phoenix/Phoenix.Token.html).

### Step 1 - Assign a Token in the Connection
### Step 1 - Enable the `auth_token` functionality in the socket

Phoenix supports a transport agnostic way to pass an authentication token to the server. To enable this, we need to pass the `:auth_token` option to the socket declaration in our `Endpoint` module and configure the `connect_info` to include the `:auth_token` key.

```elixir
defmodule HelloWeb.Endpoint do
use Phoenix.Endpoint, otp_app: :hello

socket "/socket", HelloWeb.UserSocket,
websocket: [connect_info: [:auth_token]],
longpoll: false,
auth_token: true

...
end
```

### Step 2 - Assign a Token in the Connection

Let's say we have an authentication plug in our app called `OurAuth`. When `OurAuth` authenticates a user, it sets a value for the `:current_user` key in `conn.assigns`. Since the `current_user` exists, we can simply assign the user's token in the connection for use in the layout. We can wrap that behavior up in a private function plug, `put_user_token/2`. This could also be put in its own module as well. To make this all work, we just add `OurAuth` and `put_user_token/2` to the browser pipeline.

Expand All @@ -408,7 +425,7 @@ end

Now our `conn.assigns` contains the `current_user` and `user_token`.

### Step 2 - Pass the Token to the JavaScript
### Step 3 - Pass the Token to the JavaScript

Next, we need to pass this token to JavaScript. We can do so inside a script tag in `lib/hello_web/components/layouts/root.html.heex` right above the app.js script, as follows:

Expand All @@ -417,14 +434,14 @@ Next, we need to pass this token to JavaScript. We can do so inside a script tag
<script src={~p"/assets/app.js"}></script>
```

### Step 3 - Pass the Token to the Socket Constructor and Verify
### Step 4 - Pass the Token to the Socket Constructor and Verify

We also need to pass the `:params` to the socket constructor and verify the user token in the `connect/3` function. To do so, edit `lib/hello_web/channels/user_socket.ex`, as follows:
We also need to pass the `:auth_token` to the socket constructor and verify the user token in the `connect/3` function. To do so, edit `lib/hello_web/channels/user_socket.ex`, as follows:

```elixir
def connect(%{"token" => token}, socket, _connect_info) do
def connect(_params_, socket, connect_info) do
# max_age: 1209600 is equivalent to two weeks in seconds
case Phoenix.Token.verify(socket, "user socket", token, max_age: 1209600) do
case Phoenix.Token.verify(socket, "user socket", connect_info[:auth_token], max_age: 1209600) do
{:ok, user_id} ->
{:ok, assign(socket, :current_user, user_id)}
{:error, reason} ->
Expand All @@ -436,17 +453,17 @@ end
In our JavaScript, we can use the token set previously when constructing the Socket:

```javascript
let socket = new Socket("/socket", {params: {token: window.userToken}})
let socket = new Socket("/socket", {authToken: window.userToken})
```

We used `Phoenix.Token.verify/4` to verify the user token provided by the client. `Phoenix.Token.verify/4` returns either `{:ok, user_id}` or `{:error, reason}`. We can pattern match on that return in a `case` statement. With a verified token, we set the user's id as the value to `:current_user` in the socket. Otherwise, we return `:error`.

### Step 4 - Connect to the socket in JavaScript
### Step 5 - Connect to the socket in JavaScript

With authentication set up, we can connect to sockets and channels from JavaScript.

```javascript
let socket = new Socket("/socket", {params: {token: window.userToken}})
let socket = new Socket("/socket", {authToken: window.userToken})
socket.connect()
```

Expand Down
18 changes: 17 additions & 1 deletion lib/phoenix/endpoint.ex
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,8 @@ defmodule Phoenix.Endpoint do
:check_origin,
:check_csrf,
:code_reloader,
:connect_info
:connect_info,
:auth_token
]

websocket =
Expand Down Expand Up @@ -740,6 +741,7 @@ defmodule Phoenix.Endpoint do

paths =
if websocket do
websocket = put_auth_token(websocket, opts[:auth_token])
config = Phoenix.Socket.Transport.load_config(websocket, Phoenix.Transports.WebSocket)
plug_init = {endpoint, socket, config}
{conn_ast, match_path} = socket_path(path, config)
Expand All @@ -750,6 +752,7 @@ defmodule Phoenix.Endpoint do

paths =
if longpoll do
longpoll = put_auth_token(longpoll, opts[:auth_token])
config = Phoenix.Socket.Transport.load_config(longpoll, Phoenix.Transports.LongPoll)
plug_init = {endpoint, socket, config}
{conn_ast, match_path} = socket_path(path, config)
Expand All @@ -761,6 +764,9 @@ defmodule Phoenix.Endpoint do
paths
end

defp put_auth_token(true, enabled), do: [auth_token: enabled]
defp put_auth_token(opts, enabled), do: Keyword.put(opts, :auth_token, enabled)

defp socket_path(path, config) do
end_path_fragment = Keyword.fetch!(config, :path)

Expand Down Expand Up @@ -844,6 +850,16 @@ defmodule Phoenix.Endpoint do
HTTP/HTTPS connection drainer will still run, and apply to all connections.
Set it to `false` to disable draining.
* `auth_token` - a boolean that enables the use of the channels client's auth_token option.
The exact token exchange mechanism depends on the transport:
* the websocket transport, this enables a token to be passed through the `Sec-WebSocket-Protocol` header.
* the longpoll transport, this allows the token to be passed through the `Authorization` header.
The token is available in the `connect_info` as `:auth_token`.
Custom transports might implement their own mechanism.
You can also pass the options below on `use Phoenix.Socket`.
The values specified here override the value in `use Phoenix.Socket`.
Expand Down
15 changes: 13 additions & 2 deletions lib/phoenix/socket/transport.ex
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,17 @@ defmodule Phoenix.Socket.Transport do
def load_config(config) do
{connect_info, config} = Keyword.pop(config, :connect_info, [])

connect_info =
if config[:auth_token] do
# auth_token is included by default when enabled
[:auth_token | connect_info]
else
connect_info
end

connect_info =
Enum.map(connect_info, fn
key when key in [:peer_data, :trace_context_headers, :uri, :user_agent, :x_headers] ->
key when key in [:peer_data, :trace_context_headers, :uri, :user_agent, :x_headers, :auth_token] ->
key

{:session, session} ->
Expand Down Expand Up @@ -485,6 +493,9 @@ defmodule Phoenix.Socket.Transport do
{:session, session} ->
{:session, connect_session(conn, endpoint, session, opts)}

:auth_token ->
{:auth_token, conn.private[:phoenix_transport_auth_token]}

{key, val} ->
{key, val}
end
Expand Down Expand Up @@ -549,7 +560,7 @@ defmodule Phoenix.Socket.Transport do
with csrf_token when is_binary(csrf_token) <- conn.params["_csrf_token"],
csrf_state when is_binary(csrf_state) <-
Plug.CSRFProtection.dump_state_from_session(session[csrf_token_key]) do
Plug.CSRFProtection.valid_state_and_csrf_token?(csrf_state, csrf_token)
Plug.CSRFProtection.valid_state_and_csrf_token?(csrf_state, csrf_token)
end
end

Expand Down
14 changes: 14 additions & 0 deletions lib/phoenix/transports/long_poll.ex
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ defmodule Phoenix.Transports.LongPoll do

keys = Keyword.get(opts, :connect_info, [])

conn = maybe_auth_token_from_header(conn, opts[:auth_token])

connect_info =
Transport.connect_info(conn, endpoint, keys, Keyword.take(opts, @connect_info_opts))

Expand Down Expand Up @@ -265,6 +267,18 @@ defmodule Phoenix.Transports.LongPoll do
)
end

defp maybe_auth_token_from_header(conn, true) do
case Plug.Conn.get_req_header(conn, "x-phoenix-authtoken") do
[] ->
conn

[token | _] ->
Plug.Conn.put_private(conn, :phoenix_transport_auth_token, token)
end
end

defp maybe_auth_token_from_header(conn, _), do: conn

defp status_json(conn) do
send_json(conn, %{"status" => conn.status || 200})
end
Expand Down
Loading

0 comments on commit db8eac8

Please sign in to comment.