| OLD | NEW |
| 1 // Copyright 2015 The LUCI Authors. All rights reserved. | 1 // Copyright 2015 The LUCI Authors. All rights reserved. |
| 2 // Use of this source code is governed under the Apache License, Version 2.0 | 2 // Use of this source code is governed under the Apache License, Version 2.0 |
| 3 // that can be found in the LICENSE file. | 3 // that can be found in the LICENSE file. |
| 4 | 4 |
| 5 package gaemiddleware | 5 package gaemiddleware |
| 6 | 6 |
| 7 import ( | 7 import ( |
| 8 "net/http" | 8 "net/http" |
| 9 "net/http/httptest" | 9 "net/http/httptest" |
| 10 "testing" | 10 "testing" |
| 11 | 11 |
| 12 "github.com/julienschmidt/httprouter" | |
| 13 "github.com/luci/luci-go/appengine/gaetesting" | 12 "github.com/luci/luci-go/appengine/gaetesting" |
| 13 "github.com/luci/luci-go/server/router" |
| 14 . "github.com/smartystreets/goconvey/convey" | 14 . "github.com/smartystreets/goconvey/convey" |
| 15 "golang.org/x/net/context" | 15 "golang.org/x/net/context" |
| 16 ) | 16 ) |
| 17 | 17 |
| 18 func init() { | 18 func init() { |
| 19 // disable this so that we can actually check the logic in these middlew
ares | 19 // disable this so that we can actually check the logic in these middlew
ares |
| 20 devAppserverBypassFn = func(context.Context) bool { return false } | 20 devAppserverBypassFn = func(context.Context) bool { return false } |
| 21 } | 21 } |
| 22 | 22 |
| 23 func TestRequireCron(t *testing.T) { | 23 func TestRequireCron(t *testing.T) { |
| 24 t.Parallel() | 24 t.Parallel() |
| 25 | 25 |
| 26 Convey("Test RequireCron", t, func() { | 26 Convey("Test RequireCron", t, func() { |
| 27 hit := false | 27 hit := false |
| 28 » » f := func(c context.Context, rw http.ResponseWriter, r *http.Req
uest, p httprouter.Params) { | 28 » » f := func(c *router.Context) { |
| 29 hit = true | 29 hit = true |
| 30 » » » rw.Write([]byte("ok")) | 30 » » » c.Writer.Write([]byte("ok")) |
| 31 } | 31 } |
| 32 | 32 |
| 33 Convey("from non-cron fails", func() { | 33 Convey("from non-cron fails", func() { |
| 34 rec := httptest.NewRecorder() | 34 rec := httptest.NewRecorder() |
| 35 » » » gaetesting.BaseTest(RequireCron(f))(rec, &http.Request{}
, nil) | 35 » » » c := &router.Context{ |
| 36 » » » » Context: gaetesting.TestingContext(), |
| 37 » » » » Writer: rec, |
| 38 » » » » Request: &http.Request{}, |
| 39 » » » } |
| 40 » » » router.RunMiddleware(c, router.MiddlewareChain{RequireCr
on}, f) |
| 36 So(hit, ShouldBeFalse) | 41 So(hit, ShouldBeFalse) |
| 37 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom cron") | 42 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom cron") |
| 38 So(rec.Code, ShouldEqual, http.StatusForbidden) | 43 So(rec.Code, ShouldEqual, http.StatusForbidden) |
| 39 }) | 44 }) |
| 40 | 45 |
| 41 Convey("from cron succeeds", func() { | 46 Convey("from cron succeeds", func() { |
| 42 rec := httptest.NewRecorder() | 47 rec := httptest.NewRecorder() |
| 43 » » » gaetesting.BaseTest(RequireCron(f))(rec, &http.Request{ | 48 » » » c := &router.Context{ |
| 44 » » » » Header: http.Header{ | 49 » » » » Context: gaetesting.TestingContext(), |
| 45 » » » » » http.CanonicalHeaderKey("x-appengine-cro
n"): []string{"true"}, | 50 » » » » Writer: rec, |
| 51 » » » » Request: &http.Request{ |
| 52 » » » » » Header: http.Header{ |
| 53 » » » » » » http.CanonicalHeaderKey("x-appen
gine-cron"): []string{"true"}, |
| 54 » » » » » }, |
| 46 }, | 55 }, |
| 47 » » » }, nil) | 56 » » » } |
| 57 » » » router.RunMiddleware(c, router.MiddlewareChain{RequireCr
on}, f) |
| 48 So(hit, ShouldBeTrue) | 58 So(hit, ShouldBeTrue) |
| 49 So(rec.Body.String(), ShouldEqual, "ok") | 59 So(rec.Body.String(), ShouldEqual, "ok") |
| 50 So(rec.Code, ShouldEqual, http.StatusOK) | 60 So(rec.Code, ShouldEqual, http.StatusOK) |
| 51 }) | 61 }) |
| 52 }) | 62 }) |
| 53 } | 63 } |
| 54 | 64 |
| 55 func TestRequireTQ(t *testing.T) { | 65 func TestRequireTQ(t *testing.T) { |
| 56 t.Parallel() | 66 t.Parallel() |
| 57 | 67 |
| 58 Convey("Test RequireTQ", t, func() { | 68 Convey("Test RequireTQ", t, func() { |
| 59 hit := false | 69 hit := false |
| 60 » » f := func(c context.Context, rw http.ResponseWriter, r *http.Req
uest, p httprouter.Params) { | 70 » » f := func(c *router.Context) { |
| 61 hit = true | 71 hit = true |
| 62 » » » rw.Write([]byte("ok")) | 72 » » » c.Writer.Write([]byte("ok")) |
| 63 } | 73 } |
| 64 | 74 |
| 65 Convey("from non-tq fails (wat)", func() { | 75 Convey("from non-tq fails (wat)", func() { |
| 66 rec := httptest.NewRecorder() | 76 rec := httptest.NewRecorder() |
| 67 » » » gaetesting.BaseTest(RequireTaskQueue("wat", f))(rec, &ht
tp.Request{}, nil) | 77 » » » c := &router.Context{ |
| 78 » » » » Context: gaetesting.TestingContext(), |
| 79 » » » » Writer: rec, |
| 80 » » » » Request: &http.Request{}, |
| 81 » » » } |
| 82 » » » router.RunMiddleware(c, router.MiddlewareChain{RequireTa
skQueue("wat")}, f) |
| 68 So(hit, ShouldBeFalse) | 83 So(hit, ShouldBeFalse) |
| 69 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom the correct taskqueue") | 84 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom the correct taskqueue") |
| 70 So(rec.Code, ShouldEqual, http.StatusForbidden) | 85 So(rec.Code, ShouldEqual, http.StatusForbidden) |
| 71 }) | 86 }) |
| 72 | 87 |
| 73 Convey("from non-tq fails", func() { | 88 Convey("from non-tq fails", func() { |
| 74 rec := httptest.NewRecorder() | 89 rec := httptest.NewRecorder() |
| 75 » » » gaetesting.BaseTest(RequireTaskQueue("", f))(rec, &http.
Request{}, nil) | 90 » » » c := &router.Context{ |
| 91 » » » » Context: gaetesting.TestingContext(), |
| 92 » » » » Writer: rec, |
| 93 » » » » Request: &http.Request{}, |
| 94 » » » } |
| 95 » » » router.RunMiddleware(c, router.MiddlewareChain{RequireTa
skQueue("")}, f) |
| 76 So(hit, ShouldBeFalse) | 96 So(hit, ShouldBeFalse) |
| 77 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom the correct taskqueue") | 97 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom the correct taskqueue") |
| 78 So(rec.Code, ShouldEqual, http.StatusForbidden) | 98 So(rec.Code, ShouldEqual, http.StatusForbidden) |
| 79 }) | 99 }) |
| 80 | 100 |
| 81 Convey("from wrong tq fails (wat)", func() { | 101 Convey("from wrong tq fails (wat)", func() { |
| 82 rec := httptest.NewRecorder() | 102 rec := httptest.NewRecorder() |
| 83 » » » gaetesting.BaseTest(RequireTaskQueue("wat", f))(rec, &ht
tp.Request{ | 103 » » » c := &router.Context{ |
| 84 » » » » Header: http.Header{ | 104 » » » » Context: gaetesting.TestingContext(), |
| 85 » » » » » http.CanonicalHeaderKey("x-appengine-que
uename"): []string{"else"}, | 105 » » » » Writer: rec, |
| 106 » » » » Request: &http.Request{ |
| 107 » » » » » Header: http.Header{ |
| 108 » » » » » » http.CanonicalHeaderKey("x-appen
gine-queuename"): []string{"else"}, |
| 109 » » » » » }, |
| 86 }, | 110 }, |
| 87 » » » }, nil) | 111 » » » } |
| 112 » » » router.RunMiddleware(c, router.MiddlewareChain{RequireTa
skQueue("wat")}, f) |
| 88 So(hit, ShouldBeFalse) | 113 So(hit, ShouldBeFalse) |
| 89 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom the correct taskqueue") | 114 So(rec.Body.String(), ShouldEqual, "error: must be run f
rom the correct taskqueue") |
| 90 So(rec.Code, ShouldEqual, http.StatusForbidden) | 115 So(rec.Code, ShouldEqual, http.StatusForbidden) |
| 91 }) | 116 }) |
| 92 | 117 |
| 93 Convey("from right tq succeeds (wat)", func() { | 118 Convey("from right tq succeeds (wat)", func() { |
| 94 rec := httptest.NewRecorder() | 119 rec := httptest.NewRecorder() |
| 95 » » » gaetesting.BaseTest(RequireTaskQueue("wat", f))(rec, &ht
tp.Request{ | 120 » » » c := &router.Context{ |
| 96 » » » » Header: http.Header{ | 121 » » » » Context: gaetesting.TestingContext(), |
| 97 » » » » » http.CanonicalHeaderKey("x-appengine-que
uename"): []string{"wat"}, | 122 » » » » Writer: rec, |
| 123 » » » » Request: &http.Request{ |
| 124 » » » » » Header: http.Header{ |
| 125 » » » » » » http.CanonicalHeaderKey("x-appen
gine-queuename"): []string{"wat"}, |
| 126 » » » » » }, |
| 98 }, | 127 }, |
| 99 » » » }, nil) | 128 » » » } |
| 129 » » » router.RunMiddleware(c, router.MiddlewareChain{RequireTa
skQueue("wat")}, f) |
| 100 So(hit, ShouldBeTrue) | 130 So(hit, ShouldBeTrue) |
| 101 So(rec.Body.String(), ShouldEqual, "ok") | 131 So(rec.Body.String(), ShouldEqual, "ok") |
| 102 So(rec.Code, ShouldEqual, http.StatusOK) | 132 So(rec.Code, ShouldEqual, http.StatusOK) |
| 103 }) | 133 }) |
| 104 | 134 |
| 105 Convey("from any tq succeeds", func() { | 135 Convey("from any tq succeeds", func() { |
| 106 rec := httptest.NewRecorder() | 136 rec := httptest.NewRecorder() |
| 107 » » » gaetesting.BaseTest(RequireTaskQueue("", f))(rec, &http.
Request{ | 137 » » » c := &router.Context{ |
| 108 » » » » Header: http.Header{ | 138 » » » » Context: gaetesting.TestingContext(), |
| 109 » » » » » http.CanonicalHeaderKey("x-appengine-que
uename"): []string{"wat"}, | 139 » » » » Writer: rec, |
| 140 » » » » Request: &http.Request{ |
| 141 » » » » » Header: http.Header{ |
| 142 » » » » » » http.CanonicalHeaderKey("x-appen
gine-queuename"): []string{"wat"}, |
| 143 » » » » » }, |
| 110 }, | 144 }, |
| 111 » » » }, nil) | 145 » » » } |
| 146 » » » router.RunMiddleware(c, router.MiddlewareChain{RequireTa
skQueue("")}, f) |
| 112 So(hit, ShouldBeTrue) | 147 So(hit, ShouldBeTrue) |
| 113 So(rec.Body.String(), ShouldEqual, "ok") | 148 So(rec.Body.String(), ShouldEqual, "ok") |
| 114 So(rec.Code, ShouldEqual, http.StatusOK) | 149 So(rec.Code, ShouldEqual, http.StatusOK) |
| 115 }) | 150 }) |
| 116 }) | 151 }) |
| 117 } | 152 } |
| OLD | NEW |