From f573831639d0006a3884693eafbdb200b57f3f17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pedro=20Pi=C3=B1era=20Buend=C3=ADa?= Date: Mon, 20 Apr 2026 14:39:51 +0200 Subject: [PATCH 1/3] Add experimental transaction support --- lib/ch.ex | 2 + lib/ch/connection.ex | 137 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 130 insertions(+), 9 deletions(-) diff --git a/lib/ch.ex b/lib/ch.ex index 6f2f567..28e8a75 100644 --- a/lib/ch.ex +++ b/lib/ch.ex @@ -16,6 +16,8 @@ defmodule Ch do | {:username, String.t()} | {:password, String.t()} | {:settings, Keyword.t()} + | {:session_id, String.t()} + | {:session_timeout, pos_integer()} | {:timeout, timeout} @typedoc """ diff --git a/lib/ch/connection.ex b/lib/ch/connection.ex index b53394a..d00b3ca 100644 --- a/lib/ch/connection.ex +++ b/lib/ch/connection.ex @@ -6,12 +6,17 @@ defmodule Ch.Connection do alias Mint.HTTP1, as: HTTP @user_agent "ch/" <> Mix.Project.config()[:version] + @default_session_timeout 300 + @session_id_key :session_id + @session_timeout_key :session_timeout + @transaction_status_key :transaction_status @typep conn :: HTTP.t() @impl true @spec connect([Ch.start_option()]) :: {:ok, conn} | {:error, Error.t() | Mint.Types.error()} def connect(opts) do + opts = put_default_session_opts(opts) scheme = String.to_existing_atom(opts[:scheme] || "http") address = opts[:hostname] || "localhost" port = opts[:port] || 8123 @@ -25,6 +30,9 @@ defmodule Ch.Connection do |> maybe_put_private(:username, opts[:username]) |> maybe_put_private(:password, opts[:password]) |> maybe_put_private(:settings, opts[:settings]) + |> maybe_put_private(@session_id_key, opts[@session_id_key]) + |> maybe_put_private(@session_timeout_key, opts[@session_timeout_key]) + |> put_transaction_status(:idle) handshake = Query.build("select 1, version()") params = DBConnection.Query.encode(handshake, _params = [], _opts = []) @@ -93,17 +101,53 @@ defmodule Ch.Connection do @spec checkout(conn) :: {:ok, conn} def checkout(conn), do: {:ok, conn} - # we "support" these four tx callbacks for Repo.checkout - # even though ClickHouse doesn't support txs - @impl true - def handle_begin(_opts, conn), do: {:ok, %{}, conn} + def handle_begin(opts, conn) do + case {Keyword.get(opts, :mode, :transaction), transaction_status(conn)} do + {:transaction, :idle} -> + execute_transaction_command("BEGIN TRANSACTION", :transaction, conn) + + {:transaction, status} when status in [:transaction, :error] -> + {:error, Error.exception("nested transactions are not supported"), conn} + + {:savepoint, _status} -> + {:error, Error.exception("savepoints are not supported"), conn} + end + end + @impl true - def handle_commit(_opts, conn), do: {:ok, %{}, conn} + def handle_commit(opts, conn) do + case {Keyword.get(opts, :mode, :transaction), transaction_status(conn)} do + {:transaction, :transaction} -> + execute_transaction_command("COMMIT", :idle, conn) + + {:transaction, :error} -> + {:error, conn} + + {:transaction, :idle} -> + {:idle, conn} + + {:savepoint, _status} -> + {:error, Error.exception("savepoints are not supported"), conn} + end + end + @impl true - def handle_rollback(_opts, conn), do: {:ok, %{}, conn} + def handle_rollback(opts, conn) do + case {Keyword.get(opts, :mode, :transaction), transaction_status(conn)} do + {:transaction, status} when status in [:transaction, :error] -> + execute_transaction_command("ROLLBACK", :idle, conn) + + {:transaction, :idle} -> + {:idle, conn} + + {:savepoint, _status} -> + {:error, Error.exception("savepoints are not supported"), conn} + end + end + @impl true - def handle_status(_opts, conn), do: {:idle, conn} + def handle_status(_opts, conn), do: {transaction_status(conn), conn} @impl true def handle_prepare(_query, _opts, conn) do @@ -424,6 +468,13 @@ defmodule Ch.Connection do String.to_integer(code) end + conn = + if transaction_status(conn) == :transaction do + put_transaction_status(conn, :error) + else + conn + end + {:error, Error.exception(code: code, message: message), conn} end end @@ -474,7 +525,10 @@ defmodule Ch.Connection do end defp get_opts_or_private(conn, opts, key) do - Keyword.get(opts, key) || HTTP.get_private(conn, key) + case Keyword.fetch(opts, key) do + {:ok, value} -> value + :error -> HTTP.get_private(conn, key) + end end defp maybe_put_new_header(headers, _name, _no_value = nil), do: headers @@ -496,7 +550,72 @@ defmodule Ch.Connection do defp path(conn, query_params, opts) do settings = settings(conn, opts) - "/?" <> URI.encode_query(settings ++ query_params) + "/?" <> URI.encode_query(settings ++ session_query_params(conn, opts) ++ query_params) + end + + defp execute_transaction_command(statement, next_status, conn) do + opts = session_opts(conn) + + case request(conn, "POST", path(conn, [], opts), headers(conn, [], opts), statement, opts) do + {:ok, conn, _responses} -> + {:ok, %{}, put_transaction_status(conn, next_status)} + + {:error, error, conn} -> + {:error, error, conn} + + {:disconnect, reason, conn} -> + {:disconnect, reason, conn} + end + end + + defp session_query_params(conn, opts) do + if transaction_status(conn) != :idle or + Keyword.has_key?(opts, @session_id_key) or + Keyword.has_key?(opts, @session_timeout_key) do + session_id = get_opts_or_private(conn, opts, @session_id_key) + session_timeout = get_opts_or_private(conn, opts, @session_timeout_key) + + [] + |> maybe_put_query_param("session_id", session_id) + |> maybe_put_query_param("session_timeout", session_timeout) + else + [] + end + end + + defp maybe_put_query_param(params, _key, nil), do: params + defp maybe_put_query_param(params, key, value), do: [{key, value} | params] + + defp transaction_status(conn) do + HTTP.get_private(conn, @transaction_status_key, :idle) + end + + defp put_transaction_status(conn, status) do + HTTP.put_private(conn, @transaction_status_key, status) + end + + defp new_session_id do + "ch_" <> Base.url_encode64(:crypto.strong_rand_bytes(12), padding: false) + end + + defp put_default_session_opts(opts) do + opts + |> Keyword.put_new(@session_timeout_key, @default_session_timeout) + |> put_default_session_id() + end + + defp put_default_session_id(opts) do + case Keyword.has_key?(opts, @session_id_key) do + true -> opts + false -> Keyword.put(opts, @session_id_key, new_session_id()) + end + end + + defp session_opts(conn) do + [ + session_id: HTTP.get_private(conn, @session_id_key), + session_timeout: HTTP.get_private(conn, @session_timeout_key) + ] end @server_display_name_key :server_display_name From b35951ba96ff41a18c7d50d1025bd4288b482106 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pedro=20Pi=C3=B1era=20Buend=C3=ADa?= Date: Mon, 20 Apr 2026 14:51:20 +0200 Subject: [PATCH 2/3] Add transaction tests --- test/ch/connection_test.exs | 106 ++++++++++++++++++++++++++---- test/ch/faults_test.exs | 127 ++++++++++++++++++++++++++++++++++++ 2 files changed, 221 insertions(+), 12 deletions(-) diff --git a/test/ch/connection_test.exs b/test/ch/connection_test.exs index 57c48d3..0221ed1 100644 --- a/test/ch/connection_test.exs +++ b/test/ch/connection_test.exs @@ -1665,22 +1665,90 @@ defmodule Ch.ConnectionTest do end end - describe "transactions" do - test "commit", ctx do - DBConnection.transaction(ctx.conn, fn conn -> - ctx = Map.put(ctx, :conn, conn) - parameterize_query!(ctx, "select 1 + 1") - end) + describe "transactions when supported" do + setup %{conn: conn} do + if transactions_supported?(conn) do + table = "transaction_t_#{System.unique_integer([:positive])}" + Ch.query!(conn, "create table #{table}(id UInt8) engine = MergeTree order by tuple()") + {:ok, table: table, transactions_supported?: true} + else + {:ok, table: nil, transactions_supported?: false} + end end - test "rollback", ctx do - DBConnection.transaction(ctx.conn, fn conn -> - DBConnection.rollback(conn, :some_reason) - end) + test "commit persists rows and resets status", %{ + conn: conn, + table: table, + transactions_supported?: transactions_supported? + } do + if transactions_supported? do + assert DBConnection.status(conn) == :idle + + assert {:ok, :committed} = + DBConnection.transaction(conn, fn conn -> + assert DBConnection.status(conn) == :transaction + + assert {:ok, %{num_rows: 1}} = + Ch.query(conn, "insert into #{table} values (1)") + + assert Ch.query!(conn, "select count() from #{table}").rows == [[1]] + :committed + end) + + assert DBConnection.status(conn) == :idle + assert Ch.query!(conn, "select count() from #{table}").rows == [[1]] + end end - test "status", ctx do - assert DBConnection.status(ctx.conn) == :idle + test "rollback discards rows and resets status", %{ + conn: conn, + table: table, + transactions_supported?: transactions_supported? + } do + if transactions_supported? do + assert DBConnection.status(conn) == :idle + + assert {:error, :rolled_back} = + DBConnection.transaction(conn, fn conn -> + assert DBConnection.status(conn) == :transaction + + assert {:ok, %{num_rows: 1}} = + Ch.query(conn, "insert into #{table} values (1)") + + assert Ch.query!(conn, "select count() from #{table}").rows == [[1]] + DBConnection.rollback(conn, :rolled_back) + end) + + assert DBConnection.status(conn) == :idle + assert Ch.query!(conn, "select count() from #{table}").rows == [[0]] + end + end + + test "query errors mark the transaction as failed", %{ + conn: conn, + transactions_supported?: transactions_supported? + } do + if transactions_supported? do + assert_raise Ch.Error, "cannot commit a failed transaction; rollback is required", fn -> + DBConnection.transaction(conn, fn conn -> + assert DBConnection.status(conn) == :transaction + + assert {:error, %Ch.Error{}} = + Ch.query(conn, "select missing_transaction_column") + + assert DBConnection.status(conn) == :error + :ok + end) + end + + assert DBConnection.status(conn) == :idle + end + end + end + + describe "transaction status" do + test "is idle outside of a transaction", %{conn: conn} do + assert DBConnection.status(conn) == :idle end end @@ -1827,4 +1895,18 @@ defmodule Ch.ConnectionTest do assert List.last(row) == 1000 end end + + defp transactions_supported?(conn) do + case Ch.query(conn, "BEGIN TRANSACTION") do + {:ok, _result} -> + Ch.query!(conn, "ROLLBACK") + true + + {:error, %Ch.Error{code: 48}} -> + false + + {:error, error} -> + raise error + end + end end diff --git a/test/ch/faults_test.exs b/test/ch/faults_test.exs index bcc7457..2071608 100644 --- a/test/ch/faults_test.exs +++ b/test/ch/faults_test.exs @@ -545,7 +545,134 @@ defmodule Ch.FaultsTest do end end + describe "transactions" do + test "sends begin and commit over the same session", %{ + port: port, + listen: listen, + clickhouse: clickhouse + } do + test = self() + {:ok, conn} = Ch.start_link(port: port) + + {:ok, mint} = :gen_tcp.accept(listen) + + # handshake + :ok = :gen_tcp.send(clickhouse, intercept_packets(mint)) + :ok = :gen_tcp.send(mint, intercept_packets(clickhouse)) + + task = + Task.async(fn -> + DBConnection.transaction(conn, fn conn -> + assert DBConnection.status(conn) == :transaction + send(test, :inside_transaction) + :committed + end) + end) + + begin_request = intercept_packets(mint) + assert request_body(begin_request) == "BEGIN TRANSACTION" + + begin_params = request_query_params(begin_request) + assert begin_params["session_id"] + assert begin_params["session_timeout"] == "300" + :ok = :gen_tcp.send(mint, ok_response()) + + assert_receive :inside_transaction + + commit_request = intercept_packets(mint) + assert request_body(commit_request) == "COMMIT" + assert request_query_params(commit_request) == begin_params + :ok = :gen_tcp.send(mint, ok_response()) + + assert Task.await(task) == {:ok, :committed} + assert DBConnection.status(conn) == :idle + end + + test "marks the transaction as failed after a query error", %{ + port: port, + listen: listen, + clickhouse: clickhouse + } do + test = self() + {:ok, conn} = Ch.start_link(port: port) + + {:ok, mint} = :gen_tcp.accept(listen) + + # handshake + :ok = :gen_tcp.send(clickhouse, intercept_packets(mint)) + :ok = :gen_tcp.send(mint, intercept_packets(clickhouse)) + + task = + Task.async(fn -> + DBConnection.transaction(conn, fn conn -> + assert DBConnection.status(conn) == :transaction + + assert {:error, %Ch.Error{code: 47}} = + Ch.query(conn, "select missing_transaction_column") + + send(test, {:transaction_status, DBConnection.status(conn)}) + DBConnection.rollback(conn, :query_failed) + end) + end) + + begin_request = intercept_packets(mint) + begin_params = request_query_params(begin_request) + assert request_body(begin_request) == "BEGIN TRANSACTION" + :ok = :gen_tcp.send(mint, ok_response()) + + query_request = intercept_packets(mint) + assert request_body(query_request) == "select missing_transaction_column" + assert request_query_params(query_request) == begin_params + + :ok = + :gen_tcp.send( + mint, + error_response(47, "Code: 47. DB::Exception: Unknown expression identifier") + ) + + assert_receive {:transaction_status, :error} + rollback_request = intercept_packets(mint) + assert request_body(rollback_request) == "ROLLBACK" + assert request_query_params(rollback_request) == begin_params + :ok = :gen_tcp.send(mint, ok_response()) + + assert Task.await(task) == {:error, :query_failed} + assert DBConnection.status(conn) == :idle + end + end + defp first_byte(binary) do :binary.part(binary, 0, 1) end + + defp ok_response do + "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" + end + + defp error_response(code, message) do + "HTTP/1.1 500 Internal Server Error\r\n" <> + "X-ClickHouse-Exception-Code: #{code}\r\n" <> + "Content-Length: #{byte_size(message)}\r\n\r\n" <> message + end + + defp request_query_params(request) do + request + |> request_target() + |> URI.parse() + |> Map.get(:query, "") + |> URI.decode_query() + end + + defp request_body(request) do + case String.split(request, "\r\n\r\n", parts: 2) do + [_headers, body] -> body + [_headers] -> "" + end + end + + defp request_target(request) do + [request_line | _rest] = String.split(request, "\r\n", parts: 2) + [_method, target, _version] = String.split(request_line, " ", parts: 3) + target + end end From 0acf8c2f053fd56fcb299140ce4c42990f198e99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pedro=20Pi=C3=B1era=20Buend=C3=ADa?= Date: Mon, 20 Apr 2026 15:00:31 +0200 Subject: [PATCH 3/3] Refine HTTP fault test helpers --- test/ch/faults_test.exs | 92 +++++++++++++++++++++-------------------- 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/test/ch/faults_test.exs b/test/ch/faults_test.exs index 2071608..cd257e4 100644 --- a/test/ch/faults_test.exs +++ b/test/ch/faults_test.exs @@ -569,20 +569,20 @@ defmodule Ch.FaultsTest do end) end) - begin_request = intercept_packets(mint) - assert request_body(begin_request) == "BEGIN TRANSACTION" + begin_request = parse_request(intercept_packets(mint)) + assert begin_request.body == "BEGIN TRANSACTION" - begin_params = request_query_params(begin_request) + begin_params = begin_request.query_params assert begin_params["session_id"] assert begin_params["session_timeout"] == "300" - :ok = :gen_tcp.send(mint, ok_response()) + :ok = :gen_tcp.send(mint, response()) assert_receive :inside_transaction - commit_request = intercept_packets(mint) - assert request_body(commit_request) == "COMMIT" - assert request_query_params(commit_request) == begin_params - :ok = :gen_tcp.send(mint, ok_response()) + commit_request = parse_request(intercept_packets(mint)) + assert commit_request.body == "COMMIT" + assert commit_request.query_params == begin_params + :ok = :gen_tcp.send(mint, response()) assert Task.await(task) == {:ok, :committed} assert DBConnection.status(conn) == :idle @@ -615,26 +615,30 @@ defmodule Ch.FaultsTest do end) end) - begin_request = intercept_packets(mint) - begin_params = request_query_params(begin_request) - assert request_body(begin_request) == "BEGIN TRANSACTION" - :ok = :gen_tcp.send(mint, ok_response()) + begin_request = parse_request(intercept_packets(mint)) + begin_params = begin_request.query_params + assert begin_request.body == "BEGIN TRANSACTION" + :ok = :gen_tcp.send(mint, response()) - query_request = intercept_packets(mint) - assert request_body(query_request) == "select missing_transaction_column" - assert request_query_params(query_request) == begin_params + query_request = parse_request(intercept_packets(mint)) + assert query_request.body == "select missing_transaction_column" + assert query_request.query_params == begin_params :ok = :gen_tcp.send( mint, - error_response(47, "Code: 47. DB::Exception: Unknown expression identifier") + response( + 500, + [{"X-ClickHouse-Exception-Code", "47"}], + "Code: 47. DB::Exception: Unknown expression identifier" + ) ) assert_receive {:transaction_status, :error} - rollback_request = intercept_packets(mint) - assert request_body(rollback_request) == "ROLLBACK" - assert request_query_params(rollback_request) == begin_params - :ok = :gen_tcp.send(mint, ok_response()) + rollback_request = parse_request(intercept_packets(mint)) + assert rollback_request.body == "ROLLBACK" + assert rollback_request.query_params == begin_params + :ok = :gen_tcp.send(mint, response()) assert Task.await(task) == {:error, :query_failed} assert DBConnection.status(conn) == :idle @@ -645,34 +649,34 @@ defmodule Ch.FaultsTest do :binary.part(binary, 0, 1) end - defp ok_response do - "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" - end + defp response(status_code \\ 200, headers \\ [], body \\ "") - defp error_response(code, message) do - "HTTP/1.1 500 Internal Server Error\r\n" <> - "X-ClickHouse-Exception-Code: #{code}\r\n" <> - "Content-Length: #{byte_size(message)}\r\n\r\n" <> message + defp response(status_code, headers, body) do + "HTTP/1.1 #{status_code} #{reason_phrase(status_code)}\r\n" <> + Enum.map_join(headers, "", fn {name, value} -> "#{name}: #{value}\r\n" end) <> + "Content-Length: #{byte_size(body)}\r\n\r\n" <> body end - defp request_query_params(request) do - request - |> request_target() - |> URI.parse() - |> Map.get(:query, "") - |> URI.decode_query() - end + defp parse_request(request) do + {head, body} = + case String.split(request, "\r\n\r\n", parts: 2) do + [head, body] -> {head, body} + [head] -> {head, ""} + end - defp request_body(request) do - case String.split(request, "\r\n\r\n", parts: 2) do - [_headers, body] -> body - [_headers] -> "" - end - end - - defp request_target(request) do - [request_line | _rest] = String.split(request, "\r\n", parts: 2) + [request_line | _headers] = String.split(head, "\r\n", parts: 2) [_method, target, _version] = String.split(request_line, " ", parts: 3) - target + + %{ + body: body, + query_params: + target + |> URI.parse() + |> Map.get(:query, "") + |> URI.decode_query() + } end + + defp reason_phrase(200), do: "OK" + defp reason_phrase(500), do: "Internal Server Error" end