| Index: server/auth/xsrf/xsrf_test.go
|
| diff --git a/server/auth/xsrf/xsrf_test.go b/server/auth/xsrf/xsrf_test.go
|
| index fbda5788f4e6d1e68da6e44a757e93fa7999ee45..cf554ad18b694f39ff8b84b9cbd1497f395658b6 100644
|
| --- a/server/auth/xsrf/xsrf_test.go
|
| +++ b/server/auth/xsrf/xsrf_test.go
|
| @@ -6,24 +6,24 @@ package xsrf
|
|
|
| import (
|
| "html/template"
|
| "net/http"
|
| "net/http/httptest"
|
| "net/url"
|
| "strings"
|
| "testing"
|
| "time"
|
|
|
| - "github.com/julienschmidt/httprouter"
|
| "golang.org/x/net/context"
|
|
|
| "github.com/luci/luci-go/common/clock/testclock"
|
| + "github.com/luci/luci-go/server/router"
|
| "github.com/luci/luci-go/server/secrets/testsecrets"
|
|
|
| . "github.com/smartystreets/goconvey/convey"
|
| )
|
|
|
| func TestXsrf(t *testing.T) {
|
| Convey("Token + Check", t, func() {
|
| c := makeContext()
|
| tok, err := Token(c)
|
| So(err, ShouldBeNil)
|
| @@ -35,40 +35,53 @@ func TestXsrf(t *testing.T) {
|
| c := makeContext()
|
| So(TokenField(c), ShouldResemble,
|
| template.HTML("<input type=\"hidden\" name=\"xsrf_token\" "+
|
| "value=\"AXsiX2kiOiIxNDQyMjcwNTIwMDAwIn1ceiDv1yfNK9OHcdb209l3fM4p_gn-Uaembaa8gr3WXg\">"))
|
| })
|
|
|
| Convey("Middleware works", t, func() {
|
| c := makeContext()
|
| tok, _ := Token(c)
|
|
|
| - h := WithTokenCheck(func(c context.Context, rw http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
| - rw.Write([]byte("hi"))
|
| - })
|
| + h := func(c *router.Context) {
|
| + c.Writer.Write([]byte("hi"))
|
| + }
|
| + mc := router.MiddlewareChain{WithTokenCheck}
|
|
|
| // Has token -> works.
|
| rec := httptest.NewRecorder()
|
| req := makeRequest(tok)
|
| - h(c, rec, req, nil)
|
| + router.RunMiddleware(&router.Context{
|
| + Context: c,
|
| + Writer: rec,
|
| + Request: req,
|
| + }, mc, h)
|
| So(rec.Code, ShouldEqual, 200)
|
|
|
| // No token.
|
| rec = httptest.NewRecorder()
|
| req = makeRequest("")
|
| - h(c, rec, req, nil)
|
| + router.RunMiddleware(&router.Context{
|
| + Context: c,
|
| + Writer: rec,
|
| + Request: req,
|
| + }, mc, h)
|
| So(rec.Code, ShouldEqual, 403)
|
|
|
| // Bad token.
|
| rec = httptest.NewRecorder()
|
| req = makeRequest("blah")
|
| - h(c, rec, req, nil)
|
| + router.RunMiddleware(&router.Context{
|
| + Context: c,
|
| + Writer: rec,
|
| + Request: req,
|
| + }, mc, h)
|
| So(rec.Code, ShouldEqual, 403)
|
| })
|
| }
|
|
|
| func makeContext() context.Context {
|
| c := testsecrets.Use(context.Background())
|
| c, _ = testclock.UseTime(c, time.Unix(1442270520, 0))
|
| return c
|
| }
|
|
|
|
|