diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 99db15b..c2df737 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -32,7 +32,10 @@ jobs: go-version: 1.24 # TODO: matrix - name: Test - run: go test -v -race ./... + uses: magefile/mage-action@6f50bbb8ea47d56e62dee92392788acbc8192d0b # v3.1.0 + with: + version: latest + args: test check: runs-on: ubuntu-latest diff --git a/conn.go b/conn.go index 46d13c8..c4033cf 100755 --- a/conn.go +++ b/conn.go @@ -20,6 +20,8 @@ var ( _ driver.Pinger = wrappedConn{} _ driver.Queryer = wrappedConn{} _ driver.QueryerContext = wrappedConn{} + _ driver.NamedValueChecker = wrappedConn{} + _ driver.SessionResetter = wrappedConn{} ) func (c wrappedConn) Prepare(query string) (driver.Stmt, error) { @@ -179,3 +181,35 @@ func (c wrappedParentConn) QueryContext(ctx context.Context, query string, args return c.Conn.(driver.Queryer).Query(query, dargs) } } + +func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value) + return err +} + +func (c wrappedConn) CheckNamedValue(v *driver.NamedValue) error { + if checker, ok := c.parent.(driver.NamedValueChecker); ok { + return checker.CheckNamedValue(v) + } + + return defaultCheckNamedValue(v) +} + +func (c wrappedConn) ResetSession(ctx context.Context) error { + conn, ok := c.parent.(driver.SessionResetter) + if !ok { + return nil + } + + return conn.ResetSession(ctx) +} + +func (c wrappedConn) IsValid() bool { + conn, ok := c.parent.(driver.Validator) + if !ok { + // the default if driver.Validator is not supported + return true + } + + return conn.IsValid() +} diff --git a/conn_go110.go b/conn_go110.go deleted file mode 100755 index b980d93..0000000 --- a/conn_go110.go +++ /dev/null @@ -1,19 +0,0 @@ -// +build go1.10 - -package sqlmw - -import ( - "context" - "database/sql/driver" -) - -var _ driver.SessionResetter = wrappedConn{} - -func (c wrappedConn) ResetSession(ctx context.Context) error { - conn, ok := c.parent.(driver.SessionResetter) - if !ok { - return nil - } - - return conn.ResetSession(ctx) -} diff --git a/conn_go115.go b/conn_go115.go deleted file mode 100755 index f635425..0000000 --- a/conn_go115.go +++ /dev/null @@ -1,19 +0,0 @@ -// +build go1.15 - -package sqlmw - -import ( - "database/sql/driver" -) - -var _ driver.SessionResetter = wrappedConn{} - -func (c wrappedConn) IsValid() bool { - conn, ok := c.parent.(driver.Validator) - if !ok { - // the default if driver.Validator is not supported - return true - } - - return conn.IsValid() -} diff --git a/conn_go19.go b/conn_go19.go deleted file mode 100755 index 268cabc..0000000 --- a/conn_go19.go +++ /dev/null @@ -1,22 +0,0 @@ -// +build go1.9 - -package sqlmw - -import "database/sql/driver" - -var ( - _ driver.NamedValueChecker = wrappedConn{} -) - -func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { - nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value) - return err -} - -func (c wrappedConn) CheckNamedValue(v *driver.NamedValue) error { - if checker, ok := c.parent.(driver.NamedValueChecker); ok { - return checker.CheckNamedValue(v) - } - - return defaultCheckNamedValue(v) -} diff --git a/connector.go b/connector.go index 4740a0d..86a3d57 100755 --- a/connector.go +++ b/connector.go @@ -1,5 +1,3 @@ -// +build go1.10 - package sqlmw import ( diff --git a/connector_test.go b/connector_test.go index fab236f..cbacb69 100755 --- a/connector_test.go +++ b/connector_test.go @@ -1,5 +1,3 @@ -// +build go1.10 - package sqlmw import ( diff --git a/driver.go b/driver.go index 254972e..65a7247 100755 --- a/driver.go +++ b/driver.go @@ -2,7 +2,7 @@ package sqlmw import "database/sql/driver" -// driver wraps a sql.Driver with an interceptor. +// wrappedDriver wraps a sql.Driver with an interceptor. type wrappedDriver struct { intr Interceptor parent driver.Driver @@ -10,7 +10,8 @@ type wrappedDriver struct { // Compile time validation that our types implement the expected interfaces var ( - _ driver.Driver = wrappedDriver{} + _ driver.Driver = wrappedDriver{} + _ driver.DriverContext = wrappedDriver{} ) // WrapDriver will wrap the passed SQL driver and return a new sql driver that uses it and also logs and traces calls using the passed logger and tracer @@ -36,3 +37,21 @@ func (d wrappedDriver) Open(name string) (driver.Conn, error) { return wrappedConn{intr: d.intr, parent: conn}, nil } + +// OpenConnector implements the database/sql/driver.Driver interface for WrappedDriver. +func (d wrappedDriver) OpenConnector(name string) (driver.Connector, error) { + driver, ok := d.parent.(driver.DriverContext) + if !ok { + return wrappedConnector{ + parent: dsnConnector{dsn: name, driver: d.parent}, + driverRef: &d, + }, nil + } + + conn, err := driver.OpenConnector(name) + if err != nil { + return nil, err + } + + return wrappedConnector{parent: conn, driverRef: &d}, nil +} diff --git a/driver_go110.go b/driver_go110.go deleted file mode 100755 index 0a2c0a3..0000000 --- a/driver_go110.go +++ /dev/null @@ -1,23 +0,0 @@ -// +build go1.10 - -package sqlmw - -import "database/sql/driver" - -var _ driver.DriverContext = wrappedDriver{} - -func (d wrappedDriver) OpenConnector(name string) (driver.Connector, error) { - driver, ok := d.parent.(driver.DriverContext) - if !ok { - return wrappedConnector{ - parent: dsnConnector{dsn: name, driver: d.parent}, - driverRef: &d, - }, nil - } - conn, err := driver.OpenConnector(name) - if err != nil { - return nil, err - } - - return wrappedConnector{parent: conn, driverRef: &d}, nil -} diff --git a/fakedb_test.go b/fakedb_test.go index bf261a6..18d509f 100644 --- a/fakedb_test.go +++ b/fakedb_test.go @@ -194,6 +194,7 @@ type fakeConn struct { called bool // nolint:structcheck // ignore unused warning, it is accessed via reflection rowsCloseCalled bool stmt driver.Stmt + queryStmt driver.StmtQueryContext tx driver.Tx } @@ -221,17 +222,12 @@ func (c *fakeConn) Close() error { return nil } func (c *fakeConn) Begin() (driver.Tx, error) { return c.tx, nil } -func (c *fakeConn) QueryContext(_ context.Context, _ string, nvs []driver.NamedValue) (driver.Rows, error) { - if c.stmt == nil { +func (c *fakeConn) QueryContext(ctx context.Context, _ string, nvs []driver.NamedValue) (driver.Rows, error) { + if c.queryStmt == nil { return &fakeRows{con: c}, nil } - var args []driver.Value - for _, nv := range nvs { - args = append(args, nv.Value) - } - - return c.stmt.Query(args) + return c.queryStmt.QueryContext(ctx, nvs) } func (c *fakeConnWithCheckNamedValue) CheckNamedValue(_ *driver.NamedValue) (err error) { diff --git a/magefiles/mage.go b/magefiles/mage.go index 091bb68..72bcf00 100644 --- a/magefiles/mage.go +++ b/magefiles/mage.go @@ -14,3 +14,9 @@ func Lint(ctx context.Context) error { args := []string{"run", "--config", ".golangci.yml"} return sh.RunV("golangci-lint", args...) } + +// Test runs the tests +func Test(ctx context.Context) error { + args := []string{"test", "-v", "-race", "./..."} + return sh.RunV("go", args...) +} diff --git a/rows_test.go b/rows_test.go index 7b9243c..defb142 100644 --- a/rows_test.go +++ b/rows_test.go @@ -109,7 +109,7 @@ func TestRowsNext(t *testing.T) { stmt := fakeStmt{ rows: rows, } - con.stmt = stmt + con.queryStmt = stmt driverName := driverName(t) interceptor := rowsNextInterceptor{} @@ -199,7 +199,7 @@ func TestRows_LikePGX(t *testing.T) { stmt := fakeStmt{ rows: rows, } - con.stmt = stmt + con.queryStmt = stmt driverName := driverName(t) interceptor := rowsNextInterceptor{} diff --git a/stmt.go b/stmt.go index f2d7e2f..2c91411 100755 --- a/stmt.go +++ b/stmt.go @@ -15,10 +15,11 @@ type wrappedStmt struct { // Compile time validation that our types implement the expected interfaces var ( - _ driver.Stmt = wrappedStmt{} - _ driver.StmtExecContext = wrappedStmt{} - _ driver.StmtQueryContext = wrappedStmt{} - _ driver.ColumnConverter = wrappedStmt{} + _ driver.Stmt = wrappedStmt{} + _ driver.StmtExecContext = wrappedStmt{} + _ driver.StmtQueryContext = wrappedStmt{} + _ driver.ColumnConverter = wrappedStmt{} + _ driver.NamedValueChecker = wrappedStmt{} ) func (s wrappedStmt) Close() (err error) { @@ -108,3 +109,15 @@ func (s wrappedParentStmt) ExecContext(ctx context.Context, args []driver.NamedV } return s.Exec(dargs) } + +func (s wrappedStmt) CheckNamedValue(v *driver.NamedValue) error { + if checker, ok := s.parent.(driver.NamedValueChecker); ok { + return checker.CheckNamedValue(v) + } + + if checker, ok := s.conn.parent.(driver.NamedValueChecker); ok { + return checker.CheckNamedValue(v) + } + + return defaultCheckNamedValue(v) +} diff --git a/stmt_go19.go b/stmt_go19.go deleted file mode 100755 index 8442101..0000000 --- a/stmt_go19.go +++ /dev/null @@ -1,17 +0,0 @@ -package sqlmw - -import "database/sql/driver" - -var _ driver.NamedValueChecker = wrappedStmt{} - -func (s wrappedStmt) CheckNamedValue(v *driver.NamedValue) error { - if checker, ok := s.parent.(driver.NamedValueChecker); ok { - return checker.CheckNamedValue(v) - } - - if checker, ok := s.conn.parent.(driver.NamedValueChecker); ok { - return checker.CheckNamedValue(v) - } - - return defaultCheckNamedValue(v) -} diff --git a/stmt_go19_test.go b/stmt_go19_test.go deleted file mode 100644 index 8f58153..0000000 --- a/stmt_go19_test.go +++ /dev/null @@ -1,171 +0,0 @@ -package sqlmw - -import ( - "database/sql" - "database/sql/driver" - "reflect" - "testing" -) - -// TestDefaultParameterConversion ensures that -// driver.DefaultParameterConverter is used when neither stmt nor con -// implements any value converters. -func TestDefaultParameterConversion(t *testing.T) { - driverName := driverName(t) - - expectVal := int64(1) - con := &fakeConn{} - fakeStmt := &fakeStmt{ - rows: &fakeRows{ - con: con, - vals: [][]driver.Value{{expectVal}}, - }, - } - con.stmt = fakeStmt - - sql.Register( - driverName, - Driver(&fakeDriver{conn: con}, &NullInterceptor{}), - ) - - db, err := sql.Open(driverName, "") - if err != nil { - t.Fatalf("Failed to open: %v", err) - } - - t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Errorf("Failed to close db: %v", err) - } - }) - - stmt, err := db.Prepare("") - if err != nil { - t.Fatalf("Prepare failed: %s", err) - } - - // int32 values are converted by driver.DefaultParameterConverter to - // int64 - queryVal := int32(1) - rows, err := stmt.Query(queryVal) - if err != nil { - t.Fatalf("Query failed: %s", err) - } - - count := 0 - for rows.Next() { - var v int64 - err := rows.Scan(&v) - if err != nil { - t.Fatalf("rows.Scan failed, %v", err) - } - if v != 1 { - t.Errorf("converted value is %d, passed value to Query was: %d", v, expectVal) - } - count++ - } - - if count != 1 { - t.Fatalf("got too many rows, expected 1, got %d ", 1) - } -} - -func TestWrappedStmt_CheckNamedValue(t *testing.T) { - tests := map[string]struct { - fd *fakeDriver - expected struct { - cc bool // Whether the fakeConn's CheckNamedValue was called - sc bool // Whether the fakeStmt's CheckNamedValue was called - } - }{ - "When both conn and stmt implement CheckNamedValue": { - fd: &fakeDriver{ - conn: &fakeConnWithCheckNamedValue{ - fakeConn: fakeConn{ - stmt: &fakeStmtWithCheckNamedValue{}, - }, - }, - }, - expected: struct { - cc bool - sc bool - }{cc: false, sc: true}, - }, - "When only conn implements CheckNamedValue": { - fd: &fakeDriver{ - conn: &fakeConnWithCheckNamedValue{ - fakeConn: fakeConn{ - stmt: &fakeStmtWithoutCheckNamedValue{}, - }, - }, - }, - expected: struct { - cc bool - sc bool - }{cc: true, sc: false}, - }, - "When only stmt implements CheckNamedValue": { - fd: &fakeDriver{ - conn: &fakeConnWithoutCheckNamedValue{ - fakeConn: fakeConn{ - stmt: &fakeStmtWithCheckNamedValue{}, - }, - }, - }, - expected: struct { - cc bool - sc bool - }{cc: false, sc: true}, - }, - "When both stmt do not implement CheckNamedValue": { - fd: &fakeDriver{ - conn: &fakeConnWithoutCheckNamedValue{ - fakeConn: fakeConn{ - stmt: &fakeStmtWithoutCheckNamedValue{}, - }, - }, - }, - expected: struct { - cc bool - sc bool - }{cc: false, sc: false}, - }, - } - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - driverName := driverName(t) - sql.Register(driverName, Driver(test.fd, &fakeInterceptor{})) - db, err := sql.Open(driverName, "dummy") - if err != nil { - t.Errorf("Failed to open: %v", err) - } - defer func() { - if err := db.Close(); err != nil { - t.Errorf("Failed to close db: %v", err) - } - }() - - stmt, err := db.Prepare("SELECT foo FROM bar Where 1 = ?") - if err != nil { - t.Errorf("Failed to prepare: %v", err) - } - - if _, err := stmt.Query(1); err != nil { - t.Errorf("Failed to query: %v", err) - } - - conn := reflect.ValueOf(test.fd.conn).Elem() - sc := conn.FieldByName("stmt").Elem().Elem().FieldByName("called").Bool() - cc := conn.FieldByName("called").Bool() - - if test.expected.sc != sc { - t.Errorf("sc mismatch.\n got: %#v\nwant: %#v", sc, test.expected.sc) - } - - if test.expected.cc != cc { - t.Errorf("cc mismatch.\n got: %#v\nwant: %#v", cc, test.expected.cc) - } - }) - } -} diff --git a/stmt_test.go b/stmt_test.go index 9d60083..02afa75 100644 --- a/stmt_test.go +++ b/stmt_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "database/sql/driver" + "reflect" "testing" ) @@ -99,3 +100,166 @@ func TestStmtQueryContext_PassWrappedRowContext(t *testing.T) { t.Error("RowsClose context not valid") } } + +// TestDefaultParameterConversion ensures that +// driver.DefaultParameterConverter is used when neither stmt nor con +// implements any value converters. +func TestDefaultParameterConversion(t *testing.T) { + driverName := driverName(t) + + expectVal := int64(1) + con := &fakeConn{} + fakeStmt := &fakeStmt{ + rows: &fakeRows{ + con: con, + vals: [][]driver.Value{{expectVal}}, + }, + } + con.stmt = fakeStmt + + sql.Register( + driverName, + Driver(&fakeDriver{conn: con}, &NullInterceptor{}), + ) + + db, err := sql.Open(driverName, "") + if err != nil { + t.Fatalf("Failed to open: %v", err) + } + + t.Cleanup(func() { + if err := db.Close(); err != nil { + t.Errorf("Failed to close db: %v", err) + } + }) + + stmt, err := db.Prepare("") + if err != nil { + t.Fatalf("Prepare failed: %s", err) + } + + // int32 values are converted by driver.DefaultParameterConverter to + // int64 + queryVal := int32(1) + rows, err := stmt.Query(queryVal) + if err != nil { + t.Fatalf("Query failed: %s", err) + } + + count := 0 + for rows.Next() { + var v int64 + err := rows.Scan(&v) + if err != nil { + t.Fatalf("rows.Scan failed, %v", err) + } + if v != 1 { + t.Errorf("converted value is %d, passed value to Query was: %d", v, expectVal) + } + count++ + } + + if count != 1 { + t.Fatalf("got too many rows, expected 1, got %d ", 1) + } +} + +func TestWrappedStmt_CheckNamedValue(t *testing.T) { + tests := map[string]struct { + fd *fakeDriver + expected struct { + cc bool // Whether the fakeConn's CheckNamedValue was called + sc bool // Whether the fakeStmt's CheckNamedValue was called + } + }{ + "When both conn and stmt implement CheckNamedValue": { + fd: &fakeDriver{ + conn: &fakeConnWithCheckNamedValue{ + fakeConn: fakeConn{ + stmt: &fakeStmtWithCheckNamedValue{}, + }, + }, + }, + expected: struct { + cc bool + sc bool + }{cc: false, sc: true}, + }, + "When only conn implements CheckNamedValue": { + fd: &fakeDriver{ + conn: &fakeConnWithCheckNamedValue{ + fakeConn: fakeConn{ + stmt: &fakeStmtWithoutCheckNamedValue{}, + }, + }, + }, + expected: struct { + cc bool + sc bool + }{cc: true, sc: false}, + }, + "When only stmt implements CheckNamedValue": { + fd: &fakeDriver{ + conn: &fakeConnWithoutCheckNamedValue{ + fakeConn: fakeConn{ + stmt: &fakeStmtWithCheckNamedValue{}, + }, + }, + }, + expected: struct { + cc bool + sc bool + }{cc: false, sc: true}, + }, + "When both stmt do not implement CheckNamedValue": { + fd: &fakeDriver{ + conn: &fakeConnWithoutCheckNamedValue{ + fakeConn: fakeConn{ + stmt: &fakeStmtWithoutCheckNamedValue{}, + }, + }, + }, + expected: struct { + cc bool + sc bool + }{cc: false, sc: false}, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + driverName := driverName(t) + sql.Register(driverName, Driver(test.fd, &fakeInterceptor{})) + db, err := sql.Open(driverName, "dummy") + if err != nil { + t.Errorf("Failed to open: %v", err) + } + defer func() { + if err := db.Close(); err != nil { + t.Errorf("Failed to close db: %v", err) + } + }() + + stmt, err := db.Prepare("SELECT foo FROM bar Where 1 = ?") + if err != nil { + t.Errorf("Failed to prepare: %v", err) + } + + if _, err := stmt.Query(1); err != nil { + t.Errorf("Failed to query: %v", err) + } + + conn := reflect.ValueOf(test.fd.conn).Elem() + sc := conn.FieldByName("stmt").Elem().Elem().FieldByName("called").Bool() + cc := conn.FieldByName("called").Bool() + + if test.expected.sc != sc { + t.Errorf("sc mismatch.\n got: %#v\nwant: %#v", sc, test.expected.sc) + } + + if test.expected.cc != cc { + t.Errorf("cc mismatch.\n got: %#v\nwant: %#v", cc, test.expected.cc) + } + }) + } +} diff --git a/tools/rows_picker_gen.go b/tools/rows_picker_gen.go index c02f99c..c4b036b 100755 --- a/tools/rows_picker_gen.go +++ b/tools/rows_picker_gen.go @@ -1,4 +1,4 @@ -// +build ignore +//go:build ignore package main