| OLD | NEW |
| 1 // Copyright 2017 The LUCI Authors. All rights reserved. | 1 // Copyright 2017 The LUCI Authors. All rights reserved. |
| 2 // Use of this source code is governed under the Apache License, Version 2.0 | 2 // Use of this source code is governed under the Apache License, Version 2.0 |
| 3 // that can be found in the LICENSE file. | 3 // that can be found in the LICENSE file. |
| 4 | 4 |
| 5 package localauth | 5 package localauth |
| 6 | 6 |
| 7 import ( | 7 import ( |
| 8 "crypto/subtle" | 8 "crypto/subtle" |
| 9 "encoding/json" | 9 "encoding/json" |
| 10 "fmt" | 10 "fmt" |
| (...skipping 53 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 64 | 64 |
| 65 // Code returns a code to put into RPC response alongside the error mess
age. | 65 // Code returns a code to put into RPC response alongside the error mess
age. |
| 66 Code() int | 66 Code() int |
| 67 } | 67 } |
| 68 | 68 |
| 69 // Server runs a local RPC server that hands out access tokens. | 69 // Server runs a local RPC server that hands out access tokens. |
| 70 // | 70 // |
| 71 // Processes that need a token can discover location of this server by looking | 71 // Processes that need a token can discover location of this server by looking |
| 72 // at "local_auth" section of LUCI_CONTEXT. | 72 // at "local_auth" section of LUCI_CONTEXT. |
| 73 type Server struct { | 73 type Server struct { |
| 74 » // TokenGenerator produces access tokens. | 74 » // TokenGenerators produce access tokens for given account IDs. |
| 75 » TokenGenerator TokenGenerator | 75 » TokenGenerators map[string]TokenGenerator |
| 76 |
| 77 » // DefaultAccountID is account ID subprocesses should pick by default. |
| 78 » // |
| 79 » // It is put into "local_auth" section of LUIC_CONTEXT. If empty string, |
| 80 » // subprocesses won't attempt to use any account by default (they still
can |
| 81 » // pick some non-default account though). |
| 82 » DefaultAccountID string |
| 76 | 83 |
| 77 // Port is a local TCP port to bind to or 0 to allow the OS to pick one. | 84 // Port is a local TCP port to bind to or 0 to allow the OS to pick one. |
| 78 Port int | 85 Port int |
| 79 | 86 |
| 80 l sync.Mutex | 87 l sync.Mutex |
| 81 secret []byte // the clients are expected to send this sec
ret | 88 secret []byte // the clients are expected to send this sec
ret |
| 82 listener net.Listener // to know what to stop in Close, nil after
Close | 89 listener net.Listener // to know what to stop in Close, nil after
Close |
| 83 wg sync.WaitGroup // +1 for each request being processed now | 90 wg sync.WaitGroup // +1 for each request being processed now |
| 84 ctx context.Context // derived from ctx in Initialize | 91 ctx context.Context // derived from ctx in Initialize |
| 85 cancel context.CancelFunc // cancels 'ctx' | 92 cancel context.CancelFunc // cancels 'ctx' |
| 86 | 93 |
| 87 testingServeHook func() // called right before serving | 94 testingServeHook func() // called right before serving |
| 88 } | 95 } |
| 89 | 96 |
| 90 // Initialize binds the server to a local port and prepares it for serving. | 97 // Initialize binds the server to a local port and prepares it for serving. |
| 91 // | 98 // |
| 92 // The provided context is used as base context for request handlers and for | 99 // The provided context is used as base context for request handlers and for |
| 93 // logging. | 100 // logging. |
| 94 // | 101 // |
| 95 // Returns lucictx.LocalAuth structure that specifies how to contact the server. | 102 // Returns a copy of lucictx.LocalAuth structure that specifies how to contact |
| 96 // It should be put into "local_auth" section of LUCI_CONTEXT where clients can | 103 // the server. It should be put into "local_auth" section of LUCI_CONTEXT where |
| 97 // discover it. | 104 // clients can discover it. |
| 98 func (s *Server) Initialize(ctx context.Context) (*lucictx.LocalAuth, error) { | 105 func (s *Server) Initialize(ctx context.Context) (*lucictx.LocalAuth, error) { |
| 99 s.l.Lock() | 106 s.l.Lock() |
| 100 defer s.l.Unlock() | 107 defer s.l.Unlock() |
| 101 | 108 |
| 102 if s.ctx != nil { | 109 if s.ctx != nil { |
| 103 return nil, fmt.Errorf("already initialized") | 110 return nil, fmt.Errorf("already initialized") |
| 104 } | 111 } |
| 105 | 112 |
| 106 secret := make([]byte, 48) | 113 secret := make([]byte, 48) |
| 107 if _, err := cryptorand.Read(ctx, secret); err != nil { | 114 if _, err := cryptorand.Read(ctx, secret); err != nil { |
| 108 return nil, err | 115 return nil, err |
| 109 } | 116 } |
| 110 | 117 |
| 111 ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", s.Port)) | 118 ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", s.Port)) |
| 112 if err != nil { | 119 if err != nil { |
| 113 return nil, err | 120 return nil, err |
| 114 } | 121 } |
| 115 | 122 |
| 116 s.ctx, s.cancel = context.WithCancel(ctx) | 123 s.ctx, s.cancel = context.WithCancel(ctx) |
| 117 s.listener = ln | 124 s.listener = ln |
| 118 s.secret = secret | 125 s.secret = secret |
| 119 | 126 |
| 127 // Build a sorted list of LocalAuthAccount to put into the context. |
| 128 ids := make([]string, 0, len(s.TokenGenerators)) |
| 129 for id := range s.TokenGenerators { |
| 130 ids = append(ids, id) |
| 131 } |
| 132 sort.Strings(ids) |
| 133 accounts := make([]lucictx.LocalAuthAccount, len(ids)) |
| 134 for i, id := range ids { |
| 135 accounts[i] = lucictx.LocalAuthAccount{ID: id} |
| 136 } |
| 137 |
| 120 return &lucictx.LocalAuth{ | 138 return &lucictx.LocalAuth{ |
| 121 » » RPCPort: uint32(ln.Addr().(*net.TCPAddr).Port), | 139 » » RPCPort: uint32(ln.Addr().(*net.TCPAddr).Port), |
| 122 » » Secret: secret, | 140 » » Secret: secret, |
| 141 » » Accounts: accounts, |
| 142 » » DefaultAccountID: s.DefaultAccountID, |
| 123 }, nil | 143 }, nil |
| 124 } | 144 } |
| 125 | 145 |
| 126 // Serve runs a serving loop. | 146 // Serve runs a serving loop. |
| 127 // | 147 // |
| 128 // It unblocks once Close is called and all pending requests are served. | 148 // It unblocks once Close is called and all pending requests are served. |
| 129 // | 149 // |
| 130 // Returns nil if serving was stopped by Close or non-nil if it failed for some | 150 // Returns nil if serving was stopped by Close or non-nil if it failed for some |
| 131 // other reason. | 151 // other reason. |
| 132 func (s *Server) Serve() (err error) { | 152 func (s *Server) Serve() (err error) { |
| 133 s.l.Lock() | 153 s.l.Lock() |
| 134 switch { | 154 switch { |
| 135 case s.ctx == nil: | 155 case s.ctx == nil: |
| 136 err = fmt.Errorf("not initialized") | 156 err = fmt.Errorf("not initialized") |
| 137 case s.listener == nil: | 157 case s.listener == nil: |
| 138 err = fmt.Errorf("already closed") | 158 err = fmt.Errorf("already closed") |
| 139 } | 159 } |
| 140 if err != nil { | 160 if err != nil { |
| 141 s.l.Unlock() | 161 s.l.Unlock() |
| 142 return | 162 return |
| 143 } | 163 } |
| 144 listener := s.listener // accessed outside the lock | 164 listener := s.listener // accessed outside the lock |
| 145 srv := http.Server{ | 165 srv := http.Server{ |
| 146 Handler: &protocolHandler{ | 166 Handler: &protocolHandler{ |
| 147 ctx: s.ctx, | 167 ctx: s.ctx, |
| 148 wg: &s.wg, | 168 wg: &s.wg, |
| 149 secret: s.secret, | 169 secret: s.secret, |
| 150 » » » tokens: s.TokenGenerator, | 170 » » » tokens: s.TokenGenerators, |
| 151 }, | 171 }, |
| 152 } | 172 } |
| 153 s.l.Unlock() | 173 s.l.Unlock() |
| 154 | 174 |
| 155 // Notify unit tests that we have initialized. | 175 // Notify unit tests that we have initialized. |
| 156 if s.testingServeHook != nil { | 176 if s.testingServeHook != nil { |
| 157 s.testingServeHook() | 177 s.testingServeHook() |
| 158 } | 178 } |
| 159 | 179 |
| 160 err = srv.Serve(listener) // blocks until Close() is called | 180 err = srv.Serve(listener) // blocks until Close() is called |
| (...skipping 53 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 214 // * The server sets Content-Length header in the response. | 234 // * The server sets Content-Length header in the response. |
| 215 // * Protocol-level errors have non-200 HTTP status code. | 235 // * Protocol-level errors have non-200 HTTP status code. |
| 216 // * Logic errors have 200 HTTP status code and error is communicated in | 236 // * Logic errors have 200 HTTP status code and error is communicated in |
| 217 // the response body. | 237 // the response body. |
| 218 // | 238 // |
| 219 // The only supported method currently is 'GetOAuthToken': | 239 // The only supported method currently is 'GetOAuthToken': |
| 220 // | 240 // |
| 221 // Request body: | 241 // Request body: |
| 222 // { | 242 // { |
| 223 // "scopes": [<string scope1>, <string scope2>, ...], | 243 // "scopes": [<string scope1>, <string scope2>, ...], |
| 224 // "secret": <string from LUCI_CONTEXT.local_auth.secret> | 244 // "secret": <string from LUCI_CONTEXT.local_auth.secret>, |
| 245 // "account_id": <ID of some account from LUCI_CONTEXT.local_auth.accounts> |
| 225 // } | 246 // } |
| 226 // Response body: | 247 // Response body: |
| 227 // { | 248 // { |
| 228 // "error_code": <int, on success not set or 0>, | 249 // "error_code": <int, on success not set or 0>, |
| 229 // "error_message": <string, on success not set>, | 250 // "error_message": <string, on success not set>, |
| 230 // "access_token": <string with actual token (on success)>, | 251 // "access_token": <string with actual token (on success)>, |
| 231 // "expiry": <int with unix timestamp in seconds (on success)> | 252 // "expiry": <int with unix timestamp in seconds (on success)> |
| 232 // } | 253 // } |
| 233 // | 254 // |
| 234 // See also python counterpart of this code: | 255 // See also python counterpart of this code: |
| 235 // https://github.com/luci/luci-py/blob/master/client/utils/auth_server.py | 256 // https://github.com/luci/luci-py/blob/master/client/utils/auth_server.py |
| 236 type protocolHandler struct { | 257 type protocolHandler struct { |
| 237 » ctx context.Context // the parent context | 258 » ctx context.Context // the parent context |
| 238 » wg *sync.WaitGroup // used for graceful shutdown | 259 » wg *sync.WaitGroup // used for graceful shutdown |
| 239 » secret []byte // expected "secret" value | 260 » secret []byte // expected "secret" value |
| 240 » tokens TokenGenerator // the actual producer of tokens | 261 » tokens map[string]TokenGenerator // the actual producer of tokens (per a
ccount) |
| 241 } | 262 } |
| 242 | 263 |
| 243 // protocolError triggers an HTTP reply with some non-200 status code. | 264 // protocolError triggers an HTTP reply with some non-200 status code. |
| 244 type protocolError struct { | 265 type protocolError struct { |
| 245 Status int // HTTP status to set | 266 Status int // HTTP status to set |
| 246 Message string // the message to put in the body | 267 Message string // the message to put in the body |
| 247 } | 268 } |
| 248 | 269 |
| 249 func (e *protocolError) Error() string { | 270 func (e *protocolError) Error() string { |
| 250 return fmt.Sprintf("%s (HTTP %d)", e.Message, e.Status) | 271 return fmt.Sprintf("%s (HTTP %d)", e.Message, e.Status) |
| (...skipping 115 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 366 Status: http.StatusNotFound, | 387 Status: http.StatusNotFound, |
| 367 Message: fmt.Sprintf("Unknown RPC method %q", method), | 388 Message: fmt.Sprintf("Unknown RPC method %q", method), |
| 368 } | 389 } |
| 369 } | 390 } |
| 370 } | 391 } |
| 371 | 392 |
| 372 //////////////////////////////////////////////////////////////////////////////// | 393 //////////////////////////////////////////////////////////////////////////////// |
| 373 // RPC implementations. | 394 // RPC implementations. |
| 374 | 395 |
| 375 func (h *protocolHandler) handleGetOAuthToken(req *rpcs.GetOAuthTokenRequest) (*
rpcs.GetOAuthTokenResponse, error) { | 396 func (h *protocolHandler) handleGetOAuthToken(req *rpcs.GetOAuthTokenRequest) (*
rpcs.GetOAuthTokenResponse, error) { |
| 376 » // Validate the request, verify the correct secret is passed. | 397 » // Validate the request. |
| 377 if err := req.Validate(); err != nil { | 398 if err := req.Validate(); err != nil { |
| 378 return nil, &protocolError{ | 399 return nil, &protocolError{ |
| 379 Status: 400, | 400 Status: 400, |
| 380 » » » Message: err.Error(), | 401 » » » Message: fmt.Sprintf("Bad request: %s.", err.Error()), |
| 402 » » } |
| 403 » } |
| 404 » // TODO(vadimsh): Remove this check once it is moved into Validate(). |
| 405 » if req.AccountID == "" { |
| 406 » » return nil, &protocolError{ |
| 407 » » » Status: 400, |
| 408 » » » Message: `Bad request: field "account_id" is required.`, |
| 381 } | 409 } |
| 382 } | 410 } |
| 383 if subtle.ConstantTimeCompare(h.secret, req.Secret) != 1 { | 411 if subtle.ConstantTimeCompare(h.secret, req.Secret) != 1 { |
| 384 return nil, &protocolError{ | 412 return nil, &protocolError{ |
| 385 Status: 403, | 413 Status: 403, |
| 386 » » » Message: `Invalid secret.`, | 414 » » » Message: "Invalid secret.", |
| 415 » » } |
| 416 » } |
| 417 » generator := h.tokens[req.AccountID] |
| 418 » if generator == nil { |
| 419 » » return nil, &protocolError{ |
| 420 » » » Status: 404, |
| 421 » » » Message: fmt.Sprintf("Unrecognized account ID %q.", req.
AccountID), |
| 387 } | 422 } |
| 388 } | 423 } |
| 389 | 424 |
| 390 // Dedup and sort scopes. | 425 // Dedup and sort scopes. |
| 391 scopes := stringset.New(len(req.Scopes)) | 426 scopes := stringset.New(len(req.Scopes)) |
| 392 for _, s := range req.Scopes { | 427 for _, s := range req.Scopes { |
| 393 scopes.Add(s) | 428 scopes.Add(s) |
| 394 } | 429 } |
| 395 sortedScopes := scopes.ToSlice() | 430 sortedScopes := scopes.ToSlice() |
| 396 sort.Strings(sortedScopes) | 431 sort.Strings(sortedScopes) |
| 397 | 432 |
| 398 // Ask the token provider for the token. This may produce ErrorWithCode. | 433 // Ask the token provider for the token. This may produce ErrorWithCode. |
| 399 » tok, err := h.tokens(h.ctx, sortedScopes, minTokenLifetime) | 434 » tok, err := generator(h.ctx, sortedScopes, minTokenLifetime) |
| 400 if err != nil { | 435 if err != nil { |
| 401 return nil, err | 436 return nil, err |
| 402 } | 437 } |
| 403 return &rpcs.GetOAuthTokenResponse{ | 438 return &rpcs.GetOAuthTokenResponse{ |
| 404 AccessToken: tok.AccessToken, | 439 AccessToken: tok.AccessToken, |
| 405 Expiry: tok.Expiry.Unix(), | 440 Expiry: tok.Expiry.Unix(), |
| 406 }, nil | 441 }, nil |
| 407 } | 442 } |
| OLD | NEW |