| OLD | NEW |
| 1 // Copyright 2015 The LUCI Authors. All rights reserved. | 1 // Copyright 2015 The LUCI Authors. All rights reserved. |
| 2 // Use of this source code is governed under the Apache License, Version 2.0 | 2 // Use of this source code is governed under the Apache License, Version 2.0 |
| 3 // that can be found in the LICENSE file. | 3 // that can be found in the LICENSE file. |
| 4 | 4 |
| 5 package gaemiddleware | 5 package gaemiddleware |
| 6 | 6 |
| 7 import ( | 7 import ( |
| 8 "fmt" | 8 "fmt" |
| 9 "net/http" | 9 "net/http" |
| 10 | 10 |
| 11 "github.com/julienschmidt/httprouter" | |
| 12 "github.com/luci/gae/service/info" | 11 "github.com/luci/gae/service/info" |
| 13 "github.com/luci/luci-go/common/logging" | 12 "github.com/luci/luci-go/common/logging" |
| 14 » "github.com/luci/luci-go/server/middleware" | 13 » "github.com/luci/luci-go/server/router" |
| 15 "golang.org/x/net/context" | 14 "golang.org/x/net/context" |
| 16 ) | 15 ) |
| 17 | 16 |
| 18 var devAppserverBypassFn = func(c context.Context) bool { | 17 var devAppserverBypassFn = func(c context.Context) bool { |
| 19 return info.Get(c).IsDevAppServer() | 18 return info.Get(c).IsDevAppServer() |
| 20 } | 19 } |
| 21 | 20 |
| 22 // RequireCron ensures that this handler was run from the appengine 'cron' | 21 // RequireCron ensures that this handler was run from the appengine 'cron' |
| 23 // service. Otherwise it aborts the request with a StatusForbidden. | 22 // service. Otherwise it aborts the request with a StatusForbidden. |
| 24 // | 23 // |
| 25 // This middleware has no effect when using 'BaseTest' or when running under | 24 // This middleware has no effect when using 'BaseTest' or when running under |
| 26 // dev_appserver.py | 25 // dev_appserver.py |
| 27 func RequireCron(h middleware.Handler) middleware.Handler { | 26 func RequireCron() router.Handler { |
| 28 » return func(c context.Context, rw http.ResponseWriter, r *http.Request,
p httprouter.Params) { | 27 » return func(c *router.Context) { |
| 29 » » if !devAppserverBypassFn(c) { | 28 » » if !devAppserverBypassFn(c.Context) { |
| 30 » » » if r.Header.Get("X-Appengine-Cron") != "true" { | 29 » » » if c.Request.Header.Get("X-Appengine-Cron") != "true" { |
| 31 » » » » rw.WriteHeader(http.StatusForbidden) | 30 » » » » c.Writer.WriteHeader(http.StatusForbidden) |
| 32 » » » » logging.Errorf(c, "request not made from cron") | 31 » » » » logging.Errorf(c.Context, "request not made from
cron") |
| 33 » » » » fmt.Fprint(rw, "error: must be run from cron") | 32 » » » » fmt.Fprint(c.Writer, "error: must be run from cr
on") |
| 33 » » » » c.Abort() |
| 34 return | 34 return |
| 35 } | 35 } |
| 36 } | 36 } |
| 37 h(c, rw, r, p) | |
| 38 } | 37 } |
| 39 } | 38 } |
| 40 | 39 |
| 41 // RequireTaskQueue ensures that this handler was run from the specified | 40 // RequireTaskQueue ensures that this handler was run from the specified |
| 42 // appengine 'taskqueue' queue. Otherwise it aborts the request with | 41 // appengine 'taskqueue' queue. Otherwise it aborts the request with |
| 43 // a StatusForbidden. | 42 // a StatusForbidden. |
| 44 // | 43 // |
| 45 // if `queue` is the empty string, than this simply checks that this handler was | 44 // if `queue` is the empty string, than this simply checks that this handler was |
| 46 // run from ANY appengine taskqueue. | 45 // run from ANY appengine taskqueue. |
| 47 // | 46 // |
| 48 // This middleware has no effect when using 'BaseTest' or when running under | 47 // This middleware has no effect when using 'BaseTest' or when running under |
| 49 // dev_appserver.py | 48 // dev_appserver.py |
| 50 func RequireTaskQueue(queue string, h middleware.Handler) middleware.Handler { | 49 func RequireTaskQueue(queue string) router.Handler { |
| 51 » return func(c context.Context, rw http.ResponseWriter, r *http.Request,
p httprouter.Params) { | 50 » return func(c *router.Context) { |
| 52 » » if !devAppserverBypassFn(c) { | 51 » » if !devAppserverBypassFn(c.Context) { |
| 53 » » » qName := r.Header.Get("X-AppEngine-QueueName") | 52 » » » qName := c.Request.Header.Get("X-AppEngine-QueueName") |
| 54 if qName == "" || (queue != "" && queue != qName) { | 53 if qName == "" || (queue != "" && queue != qName) { |
| 55 » » » » rw.WriteHeader(http.StatusForbidden) | 54 » » » » c.Writer.WriteHeader(http.StatusForbidden) |
| 56 » » » » logging.Errorf(c, "request made from wrong taskq
ueue: %q v %q", qName, queue) | 55 » » » » logging.Errorf(c.Context, "request made from wro
ng taskqueue: %q v %q", qName, queue) |
| 57 » » » » fmt.Fprintf(rw, "error: must be run from the cor
rect taskqueue") | 56 » » » » fmt.Fprintf(c.Writer, "error: must be run from t
he correct taskqueue") |
| 57 » » » » c.Abort() |
| 58 return | 58 return |
| 59 } | 59 } |
| 60 } | 60 } |
| 61 h(c, rw, r, p) | |
| 62 } | 61 } |
| 63 } | 62 } |
| OLD | NEW |