Skip to content

Commit

Permalink
feat: add native mode
Browse files Browse the repository at this point in the history
  • Loading branch information
abc3 committed Oct 17, 2023
1 parent d78c7f6 commit 3cba0b0
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 116 deletions.
2 changes: 1 addition & 1 deletion lib/supavisor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ defmodule Supavisor do
@type tcp_sock :: {:gen_tcp, :gen_tcp.socket()}
@type workers :: %{manager: pid, pool: pid}
@type secrets :: {:password | :auth_query, fun()}
@type mode :: :transaction | :session
@type mode :: :transaction | :session | :native
@type id :: {String.t(), String.t(), mode}
@type subscribe_opts :: %{workers: workers, ps: list, idle_timeout: integer}

Expand Down
9 changes: 5 additions & 4 deletions lib/supavisor/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ defmodule Supavisor.Application do

proxy_ports = [
{:pg_proxy_transaction, Application.get_env(:supavisor, :proxy_port_transaction),
:transaction},
{:pg_proxy_session, Application.get_env(:supavisor, :proxy_port_session), :session}
:transaction, Supavisor.ClientHandler},
{:pg_proxy_native, Application.get_env(:supavisor, :proxy_port_session), :native,
Supavisor.NativeHandler}
]

for {key, port, mode} <- proxy_ports do
for {key, port, mode, module} <- proxy_ports do
:ranch.start_listener(
key,
:ranch_tcp,
Expand All @@ -42,7 +43,7 @@ defmodule Supavisor.Application do
num_acceptors: String.to_integer(System.get_env("NUM_ACCEPTORS") || "100"),
socket_opts: [port: port, keepalive: true]
},
Supavisor.ClientHandler,
module,
%{mode: mode}
)
|> then(&"Proxy started #{mode} on port #{port}, result: #{inspect(&1)}")
Expand Down
136 changes: 25 additions & 111 deletions lib/supavisor/client_handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ defmodule Supavisor.ClientHandler do
alias Supavisor, as: S
alias Supavisor.DbHandler, as: Db
alias Supavisor.Helpers, as: H
alias Supavisor.HandlerHelpers, as: HH
alias Supavisor.{Tenants, Monitoring.Telem, Protocol.Server}

@impl true
Expand Down Expand Up @@ -65,7 +66,7 @@ defmodule Supavisor.ClientHandler do
@impl true
def handle_event(:info, {_proto, _, <<"GET", _::binary>>}, :exchange, data) do
Logger.debug("Client is trying to request HTTP")
sock_send(data.sock, "HTTP/1.1 204 OK\r\n\r\n")
HH.sock_send(data.sock, "HTTP/1.1 204 OK\r\n\r\n")
{:stop, :normal, data}
end

Expand All @@ -77,8 +78,8 @@ defmodule Supavisor.ClientHandler do

# SSL negotiation, S/N/Error
if !!downstream_cert and !!downstream_key do
:ok = setopts(sock, active: false)
:ok = sock_send(sock, "S")
:ok = HH.setopts(sock, active: false)
:ok = HH.sock_send(sock, "S")

