Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/clients/mssql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (c mssqlDB) ExecTx(_ context.Context, _ []xsql.Query) error {

// Exec the supplied query.
func (c mssqlDB) Exec(ctx context.Context, q xsql.Query) error {
d, err := sql.Open(driverName, c.dsn)
d, err := openDB(c.endpoint, c.port, c.dsn)
if err != nil {
return err
}
Expand All @@ -89,7 +89,7 @@ func (c mssqlDB) Exec(ctx context.Context, q xsql.Query) error {

// Query the supplied query.
func (c mssqlDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) {
d, err := sql.Open(driverName, c.dsn)
d, err := openDB(c.endpoint, c.port, c.dsn)
if err != nil {
return nil, err
}
Expand All @@ -100,7 +100,7 @@ func (c mssqlDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) {

// Scan the results of the supplied query into the supplied destination.
func (c mssqlDB) Scan(ctx context.Context, q xsql.Query, dest ...interface{}) error {
db, err := sql.Open(driverName, c.dsn)
db, err := openDB(c.endpoint, c.port, c.dsn)
if err != nil {
return err
}
Expand Down
69 changes: 69 additions & 0 deletions pkg/clients/mssql/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package mssql

import (
"bufio"
"context"
"database/sql"
"fmt"
"net"
"net/http"
"net/url"

mssqldb "github.com/microsoft/go-mssqldb"
)

type httpProxyDialer struct {
proxy *url.URL
}

func (d *httpProxyDialer) DialContext(ctx context.Context, network, addr string) (_ net.Conn, err error) {
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", d.proxy.Host)
if err != nil {
return nil, err
}
Comment thread
jaymiracola marked this conversation as resolved.
defer func() {
if err != nil {
closeErr := conn.Close()
if closeErr != nil {
err = fmt.Errorf("error closing connection (%w) after: %w", closeErr, err)
}
}
}()
req := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: addr},
Host: addr,
Header: http.Header{},
}
if err = req.Write(conn); err != nil {
return nil, err
}
var resp *http.Response
resp, err = http.ReadResponse(bufio.NewReader(conn), req)
Copy link
Copy Markdown
Collaborator

@chlunde chlunde May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will possibly hijack bytes from the database server - as NewReader does a 4 KB read and that could be some of the bytes not just from the proxy but the sql server banner data?

if err != nil {
return nil, err
}
resp.Body.Close() //nolint:errcheck
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status)
return nil, err
}
return conn, nil
}

// openDB opens a *sql.DB for the given DSN. When HTTP_PROXY or HTTPS_PROXY is
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it respect HTTP_PROXY? you construct a https URL below?

// set and applicable to the target endpoint, connections are tunnelled through
// the proxy via HTTP CONNECT.
func openDB(endpoint, port, dsn string) (*sql.DB, error) {
req, _ := http.NewRequest(http.MethodConnect, "https://"+endpoint+":"+port, nil)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this panic on invalid URL?

Copy link
Copy Markdown
Collaborator

@chlunde chlunde May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You construct a &http.Request{ manually in the code the other place - should probably use the same here? this would avoid discarding the error too

Comment thread
chlunde marked this conversation as resolved.
proxyURL, err := http.ProxyFromEnvironment(req)
if err != nil || proxyURL == nil {
return sql.Open(driverName, dsn)
}
connector, err := mssqldb.NewConnector(dsn)
if err != nil {
return nil, err
}
connector.Dialer = &httpProxyDialer{proxy: proxyURL}
return sql.OpenDB(connector), nil
}
72 changes: 72 additions & 0 deletions pkg/clients/mssql/proxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package mssql

import (
"context"
"fmt"
"net"
"net/url"
"testing"
)

func TestHttpProxyDialerSuccess(t *testing.T) {
proxy, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer proxy.Close() //nolint:errcheck

go func() {
conn, err := proxy.Accept()
if err != nil {
return
}
defer conn.Close() //nolint:errcheck
buf := make([]byte, 4096)
_, _ = conn.Read(buf)
_, _ = fmt.Fprint(conn, "HTTP/1.1 200 Connection established\r\n\r\n")
}()

d := &httpProxyDialer{proxy: &url.URL{Host: proxy.Addr().String()}}
conn, err := d.DialContext(context.Background(), "tcp", "db.example.com:1433")
if err != nil {
t.Fatalf("DialContext() unexpected error: %v", err)
}
conn.Close() //nolint:errcheck
}

func TestHttpProxyDialerRejected(t *testing.T) {
proxy, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer proxy.Close() //nolint:errcheck

go func() {
conn, err := proxy.Accept()
if err != nil {
return
}
defer conn.Close() //nolint:errcheck
buf := make([]byte, 4096)
_, _ = conn.Read(buf)
_, _ = fmt.Fprint(conn, "HTTP/1.1 403 Forbidden\r\n\r\n")
}()

d := &httpProxyDialer{proxy: &url.URL{Host: proxy.Addr().String()}}
_, err = d.DialContext(context.Background(), "tcp", "db.example.com:1433")
if err == nil {
t.Fatal("DialContext() expected error on non-200 response, got nil")
}
}

func TestOpenDBNoProxy(t *testing.T) {
t.Setenv("HTTP_PROXY", "")
t.Setenv("HTTPS_PROXY", "")
t.Setenv("NO_PROXY", "*")

db, err := openDB("localhost", "1433", "sqlserver://u:p@localhost:1433?database=db")
if err != nil {
t.Fatalf("openDB() unexpected error: %v", err)
}
db.Close() //nolint:errcheck
}
6 changes: 3 additions & 3 deletions pkg/clients/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (c mySQLDB) ExecTx(ctx context.Context, ql []xsql.Query) error {

// Exec the supplied query.
func (c mySQLDB) Exec(ctx context.Context, q xsql.Query) error {
d, err := sql.Open("mysql", c.dsn)
d, err := openDB(c.endpoint, c.port, c.dsn)
if err != nil {
return err
}
Expand All @@ -87,7 +87,7 @@ func (c mySQLDB) Exec(ctx context.Context, q xsql.Query) error {

// Query the supplied query.
func (c mySQLDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) {
d, err := sql.Open("mysql", c.dsn)
d, err := openDB(c.endpoint, c.port, c.dsn)
if err != nil {
return nil, err
}
Expand All @@ -99,7 +99,7 @@ func (c mySQLDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) {

// Scan the results of the supplied query into the supplied destination.
func (c mySQLDB) Scan(ctx context.Context, q xsql.Query, dest ...interface{}) error {
db, err := sql.Open("mysql", c.dsn)
db, err := openDB(c.endpoint, c.port, c.dsn)
if err != nil {
return err
}
Expand Down
71 changes: 71 additions & 0 deletions pkg/clients/mysql/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package mysql

import (
"bufio"
"context"
"database/sql"
"fmt"
"net"
"net/http"
"net/url"

gomysql "github.com/go-sql-driver/mysql"
)

func tunnel(ctx context.Context, proxy *url.URL, addr string) (_ net.Conn, err error) {
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", proxy.Host)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
closeErr := conn.Close()
if closeErr != nil {
err = fmt.Errorf("error closing connection (%w) after: %w", closeErr, err)
}
}
}()
req := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: addr},
Host: addr,
Header: http.Header{},
}
if err = req.Write(conn); err != nil {
return nil, err
}
var resp *http.Response
resp, err = http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
return nil, err
}
resp.Body.Close() //nolint:errcheck
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status)
return nil, err
}
return conn, nil
}

// openDB opens a *sql.DB for the given DSN. When HTTP_PROXY or HTTPS_PROXY is
// set and applicable to the target endpoint, connections are tunnelled through
// the proxy via HTTP CONNECT.
func openDB(endpoint, port, dsn string) (*sql.DB, error) {
req, _ := http.NewRequest(http.MethodConnect, "https://"+endpoint+":"+port, nil)
proxyURL, err := http.ProxyFromEnvironment(req)
if err != nil || proxyURL == nil {
return sql.Open("mysql", dsn)
}
cfg, err := gomysql.ParseDSN(dsn)
if err != nil {
return nil, err
}
cfg.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
return tunnel(ctx, proxyURL, addr)
}
connector, err := gomysql.NewConnector(cfg)
if err != nil {
return nil, err
}
return sql.OpenDB(connector), nil
}
72 changes: 72 additions & 0 deletions pkg/clients/mysql/proxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package mysql

import (
"context"
"fmt"
"net"
"net/url"
"testing"
)

func TestTunnelSuccess(t *testing.T) {
proxy, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer proxy.Close() //nolint:errcheck

go func() {
conn, err := proxy.Accept()
if err != nil {
return
}
defer conn.Close() //nolint:errcheck
buf := make([]byte, 4096)
_, _ = conn.Read(buf)
_, _ = fmt.Fprint(conn, "HTTP/1.1 200 Connection established\r\n\r\n")
}()

proxyURL := &url.URL{Host: proxy.Addr().String()}
conn, err := tunnel(context.Background(), proxyURL, "db.example.com:3306")
if err != nil {
t.Fatalf("tunnel() unexpected error: %v", err)
}
conn.Close() //nolint:errcheck
}

func TestTunnelProxyRejected(t *testing.T) {
proxy, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer proxy.Close() //nolint:errcheck

go func() {
conn, err := proxy.Accept()
if err != nil {
return
}
defer conn.Close() //nolint:errcheck
buf := make([]byte, 4096)
_, _ = conn.Read(buf)
_, _ = fmt.Fprint(conn, "HTTP/1.1 407 Proxy Authentication Required\r\n\r\n")
}()

proxyURL := &url.URL{Host: proxy.Addr().String()}
_, err = tunnel(context.Background(), proxyURL, "db.example.com:3306")
if err == nil {
t.Fatal("tunnel() expected error on non-200 response, got nil")
}
}

func TestOpenDBNoProxy(t *testing.T) {
t.Setenv("HTTP_PROXY", "")
t.Setenv("HTTPS_PROXY", "")
t.Setenv("NO_PROXY", "*")

db, err := openDB("localhost", "3306", "u:p@tcp(localhost:3306)/db")
if err != nil {
t.Fatalf("openDB() unexpected error: %v", err)
}
db.Close() //nolint:errcheck
}
8 changes: 4 additions & 4 deletions pkg/clients/postgresql/postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func DSN(username, password, endpoint, port, database, sslmode string) string {
// ExecTx executes an array of queries, committing if all are successful and
// rolling back immediately on failure.
func (c postgresDB) ExecTx(ctx context.Context, ql []xsql.Query) error {
d, err := sql.Open("postgres", c.dsn)
d, err := openDB(c.endpoint, c.port, c.dsn)
if err != nil {
return err
}
Expand Down Expand Up @@ -96,7 +96,7 @@ func (c postgresDB) ExecTx(ctx context.Context, ql []xsql.Query) error {

// Exec the supplied query.
func (c postgresDB) Exec(ctx context.Context, q xsql.Query) error {
d, err := sql.Open("postgres", c.dsn)
d, err := openDB(c.endpoint, c.port, c.dsn)
if err != nil {
return err
}
Expand All @@ -108,7 +108,7 @@ func (c postgresDB) Exec(ctx context.Context, q xsql.Query) error {

// Query the supplied query.
func (c postgresDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) {
d, err := sql.Open("postgres", c.dsn)
d, err := openDB(c.endpoint, c.port, c.dsn)
if err != nil {
return nil, err
}
Expand All @@ -120,7 +120,7 @@ func (c postgresDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error)

// Scan the results of the supplied query into the supplied destination.
func (c postgresDB) Scan(ctx context.Context, q xsql.Query, dest ...interface{}) error {
db, err := sql.Open("postgres", c.dsn)
db, err := openDB(c.endpoint, c.port, c.dsn)
if err != nil {
return err
}
Expand Down
Loading
Loading