| Index: appengine/gaemiddleware/appengine_test.go
|
| diff --git a/appengine/gaemiddleware/appengine_test.go b/appengine/gaemiddleware/appengine_test.go
|
| index a95bc160401b2b597e28ae51485249722fbd3f81..d29707e4a4a33cb33dfee76350a95c94cda3697e 100644
|
| --- a/appengine/gaemiddleware/appengine_test.go
|
| +++ b/appengine/gaemiddleware/appengine_test.go
|
| @@ -2,116 +2,151 @@
|
| // Use of this source code is governed under the Apache License, Version 2.0
|
| // that can be found in the LICENSE file.
|
|
|
| package gaemiddleware
|
|
|
| import (
|
| "net/http"
|
| "net/http/httptest"
|
| "testing"
|
|
|
| - "github.com/julienschmidt/httprouter"
|
| "github.com/luci/luci-go/appengine/gaetesting"
|
| + "github.com/luci/luci-go/server/router"
|
| . "github.com/smartystreets/goconvey/convey"
|
| "golang.org/x/net/context"
|
| )
|
|
|
| func init() {
|
| // disable this so that we can actually check the logic in these middlewares
|
| devAppserverBypassFn = func(context.Context) bool { return false }
|
| }
|
|
|
| func TestRequireCron(t *testing.T) {
|
| t.Parallel()
|
|
|
| Convey("Test RequireCron", t, func() {
|
| hit := false
|
| - f := func(c context.Context, rw http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
| + f := func(c *router.Context) {
|
| hit = true
|
| - rw.Write([]byte("ok"))
|
| + c.Writer.Write([]byte("ok"))
|
| }
|
|
|
| Convey("from non-cron fails", func() {
|
| rec := httptest.NewRecorder()
|
| - gaetesting.BaseTest(RequireCron(f))(rec, &http.Request{}, nil)
|
| + c := &router.Context{
|
| + Context: gaetesting.TestingContext(),
|
| + Writer: rec,
|
| + Request: &http.Request{},
|
| + }
|
| + router.RunMiddleware(c, router.MiddlewareChain{RequireCron}, f)
|
| So(hit, ShouldBeFalse)
|
| So(rec.Body.String(), ShouldEqual, "error: must be run from cron")
|
| So(rec.Code, ShouldEqual, http.StatusForbidden)
|
| })
|
|
|
| Convey("from cron succeeds", func() {
|
| rec := httptest.NewRecorder()
|
| - gaetesting.BaseTest(RequireCron(f))(rec, &http.Request{
|
| - Header: http.Header{
|
| - http.CanonicalHeaderKey("x-appengine-cron"): []string{"true"},
|
| + c := &router.Context{
|
| + Context: gaetesting.TestingContext(),
|
| + Writer: rec,
|
| + Request: &http.Request{
|
| + Header: http.Header{
|
| + http.CanonicalHeaderKey("x-appengine-cron"): []string{"true"},
|
| + },
|
| },
|
| - }, nil)
|
| + }
|
| + router.RunMiddleware(c, router.MiddlewareChain{RequireCron}, f)
|
| So(hit, ShouldBeTrue)
|
| So(rec.Body.String(), ShouldEqual, "ok")
|
| So(rec.Code, ShouldEqual, http.StatusOK)
|
| })
|
| })
|
| }
|
|
|
| func TestRequireTQ(t *testing.T) {
|
| t.Parallel()
|
|
|
| Convey("Test RequireTQ", t, func() {
|
| hit := false
|
| - f := func(c context.Context, rw http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
| + f := func(c *router.Context) {
|
| hit = true
|
| - rw.Write([]byte("ok"))
|
| + c.Writer.Write([]byte("ok"))
|
| }
|
|
|
| Convey("from non-tq fails (wat)", func() {
|
| rec := httptest.NewRecorder()
|
| - gaetesting.BaseTest(RequireTaskQueue("wat", f))(rec, &http.Request{}, nil)
|
| + c := &router.Context{
|
| + Context: gaetesting.TestingContext(),
|
| + Writer: rec,
|
| + Request: &http.Request{},
|
| + }
|
| + router.RunMiddleware(c, router.MiddlewareChain{RequireTaskQueue("wat")}, f)
|
| So(hit, ShouldBeFalse)
|
| So(rec.Body.String(), ShouldEqual, "error: must be run from the correct taskqueue")
|
| So(rec.Code, ShouldEqual, http.StatusForbidden)
|
| })
|
|
|
| Convey("from non-tq fails", func() {
|
| rec := httptest.NewRecorder()
|
| - gaetesting.BaseTest(RequireTaskQueue("", f))(rec, &http.Request{}, nil)
|
| + c := &router.Context{
|
| + Context: gaetesting.TestingContext(),
|
| + Writer: rec,
|
| + Request: &http.Request{},
|
| + }
|
| + router.RunMiddleware(c, router.MiddlewareChain{RequireTaskQueue("")}, f)
|
| So(hit, ShouldBeFalse)
|
| So(rec.Body.String(), ShouldEqual, "error: must be run from the correct taskqueue")
|
| So(rec.Code, ShouldEqual, http.StatusForbidden)
|
| })
|
|
|
| Convey("from wrong tq fails (wat)", func() {
|
| rec := httptest.NewRecorder()
|
| - gaetesting.BaseTest(RequireTaskQueue("wat", f))(rec, &http.Request{
|
| - Header: http.Header{
|
| - http.CanonicalHeaderKey("x-appengine-queuename"): []string{"else"},
|
| + c := &router.Context{
|
| + Context: gaetesting.TestingContext(),
|
| + Writer: rec,
|
| + Request: &http.Request{
|
| + Header: http.Header{
|
| + http.CanonicalHeaderKey("x-appengine-queuename"): []string{"else"},
|
| + },
|
| },
|
| - }, nil)
|
| + }
|
| + router.RunMiddleware(c, router.MiddlewareChain{RequireTaskQueue("wat")}, f)
|
| So(hit, ShouldBeFalse)
|
| So(rec.Body.String(), ShouldEqual, "error: must be run from the correct taskqueue")
|
| So(rec.Code, ShouldEqual, http.StatusForbidden)
|
| })
|
|
|
| Convey("from right tq succeeds (wat)", func() {
|
| rec := httptest.NewRecorder()
|
| - gaetesting.BaseTest(RequireTaskQueue("wat", f))(rec, &http.Request{
|
| - Header: http.Header{
|
| - http.CanonicalHeaderKey("x-appengine-queuename"): []string{"wat"},
|
| + c := &router.Context{
|
| + Context: gaetesting.TestingContext(),
|
| + Writer: rec,
|
| + Request: &http.Request{
|
| + Header: http.Header{
|
| + http.CanonicalHeaderKey("x-appengine-queuename"): []string{"wat"},
|
| + },
|
| },
|
| - }, nil)
|
| + }
|
| + router.RunMiddleware(c, router.MiddlewareChain{RequireTaskQueue("wat")}, f)
|
| So(hit, ShouldBeTrue)
|
| So(rec.Body.String(), ShouldEqual, "ok")
|
| So(rec.Code, ShouldEqual, http.StatusOK)
|
| })
|
|
|
| Convey("from any tq succeeds", func() {
|
| rec := httptest.NewRecorder()
|
| - gaetesting.BaseTest(RequireTaskQueue("", f))(rec, &http.Request{
|
| - Header: http.Header{
|
| - http.CanonicalHeaderKey("x-appengine-queuename"): []string{"wat"},
|
| + c := &router.Context{
|
| + Context: gaetesting.TestingContext(),
|
| + Writer: rec,
|
| + Request: &http.Request{
|
| + Header: http.Header{
|
| + http.CanonicalHeaderKey("x-appengine-queuename"): []string{"wat"},
|
| + },
|
| },
|
| - }, nil)
|
| + }
|
| + router.RunMiddleware(c, router.MiddlewareChain{RequireTaskQueue("")}, f)
|
| So(hit, ShouldBeTrue)
|
| So(rec.Body.String(), ShouldEqual, "ok")
|
| So(rec.Code, ShouldEqual, http.StatusOK)
|
| })
|
| })
|
| }
|
|
|