diff --git a/src/database/sql/convert.go b/src/database/sql/convert.go index c349a96edf0d8..b44bed559d5a5 100644 --- a/src/database/sql/convert.go +++ b/src/database/sql/convert.go @@ -12,7 +12,6 @@ import ( "fmt" "reflect" "strconv" - "sync" "time" "unicode" "unicode/utf8" @@ -38,17 +37,10 @@ func validateNamedValueName(name string) error { return fmt.Errorf("name %q does not begin with a letter", name) } -func driverNumInput(ds *driverStmt) int { - ds.Lock() - defer ds.Unlock() // in case NumInput panics - return ds.si.NumInput() -} - // ccChecker wraps the driver.ColumnConverter and allows it to be used // as if it were a NamedValueChecker. If the driver ColumnConverter // is not present then the NamedValueChecker will return driver.ErrSkip. type ccChecker struct { - sync.Locker cci driver.ColumnConverter want int } @@ -88,9 +80,7 @@ func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error { // same error. var err error arg := nv.Value - c.Lock() nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg) - c.Unlock() if err != nil { return err } @@ -112,7 +102,7 @@ func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { // Stmt.Query into driver Values. // // The statement ds may be nil, if no statement is available. -func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { +func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { nvargs := make([]driver.NamedValue, len(args)) // -1 means the driver doesn't know how to count the number of @@ -124,8 +114,7 @@ func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.Na var cc ccChecker if ds != nil { si = ds.si - want = driverNumInput(ds) - cc.Locker = ds.Locker + want = ds.si.NumInput() cc.want = want } diff --git a/src/database/sql/convert_test.go b/src/database/sql/convert_test.go index 35dbab3339ebd..c198177760ac3 100644 --- a/src/database/sql/convert_test.go +++ b/src/database/sql/convert_test.go @@ -481,7 +481,7 @@ func TestDriverArgs(t *testing.T) { } for i, tt := range tests { ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}} - got, err := driverArgs(nil, ds, tt.args) + got, err := driverArgsConnLocked(nil, ds, tt.args) if err != nil { t.Errorf("test[%d]: %v", i, err) continue diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index 070b7834539d5..31e22a7a74343 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -1005,6 +1005,7 @@ type rowsCursor struct { } func (rc *rowsCursor) touchMem() { + rc.parentMem.touchMem() rc.line++ } diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index c17b2b543b572..be73b5e372920 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -1368,12 +1368,12 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q } if ok { var nvdargs []driver.NamedValue - nvdargs, err = driverArgs(dc.ci, nil, args) - if err != nil { - return nil, err - } var resi driver.Result withLock(dc, func() { + nvdargs, err = driverArgsConnLocked(dc.ci, nil, args) + if err != nil { + return + } resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs) }) if err != driver.ErrSkip { @@ -1439,13 +1439,14 @@ func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn fu queryer, ok = dc.ci.(driver.Queryer) } if ok { - nvdargs, err := driverArgs(dc.ci, nil, args) - if err != nil { - releaseConn(err) - return nil, err - } + var nvdargs []driver.NamedValue var rowsi driver.Rows + var err error withLock(dc, func() { + nvdargs, err = driverArgsConnLocked(dc.ci, nil, args) + if err != nil { + return + } rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs) }) if err != driver.ErrSkip { @@ -2034,11 +2035,14 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { stmt.mu.Unlock() if si == nil { - cs, err := stmt.prepareOnConnLocked(ctx, dc) + withLock(dc, func() { + var ds *driverStmt + ds, err = stmt.prepareOnConnLocked(ctx, dc) + si = ds.si + }) if err != nil { return &Stmt{stickyErr: err} } - si = cs.si } parentStmt = stmt } @@ -2230,14 +2234,14 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { } func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) { - dargs, err := driverArgs(ci, ds, args) + ds.Lock() + defer ds.Unlock() + + dargs, err := driverArgsConnLocked(ci, ds, args) if err != nil { return nil, err } - ds.Lock() - defer ds.Unlock() - resi, err := ctxDriverStmtExec(ctx, ds.si, dargs) if err != nil { return nil, err @@ -2401,10 +2405,10 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { } func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) { - var want int - withLock(ds, func() { - want = ds.si.NumInput() - }) + ds.Lock() + defer ds.Unlock() + + want := ds.si.NumInput() // -1 means the driver doesn't know how to count the number of // placeholders, so we won't sanity check input here and instead let the @@ -2413,14 +2417,11 @@ func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, arg return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) } - dargs, err := driverArgs(ci, ds, args) + dargs, err := driverArgsConnLocked(ci, ds, args) if err != nil { return nil, err } - ds.Lock() - defer ds.Unlock() - rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs) if err != nil { return nil, err @@ -2583,9 +2584,16 @@ func (rs *Rows) nextLocked() (doClose, ok bool) { if rs.closed { return false, false } + + // Lock the driver connection before calling the driver interface + // rowsi to prevent a Tx from rolling back the connection at the same time. + rs.dc.Lock() + defer rs.dc.Unlock() + if rs.lastcols == nil { rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) } + rs.lasterr = rs.rowsi.Next(rs.lastcols) if rs.lasterr != nil { // Close the connection if there is a driver error. @@ -2635,6 +2643,12 @@ func (rs *Rows) NextResultSet() bool { doClose = true return false } + + // Lock the driver connection before calling the driver interface + // rowsi to prevent a Tx from rolling back the connection at the same time. + rs.dc.Lock() + defer rs.dc.Unlock() + rs.lasterr = nextResultSet.NextResultSet() if rs.lasterr != nil { doClose = true @@ -2666,6 +2680,9 @@ func (rs *Rows) Columns() ([]string, error) { if rs.rowsi == nil { return nil, errors.New("sql: no Rows available") } + rs.dc.Lock() + defer rs.dc.Unlock() + return rs.rowsi.Columns(), nil } @@ -2680,7 +2697,10 @@ func (rs *Rows) ColumnTypes() ([]*ColumnType, error) { if rs.rowsi == nil { return nil, errors.New("sql: no Rows available") } - return rowsColumnInfoSetup(rs.rowsi), nil + rs.dc.Lock() + defer rs.dc.Unlock() + + return rowsColumnInfoSetupConnLocked(rs.rowsi), nil } // ColumnType contains the name and type of a column. @@ -2741,7 +2761,7 @@ func (ci *ColumnType) DatabaseTypeName() string { return ci.databaseType } -func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType { +func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType { names := rowsi.Columns() list := make([]*ColumnType, len(names)) diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index dead273503f5a..f7b7d988e1b80 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -3157,6 +3157,9 @@ func TestIssue6081(t *testing.T) { // In the test, a context is canceled while the query is in process so // the internal rollback will run concurrently with the explicitly called // Tx.Rollback. +// +// The addition of calling rows.Next also tests +// Issue 21117. func TestIssue18429(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) @@ -3189,6 +3192,9 @@ func TestIssue18429(t *testing.T) { // reported. rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|") if rows != nil { + // Call Next to test Issue 21117 and check for races. + for rows.Next() { + } rows.Close() } // This call will race with the context cancel rollback to complete