| Index: common/lhttp/client_test.go
|
| diff --git a/common/lhttp/client_test.go b/common/lhttp/client_test.go
|
| index a6c721d51aeee95b6532bce7736038b55238d000..5df535dfdc27d594786f06cd113da42505f48d75 100644
|
| --- a/common/lhttp/client_test.go
|
| +++ b/common/lhttp/client_test.go
|
| @@ -12,11 +12,11 @@ import (
|
| "io/ioutil"
|
| "net/http"
|
| "net/http/httptest"
|
| + "sync"
|
| "testing"
|
|
|
| "golang.org/x/net/context"
|
|
|
| - "github.com/luci/luci-go/common/errors"
|
| "github.com/luci/luci-go/common/retry"
|
|
|
| "github.com/maruel/ut"
|
| @@ -125,8 +125,7 @@ func TestNewRequestGETFail(t *testing.T) {
|
| })
|
|
|
| status, err := clientReq()
|
| - ut.AssertEqual(t, true, errors.IsTransient(err))
|
| - ut.AssertEqual(t, "http request failed: Internal Server Error (HTTP 500)", err.Error())
|
| + ut.AssertEqual(t, "http request failed: Internal Server Error (HTTP 500) (attempts: 4)", err.Error())
|
| ut.AssertEqual(t, 500, status)
|
| }
|
|
|
| @@ -169,8 +168,7 @@ func TestGetJSONBadResult(t *testing.T) {
|
|
|
| actual := map[string]string{}
|
| status, err := GetJSON(ctx, fast, http.DefaultClient, ts.URL, &actual)
|
| - ut.AssertEqual(t, true, errors.IsTransient(err))
|
| - ut.AssertEqual(t, "bad response "+ts.URL+": invalid character 'y' looking for beginning of value", err.Error())
|
| + ut.AssertEqual(t, "bad response "+ts.URL+": invalid character 'y' looking for beginning of value (attempts: 4)", err.Error())
|
| ut.AssertEqual(t, 200, status)
|
| ut.AssertEqual(t, map[string]string{}, actual)
|
| }
|
| @@ -188,8 +186,7 @@ func TestGetJSONBadResultIgnore(t *testing.T) {
|
| defer ts.Close()
|
|
|
| status, err := GetJSON(ctx, fast, http.DefaultClient, ts.URL, nil)
|
| - ut.AssertEqual(t, true, errors.IsTransient(err))
|
| - ut.AssertEqual(t, "bad response "+ts.URL+": invalid character 'y' looking for beginning of value", err.Error())
|
| + ut.AssertEqual(t, "bad response "+ts.URL+": invalid character 'y' looking for beginning of value (attempts: 4)", err.Error())
|
| ut.AssertEqual(t, 200, status)
|
| }
|
|
|
| @@ -203,7 +200,7 @@ func TestGetJSONBadContentTypeIgnore(t *testing.T) {
|
| defer ts.Close()
|
|
|
| status, err := GetJSON(ctx, fast, http.DefaultClient, ts.URL, nil)
|
| - ut.AssertEqual(t, "unexpected Content-Type, expected \"application/json\", got \"text/plain; charset=utf-8\"", err.Error())
|
| + ut.AssertEqual(t, "unexpected Content-Type, expected \"application/json\", got \"text/plain; charset=utf-8\" (attempts: 4)", err.Error())
|
| ut.AssertEqual(t, 200, status)
|
| }
|
|
|
| @@ -253,6 +250,96 @@ func TestPostJSONwithHeaders(t *testing.T) {
|
| ut.AssertEqual(t, 1, serverCalls)
|
| }
|
|
|
| +func TestNewRequestClosesBody(t *testing.T) {
|
| + ctx := context.Background()
|
| + serverCalls := 0
|
| +
|
| + // Return a 500 for the first 2 requests.
|
| + ts := httptest.NewServer(
|
| + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| + defer r.Body.Close()
|
| + serverCalls++
|
| + if serverCalls <= 2 {
|
| + w.WriteHeader(500)
|
| + }
|
| + fmt.Fprintf(w, "Hello World!\n")
|
| + }))
|
| + defer ts.Close()
|
| +
|
| + rt := &trackingRoundTripper{RoundTripper: http.DefaultTransport}
|
| + hc := &http.Client{Transport: rt}
|
| + httpReq := httpReqGen("GET", ts.URL, nil)
|
| +
|
| + clientCalls := 0
|
| + var lastResp *http.Response
|
| + req := NewRequest(ctx, hc, fast, httpReq, func(resp *http.Response) error {
|
| + clientCalls++
|
| + lastResp = resp
|
| + return resp.Body.Close()
|
| + })
|
| +
|
| + status, err := req()
|
| + if err != nil {
|
| + t.Fatalf("req returned err %v, want nil", err)
|
| + }
|
| + if got, want := status, http.StatusOK; got != want {
|
| + t.Errorf("req returned status %d, want %d", got, want)
|
| + }
|
| +
|
| + // We expect only one client call, but three requests through to the server.
|
| + if got, want := clientCalls, 1; got != want {
|
| + t.Errorf("handler callback invoked %d times, want %d", got, want)
|
| + }
|
| + if got, want := len(rt.Responses), 3; got != want {
|
| + t.Errorf("len(Responses) = %d, want %d", got, want)
|
| + }
|
| +
|
| + // Check that the last response is the one we handled, and that all the bodies
|
| + // were closed.
|
| + if got, want := lastResp, rt.Responses[2]; got != want {
|
| + t.Errorf("Last Response did not match Response in handler callback.\nGot: %v\nWant: %v", got, want)
|
| + }
|
| + for i, resp := range rt.Responses {
|
| + rc := resp.Body.(*trackingReadCloser)
|
| + if !rc.Closed {
|
| + t.Errorf("Reponses[%d].Body was not closed", i)
|
| + }
|
| + }
|
| +}
|
| +
|
| +// trackingRoundTripper wraps an http.RoundTripper, keeping track of any
|
| +// returned Responses. Each response's Body, when set, is wrapped with a
|
| +// trackingReadCloser.
|
| +type trackingRoundTripper struct {
|
| + http.RoundTripper
|
| +
|
| + mu sync.Mutex
|
| + Responses []*http.Response
|
| +}
|
| +
|
| +func (t *trackingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
| + resp, err := t.RoundTripper.RoundTrip(req)
|
| + if resp != nil && resp.Body != nil {
|
| + resp.Body = &trackingReadCloser{ReadCloser: resp.Body}
|
| + }
|
| + t.mu.Lock()
|
| + defer t.mu.Unlock()
|
| + t.Responses = append(t.Responses, resp)
|
| + return resp, err
|
| +}
|
| +
|
| +// trackingReadCloser wraps an io.ReadCloser, keeping track of whether Closed was
|
| +// called.
|
| +type trackingReadCloser struct {
|
| + io.ReadCloser
|
| + Closed bool
|
| +}
|
| +
|
| +func (t *trackingReadCloser) Close() error {
|
| + t.Closed = true
|
| + return t.ReadCloser.Close()
|
| +}
|
| +
|
| // Private details.
|
|
|
| func fast() retry.Iterator {
|
|
|