diff --git a/client.go b/client.go index 3a22b99..e15ace0 100644 --- a/client.go +++ b/client.go @@ -763,12 +763,18 @@ func (c *Client) Do(req *Request) (*http.Response, error) { break } + wait := c.Backoff(c.RetryWaitMin, c.RetryWaitMax, i, resp) + // We're going to retry, consume any response to reuse the connection. if doErr == nil { c.drainBody(resp.Body) } - wait := c.Backoff(c.RetryWaitMin, c.RetryWaitMax, i, resp) + if deadline, ok := req.Context().Deadline(); ok && timeNow().Add(wait).After(deadline) { + c.HTTPClient.CloseIdleConnections() + return nil, context.DeadlineExceeded + } + if logger != nil { desc := fmt.Sprintf("%s %s", req.Method, redactURL(req.URL)) if resp != nil { diff --git a/client_test.go b/client_test.go index 0e8ca60..041a6bb 100644 --- a/client_test.go +++ b/client_test.go @@ -23,6 +23,12 @@ import ( "github.com/hashicorp/go-hclog" ) +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + func TestRequest(t *testing.T) { // Fails on invalid request _, err := NewRequest("GET", "://foo", nil) @@ -1332,6 +1338,57 @@ func TestClient_BackoffCustom(t *testing.T) { } } +func TestClient_BackoffStopsWhenWaitWouldExceedDeadline(t *testing.T) { + testStaticTime(t) + + client := NewClient() + client.RetryMax = 1 + + ctx, cancel := context.WithDeadline(context.Background(), timeNow().Add(500*time.Millisecond)) + defer cancel() + + req, err := NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + return true, nil + } + + backoffCalls := 0 + client.Backoff = func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { + backoffCalls++ + return time.Second + } + + doCalls := 0 + client.HTTPClient = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + doCalls++ + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(strings.NewReader("retry later")), + Header: make(http.Header), + Request: r, + }, nil + }), + } + + _, err = client.Do(req) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected context deadline exceeded, got: %v", err) + } + + if doCalls != 1 { + t.Fatalf("expected 1 request attempt before aborting retry wait, got %d", doCalls) + } + + if backoffCalls != 1 { + t.Fatalf("expected Backoff to be consulted once, got %d", backoffCalls) + } +} + func TestClient_StandardClient(t *testing.T) { // Create a retryable HTTP client. client := NewClient()