Skip to content

Commit

Permalink
Ensure context propagation in MySQL binding (dapr#1829)
Browse files Browse the repository at this point in the history
Spin-off from PR adding contexts to input bindings

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

Co-authored-by: Dapr Bot <56698301+dapr-bot@users.noreply.github.com>
  • Loading branch information
2 people authored and cmendible committed Jul 4, 2022
1 parent b5edec6 commit 06a49bd
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 19 deletions.
18 changes: 8 additions & 10 deletions bindings/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ type Mysql struct {
logger logger.Logger
}

var _ = bindings.OutputBinding(&Mysql{})

// NewMysql returns a new MySQL output binding.
func NewMysql(logger logger.Logger) *Mysql {
return &Mysql{logger: logger}
Expand Down Expand Up @@ -147,7 +145,7 @@ func (m *Mysql) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi
return nil, errors.Errorf("required metadata not set: %s", commandSQLKey)
}

startTime := time.Now().UTC()
startTime := time.Now()

resp := &bindings.InvokeResponse{
Metadata: map[string]string{
Expand All @@ -159,14 +157,14 @@ func (m *Mysql) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi

switch req.Operation { // nolint: exhaustive
case execOperation:
r, err := m.exec(s)
r, err := m.exec(ctx, s)
if err != nil {
return nil, err
}
resp.Metadata[respRowsAffectedKey] = strconv.FormatInt(r, 10)

case queryOperation:
d, err := m.query(s)
d, err := m.query(ctx, s)
if err != nil {
return nil, err
}
Expand All @@ -177,7 +175,7 @@ func (m *Mysql) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi
req.Operation, execOperation, queryOperation, closeOperation)
}

endTime := time.Now().UTC()
endTime := time.Now()
resp.Metadata[respEndTimeKey] = endTime.Format(time.RFC3339Nano)
resp.Metadata[respDurationKey] = endTime.Sub(startTime).String()

Expand All @@ -202,10 +200,10 @@ func (m *Mysql) Close() error {
return nil
}

func (m *Mysql) query(sql string) ([]byte, error) {
func (m *Mysql) query(ctx context.Context, sql string) ([]byte, error) {
m.logger.Debugf("query: %s", sql)

rows, err := m.db.Query(sql)
rows, err := m.db.QueryContext(ctx, sql)
if err != nil {
return nil, errors.Wrapf(err, "error executing %s", sql)
}
Expand All @@ -223,10 +221,10 @@ func (m *Mysql) query(sql string) ([]byte, error) {
return result, nil
}

func (m *Mysql) exec(sql string) (int64, error) {
func (m *Mysql) exec(ctx context.Context, sql string) (int64, error) {
m.logger.Debugf("exec: %s", sql)

res, err := m.db.Exec(sql)
res, err := m.db.ExecContext(ctx, sql)
if err != nil {
return 0, errors.Wrapf(err, "error executing %s", sql)
}
Expand Down
18 changes: 9 additions & 9 deletions bindings/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestQuery(t *testing.T) {
AddRow(3, "value-3", time.Now().Add(2000))

mock.ExpectQuery("SELECT \\* FROM foo WHERE id < 4").WillReturnRows(rows)
ret, err := m.query(`SELECT * FROM foo WHERE id < 4`)
ret, err := m.query(context.Background(), `SELECT * FROM foo WHERE id < 4`)
assert.Nil(t, err)
t.Logf("query result: %s", ret)
assert.Contains(t, string(ret), "\"id\":1")
Expand All @@ -57,7 +57,7 @@ func TestQuery(t *testing.T) {
AddRow(2, 2.2, time.Now().Add(1000)).
AddRow(3, 3.3, time.Now().Add(2000))
mock.ExpectQuery("SELECT \\* FROM foo WHERE id < 4").WillReturnRows(rows)
ret, err := m.query("SELECT * FROM foo WHERE id < 4")
ret, err := m.query(context.Background(), "SELECT * FROM foo WHERE id < 4")
assert.Nil(t, err)
t.Logf("query result: %s", ret)

Expand All @@ -84,7 +84,7 @@ func TestExec(t *testing.T) {
m, mock, _ := mockDatabase(t)
defer m.Close()
mock.ExpectExec("INSERT INTO foo \\(id, v1, ts\\) VALUES \\(.*\\)").WillReturnResult(sqlmock.NewResult(1, 1))
i, err := m.exec("INSERT INTO foo (id, v1, ts) VALUES (1, 'test-1', '2021-01-22')")
i, err := m.exec(context.Background(), "INSERT INTO foo (id, v1, ts) VALUES (1, 'test-1', '2021-01-22')")
assert.Equal(t, int64(1), i)
assert.Nil(t, err)
}
Expand All @@ -101,7 +101,7 @@ func TestInvoke(t *testing.T) {
Metadata: metadata,
Operation: execOperation,
}
resp, err := m.Invoke(context.TODO(), req)
resp, err := m.Invoke(context.Background(), req)
assert.Nil(t, err)
assert.Equal(t, "1", resp.Metadata[respRowsAffectedKey])
})
Expand All @@ -114,7 +114,7 @@ func TestInvoke(t *testing.T) {
Metadata: metadata,
Operation: execOperation,
}
resp, err := m.Invoke(context.TODO(), req)
resp, err := m.Invoke(context.Background(), req)
assert.Nil(t, resp)
assert.NotNil(t, err)
})
Expand All @@ -132,7 +132,7 @@ func TestInvoke(t *testing.T) {
Metadata: metadata,
Operation: queryOperation,
}
resp, err := m.Invoke(context.TODO(), req)
resp, err := m.Invoke(context.Background(), req)
assert.Nil(t, err)
var data []interface{}
err = json.Unmarshal(resp.Data, &data)
Expand All @@ -148,7 +148,7 @@ func TestInvoke(t *testing.T) {
Metadata: metadata,
Operation: queryOperation,
}
resp, err := m.Invoke(context.TODO(), req)
resp, err := m.Invoke(context.Background(), req)
assert.Nil(t, resp)
assert.NotNil(t, err)
})
Expand All @@ -158,7 +158,7 @@ func TestInvoke(t *testing.T) {
req := &bindings.InvokeRequest{
Operation: closeOperation,
}
resp, _ := m.Invoke(context.TODO(), req)
resp, _ := m.Invoke(context.Background(), req)
assert.Nil(t, resp)
})

Expand All @@ -168,7 +168,7 @@ func TestInvoke(t *testing.T) {
Metadata: map[string]string{},
Operation: "unsupported",
}
resp, err := m.Invoke(context.TODO(), req)
resp, err := m.Invoke(context.Background(), req)
assert.Nil(t, resp)
assert.NotNil(t, err)
})
Expand Down

0 comments on commit 06a49bd

Please sign in to comment.