diff --git a/pkg/clients/mssql/mssql.go b/pkg/clients/mssql/mssql.go index 1c0a5ac0..e6612b0c 100644 --- a/pkg/clients/mssql/mssql.go +++ b/pkg/clients/mssql/mssql.go @@ -77,7 +77,7 @@ func (c mssqlDB) ExecTx(_ context.Context, _ []xsql.Query) error { // Exec the supplied query. func (c mssqlDB) Exec(ctx context.Context, q xsql.Query) error { - d, err := sql.Open(driverName, c.dsn) + d, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return err } @@ -89,7 +89,7 @@ func (c mssqlDB) Exec(ctx context.Context, q xsql.Query) error { // Query the supplied query. func (c mssqlDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) { - d, err := sql.Open(driverName, c.dsn) + d, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return nil, err } @@ -100,7 +100,7 @@ func (c mssqlDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) { // Scan the results of the supplied query into the supplied destination. func (c mssqlDB) Scan(ctx context.Context, q xsql.Query, dest ...interface{}) error { - db, err := sql.Open(driverName, c.dsn) + db, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return err } diff --git a/pkg/clients/mssql/proxy.go b/pkg/clients/mssql/proxy.go new file mode 100644 index 00000000..1a7f7d04 --- /dev/null +++ b/pkg/clients/mssql/proxy.go @@ -0,0 +1,69 @@ +package mssql + +import ( + "bufio" + "context" + "database/sql" + "fmt" + "net" + "net/http" + "net/url" + + mssqldb "github.com/microsoft/go-mssqldb" +) + +type httpProxyDialer struct { + proxy *url.URL +} + +func (d *httpProxyDialer) DialContext(ctx context.Context, network, addr string) (_ net.Conn, err error) { + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", d.proxy.Host) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + closeErr := conn.Close() + if closeErr != nil { + err = fmt.Errorf("error closing connection (%w) after: %w", closeErr, err) + } + } + }() + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Host: addr}, + Host: addr, + Header: http.Header{}, + } + if err = req.Write(conn); err != nil { + return nil, err + } + var resp *http.Response + resp, err = http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + return nil, err + } + resp.Body.Close() //nolint:errcheck + if resp.StatusCode != http.StatusOK { + err = fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status) + return nil, err + } + return conn, nil +} + +// openDB opens a *sql.DB for the given DSN. When HTTP_PROXY or HTTPS_PROXY is +// set and applicable to the target endpoint, connections are tunnelled through +// the proxy via HTTP CONNECT. +func openDB(endpoint, port, dsn string) (*sql.DB, error) { + req, _ := http.NewRequest(http.MethodConnect, "https://"+endpoint+":"+port, nil) + proxyURL, err := http.ProxyFromEnvironment(req) + if err != nil || proxyURL == nil { + return sql.Open(driverName, dsn) + } + connector, err := mssqldb.NewConnector(dsn) + if err != nil { + return nil, err + } + connector.Dialer = &httpProxyDialer{proxy: proxyURL} + return sql.OpenDB(connector), nil +} diff --git a/pkg/clients/mssql/proxy_test.go b/pkg/clients/mssql/proxy_test.go new file mode 100644 index 00000000..f44a7b70 --- /dev/null +++ b/pkg/clients/mssql/proxy_test.go @@ -0,0 +1,72 @@ +package mssql + +import ( + "context" + "fmt" + "net" + "net/url" + "testing" +) + +func TestHttpProxyDialerSuccess(t *testing.T) { + proxy, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer proxy.Close() //nolint:errcheck + + go func() { + conn, err := proxy.Accept() + if err != nil { + return + } + defer conn.Close() //nolint:errcheck + buf := make([]byte, 4096) + _, _ = conn.Read(buf) + _, _ = fmt.Fprint(conn, "HTTP/1.1 200 Connection established\r\n\r\n") + }() + + d := &httpProxyDialer{proxy: &url.URL{Host: proxy.Addr().String()}} + conn, err := d.DialContext(context.Background(), "tcp", "db.example.com:1433") + if err != nil { + t.Fatalf("DialContext() unexpected error: %v", err) + } + conn.Close() //nolint:errcheck +} + +func TestHttpProxyDialerRejected(t *testing.T) { + proxy, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer proxy.Close() //nolint:errcheck + + go func() { + conn, err := proxy.Accept() + if err != nil { + return + } + defer conn.Close() //nolint:errcheck + buf := make([]byte, 4096) + _, _ = conn.Read(buf) + _, _ = fmt.Fprint(conn, "HTTP/1.1 403 Forbidden\r\n\r\n") + }() + + d := &httpProxyDialer{proxy: &url.URL{Host: proxy.Addr().String()}} + _, err = d.DialContext(context.Background(), "tcp", "db.example.com:1433") + if err == nil { + t.Fatal("DialContext() expected error on non-200 response, got nil") + } +} + +func TestOpenDBNoProxy(t *testing.T) { + t.Setenv("HTTP_PROXY", "") + t.Setenv("HTTPS_PROXY", "") + t.Setenv("NO_PROXY", "*") + + db, err := openDB("localhost", "1433", "sqlserver://u:p@localhost:1433?database=db") + if err != nil { + t.Fatalf("openDB() unexpected error: %v", err) + } + db.Close() //nolint:errcheck +} diff --git a/pkg/clients/mysql/mysql.go b/pkg/clients/mysql/mysql.go index 1ce9c66f..9662f50e 100644 --- a/pkg/clients/mysql/mysql.go +++ b/pkg/clients/mysql/mysql.go @@ -75,7 +75,7 @@ func (c mySQLDB) ExecTx(ctx context.Context, ql []xsql.Query) error { // Exec the supplied query. func (c mySQLDB) Exec(ctx context.Context, q xsql.Query) error { - d, err := sql.Open("mysql", c.dsn) + d, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return err } @@ -87,7 +87,7 @@ func (c mySQLDB) Exec(ctx context.Context, q xsql.Query) error { // Query the supplied query. func (c mySQLDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) { - d, err := sql.Open("mysql", c.dsn) + d, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return nil, err } @@ -99,7 +99,7 @@ func (c mySQLDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) { // Scan the results of the supplied query into the supplied destination. func (c mySQLDB) Scan(ctx context.Context, q xsql.Query, dest ...interface{}) error { - db, err := sql.Open("mysql", c.dsn) + db, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return err } diff --git a/pkg/clients/mysql/proxy.go b/pkg/clients/mysql/proxy.go new file mode 100644 index 00000000..3b6e94a3 --- /dev/null +++ b/pkg/clients/mysql/proxy.go @@ -0,0 +1,71 @@ +package mysql + +import ( + "bufio" + "context" + "database/sql" + "fmt" + "net" + "net/http" + "net/url" + + gomysql "github.com/go-sql-driver/mysql" +) + +func tunnel(ctx context.Context, proxy *url.URL, addr string) (_ net.Conn, err error) { + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", proxy.Host) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + closeErr := conn.Close() + if closeErr != nil { + err = fmt.Errorf("error closing connection (%w) after: %w", closeErr, err) + } + } + }() + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Host: addr}, + Host: addr, + Header: http.Header{}, + } + if err = req.Write(conn); err != nil { + return nil, err + } + var resp *http.Response + resp, err = http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + return nil, err + } + resp.Body.Close() //nolint:errcheck + if resp.StatusCode != http.StatusOK { + err = fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status) + return nil, err + } + return conn, nil +} + +// openDB opens a *sql.DB for the given DSN. When HTTP_PROXY or HTTPS_PROXY is +// set and applicable to the target endpoint, connections are tunnelled through +// the proxy via HTTP CONNECT. +func openDB(endpoint, port, dsn string) (*sql.DB, error) { + req, _ := http.NewRequest(http.MethodConnect, "https://"+endpoint+":"+port, nil) + proxyURL, err := http.ProxyFromEnvironment(req) + if err != nil || proxyURL == nil { + return sql.Open("mysql", dsn) + } + cfg, err := gomysql.ParseDSN(dsn) + if err != nil { + return nil, err + } + cfg.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { + return tunnel(ctx, proxyURL, addr) + } + connector, err := gomysql.NewConnector(cfg) + if err != nil { + return nil, err + } + return sql.OpenDB(connector), nil +} diff --git a/pkg/clients/mysql/proxy_test.go b/pkg/clients/mysql/proxy_test.go new file mode 100644 index 00000000..3ba6ac22 --- /dev/null +++ b/pkg/clients/mysql/proxy_test.go @@ -0,0 +1,72 @@ +package mysql + +import ( + "context" + "fmt" + "net" + "net/url" + "testing" +) + +func TestTunnelSuccess(t *testing.T) { + proxy, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer proxy.Close() //nolint:errcheck + + go func() { + conn, err := proxy.Accept() + if err != nil { + return + } + defer conn.Close() //nolint:errcheck + buf := make([]byte, 4096) + _, _ = conn.Read(buf) + _, _ = fmt.Fprint(conn, "HTTP/1.1 200 Connection established\r\n\r\n") + }() + + proxyURL := &url.URL{Host: proxy.Addr().String()} + conn, err := tunnel(context.Background(), proxyURL, "db.example.com:3306") + if err != nil { + t.Fatalf("tunnel() unexpected error: %v", err) + } + conn.Close() //nolint:errcheck +} + +func TestTunnelProxyRejected(t *testing.T) { + proxy, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer proxy.Close() //nolint:errcheck + + go func() { + conn, err := proxy.Accept() + if err != nil { + return + } + defer conn.Close() //nolint:errcheck + buf := make([]byte, 4096) + _, _ = conn.Read(buf) + _, _ = fmt.Fprint(conn, "HTTP/1.1 407 Proxy Authentication Required\r\n\r\n") + }() + + proxyURL := &url.URL{Host: proxy.Addr().String()} + _, err = tunnel(context.Background(), proxyURL, "db.example.com:3306") + if err == nil { + t.Fatal("tunnel() expected error on non-200 response, got nil") + } +} + +func TestOpenDBNoProxy(t *testing.T) { + t.Setenv("HTTP_PROXY", "") + t.Setenv("HTTPS_PROXY", "") + t.Setenv("NO_PROXY", "*") + + db, err := openDB("localhost", "3306", "u:p@tcp(localhost:3306)/db") + if err != nil { + t.Fatalf("openDB() unexpected error: %v", err) + } + db.Close() //nolint:errcheck +} diff --git a/pkg/clients/postgresql/postgresql.go b/pkg/clients/postgresql/postgresql.go index c48c7483..8db77cc7 100644 --- a/pkg/clients/postgresql/postgresql.go +++ b/pkg/clients/postgresql/postgresql.go @@ -64,7 +64,7 @@ func DSN(username, password, endpoint, port, database, sslmode string) string { // ExecTx executes an array of queries, committing if all are successful and // rolling back immediately on failure. func (c postgresDB) ExecTx(ctx context.Context, ql []xsql.Query) error { - d, err := sql.Open("postgres", c.dsn) + d, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return err } @@ -96,7 +96,7 @@ func (c postgresDB) ExecTx(ctx context.Context, ql []xsql.Query) error { // Exec the supplied query. func (c postgresDB) Exec(ctx context.Context, q xsql.Query) error { - d, err := sql.Open("postgres", c.dsn) + d, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return err } @@ -108,7 +108,7 @@ func (c postgresDB) Exec(ctx context.Context, q xsql.Query) error { // Query the supplied query. func (c postgresDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) { - d, err := sql.Open("postgres", c.dsn) + d, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return nil, err } @@ -120,7 +120,7 @@ func (c postgresDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) // Scan the results of the supplied query into the supplied destination. func (c postgresDB) Scan(ctx context.Context, q xsql.Query, dest ...interface{}) error { - db, err := sql.Open("postgres", c.dsn) + db, err := openDB(c.endpoint, c.port, c.dsn) if err != nil { return err } diff --git a/pkg/clients/postgresql/proxy.go b/pkg/clients/postgresql/proxy.go new file mode 100644 index 00000000..70e6c05a --- /dev/null +++ b/pkg/clients/postgresql/proxy.go @@ -0,0 +1,78 @@ +package postgresql + +import ( + "bufio" + "database/sql" + "fmt" + "net" + "net/http" + "net/url" + "time" + + "github.com/lib/pq" +) + +// httpProxyDialer tunnels connections through an HTTP CONNECT proxy. +type httpProxyDialer struct { + proxy *url.URL +} + +func (d *httpProxyDialer) Dial(network, addr string) (net.Conn, error) { + return tunnel(&net.Dialer{}, d.proxy, addr) +} + +func (d *httpProxyDialer) DialTimeout(network, addr string, timeout time.Duration) (net.Conn, error) { + return tunnel(&net.Dialer{Timeout: timeout}, d.proxy, addr) +} + +func tunnel(nd *net.Dialer, proxy *url.URL, addr string) (_ net.Conn, err error) { + conn, err := nd.Dial("tcp", proxy.Host) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + closeErr := conn.Close() + if closeErr != nil { + err = fmt.Errorf("error closing connection (%w) after: %w", closeErr, err) + } + } + }() + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Host: addr}, + Host: addr, + Header: http.Header{}, + } + if err = req.Write(conn); err != nil { + return nil, err + } + var resp *http.Response + resp, err = http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + return nil, err + } + resp.Body.Close() //nolint:errcheck + if resp.StatusCode != http.StatusOK { + err = fmt.Errorf("proxy CONNECT to %s: %s", addr, resp.Status) + return nil, err + } + return conn, nil +} + +// openDB opens a *sql.DB for the given DSN. When HTTP_PROXY or HTTPS_PROXY is +// set and applicable to the target endpoint, connections are tunnelled through +// the proxy via HTTP CONNECT. +func openDB(endpoint, port, dsn string) (*sql.DB, error) { + req, _ := http.NewRequest(http.MethodConnect, "https://"+endpoint+":"+port, nil) + proxyURL, err := http.ProxyFromEnvironment(req) + if err != nil || proxyURL == nil { + return sql.Open("postgres", dsn) + } + connector, err := pq.NewConnector(dsn) + if err != nil { + return nil, err + } + connector.Dialer(&httpProxyDialer{proxy: proxyURL}) + return sql.OpenDB(connector), nil +} diff --git a/pkg/clients/postgresql/proxy_test.go b/pkg/clients/postgresql/proxy_test.go new file mode 100644 index 00000000..628192ca --- /dev/null +++ b/pkg/clients/postgresql/proxy_test.go @@ -0,0 +1,71 @@ +package postgresql + +import ( + "fmt" + "net" + "net/url" + "testing" +) + +func TestTunnelSuccess(t *testing.T) { + proxy, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer proxy.Close() //nolint:errcheck + + go func() { + conn, err := proxy.Accept() + if err != nil { + return + } + defer conn.Close() //nolint:errcheck + buf := make([]byte, 4096) + _, _ = conn.Read(buf) + _, _ = fmt.Fprint(conn, "HTTP/1.1 200 Connection established\r\n\r\n") + }() + + proxyURL := &url.URL{Host: proxy.Addr().String()} + conn, err := tunnel(&net.Dialer{}, proxyURL, "db.example.com:5432") + if err != nil { + t.Fatalf("tunnel() unexpected error: %v", err) + } + conn.Close() //nolint:errcheck +} + +func TestTunnelProxyRejected(t *testing.T) { + proxy, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer proxy.Close() //nolint:errcheck + + go func() { + conn, err := proxy.Accept() + if err != nil { + return + } + defer conn.Close() //nolint:errcheck + buf := make([]byte, 4096) + _, _ = conn.Read(buf) + _, _ = fmt.Fprint(conn, "HTTP/1.1 403 Forbidden\r\n\r\n") + }() + + proxyURL := &url.URL{Host: proxy.Addr().String()} + _, err = tunnel(&net.Dialer{}, proxyURL, "db.example.com:5432") + if err == nil { + t.Fatal("tunnel() expected error on non-200 response, got nil") + } +} + +func TestOpenDBNoProxy(t *testing.T) { + t.Setenv("HTTP_PROXY", "") + t.Setenv("HTTPS_PROXY", "") + t.Setenv("NO_PROXY", "*") + + db, err := openDB("localhost", "5432", "postgres://u:p@localhost:5432/db?sslmode=disable") + if err != nil { + t.Fatalf("openDB() unexpected error: %v", err) + } + db.Close() //nolint:errcheck +}