| OLD | NEW |
| 1 // Copyright 2015 The LUCI Authors. All rights reserved. | 1 // Copyright 2015 The LUCI Authors. All rights reserved. |
| 2 // Use of this source code is governed under the Apache License, Version 2.0 | 2 // Use of this source code is governed under the Apache License, Version 2.0 |
| 3 // that can be found in the LICENSE file. | 3 // that can be found in the LICENSE file. |
| 4 | 4 |
| 5 package auth | 5 package auth |
| 6 | 6 |
| 7 import ( | 7 import ( |
| 8 "fmt" | 8 "fmt" |
| 9 "net/http" | 9 "net/http" |
| 10 | 10 |
| 11 "github.com/julienschmidt/httprouter" | |
| 12 "golang.org/x/net/context" | 11 "golang.org/x/net/context" |
| 13 | 12 |
| 14 "github.com/luci/luci-go/common/errors" | 13 "github.com/luci/luci-go/common/errors" |
| 15 "github.com/luci/luci-go/common/logging" | 14 "github.com/luci/luci-go/common/logging" |
| 16 | 15 |
| 17 "github.com/luci/luci-go/server/auth/identity" | 16 "github.com/luci/luci-go/server/auth/identity" |
| 18 » "github.com/luci/luci-go/server/middleware" | 17 » "github.com/luci/luci-go/server/router" |
| 19 ) | 18 ) |
| 20 | 19 |
| 21 type authenticatorKey int | 20 type authenticatorKey int |
| 22 | 21 |
| 23 // SetAuthenticator injects copy of Authenticator (list of auth methods) into | 22 // SetAuthenticator injects copy of Authenticator (list of auth methods) into |
| 24 // the context to use by default in LoginURL, LogoutURL and Authenticate. | 23 // the context to use by default in LoginURL, LogoutURL and Authenticate. |
| 25 // Usually installed into the context by some base middleware. | 24 // Usually installed into the context by some base middleware. |
| 26 func SetAuthenticator(c context.Context, a Authenticator) context.Context { | 25 func SetAuthenticator(c context.Context, a Authenticator) context.Context { |
| 27 return context.WithValue(c, authenticatorKey(0), append(Authenticator(ni
l), a...)) | 26 return context.WithValue(c, authenticatorKey(0), append(Authenticator(ni
l), a...)) |
| 28 } | 27 } |
| 29 | 28 |
| 30 // GetAuthenticator extracts instance of Authenticator (list of auth methods) | 29 // GetAuthenticator extracts instance of Authenticator (list of auth methods) |
| 31 // from the context. Returns nil if no authenticator is set. | 30 // from the context. Returns nil if no authenticator is set. |
| 32 func GetAuthenticator(c context.Context) Authenticator { | 31 func GetAuthenticator(c context.Context) Authenticator { |
| 33 if a, ok := c.Value(authenticatorKey(0)).(Authenticator); ok { | 32 if a, ok := c.Value(authenticatorKey(0)).(Authenticator); ok { |
| 34 return a | 33 return a |
| 35 } | 34 } |
| 36 return nil | 35 return nil |
| 37 } | 36 } |
| 38 | 37 |
| 39 // Use is a middleware that simply puts given Authenticator into the context. | 38 // Use is a middleware that simply puts given Authenticator into the context. |
| 40 func Use(h middleware.Handler, a Authenticator) middleware.Handler { | 39 func Use(a Authenticator) router.Middleware { |
| 41 » return func(c context.Context, rw http.ResponseWriter, r *http.Request,
p httprouter.Params) { | 40 » return func(c *router.Context, next router.Handler) { |
| 42 » » h(SetAuthenticator(c, a), rw, r, p) | 41 » » c.Context = SetAuthenticator(c.Context, a) |
| 42 » » next(c) |
| 43 } | 43 } |
| 44 } | 44 } |
| 45 | 45 |
| 46 // LoginURL returns a URL that, when visited, prompts the user to sign in, | 46 // LoginURL returns a URL that, when visited, prompts the user to sign in, |
| 47 // then redirects the user to the URL specified by dest. It is wrapper around | 47 // then redirects the user to the URL specified by dest. It is wrapper around |
| 48 // LoginURL method of Authenticator in the context. | 48 // LoginURL method of Authenticator in the context. |
| 49 func LoginURL(c context.Context, dest string) (string, error) { | 49 func LoginURL(c context.Context, dest string) (string, error) { |
| 50 return GetAuthenticator(c).LoginURL(c, dest) | 50 return GetAuthenticator(c).LoginURL(c, dest) |
| 51 } | 51 } |
| 52 | 52 |
| 53 // LogoutURL returns a URL that, when visited, signs the user out, | 53 // LogoutURL returns a URL that, when visited, signs the user out, |
| 54 // then redirects the user to the URL specified by dest. It is wrapper around | 54 // then redirects the user to the URL specified by dest. It is wrapper around |
| 55 // LogoutURL method of Authenticator in the context. | 55 // LogoutURL method of Authenticator in the context. |
| 56 func LogoutURL(c context.Context, dest string) (string, error) { | 56 func LogoutURL(c context.Context, dest string) (string, error) { |
| 57 return GetAuthenticator(c).LogoutURL(c, dest) | 57 return GetAuthenticator(c).LogoutURL(c, dest) |
| 58 } | 58 } |
| 59 | 59 |
| 60 // Authenticate returns a wrapper around middleware.Handler that performs | 60 // Authenticate is a middleware that performs authentication (using Authenticato
r |
| 61 // authentication (using Authenticator in the context) and calls `h`. | 61 // in the context) and calls next handler. |
| 62 func Authenticate(h middleware.Handler) middleware.Handler { | 62 func Authenticate(c *router.Context, next router.Handler) { |
| 63 » return func(c context.Context, rw http.ResponseWriter, r *http.Request,
p httprouter.Params) { | 63 » a := GetAuthenticator(c.Context) |
| 64 » » a := GetAuthenticator(c) | 64 » if a == nil { |
| 65 » » if a == nil { | 65 » » replyError(c.Context, c.Writer, 500, "Authentication middleware
is not configured") |
| 66 » » » replyError(c, rw, 500, "Authentication middleware is not
configured") | 66 » » return |
| 67 » » » return | 67 » } |
| 68 » » } | 68 » ctx, err := a.Authenticate(c.Context, c.Request) |
| 69 » » ctx, err := a.Authenticate(c, r) | 69 » switch { |
| 70 » » switch { | 70 » case errors.IsTransient(err): |
| 71 » » case errors.IsTransient(err): | 71 » » replyError(c.Context, c.Writer, 500, fmt.Sprintf("Transient erro
r during authentication - %s", err)) |
| 72 » » » replyError(c, rw, 500, fmt.Sprintf("Transient error duri
ng authentication - %s", err)) | 72 » case err != nil: |
| 73 » » case err != nil: | 73 » » replyError(c.Context, c.Writer, 401, fmt.Sprintf("Authentication
error - %s", err)) |
| 74 » » » replyError(c, rw, 401, fmt.Sprintf("Authentication error
- %s", err)) | 74 » default: |
| 75 » » default: | 75 » » c.Context = ctx |
| 76 » » » h(ctx, rw, r, p) | 76 » » next(c) |
| 77 » » } | |
| 78 } | 77 } |
| 79 } | 78 } |
| 80 | 79 |
| 81 // Autologin is a middleware that redirects the user to login page if the user | 80 // Autologin is a middleware that redirects the user to login page if the user |
| 82 // is not signed in yet or authentication methods do not recognize user | 81 // is not signed in yet or authentication methods do not recognize user |
| 83 // credentials. Uses Authenticator instance in the context. | 82 // credentials. Uses Authenticator instance in the context. |
| 84 func Autologin(h middleware.Handler) middleware.Handler { | 83 func Autologin(c *router.Context, next router.Handler) { |
| 85 » return func(c context.Context, rw http.ResponseWriter, r *http.Request,
p httprouter.Params) { | 84 » a := GetAuthenticator(c.Context) |
| 86 » » a := GetAuthenticator(c) | 85 » if a == nil { |
| 87 » » if a == nil { | 86 » » replyError(c.Context, c.Writer, 500, "Authentication middleware
is not configured") |
| 88 » » » replyError(c, rw, 500, "Authentication middleware is not
configured") | 87 » » return |
| 88 » } |
| 89 » ctx, err := a.Authenticate(c.Context, c.Request) |
| 90 |
| 91 » switch { |
| 92 » case errors.IsTransient(err): |
| 93 » » replyError(c.Context, c.Writer, 500, fmt.Sprintf("Transient erro
r during authentication - %s", err)) |
| 94 |
| 95 » case err != nil: |
| 96 » » replyError(c.Context, c.Writer, 401, fmt.Sprintf("Authentication
error - %s", err)) |
| 97 |
| 98 » case CurrentIdentity(ctx).Kind() == identity.Anonymous: |
| 99 » » dest := c.Request.RequestURI |
| 100 » » if dest == "" { |
| 101 » » » // Make r.URL relative. |
| 102 » » » destURL := *c.Request.URL |
| 103 » » » destURL.Host = "" |
| 104 » » » destURL.Scheme = "" |
| 105 » » » dest = destURL.String() |
| 106 » » } |
| 107 » » url, err := a.LoginURL(c.Context, dest) |
| 108 » » if err != nil { |
| 109 » » » if errors.IsTransient(err) { |
| 110 » » » » replyError(c.Context, c.Writer, 500, fmt.Sprintf
("Transient error during authentication - %s", err)) |
| 111 » » » } else { |
| 112 » » » » replyError(c.Context, c.Writer, 401, fmt.Sprintf
("Authentication error - %s", err)) |
| 113 » » » } |
| 89 return | 114 return |
| 90 } | 115 } |
| 91 » » ctx, err := a.Authenticate(c, r) | 116 » » http.Redirect(c.Writer, c.Request, url, 302) |
| 92 | 117 |
| 93 » » switch { | 118 » default: |
| 94 » » case errors.IsTransient(err): | 119 » » c.Context = ctx |
| 95 » » » replyError(c, rw, 500, fmt.Sprintf("Transient error duri
ng authentication - %s", err)) | 120 » » next(c) |
| 96 | |
| 97 » » case err != nil: | |
| 98 » » » replyError(c, rw, 401, fmt.Sprintf("Authentication error
- %s", err)) | |
| 99 | |
| 100 » » case CurrentIdentity(ctx).Kind() == identity.Anonymous: | |
| 101 » » » dest := r.RequestURI | |
| 102 » » » if dest == "" { | |
| 103 » » » » // Make r.URL relative. | |
| 104 » » » » destURL := *r.URL | |
| 105 » » » » destURL.Host = "" | |
| 106 » » » » destURL.Scheme = "" | |
| 107 » » » » dest = destURL.String() | |
| 108 » » » } | |
| 109 » » » url, err := a.LoginURL(c, dest) | |
| 110 » » » if err != nil { | |
| 111 » » » » if errors.IsTransient(err) { | |
| 112 » » » » » replyError(c, rw, 500, fmt.Sprintf("Tran
sient error during authentication - %s", err)) | |
| 113 » » » » } else { | |
| 114 » » » » » replyError(c, rw, 401, fmt.Sprintf("Auth
entication error - %s", err)) | |
| 115 » » » » } | |
| 116 » » » » return | |
| 117 » » » } | |
| 118 » » » http.Redirect(rw, r, url, 302) | |
| 119 | |
| 120 » » default: | |
| 121 » » » h(ctx, rw, r, p) | |
| 122 » » } | |
| 123 } | 121 } |
| 124 } | 122 } |
| 125 | 123 |
| 126 // replyError logs the error and writes it to ResponseWriter. | 124 // replyError logs the error and writes it to ResponseWriter. |
| 127 func replyError(c context.Context, rw http.ResponseWriter, code int, msg string)
{ | 125 func replyError(c context.Context, rw http.ResponseWriter, code int, msg string)
{ |
| 128 logging.Errorf(c, "HTTP %d: %s", code, msg) | 126 logging.Errorf(c, "HTTP %d: %s", code, msg) |
| 129 http.Error(rw, msg, code) | 127 http.Error(rw, msg, code) |
| 130 } | 128 } |
| OLD | NEW |