opts = [
certfile: downstream_cert,
Expand All @@ -88,7 +89,7 @@ defmodule Supavisor.ClientHandler do
case :ssl.handshake(elem(sock, 1), opts) do
{:ok, ssl_sock} ->
socket = {:ssl, ssl_sock}
:ok = setopts(socket, active: true)
:ok = HH.setopts(socket, active: true)
{:keep_state, %{data | sock: socket, ssl: true}}

error ->
Expand All @@ -97,16 +98,16 @@ defmodule Supavisor.ClientHandler do
end
else
Logger.error("User requested SSL connection but no downstream cert/key found")
:ok = sock_send(data.sock, "N")
:ok = HH.sock_send(data.sock, "N")
:keep_state_and_data
end
end

def handle_event(:info, {_, _, bin}, :exchange, data) do
case decode_startup_packet(bin) do
case Server.decode_startup_packet(bin) do
{:ok, hello} ->
Logger.debug("Client startup message: #{inspect(hello)}")
{user, external_id} = parse_user_info(hello.payload)
{user, external_id} = HH.parse_user_info(hello.payload)
Logger.metadata(project: external_id, user: user, mode: data.mode)
{:keep_state, data, {:next_event, :internal, {:hello, {user, external_id}}}}

Expand All @@ -117,7 +118,7 @@ defmodule Supavisor.ClientHandler do
end

def handle_event(:internal, {:hello, {user, external_id}}, :exchange, %{sock: sock} = data) do
sni_hostname = try_get_sni(sock)
sni_hostname = HH.try_get_sni(sock)

case Tenants.get_user_cache(user, external_id, sni_hostname) do
{:ok, info} ->
Expand All @@ -126,7 +127,7 @@ defmodule Supavisor.ClientHandler do

if info.tenant.enforce_ssl and !data.ssl do
Logger.error("Tenant is not allowed to connect without SSL, user #{user}")
:ok = send_error(sock, "XX000", "SSL connection is required")
:ok = HH.send_error(sock, "XX000", "SSL connection is required")
{:stop, :normal, data}
else
new_data = update_user_data(data, info, user, id)
Expand All @@ -139,15 +140,15 @@ defmodule Supavisor.ClientHandler do
{:error, reason} ->
Logger.error("Authentication auth_secrets error: #{inspect(reason)}")

:ok = send_error(sock, "XX000", "Authentication error")
:ok = HH.send_error(sock, "XX000", "Authentication error")
{:stop, :normal, data}
end
end

{:error, reason} ->
Logger.error("User not found: #{inspect(reason)} #{inspect({user, external_id})}")

:ok = send_error(sock, "XX000", "Tenant or user not found")
:ok = HH.send_error(sock, "XX000", "Tenant or user not found")
{:stop, :normal, data}
end
end
Expand All @@ -166,7 +167,7 @@ defmodule Supavisor.ClientHandler do
Server.exchange_message(:final, "e=#{reason}")
end

sock_send(sock, msg)
HH.sock_send(sock, msg)

{:stop, :normal, data}

Expand All @@ -181,7 +182,7 @@ defmodule Supavisor.ClientHandler do
end

Logger.debug("Exchange success")
:ok = sock_send(sock, Server.authentication_ok())
:ok = HH.sock_send(sock, Server.authentication_ok())

{:keep_state, %{data | auth_secrets: {method, secrets}},
{:next_event, :internal, :subscribe}}
Expand Down Expand Up @@ -210,7 +211,7 @@ defmodule Supavisor.ClientHandler do
{:error, :max_clients_reached} ->
msg = "Max client connections reached"
Logger.error(msg)
:ok = send_error(data.sock, "XX000", msg)
:ok = HH.send_error(data.sock, "XX000", msg)
{:stop, :normal, data}

error ->
Expand All @@ -220,7 +221,7 @@ defmodule Supavisor.ClientHandler do
end

def handle_event(:internal, {:greetings, ps}, _, %{sock: sock} = data) do
:ok = sock_send(sock, Server.greetings(ps))
:ok = HH.sock_send(sock, Server.greetings(ps))

if data.idle_timeout > 0 do
{:next_state, :idle, data, idle_check(data.idle_timeout)}
Expand Down Expand Up @@ -254,7 +255,7 @@ defmodule Supavisor.ClientHandler do
# handle Sync message
def handle_event(:info, {proto, _, <<?S, 4::32>>}, :idle, data) when proto in [:tcp, :ssl] do
Logger.debug("Receive sync")
:ok = sock_send(data.sock, Server.ready_for_query())
:ok = HH.sock_send(data.sock, Server.ready_for_query())

if data.idle_timeout > 0 do
{:keep_state_and_data, idle_check(data.idle_timeout)}
Expand Down Expand Up @@ -283,7 +284,7 @@ defmodule Supavisor.ClientHandler do
if size > 1_000_000 do
msg = "Db buffer size is too big: #{size}"
Logger.error(msg)
sock_send(data.sock, Server.error_message("XX000", msg))
HH.sock_send(data.sock, Server.error_message("XX000", msg))
{:stop, :normal, data}
else
Logger.debug("DB call buffering")
Expand All @@ -293,7 +294,7 @@ defmodule Supavisor.ClientHandler do
{:error, reason} ->
msg = "DB call error: #{inspect(reason)}"
Logger.error(msg)
sock_send(data.sock, Server.error_message("XX000", msg))
HH.sock_send(data.sock, Server.error_message("XX000", msg))
{:stop, :normal, data}
end
end
Expand Down Expand Up @@ -339,7 +340,7 @@ defmodule Supavisor.ClientHandler do
def handle_event({:call, from}, {:client_call, bin, ready?}, _, data) do
Logger.debug("--> --> bin #{inspect(byte_size(bin))} bytes")

reply = {:reply, from, sock_send(data.sock, bin)}
reply = {:reply, from, HH.sock_send(data.sock, bin)}

if ready? do
Logger.debug("Client is ready")
Expand Down Expand Up @@ -392,78 +393,18 @@ defmodule Supavisor.ClientHandler do
end

Logger.error(msg)
sock_send(data.sock, Server.error_message("XX000", msg))
HH.sock_send(data.sock, Server.error_message("XX000", msg))
:ok
end

def terminate(_reason, _state, _data), do: :ok

## Internal functions

@spec parse_user_info(map) :: {String.t() | nil, String.t()}
def parse_user_info(%{"user" => user, "options" => %{"reference" => ref}}) do
{user, ref}
end

def parse_user_info(%{"user" => user}) do
case :binary.matches(user, ".") do
[] ->
{user, nil}

matches ->
{pos, 1} = List.last(matches)
<<name::size(pos)-binary, ?., external_id::binary>> = user
{name, external_id}
end
end

def decode_startup_packet(<<len::integer-32, _protocol::binary-4, rest::binary>>) do
with {:ok, payload} <- decode_startup_packet_payload(rest) do
pkt = %{
len: len,
payload: payload,
tag: :startup
}

{:ok, pkt}
end
end

def decode_startup_packet(_) do
{:error, :bad_startup_payload}
end

# The startup packet payload is a list of key/value pairs, separated by null bytes
defp decode_startup_packet_payload(payload) do
fields = String.split(payload, <<0>>, trim: true)

# If the number of fields is odd, then the payload is malformed
if rem(length(fields), 2) == 1 do
{:error, :bad_startup_payload}
else
map =
fields
|> Enum.chunk_every(2)
|> Enum.map(fn
["options" = k, v] -> {k, URI.decode_query(v)}
[k, v] -> {k, v}
end)
|> Map.new()

# We only do light validation on the fields in the payload. The only field we use at the
# moment is `user`. If that's missing, this is a bad payload.
if Map.has_key?(map, "user") do
{:ok, map}
else
{:error, :bad_startup_payload}
end
end
end

@spec handle_exchange(S.sock(), {atom(), fun()}) :: {:ok, binary() | nil} | {:error, String.t()}
def handle_exchange({_, socket} = sock, {:auth_query_md5 = method, secrets}) do
salt = :crypto.strong_rand_bytes(4)
:ok = sock_send(sock, Server.md5_request(salt))
:ok = HH.sock_send(sock, Server.md5_request(salt))

with {:ok,
%{
Expand All @@ -479,7 +420,7 @@ defmodule Supavisor.ClientHandler do
end

def handle_exchange({_, socket} = sock, {method, secrets}) do
:ok = sock_send(sock, Server.scram_request())
:ok = HH.sock_send(sock, Server.scram_request())

with {:ok,
%{
Expand All @@ -494,7 +435,7 @@ defmodule Supavisor.ClientHandler do
}, _} <- receive_next(socket, "Timeout while waiting for the second password message"),
{:ok, key} <- authenticate_exchange(method, secrets, signatures, p) do
message = "v=#{Base.encode64(signatures.server)}"
:ok = sock_send(sock, Server.exchange_message(:final, message))
:ok = HH.sock_send(sock, Server.exchange_message(:final, message))
{:ok, key}
else
{:error, message} -> {:error, message}
Expand All @@ -512,7 +453,7 @@ defmodule Supavisor.ClientHandler do

defp reply_first_exchange(sock, method, secrets, channel, nonce, user) do
{message, signatures} = exchange_first(method, secrets, nonce, user, channel)
:ok = sock_send(sock, Server.exchange_message(:first, message))
:ok = HH.sock_send(sock, Server.exchange_message(:first, message))
{:ok, signatures}
end

Expand Down Expand Up @@ -585,23 +526,6 @@ defmodule Supavisor.ClientHandler do
}
end

@spec sock_send(S.sock(), iodata()) :: :ok | {:error, term()}
defp sock_send({mod, sock}, data) do
mod.send(sock, data)
end

@spec send_error(S.sock(), String.t(), String.t()) :: :ok | {:error, term()}
defp send_error(sock, code, message) do
data = Server.error_message(code, message)
sock_send(sock, data)
end

@spec setopts(S.sock(), term()) :: :ok | {:error, term()}
defp setopts({mod, sock}, opts) do
mod = if mod == :gen_tcp, do: :inet, else: mod
mod.setopts(sock, opts)
end

@spec auth_secrets(map, String.t()) :: {:ok, S.secrets()} | {:error, term()}
## password secrets
def auth_secrets(%{user: user, tenant: %{require_user: true}}, _) do
Expand Down Expand Up @@ -704,16 +628,6 @@ defmodule Supavisor.ClientHandler do
{message, sings}
end

@spec try_get_sni(S.sock()) :: String.t() | nil
def try_get_sni({:ssl, sock}) do
case :ssl.connection_information(sock, [:sni_hostname]) do
{:ok, [sni_hostname: sni]} -> List.to_string(sni)
_ -> nil
end
end

def try_get_sni(_), do: nil

@spec idle_check(non_neg_integer) :: {:timeout, non_neg_integer, :idle_terminate}
defp idle_check(timeout) do
{:timeout, timeout, :idle_terminate}
Expand Down
Loading

0 comments on commit 3cba0b0

Please sign in to comment.