Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ jobs:
go-version: 1.24 # TODO: matrix

- name: Test
run: go test -v -race ./...
uses: magefile/mage-action@6f50bbb8ea47d56e62dee92392788acbc8192d0b # v3.1.0
with:
version: latest
args: test

check:
runs-on: ubuntu-latest
Expand Down
34 changes: 34 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ var (
_ driver.Pinger = wrappedConn{}
_ driver.Queryer = wrappedConn{}
_ driver.QueryerContext = wrappedConn{}
_ driver.NamedValueChecker = wrappedConn{}
_ driver.SessionResetter = wrappedConn{}
)

func (c wrappedConn) Prepare(query string) (driver.Stmt, error) {
Expand Down Expand Up @@ -179,3 +181,35 @@ func (c wrappedParentConn) QueryContext(ctx context.Context, query string, args
return c.Conn.(driver.Queryer).Query(query, dargs)
}
}

func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
return err
}

func (c wrappedConn) CheckNamedValue(v *driver.NamedValue) error {
if checker, ok := c.parent.(driver.NamedValueChecker); ok {
return checker.CheckNamedValue(v)
}

return defaultCheckNamedValue(v)
}

func (c wrappedConn) ResetSession(ctx context.Context) error {
conn, ok := c.parent.(driver.SessionResetter)
if !ok {
return nil
}

return conn.ResetSession(ctx)
}

func (c wrappedConn) IsValid() bool {
conn, ok := c.parent.(driver.Validator)
if !ok {
// the default if driver.Validator is not supported
return true
}

return conn.IsValid()
}
19 changes: 0 additions & 19 deletions conn_go110.go

This file was deleted.

19 changes: 0 additions & 19 deletions conn_go115.go

This file was deleted.

22 changes: 0 additions & 22 deletions conn_go19.go

This file was deleted.

2 changes: 0 additions & 2 deletions connector.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// +build go1.10

package sqlmw

import (
Expand Down
2 changes: 0 additions & 2 deletions connector_test.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// +build go1.10

package sqlmw

import (
Expand Down
23 changes: 21 additions & 2 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package sqlmw

import "database/sql/driver"

// driver wraps a sql.Driver with an interceptor.
// wrappedDriver wraps a sql.Driver with an interceptor.
type wrappedDriver struct {
intr Interceptor
parent driver.Driver
}

// Compile time validation that our types implement the expected interfaces
var (
_ driver.Driver = wrappedDriver{}
_ driver.Driver = wrappedDriver{}
_ driver.DriverContext = wrappedDriver{}
)

// WrapDriver will wrap the passed SQL driver and return a new sql driver that uses it and also logs and traces calls using the passed logger and tracer
Expand All @@ -36,3 +37,21 @@ func (d wrappedDriver) Open(name string) (driver.Conn, error) {

return wrappedConn{intr: d.intr, parent: conn}, nil
}

// OpenConnector implements the database/sql/driver.Driver interface for WrappedDriver.
func (d wrappedDriver) OpenConnector(name string) (driver.Connector, error) {
driver, ok := d.parent.(driver.DriverContext)
if !ok {
return wrappedConnector{
parent: dsnConnector{dsn: name, driver: d.parent},
driverRef: &d,
}, nil
}

conn, err := driver.OpenConnector(name)
if err != nil {
return nil, err
}

return wrappedConnector{parent: conn, driverRef: &d}, nil
}
23 changes: 0 additions & 23 deletions driver_go110.go

This file was deleted.

12 changes: 4 additions & 8 deletions fakedb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ type fakeConn struct {
called bool // nolint:structcheck // ignore unused warning, it is accessed via reflection
rowsCloseCalled bool
stmt driver.Stmt
queryStmt driver.StmtQueryContext
tx driver.Tx
}

Expand Down Expand Up @@ -221,17 +222,12 @@ func (c *fakeConn) Close() error { return nil }

func (c *fakeConn) Begin() (driver.Tx, error) { return c.tx, nil }

func (c *fakeConn) QueryContext(_ context.Context, _ string, nvs []driver.NamedValue) (driver.Rows, error) {
if c.stmt == nil {
func (c *fakeConn) QueryContext(ctx context.Context, _ string, nvs []driver.NamedValue) (driver.Rows, error) {
if c.queryStmt == nil {
return &fakeRows{con: c}, nil
}

var args []driver.Value
for _, nv := range nvs {
args = append(args, nv.Value)
}

return c.stmt.Query(args)
return c.queryStmt.QueryContext(ctx, nvs)
}

func (c *fakeConnWithCheckNamedValue) CheckNamedValue(_ *driver.NamedValue) (err error) {
Expand Down
6 changes: 6 additions & 0 deletions magefiles/mage.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,9 @@ func Lint(ctx context.Context) error {
args := []string{"run", "--config", ".golangci.yml"}
return sh.RunV("golangci-lint", args...)
}

// Test runs the tests
func Test(ctx context.Context) error {
args := []string{"test", "-v", "-race", "./..."}
return sh.RunV("go", args...)
}
4 changes: 2 additions & 2 deletions rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func TestRowsNext(t *testing.T) {
stmt := fakeStmt{
rows: rows,
}
con.stmt = stmt
con.queryStmt = stmt
driverName := driverName(t)
interceptor := rowsNextInterceptor{}

Expand Down Expand Up @@ -199,7 +199,7 @@ func TestRows_LikePGX(t *testing.T) {
stmt := fakeStmt{
rows: rows,
}
con.stmt = stmt
con.queryStmt = stmt
driverName := driverName(t)
interceptor := rowsNextInterceptor{}

Expand Down
21 changes: 17 additions & 4 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ type wrappedStmt struct {

// Compile time validation that our types implement the expected interfaces
var (
_ driver.Stmt = wrappedStmt{}
_ driver.StmtExecContext = wrappedStmt{}
_ driver.StmtQueryContext = wrappedStmt{}
_ driver.ColumnConverter = wrappedStmt{}
_ driver.Stmt = wrappedStmt{}
_ driver.StmtExecContext = wrappedStmt{}
_ driver.StmtQueryContext = wrappedStmt{}
_ driver.ColumnConverter = wrappedStmt{}
_ driver.NamedValueChecker = wrappedStmt{}
)

func (s wrappedStmt) Close() (err error) {
Expand Down Expand Up @@ -108,3 +109,15 @@ func (s wrappedParentStmt) ExecContext(ctx context.Context, args []driver.NamedV
}
return s.Exec(dargs)
}

func (s wrappedStmt) CheckNamedValue(v *driver.NamedValue) error {
if checker, ok := s.parent.(driver.NamedValueChecker); ok {
return checker.CheckNamedValue(v)
}

if checker, ok := s.conn.parent.(driver.NamedValueChecker); ok {
return checker.CheckNamedValue(v)
}

return defaultCheckNamedValue(v)
}
17 changes: 0 additions & 17 deletions stmt_go19.go

This file was deleted.

Loading