diff --git a/br/pkg/pdutil/pd.go b/br/pkg/pdutil/pd.go index f7aa759e6c4a1..8f88fa73249b0 100644 --- a/br/pkg/pdutil/pd.go +++ b/br/pkg/pdutil/pd.go @@ -130,13 +130,13 @@ var ( // pdHTTPRequest defines the interface to send a request to pd and return the result in bytes. type pdHTTPRequest func(ctx context.Context, addr string, prefix string, - cli *http.Client, method string, body io.Reader) ([]byte, error) + cli *http.Client, method string, body []byte) ([]byte, error) // pdRequest is a func to send an HTTP to pd and return the result bytes. func pdRequest( ctx context.Context, addr string, prefix string, - cli *http.Client, method string, body io.Reader) ([]byte, error) { + cli *http.Client, method string, body []byte) ([]byte, error) { _, respBody, err := pdRequestWithCode(ctx, addr, prefix, cli, method, body) return respBody, err } @@ -144,7 +144,7 @@ func pdRequest( func pdRequestWithCode( ctx context.Context, addr string, prefix string, - cli *http.Client, method string, body io.Reader) (int, []byte, error) { + cli *http.Client, method string, body []byte) (int, []byte, error) { u, err := url.Parse(addr) if err != nil { return 0, nil, errors.Trace(err) @@ -154,10 +154,13 @@ func pdRequestWithCode( req *http.Request resp *http.Response ) + if body == nil { + body = []byte("") + } count := 0 // the total retry duration: 120*1 = 2min for { - req, err = http.NewRequestWithContext(ctx, method, reqURL, body) + req, err = http.NewRequestWithContext(ctx, method, reqURL, bytes.NewBuffer(body)) if err != nil { return 0, nil, errors.Trace(err) } @@ -184,6 +187,8 @@ func pdRequestWithCode( (err != nil && !common.IsRetryableError(err)) { break } + log.Warn("request failed, will retry later", + zap.String("url", reqURL), zap.Int("retry-count", count), zap.Error(err)) if resp != nil { _ = resp.Body.Close() } @@ -434,7 +439,7 @@ func (p *PdController) doPauseSchedulers(ctx context.Context, removedSchedulers := make([]string, 0, len(schedulers)) for _, scheduler := range schedulers { for _, addr := range p.getAllPDAddrs() { - _, err = post(ctx, addr, pdapi.SchedulerByName(scheduler), p.cli, http.MethodPost, bytes.NewBuffer(body)) + _, err = post(ctx, addr, pdapi.SchedulerByName(scheduler), p.cli, http.MethodPost, body) if err == nil { removedSchedulers = append(removedSchedulers, scheduler) break @@ -516,7 +521,7 @@ func (p *PdController) resumeSchedulerWith(ctx context.Context, schedulers []str } for _, scheduler := range schedulers { for _, addr := range p.getAllPDAddrs() { - _, err = post(ctx, addr, pdapi.SchedulerByName(scheduler), p.cli, http.MethodPost, bytes.NewBuffer(body)) + _, err = post(ctx, addr, pdapi.SchedulerByName(scheduler), p.cli, http.MethodPost, body) if err == nil { break } @@ -605,7 +610,7 @@ func (p *PdController) doUpdatePDScheduleConfig( return errors.Trace(err) } _, e := post(ctx, addr, prefix, - p.cli, http.MethodPost, bytes.NewBuffer(reqData)) + p.cli, http.MethodPost, reqData) if e == nil { return nil } @@ -861,7 +866,7 @@ func (p *PdController) RecoverBaseAllocID(ctx context.Context, id uint64) error }) var err error for _, addr := range p.getAllPDAddrs() { - _, e := pdRequest(ctx, addr, pdapi.BaseAllocID, p.cli, http.MethodPost, bytes.NewBuffer(reqData)) + _, e := pdRequest(ctx, addr, pdapi.BaseAllocID, p.cli, http.MethodPost, reqData) if e != nil { log.Warn("failed to recover base alloc id", zap.String("addr", addr), zap.Error(e)) err = e @@ -885,7 +890,7 @@ func (p *PdController) ResetTS(ctx context.Context, ts uint64) error { }) var err error for _, addr := range p.getAllPDAddrs() { - code, _, e := pdRequestWithCode(ctx, addr, pdapi.ResetTS, p.cli, http.MethodPost, bytes.NewBuffer(reqData)) + code, _, e := pdRequestWithCode(ctx, addr, pdapi.ResetTS, p.cli, http.MethodPost, reqData) if e != nil { // for pd version <= 6.2, if the given ts < current ts of pd, pd returns StatusForbidden. // it's not an error for br @@ -962,7 +967,7 @@ func (p *PdController) CreateOrUpdateRegionLabelRule(ctx context.Context, rule L addrs := p.getAllPDAddrs() for i, addr := range addrs { _, lastErr = pdRequest(ctx, addr, pdapi.RegionLabelRule, - p.cli, http.MethodPost, bytes.NewBuffer(reqData)) + p.cli, http.MethodPost, reqData) if lastErr == nil { return nil } diff --git a/br/pkg/pdutil/pd_serial_test.go b/br/pkg/pdutil/pd_serial_test.go index 67a37c072b834..5dc91f7192a46 100644 --- a/br/pkg/pdutil/pd_serial_test.go +++ b/br/pkg/pdutil/pd_serial_test.go @@ -3,7 +3,6 @@ package pdutil import ( - "bytes" "context" "encoding/hex" "encoding/json" @@ -32,7 +31,7 @@ func TestScheduler(t *testing.T) { defer cancel() scheduler := "balance-leader-scheduler" - mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock := func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { return nil, errors.New("failed") } schedulerPauseCh := make(chan struct{}) @@ -67,7 +66,7 @@ func TestScheduler(t *testing.T) { _, err = pdController.listSchedulersWith(ctx, mock) require.EqualError(t, err, "failed") - mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock = func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { return []byte(`["` + scheduler + `"]`), nil } @@ -87,7 +86,7 @@ func TestScheduler(t *testing.T) { func TestGetClusterVersion(t *testing.T) { pdController := &PdController{addrs: []string{"", ""}} // two endpoints counter := 0 - mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock := func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { counter++ if counter <= 1 { return nil, errors.New("mock error") @@ -100,7 +99,7 @@ func TestGetClusterVersion(t *testing.T) { require.NoError(t, err) require.Equal(t, "test", respString) - mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock = func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { return nil, errors.New("mock error") } _, err = pdController.getClusterVersionWith(ctx, mock) @@ -130,7 +129,7 @@ func TestRegionCount(t *testing.T) { require.Equal(t, 3, len(regions.Regions)) mock := func( - _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader, + _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ []byte, ) ([]byte, error) { query := fmt.Sprintf("%s/%s", addr, prefix) u, e := url.Parse(query) @@ -181,6 +180,9 @@ func TestPDRequestRetry(t *testing.T) { count := 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { count++ + bytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.Equal(t, "test", string(bytes)) if count <= pdRequestRetryTime-1 { w.WriteHeader(http.StatusGatewayTimeout) return @@ -196,8 +198,7 @@ func TestPDRequestRetry(t *testing.T) { cli.Transport.(*http.Transport).DisableKeepAlives = true taddr := ts.URL - body := bytes.NewBuffer([]byte("test")) - _, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodPost, body) + _, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodPost, []byte("test")) require.NoError(t, reqErr) ts.Close() count = 0 @@ -269,7 +270,7 @@ func TestStoreInfo(t *testing.T) { }, } mock := func( - _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader, + _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ []byte, ) ([]byte, error) { require.Equal(t, fmt.Sprintf("http://mock%s", pdapi.StoreByID(1)),