From bb6c448612dc06f7c9b2f879d17705a79292ef60 Mon Sep 17 00:00:00 2001 From: Jay Miracola Date: Thu, 9 Apr 2026 16:35:55 -0400 Subject: [PATCH 1/6] feat: tunnel raw TCP database connections through HTTP CONNECT proxy When HTTP_PROXY or HTTPS_PROXY is set, PostgreSQL, MySQL, and MSSQL connections are tunnelled through the proxy via HTTP CONNECT. This mirrors how net/http handles proxies automatically for HTTP-based providers, extending the same behaviour to raw TCP database drivers which bypass net/http entirely. Signed-off-by: Jay Miracola --- pkg/clients/mssql/mssql.go | 6 +-- pkg/clients/mssql/proxy.go | 62 ++++++++++++++++++++++++ pkg/clients/mssql/proxy_test.go | 72 ++++++++++++++++++++++++++++ pkg/clients/mysql/mysql.go | 6 +-- pkg/clients/mysql/proxy.go | 64 +++++++++++++++++++++++++ pkg/clients/mysql/proxy_test.go | 72 ++++++++++++++++++++++++++++ pkg/clients/postgresql/postgresql.go | 8 ++-- pkg/clients/postgresql/proxy.go | 71 +++++++++++++++++++++++++++ pkg/clients/postgresql/proxy_test.go | 71 +++++++++++++++++++++++++++ 9 files changed, 422 insertions(+), 10 deletions(-) create mode 100644 pkg/clients/mssql/proxy.go create mode 100644 pkg/clients/mssql/proxy_test.go create mode 100644 pkg/clients/mysql/proxy.go create mode 100644 pkg/clients/mysql/proxy_test.go create mode 100644 pkg/clients/postgresql/proxy.go create mode 100644 pkg/clients/postgresql/proxy_test.go diff --git a/pkg/clients/mssql/mssql.go b/pkg/clients/mssql/mssql.go index ff6cebba..e6cd1c69 100644 --- a/pkg/clients/mssql/mssql.go +++ b/pkg/clients/mssql/mssql.go @@ -77,7 +77,7 @@ func (c mssqlDB) ExecTx(_ context.Context, _ []xsql.Query) error { // Exec the supplied query. func (c mssqlDB) Exec(ctx context.Context, q xsql.Query) error { - d, err := sql.Open(driverName, c.dsn) + d, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return err } @@ -89,7 +89,7 @@ func (c mssqlDB) Exec(ctx context.Context, q xsql.Query) error { // Query the supplied query. func (c mssqlDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) { - d, err := sql.Open(driverName, c.dsn) + d, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return nil, err } @@ -100,7 +100,7 @@ func (c mssqlDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) { // Scan the results of the supplied query into the supplied destination. func (c mssqlDB) Scan(ctx context.Context, q xsql.Query, dest ...interface{}) error { - db, err := sql.Open(driverName, c.dsn) + db, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return err } diff --git a/pkg/clients/mssql/proxy.go b/pkg/clients/mssql/proxy.go new file mode 100644 index 00000000..89500353 --- /dev/null +++ b/pkg/clients/mssql/proxy.go @@ -0,0 +1,62 @@ +package mssql + +import ( + "bufio" + "context" + "database/sql" + "fmt" + "net" + "net/http" + "net/url" + + mssqldb "github.com/microsoft/go-mssqldb" +) + +type httpProxyDialer struct { + proxy *url.URL +} + +func (d *httpProxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", d.proxy.Host) + if err != nil { + return nil, err + } + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Host: addr}, + Host: addr, + Header: http.Header{}, + } + if err := req.Write(conn); err != nil { + conn.Close() //nolint:errcheck + return nil, err + } + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + conn.Close() //nolint:errcheck + return nil, err + } + resp.Body.Close() //nolint:errcheck + if resp.StatusCode != http.StatusOK { + conn.Close() //nolint:errcheck + return nil, fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status) + } + return conn, nil +} + +// openDB opens a *sql.DB for the given DSN. When HTTP_PROXY or HTTPS_PROXY is +// set and applicable to the target endpoint, connections are tunnelled through +// the proxy via HTTP CONNECT. +func openDB(endpoint, port, dsn string) (*sql.DB, error) { + req, _ := http.NewRequest(http.MethodConnect, "https://"+endpoint+":"+port, nil) + proxyURL, err := http.ProxyFromEnvironment(req) + if err != nil || proxyURL == nil { + return sql.Open(driverName, dsn) + } + connector, err := mssqldb.NewConnector(dsn) + if err != nil { + return nil, err + } + connector.Dialer = &httpProxyDialer{proxy: proxyURL} + return sql.OpenDB(connector), nil +} diff --git a/pkg/clients/mssql/proxy_test.go b/pkg/clients/mssql/proxy_test.go new file mode 100644 index 00000000..f44a7b70 --- /dev/null +++ b/pkg/clients/mssql/proxy_test.go @@ -0,0 +1,72 @@ +package mssql + +import ( + "context" + "fmt" + "net" + "net/url" + "testing" +) + +func TestHttpProxyDialerSuccess(t *testing.T) { + proxy, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer proxy.Close() //nolint:errcheck + + go func() { + conn, err := proxy.Accept() + if err != nil { + return + } + defer conn.Close() //nolint:errcheck + buf := make([]byte, 4096) + _, _ = conn.Read(buf) + _, _ = fmt.Fprint(conn, "HTTP/1.1 200 Connection established\r\n\r\n") + }() + + d := &httpProxyDialer{proxy: &url.URL{Host: proxy.Addr().String()}} + conn, err := d.DialContext(context.Background(), "tcp", "db.example.com:1433") + if err != nil { + t.Fatalf("DialContext() unexpected error: %v", err) + } + conn.Close() //nolint:errcheck +} + +func TestHttpProxyDialerRejected(t *testing.T) { + proxy, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer proxy.Close() //nolint:errcheck + + go func() { + conn, err := proxy.Accept() + if err != nil { + return + } + defer conn.Close() //nolint:errcheck + buf := make([]byte, 4096) + _, _ = conn.Read(buf) + _, _ = fmt.Fprint(conn, "HTTP/1.1 403 Forbidden\r\n\r\n") + }() + + d := &httpProxyDialer{proxy: &url.URL{Host: proxy.Addr().String()}} + _, err = d.DialContext(context.Background(), "tcp", "db.example.com:1433") + if err == nil { + t.Fatal("DialContext() expected error on non-200 response, got nil") + } +} + +func TestOpenDBNoProxy(t *testing.T) { + t.Setenv("HTTP_PROXY", "") + t.Setenv("HTTPS_PROXY", "") + t.Setenv("NO_PROXY", "*") + + db, err := openDB("localhost", "1433", "sqlserver://u:p@localhost:1433?database=db") + if err != nil { + t.Fatalf("openDB() unexpected error: %v", err) + } + db.Close() //nolint:errcheck +} diff --git a/pkg/clients/mysql/mysql.go b/pkg/clients/mysql/mysql.go index a6ea5306..13af55b8 100644 --- a/pkg/clients/mysql/mysql.go +++ b/pkg/clients/mysql/mysql.go @@ -75,7 +75,7 @@ func (c mySQLDB) ExecTx(ctx context.Context, ql []xsql.Query) error { // Exec the supplied query. func (c mySQLDB) Exec(ctx context.Context, q xsql.Query) error { - d, err := sql.Open("mysql", c.dsn) + d, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return err } @@ -87,7 +87,7 @@ func (c mySQLDB) Exec(ctx context.Context, q xsql.Query) error { // Query the supplied query. func (c mySQLDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) { - d, err := sql.Open("mysql", c.dsn) + d, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return nil, err } @@ -99,7 +99,7 @@ func (c mySQLDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) { // Scan the results of the supplied query into the supplied destination. func (c mySQLDB) Scan(ctx context.Context, q xsql.Query, dest ...interface{}) error { - db, err := sql.Open("mysql", c.dsn) + db, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return err } diff --git a/pkg/clients/mysql/proxy.go b/pkg/clients/mysql/proxy.go new file mode 100644 index 00000000..357f838e --- /dev/null +++ b/pkg/clients/mysql/proxy.go @@ -0,0 +1,64 @@ +package mysql + +import ( + "bufio" + "context" + "database/sql" + "fmt" + "net" + "net/http" + "net/url" + + gomysql "github.com/go-sql-driver/mysql" +) + +func tunnel(ctx context.Context, proxy *url.URL, addr string) (net.Conn, error) { + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", proxy.Host) + if err != nil { + return nil, err + } + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Host: addr}, + Host: addr, + Header: http.Header{}, + } + if err := req.Write(conn); err != nil { + conn.Close() //nolint:errcheck + return nil, err + } + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + conn.Close() //nolint:errcheck + return nil, err + } + resp.Body.Close() //nolint:errcheck + if resp.StatusCode != http.StatusOK { + conn.Close() //nolint:errcheck + return nil, fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status) + } + return conn, nil +} + +// openDB opens a *sql.DB for the given DSN. When HTTP_PROXY or HTTPS_PROXY is +// set and applicable to the target endpoint, connections are tunnelled through +// the proxy via HTTP CONNECT. +func openDB(endpoint, port, dsn string) (*sql.DB, error) { + req, _ := http.NewRequest(http.MethodConnect, "https://"+endpoint+":"+port, nil) + proxyURL, err := http.ProxyFromEnvironment(req) + if err != nil || proxyURL == nil { + return sql.Open("mysql", dsn) + } + cfg, err := gomysql.ParseDSN(dsn) + if err != nil { + return nil, err + } + cfg.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { + return tunnel(ctx, proxyURL, addr) + } + connector, err := gomysql.NewConnector(cfg) + if err != nil { + return nil, err + } + return sql.OpenDB(connector), nil +} diff --git a/pkg/clients/mysql/proxy_test.go b/pkg/clients/mysql/proxy_test.go new file mode 100644 index 00000000..3ba6ac22 --- /dev/null +++ b/pkg/clients/mysql/proxy_test.go @@ -0,0 +1,72 @@ +package mysql + +import ( + "context" + "fmt" + "net" + "net/url" + "testing" +) + +func TestTunnelSuccess(t *testing.T) { + proxy, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer proxy.Close() //nolint:errcheck + + go func() { + conn, err := proxy.Accept() + if err != nil { + return + } + defer conn.Close() //nolint:errcheck + buf := make([]byte, 4096) + _, _ = conn.Read(buf) + _, _ = fmt.Fprint(conn, "HTTP/1.1 200 Connection established\r\n\r\n") + }() + + proxyURL := &url.URL{Host: proxy.Addr().String()} + conn, err := tunnel(context.Background(), proxyURL, "db.example.com:3306") + if err != nil { + t.Fatalf("tunnel() unexpected error: %v", err) + } + conn.Close() //nolint:errcheck +} + +func TestTunnelProxyRejected(t *testing.T) { + proxy, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer proxy.Close() //nolint:errcheck + + go func() { + conn, err := proxy.Accept() + if err != nil { + return + } + defer conn.Close() //nolint:errcheck + buf := make([]byte, 4096) + _, _ = conn.Read(buf) + _, _ = fmt.Fprint(conn, "HTTP/1.1 407 Proxy Authentication Required\r\n\r\n") + }() + + proxyURL := &url.URL{Host: proxy.Addr().String()} + _, err = tunnel(context.Background(), proxyURL, "db.example.com:3306") + if err == nil { + t.Fatal("tunnel() expected error on non-200 response, got nil") + } +} + +func TestOpenDBNoProxy(t *testing.T) { + t.Setenv("HTTP_PROXY", "") + t.Setenv("HTTPS_PROXY", "") + t.Setenv("NO_PROXY", "*") + + db, err := openDB("localhost", "3306", "u:p@tcp(localhost:3306)/db") + if err != nil { + t.Fatalf("openDB() unexpected error: %v", err) + } + db.Close() //nolint:errcheck +} diff --git a/pkg/clients/postgresql/postgresql.go b/pkg/clients/postgresql/postgresql.go index 53e12744..18e91615 100644 --- a/pkg/clients/postgresql/postgresql.go +++ b/pkg/clients/postgresql/postgresql.go @@ -63,7 +63,7 @@ func DSN(username, password, endpoint, port, database, sslmode string) string { // ExecTx executes an array of queries, committing if all are successful and // rolling back immediately on failure. func (c postgresDB) ExecTx(ctx context.Context, ql []xsql.Query) error { - d, err := sql.Open("postgres", c.dsn) + d, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return err } @@ -95,7 +95,7 @@ func (c postgresDB) ExecTx(ctx context.Context, ql []xsql.Query) error { // Exec the supplied query. func (c postgresDB) Exec(ctx context.Context, q xsql.Query) error { - d, err := sql.Open("postgres", c.dsn) + d, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return err } @@ -107,7 +107,7 @@ func (c postgresDB) Exec(ctx context.Context, q xsql.Query) error { // Query the supplied query. func (c postgresDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) { - d, err := sql.Open("postgres", c.dsn) + d, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return nil, err } @@ -119,7 +119,7 @@ func (c postgresDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) // Scan the results of the supplied query into the supplied destination. func (c postgresDB) Scan(ctx context.Context, q xsql.Query, dest ...interface{}) error { - db, err := sql.Open("postgres", c.dsn) + db, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return err } diff --git a/pkg/clients/postgresql/proxy.go b/pkg/clients/postgresql/proxy.go new file mode 100644 index 00000000..b305232e --- /dev/null +++ b/pkg/clients/postgresql/proxy.go @@ -0,0 +1,71 @@ +package postgresql + +import ( + "bufio" + "database/sql" + "fmt" + "net" + "net/http" + "net/url" + "time" + + "github.com/lib/pq" +) + +// httpProxyDialer tunnels connections through an HTTP CONNECT proxy. +type httpProxyDialer struct { + proxy *url.URL +} + +func (d *httpProxyDialer) Dial(network, addr string) (net.Conn, error) { + return tunnel(&net.Dialer{}, d.proxy, addr) +} + +func (d *httpProxyDialer) DialTimeout(network, addr string, timeout time.Duration) (net.Conn, error) { + return tunnel(&net.Dialer{Timeout: timeout}, d.proxy, addr) +} + +func tunnel(nd *net.Dialer, proxy *url.URL, addr string) (net.Conn, error) { + conn, err := nd.Dial("tcp", proxy.Host) + if err != nil { + return nil, err + } + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Host: addr}, + Host: addr, + Header: http.Header{}, + } + if err := req.Write(conn); err != nil { + conn.Close() //nolint:errcheck + return nil, err + } + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + conn.Close() //nolint:errcheck + return nil, err + } + resp.Body.Close() //nolint:errcheck + if resp.StatusCode != http.StatusOK { + conn.Close() //nolint:errcheck + return nil, fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status) + } + return conn, nil +} + +// openDB opens a *sql.DB for the given DSN. When HTTP_PROXY or HTTPS_PROXY is +// set and applicable to the target endpoint, connections are tunnelled through +// the proxy via HTTP CONNECT. +func openDB(endpoint, port, dsn string) (*sql.DB, error) { + req, _ := http.NewRequest(http.MethodConnect, "https://"+endpoint+":"+port, nil) + proxyURL, err := http.ProxyFromEnvironment(req) + if err != nil || proxyURL == nil { + return sql.Open("postgres", dsn) + } + connector, err := pq.NewConnector(dsn) + if err != nil { + return nil, err + } + connector.Dialer(&httpProxyDialer{proxy: proxyURL}) + return sql.OpenDB(connector), nil +} diff --git a/pkg/clients/postgresql/proxy_test.go b/pkg/clients/postgresql/proxy_test.go new file mode 100644 index 00000000..628192ca --- /dev/null +++ b/pkg/clients/postgresql/proxy_test.go @@ -0,0 +1,71 @@ +package postgresql + +import ( + "fmt" + "net" + "net/url" + "testing" +) + +func TestTunnelSuccess(t *testing.T) { + proxy, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer proxy.Close() //nolint:errcheck + + go func() { + conn, err := proxy.Accept() + if err != nil { + return + } + defer conn.Close() //nolint:errcheck + buf := make([]byte, 4096) + _, _ = conn.Read(buf) + _, _ = fmt.Fprint(conn, "HTTP/1.1 200 Connection established\r\n\r\n") + }() + + proxyURL := &url.URL{Host: proxy.Addr().String()} + conn, err := tunnel(&net.Dialer{}, proxyURL, "db.example.com:5432") + if err != nil { + t.Fatalf("tunnel() unexpected error: %v", err) + } + conn.Close() //nolint:errcheck +} + +func TestTunnelProxyRejected(t *testing.T) { + proxy, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer proxy.Close() //nolint:errcheck + + go func() { + conn, err := proxy.Accept() + if err != nil { + return + } + defer conn.Close() //nolint:errcheck + buf := make([]byte, 4096) + _, _ = conn.Read(buf) + _, _ = fmt.Fprint(conn, "HTTP/1.1 403 Forbidden\r\n\r\n") + }() + + proxyURL := &url.URL{Host: proxy.Addr().String()} + _, err = tunnel(&net.Dialer{}, proxyURL, "db.example.com:5432") + if err == nil { + t.Fatal("tunnel() expected error on non-200 response, got nil") + } +} + +func TestOpenDBNoProxy(t *testing.T) { + t.Setenv("HTTP_PROXY", "") + t.Setenv("HTTPS_PROXY", "") + t.Setenv("NO_PROXY", "*") + + db, err := openDB("localhost", "5432", "postgres://u:p@localhost:5432/db?sslmode=disable") + if err != nil { + t.Fatalf("openDB() unexpected error: %v", err) + } + db.Close() //nolint:errcheck +} From 0bf8d147f791c4f42320bfc4c691e8335e9d2137 Mon Sep 17 00:00:00 2001 From: Jay Miracola <30883275+jaymiracola@users.noreply.github.com> Date: Tue, 14 Apr 2026 10:46:01 -0400 Subject: [PATCH 2/6] Update pkg/clients/mssql/proxy.go MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: J. Fernández <7312236+fernandezcuesta@users.noreply.github.com> Signed-off-by: Jay Miracola --- pkg/clients/mssql/proxy.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pkg/clients/mssql/proxy.go b/pkg/clients/mssql/proxy.go index 89500353..33758927 100644 --- a/pkg/clients/mssql/proxy.go +++ b/pkg/clients/mssql/proxy.go @@ -21,6 +21,14 @@ func (d *httpProxyDialer) DialContext(ctx context.Context, network, addr string) if err != nil { return nil, err } + defer func() + if err != nil { + closeErr := conn.Close() + if closeErr != nil { + err = fmt.Errorf("error closing connection (%w) after: %w", closeErr, err) + } + } + }() req := &http.Request{ Method: http.MethodConnect, URL: &url.URL{Host: addr}, From 5b58c8efce8eb0b3b3b251c43d35d1897ce06d28 Mon Sep 17 00:00:00 2001 From: Jay Miracola <30883275+jaymiracola@users.noreply.github.com> Date: Tue, 14 Apr 2026 10:46:14 -0400 Subject: [PATCH 3/6] Update pkg/clients/mssql/proxy.go MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: J. Fernández <7312236+fernandezcuesta@users.noreply.github.com> Signed-off-by: Jay Miracola --- pkg/clients/mssql/proxy.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/clients/mssql/proxy.go b/pkg/clients/mssql/proxy.go index 33758927..cbfedfec 100644 --- a/pkg/clients/mssql/proxy.go +++ b/pkg/clients/mssql/proxy.go @@ -36,7 +36,6 @@ func (d *httpProxyDialer) DialContext(ctx context.Context, network, addr string) Header: http.Header{}, } if err := req.Write(conn); err != nil { - conn.Close() //nolint:errcheck return nil, err } resp, err := http.ReadResponse(bufio.NewReader(conn), req) From 96f9325bcfc400f5b6a99cbc011eccc2be8eb7e0 Mon Sep 17 00:00:00 2001 From: Jay Miracola <30883275+jaymiracola@users.noreply.github.com> Date: Tue, 14 Apr 2026 10:46:21 -0400 Subject: [PATCH 4/6] Update pkg/clients/mssql/proxy.go MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: J. Fernández <7312236+fernandezcuesta@users.noreply.github.com> Signed-off-by: Jay Miracola --- pkg/clients/mssql/proxy.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/clients/mssql/proxy.go b/pkg/clients/mssql/proxy.go index cbfedfec..1ffb65b5 100644 --- a/pkg/clients/mssql/proxy.go +++ b/pkg/clients/mssql/proxy.go @@ -40,7 +40,6 @@ func (d *httpProxyDialer) DialContext(ctx context.Context, network, addr string) } resp, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { - conn.Close() //nolint:errcheck return nil, err } resp.Body.Close() //nolint:errcheck From ae5cd16ba78fc788cd416ec31bf3e55925f5b563 Mon Sep 17 00:00:00 2001 From: Jay Miracola <30883275+jaymiracola@users.noreply.github.com> Date: Tue, 14 Apr 2026 10:46:28 -0400 Subject: [PATCH 5/6] Update pkg/clients/mssql/proxy.go MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: J. Fernández <7312236+fernandezcuesta@users.noreply.github.com> Signed-off-by: Jay Miracola --- pkg/clients/mssql/proxy.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/clients/mssql/proxy.go b/pkg/clients/mssql/proxy.go index 1ffb65b5..2053c451 100644 --- a/pkg/clients/mssql/proxy.go +++ b/pkg/clients/mssql/proxy.go @@ -44,7 +44,6 @@ func (d *httpProxyDialer) DialContext(ctx context.Context, network, addr string) } resp.Body.Close() //nolint:errcheck if resp.StatusCode != http.StatusOK { - conn.Close() //nolint:errcheck return nil, fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status) } return conn, nil From 08a59bb0f6eee68572597a47ee824d09c235c41d Mon Sep 17 00:00:00 2001 From: Jay Miracola Date: Tue, 14 Apr 2026 10:49:54 -0400 Subject: [PATCH 6/6] fix: use named returns and defer for conn cleanup in proxy dialers Apply the defer-based error handling pattern consistently across all three proxy files. Named return values allow the deferred closure to capture and wrap close errors alongside the original error rather than discarding them silently. Signed-off-by: Jay Miracola --- pkg/clients/mssql/proxy.go | 12 +++++++----- pkg/clients/mysql/proxy.go | 21 ++++++++++++++------- pkg/clients/postgresql/proxy.go | 21 ++++++++++++++------- 3 files changed, 35 insertions(+), 19 deletions(-) diff --git a/pkg/clients/mssql/proxy.go b/pkg/clients/mssql/proxy.go index 2053c451..1a7f7d04 100644 --- a/pkg/clients/mssql/proxy.go +++ b/pkg/clients/mssql/proxy.go @@ -16,12 +16,12 @@ type httpProxyDialer struct { proxy *url.URL } -func (d *httpProxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { +func (d *httpProxyDialer) DialContext(ctx context.Context, network, addr string) (_ net.Conn, err error) { conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", d.proxy.Host) if err != nil { return nil, err } - defer func() + defer func() { if err != nil { closeErr := conn.Close() if closeErr != nil { @@ -35,16 +35,18 @@ func (d *httpProxyDialer) DialContext(ctx context.Context, network, addr string) Host: addr, Header: http.Header{}, } - if err := req.Write(conn); err != nil { + if err = req.Write(conn); err != nil { return nil, err } - resp, err := http.ReadResponse(bufio.NewReader(conn), req) + var resp *http.Response + resp, err = http.ReadResponse(bufio.NewReader(conn), req) if err != nil { return nil, err } resp.Body.Close() //nolint:errcheck if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status) + err = fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status) + return nil, err } return conn, nil } diff --git a/pkg/clients/mysql/proxy.go b/pkg/clients/mysql/proxy.go index 357f838e..3b6e94a3 100644 --- a/pkg/clients/mysql/proxy.go +++ b/pkg/clients/mysql/proxy.go @@ -12,30 +12,37 @@ import ( gomysql "github.com/go-sql-driver/mysql" ) -func tunnel(ctx context.Context, proxy *url.URL, addr string) (net.Conn, error) { +func tunnel(ctx context.Context, proxy *url.URL, addr string) (_ net.Conn, err error) { conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", proxy.Host) if err != nil { return nil, err } + defer func() { + if err != nil { + closeErr := conn.Close() + if closeErr != nil { + err = fmt.Errorf("error closing connection (%w) after: %w", closeErr, err) + } + } + }() req := &http.Request{ Method: http.MethodConnect, URL: &url.URL{Host: addr}, Host: addr, Header: http.Header{}, } - if err := req.Write(conn); err != nil { - conn.Close() //nolint:errcheck + if err = req.Write(conn); err != nil { return nil, err } - resp, err := http.ReadResponse(bufio.NewReader(conn), req) + var resp *http.Response + resp, err = http.ReadResponse(bufio.NewReader(conn), req) if err != nil { - conn.Close() //nolint:errcheck return nil, err } resp.Body.Close() //nolint:errcheck if resp.StatusCode != http.StatusOK { - conn.Close() //nolint:errcheck - return nil, fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status) + err = fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status) + return nil, err } return conn, nil } diff --git a/pkg/clients/postgresql/proxy.go b/pkg/clients/postgresql/proxy.go index b305232e..70e6c05a 100644 --- a/pkg/clients/postgresql/proxy.go +++ b/pkg/clients/postgresql/proxy.go @@ -25,30 +25,37 @@ func (d *httpProxyDialer) DialTimeout(network, addr string, timeout time.Duratio return tunnel(&net.Dialer{Timeout: timeout}, d.proxy, addr) } -func tunnel(nd *net.Dialer, proxy *url.URL, addr string) (net.Conn, error) { +func tunnel(nd *net.Dialer, proxy *url.URL, addr string) (_ net.Conn, err error) { conn, err := nd.Dial("tcp", proxy.Host) if err != nil { return nil, err } + defer func() { + if err != nil { + closeErr := conn.Close() + if closeErr != nil { + err = fmt.Errorf("error closing connection (%w) after: %w", closeErr, err) + } + } + }() req := &http.Request{ Method: http.MethodConnect, URL: &url.URL{Host: addr}, Host: addr, Header: http.Header{}, } - if err := req.Write(conn); err != nil { - conn.Close() //nolint:errcheck + if err = req.Write(conn); err != nil { return nil, err } - resp, err := http.ReadResponse(bufio.NewReader(conn), req) + var resp *http.Response + resp, err = http.ReadResponse(bufio.NewReader(conn), req) if err != nil { - conn.Close() //nolint:errcheck return nil, err } resp.Body.Close() //nolint:errcheck if resp.StatusCode != http.StatusOK { - conn.Close() //nolint:errcheck - return nil, fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status) + err = fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status) + return nil, err } return conn, nil }