| Index: common/auth/localauth/server.go
|
| diff --git a/common/auth/localauth/server.go b/common/auth/localauth/server.go
|
| index 96b44f63434d291dec3436923b13e78240ba38c9..49df2dc78075e1008eb1380fd77da79b9d1ccc9e 100644
|
| --- a/common/auth/localauth/server.go
|
| +++ b/common/auth/localauth/server.go
|
| @@ -71,8 +71,15 @@ type ErrorWithCode interface {
|
| // Processes that need a token can discover location of this server by looking
|
| // at "local_auth" section of LUCI_CONTEXT.
|
| type Server struct {
|
| - // TokenGenerator produces access tokens.
|
| - TokenGenerator TokenGenerator
|
| + // TokenGenerators produce access tokens for given account IDs.
|
| + TokenGenerators map[string]TokenGenerator
|
| +
|
| + // DefaultAccountID is account ID subprocesses should pick by default.
|
| + //
|
| + // It is put into "local_auth" section of LUCI_CONTEXT. If empty string,
|
| + // subprocesses won't attempt to use any account by default (they still can
|
| + // pick some non-default account though).
|
| + DefaultAccountID string
|
|
|
| // Port is a local TCP port to bind to or 0 to allow the OS to pick one.
|
| Port int
|
| @@ -92,9 +99,9 @@ type Server struct {
|
| // The provided context is used as base context for request handlers and for
|
| // logging.
|
| //
|
| -// Returns lucictx.LocalAuth structure that specifies how to contact the server.
|
| -// It should be put into "local_auth" section of LUCI_CONTEXT where clients can
|
| -// discover it.
|
| +// Returns a copy of lucictx.LocalAuth structure that specifies how to contact
|
| +// the server. It should be put into "local_auth" section of LUCI_CONTEXT where
|
| +// clients can discover it.
|
| func (s *Server) Initialize(ctx context.Context) (*lucictx.LocalAuth, error) {
|
| s.l.Lock()
|
| defer s.l.Unlock()
|
| @@ -117,9 +124,22 @@ func (s *Server) Initialize(ctx context.Context) (*lucictx.LocalAuth, error) {
|
| s.listener = ln
|
| s.secret = secret
|
|
|
| + // Build a sorted list of LocalAuthAccount to put into the context.
|
| + ids := make([]string, 0, len(s.TokenGenerators))
|
| + for id := range s.TokenGenerators {
|
| + ids = append(ids, id)
|
| + }
|
| + sort.Strings(ids)
|
| + accounts := make([]lucictx.LocalAuthAccount, len(ids))
|
| + for i, id := range ids {
|
| + accounts[i] = lucictx.LocalAuthAccount{ID: id}
|
| + }
|
| +
|
| return &lucictx.LocalAuth{
|
| - RPCPort: uint32(ln.Addr().(*net.TCPAddr).Port),
|
| - Secret: secret,
|
| + RPCPort: uint32(ln.Addr().(*net.TCPAddr).Port),
|
| + Secret: secret,
|
| + Accounts: accounts,
|
| + DefaultAccountID: s.DefaultAccountID,
|
| }, nil
|
| }
|
|
|
| @@ -147,7 +167,7 @@ func (s *Server) Serve() (err error) {
|
| ctx: s.ctx,
|
| wg: &s.wg,
|
| secret: s.secret,
|
| - tokens: s.TokenGenerator,
|
| + tokens: s.TokenGenerators,
|
| },
|
| }
|
| s.l.Unlock()
|
| @@ -221,7 +241,8 @@ const minTokenLifetime = 3 * time.Minute
|
| // Request body:
|
| // {
|
| // "scopes": [<string scope1>, <string scope2>, ...],
|
| -// "secret": <string from LUCI_CONTEXT.local_auth.secret>
|
| +// "secret": <string from LUCI_CONTEXT.local_auth.secret>,
|
| +// "account_id": <ID of some account from LUCI_CONTEXT.local_auth.accounts>
|
| // }
|
| // Response body:
|
| // {
|
| @@ -234,10 +255,10 @@ const minTokenLifetime = 3 * time.Minute
|
| // See also python counterpart of this code:
|
| // https://github.com/luci/luci-py/blob/master/client/utils/auth_server.py
|
| type protocolHandler struct {
|
| - ctx context.Context // the parent context
|
| - wg *sync.WaitGroup // used for graceful shutdown
|
| - secret []byte // expected "secret" value
|
| - tokens TokenGenerator // the actual producer of tokens
|
| + ctx context.Context // the parent context
|
| + wg *sync.WaitGroup // used for graceful shutdown
|
| + secret []byte // expected "secret" value
|
| + tokens map[string]TokenGenerator // the actual producer of tokens (per account)
|
| }
|
|
|
| // protocolError triggers an HTTP reply with some non-200 status code.
|
| @@ -373,17 +394,31 @@ func (h *protocolHandler) routeToImpl(method string, request []byte) (interface{
|
| // RPC implementations.
|
|
|
| func (h *protocolHandler) handleGetOAuthToken(req *rpcs.GetOAuthTokenRequest) (*rpcs.GetOAuthTokenResponse, error) {
|
| - // Validate the request, verify the correct secret is passed.
|
| + // Validate the request.
|
| if err := req.Validate(); err != nil {
|
| return nil, &protocolError{
|
| Status: 400,
|
| - Message: err.Error(),
|
| + Message: fmt.Sprintf("Bad request: %s.", err.Error()),
|
| + }
|
| + }
|
| + // TODO(vadimsh): Remove this check once it is moved into Validate().
|
| + if req.AccountID == "" {
|
| + return nil, &protocolError{
|
| + Status: 400,
|
| + Message: `Bad request: field "account_id" is required.`,
|
| }
|
| }
|
| if subtle.ConstantTimeCompare(h.secret, req.Secret) != 1 {
|
| return nil, &protocolError{
|
| Status: 403,
|
| - Message: `Invalid secret.`,
|
| + Message: "Invalid secret.",
|
| + }
|
| + }
|
| + generator := h.tokens[req.AccountID]
|
| + if generator == nil {
|
| + return nil, &protocolError{
|
| + Status: 404,
|
| + Message: fmt.Sprintf("Unrecognized account ID %q.", req.AccountID),
|
| }
|
| }
|
|
|
| @@ -396,7 +431,7 @@ func (h *protocolHandler) handleGetOAuthToken(req *rpcs.GetOAuthTokenRequest) (*
|
| sort.Strings(sortedScopes)
|
|
|
| // Ask the token provider for the token. This may produce ErrorWithCode.
|
| - tok, err := h.tokens(h.ctx, sortedScopes, minTokenLifetime)
|
| + tok, err := generator(h.ctx, sortedScopes, minTokenLifetime)
|
| if err != nil {
|
| return nil, err
|
| }
|
|
|