diff --git a/godo.go b/godo.go index f9ab41d0..abb4cfef 100644 --- a/godo.go +++ b/godo.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "net/url" "reflect" @@ -27,9 +26,11 @@ const ( userAgent = "godo/" + libraryVersion mediaType = "application/json" - headerRateLimit = "RateLimit-Limit" - headerRateRemaining = "RateLimit-Remaining" - headerRateReset = "RateLimit-Reset" + headerRateLimit = "RateLimit-Limit" + headerRateRemaining = "RateLimit-Remaining" + headerRateReset = "RateLimit-Reset" + headerRequestID = "x-request-id" + internalHeaderRetryAttempts = "X-Godo-Retry-Attempts" ) // Client manages communication with DigitalOcean V2 API. @@ -178,6 +179,9 @@ type ErrorResponse struct { // RequestID returned from the API, useful to contact support. RequestID string `json:"request_id"` + + // Attempts is the number of times the request was attempted when retries are enabled. + Attempts int } // Rate contains the rate limit for the current client. @@ -314,6 +318,19 @@ func New(httpClient *http.Client, opts ...ClientOpt) (*Client, error) { // if timeout is set, it is maintained before overwriting client with StandardClient() retryableClient.HTTPClient.Timeout = c.HTTPClient.Timeout + // This custom ErrorHandler is required to provide errors that are consistent + // with a *godo.ErrorResponse and a non-nil *godo.Response while providing + // insight into retries using an internal header. + retryableClient.ErrorHandler = func(resp *http.Response, err error, numTries int) (*http.Response, error) { + if resp != nil { + resp.Header.Add(internalHeaderRetryAttempts, strconv.Itoa(numTries)) + + return resp, err + } + + return resp, err + } + var source *oauth2.Transport if _, ok := c.HTTPClient.Transport.(*oauth2.Transport); ok { source = c.HTTPClient.Transport.(*oauth2.Transport) @@ -489,7 +506,7 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res // won't reuse it anyway. const maxBodySlurpSize = 2 << 10 if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize { - io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize) + io.CopyN(io.Discard, resp.Body, maxBodySlurpSize) } if rerr := resp.Body.Close(); err == nil { @@ -539,12 +556,17 @@ func DoRequestWithClient( } func (r *ErrorResponse) Error() string { + var attempted string + if r.Attempts > 0 { + attempted = fmt.Sprintf("; giving up after %d attempt(s)", r.Attempts) + } + if r.RequestID != "" { - return fmt.Sprintf("%v %v: %d (request %q) %v", - r.Response.Request.Method, r.Response.Request.URL, r.Response.StatusCode, r.RequestID, r.Message) + return fmt.Sprintf("%v %v: %d (request %q) %v%s", + r.Response.Request.Method, r.Response.Request.URL, r.Response.StatusCode, r.RequestID, r.Message, attempted) } - return fmt.Sprintf("%v %v: %d %v", - r.Response.Request.Method, r.Response.Request.URL, r.Response.StatusCode, r.Message) + return fmt.Sprintf("%v %v: %d %v%s", + r.Response.Request.Method, r.Response.Request.URL, r.Response.StatusCode, r.Message, attempted) } // CheckResponse checks the API response for errors, and returns them if present. A response is considered an @@ -557,7 +579,7 @@ func CheckResponse(r *http.Response) error { } errorResponse := &ErrorResponse{Response: r} - data, err := ioutil.ReadAll(r.Body) + data, err := io.ReadAll(r.Body) if err == nil && len(data) > 0 { err := json.Unmarshal(data, errorResponse) if err != nil { @@ -566,7 +588,12 @@ func CheckResponse(r *http.Response) error { } if errorResponse.RequestID == "" { - errorResponse.RequestID = r.Header.Get("x-request-id") + errorResponse.RequestID = r.Header.Get(headerRequestID) + } + + attempts, strconvErr := strconv.Atoi(r.Header.Get(internalHeaderRetryAttempts)) + if strconvErr == nil { + errorResponse.Attempts = attempts } return errorResponse diff --git a/godo_test.go b/godo_test.go index 323eba72..e7575ac9 100644 --- a/godo_test.go +++ b/godo_test.go @@ -4,7 +4,7 @@ import ( "bytes" "context" "fmt" - "io/ioutil" + "io" "log" "net/http" "net/http/httptest" @@ -191,7 +191,7 @@ func TestNewRequest(t *testing.T) { } // test body was JSON encoded - body, _ := ioutil.ReadAll(req.Body) + body, _ := io.ReadAll(req.Body) if string(body) != outBody { t.Errorf("NewRequest(%v)Body = %v, expected %v", inBody, string(body), outBody) } @@ -242,7 +242,7 @@ func TestNewRequest_withUserData(t *testing.T) { } // test body was JSON encoded - body, _ := ioutil.ReadAll(req.Body) + body, _ := io.ReadAll(req.Body) if string(body) != outBody { t.Errorf("NewRequest(%v)Body = %v, expected %v", inBody, string(body), outBody) } @@ -271,7 +271,7 @@ func TestNewRequest_withDropletAgent(t *testing.T) { } // test body was JSON encoded - body, _ := ioutil.ReadAll(req.Body) + body, _ := io.ReadAll(req.Body) if string(body) != outBody { t.Errorf("NewRequest(%v)Body = %v, expected %v", inBody, string(body), outBody) } @@ -406,7 +406,7 @@ func TestCheckResponse(t *testing.T) { input: &http.Response{ Request: &http.Request{}, StatusCode: http.StatusBadRequest, - Body: ioutil.NopCloser(strings.NewReader(`{"message":"m", + Body: io.NopCloser(strings.NewReader(`{"message":"m", "errors": [{"resource": "r", "field": "f", "code": "c"}]}`)), }, expected: &ErrorResponse{ @@ -418,7 +418,7 @@ func TestCheckResponse(t *testing.T) { input: &http.Response{ Request: &http.Request{}, StatusCode: http.StatusBadRequest, - Body: ioutil.NopCloser(strings.NewReader(`{"message":"m", "request_id": "dead-beef", + Body: io.NopCloser(strings.NewReader(`{"message":"m", "request_id": "dead-beef", "errors": [{"resource": "r", "field": "f", "code": "c"}]}`)), }, expected: &ErrorResponse{ @@ -432,7 +432,7 @@ func TestCheckResponse(t *testing.T) { Request: &http.Request{}, StatusCode: http.StatusBadRequest, Header: testHeaders, - Body: ioutil.NopCloser(strings.NewReader(`{"message":"m", + Body: io.NopCloser(strings.NewReader(`{"message":"m", "errors": [{"resource": "r", "field": "f", "code": "c"}]}`)), }, expected: &ErrorResponse{ @@ -448,7 +448,7 @@ func TestCheckResponse(t *testing.T) { Request: &http.Request{}, StatusCode: http.StatusBadRequest, Header: testHeaders, - Body: ioutil.NopCloser(strings.NewReader(`{"message":"m", "request_id": "dead-beef-body", + Body: io.NopCloser(strings.NewReader(`{"message":"m", "request_id": "dead-beef-body", "errors": [{"resource": "r", "field": "f", "code": "c"}]}`)), }, expected: &ErrorResponse{ @@ -463,7 +463,7 @@ func TestCheckResponse(t *testing.T) { input: &http.Response{ Request: &http.Request{}, StatusCode: http.StatusBadRequest, - Body: ioutil.NopCloser(strings.NewReader("")), + Body: io.NopCloser(strings.NewReader("")), }, expected: &ErrorResponse{}, }, @@ -614,14 +614,15 @@ func TestWithRetryAndBackoffs(t *testing.T) { url, _ := url.Parse(server.URL) mux.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(500) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"id": "bad_request", "message": "broken"}`)) }) tokenSrc := oauth2.StaticTokenSource(&oauth2.Token{ AccessToken: "new_token", }) - oauth_client := oauth2.NewClient(oauth2.NoContext, tokenSrc) + oauthClient := oauth2.NewClient(oauth2.NoContext, tokenSrc) waitMax := PtrTo(6.0) waitMin := PtrTo(3.0) @@ -633,7 +634,7 @@ func TestWithRetryAndBackoffs(t *testing.T) { } // Create the client. Use short retry windows so we fail faster. - client, err := New(oauth_client, WithRetryAndBackoffs(retryConfig)) + client, err := New(oauthClient, WithRetryAndBackoffs(retryConfig)) client.BaseURL = url if err != nil { t.Fatalf("err: %v", err) @@ -645,13 +646,12 @@ func TestWithRetryAndBackoffs(t *testing.T) { t.Fatalf("err: %v", err) } - expectingErr := "giving up after 4 attempt(s)" + expectingErr := fmt.Sprintf("GET %s/foo: 500 broken; giving up after 4 attempt(s)", url) // Send the request. _, err = client.Do(context.Background(), req, nil) - if err == nil || !strings.HasSuffix(err.Error(), expectingErr) { + if err == nil || (err.Error() != expectingErr) { t.Fatalf("expected giving up error, got: %#v", err) } - } func TestWithRetryAndBackoffsLogger(t *testing.T) { @@ -701,6 +701,70 @@ func TestWithRetryAndBackoffsLogger(t *testing.T) { } } +func TestWithRetryAndBackoffsForResourceMethods(t *testing.T) { + // Mock server which always responds 500. + setup() + defer teardown() + + url, _ := url.Parse(server.URL) + mux.HandleFunc("/v2/account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Add(headerRateLimit, "500") + w.Header().Add(headerRateRemaining, "42") + w.Header().Add(headerRateReset, "1372700873") + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"id": "bad_request", "message": "broken"}`)) + }) + + tokenSrc := oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "new_token", + }) + + oauthClient := oauth2.NewClient(context.TODO(), tokenSrc) + + waitMax := PtrTo(6.0) + waitMin := PtrTo(3.0) + + retryConfig := RetryConfig{ + RetryMax: 3, + RetryWaitMin: waitMin, + RetryWaitMax: waitMax, + } + + // Create the client. Use short retry windows so we fail faster. + client, err := New(oauthClient, WithRetryAndBackoffs(retryConfig)) + client.BaseURL = url + if err != nil { + t.Fatalf("err: %v", err) + } + + expectingErr := fmt.Sprintf("GET %s/v2/account: 500 broken; giving up after 4 attempt(s)", url) + _, resp, err := client.Account.Get(context.Background()) + if err == nil || (err.Error() != expectingErr) { + t.Fatalf("expected giving up error, got: %s", err.Error()) + } + if _, ok := err.(*ErrorResponse); !ok { + t.Fatalf("expected error to be *godo.ErrorResponse, got: %#v", err) + } + + // Ensure that the *Response is properly populated + if resp == nil { + t.Fatal("expected non-nil *godo.Response") + } + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("expected %d, got: %d", http.StatusInternalServerError, resp.StatusCode) + } + if expected := 500; resp.Rate.Limit != expected { + t.Errorf("expected rate limit to be populate: got: %v, expected: %v", resp.Rate.Limit, expected) + } + if expected := 42; resp.Rate.Remaining != expected { + t.Errorf("expected rate limit remaining to be populate: got: %v, expected: %v", resp.Rate.Remaining, expected) + } + reset := time.Date(2013, 7, 1, 17, 47, 53, 0, time.UTC) + if client.Rate.Reset.UTC() != reset { + t.Errorf("expected rate limit reset to be populate: got: %v, expected: %v", resp.Rate.Reset, reset) + } +} + func checkCurrentPage(t *testing.T, resp *Response, expectedPage int) { links := resp.Links p, err := links.CurrentPage()