diff --git a/lib/mariaex.ex b/lib/mariaex.ex index 9869f5a..9bbfc1a 100644 --- a/lib/mariaex.ex +++ b/lib/mariaex.ex @@ -239,10 +239,12 @@ defmodule Mariaex do {:ok, Mariaex.Result.t} | {:error, Mariaex.Error.t} def execute(conn, %Query{} = query, params, opts \\ []) do case DBConnection.execute(conn, query, params, defaults(opts)) do - {:error, %ArgumentError{} = err} -> + {:ok, _} = ok -> + ok + {:error, %Mariaex.Error{}} = error -> + error + {:error, err} -> raise err - other -> - other end end @@ -288,10 +290,10 @@ defmodule Mariaex do case DBConnection.close(conn, query, defaults(opts)) do {:ok, _} -> :ok - {:error, %ArgumentError{} = err} -> + {:error, %Mariaex.Error{}} = error -> + error + {:error, err} -> raise err - other -> - other end end @@ -420,19 +422,21 @@ defmodule Mariaex do case DBConnection.prepare_execute(conn, query, params, defaults(opts)) do {:ok, _, result} -> {:ok, result} - {:error, %ArgumentError{} = err} -> - raise err - {:error, _} = error -> + {:error, %Mariaex.Error{}} = error -> error + {:error, err} -> + raise err end end defp prepare_binary(conn, query, opts) do case DBConnection.prepare(conn, query, defaults(opts)) do - {:error, %ArgumentError{} = err} -> + {:ok, _} = ok -> + ok + {:error, %Mariaex.Error{}} = error -> + error + {:error, err} -> raise err - other -> - other end end diff --git a/lib/mariaex/protocol.ex b/lib/mariaex/protocol.ex index 50619ec..72f70ee 100644 --- a/lib/mariaex/protocol.ex +++ b/lib/mariaex/protocol.ex @@ -12,10 +12,10 @@ defmodule Mariaex.Protocol do use DBConnection use Bitwise - @reserved_prefix "MARIAEX_" @timeout 5000 @cache_size 100 @max_rows 500 + @nonposix_errors [:closed, :timeout] @maxpacketbytes 50000000 @mysql_native_password "mysql_native_password" @@ -35,6 +35,7 @@ defmodule Mariaex.Protocol do @client_ps_multi_results 0x00040000 @client_deprecate_eof 0x01000000 + @server_status_in_trans 0x0001 @server_more_results_exists 0x0008 @server_status_cursor_exists 0x0040 @server_status_last_row_sent 0x0080 @@ -63,6 +64,7 @@ defmodule Mariaex.Protocol do seqnum: 0, datetime: :structs, json_library: Poison, + transaction_status: :idle, ssl_conn_state: :undefined # :undefined | :not_used | :ssl_handshake | :connected @doc """ @@ -318,9 +320,6 @@ defmodule Mariaex.Protocol do @doc """ DBConnection callback """ - def handle_prepare(%Query{name: @reserved_prefix <> _} = query, _, s) do - reserved_error(query, s) - end def handle_prepare(%Query{type: nil} = query, opts, s) do case handle_prepare(%Query{query | type: :binary}, opts, s) do {:error, %Mariaex.Error{mariadb: %{code: 1295}}, s} -> @@ -377,8 +376,8 @@ defmodule Mariaex.Protocol do defp prepare_recv(state, query) do case prepare_recv(state) do - {:prepared, id, num_params, state} -> - {:ok, prepare_insert(id, num_params, query, state), clean_state(state)} + {:prepared, id, num_params, flags, state} -> + {:ok, prepare_insert(id, num_params, query, state), clean_state(state, flags)} {:ok, packet, state} -> handle_error(packet, query, state) {:error, reason} -> @@ -389,13 +388,13 @@ defmodule Mariaex.Protocol do defp prepare_recv(state) do state = %{state | state: :prepare_send} with {:ok, packet(msg: stmt_prepare_ok(statement_id: id, num_columns: num_cols, num_params: num_params)), state} <- msg_recv(state), - {:eof, state} <- skip_definitions(state, num_params), - {:eof, state} <- skip_definitions(state, num_cols) do - {:prepared, id, num_params, state} + {:eof, _, state} <- skip_definitions(state, num_params), + {:eof, flags, state} <- skip_definitions(state, num_cols) do + {:prepared, id, num_params, flags, state} end end - defp skip_definitions(state, 0), do: {:eof, state} + defp skip_definitions(state, 0), do: {:eof, nil, state} defp skip_definitions(state, count) do do_skip_definitions(%{state | state: :column_definitions}, count) end @@ -409,12 +408,12 @@ defmodule Mariaex.Protocol do end end defp do_skip_definitions(%{deprecated_eof: true} = state, 0) do - {:eof, state} + {:eof, nil, state} end defp do_skip_definitions(%{deprecated_eof: false} = state, 0) do case msg_recv(state) do - {:ok, packet(msg: eof_resp()), state} -> - {:eof, state} + {:ok, packet(msg: eof_resp(status_flags: flags)), state} -> + {:eof, flags, state} other -> other end @@ -434,9 +433,6 @@ defmodule Mariaex.Protocol do @doc """ DBConnection callback """ - def handle_execute(%Query{name: @reserved_prefix <> _, reserved?: false} = query, _, s) do - reserved_error(query, s) - end def handle_execute(%Query{type: :text, statement: statement} = query, [], _opts, state) do send_text_query(state, statement) |> text_query_recv(query) end @@ -501,9 +497,9 @@ defmodule Mariaex.Protocol do defp text_query_recv(state, query) do case text_query_recv(state) do - {:resultset, columns, rows, _flags, state} -> + {:resultset, columns, rows, flags, state} -> result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} - {:ok, {result, columns}, clean_state(state)} + {:ok, {result, columns}, clean_state(state, flags)} {:ok, packet(msg: ok_resp()) = packet, state} -> handle_ok_packet(packet, query, state) {:ok, packet, state} -> @@ -624,16 +620,16 @@ defmodule Mariaex.Protocol do binary_query_more(state, query, columns, rows) true -> result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} - {:ok, {result, columns}, clean_state(state)} + {:ok, {result, columns}, clean_state(state, flags)} end end defp binary_query_more(state, query, columns, rows) do case msg_recv(state) do - {:ok, packet(msg: ok_resp(affected_rows: affected_rows, last_insert_id: last_insert_id)), state} -> + {:ok, packet(msg: ok_resp(affected_rows: affected_rows, last_insert_id: last_insert_id, status_flags: flags)), state} -> result = %Mariaex.Result{rows: rows, num_rows: affected_rows, last_insert_id: last_insert_id, connection_id: state.connection_id} - {:ok, {result, columns}, clean_state(state)} + {:ok, {result, columns}, clean_state(state, flags)} {:ok, packet, state} -> handle_error(packet, query, state) {:error, reason} -> @@ -662,22 +658,45 @@ defmodule Mariaex.Protocol do end end - defp handle_ok_packet(packet(msg: ok_resp(affected_rows: affected_rows, last_insert_id: last_insert_id)), _query, s) do + defp handle_ok_packet(packet(msg: ok_resp(affected_rows: affected_rows, last_insert_id: last_insert_id, status_flags: flags)), _query, s) do result = %Mariaex.Result{columns: [], rows: nil, num_rows: affected_rows, last_insert_id: last_insert_id, connection_id: s.connection_id} - {:ok, {result, nil}, clean_state(s)} + {:ok, {result, nil}, clean_state(s, flags)} + end + + defp clean_state(state, flags) do + status = transaction_status(state, flags) + state = %{state | state: :running, state_data: nil, transaction_status: status} + case status do + :idle -> + clean_cursors(state) + :transaction -> + state + end end - defp clean_state(state) do - %{state | state: :running, state_data: nil} + defp transaction_status(_, flags) when is_integer(flags) do + case flags &&& @server_status_in_trans do + @server_status_in_trans -> + :transaction + 0 -> + :idle + end + end + defp transaction_status(%{transaction_status: status}, nil) do + status + end + + defp clean_cursors(%{cursors: cursors} = state) do + for {_ref, {_status, id, _info}} <- cursors, is_integer(id) do + msg_send(stmt_close(command: com_stmt_close(), statement_id: id), state, 0) + end + %{state | cursors: %{}} end @doc """ DBConnection callback """ - def handle_close(%Query{name: @reserved_prefix <> _ , reserved?: false} = query, _, s) do - reserved_error(query, s) - end def handle_close(%Query{type: :text}, _, s) do {:ok, nil, s} end @@ -708,17 +727,18 @@ defmodule Mariaex.Protocol do end end - def handle_declare(query, params, opts, state) do + def handle_declare(query, params, _, state) do case declare_lookup(query, state) do {:declare, id} -> - declare(id, params, opts, state) + cursor = %Cursor{statement_id: id, ref: make_ref()} + declare(cursor, params, state) {:prepare_declare, query} -> - prepare_declare(&prepare(query, &1), params, opts, state) + prepare_declare(&prepare(query, &1), params, state) {:close_prepare_declare, id, query} -> - prepare_declare(&close_prepare(id, query, &1), params, opts, state) + prepare_declare(&close_prepare(id, query, &1), params, state) {:text, _} -> - cursor = %Cursor{statement_id: :text, params: params, ref: make_ref()} - {:ok, cursor, state} + cursor = %Cursor{statement_id: :text, ref: make_ref()} + declare(cursor, params, state) end end @@ -744,17 +764,18 @@ defmodule Mariaex.Protocol do end end - defp declare(id, params, opts, state) do - max_rows = Keyword.get(opts, :max_rows, @max_rows) - cursor = %Cursor{statement_id: id, params: params, ref: make_ref(), max_rows: max_rows} - {:ok, cursor, state} + defp declare(%Cursor{ref: ref, statement_id: id} = cursor, params, state) do + state = put_in(state.cursors[ref], {:first, id, params}) + # close cursor if idle + {:ok, cursor, clean_state(state, nil)} end - defp prepare_declare(prepare, params, opts, state) do + defp prepare_declare(prepare, params, state) do case prepare.(state) do {:ok, query, state} -> id = prepare_declare_lookup(query, state) - declare(id, params, opts, state) + cursor = %Cursor{statement_id: id, ref: make_ref()} + declare(cursor, params, state) {err, _, _} = error when err in [:error, :disconnect] -> error end @@ -767,28 +788,59 @@ defmodule Mariaex.Protocol do Cache.take(cache, name) end - def handle_first(query, %Cursor{statement_id: :text, params: params}, opts, state) do + def handle_fetch(query, cursor, opts, state) do + %Cursor{ref: ref, statement_id: id} = cursor + %{cursors: cursors} = state + case cursors do + %{^ref => {:first, _, params}} -> + first(query, cursor, params, opts, state) |> fetch_result(ref, id) + %{^ref => {:cont, _, columns}} -> + next(query, cursor, columns, opts, state) |> fetch_result(ref, id) + %{^ref => {:halt, _, columns}} -> + # cursor finished, empty result + result = %Mariaex.Result{rows: [], num_rows: 0} + {:halt, {result, columns}, state} + %{} -> + msg = "could not find active cursor: #{inspect cursor}" + {:error, Mariaex.Error.exception(msg), state} + end + end + + defp fetch_result({:cont, {_, columns} = res, state}, ref, id) do + {:cont, res, put_in(state.cursors[ref], {:cont, id, columns})} + end + defp fetch_result({:halt, {_, columns} = res, state}, ref, id) do + {:halt, res, put_in(state.cursors[ref], {:halt, id, columns})} + end + defp fetch_result({:error, _, _} = error, _ref, _id) do + error + end + defp fetch_result({:disconnect, _, _} = disconnect, _ref, _id) do + disconnect + end + + defp first(query, %Cursor{statement_id: :text}, params, opts, state) do case handle_execute(query, params, opts, state) do {:ok, result, state} -> - {:deallocate, result, state} + {:halt, result, state} other -> other end end - def handle_first(query, %Cursor{statement_id: id, ref: ref, params: params}, _, state) do + defp first(query, %Cursor{statement_id: id}, params, _, state) do msg_send(stmt_execute(command: com_stmt_execute(), parameters: params, statement_id: id, flags: @cursor_type_read_only, iteration_count: 1), state, 0) - binary_first_recv(state, ref, query) + binary_first_recv(state, query) end - defp binary_first_recv(state, ref, query) do + defp binary_first_recv(state, query) do case binary_first_recv(state) do {:eof, columns, flags, state} -> - binary_first_resultset(state, query, ref, columns, [], flags) + binary_first_resultset(state, query, columns, [], flags) {:resultset, columns, rows, flags, state} -> - binary_first_resultset(state, query, ref, columns, rows, flags) + binary_first_resultset(state, query, columns, rows, flags) {:ok, packet(msg: ok_resp()) = packet, state} -> {:ok, result, state} = handle_ok_packet(packet, query, state) - {:deallocate, result, state} + {:halt, result, state} {:ok, packet, state} -> handle_error(packet, query, state) {:error, reason} -> @@ -805,35 +857,35 @@ defmodule Mariaex.Protocol do end end - defp binary_first_resultset(state, query, ref, columns, rows, flags) do + defp binary_first_resultset(state, query, columns, rows, flags) do cond do (flags &&& @server_more_results_exists) == @server_more_results_exists -> binary_first_more(state, query, columns, rows) (flags &&& @server_status_cursor_exists) == @server_status_cursor_exists -> - %{cursors: cursors} = state - state = clean_state(%{state | cursors: Map.put(cursors, ref, columns)}) - {:ok, {%Mariaex.Result{rows: rows, connection_id: state.connection_id}, columns}, state} + result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} + {:cont, {result, columns}, clean_state(state, flags)} true -> - {:deallocate, {%Mariaex.Result{rows: rows, connection_id: state.connection_id}, columns}, clean_state(state)} + result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} + {:halt, {result, columns}, clean_state(state, flags)} end end defp binary_first_more(state, query, columns, rows) do case binary_query_more(state, query, columns, rows) do {:ok, res, state} -> - {:deallocate, res, state} + {:halt, res, state} other -> other end end - def handle_next(query, %Cursor{statement_id: id, ref: ref, max_rows: max_rows}, _, state) do + defp next(query, %Cursor{statement_id: id}, columns, opts, state) do + max_rows = Keyword.get(opts, :max_rows, @max_rows) msg_send(stmt_fetch(command: com_stmt_fetch(), statement_id: id, num_rows: max_rows), state, 0) - binary_next_recv(state, ref, query) + binary_next_recv(state, query, columns) end - defp binary_next_recv(%{cursors: cursors} = state, ref, query) do - columns = Map.fetch!(cursors, ref) + defp binary_next_recv(state, query, columns) do case bin_rows_recv(state, columns) do {:eof, rows, flags, state} -> binary_next_resultset(state, columns, rows, flags) @@ -847,18 +899,24 @@ defmodule Mariaex.Protocol do defp binary_next_resultset(state, columns, rows, flags) do cond do (flags &&& @server_status_last_row_sent) == @server_status_last_row_sent -> - {:deallocate, {%Mariaex.Result{rows: rows, connection_id: state.connection_id}, columns}, clean_state(state)} + result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} + {:halt, {result, columns}, clean_state(state, flags)} (flags &&& @server_status_cursor_exists) == @server_status_cursor_exists -> - {:ok, {%Mariaex.Result{rows: rows, connection_id: state.connection_id}, columns}, clean_state(state)} + result = %Mariaex.Result{rows: rows, connection_id: state.connection_id} + {:cont, {result, columns}, clean_state(state, flags)} end end - def handle_deallocate(_, %Cursor{statement_id: :text}, _, state) do - {:ok, nil, state} - end - def handle_deallocate(query, %Cursor{statement_id: id, ref: ref}, _, state) do - %{cursors: cursors} = state - deallocate(id, query, %{state | cursors: Map.delete(cursors, ref)}) + def handle_deallocate(query, cursor, _, state) do + %Cursor{ref: ref, statement_id: id} = cursor + case pop_in(state.cursors[ref]) do + {nil, state} -> + {:ok, nil, state} + {_exists, state} when id == :text -> + {:ok, nil, state} + {_exists, state} -> + deallocate(id, query, state) + end end defp deallocate(id, query, state) do @@ -910,64 +968,77 @@ defmodule Mariaex.Protocol do @doc """ DBConnection callback """ - def handle_begin(opts, s) do + def handle_begin(opts, %{transaction_status: status} = s) do case Keyword.get(opts, :mode, :transaction) do - :transaction -> - name = @reserved_prefix <> "BEGIN" - handle_transaction(name, :begin, opts, s) - :savepoint -> - name = @reserved_prefix <> "SAVEPOINT mariaex_savepoint" - handle_savepoint([name], [:savepoint], opts, s) + :transaction when status == :idle -> + handle_transaction("BEGIN", s) + :savepoint when status == :transaction -> + handle_transaction("SAVEPOINT mariaex_savepoint", s) + mode when mode in [:transaction, :savepoint] -> + {status, s} end end @doc """ DBConnection callback """ - def handle_commit(opts, s) do + def handle_commit(opts, %{transaction_status: status} = s) do case Keyword.get(opts, :mode, :transaction) do - :transaction -> - name = @reserved_prefix <> "COMMIT" - handle_transaction(name, :commit, opts, s) - :savepoint -> - name = @reserved_prefix <> "RELEASE SAVEPOINT mariaex_savepoint" - handle_savepoint([name], [:release], opts, s) + :transaction when status == :transaction -> + handle_transaction("COMMIT", s) + :savepoint when status == :transaction -> + handle_transaction("RELEASE SAVEPOINT mariaex_savepoint", s) + mode when mode in [:transaction, :savepoint] -> + {status, s} end end @doc """ DBConnection callback """ - def handle_rollback(opts, s) do + def handle_rollback(opts, %{transaction_status: status} = s) do case Keyword.get(opts, :mode, :transaction) do - :transaction -> - name = @reserved_prefix <> "ROLLBACK" - handle_transaction(name, :rollback, opts, s) - :savepoint -> - names = [@reserved_prefix <> "ROLLBACK TO SAVEPOINT mariaex_savepoint", - @reserved_prefix <> "RELEASE SAVEPOINT mariaex_savepoint"] - handle_savepoint(names, [:rollback, :release], opts, s) + :transaction when status == :transaction -> + handle_transaction("ROLLBACK", s) + :savepoint when status == :transaction -> + rollback_release = + "ROLLBACK TO SAVEPOINT mariaex_savepoint; RELEASE SAVEPOINT mariaex_savepoint" + handle_transaction(rollback_release, s) + mode when mode in [:transaction, :savepoint] -> + {status, s} end end - defp handle_transaction(name, cmd, opts, state) do - query = %Query{type: :text, name: name, statement: to_string(cmd), reserved?: true} - handle_execute(query, [], opts, state) + @doc """ + DBConnection callback + """ + def handle_status(_, %{transaction_status: status} = state) do + {status, state} end - defp handle_savepoint(names, cmds, opts, state) do - Enum.zip(names, cmds) |> Enum.reduce({:ok, nil, state}, - fn({@reserved_prefix <> name, _cmd}, {:ok, _, state}) -> - query = %Query{type: :text, name: @reserved_prefix <> name, statement: name} - case handle_execute(query, [], opts, state) do - {:ok, res, state} -> - {:ok, res, state} - other -> - other - end - ({_name, _cmd}, {:error, _, _} = error) -> - error - end) + defp handle_transaction(statement, state) do + state + |> send_text_query(statement) + |> transaction_recv() + end + + defp transaction_recv(state) do + case msg_recv(state) do + {:ok, packet(msg: ok_resp(status_flags: flags)), state} + when (flags &&& @server_more_results_exists) == @server_more_results_exists -> + # rollback/release has multiple results + transaction_recv(state) + {:ok, packet(msg: ok_resp(status_flags: flags)), state} -> + result = %Mariaex.Result{columns: [], rows: nil, num_rows: 0, + last_insert_id: 0} + {:ok, result, clean_state(state, flags)} + {:ok, packet(msg: error_resp(error_code: code, error_message: message)), state} -> + err = %Mariaex.Error{mariadb: %{code: code, message: message}} + # connection in bad state and unlikely to recover + {:disconnect, err, state} + {:error, reason} -> + recv_error(reason, state) + end end defp recv_error(reason, %{sock: {sock_mod, _}} = state) do @@ -978,8 +1049,18 @@ defmodule Mariaex.Protocol do Do disconnect """ def do_disconnect(s, {tag, action, reason, buffer}) do - err = Mariaex.Error.exception(tag: tag, action: action, reason: reason) - do_disconnect(s, err, buffer) + msg = "#{tag} #{action}: #{format_error(tag, reason)}" + {:disconnect, DBConnection.ConnectionError.exception(msg), %{s | buffer: buffer}} + end + + defp format_error(_, reason) when reason in @nonposix_errors do + Atom.to_string(reason) + end + defp format_error(:tcp, reason) do + "#{:inet.format_error(reason)} - #{inspect(reason)}" + end + defp format_error(:ssl, reason) do + "#{:ssl.format_error(reason)} - #{inspect(reason)}" end defp do_disconnect(%{connection_id: connection_id} = state, %Mariaex.Error{} = err, buffer) do @@ -1152,14 +1233,9 @@ defmodule Mariaex.Protocol do case query do %Query{} -> {:ok, nil, s} = handle_close(query, [], s) - {:error, error, clean_state(s)} + {:error, error, clean_state(s, nil)} nil -> - {:error, error, clean_state(s)} + {:error, error, clean_state(s, nil)} end end - - defp reserved_error(query, s) do - error = ArgumentError.exception("query #{inspect query} uses reserved name") - {:error, error, s} - end end diff --git a/lib/mariaex/structs.ex b/lib/mariaex/structs.ex index 6527d85..6ef6591 100644 --- a/lib/mariaex/structs.ex +++ b/lib/mariaex/structs.ex @@ -35,5 +35,5 @@ end defmodule Mariaex.Cursor do @moduledoc false - defstruct [:ref, :statement_id, :params, max_rows: 0] + defstruct [:ref, :statement_id] end diff --git a/mix.exs b/mix.exs index 6d71435..354d927 100644 --- a/mix.exs +++ b/mix.exs @@ -25,7 +25,7 @@ defmodule Mariaex.Mixfile do defp deps do [{:decimal, "~> 1.0"}, - {:db_connection, "~> 1.1"}, + {:db_connection, "~> 1.1", github: "elixir-ecto/db_connection", ref: "4947966"}, {:coverex, "~> 1.4.10", only: :test}, {:ex_doc, ">= 0.0.0", only: :dev}, {:poison, ">= 0.0.0", optional: true}] diff --git a/mix.lock b/mix.lock index 8c88065..d7d061b 100644 --- a/mix.lock +++ b/mix.lock @@ -1,7 +1,7 @@ %{"certifi": {:hex, :certifi, "0.7.0", "861a57f3808f7eb0c2d1802afeaae0fa5de813b0df0979153cbafcd853ababaf", [:rebar3], [], "hexpm"}, "connection": {:hex, :connection, "1.0.4", "a1cae72211f0eef17705aaededacac3eb30e6625b04a6117c1b2db6ace7d5976", [:mix], [], "hexpm"}, "coverex": {:hex, :coverex, "1.4.10", "f6b68f95b3d51d04571a09dd2071c980e8398a38cf663db22b903ecad1083d51", [:mix], [{:httpoison, "~> 0.9", [hex: :httpoison, repo: "hexpm", optional: false]}, {:poison, "~> 1.5 or ~> 2.0", [hex: :poison, repo: "hexpm", optional: false]}], "hexpm"}, - "db_connection": {:hex, :db_connection, "1.1.0", "b2b88db6d7d12f99997b584d09fad98e560b817a20dab6a526830e339f54cdb3", [:mix], [{:connection, "~> 1.0.2", [hex: :connection, repo: "hexpm", optional: false]}, {:poolboy, "~> 1.5", [hex: :poolboy, repo: "hexpm", optional: true]}, {:sbroker, "~> 1.0", [hex: :sbroker, repo: "hexpm", optional: true]}], "hexpm"}, + "db_connection": {:git, "https://github.com/elixir-ecto/db_connection.git", "49479667131329376adf1c2c0e9a16bcf470aa84", [ref: "4947966"]}, "decimal": {:hex, :decimal, "1.1.0", "3333732f17a90ff3057d7ab8c65f0930ca2d67e15cca812a91ead5633ed472fe", [:mix], [], "hexpm"}, "earmark": {:hex, :earmark, "1.0.3", "89bdbaf2aca8bbb5c97d8b3b55c5dd0cff517ecc78d417e87f1d0982e514557b", [:mix], [], "hexpm"}, "ex_doc": {:hex, :ex_doc, "0.14.5", "c0433c8117e948404d93ca69411dd575ec6be39b47802e81ca8d91017a0cf83c", [:mix], [{:earmark, "~> 1.0", [hex: :earmark, repo: "hexpm", optional: false]}], "hexpm"}, diff --git a/test/query_test.exs b/test/query_test.exs index 4eda2c0..36f3948 100644 --- a/test/query_test.exs +++ b/test/query_test.exs @@ -40,7 +40,8 @@ defmodule QueryTest do Process.flag(:trap_exit, true) capture_log fn -> - assert %Mariaex.Error{} = query("DO SLEEP(0.1)", [], timeout: 0) + assert_raise DBConnection.ConnectionError, "tcp recv: closed", + fn -> query("DO SLEEP(10)", [], timeout: 50) end assert_receive {:EXIT, ^conn, {:shutdown, %DBConnection.ConnectionError{}}} end end diff --git a/test/stream_test.exs b/test/stream_test.exs index 0cc25f7..bbda627 100644 --- a/test/stream_test.exs +++ b/test/stream_test.exs @@ -17,9 +17,8 @@ defmodule StreamTest do test "simple text stream", context do assert Mariaex.transaction(context[:pid], fn(conn) -> - stream = Mariaex.stream(conn, "SELECT * FROM stream", [], []) - assert [%Mariaex.Result{num_rows: 0, rows: []}, - %Mariaex.Result{num_rows: 2, rows: [[1, "foo"], [2, "bar"]]}] = + stream = Mariaex.stream(conn, "SELECT * FROM stream", [], [query_type: :text]) + assert [%Mariaex.Result{num_rows: 2, rows: [[1, "foo"], [2, "bar"]]}] = Enum.to_list(stream) :done end) == {:ok, :done} @@ -147,6 +146,98 @@ defmodule StreamTest do end) == {:ok, :done} end + test "simple text cursor", context do + query = %Mariaex.Query{type: :text, statement: "SELECT * FROM stream", + ref: make_ref(), num_params: 0} + assert {:ok, cursor} = Mariaex.transaction(context[:pid], fn(conn) -> + assert {:ok, cursor} = DBConnection.declare(conn, query, []) + assert {:halt, %Mariaex.Result{num_rows: 2, rows: [[1, "foo"], [2, "bar"]]}} = DBConnection.fetch(conn, query, cursor) + + # no results once halt, don't re-execute + assert {:halt, %Mariaex.Result{num_rows: 0, rows: []}} = DBConnection.fetch(conn, query, cursor) + cursor + end) + + pid = context[:pid] + + # cursor gets removed when transaction ends + assert_raise Mariaex.Error, ~r"could not find active cursor", + fn -> DBConnection.fetch!(pid, query, cursor) end + + # deallocate should never fail + assert {:ok, _} = DBConnection.deallocate(pid, query, cursor) + end + + test "simple unnamed prepared cursor", context do + query = prepare("", "SELECT * FROM stream") + assert {:ok, cursor} = Mariaex.transaction(context[:pid], fn(conn) -> + cursor = DBConnection.declare!(conn, query, []) + assert {:cont, %Mariaex.Result{num_rows: 0, rows: []}} = + DBConnection.fetch(conn, query, cursor) + assert {:halt, %Mariaex.Result{num_rows: 2, rows: [[1, "foo"], [2, "bar"]]}} = + DBConnection.fetch(conn, query, cursor) + + # no results once halt, don't re-execute + assert {:halt, %Mariaex.Result{num_rows: 0, rows: []}} = DBConnection.fetch(conn, query, cursor) + cursor + end) + + pid = context[:pid] + + # cursor gets removed when transaction ends + assert_raise Mariaex.Error, ~r"could not find active cursor", + fn -> DBConnection.fetch!(pid, query, cursor) end + + # deallocate should never fail + assert {:ok, _} = DBConnection.deallocate(pid, query, cursor) + + assert [[1, "foo"], [2, "bar"]] = execute(query, []) + end + + test "simple named prepared cursor", context do + query = prepare("stream", "SELECT * FROM stream") + assert {:ok, cursor} = Mariaex.transaction(context[:pid], fn(conn) -> + cursor = DBConnection.declare!(conn, query, []) + assert {:cont, %Mariaex.Result{num_rows: 0, rows: []}} = + DBConnection.fetch(conn, query, cursor) + assert {:halt, %Mariaex.Result{num_rows: 2, rows: [[1, "foo"], [2, "bar"]]}} = + DBConnection.fetch(conn, query, cursor) + + # no results once halt, don't re-execute + assert {:halt, %Mariaex.Result{num_rows: 0, rows: []}} = DBConnection.fetch(conn, query, cursor) + cursor + end) + + pid = context[:pid] + + # cursor gets removed when transaction ends + assert_raise Mariaex.Error, ~r"could not find active cursor", + fn -> DBConnection.fetch!(pid, query, cursor) end + + # deallocate should never fail + assert {:ok, _} = DBConnection.deallocate(pid, query, cursor) + + assert [[1, "foo"], [2, "bar"]] = execute(query, []) + end + + test "fetch fetches max_rows", context do + query = prepare("", "SELECT * FROM stream") + assert Mariaex.transaction(context[:pid], fn(conn) -> + cursor = DBConnection.declare!(conn, query, []) + assert {:cont, %Mariaex.Result{num_rows: 0, rows: []}} = + DBConnection.fetch(conn, query, cursor) + assert {:cont, %Mariaex.Result{num_rows: 1, rows: [[1, "foo"]]}} = + DBConnection.fetch(conn, query, cursor, [max_rows: 1]) + assert {:ok, _} = DBConnection.deallocate(conn, query, cursor) + + assert %Mariaex.Result{rows: [[1, "foo"], [2, "bar"]]} = + Mariaex.execute!(conn, query, []) + :done + end) == {:ok, :done} + + assert [[1, "foo"], [2, "bar"]] = execute(query, []) + end + defp connect() do opts = [database: "mariaex_test", username: "mariaex_user", password: "mariaex_pass", backoff_type: :stop] Mariaex.Connection.start_link(opts) diff --git a/test/transaction_test.exs b/test/transaction_test.exs new file mode 100644 index 0000000..c650784 --- /dev/null +++ b/test/transaction_test.exs @@ -0,0 +1,79 @@ +defmodule TransactionTest do + use ExUnit.Case + import Mariaex.TestHelper + + setup do + opts = [database: "mariaex_test", username: "mariaex_user", password: "mariaex_pass", backoff_type: :stop] + {:ok, pid} = Mariaex.Connection.start_link(opts) + {:ok, [pid: pid]} + end + + test "transaction shows correct transaction status", context do + pid = context[:pid] + opts = [mode: :transaction] + + assert DBConnection.status(pid, opts) == :idle + assert query("SELECT 42", []) == [[42]] + assert DBConnection.status(pid, opts) == :idle + {conn, _} = DBConnection.begin!(pid, opts) + assert DBConnection.status(conn, opts) == :transaction + DBConnection.commit!(conn, opts) + assert DBConnection.status(pid, opts) == :idle + assert query("SELECT 42", []) == [[42]] + assert DBConnection.status(pid) == :idle + end + + test "can not begin transaction if already begun", context do + pid = context[:pid] + opts = [mode: :transaction] + + {conn, _} = DBConnection.begin!(pid, opts) + assert {:error, %DBConnection.TransactionError{status: :transaction}} = + DBConnection.begin(conn, opts) + DBConnection.commit!(conn, opts) + end + + test "can not commit or rollback transaction if not begun", context do + pid = context[:pid] + opts = [mode: :transaction] + + assert {:error, %DBConnection.TransactionError{status: :idle}} = + DBConnection.commit(pid, opts) + assert {:error, %DBConnection.TransactionError{status: :idle}} = + DBConnection.rollback(pid, opts) + end + + test "savepoint transaction shows correct transaction status", context do + pid = context[:pid] + opts = [mode: :savepoint] + + {conn, _} = DBConnection.begin!(pid, [mode: :transaction]) + assert DBConnection.status(conn, opts) == :transaction + + assert {:ok, conn, _} = DBConnection.begin(conn, opts) + assert DBConnection.status(conn, opts) == :transaction + DBConnection.commit!(conn, opts) + assert DBConnection.status(pid, opts) == :transaction + + assert {:ok, conn, _} = DBConnection.begin(pid, opts) + assert DBConnection.status(conn, opts) == :transaction + DBConnection.rollback!(conn, opts) + assert DBConnection.status(pid, opts) == :transaction + + DBConnection.commit!(pid, [mode: :transaction]) + assert DBConnection.status(pid) == :idle + assert query("SELECT 42", []) == [[42]] + end + + test "can not begin, commit or rollback savepoint transaction if not begun", context do + pid = context[:pid] + opts = [mode: :savepoint] + + assert {:error, %DBConnection.TransactionError{status: :idle}} = + DBConnection.begin(pid, opts) + assert {:error, %DBConnection.TransactionError{status: :idle}} = + DBConnection.commit(pid, opts) + assert {:error, %DBConnection.TransactionError{status: :idle}} = + DBConnection.rollback(pid, opts) + end +end