From c813c4a6d9b342ad85aa7444ce2e4915a176ef13 Mon Sep 17 00:00:00 2001 From: James Fish Date: Fri, 11 Aug 2017 12:14:32 +0100 Subject: [PATCH 1/6] Use database transaction status to prevent transaction errors --- lib/mariaex/protocol.ex | 97 +++++++++++++++++++++++++++-------------- mix.exs | 2 +- mix.lock | 2 +- 3 files changed, 66 insertions(+), 35 deletions(-) diff --git a/lib/mariaex/protocol.ex b/lib/mariaex/protocol.ex index 50619ec..952f565 100644 --- a/lib/mariaex/protocol.ex +++ b/lib/mariaex/protocol.ex @@ -35,6 +35,7 @@ defmodule Mariaex.Protocol do @client_ps_multi_results 0x00040000 @client_deprecate_eof 0x01000000 + @server_status_in_trans 0x0001 @server_more_results_exists 0x0008 @server_status_cursor_exists 0x0040 @server_status_last_row_sent 0x0080 @@ -63,6 +64,7 @@ defmodule Mariaex.Protocol do seqnum: 0, datetime: :structs, json_library: Poison, + transaction_status: :idle, ssl_conn_state: :undefined # :undefined | :not_used | :ssl_handshake | :connected @doc """ @@ -377,8 +379,8 @@ defmodule Mariaex.Protocol do defp prepare_recv(state, query) do case prepare_recv(state) do - {:prepared, id, num_params, state} -> - {:ok, prepare_insert(id, num_params, query, state), clean_state(state)} + {:prepared, id, num_params, flags, state} -> + {:ok, prepare_insert(id, num_params, query, state), clean_state(state, flags)} {:ok, packet, state} -> handle_error(packet, query, state) {:error, reason} -> @@ -389,13 +391,13 @@ defmodule Mariaex.Protocol do defp prepare_recv(state) do state = %{state | state: :prepare_send} with {:ok, packet(msg: stmt_prepare_ok(statement_id: id, num_columns: num_cols, num_params: num_params)), state} <- msg_recv(state), - {:eof, state} <- skip_definitions(state, num_params), - {:eof, state} <- skip_definitions(state, num_cols) do - {:prepared, id, num_params, state} + {:eof, _, state} <- skip_definitions(state, num_params), + {:eof, flags, state} <- skip_definitions(state, num_cols) do + {:prepared, id, num_params, flags, state} end end - defp skip_definitions(state, 0), do: {:eof, state} + defp skip_definitions(state, 0), do: {:eof, nil, state} defp skip_definitions(state, count) do do_skip_definitions(%{state | state: :column_definitions}, count) end @@ -409,12 +411,12 @@ defmodule Mariaex.Protocol do end end defp do_skip_definitions(%{deprecated_eof: true} = state, 0) do - {:eof, state} + {:eof, nil, state} end defp do_skip_definitions(%{deprecated_eof: false} = state, 0) do case msg_recv(state) do - {:ok, packet(msg: eof_resp()), state} -> - {:eof, state} + {:ok, packet(msg: eof_resp(status_flags: flags)), state} -> + {:eof, flags, state} other -> other end @@ -501,9 +503,9 @@ defmodule Mariaex.Protocol do defp text_query_recv(state, query) do case text_query_recv(state) do - {:resultset, columns, rows, _flags, state} -> + {:resultset, columns, rows, flags, state} -> result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} - {:ok, {result, columns}, clean_state(state)} + {:ok, {result, columns}, clean_state(state, flags)} {:ok, packet(msg: ok_resp()) = packet, state} -> handle_ok_packet(packet, query, state) {:ok, packet, state} -> @@ -624,16 +626,16 @@ defmodule Mariaex.Protocol do binary_query_more(state, query, columns, rows) true -> result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} - {:ok, {result, columns}, clean_state(state)} + {:ok, {result, columns}, clean_state(state, flags)} end end defp binary_query_more(state, query, columns, rows) do case msg_recv(state) do - {:ok, packet(msg: ok_resp(affected_rows: affected_rows, last_insert_id: last_insert_id)), state} -> + {:ok, packet(msg: ok_resp(affected_rows: affected_rows, last_insert_id: last_insert_id, status_flags: flags)), state} -> result = %Mariaex.Result{rows: rows, num_rows: affected_rows, last_insert_id: last_insert_id, connection_id: state.connection_id} - {:ok, {result, columns}, clean_state(state)} + {:ok, {result, columns}, clean_state(state, flags)} {:ok, packet, state} -> handle_error(packet, query, state) {:error, reason} -> @@ -662,14 +664,27 @@ defmodule Mariaex.Protocol do end end - defp handle_ok_packet(packet(msg: ok_resp(affected_rows: affected_rows, last_insert_id: last_insert_id)), _query, s) do + defp handle_ok_packet(packet(msg: ok_resp(affected_rows: affected_rows, last_insert_id: last_insert_id, status_flags: flags)), _query, s) do result = %Mariaex.Result{columns: [], rows: nil, num_rows: affected_rows, last_insert_id: last_insert_id, connection_id: s.connection_id} - {:ok, {result, nil}, clean_state(s)} + {:ok, {result, nil}, clean_state(s, flags)} end - defp clean_state(state) do - %{state | state: :running, state_data: nil} + defp clean_state(state, flags) do + status = transaction_status(state, flags) + %{state | state: :running, state_data: nil, transaction_status: status} + end + + defp transaction_status(_, flags) when is_integer(flags) do + case flags &&& @server_status_in_trans do + @server_status_in_trans -> + :transaction + 0 -> + :idle + end + end + defp transaction_status(%{transaction_status: status}, nil) do + status end @doc """ @@ -811,10 +826,11 @@ defmodule Mariaex.Protocol do binary_first_more(state, query, columns, rows) (flags &&& @server_status_cursor_exists) == @server_status_cursor_exists -> %{cursors: cursors} = state - state = clean_state(%{state | cursors: Map.put(cursors, ref, columns)}) + state = clean_state(%{state | cursors: Map.put(cursors, ref, columns)}, flags) {:ok, {%Mariaex.Result{rows: rows, connection_id: state.connection_id}, columns}, state} true -> - {:deallocate, {%Mariaex.Result{rows: rows, connection_id: state.connection_id}, columns}, clean_state(state)} + result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} + {:deallocate, {result, columns}, clean_state(state, flags)} end end @@ -847,9 +863,11 @@ defmodule Mariaex.Protocol do defp binary_next_resultset(state, columns, rows, flags) do cond do (flags &&& @server_status_last_row_sent) == @server_status_last_row_sent -> - {:deallocate, {%Mariaex.Result{rows: rows, connection_id: state.connection_id}, columns}, clean_state(state)} + result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} + {:deallocate, {result, columns}, clean_state(state, flags)} (flags &&& @server_status_cursor_exists) == @server_status_cursor_exists -> - {:ok, {%Mariaex.Result{rows: rows, connection_id: state.connection_id}, columns}, clean_state(state)} + result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} + {:ok, {result, columns}, clean_state(state, flags)} end end @@ -910,46 +928,59 @@ defmodule Mariaex.Protocol do @doc """ DBConnection callback """ - def handle_begin(opts, s) do + def handle_begin(opts, %{transaction_status: status} = s) do case Keyword.get(opts, :mode, :transaction) do - :transaction -> + :transaction when status == :idle -> name = @reserved_prefix <> "BEGIN" handle_transaction(name, :begin, opts, s) - :savepoint -> + :savepoint when status == :transaction -> name = @reserved_prefix <> "SAVEPOINT mariaex_savepoint" handle_savepoint([name], [:savepoint], opts, s) + mode when mode in [:transaction, :savepoint] -> + {status, s} end end @doc """ DBConnection callback """ - def handle_commit(opts, s) do + def handle_commit(opts, %{transaction_status: status} = s) do case Keyword.get(opts, :mode, :transaction) do - :transaction -> + :transaction when status == :transaction -> name = @reserved_prefix <> "COMMIT" handle_transaction(name, :commit, opts, s) - :savepoint -> + :savepoint when status == :transaction -> name = @reserved_prefix <> "RELEASE SAVEPOINT mariaex_savepoint" handle_savepoint([name], [:release], opts, s) + mode when mode in [:transaction, :savepoint] -> + {status, s} end end @doc """ DBConnection callback """ - def handle_rollback(opts, s) do + def handle_rollback(opts, %{transaction_status: status} = s) do case Keyword.get(opts, :mode, :transaction) do - :transaction -> + :transaction when status == :transaction -> name = @reserved_prefix <> "ROLLBACK" handle_transaction(name, :rollback, opts, s) - :savepoint -> + :savepoint when status == :transaction -> names = [@reserved_prefix <> "ROLLBACK TO SAVEPOINT mariaex_savepoint", @reserved_prefix <> "RELEASE SAVEPOINT mariaex_savepoint"] handle_savepoint(names, [:rollback, :release], opts, s) + mode when mode in [:transaction, :savepoint] -> + {status, s} end end + @doc """ + DBConnection callback + """ + def handle_status(_, %{transaction_status: status} = state) do + {status, state} + end + defp handle_transaction(name, cmd, opts, state) do query = %Query{type: :text, name: name, statement: to_string(cmd), reserved?: true} handle_execute(query, [], opts, state) @@ -1152,9 +1183,9 @@ defmodule Mariaex.Protocol do case query do %Query{} -> {:ok, nil, s} = handle_close(query, [], s) - {:error, error, clean_state(s)} + {:error, error, clean_state(s, nil)} nil -> - {:error, error, clean_state(s)} + {:error, error, clean_state(s, nil)} end end diff --git a/mix.exs b/mix.exs index 6d71435..354d927 100644 --- a/mix.exs +++ b/mix.exs @@ -25,7 +25,7 @@ defmodule Mariaex.Mixfile do defp deps do [{:decimal, "~> 1.0"}, - {:db_connection, "~> 1.1"}, + {:db_connection, "~> 1.1", github: "elixir-ecto/db_connection", ref: "4947966"}, {:coverex, "~> 1.4.10", only: :test}, {:ex_doc, ">= 0.0.0", only: :dev}, {:poison, ">= 0.0.0", optional: true}] diff --git a/mix.lock b/mix.lock index 8c88065..d7d061b 100644 --- a/mix.lock +++ b/mix.lock @@ -1,7 +1,7 @@ %{"certifi": {:hex, :certifi, "0.7.0", "861a57f3808f7eb0c2d1802afeaae0fa5de813b0df0979153cbafcd853ababaf", [:rebar3], [], "hexpm"}, "connection": {:hex, :connection, "1.0.4", "a1cae72211f0eef17705aaededacac3eb30e6625b04a6117c1b2db6ace7d5976", [:mix], [], "hexpm"}, "coverex": {:hex, :coverex, "1.4.10", "f6b68f95b3d51d04571a09dd2071c980e8398a38cf663db22b903ecad1083d51", [:mix], [{:httpoison, "~> 0.9", [hex: :httpoison, repo: "hexpm", optional: false]}, {:poison, "~> 1.5 or ~> 2.0", [hex: :poison, repo: "hexpm", optional: false]}], "hexpm"}, - "db_connection": {:hex, :db_connection, "1.1.0", "b2b88db6d7d12f99997b584d09fad98e560b817a20dab6a526830e339f54cdb3", [:mix], [{:connection, "~> 1.0.2", [hex: :connection, repo: "hexpm", optional: false]}, {:poolboy, "~> 1.5", [hex: :poolboy, repo: "hexpm", optional: true]}, {:sbroker, "~> 1.0", [hex: :sbroker, repo: "hexpm", optional: true]}], "hexpm"}, + "db_connection": {:git, "https://github.com/elixir-ecto/db_connection.git", "49479667131329376adf1c2c0e9a16bcf470aa84", [ref: "4947966"]}, "decimal": {:hex, :decimal, "1.1.0", "3333732f17a90ff3057d7ab8c65f0930ca2d67e15cca812a91ead5633ed472fe", [:mix], [], "hexpm"}, "earmark": {:hex, :earmark, "1.0.3", "89bdbaf2aca8bbb5c97d8b3b55c5dd0cff517ecc78d417e87f1d0982e514557b", [:mix], [], "hexpm"}, "ex_doc": {:hex, :ex_doc, "0.14.5", "c0433c8117e948404d93ca69411dd575ec6be39b47802e81ca8d91017a0cf83c", [:mix], [{:earmark, "~> 1.0", [hex: :earmark, repo: "hexpm", optional: false]}], "hexpm"}, From 75d1482d21eb0b2331eaaec417e6365f4da79a10 Mon Sep 17 00:00:00 2001 From: James Fish Date: Sat, 12 Aug 2017 12:08:44 +0100 Subject: [PATCH 2/6] Use one text statement for all queries --- lib/mariaex/protocol.ex | 58 ++++++++++++++-------------- test/transaction_test.exs | 79 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 29 deletions(-) create mode 100644 test/transaction_test.exs diff --git a/lib/mariaex/protocol.ex b/lib/mariaex/protocol.ex index 952f565..d9889c6 100644 --- a/lib/mariaex/protocol.ex +++ b/lib/mariaex/protocol.ex @@ -931,11 +931,9 @@ defmodule Mariaex.Protocol do def handle_begin(opts, %{transaction_status: status} = s) do case Keyword.get(opts, :mode, :transaction) do :transaction when status == :idle -> - name = @reserved_prefix <> "BEGIN" - handle_transaction(name, :begin, opts, s) + handle_transaction("BEGIN", s) :savepoint when status == :transaction -> - name = @reserved_prefix <> "SAVEPOINT mariaex_savepoint" - handle_savepoint([name], [:savepoint], opts, s) + handle_transaction("SAVEPOINT mariaex_savepoint", s) mode when mode in [:transaction, :savepoint] -> {status, s} end @@ -947,11 +945,9 @@ defmodule Mariaex.Protocol do def handle_commit(opts, %{transaction_status: status} = s) do case Keyword.get(opts, :mode, :transaction) do :transaction when status == :transaction -> - name = @reserved_prefix <> "COMMIT" - handle_transaction(name, :commit, opts, s) + handle_transaction("COMMIT", s) :savepoint when status == :transaction -> - name = @reserved_prefix <> "RELEASE SAVEPOINT mariaex_savepoint" - handle_savepoint([name], [:release], opts, s) + handle_transaction("RELEASE SAVEPOINT mariaex_savepoint", s) mode when mode in [:transaction, :savepoint] -> {status, s} end @@ -963,12 +959,11 @@ defmodule Mariaex.Protocol do def handle_rollback(opts, %{transaction_status: status} = s) do case Keyword.get(opts, :mode, :transaction) do :transaction when status == :transaction -> - name = @reserved_prefix <> "ROLLBACK" - handle_transaction(name, :rollback, opts, s) + handle_transaction("ROLLBACK", s) :savepoint when status == :transaction -> - names = [@reserved_prefix <> "ROLLBACK TO SAVEPOINT mariaex_savepoint", - @reserved_prefix <> "RELEASE SAVEPOINT mariaex_savepoint"] - handle_savepoint(names, [:rollback, :release], opts, s) + rollback_release = + "ROLLBACK TO SAVEPOINT mariaex_savepoint; RELEASE SAVEPOINT mariaex_savepoint" + handle_transaction(rollback_release, s) mode when mode in [:transaction, :savepoint] -> {status, s} end @@ -981,24 +976,29 @@ defmodule Mariaex.Protocol do {status, state} end - defp handle_transaction(name, cmd, opts, state) do - query = %Query{type: :text, name: name, statement: to_string(cmd), reserved?: true} - handle_execute(query, [], opts, state) + defp handle_transaction(statement, state) do + state + |> send_text_query(statement) + |> transaction_recv() end - defp handle_savepoint(names, cmds, opts, state) do - Enum.zip(names, cmds) |> Enum.reduce({:ok, nil, state}, - fn({@reserved_prefix <> name, _cmd}, {:ok, _, state}) -> - query = %Query{type: :text, name: @reserved_prefix <> name, statement: name} - case handle_execute(query, [], opts, state) do - {:ok, res, state} -> - {:ok, res, state} - other -> - other - end - ({_name, _cmd}, {:error, _, _} = error) -> - error - end) + defp transaction_recv(state) do + case msg_recv(state) do + {:ok, packet(msg: ok_resp(status_flags: flags)), state} + when (flags &&& @server_more_results_exists) == @server_more_results_exists -> + # rollback/release has multiple results + transaction_recv(state) + {:ok, packet(msg: ok_resp(status_flags: flags)), state} -> + result = %Mariaex.Result{columns: [], rows: nil, num_rows: 0, + last_insert_id: 0} + {:ok, result, clean_state(state, flags)} + {:ok, packet(msg: error_resp(error_code: code, error_message: message)), state} -> + err = %Mariaex.Error{mariadb: %{code: code, message: message}} + # connection in bad state and unlikely to recover + {:disconnect, err, state} + {:error, reason} -> + recv_error(reason, state) + end end defp recv_error(reason, %{sock: {sock_mod, _}} = state) do diff --git a/test/transaction_test.exs b/test/transaction_test.exs new file mode 100644 index 0000000..c650784 --- /dev/null +++ b/test/transaction_test.exs @@ -0,0 +1,79 @@ +defmodule TransactionTest do + use ExUnit.Case + import Mariaex.TestHelper + + setup do + opts = [database: "mariaex_test", username: "mariaex_user", password: "mariaex_pass", backoff_type: :stop] + {:ok, pid} = Mariaex.Connection.start_link(opts) + {:ok, [pid: pid]} + end + + test "transaction shows correct transaction status", context do + pid = context[:pid] + opts = [mode: :transaction] + + assert DBConnection.status(pid, opts) == :idle + assert query("SELECT 42", []) == [[42]] + assert DBConnection.status(pid, opts) == :idle + {conn, _} = DBConnection.begin!(pid, opts) + assert DBConnection.status(conn, opts) == :transaction + DBConnection.commit!(conn, opts) + assert DBConnection.status(pid, opts) == :idle + assert query("SELECT 42", []) == [[42]] + assert DBConnection.status(pid) == :idle + end + + test "can not begin transaction if already begun", context do + pid = context[:pid] + opts = [mode: :transaction] + + {conn, _} = DBConnection.begin!(pid, opts) + assert {:error, %DBConnection.TransactionError{status: :transaction}} = + DBConnection.begin(conn, opts) + DBConnection.commit!(conn, opts) + end + + test "can not commit or rollback transaction if not begun", context do + pid = context[:pid] + opts = [mode: :transaction] + + assert {:error, %DBConnection.TransactionError{status: :idle}} = + DBConnection.commit(pid, opts) + assert {:error, %DBConnection.TransactionError{status: :idle}} = + DBConnection.rollback(pid, opts) + end + + test "savepoint transaction shows correct transaction status", context do + pid = context[:pid] + opts = [mode: :savepoint] + + {conn, _} = DBConnection.begin!(pid, [mode: :transaction]) + assert DBConnection.status(conn, opts) == :transaction + + assert {:ok, conn, _} = DBConnection.begin(conn, opts) + assert DBConnection.status(conn, opts) == :transaction + DBConnection.commit!(conn, opts) + assert DBConnection.status(pid, opts) == :transaction + + assert {:ok, conn, _} = DBConnection.begin(pid, opts) + assert DBConnection.status(conn, opts) == :transaction + DBConnection.rollback!(conn, opts) + assert DBConnection.status(pid, opts) == :transaction + + DBConnection.commit!(pid, [mode: :transaction]) + assert DBConnection.status(pid) == :idle + assert query("SELECT 42", []) == [[42]] + end + + test "can not begin, commit or rollback savepoint transaction if not begun", context do + pid = context[:pid] + opts = [mode: :savepoint] + + assert {:error, %DBConnection.TransactionError{status: :idle}} = + DBConnection.begin(pid, opts) + assert {:error, %DBConnection.TransactionError{status: :idle}} = + DBConnection.commit(pid, opts) + assert {:error, %DBConnection.TransactionError{status: :idle}} = + DBConnection.rollback(pid, opts) + end +end From bba273d92b804e65be4c2063a61d6d56d6925f69 Mon Sep 17 00:00:00 2001 From: James Fish Date: Sun, 13 Aug 2017 00:29:08 +0100 Subject: [PATCH 3/6] Support explicit cursor API --- lib/mariaex/protocol.ex | 114 +++++++++++++++++++++++++++++----------- lib/mariaex/structs.ex | 2 +- test/stream_test.exs | 79 ++++++++++++++++++++++++++-- 3 files changed, 160 insertions(+), 35 deletions(-) diff --git a/lib/mariaex/protocol.ex b/lib/mariaex/protocol.ex index d9889c6..3d8bfb7 100644 --- a/lib/mariaex/protocol.ex +++ b/lib/mariaex/protocol.ex @@ -672,7 +672,13 @@ defmodule Mariaex.Protocol do defp clean_state(state, flags) do status = transaction_status(state, flags) - %{state | state: :running, state_data: nil, transaction_status: status} + state = %{state | state: :running, state_data: nil, transaction_status: status} + case status do + :idle -> + clean_cursors(state) + :transaction -> + state + end end defp transaction_status(_, flags) when is_integer(flags) do @@ -687,6 +693,13 @@ defmodule Mariaex.Protocol do status end + defp clean_cursors(%{cursors: cursors} = state) do + for {_ref, {_status, id, _info}} <- cursors, is_integer(id) do + msg_send(stmt_close(command: com_stmt_close(), statement_id: id), state, 0) + end + %{state | cursors: %{}} + end + @doc """ DBConnection callback """ @@ -732,8 +745,8 @@ defmodule Mariaex.Protocol do {:close_prepare_declare, id, query} -> prepare_declare(&close_prepare(id, query, &1), params, opts, state) {:text, _} -> - cursor = %Cursor{statement_id: :text, params: params, ref: make_ref()} - {:ok, cursor, state} + cursor = %Cursor{statement_id: :text, ref: make_ref()} + declare(cursor, params, state) end end @@ -761,8 +774,14 @@ defmodule Mariaex.Protocol do defp declare(id, params, opts, state) do max_rows = Keyword.get(opts, :max_rows, @max_rows) - cursor = %Cursor{statement_id: id, params: params, ref: make_ref(), max_rows: max_rows} - {:ok, cursor, state} + cursor = %Cursor{statement_id: id, ref: make_ref(), max_rows: max_rows} + declare(cursor, params, state) + end + + defp declare(%Cursor{ref: ref, statement_id: id} = cursor, params, state) do + state = put_in(state.cursors[ref], {:first, id, params}) + # close cursor if idle + {:ok, cursor, clean_state(state, nil)} end defp prepare_declare(prepare, params, opts, state) do @@ -782,28 +801,59 @@ defmodule Mariaex.Protocol do Cache.take(cache, name) end - def handle_first(query, %Cursor{statement_id: :text, params: params}, opts, state) do + def handle_fetch(query, cursor, opts, state) do + %Cursor{ref: ref, statement_id: id} = cursor + %{cursors: cursors} = state + case cursors do + %{^ref => {:first, _, params}} -> + first(query, cursor, params, opts, state) |> fetch_result(ref, id) + %{^ref => {:cont, _, columns}} -> + next(query, cursor, columns, state) |> fetch_result(ref, id) + %{^ref => {:halt, _, columns}} -> + # cursor finished, empty result + result = %Mariaex.Result{rows: [], num_rows: 0} + {:halt, {result, columns}, state} + %{} -> + msg = "could not find active cursor: #{inspect cursor}" + {:error, Mariaex.Error.exception(msg), state} + end + end + + defp fetch_result({:cont, {_, columns} = res, state}, ref, id) do + {:cont, res, put_in(state.cursors[ref], {:cont, id, columns})} + end + defp fetch_result({:halt, {_, columns} = res, state}, ref, id) do + {:halt, res, put_in(state.cursors[ref], {:halt, id, columns})} + end + defp fetch_result({:error, _, _} = error, _ref, _id) do + error + end + defp fetch_result({:disconnect, _, _} = disconnect, _ref, _id) do + disconnect + end + + defp first(query, %Cursor{statement_id: :text}, params, opts, state) do case handle_execute(query, params, opts, state) do {:ok, result, state} -> - {:deallocate, result, state} + {:halt, result, state} other -> other end end - def handle_first(query, %Cursor{statement_id: id, ref: ref, params: params}, _, state) do + defp first(query, %Cursor{statement_id: id}, params, _, state) do msg_send(stmt_execute(command: com_stmt_execute(), parameters: params, statement_id: id, flags: @cursor_type_read_only, iteration_count: 1), state, 0) - binary_first_recv(state, ref, query) + binary_first_recv(state, query) end - defp binary_first_recv(state, ref, query) do + defp binary_first_recv(state, query) do case binary_first_recv(state) do {:eof, columns, flags, state} -> - binary_first_resultset(state, query, ref, columns, [], flags) + binary_first_resultset(state, query, columns, [], flags) {:resultset, columns, rows, flags, state} -> - binary_first_resultset(state, query, ref, columns, rows, flags) + binary_first_resultset(state, query, columns, rows, flags) {:ok, packet(msg: ok_resp()) = packet, state} -> {:ok, result, state} = handle_ok_packet(packet, query, state) - {:deallocate, result, state} + {:halt, result, state} {:ok, packet, state} -> handle_error(packet, query, state) {:error, reason} -> @@ -820,36 +870,34 @@ defmodule Mariaex.Protocol do end end - defp binary_first_resultset(state, query, ref, columns, rows, flags) do + defp binary_first_resultset(state, query, columns, rows, flags) do cond do (flags &&& @server_more_results_exists) == @server_more_results_exists -> binary_first_more(state, query, columns, rows) (flags &&& @server_status_cursor_exists) == @server_status_cursor_exists -> - %{cursors: cursors} = state - state = clean_state(%{state | cursors: Map.put(cursors, ref, columns)}, flags) - {:ok, {%Mariaex.Result{rows: rows, connection_id: state.connection_id}, columns}, state} + result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} + {:cont, {result, columns}, clean_state(state, flags)} true -> result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} - {:deallocate, {result, columns}, clean_state(state, flags)} + {:halt, {result, columns}, clean_state(state, flags)} end end defp binary_first_more(state, query, columns, rows) do case binary_query_more(state, query, columns, rows) do {:ok, res, state} -> - {:deallocate, res, state} + {:halt, res, state} other -> other end end - def handle_next(query, %Cursor{statement_id: id, ref: ref, max_rows: max_rows}, _, state) do + defp next(query, %Cursor{statement_id: id, max_rows: max_rows}, columns, state) do msg_send(stmt_fetch(command: com_stmt_fetch(), statement_id: id, num_rows: max_rows), state, 0) - binary_next_recv(state, ref, query) + binary_next_recv(state, query, columns) end - defp binary_next_recv(%{cursors: cursors} = state, ref, query) do - columns = Map.fetch!(cursors, ref) + defp binary_next_recv(state, query, columns) do case bin_rows_recv(state, columns) do {:eof, rows, flags, state} -> binary_next_resultset(state, columns, rows, flags) @@ -864,19 +912,23 @@ defmodule Mariaex.Protocol do cond do (flags &&& @server_status_last_row_sent) == @server_status_last_row_sent -> result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} - {:deallocate, {result, columns}, clean_state(state, flags)} + {:halt, {result, columns}, clean_state(state, flags)} (flags &&& @server_status_cursor_exists) == @server_status_cursor_exists -> result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} - {:ok, {result, columns}, clean_state(state, flags)} + {:cont, {result, columns}, clean_state(state, flags)} end end - def handle_deallocate(_, %Cursor{statement_id: :text}, _, state) do - {:ok, nil, state} - end - def handle_deallocate(query, %Cursor{statement_id: id, ref: ref}, _, state) do - %{cursors: cursors} = state - deallocate(id, query, %{state | cursors: Map.delete(cursors, ref)}) + def handle_deallocate(query, cursor, _, state) do + %Cursor{ref: ref, statement_id: id} = cursor + case pop_in(state.cursors[ref]) do + {nil, state} -> + {:ok, nil, state} + {_exists, state} when id == :text -> + {:ok, nil, state} + {_exists, state} -> + deallocate(id, query, state) + end end defp deallocate(id, query, state) do diff --git a/lib/mariaex/structs.ex b/lib/mariaex/structs.ex index 6527d85..6cf1f3c 100644 --- a/lib/mariaex/structs.ex +++ b/lib/mariaex/structs.ex @@ -35,5 +35,5 @@ end defmodule Mariaex.Cursor do @moduledoc false - defstruct [:ref, :statement_id, :params, max_rows: 0] + defstruct [:ref, :statement_id, max_rows: 0] end diff --git a/test/stream_test.exs b/test/stream_test.exs index 0cc25f7..e89674a 100644 --- a/test/stream_test.exs +++ b/test/stream_test.exs @@ -17,9 +17,8 @@ defmodule StreamTest do test "simple text stream", context do assert Mariaex.transaction(context[:pid], fn(conn) -> - stream = Mariaex.stream(conn, "SELECT * FROM stream", [], []) - assert [%Mariaex.Result{num_rows: 0, rows: []}, - %Mariaex.Result{num_rows: 2, rows: [[1, "foo"], [2, "bar"]]}] = + stream = Mariaex.stream(conn, "SELECT * FROM stream", [], [query_type: :text]) + assert [%Mariaex.Result{num_rows: 2, rows: [[1, "foo"], [2, "bar"]]}] = Enum.to_list(stream) :done end) == {:ok, :done} @@ -147,6 +146,80 @@ defmodule StreamTest do end) == {:ok, :done} end + test "simple text cursor", context do + query = %Mariaex.Query{type: :text, statement: "SELECT * FROM stream", + ref: make_ref(), num_params: 0} + assert {:ok, cursor} = Mariaex.transaction(context[:pid], fn(conn) -> + assert {:ok, cursor} = DBConnection.declare(conn, query, []) + assert {:halt, %Mariaex.Result{num_rows: 2, rows: [[1, "foo"], [2, "bar"]]}} = DBConnection.fetch(conn, query, cursor) + + # no results once halt, don't re-execute + assert {:halt, %Mariaex.Result{num_rows: 0, rows: []}} = DBConnection.fetch(conn, query, cursor) + cursor + end) + + pid = context[:pid] + + # cursor gets removed when transaction ends + assert_raise Mariaex.Error, ~r"could not find active cursor", + fn -> DBConnection.fetch!(pid, query, cursor) end + + # deallocate should never fail + assert {:ok, _} = DBConnection.deallocate(pid, query, cursor) + end + + test "simple unnamed prepared cursor", context do + query = prepare("", "SELECT * FROM stream") + assert {:ok, cursor} = Mariaex.transaction(context[:pid], fn(conn) -> + cursor = DBConnection.declare!(conn, query, []) + assert {:cont, %Mariaex.Result{num_rows: 0, rows: []}} = + DBConnection.fetch(conn, query, cursor) + assert {:halt, %Mariaex.Result{num_rows: 2, rows: [[1, "foo"], [2, "bar"]]}} = + DBConnection.fetch(conn, query, cursor) + + # no results once halt, don't re-execute + assert {:halt, %Mariaex.Result{num_rows: 0, rows: []}} = DBConnection.fetch(conn, query, cursor) + cursor + end) + + pid = context[:pid] + + # cursor gets removed when transaction ends + assert_raise Mariaex.Error, ~r"could not find active cursor", + fn -> DBConnection.fetch!(pid, query, cursor) end + + # deallocate should never fail + assert {:ok, _} = DBConnection.deallocate(pid, query, cursor) + + assert [[1, "foo"], [2, "bar"]] = execute(query, []) + end + + test "simple named prepared cursor", context do + query = prepare("stream", "SELECT * FROM stream") + assert {:ok, cursor} = Mariaex.transaction(context[:pid], fn(conn) -> + cursor = DBConnection.declare!(conn, query, []) + assert {:cont, %Mariaex.Result{num_rows: 0, rows: []}} = + DBConnection.fetch(conn, query, cursor) + assert {:halt, %Mariaex.Result{num_rows: 2, rows: [[1, "foo"], [2, "bar"]]}} = + DBConnection.fetch(conn, query, cursor) + + # no results once halt, don't re-execute + assert {:halt, %Mariaex.Result{num_rows: 0, rows: []}} = DBConnection.fetch(conn, query, cursor) + cursor + end) + + pid = context[:pid] + + # cursor gets removed when transaction ends + assert_raise Mariaex.Error, ~r"could not find active cursor", + fn -> DBConnection.fetch!(pid, query, cursor) end + + # deallocate should never fail + assert {:ok, _} = DBConnection.deallocate(pid, query, cursor) + + assert [[1, "foo"], [2, "bar"]] = execute(query, []) + end + defp connect() do opts = [database: "mariaex_test", username: "mariaex_user", password: "mariaex_pass", backoff_type: :stop] Mariaex.Connection.start_link(opts) From d088b77ecfc54172746a913edf9ccfea54bc072c Mon Sep 17 00:00:00 2001 From: James Fish Date: Sun, 13 Aug 2017 11:45:58 +0100 Subject: [PATCH 4/6] Remove reserved query name errors --- lib/mariaex/protocol.ex | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/lib/mariaex/protocol.ex b/lib/mariaex/protocol.ex index 3d8bfb7..8ccda58 100644 --- a/lib/mariaex/protocol.ex +++ b/lib/mariaex/protocol.ex @@ -12,7 +12,6 @@ defmodule Mariaex.Protocol do use DBConnection use Bitwise - @reserved_prefix "MARIAEX_" @timeout 5000 @cache_size 100 @max_rows 500 @@ -320,9 +319,6 @@ defmodule Mariaex.Protocol do @doc """ DBConnection callback """ - def handle_prepare(%Query{name: @reserved_prefix <> _} = query, _, s) do - reserved_error(query, s) - end def handle_prepare(%Query{type: nil} = query, opts, s) do case handle_prepare(%Query{query | type: :binary}, opts, s) do {:error, %Mariaex.Error{mariadb: %{code: 1295}}, s} -> @@ -436,9 +432,6 @@ defmodule Mariaex.Protocol do @doc """ DBConnection callback """ - def handle_execute(%Query{name: @reserved_prefix <> _, reserved?: false} = query, _, s) do - reserved_error(query, s) - end def handle_execute(%Query{type: :text, statement: statement} = query, [], _opts, state) do send_text_query(state, statement) |> text_query_recv(query) end @@ -703,9 +696,6 @@ defmodule Mariaex.Protocol do @doc """ DBConnection callback """ - def handle_close(%Query{name: @reserved_prefix <> _ , reserved?: false} = query, _, s) do - reserved_error(query, s) - end def handle_close(%Query{type: :text}, _, s) do {:ok, nil, s} end @@ -1240,9 +1230,4 @@ defmodule Mariaex.Protocol do {:error, error, clean_state(s, nil)} end end - - defp reserved_error(query, s) do - error = ArgumentError.exception("query #{inspect query} uses reserved name") - {:error, error, s} - end end From 64d51fb70fd3a0d624835d173b840b213e121b11 Mon Sep 17 00:00:00 2001 From: James Fish Date: Sun, 13 Aug 2017 11:57:28 +0100 Subject: [PATCH 5/6] Add max rows option to each cursor fetch --- lib/mariaex/protocol.ex | 25 +++++++++++-------------- lib/mariaex/structs.ex | 2 +- test/stream_test.exs | 18 ++++++++++++++++++ 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/lib/mariaex/protocol.ex b/lib/mariaex/protocol.ex index 8ccda58..1c1ff82 100644 --- a/lib/mariaex/protocol.ex +++ b/lib/mariaex/protocol.ex @@ -726,14 +726,15 @@ defmodule Mariaex.Protocol do end end - def handle_declare(query, params, opts, state) do + def handle_declare(query, params, _, state) do case declare_lookup(query, state) do {:declare, id} -> - declare(id, params, opts, state) + cursor = %Cursor{statement_id: id, ref: make_ref()} + declare(cursor, params, state) {:prepare_declare, query} -> - prepare_declare(&prepare(query, &1), params, opts, state) + prepare_declare(&prepare(query, &1), params, state) {:close_prepare_declare, id, query} -> - prepare_declare(&close_prepare(id, query, &1), params, opts, state) + prepare_declare(&close_prepare(id, query, &1), params, state) {:text, _} -> cursor = %Cursor{statement_id: :text, ref: make_ref()} declare(cursor, params, state) @@ -762,23 +763,18 @@ defmodule Mariaex.Protocol do end end - defp declare(id, params, opts, state) do - max_rows = Keyword.get(opts, :max_rows, @max_rows) - cursor = %Cursor{statement_id: id, ref: make_ref(), max_rows: max_rows} - declare(cursor, params, state) - end - defp declare(%Cursor{ref: ref, statement_id: id} = cursor, params, state) do state = put_in(state.cursors[ref], {:first, id, params}) # close cursor if idle {:ok, cursor, clean_state(state, nil)} end - defp prepare_declare(prepare, params, opts, state) do + defp prepare_declare(prepare, params, state) do case prepare.(state) do {:ok, query, state} -> id = prepare_declare_lookup(query, state) - declare(id, params, opts, state) + cursor = %Cursor{statement_id: id, ref: make_ref()} + declare(cursor, params, state) {err, _, _} = error when err in [:error, :disconnect] -> error end @@ -798,7 +794,7 @@ defmodule Mariaex.Protocol do %{^ref => {:first, _, params}} -> first(query, cursor, params, opts, state) |> fetch_result(ref, id) %{^ref => {:cont, _, columns}} -> - next(query, cursor, columns, state) |> fetch_result(ref, id) + next(query, cursor, columns, opts, state) |> fetch_result(ref, id) %{^ref => {:halt, _, columns}} -> # cursor finished, empty result result = %Mariaex.Result{rows: [], num_rows: 0} @@ -882,7 +878,8 @@ defmodule Mariaex.Protocol do end end - defp next(query, %Cursor{statement_id: id, max_rows: max_rows}, columns, state) do + defp next(query, %Cursor{statement_id: id}, columns, opts, state) do + max_rows = Keyword.get(opts, :max_rows, @max_rows) msg_send(stmt_fetch(command: com_stmt_fetch(), statement_id: id, num_rows: max_rows), state, 0) binary_next_recv(state, query, columns) end diff --git a/lib/mariaex/structs.ex b/lib/mariaex/structs.ex index 6cf1f3c..6ef6591 100644 --- a/lib/mariaex/structs.ex +++ b/lib/mariaex/structs.ex @@ -35,5 +35,5 @@ end defmodule Mariaex.Cursor do @moduledoc false - defstruct [:ref, :statement_id, max_rows: 0] + defstruct [:ref, :statement_id] end diff --git a/test/stream_test.exs b/test/stream_test.exs index e89674a..bbda627 100644 --- a/test/stream_test.exs +++ b/test/stream_test.exs @@ -220,6 +220,24 @@ defmodule StreamTest do assert [[1, "foo"], [2, "bar"]] = execute(query, []) end + test "fetch fetches max_rows", context do + query = prepare("", "SELECT * FROM stream") + assert Mariaex.transaction(context[:pid], fn(conn) -> + cursor = DBConnection.declare!(conn, query, []) + assert {:cont, %Mariaex.Result{num_rows: 0, rows: []}} = + DBConnection.fetch(conn, query, cursor) + assert {:cont, %Mariaex.Result{num_rows: 1, rows: [[1, "foo"]]}} = + DBConnection.fetch(conn, query, cursor, [max_rows: 1]) + assert {:ok, _} = DBConnection.deallocate(conn, query, cursor) + + assert %Mariaex.Result{rows: [[1, "foo"], [2, "bar"]]} = + Mariaex.execute!(conn, query, []) + :done + end) == {:ok, :done} + + assert [[1, "foo"], [2, "bar"]] = execute(query, []) + end + defp connect() do opts = [database: "mariaex_test", username: "mariaex_user", password: "mariaex_pass", backoff_type: :stop] Mariaex.Connection.start_link(opts) From cf37d5324879f9c90a481a6f8bcaee25be516a13 Mon Sep 17 00:00:00 2001 From: James Fish Date: Sun, 26 Nov 2017 15:24:44 -0800 Subject: [PATCH 6/6] Raise DBConnection.ConnectionError on connection error --- lib/mariaex.ex | 28 ++++++++++++++++------------ lib/mariaex/protocol.ex | 15 +++++++++++++-- test/query_test.exs | 3 ++- 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/lib/mariaex.ex b/lib/mariaex.ex index 9869f5a..9bbfc1a 100644 --- a/lib/mariaex.ex +++ b/lib/mariaex.ex @@ -239,10 +239,12 @@ defmodule Mariaex do {:ok, Mariaex.Result.t} | {:error, Mariaex.Error.t} def execute(conn, %Query{} = query, params, opts \\ []) do case DBConnection.execute(conn, query, params, defaults(opts)) do - {:error, %ArgumentError{} = err} -> + {:ok, _} = ok -> + ok + {:error, %Mariaex.Error{}} = error -> + error + {:error, err} -> raise err - other -> - other end end @@ -288,10 +290,10 @@ defmodule Mariaex do case DBConnection.close(conn, query, defaults(opts)) do {:ok, _} -> :ok - {:error, %ArgumentError{} = err} -> + {:error, %Mariaex.Error{}} = error -> + error + {:error, err} -> raise err - other -> - other end end @@ -420,19 +422,21 @@ defmodule Mariaex do case DBConnection.prepare_execute(conn, query, params, defaults(opts)) do {:ok, _, result} -> {:ok, result} - {:error, %ArgumentError{} = err} -> - raise err - {:error, _} = error -> + {:error, %Mariaex.Error{}} = error -> error + {:error, err} -> + raise err end end defp prepare_binary(conn, query, opts) do case DBConnection.prepare(conn, query, defaults(opts)) do - {:error, %ArgumentError{} = err} -> + {:ok, _} = ok -> + ok + {:error, %Mariaex.Error{}} = error -> + error + {:error, err} -> raise err - other -> - other end end diff --git a/lib/mariaex/protocol.ex b/lib/mariaex/protocol.ex index 1c1ff82..72f70ee 100644 --- a/lib/mariaex/protocol.ex +++ b/lib/mariaex/protocol.ex @@ -15,6 +15,7 @@ defmodule Mariaex.Protocol do @timeout 5000 @cache_size 100 @max_rows 500 + @nonposix_errors [:closed, :timeout] @maxpacketbytes 50000000 @mysql_native_password "mysql_native_password" @@ -1048,8 +1049,18 @@ defmodule Mariaex.Protocol do Do disconnect """ def do_disconnect(s, {tag, action, reason, buffer}) do - err = Mariaex.Error.exception(tag: tag, action: action, reason: reason) - do_disconnect(s, err, buffer) + msg = "#{tag} #{action}: #{format_error(tag, reason)}" + {:disconnect, DBConnection.ConnectionError.exception(msg), %{s | buffer: buffer}} + end + + defp format_error(_, reason) when reason in @nonposix_errors do + Atom.to_string(reason) + end + defp format_error(:tcp, reason) do + "#{:inet.format_error(reason)} - #{inspect(reason)}" + end + defp format_error(:ssl, reason) do + "#{:ssl.format_error(reason)} - #{inspect(reason)}" end defp do_disconnect(%{connection_id: connection_id} = state, %Mariaex.Error{} = err, buffer) do diff --git a/test/query_test.exs b/test/query_test.exs index 4eda2c0..36f3948 100644 --- a/test/query_test.exs +++ b/test/query_test.exs @@ -40,7 +40,8 @@ defmodule QueryTest do Process.flag(:trap_exit, true) capture_log fn -> - assert %Mariaex.Error{} = query("DO SLEEP(0.1)", [], timeout: 0) + assert_raise DBConnection.ConnectionError, "tcp recv: closed", + fn -> query("DO SLEEP(10)", [], timeout: 50) end assert_receive {:EXIT, ^conn, {:shutdown, %DBConnection.ConnectionError{}}} end end