From 2bb1195b242c282d3af12beb1a2cb56b8600aadb Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Thu, 27 Nov 2025 20:07:34 +0100 Subject: [PATCH 1/2] Remove some old Go version specific logic This fork doesn't aim at supporting very old Go versions, so the logic that existed to keep the implementation backwards compatible with those old Go versions was removed. --- conn.go | 34 +++++++++ conn_go110.go | 19 ------ conn_go115.go | 19 ------ conn_go19.go | 22 ------ connector.go | 2 - connector_test.go | 2 - driver.go | 23 ++++++- driver_go110.go | 23 ------- stmt.go | 21 ++++-- stmt_go19.go | 17 ----- stmt_go19_test.go | 171 ---------------------------------------------- stmt_test.go | 164 ++++++++++++++++++++++++++++++++++++++++++++ 12 files changed, 236 insertions(+), 281 deletions(-) delete mode 100755 conn_go110.go delete mode 100755 conn_go115.go delete mode 100755 conn_go19.go delete mode 100755 driver_go110.go delete mode 100755 stmt_go19.go delete mode 100644 stmt_go19_test.go diff --git a/conn.go b/conn.go index 46d13c8..c4033cf 100755 --- a/conn.go +++ b/conn.go @@ -20,6 +20,8 @@ var ( _ driver.Pinger = wrappedConn{} _ driver.Queryer = wrappedConn{} _ driver.QueryerContext = wrappedConn{} + _ driver.NamedValueChecker = wrappedConn{} + _ driver.SessionResetter = wrappedConn{} ) func (c wrappedConn) Prepare(query string) (driver.Stmt, error) { @@ -179,3 +181,35 @@ func (c wrappedParentConn) QueryContext(ctx context.Context, query string, args return c.Conn.(driver.Queryer).Query(query, dargs) } } + +func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value) + return err +} + +func (c wrappedConn) CheckNamedValue(v *driver.NamedValue) error { + if checker, ok := c.parent.(driver.NamedValueChecker); ok { + return checker.CheckNamedValue(v) + } + + return defaultCheckNamedValue(v) +} + +func (c wrappedConn) ResetSession(ctx context.Context) error { + conn, ok := c.parent.(driver.SessionResetter) + if !ok { + return nil + } + + return conn.ResetSession(ctx) +} + +func (c wrappedConn) IsValid() bool { + conn, ok := c.parent.(driver.Validator) + if !ok { + // the default if driver.Validator is not supported + return true + } + + return conn.IsValid() +} diff --git a/conn_go110.go b/conn_go110.go deleted file mode 100755 index b980d93..0000000 --- a/conn_go110.go +++ /dev/null @@ -1,19 +0,0 @@ -// +build go1.10 - -package sqlmw - -import ( - "context" - "database/sql/driver" -) - -var _ driver.SessionResetter = wrappedConn{} - -func (c wrappedConn) ResetSession(ctx context.Context) error { - conn, ok := c.parent.(driver.SessionResetter) - if !ok { - return nil - } - - return conn.ResetSession(ctx) -} diff --git a/conn_go115.go b/conn_go115.go deleted file mode 100755 index f635425..0000000 --- a/conn_go115.go +++ /dev/null @@ -1,19 +0,0 @@ -// +build go1.15 - -package sqlmw - -import ( - "database/sql/driver" -) - -var _ driver.SessionResetter = wrappedConn{} - -func (c wrappedConn) IsValid() bool { - conn, ok := c.parent.(driver.Validator) - if !ok { - // the default if driver.Validator is not supported - return true - } - - return conn.IsValid() -} diff --git a/conn_go19.go b/conn_go19.go deleted file mode 100755 index 268cabc..0000000 --- a/conn_go19.go +++ /dev/null @@ -1,22 +0,0 @@ -// +build go1.9 - -package sqlmw - -import "database/sql/driver" - -var ( - _ driver.NamedValueChecker = wrappedConn{} -) - -func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { - nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value) - return err -} - -func (c wrappedConn) CheckNamedValue(v *driver.NamedValue) error { - if checker, ok := c.parent.(driver.NamedValueChecker); ok { - return checker.CheckNamedValue(v) - } - - return defaultCheckNamedValue(v) -} diff --git a/connector.go b/connector.go index 4740a0d..86a3d57 100755 --- a/connector.go +++ b/connector.go @@ -1,5 +1,3 @@ -// +build go1.10 - package sqlmw import ( diff --git a/connector_test.go b/connector_test.go index fab236f..cbacb69 100755 --- a/connector_test.go +++ b/connector_test.go @@ -1,5 +1,3 @@ -// +build go1.10 - package sqlmw import ( diff --git a/driver.go b/driver.go index 254972e..65a7247 100755 --- a/driver.go +++ b/driver.go @@ -2,7 +2,7 @@ package sqlmw import "database/sql/driver" -// driver wraps a sql.Driver with an interceptor. +// wrappedDriver wraps a sql.Driver with an interceptor. type wrappedDriver struct { intr Interceptor parent driver.Driver @@ -10,7 +10,8 @@ type wrappedDriver struct { // Compile time validation that our types implement the expected interfaces var ( - _ driver.Driver = wrappedDriver{} + _ driver.Driver = wrappedDriver{} + _ driver.DriverContext = wrappedDriver{} ) // WrapDriver will wrap the passed SQL driver and return a new sql driver that uses it and also logs and traces calls using the passed logger and tracer @@ -36,3 +37,21 @@ func (d wrappedDriver) Open(name string) (driver.Conn, error) { return wrappedConn{intr: d.intr, parent: conn}, nil } + +// OpenConnector implements the database/sql/driver.Driver interface for WrappedDriver. +func (d wrappedDriver) OpenConnector(name string) (driver.Connector, error) { + driver, ok := d.parent.(driver.DriverContext) + if !ok { + return wrappedConnector{ + parent: dsnConnector{dsn: name, driver: d.parent}, + driverRef: &d, + }, nil + } + + conn, err := driver.OpenConnector(name) + if err != nil { + return nil, err + } + + return wrappedConnector{parent: conn, driverRef: &d}, nil +} diff --git a/driver_go110.go b/driver_go110.go deleted file mode 100755 index 0a2c0a3..0000000 --- a/driver_go110.go +++ /dev/null @@ -1,23 +0,0 @@ -// +build go1.10 - -package sqlmw - -import "database/sql/driver" - -var _ driver.DriverContext = wrappedDriver{} - -func (d wrappedDriver) OpenConnector(name string) (driver.Connector, error) { - driver, ok := d.parent.(driver.DriverContext) - if !ok { - return wrappedConnector{ - parent: dsnConnector{dsn: name, driver: d.parent}, - driverRef: &d, - }, nil - } - conn, err := driver.OpenConnector(name) - if err != nil { - return nil, err - } - - return wrappedConnector{parent: conn, driverRef: &d}, nil -} diff --git a/stmt.go b/stmt.go index f2d7e2f..2c91411 100755 --- a/stmt.go +++ b/stmt.go @@ -15,10 +15,11 @@ type wrappedStmt struct { // Compile time validation that our types implement the expected interfaces var ( - _ driver.Stmt = wrappedStmt{} - _ driver.StmtExecContext = wrappedStmt{} - _ driver.StmtQueryContext = wrappedStmt{} - _ driver.ColumnConverter = wrappedStmt{} + _ driver.Stmt = wrappedStmt{} + _ driver.StmtExecContext = wrappedStmt{} + _ driver.StmtQueryContext = wrappedStmt{} + _ driver.ColumnConverter = wrappedStmt{} + _ driver.NamedValueChecker = wrappedStmt{} ) func (s wrappedStmt) Close() (err error) { @@ -108,3 +109,15 @@ func (s wrappedParentStmt) ExecContext(ctx context.Context, args []driver.NamedV } return s.Exec(dargs) } + +func (s wrappedStmt) CheckNamedValue(v *driver.NamedValue) error { + if checker, ok := s.parent.(driver.NamedValueChecker); ok { + return checker.CheckNamedValue(v) + } + + if checker, ok := s.conn.parent.(driver.NamedValueChecker); ok { + return checker.CheckNamedValue(v) + } + + return defaultCheckNamedValue(v) +} diff --git a/stmt_go19.go b/stmt_go19.go deleted file mode 100755 index 8442101..0000000 --- a/stmt_go19.go +++ /dev/null @@ -1,17 +0,0 @@ -package sqlmw - -import "database/sql/driver" - -var _ driver.NamedValueChecker = wrappedStmt{} - -func (s wrappedStmt) CheckNamedValue(v *driver.NamedValue) error { - if checker, ok := s.parent.(driver.NamedValueChecker); ok { - return checker.CheckNamedValue(v) - } - - if checker, ok := s.conn.parent.(driver.NamedValueChecker); ok { - return checker.CheckNamedValue(v) - } - - return defaultCheckNamedValue(v) -} diff --git a/stmt_go19_test.go b/stmt_go19_test.go deleted file mode 100644 index 8f58153..0000000 --- a/stmt_go19_test.go +++ /dev/null @@ -1,171 +0,0 @@ -package sqlmw - -import ( - "database/sql" - "database/sql/driver" - "reflect" - "testing" -) - -// TestDefaultParameterConversion ensures that -// driver.DefaultParameterConverter is used when neither stmt nor con -// implements any value converters. -func TestDefaultParameterConversion(t *testing.T) { - driverName := driverName(t) - - expectVal := int64(1) - con := &fakeConn{} - fakeStmt := &fakeStmt{ - rows: &fakeRows{ - con: con, - vals: [][]driver.Value{{expectVal}}, - }, - } - con.stmt = fakeStmt - - sql.Register( - driverName, - Driver(&fakeDriver{conn: con}, &NullInterceptor{}), - ) - - db, err := sql.Open(driverName, "") - if err != nil { - t.Fatalf("Failed to open: %v", err) - } - - t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Errorf("Failed to close db: %v", err) - } - }) - - stmt, err := db.Prepare("") - if err != nil { - t.Fatalf("Prepare failed: %s", err) - } - - // int32 values are converted by driver.DefaultParameterConverter to - // int64 - queryVal := int32(1) - rows, err := stmt.Query(queryVal) - if err != nil { - t.Fatalf("Query failed: %s", err) - } - - count := 0 - for rows.Next() { - var v int64 - err := rows.Scan(&v) - if err != nil { - t.Fatalf("rows.Scan failed, %v", err) - } - if v != 1 { - t.Errorf("converted value is %d, passed value to Query was: %d", v, expectVal) - } - count++ - } - - if count != 1 { - t.Fatalf("got too many rows, expected 1, got %d ", 1) - } -} - -func TestWrappedStmt_CheckNamedValue(t *testing.T) { - tests := map[string]struct { - fd *fakeDriver - expected struct { - cc bool // Whether the fakeConn's CheckNamedValue was called - sc bool // Whether the fakeStmt's CheckNamedValue was called - } - }{ - "When both conn and stmt implement CheckNamedValue": { - fd: &fakeDriver{ - conn: &fakeConnWithCheckNamedValue{ - fakeConn: fakeConn{ - stmt: &fakeStmtWithCheckNamedValue{}, - }, - }, - }, - expected: struct { - cc bool - sc bool - }{cc: false, sc: true}, - }, - "When only conn implements CheckNamedValue": { - fd: &fakeDriver{ - conn: &fakeConnWithCheckNamedValue{ - fakeConn: fakeConn{ - stmt: &fakeStmtWithoutCheckNamedValue{}, - }, - }, - }, - expected: struct { - cc bool - sc bool - }{cc: true, sc: false}, - }, - "When only stmt implements CheckNamedValue": { - fd: &fakeDriver{ - conn: &fakeConnWithoutCheckNamedValue{ - fakeConn: fakeConn{ - stmt: &fakeStmtWithCheckNamedValue{}, - }, - }, - }, - expected: struct { - cc bool - sc bool - }{cc: false, sc: true}, - }, - "When both stmt do not implement CheckNamedValue": { - fd: &fakeDriver{ - conn: &fakeConnWithoutCheckNamedValue{ - fakeConn: fakeConn{ - stmt: &fakeStmtWithoutCheckNamedValue{}, - }, - }, - }, - expected: struct { - cc bool - sc bool - }{cc: false, sc: false}, - }, - } - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - driverName := driverName(t) - sql.Register(driverName, Driver(test.fd, &fakeInterceptor{})) - db, err := sql.Open(driverName, "dummy") - if err != nil { - t.Errorf("Failed to open: %v", err) - } - defer func() { - if err := db.Close(); err != nil { - t.Errorf("Failed to close db: %v", err) - } - }() - - stmt, err := db.Prepare("SELECT foo FROM bar Where 1 = ?") - if err != nil { - t.Errorf("Failed to prepare: %v", err) - } - - if _, err := stmt.Query(1); err != nil { - t.Errorf("Failed to query: %v", err) - } - - conn := reflect.ValueOf(test.fd.conn).Elem() - sc := conn.FieldByName("stmt").Elem().Elem().FieldByName("called").Bool() - cc := conn.FieldByName("called").Bool() - - if test.expected.sc != sc { - t.Errorf("sc mismatch.\n got: %#v\nwant: %#v", sc, test.expected.sc) - } - - if test.expected.cc != cc { - t.Errorf("cc mismatch.\n got: %#v\nwant: %#v", cc, test.expected.cc) - } - }) - } -} diff --git a/stmt_test.go b/stmt_test.go index 9d60083..02afa75 100644 --- a/stmt_test.go +++ b/stmt_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "database/sql/driver" + "reflect" "testing" ) @@ -99,3 +100,166 @@ func TestStmtQueryContext_PassWrappedRowContext(t *testing.T) { t.Error("RowsClose context not valid") } } + +// TestDefaultParameterConversion ensures that +// driver.DefaultParameterConverter is used when neither stmt nor con +// implements any value converters. +func TestDefaultParameterConversion(t *testing.T) { + driverName := driverName(t) + + expectVal := int64(1) + con := &fakeConn{} + fakeStmt := &fakeStmt{ + rows: &fakeRows{ + con: con, + vals: [][]driver.Value{{expectVal}}, + }, + } + con.stmt = fakeStmt + + sql.Register( + driverName, + Driver(&fakeDriver{conn: con}, &NullInterceptor{}), + ) + + db, err := sql.Open(driverName, "") + if err != nil { + t.Fatalf("Failed to open: %v", err) + } + + t.Cleanup(func() { + if err := db.Close(); err != nil { + t.Errorf("Failed to close db: %v", err) + } + }) + + stmt, err := db.Prepare("") + if err != nil { + t.Fatalf("Prepare failed: %s", err) + } + + // int32 values are converted by driver.DefaultParameterConverter to + // int64 + queryVal := int32(1) + rows, err := stmt.Query(queryVal) + if err != nil { + t.Fatalf("Query failed: %s", err) + } + + count := 0 + for rows.Next() { + var v int64 + err := rows.Scan(&v) + if err != nil { + t.Fatalf("rows.Scan failed, %v", err) + } + if v != 1 { + t.Errorf("converted value is %d, passed value to Query was: %d", v, expectVal) + } + count++ + } + + if count != 1 { + t.Fatalf("got too many rows, expected 1, got %d ", 1) + } +} + +func TestWrappedStmt_CheckNamedValue(t *testing.T) { + tests := map[string]struct { + fd *fakeDriver + expected struct { + cc bool // Whether the fakeConn's CheckNamedValue was called + sc bool // Whether the fakeStmt's CheckNamedValue was called + } + }{ + "When both conn and stmt implement CheckNamedValue": { + fd: &fakeDriver{ + conn: &fakeConnWithCheckNamedValue{ + fakeConn: fakeConn{ + stmt: &fakeStmtWithCheckNamedValue{}, + }, + }, + }, + expected: struct { + cc bool + sc bool + }{cc: false, sc: true}, + }, + "When only conn implements CheckNamedValue": { + fd: &fakeDriver{ + conn: &fakeConnWithCheckNamedValue{ + fakeConn: fakeConn{ + stmt: &fakeStmtWithoutCheckNamedValue{}, + }, + }, + }, + expected: struct { + cc bool + sc bool + }{cc: true, sc: false}, + }, + "When only stmt implements CheckNamedValue": { + fd: &fakeDriver{ + conn: &fakeConnWithoutCheckNamedValue{ + fakeConn: fakeConn{ + stmt: &fakeStmtWithCheckNamedValue{}, + }, + }, + }, + expected: struct { + cc bool + sc bool + }{cc: false, sc: true}, + }, + "When both stmt do not implement CheckNamedValue": { + fd: &fakeDriver{ + conn: &fakeConnWithoutCheckNamedValue{ + fakeConn: fakeConn{ + stmt: &fakeStmtWithoutCheckNamedValue{}, + }, + }, + }, + expected: struct { + cc bool + sc bool + }{cc: false, sc: false}, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + driverName := driverName(t) + sql.Register(driverName, Driver(test.fd, &fakeInterceptor{})) + db, err := sql.Open(driverName, "dummy") + if err != nil { + t.Errorf("Failed to open: %v", err) + } + defer func() { + if err := db.Close(); err != nil { + t.Errorf("Failed to close db: %v", err) + } + }() + + stmt, err := db.Prepare("SELECT foo FROM bar Where 1 = ?") + if err != nil { + t.Errorf("Failed to prepare: %v", err) + } + + if _, err := stmt.Query(1); err != nil { + t.Errorf("Failed to query: %v", err) + } + + conn := reflect.ValueOf(test.fd.conn).Elem() + sc := conn.FieldByName("stmt").Elem().Elem().FieldByName("called").Bool() + cc := conn.FieldByName("called").Bool() + + if test.expected.sc != sc { + t.Errorf("sc mismatch.\n got: %#v\nwant: %#v", sc, test.expected.sc) + } + + if test.expected.cc != cc { + t.Errorf("cc mismatch.\n got: %#v\nwant: %#v", cc, test.expected.cc) + } + }) + } +} From 54f0fda0d74bcc4e1c3ef0ac8ee3b2a679456d17 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Fri, 28 Nov 2025 14:14:02 +0100 Subject: [PATCH 2/2] Use new `go:build` and add `test` target to `mage` --- .github/workflows/go.yml | 5 ++++- fakedb_test.go | 12 ++++-------- magefiles/mage.go | 6 ++++++ rows_test.go | 4 ++-- tools/rows_picker_gen.go | 2 +- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 99db15b..c2df737 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -32,7 +32,10 @@ jobs: go-version: 1.24 # TODO: matrix - name: Test - run: go test -v -race ./... + uses: magefile/mage-action@6f50bbb8ea47d56e62dee92392788acbc8192d0b # v3.1.0 + with: + version: latest + args: test check: runs-on: ubuntu-latest diff --git a/fakedb_test.go b/fakedb_test.go index bf261a6..18d509f 100644 --- a/fakedb_test.go +++ b/fakedb_test.go @@ -194,6 +194,7 @@ type fakeConn struct { called bool // nolint:structcheck // ignore unused warning, it is accessed via reflection rowsCloseCalled bool stmt driver.Stmt + queryStmt driver.StmtQueryContext tx driver.Tx } @@ -221,17 +222,12 @@ func (c *fakeConn) Close() error { return nil } func (c *fakeConn) Begin() (driver.Tx, error) { return c.tx, nil } -func (c *fakeConn) QueryContext(_ context.Context, _ string, nvs []driver.NamedValue) (driver.Rows, error) { - if c.stmt == nil { +func (c *fakeConn) QueryContext(ctx context.Context, _ string, nvs []driver.NamedValue) (driver.Rows, error) { + if c.queryStmt == nil { return &fakeRows{con: c}, nil } - var args []driver.Value - for _, nv := range nvs { - args = append(args, nv.Value) - } - - return c.stmt.Query(args) + return c.queryStmt.QueryContext(ctx, nvs) } func (c *fakeConnWithCheckNamedValue) CheckNamedValue(_ *driver.NamedValue) (err error) { diff --git a/magefiles/mage.go b/magefiles/mage.go index 091bb68..72bcf00 100644 --- a/magefiles/mage.go +++ b/magefiles/mage.go @@ -14,3 +14,9 @@ func Lint(ctx context.Context) error { args := []string{"run", "--config", ".golangci.yml"} return sh.RunV("golangci-lint", args...) } + +// Test runs the tests +func Test(ctx context.Context) error { + args := []string{"test", "-v", "-race", "./..."} + return sh.RunV("go", args...) +} diff --git a/rows_test.go b/rows_test.go index 7b9243c..defb142 100644 --- a/rows_test.go +++ b/rows_test.go @@ -109,7 +109,7 @@ func TestRowsNext(t *testing.T) { stmt := fakeStmt{ rows: rows, } - con.stmt = stmt + con.queryStmt = stmt driverName := driverName(t) interceptor := rowsNextInterceptor{} @@ -199,7 +199,7 @@ func TestRows_LikePGX(t *testing.T) { stmt := fakeStmt{ rows: rows, } - con.stmt = stmt + con.queryStmt = stmt driverName := driverName(t) interceptor := rowsNextInterceptor{} diff --git a/tools/rows_picker_gen.go b/tools/rows_picker_gen.go index c02f99c..c4b036b 100755 --- a/tools/rows_picker_gen.go +++ b/tools/rows_picker_gen.go @@ -1,4 +1,4 @@ -// +build ignore +//go:build ignore package main