| OLD | NEW |
| 1 // Copyright 2015 The LUCI Authors. All rights reserved. | 1 // Copyright 2015 The LUCI Authors. All rights reserved. |
| 2 // Use of this source code is governed under the Apache License, Version 2.0 | 2 // Use of this source code is governed under the Apache License, Version 2.0 |
| 3 // that can be found in the LICENSE file. | 3 // that can be found in the LICENSE file. |
| 4 | 4 |
| 5 package xsrf | 5 package xsrf |
| 6 | 6 |
| 7 import ( | 7 import ( |
| 8 "html/template" | 8 "html/template" |
| 9 "net/http" | 9 "net/http" |
| 10 "net/http/httptest" | 10 "net/http/httptest" |
| 11 "net/url" | 11 "net/url" |
| 12 "strings" | 12 "strings" |
| 13 "testing" | 13 "testing" |
| 14 "time" | 14 "time" |
| 15 | 15 |
| 16 "github.com/julienschmidt/httprouter" | 16 "github.com/julienschmidt/httprouter" |
| 17 "golang.org/x/net/context" | 17 "golang.org/x/net/context" |
| 18 | 18 |
| 19 "github.com/luci/luci-go/common/clock/testclock" | 19 "github.com/luci/luci-go/common/clock/testclock" |
| 20 "github.com/luci/luci-go/server/router" |
| 20 "github.com/luci/luci-go/server/secrets/testsecrets" | 21 "github.com/luci/luci-go/server/secrets/testsecrets" |
| 21 | 22 |
| 22 . "github.com/smartystreets/goconvey/convey" | 23 . "github.com/smartystreets/goconvey/convey" |
| 23 ) | 24 ) |
| 24 | 25 |
| 25 func TestXsrf(t *testing.T) { | 26 func TestXsrf(t *testing.T) { |
| 26 Convey("Token + Check", t, func() { | 27 Convey("Token + Check", t, func() { |
| 27 c := makeContext() | 28 c := makeContext() |
| 28 tok, err := Token(c) | 29 tok, err := Token(c) |
| 29 So(err, ShouldBeNil) | 30 So(err, ShouldBeNil) |
| 30 So(Check(c, tok), ShouldBeNil) | 31 So(Check(c, tok), ShouldBeNil) |
| 31 So(Check(c, tok+"abc"), ShouldNotBeNil) | 32 So(Check(c, tok+"abc"), ShouldNotBeNil) |
| 32 }) | 33 }) |
| 33 | 34 |
| 34 Convey("TokenField works", t, func() { | 35 Convey("TokenField works", t, func() { |
| 35 c := makeContext() | 36 c := makeContext() |
| 36 So(TokenField(c), ShouldResemble, | 37 So(TokenField(c), ShouldResemble, |
| 37 template.HTML("<input type=\"hidden\" name=\"xsrf_token\
" "+ | 38 template.HTML("<input type=\"hidden\" name=\"xsrf_token\
" "+ |
| 38 "value=\"AXsiX2kiOiIxNDQyMjcwNTIwMDAwIn1ceiDv1yf
NK9OHcdb209l3fM4p_gn-Uaembaa8gr3WXg\">")) | 39 "value=\"AXsiX2kiOiIxNDQyMjcwNTIwMDAwIn1ceiDv1yf
NK9OHcdb209l3fM4p_gn-Uaembaa8gr3WXg\">")) |
| 39 }) | 40 }) |
| 40 | 41 |
| 41 Convey("Middleware works", t, func() { | 42 Convey("Middleware works", t, func() { |
| 42 c := makeContext() | 43 c := makeContext() |
| 43 tok, _ := Token(c) | 44 tok, _ := Token(c) |
| 44 | 45 |
| 45 » » h := WithTokenCheck(func(c context.Context, rw http.ResponseWrit
er, r *http.Request, p httprouter.Params) { | 46 » » initialize := func(c context.Context, rw http.ResponseWriter, r
*http.Request, p httprouter.Params) router.Handler { |
| 46 » » » rw.Write([]byte("hi")) | 47 » » » return func(ctx *router.Context) { |
| 47 » » }) | 48 » » » » ctx.Context = c |
| 49 » » » » ctx.Writer = rw |
| 50 » » » » ctx.Request = r |
| 51 » » » » ctx.Params = p |
| 52 » » » } |
| 53 » » } |
| 54 » » handler := func(c *router.Context) { |
| 55 » » » c.Writer.Write([]byte("hi")) |
| 56 » » } |
| 48 | 57 |
| 49 // Has token -> works. | 58 // Has token -> works. |
| 50 rec := httptest.NewRecorder() | 59 rec := httptest.NewRecorder() |
| 51 req := makeRequest(tok) | 60 req := makeRequest(tok) |
| 52 » » h(c, rec, req, nil) | 61 » » router.ChainHandlers(initialize(c, rec, req, nil), WithTokenChec
k(), handler)() |
| 53 So(rec.Code, ShouldEqual, 200) | 62 So(rec.Code, ShouldEqual, 200) |
| 54 | 63 |
| 55 // No token. | 64 // No token. |
| 56 rec = httptest.NewRecorder() | 65 rec = httptest.NewRecorder() |
| 57 req = makeRequest("") | 66 req = makeRequest("") |
| 58 » » h(c, rec, req, nil) | 67 » » router.ChainHandlers(initialize(c, rec, req, nil), WithTokenChec
k(), handler)() |
| 59 So(rec.Code, ShouldEqual, 403) | 68 So(rec.Code, ShouldEqual, 403) |
| 60 | 69 |
| 61 // Bad token. | 70 // Bad token. |
| 62 rec = httptest.NewRecorder() | 71 rec = httptest.NewRecorder() |
| 63 req = makeRequest("blah") | 72 req = makeRequest("blah") |
| 64 » » h(c, rec, req, nil) | 73 » » router.ChainHandlers(initialize(c, rec, req, nil), WithTokenChec
k(), handler)() |
| 65 So(rec.Code, ShouldEqual, 403) | 74 So(rec.Code, ShouldEqual, 403) |
| 66 }) | 75 }) |
| 67 } | 76 } |
| 68 | 77 |
| 69 func makeContext() context.Context { | 78 func makeContext() context.Context { |
| 70 c := testsecrets.Use(context.Background()) | 79 c := testsecrets.Use(context.Background()) |
| 71 c, _ = testclock.UseTime(c, time.Unix(1442270520, 0)) | 80 c, _ = testclock.UseTime(c, time.Unix(1442270520, 0)) |
| 72 return c | 81 return c |
| 73 } | 82 } |
| 74 | 83 |
| 75 func makeRequest(tok string) *http.Request { | 84 func makeRequest(tok string) *http.Request { |
| 76 body := url.Values{} | 85 body := url.Values{} |
| 77 if tok != "" { | 86 if tok != "" { |
| 78 body.Add("xsrf_token", tok) | 87 body.Add("xsrf_token", tok) |
| 79 } | 88 } |
| 80 req, _ := http.NewRequest("POST", "https://example.com", strings.NewRead
er(body.Encode())) | 89 req, _ := http.NewRequest("POST", "https://example.com", strings.NewRead
er(body.Encode())) |
| 81 req.Header.Add("Content-Type", "application/x-www-form-urlencoded") | 90 req.Header.Add("Content-Type", "application/x-www-form-urlencoded") |
| 82 return req | 91 return req |
| 83 } | 92 } |
| OLD | NEW |