| OLD | NEW |
| 1 // Copyright 2017 The LUCI Authors. All rights reserved. | 1 // Copyright 2017 The LUCI Authors. All rights reserved. |
| 2 // Use of this source code is governed under the Apache License, Version 2.0 | 2 // Use of this source code is governed under the Apache License, Version 2.0 |
| 3 // that can be found in the LICENSE file. | 3 // that can be found in the LICENSE file. |
| 4 | 4 |
| 5 package localauth | 5 package localauth |
| 6 | 6 |
| 7 import ( | 7 import ( |
| 8 "bytes" | 8 "bytes" |
| 9 "encoding/json" | 9 "encoding/json" |
| 10 "fmt" | 10 "fmt" |
| (...skipping 72 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 83 func TestProtocol(t *testing.T) { | 83 func TestProtocol(t *testing.T) { |
| 84 t.Parallel() | 84 t.Parallel() |
| 85 | 85 |
| 86 ctx := context.Background() | 86 ctx := context.Background() |
| 87 ctx, _ = testclock.UseTime(ctx, testclock.TestRecentTimeLocal) | 87 ctx, _ = testclock.UseTime(ctx, testclock.TestRecentTimeLocal) |
| 88 | 88 |
| 89 Convey("With server", t, func(c C) { | 89 Convey("With server", t, func(c C) { |
| 90 // Use channels to pass mocked requests/responses back and forth
. | 90 // Use channels to pass mocked requests/responses back and forth
. |
| 91 requests := make(chan []string, 10000) | 91 requests := make(chan []string, 10000) |
| 92 responses := make(chan interface{}, 1) | 92 responses := make(chan interface{}, 1) |
| 93 |
| 94 testGen := func(ctx context.Context, scopes []string, lifetime t
ime.Duration) (*oauth2.Token, error) { |
| 95 requests <- scopes |
| 96 var resp interface{} |
| 97 select { |
| 98 case resp = <-responses: |
| 99 default: |
| 100 c.Println("Unexpected token request") |
| 101 return nil, fmt.Errorf("Unexpected request") |
| 102 } |
| 103 switch resp := resp.(type) { |
| 104 case error: |
| 105 return nil, resp |
| 106 case *oauth2.Token: |
| 107 return resp, nil |
| 108 default: |
| 109 panic("unknown response") |
| 110 } |
| 111 } |
| 112 |
| 93 s := Server{ | 113 s := Server{ |
| 94 » » » TokenGenerator: func(ctx context.Context, scopes []strin
g, lifetime time.Duration) (*oauth2.Token, error) { | 114 » » » TokenGenerators: map[string]TokenGenerator{ |
| 95 » » » » requests <- scopes | 115 » » » » "acc_id": testGen, |
| 96 » » » » var resp interface{} | 116 » » » » "another_id": testGen, |
| 97 » » » » select { | |
| 98 » » » » case resp = <-responses: | |
| 99 » » » » default: | |
| 100 » » » » » c.Println("Unexpected token request") | |
| 101 » » » » » return nil, fmt.Errorf("Unexpected reque
st") | |
| 102 » » » » } | |
| 103 » » » » switch resp := resp.(type) { | |
| 104 » » » » case error: | |
| 105 » » » » » return nil, resp | |
| 106 » » » » case *oauth2.Token: | |
| 107 » » » » » return resp, nil | |
| 108 » » » » default: | |
| 109 » » » » » panic("unknown response") | |
| 110 » » » » } | |
| 111 }, | 117 }, |
| 118 DefaultAccountID: "acc_id", |
| 112 } | 119 } |
| 113 p, err := s.Initialize(ctx) | 120 p, err := s.Initialize(ctx) |
| 114 So(err, ShouldBeNil) | 121 So(err, ShouldBeNil) |
| 115 | 122 |
| 123 So(p.Accounts, ShouldResemble, []lucictx.LocalAuthAccount{{ID: "
acc_id"}, {ID: "another_id"}}) |
| 124 So(p.DefaultAccountID, ShouldEqual, "acc_id") |
| 125 |
| 116 done := make(chan struct{}) | 126 done := make(chan struct{}) |
| 117 go func() { | 127 go func() { |
| 118 s.Serve() | 128 s.Serve() |
| 119 close(done) | 129 close(done) |
| 120 }() | 130 }() |
| 121 defer func() { | 131 defer func() { |
| 122 s.Close() | 132 s.Close() |
| 123 <-done | 133 <-done |
| 124 }() | 134 }() |
| 125 | 135 |
| 126 goodRequest := func() *http.Request { | 136 goodRequest := func() *http.Request { |
| 127 return prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthTok
en", map[string]interface{}{ | 137 return prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthTok
en", map[string]interface{}{ |
| 128 » » » » "scopes": []string{"B", "A"}, | 138 » » » » "scopes": []string{"B", "A"}, |
| 129 » » » » "secret": p.Secret, | 139 » » » » "secret": p.Secret, |
| 140 » » » » "account_id": "acc_id", |
| 130 }) | 141 }) |
| 131 } | 142 } |
| 132 | 143 |
| 133 Convey("Happy path", func() { | 144 Convey("Happy path", func() { |
| 134 responses <- &oauth2.Token{ | 145 responses <- &oauth2.Token{ |
| 135 AccessToken: "tok1", | 146 AccessToken: "tok1", |
| 136 Expiry: clock.Now(ctx).Add(30 * time.Minute
), | 147 Expiry: clock.Now(ctx).Add(30 * time.Minute
), |
| 137 } | 148 } |
| 138 So(call(goodRequest()), ShouldEqual, `HTTP 200 (json): {
"access_token":"tok1","expiry":1454502906}`) | 149 So(call(goodRequest()), ShouldEqual, `HTTP 200 (json): {
"access_token":"tok1","expiry":1454502906}`) |
| 139 So(<-requests, ShouldResemble, []string{"A", "B"}) | 150 So(<-requests, ShouldResemble, []string{"A", "B"}) |
| (...skipping 54 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 194 | 205 |
| 195 Convey("Unknown RPC method", func() { | 206 Convey("Unknown RPC method", func() { |
| 196 req := prepReq(p, "/rpc/LuciLocalAuthService.UnknownMeth
od", map[string]interface{}{}) | 207 req := prepReq(p, "/rpc/LuciLocalAuthService.UnknownMeth
od", map[string]interface{}{}) |
| 197 So(call(req), ShouldEqual, `HTTP 404: Unknown RPC method
"UnknownMethod"`) | 208 So(call(req), ShouldEqual, `HTTP 404: Unknown RPC method
"UnknownMethod"`) |
| 198 }) | 209 }) |
| 199 | 210 |
| 200 Convey("No scopes", func() { | 211 Convey("No scopes", func() { |
| 201 req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthTok
en", map[string]interface{}{ | 212 req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthTok
en", map[string]interface{}{ |
| 202 "secret": p.Secret, | 213 "secret": p.Secret, |
| 203 }) | 214 }) |
| 204 » » » So(call(req), ShouldEqual, `HTTP 400: Field "scopes" is
required.`) | 215 » » » So(call(req), ShouldEqual, `HTTP 400: Bad request: field
"scopes" is required.`) |
| 205 }) | 216 }) |
| 206 | 217 |
| 207 Convey("No secret", func() { | 218 Convey("No secret", func() { |
| 208 req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthTok
en", map[string]interface{}{ | 219 req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthTok
en", map[string]interface{}{ |
| 209 "scopes": []string{"B", "A"}, | 220 "scopes": []string{"B", "A"}, |
| 210 }) | 221 }) |
| 211 » » » So(call(req), ShouldEqual, `HTTP 400: Field "secret" is
required.`) | 222 » » » So(call(req), ShouldEqual, `HTTP 400: Bad request: field
"secret" is required.`) |
| 212 }) | 223 }) |
| 213 | 224 |
| 214 Convey("Bad secret", func() { | 225 Convey("Bad secret", func() { |
| 215 req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthTok
en", map[string]interface{}{ | 226 req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthTok
en", map[string]interface{}{ |
| 216 » » » » "scopes": []string{"B", "A"}, | 227 » » » » "scopes": []string{"B", "A"}, |
| 217 » » » » "secret": []byte{0, 1, 2, 3}, | 228 » » » » "secret": []byte{0, 1, 2, 3}, |
| 229 » » » » "account_id": "acc_id", |
| 218 }) | 230 }) |
| 219 So(call(req), ShouldEqual, `HTTP 403: Invalid secret.`) | 231 So(call(req), ShouldEqual, `HTTP 403: Invalid secret.`) |
| 220 }) | 232 }) |
| 221 | 233 |
| 234 Convey("No account ID", func() { |
| 235 req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthTok
en", map[string]interface{}{ |
| 236 "scopes": []string{"B", "A"}, |
| 237 "secret": p.Secret, |
| 238 }) |
| 239 So(call(req), ShouldEqual, `HTTP 400: Bad request: field
"account_id" is required.`) |
| 240 }) |
| 241 |
| 242 Convey("Unknown account ID", func() { |
| 243 req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthTok
en", map[string]interface{}{ |
| 244 "scopes": []string{"B", "A"}, |
| 245 "secret": p.Secret, |
| 246 "account_id": "unknown_acc_id", |
| 247 }) |
| 248 So(call(req), ShouldEqual, `HTTP 404: Unrecognized accou
nt ID "unknown_acc_id".`) |
| 249 }) |
| 250 |
| 222 Convey("Token generator returns fatal error", func() { | 251 Convey("Token generator returns fatal error", func() { |
| 223 responses <- fmt.Errorf("fatal!!111") | 252 responses <- fmt.Errorf("fatal!!111") |
| 224 So(call(goodRequest()), ShouldEqual, `HTTP 200 (json): {
"error_code":-1,"error_message":"fatal!!111"}`) | 253 So(call(goodRequest()), ShouldEqual, `HTTP 200 (json): {
"error_code":-1,"error_message":"fatal!!111"}`) |
| 225 }) | 254 }) |
| 226 | 255 |
| 227 Convey("Token generator returns ErrorWithCode", func() { | 256 Convey("Token generator returns ErrorWithCode", func() { |
| 228 responses <- errWithCode{ | 257 responses <- errWithCode{ |
| 229 error: fmt.Errorf("with code"), | 258 error: fmt.Errorf("with code"), |
| 230 code: 123, | 259 code: 123, |
| 231 } | 260 } |
| (...skipping 53 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 285 panic(err) | 314 panic(err) |
| 286 } | 315 } |
| 287 | 316 |
| 288 tp := "" | 317 tp := "" |
| 289 if resp.Header.Get("Content-Type") == "application/json; charset=utf-8"
{ | 318 if resp.Header.Get("Content-Type") == "application/json; charset=utf-8"
{ |
| 290 tp = " (json)" | 319 tp = " (json)" |
| 291 } | 320 } |
| 292 | 321 |
| 293 return fmt.Sprintf("HTTP %d%s: %s", resp.StatusCode, tp, strings.TrimSpa
ce(string(blob))) | 322 return fmt.Sprintf("HTTP %d%s: %s", resp.StatusCode, tp, strings.TrimSpa
ce(string(blob))) |
| 294 } | 323 } |
| OLD | NEW |