| OLD | NEW |
| 1 // Copyright 2015 The LUCI Authors. All rights reserved. | 1 // Copyright 2015 The LUCI Authors. All rights reserved. |
| 2 // Use of this source code is governed under the Apache License, Version 2.0 | 2 // Use of this source code is governed under the Apache License, Version 2.0 |
| 3 // that can be found in the LICENSE file. | 3 // that can be found in the LICENSE file. |
| 4 | 4 |
| 5 package xsrf | 5 package xsrf |
| 6 | 6 |
| 7 import ( | 7 import ( |
| 8 "html/template" | 8 "html/template" |
| 9 "net/http" | 9 "net/http" |
| 10 "net/http/httptest" | 10 "net/http/httptest" |
| 11 "net/url" | 11 "net/url" |
| 12 "strings" | 12 "strings" |
| 13 "testing" | 13 "testing" |
| 14 "time" | 14 "time" |
| 15 | 15 |
| 16 "github.com/julienschmidt/httprouter" | |
| 17 "golang.org/x/net/context" | 16 "golang.org/x/net/context" |
| 18 | 17 |
| 19 "github.com/luci/luci-go/common/clock/testclock" | 18 "github.com/luci/luci-go/common/clock/testclock" |
| 19 "github.com/luci/luci-go/server/router" |
| 20 "github.com/luci/luci-go/server/secrets/testsecrets" | 20 "github.com/luci/luci-go/server/secrets/testsecrets" |
| 21 | 21 |
| 22 . "github.com/smartystreets/goconvey/convey" | 22 . "github.com/smartystreets/goconvey/convey" |
| 23 ) | 23 ) |
| 24 | 24 |
| 25 func TestXsrf(t *testing.T) { | 25 func TestXsrf(t *testing.T) { |
| 26 Convey("Token + Check", t, func() { | 26 Convey("Token + Check", t, func() { |
| 27 c := makeContext() | 27 c := makeContext() |
| 28 tok, err := Token(c) | 28 tok, err := Token(c) |
| 29 So(err, ShouldBeNil) | 29 So(err, ShouldBeNil) |
| 30 So(Check(c, tok), ShouldBeNil) | 30 So(Check(c, tok), ShouldBeNil) |
| 31 So(Check(c, tok+"abc"), ShouldNotBeNil) | 31 So(Check(c, tok+"abc"), ShouldNotBeNil) |
| 32 }) | 32 }) |
| 33 | 33 |
| 34 Convey("TokenField works", t, func() { | 34 Convey("TokenField works", t, func() { |
| 35 c := makeContext() | 35 c := makeContext() |
| 36 So(TokenField(c), ShouldResemble, | 36 So(TokenField(c), ShouldResemble, |
| 37 template.HTML("<input type=\"hidden\" name=\"xsrf_token\
" "+ | 37 template.HTML("<input type=\"hidden\" name=\"xsrf_token\
" "+ |
| 38 "value=\"AXsiX2kiOiIxNDQyMjcwNTIwMDAwIn1ceiDv1yf
NK9OHcdb209l3fM4p_gn-Uaembaa8gr3WXg\">")) | 38 "value=\"AXsiX2kiOiIxNDQyMjcwNTIwMDAwIn1ceiDv1yf
NK9OHcdb209l3fM4p_gn-Uaembaa8gr3WXg\">")) |
| 39 }) | 39 }) |
| 40 | 40 |
| 41 Convey("Middleware works", t, func() { | 41 Convey("Middleware works", t, func() { |
| 42 c := makeContext() | 42 c := makeContext() |
| 43 tok, _ := Token(c) | 43 tok, _ := Token(c) |
| 44 | 44 |
| 45 » » h := WithTokenCheck(func(c context.Context, rw http.ResponseWrit
er, r *http.Request, p httprouter.Params) { | 45 » » h := func(c *router.Context) { |
| 46 » » » rw.Write([]byte("hi")) | 46 » » » c.Writer.Write([]byte("hi")) |
| 47 » » }) | 47 » » } |
| 48 » » mc := router.MiddlewareChain{WithTokenCheck} |
| 48 | 49 |
| 49 // Has token -> works. | 50 // Has token -> works. |
| 50 rec := httptest.NewRecorder() | 51 rec := httptest.NewRecorder() |
| 51 req := makeRequest(tok) | 52 req := makeRequest(tok) |
| 52 » » h(c, rec, req, nil) | 53 » » router.RunMiddleware(&router.Context{ |
| 54 » » » Context: c, |
| 55 » » » Writer: rec, |
| 56 » » » Request: req, |
| 57 » » }, mc, h) |
| 53 So(rec.Code, ShouldEqual, 200) | 58 So(rec.Code, ShouldEqual, 200) |
| 54 | 59 |
| 55 // No token. | 60 // No token. |
| 56 rec = httptest.NewRecorder() | 61 rec = httptest.NewRecorder() |
| 57 req = makeRequest("") | 62 req = makeRequest("") |
| 58 » » h(c, rec, req, nil) | 63 » » router.RunMiddleware(&router.Context{ |
| 64 » » » Context: c, |
| 65 » » » Writer: rec, |
| 66 » » » Request: req, |
| 67 » » }, mc, h) |
| 59 So(rec.Code, ShouldEqual, 403) | 68 So(rec.Code, ShouldEqual, 403) |
| 60 | 69 |
| 61 // Bad token. | 70 // Bad token. |
| 62 rec = httptest.NewRecorder() | 71 rec = httptest.NewRecorder() |
| 63 req = makeRequest("blah") | 72 req = makeRequest("blah") |
| 64 » » h(c, rec, req, nil) | 73 » » router.RunMiddleware(&router.Context{ |
| 74 » » » Context: c, |
| 75 » » » Writer: rec, |
| 76 » » » Request: req, |
| 77 » » }, mc, h) |
| 65 So(rec.Code, ShouldEqual, 403) | 78 So(rec.Code, ShouldEqual, 403) |
| 66 }) | 79 }) |
| 67 } | 80 } |
| 68 | 81 |
| 69 func makeContext() context.Context { | 82 func makeContext() context.Context { |
| 70 c := testsecrets.Use(context.Background()) | 83 c := testsecrets.Use(context.Background()) |
| 71 c, _ = testclock.UseTime(c, time.Unix(1442270520, 0)) | 84 c, _ = testclock.UseTime(c, time.Unix(1442270520, 0)) |
| 72 return c | 85 return c |
| 73 } | 86 } |
| 74 | 87 |
| 75 func makeRequest(tok string) *http.Request { | 88 func makeRequest(tok string) *http.Request { |
| 76 body := url.Values{} | 89 body := url.Values{} |
| 77 if tok != "" { | 90 if tok != "" { |
| 78 body.Add("xsrf_token", tok) | 91 body.Add("xsrf_token", tok) |
| 79 } | 92 } |
| 80 req, _ := http.NewRequest("POST", "https://example.com", strings.NewRead
er(body.Encode())) | 93 req, _ := http.NewRequest("POST", "https://example.com", strings.NewRead
er(body.Encode())) |
| 81 req.Header.Add("Content-Type", "application/x-www-form-urlencoded") | 94 req.Header.Add("Content-Type", "application/x-www-form-urlencoded") |
| 82 return req | 95 return req |
| 83 } | 96 } |
| OLD | NEW |