| OLD | NEW |
| 1 // Copyright 2015 The LUCI Authors. All rights reserved. | 1 // Copyright 2015 The LUCI Authors. All rights reserved. |
| 2 // Use of this source code is governed under the Apache License, Version 2.0 | 2 // Use of this source code is governed under the Apache License, Version 2.0 |
| 3 // that can be found in the LICENSE file. | 3 // that can be found in the LICENSE file. |
| 4 | 4 |
| 5 package gaemiddleware | 5 package gaemiddleware |
| 6 | 6 |
| 7 import ( | 7 import ( |
| 8 "net/http" | 8 "net/http" |
| 9 "net/http/httptest" | 9 "net/http/httptest" |
| 10 "testing" | 10 "testing" |
| 11 | 11 |
| 12 "github.com/julienschmidt/httprouter" | 12 "github.com/julienschmidt/httprouter" |
| 13 "github.com/luci/luci-go/appengine/gaetesting" | 13 "github.com/luci/luci-go/appengine/gaetesting" |
| 14 "github.com/luci/luci-go/server/router" |
| 14 . "github.com/smartystreets/goconvey/convey" | 15 . "github.com/smartystreets/goconvey/convey" |
| 15 "golang.org/x/net/context" | 16 "golang.org/x/net/context" |
| 16 ) | 17 ) |
| 17 | 18 |
| 18 func init() { | 19 func init() { |
| 19 // disable this so that we can actually check the logic in these middlew
ares | 20 // disable this so that we can actually check the logic in these middlew
ares |
| 20 devAppserverBypassFn = func(context.Context) bool { return false } | 21 devAppserverBypassFn = func(context.Context) bool { return false } |
| 21 } | 22 } |
| 22 | 23 |
| 23 func TestRequireCron(t *testing.T) { | 24 func TestRequireCron(t *testing.T) { |
| 24 t.Parallel() | 25 t.Parallel() |
| 25 | 26 |
| 27 initialize := func(w http.ResponseWriter, r *http.Request, p httprouter.
Params) router.Handler { |
| 28 return func(c *router.Context) { |
| 29 c.Writer = w |
| 30 c.Request = r |
| 31 c.Params = p |
| 32 } |
| 33 } |
| 34 |
| 26 Convey("Test RequireCron", t, func() { | 35 Convey("Test RequireCron", t, func() { |
| 27 hit := false | 36 hit := false |
| 28 » » f := func(c context.Context, rw http.ResponseWriter, r *http.Req
uest, p httprouter.Params) { | 37 |
| 38 » » f := func(c *router.Context) { |
| 29 hit = true | 39 hit = true |
| 30 » » » rw.Write([]byte("ok")) | 40 » » » c.Writer.Write([]byte("ok")) |
| 31 } | 41 } |
| 32 | 42 |
| 33 Convey("from non-cron fails", func() { | 43 Convey("from non-cron fails", func() { |
| 34 rec := httptest.NewRecorder() | 44 rec := httptest.NewRecorder() |
| 35 » » » gaetesting.BaseTest(RequireCron(f))(rec, &http.Request{}
, nil) | 45 » » » router.ChainHandlers(initialize(rec, &http.Request{}, ni
l), gaetesting.BaseTest(), RequireCron(), f)() |
| 36 So(hit, ShouldBeFalse) | 46 So(hit, ShouldBeFalse) |
| 37 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom cron") | 47 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom cron") |
| 38 So(rec.Code, ShouldEqual, http.StatusForbidden) | 48 So(rec.Code, ShouldEqual, http.StatusForbidden) |
| 39 }) | 49 }) |
| 40 | 50 |
| 41 Convey("from cron succeeds", func() { | 51 Convey("from cron succeeds", func() { |
| 42 rec := httptest.NewRecorder() | 52 rec := httptest.NewRecorder() |
| 43 » » » gaetesting.BaseTest(RequireCron(f))(rec, &http.Request{ | 53 » » » router.ChainHandlers(initialize(rec, &http.Request{ |
| 44 Header: http.Header{ | 54 Header: http.Header{ |
| 45 http.CanonicalHeaderKey("x-appengine-cro
n"): []string{"true"}, | 55 http.CanonicalHeaderKey("x-appengine-cro
n"): []string{"true"}, |
| 46 }, | 56 }, |
| 47 » » » }, nil) | 57 » » » }, nil), gaetesting.BaseTest(), RequireCron(), f)() |
| 48 So(hit, ShouldBeTrue) | 58 So(hit, ShouldBeTrue) |
| 49 So(rec.Body.String(), ShouldEqual, "ok") | 59 So(rec.Body.String(), ShouldEqual, "ok") |
| 50 So(rec.Code, ShouldEqual, http.StatusOK) | 60 So(rec.Code, ShouldEqual, http.StatusOK) |
| 51 }) | 61 }) |
| 52 }) | 62 }) |
| 53 } | 63 } |
| 54 | 64 |
| 55 func TestRequireTQ(t *testing.T) { | 65 func TestRequireTQ(t *testing.T) { |
| 56 t.Parallel() | 66 t.Parallel() |
| 57 | 67 |
| 68 initialize := func(w http.ResponseWriter, r *http.Request, p httprouter.
Params) router.Handler { |
| 69 return func(c *router.Context) { |
| 70 c.Writer = w |
| 71 c.Request = r |
| 72 c.Params = p |
| 73 } |
| 74 } |
| 75 |
| 58 Convey("Test RequireTQ", t, func() { | 76 Convey("Test RequireTQ", t, func() { |
| 59 hit := false | 77 hit := false |
| 60 » » f := func(c context.Context, rw http.ResponseWriter, r *http.Req
uest, p httprouter.Params) { | 78 » » f := func(c *router.Context) { |
| 61 hit = true | 79 hit = true |
| 62 » » » rw.Write([]byte("ok")) | 80 » » » c.Writer.Write([]byte("ok")) |
| 63 } | 81 } |
| 64 | 82 |
| 65 Convey("from non-tq fails (wat)", func() { | 83 Convey("from non-tq fails (wat)", func() { |
| 66 rec := httptest.NewRecorder() | 84 rec := httptest.NewRecorder() |
| 67 » » » gaetesting.BaseTest(RequireTaskQueue("wat", f))(rec, &ht
tp.Request{}, nil) | 85 » » » router.ChainHandlers(initialize(rec, &http.Request{}, ni
l), gaetesting.BaseTest(), RequireTaskQueue("wat"), f)() |
| 68 So(hit, ShouldBeFalse) | 86 So(hit, ShouldBeFalse) |
| 69 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom the correct taskqueue") | 87 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom the correct taskqueue") |
| 70 So(rec.Code, ShouldEqual, http.StatusForbidden) | 88 So(rec.Code, ShouldEqual, http.StatusForbidden) |
| 71 }) | 89 }) |
| 72 | 90 |
| 73 Convey("from non-tq fails", func() { | 91 Convey("from non-tq fails", func() { |
| 74 rec := httptest.NewRecorder() | 92 rec := httptest.NewRecorder() |
| 75 » » » gaetesting.BaseTest(RequireTaskQueue("", f))(rec, &http.
Request{}, nil) | 93 » » » router.ChainHandlers(initialize(rec, &http.Request{}, ni
l), gaetesting.BaseTest(), RequireTaskQueue(""), f)() |
| 76 So(hit, ShouldBeFalse) | 94 So(hit, ShouldBeFalse) |
| 77 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom the correct taskqueue") | 95 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom the correct taskqueue") |
| 78 So(rec.Code, ShouldEqual, http.StatusForbidden) | 96 So(rec.Code, ShouldEqual, http.StatusForbidden) |
| 79 }) | 97 }) |
| 80 | 98 |
| 81 Convey("from wrong tq fails (wat)", func() { | 99 Convey("from wrong tq fails (wat)", func() { |
| 82 rec := httptest.NewRecorder() | 100 rec := httptest.NewRecorder() |
| 83 » » » gaetesting.BaseTest(RequireTaskQueue("wat", f))(rec, &ht
tp.Request{ | 101 » » » router.ChainHandlers(initialize(rec, &http.Request{ |
| 84 Header: http.Header{ | 102 Header: http.Header{ |
| 85 http.CanonicalHeaderKey("x-appengine-que
uename"): []string{"else"}, | 103 http.CanonicalHeaderKey("x-appengine-que
uename"): []string{"else"}, |
| 86 }, | 104 }, |
| 87 » » » }, nil) | 105 » » » }, nil), gaetesting.BaseTest(), RequireTaskQueue("wat"),
f)() |
| 88 So(hit, ShouldBeFalse) | 106 So(hit, ShouldBeFalse) |
| 89 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom the correct taskqueue") | 107 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom the correct taskqueue") |
| 90 So(rec.Code, ShouldEqual, http.StatusForbidden) | 108 So(rec.Code, ShouldEqual, http.StatusForbidden) |
| 91 }) | 109 }) |
| 92 | 110 |
| 93 Convey("from right tq succeeds (wat)", func() { | 111 Convey("from right tq succeeds (wat)", func() { |
| 94 rec := httptest.NewRecorder() | 112 rec := httptest.NewRecorder() |
| 95 » » » gaetesting.BaseTest(RequireTaskQueue("wat", f))(rec, &ht
tp.Request{ | 113 » » » router.ChainHandlers(initialize(rec, &http.Request{ |
| 96 Header: http.Header{ | 114 Header: http.Header{ |
| 97 http.CanonicalHeaderKey("x-appengine-que
uename"): []string{"wat"}, | 115 http.CanonicalHeaderKey("x-appengine-que
uename"): []string{"wat"}, |
| 98 }, | 116 }, |
| 99 » » » }, nil) | 117 » » » }, nil), gaetesting.BaseTest(), RequireTaskQueue("wat"),
f)() |
| 100 So(hit, ShouldBeTrue) | 118 So(hit, ShouldBeTrue) |
| 101 So(rec.Body.String(), ShouldEqual, "ok") | 119 So(rec.Body.String(), ShouldEqual, "ok") |
| 102 So(rec.Code, ShouldEqual, http.StatusOK) | 120 So(rec.Code, ShouldEqual, http.StatusOK) |
| 103 }) | 121 }) |
| 104 | 122 |
| 105 Convey("from any tq succeeds", func() { | 123 Convey("from any tq succeeds", func() { |
| 106 rec := httptest.NewRecorder() | 124 rec := httptest.NewRecorder() |
| 107 » » » gaetesting.BaseTest(RequireTaskQueue("", f))(rec, &http.
Request{ | 125 » » » router.ChainHandlers(initialize(rec, &http.Request{ |
| 108 Header: http.Header{ | 126 Header: http.Header{ |
| 109 http.CanonicalHeaderKey("x-appengine-que
uename"): []string{"wat"}, | 127 http.CanonicalHeaderKey("x-appengine-que
uename"): []string{"wat"}, |
| 110 }, | 128 }, |
| 111 » » » }, nil) | 129 » » » }, nil), gaetesting.BaseTest(), RequireTaskQueue(""), f)
() |
| 112 So(hit, ShouldBeTrue) | 130 So(hit, ShouldBeTrue) |
| 113 So(rec.Body.String(), ShouldEqual, "ok") | 131 So(rec.Body.String(), ShouldEqual, "ok") |
| 114 So(rec.Code, ShouldEqual, http.StatusOK) | 132 So(rec.Code, ShouldEqual, http.StatusOK) |
| 115 }) | 133 }) |
| 116 }) | 134 }) |
| 117 } | 135 } |
| OLD | NEW |