Index: tools/chromeproxy_testserver/server_test.go |
diff --git a/tools/chromeproxy_testserver/server_test.go b/tools/chromeproxy_testserver/server_test.go |
new file mode 100644 |
index 0000000000000000000000000000000000000000..d215131194a231d3233c647a1f8e01b4c910913c |
--- /dev/null |
+++ b/tools/chromeproxy_testserver/server_test.go |
@@ -0,0 +1,99 @@ |
+package server |
+ |
+import ( |
+ "encoding/base64" |
+ "encoding/json" |
+ "net/http" |
+ "net/http/httptest" |
+ "net/url" |
+ "reflect" |
+ "strconv" |
+ "testing" |
+) |
+ |
+func composeQuery(path string, code int, headers http.Header, body []byte) (string, error) { |
+ u, err := url.Parse(path) |
+ if err != nil { |
+ return "", err |
+ } |
+ q := u.Query() |
+ if code > 0 { |
+ q.Set("respStatus", strconv.Itoa(code)) |
+ } |
+ if headers != nil { |
+ h, err := json.Marshal(headers) |
+ if err != nil { |
+ return "", err |
+ } |
+ q.Set("respHeader", base64.URLEncoding.EncodeToString(h)) |
+ } |
+ if len(body) > 0 { |
+ q.Set("respBody", base64.URLEncoding.EncodeToString(body)) |
+ } |
+ u.RawQuery = q.Encode() |
+ return u.String(), nil |
+} |
+ |
+func TestResponseOverride(t *testing.T) { |
+ tests := []struct { |
+ name string |
+ code int |
+ headers http.Header |
+ body []byte |
+ }{ |
+ {name: "code", code: 204}, |
+ {name: "body", body: []byte("new body")}, |
+ { |
+ name: "headers", |
+ headers: http.Header{ |
+ "Via": []string{"Via1", "Via2"}, |
+ "Content-Type": []string{"random content"}, |
+ }, |
+ }, |
+ { |
+ name: "everything", |
+ code: 204, |
+ body: []byte("new body"), |
+ headers: http.Header{ |
+ "Via": []string{"Via1", "Via2"}, |
+ "Content-Type": []string{"random content"}, |
+ }, |
+ }, |
+ } |
+ |
+ for _, test := range tests { |
+ u, err := composeQuery("http://test.com/override", test.code, test.headers, test.body) |
+ if err != nil { |
+ t.Errorf("%s: composeQuery: %v", test.name, err) |
+ return |
+ } |
+ req, err := http.NewRequest("GET", u, nil) |
+ if err != nil { |
+ t.Errorf("%s: http.NewRequest: %v", test.name, err) |
+ return |
+ } |
+ w := httptest.NewRecorder() |
+ defaultResponse(w, req) |
+ if test.code > 0 { |
+ if got, want := w.Code, test.code; got != want { |
+ t.Errorf("%s: response code: got %d want %d", test.name, got, want) |
+ return |
+ } |
+ } |
+ if test.headers != nil { |
+ for k, want := range test.headers { |
+ got, ok := w.HeaderMap[k] |
+ if !ok || !reflect.DeepEqual(got, want) { |
+ t.Errorf("%s: header %s: code: got %v want %v", test.name, k, got, want) |
+ return |
+ } |
+ } |
+ } |
+ if test.body != nil { |
+ if got, want := string(w.Body.Bytes()), string(test.body); got != want { |
+ t.Errorf("%s: body: got %s want %s", test.name, got, want) |
+ return |
+ } |
+ } |
+ } |
+} |