OLD | NEW |
(Empty) | |
| 1 # Copyright 2016 Google Inc. All Rights Reserved. |
| 2 # |
| 3 # Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 # you may not use this file except in compliance with the License. |
| 5 # You may obtain a copy of the License at |
| 6 # |
| 7 # http://www.apache.org/licenses/LICENSE-2.0 |
| 8 # |
| 9 # Unless required by applicable law or agreed to in writing, software |
| 10 # distributed under the License is distributed on an "AS IS" BASIS, |
| 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 # See the License for the specific language governing permissions and |
| 13 # limitations under the License. |
| 14 |
| 15 import copy |
| 16 import mock |
| 17 import json |
| 18 import time |
| 19 import unittest |
| 20 |
| 21 from Crypto import PublicKey |
| 22 from jwkest import ecc |
| 23 from jwkest import jwk |
| 24 from jwkest import jws |
| 25 from test import token_utils |
| 26 |
| 27 from google.api.auth import suppliers |
| 28 from google.api.auth import tokens |
| 29 |
| 30 |
| 31 class AuthenticatorTest(unittest.TestCase): |
| 32 _ec_kid = "ec-key-id" |
| 33 _rsa_kid = "rsa-key-id" |
| 34 |
| 35 _mock_timer = mock.MagicMock() |
| 36 |
| 37 def setUp(self): |
| 38 ec_jwk = jwk.ECKey(use="sig").load_key(ecc.P256) |
| 39 ec_jwk.kid = self._ec_kid |
| 40 |
| 41 rsa_key = jwk.RSAKey(use="sig").load_key(PublicKey.RSA.generate(1024)) |
| 42 rsa_key.kid = self._rsa_kid |
| 43 |
| 44 jwks = jwk.KEYS() |
| 45 jwks._keys.append(ec_jwk) |
| 46 jwks._keys.append(rsa_key) |
| 47 |
| 48 self._issuers_to_provider_ids = {} |
| 49 self._jwks_supplier = mock.MagicMock() |
| 50 self._authenticator = tokens.Authenticator(self._issuers_to_provider_ids, |
| 51 self._jwks_supplier) |
| 52 self._jwks = jwks |
| 53 self._jwks_supplier.supply.return_value = self._jwks |
| 54 |
| 55 self._method_info = mock.MagicMock() |
| 56 self._service_name = "service.name.com" |
| 57 |
| 58 self._jwt_claims = { |
| 59 "aud": ["first.com", "second.com"], |
| 60 "email": "someone@email.com", |
| 61 "exp": int(time.time()) + 10, |
| 62 "iss": "https://issuer.com", |
| 63 "sub": "subject-id" |
| 64 } |
| 65 |
| 66 def test_get_jwt_claims(self): |
| 67 auth_token = token_utils.generate_auth_token(self._jwt_claims, |
| 68 self._jwks._keys, |
| 69 kid=self._ec_kid) |
| 70 actual_jwt_claims = self._authenticator.get_jwt_claims(auth_token) |
| 71 self.assertEqual(self._jwt_claims, actual_jwt_claims) |
| 72 |
| 73 def test_get_jwt_claims_without_kid(self): |
| 74 auth_token = token_utils.generate_auth_token(self._jwt_claims, |
| 75 self._jwks._keys) |
| 76 actual_jwt_claims = self._authenticator.get_jwt_claims(auth_token) |
| 77 self.assertEqual(self._jwt_claims, actual_jwt_claims) |
| 78 |
| 79 def test_required_claims(self): |
| 80 def assert_missing_claim_raise_exception(claim_name): |
| 81 jwt_claims = copy.deepcopy(self._jwt_claims) |
| 82 del jwt_claims[claim_name] |
| 83 auth_token = token_utils.generate_auth_token(jwt_claims, |
| 84 self._jwks._keys, |
| 85 kid=self._ec_kid) |
| 86 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, |
| 87 'Missing "%s" claim' % claim_name): |
| 88 self._authenticator.get_jwt_claims(auth_token) |
| 89 |
| 90 assert_missing_claim_raise_exception("aud") |
| 91 assert_missing_claim_raise_exception("exp") |
| 92 assert_missing_claim_raise_exception("sub") |
| 93 assert_missing_claim_raise_exception("iss") |
| 94 |
| 95 @mock.patch("time.time", _mock_timer) |
| 96 def test_get_jwt_claims_via_caching(self): |
| 97 AuthenticatorTest._mock_timer.return_value = 10 |
| 98 |
| 99 auth_token = token_utils.generate_auth_token(self._jwt_claims, |
| 100 self._jwks._keys) |
| 101 # Populate the decoded result into cache. |
| 102 self._authenticator.get_jwt_claims(auth_token) |
| 103 |
| 104 # Reset the returned JWKS so the signature verification will fail next |
| 105 # time. |
| 106 self._jwks_supplier.supply.return_value = jwk.KEYS() |
| 107 |
| 108 # Forword time by 10 seconds. |
| 109 AuthenticatorTest._mock_timer.return_value += 10 |
| 110 # This call should succeed since the auth_token is cached. |
| 111 self._authenticator.get_jwt_claims(auth_token) |
| 112 |
| 113 # Forword time by 5 minutes. |
| 114 AuthenticatorTest._mock_timer.return_value += 5 * 60 |
| 115 # This call should fail since the cache expires and it needs to re-decode |
| 116 # the auth token with a different key set. |
| 117 with self.assertRaises(suppliers.UnauthenticatedException): |
| 118 self._authenticator.get_jwt_claims(auth_token) |
| 119 |
| 120 def test_auth_token_cache_capacity(self): |
| 121 authenticator = tokens.Authenticator({}, self._jwks_supplier, cache_capacity
=2) |
| 122 |
| 123 self._jwt_claims["email"] = "1@email.com" |
| 124 auth_token1 = token_utils.generate_auth_token(self._jwt_claims, |
| 125 self._jwks._keys) |
| 126 self._jwt_claims["email"] = "2@email.com" |
| 127 auth_token2 = token_utils.generate_auth_token(self._jwt_claims, |
| 128 self._jwks._keys) |
| 129 |
| 130 # Populate the decoded result into cache. |
| 131 authenticator.get_jwt_claims(auth_token1) |
| 132 authenticator.get_jwt_claims(auth_token2) |
| 133 |
| 134 # Reset the returned JWKS so the signature verification will fail next |
| 135 # time. |
| 136 new_ec_jwk = jwk.ECKey(use="sig").load_key(ecc.P256) |
| 137 new_ec_jwk.kid = self._ec_kid |
| 138 new_jwks = jwk.KEYS() |
| 139 new_jwks._keys.append(new_ec_jwk) |
| 140 self._jwks_supplier.supply.return_value = new_jwks |
| 141 |
| 142 # Verify the following calls still succeed since the auth tokens are |
| 143 # cached. |
| 144 authenticator.get_jwt_claims(auth_token1) |
| 145 authenticator.get_jwt_claims(auth_token2) |
| 146 |
| 147 # Populate a third auth token into the cache. |
| 148 self._jwt_claims["email"] = "3@email.com" |
| 149 auth_token3 = token_utils.generate_auth_token(self._jwt_claims, |
| 150 new_jwks._keys) |
| 151 authenticator.get_jwt_claims(auth_token3) |
| 152 |
| 153 # Make sure the first auth token is evicted from the cache since the cache |
| 154 # is full. |
| 155 with self.assertRaises(suppliers.UnauthenticatedException): |
| 156 authenticator.get_jwt_claims(auth_token1) |
| 157 |
| 158 def test_verify_fails(self): |
| 159 auth_token = token_utils.generate_auth_token(self._jwt_claims, |
| 160 self._jwks._keys, |
| 161 kid=self._ec_kid) |
| 162 |
| 163 # Let the _jwks_supplier return a different key than the one we use to sign |
| 164 # the JWT. |
| 165 new_jwk = jwk.ECKey(use="sig").load_key(ecc.P256) |
| 166 new_jwks = jwk.KEYS() |
| 167 new_jwks._keys.append(new_jwk) |
| 168 self._jwks_supplier.supply.return_value = new_jwks |
| 169 |
| 170 with self.assertRaises(suppliers.UnauthenticatedException): |
| 171 self._authenticator.get_jwt_claims(auth_token) |
| 172 |
| 173 def test_authenticate_successfully(self): |
| 174 auth_token = token_utils.generate_auth_token(self._jwt_claims, |
| 175 self._jwks._keys, |
| 176 kid=self._ec_kid) |
| 177 self._method_info.get_allowed_audiences.return_value = ["first.com"] |
| 178 self._issuers_to_provider_ids[self._jwt_claims["iss"]] = "provider-id" |
| 179 actual_user_info = self._authenticator.authenticate(auth_token, |
| 180 self._method_info, |
| 181 "service.name.com") |
| 182 self.assert_user_info(actual_user_info, self._jwt_claims["aud"], |
| 183 self._jwt_claims["email"], self._jwt_claims["sub"], |
| 184 self._jwt_claims["iss"]) |
| 185 |
| 186 def test_authenticate_with_single_audience(self): |
| 187 aud = "first.aud.com" |
| 188 self._jwt_claims["aud"] = aud |
| 189 auth_token = token_utils.generate_auth_token(self._jwt_claims, |
| 190 self._jwks._keys, |
| 191 kid=self._ec_kid) |
| 192 self._issuers_to_provider_ids[self._jwt_claims["iss"]] = "provider-id" |
| 193 actual_user_info = self._authenticator.authenticate(auth_token, |
| 194 self._method_info, aud) |
| 195 self.assertEqual([aud], actual_user_info.audiences) |
| 196 |
| 197 def test_authenticate_with_malformed_claims(self): |
| 198 def assert_malformed_time_claim_raises_exception(claim_name, expiration): |
| 199 jwt_claims = copy.deepcopy(self._jwt_claims) |
| 200 jwt_claims[claim_name] = expiration |
| 201 auth_token = token_utils.generate_auth_token(jwt_claims, |
| 202 self._jwks._keys) |
| 203 message = 'Malformed claim: "%s" must be an integer' % claim_name |
| 204 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): |
| 205 self._authenticator.authenticate(auth_token, self._method_info, |
| 206 "service.name") |
| 207 |
| 208 assert_malformed_time_claim_raises_exception("exp", "1") |
| 209 assert_malformed_time_claim_raises_exception("exp", 1.1) |
| 210 assert_malformed_time_claim_raises_exception("exp", [1]) |
| 211 assert_malformed_time_claim_raises_exception("nbf", "1") |
| 212 assert_malformed_time_claim_raises_exception("nbf", 1.1) |
| 213 assert_malformed_time_claim_raises_exception("nbf", [1]) |
| 214 |
| 215 def test_authenticate_with_expired_auth_token(self): |
| 216 self._jwt_claims["exp"] = long(time.time() - 10) |
| 217 auth_token = token_utils.generate_auth_token(self._jwt_claims, |
| 218 self._jwks._keys) |
| 219 message = "The auth token has already expired" |
| 220 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): |
| 221 self._authenticator.authenticate(auth_token, |
| 222 self._method_info, |
| 223 "service.name") |
| 224 |
| 225 def test_authenticate_with_nbf_claim(self): |
| 226 # Set the "nbf" claim to some time in the future. |
| 227 self._jwt_claims["nbf"] = long(time.time() + 5) |
| 228 auth_token = token_utils.generate_auth_token(self._jwt_claims, |
| 229 self._jwks._keys) |
| 230 message = 'Current time is less than the "nbf" time' |
| 231 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): |
| 232 self._authenticator.authenticate(auth_token, self._method_info, |
| 233 "service.name") |
| 234 |
| 235 def test_authenticate_with_service_name_as_audience(self): |
| 236 self._jwt_claims["aud"].append(self._service_name) |
| 237 self._issuers_to_provider_ids[self._jwt_claims["iss"]] = "provider-id" |
| 238 self._method_info.get_allowed_audiences.return_value = [] |
| 239 auth_token = token_utils.generate_auth_token(self._jwt_claims, |
| 240 self._jwks._keys, |
| 241 kid=self._ec_kid) |
| 242 actual_user_info = self._authenticator.authenticate(auth_token, |
| 243 self._method_info, |
| 244 self._service_name) |
| 245 self.assert_user_info(actual_user_info, self._jwt_claims["aud"], |
| 246 self._jwt_claims["email"], self._jwt_claims["sub"], |
| 247 self._jwt_claims["iss"]) |
| 248 |
| 249 def test_authenticate_with_disallowed_provider_id(self): |
| 250 auth_token = token_utils.generate_auth_token(self._jwt_claims, |
| 251 self._jwks._keys, |
| 252 kid=self._ec_kid) |
| 253 self._method_info.is_provider_allowed.return_value = False |
| 254 self._issuers_to_provider_ids[self._jwt_claims["iss"]] = "id" |
| 255 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, |
| 256 "The requested method does not allow provider " |
| 257 "id: id"): |
| 258 self._authenticator.authenticate(auth_token, self._method_info, |
| 259 self._service_name) |
| 260 |
| 261 def test_authenticate_with_disallowed_audiences(self): |
| 262 auth_token = token_utils.generate_auth_token(self._jwt_claims, |
| 263 self._jwks._keys, |
| 264 kid=self._ec_kid) |
| 265 self._method_info.get_allowed_audiences.return_value = [] |
| 266 self._issuers_to_provider_ids[self._jwt_claims["iss"]] = "project-id" |
| 267 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, |
| 268 "Audiences not allowed"): |
| 269 self._authenticator.authenticate(auth_token, self._method_info, |
| 270 self._service_name) |
| 271 |
| 272 def test_unicode_decode_error(self): |
| 273 auth_token = "ya29.CjA8A3Hrca1hCCvRg69U3Tg85CG5pRqZj7gOJUsicpRafWAW63zvg6a0Z
M6wZ5mJwM0" |
| 274 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, |
| 275 "Cannot decode the auth token"): |
| 276 self._authenticator.authenticate(auth_token, None, None) |
| 277 |
| 278 def assert_user_info(self, actual_user_info, audiences, email, subject_id, |
| 279 issuer): |
| 280 self.assertEqual(audiences, actual_user_info.audiences) |
| 281 self.assertEqual(email, actual_user_info.email) |
| 282 self.assertEqual(subject_id, actual_user_info.subject_id) |
| 283 self.assertEqual(issuer, actual_user_info.issuer) |
OLD | NEW |