OLD | NEW |
(Empty) | |
| 1 # Copyright 2016 Google Inc. All Rights Reserved. |
| 2 # |
| 3 # Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 # you may not use this file except in compliance with the License. |
| 5 # You may obtain a copy of the License at |
| 6 # |
| 7 # http://www.apache.org/licenses/LICENSE-2.0 |
| 8 # |
| 9 # Unless required by applicable law or agreed to in writing, software |
| 10 # distributed under the License is distributed on an "AS IS" BASIS, |
| 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 # See the License for the specific language governing permissions and |
| 13 # limitations under the License. |
| 14 |
| 15 import copy |
| 16 import flask |
| 17 import mock |
| 18 import os |
| 19 import ssl |
| 20 import threading |
| 21 import time |
| 22 import unittest |
| 23 |
| 24 from Crypto import PublicKey |
| 25 from jwkest import ecc |
| 26 from jwkest import jwk |
| 27 from test import token_utils |
| 28 |
| 29 from google.api import auth |
| 30 from google.api.auth import suppliers |
| 31 from google.api.auth import tokens |
| 32 |
| 33 |
| 34 class IntegrationTest(unittest.TestCase): |
| 35 |
| 36 _CURRENT_TIME = int(time.time()) |
| 37 _PORT = 8080 |
| 38 _ISSUER = "https://localhost:%d" % _PORT |
| 39 _PROVIDER_ID = "localhost" |
| 40 _INVALID_X509_PATH = "invalid-x509" |
| 41 _JWKS_PATH = "jwks" |
| 42 _SERVICE_NAME = "service@name.com" |
| 43 _X509_PATH = "x509" |
| 44 |
| 45 _JWT_CLAIMS = { |
| 46 "aud": ["https://aud1.local.host", "https://aud2.local.host"], |
| 47 "exp": _CURRENT_TIME + 60, |
| 48 "email": "user@local.host", |
| 49 "iss": _ISSUER, |
| 50 "sub": "subject-id" |
| 51 } |
| 52 |
| 53 _ec_jwk = jwk.ECKey(use="sig").load_key(ecc.P256) |
| 54 _rsa_key = jwk.RSAKey(use="sig").load_key(PublicKey.RSA.generate(1024)) |
| 55 |
| 56 _ec_jwk.kid = "ec-key-id" |
| 57 _rsa_key.kid = "rsa-key-id" |
| 58 |
| 59 _mock_timer = mock.MagicMock() |
| 60 |
| 61 _jwks = jwk.KEYS() |
| 62 _jwks._keys.append(_ec_jwk) |
| 63 _jwks._keys.append(_rsa_key) |
| 64 |
| 65 _AUTH_TOKEN = token_utils.generate_auth_token(_JWT_CLAIMS, _jwks._keys, |
| 66 alg="RS256", kid=_rsa_key.kid) |
| 67 |
| 68 |
| 69 @classmethod |
| 70 def setUpClass(cls): |
| 71 dirname = os.path.dirname(os.path.realpath(__file__)) |
| 72 cls._cert_file = os.path.join(dirname, "ssl.cert") |
| 73 cls._key_file = os.path.join(dirname, "ssl.key") |
| 74 os.environ["REQUESTS_CA_BUNDLE"] = cls._cert_file |
| 75 |
| 76 rest_server = cls._RestServer() |
| 77 rest_server.start() |
| 78 |
| 79 def setUp(self): |
| 80 self._provider_ids = {} |
| 81 self._configs = {} |
| 82 self._authenticator = auth.create_authenticator(self._provider_ids, |
| 83 self._configs) |
| 84 |
| 85 self._auth_info = mock.MagicMock() |
| 86 self._auth_info.is_provider_allowed.return_value = True |
| 87 self._auth_info.get_allowed_audiences.return_value = [ |
| 88 "https://aud1.local.host" |
| 89 ] |
| 90 |
| 91 def test_verify_auth_token_with_jwks(self): |
| 92 url = get_url(IntegrationTest._JWKS_PATH) |
| 93 self._provider_ids[self._ISSUER] = self._PROVIDER_ID |
| 94 self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig(False, |
| 95 url) |
| 96 user_info = self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, |
| 97 self._auth_info, |
| 98 IntegrationTest._SERVICE_NAME) |
| 99 self._assert_user_info_equals(tokens.UserInfo(IntegrationTest._JWT_CLAIMS), |
| 100 user_info) |
| 101 |
| 102 def test_authenticate_auth_token_with_bad_signature(self): |
| 103 new_rsa_key = jwk.RSAKey(use="sig").load_key(PublicKey.RSA.generate(2048)) |
| 104 kid = IntegrationTest._rsa_key.kid |
| 105 new_rsa_key.kid = kid |
| 106 new_jwks = jwk.KEYS() |
| 107 new_jwks._keys.append(new_rsa_key) |
| 108 auth_token = token_utils.generate_auth_token(IntegrationTest._JWT_CLAIMS, |
| 109 new_jwks._keys, alg="RS256", |
| 110 kid=kid) |
| 111 url = get_url(IntegrationTest._JWKS_PATH) |
| 112 self._provider_ids[self._ISSUER] = self._PROVIDER_ID |
| 113 self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig(False, |
| 114 url) |
| 115 message = "Signature verification failed" |
| 116 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): |
| 117 self._authenticator.authenticate(auth_token, self._auth_info, |
| 118 IntegrationTest._SERVICE_NAME) |
| 119 |
| 120 def test_verify_auth_token_with_x509(self): |
| 121 url = get_url(IntegrationTest._X509_PATH) |
| 122 self._provider_ids[self._ISSUER] = self._PROVIDER_ID |
| 123 self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig(False, |
| 124 url) |
| 125 user_info = self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, |
| 126 self._auth_info, |
| 127 IntegrationTest._SERVICE_NAME) |
| 128 self._assert_user_info_equals(tokens.UserInfo(IntegrationTest._JWT_CLAIMS), |
| 129 user_info) |
| 130 |
| 131 def test_verify_auth_token_with_invalid_x509(self): |
| 132 url = get_url(IntegrationTest._INVALID_X509_PATH) |
| 133 self._provider_ids[self._ISSUER] = self._PROVIDER_ID |
| 134 self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig(False, |
| 135 url) |
| 136 message = "Cannot load X.509 certificate" |
| 137 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): |
| 138 self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, |
| 139 self._auth_info, |
| 140 IntegrationTest._SERVICE_NAME) |
| 141 |
| 142 def test_openid_discovery(self): |
| 143 self._provider_ids[self._ISSUER] = self._PROVIDER_ID |
| 144 self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig(True, |
| 145 None) |
| 146 user_info = self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, |
| 147 self._auth_info, |
| 148 IntegrationTest._SERVICE_NAME) |
| 149 self._assert_user_info_equals(tokens.UserInfo(IntegrationTest._JWT_CLAIMS), |
| 150 user_info) |
| 151 |
| 152 def test_openid_discovery_failed(self): |
| 153 self._provider_ids[self._ISSUER] = self._PROVIDER_ID |
| 154 self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig(False, |
| 155 None) |
| 156 message = ("Cannot find the `jwks_uri` for issuer %s" % |
| 157 IntegrationTest._ISSUER) |
| 158 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): |
| 159 self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, |
| 160 self._auth_info, |
| 161 IntegrationTest._SERVICE_NAME) |
| 162 |
| 163 def test_authenticate_with_malformed_auth_code(self): |
| 164 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, |
| 165 "Cannot decode the auth token"): |
| 166 self._authenticator.authenticate("invalid-auth-code", self._auth_info, |
| 167 IntegrationTest._SERVICE_NAME) |
| 168 |
| 169 def test_authenticate_with_disallowed_issuer(self): |
| 170 url = get_url(IntegrationTest._JWKS_PATH) |
| 171 self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig(False, |
| 172 url) |
| 173 message = "Unknown issuer: " + self._ISSUER |
| 174 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): |
| 175 self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, |
| 176 self._auth_info, |
| 177 IntegrationTest._SERVICE_NAME) |
| 178 |
| 179 def test_authenticate_with_unknown_issuer(self): |
| 180 message = ("Cannot find the `jwks_uri` for issuer %s: " |
| 181 "either the issuer is unknown") % IntegrationTest._ISSUER |
| 182 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): |
| 183 self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, |
| 184 self._auth_info, |
| 185 IntegrationTest._SERVICE_NAME) |
| 186 |
| 187 def test_authenticate_with_invalid_audience(self): |
| 188 url = get_url(IntegrationTest._JWKS_PATH) |
| 189 self._provider_ids[self._ISSUER] = self._PROVIDER_ID |
| 190 self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig(False, |
| 191 url) |
| 192 self._auth_info.get_allowed_audiences.return_value = [] |
| 193 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, |
| 194 "Audiences not allowed"): |
| 195 self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, |
| 196 self._auth_info, |
| 197 IntegrationTest._SERVICE_NAME) |
| 198 |
| 199 @mock.patch("time.time", _mock_timer) |
| 200 def test_authenticate_with_expired_auth_token(self): |
| 201 url = get_url(IntegrationTest._JWKS_PATH) |
| 202 self._provider_ids[self._ISSUER] = self._PROVIDER_ID |
| 203 self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig(False, |
| 204 url) |
| 205 IntegrationTest._mock_timer.return_value = 0 |
| 206 |
| 207 # Create an auth token that expires in 10 seconds. |
| 208 jwt_claims = copy.deepcopy(IntegrationTest._JWT_CLAIMS) |
| 209 jwt_claims["exp"] = time.time() + 10 |
| 210 auth_token = token_utils.generate_auth_token(jwt_claims, |
| 211 IntegrationTest._jwks._keys, |
| 212 alg="RS256", |
| 213 kid=IntegrationTest._rsa_key.ki
d) |
| 214 |
| 215 # Verify that the auth token can be authenticated successfully. |
| 216 self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, |
| 217 self._auth_info, |
| 218 IntegrationTest._SERVICE_NAME) |
| 219 |
| 220 # Advance the timer by 20 seconds and make sure the token is expired. |
| 221 IntegrationTest._mock_timer.return_value += 20 |
| 222 message = "The auth token has already expired" |
| 223 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): |
| 224 self._authenticator.authenticate(auth_token, self._auth_info, |
| 225 IntegrationTest._SERVICE_NAME) |
| 226 |
| 227 def test_invalid_openid_discovery_url(self): |
| 228 issuer = "https://invalid.issuer" |
| 229 self._provider_ids[self._ISSUER] = self._PROVIDER_ID |
| 230 self._configs[issuer] = suppliers.IssuerUriConfig(True, None) |
| 231 |
| 232 jwt_claims = copy.deepcopy(IntegrationTest._JWT_CLAIMS) |
| 233 jwt_claims["iss"] = issuer |
| 234 auth_token = token_utils.generate_auth_token(jwt_claims, |
| 235 IntegrationTest._jwks._keys, |
| 236 alg="RS256", |
| 237 kid=IntegrationTest._rsa_key.ki
d) |
| 238 message = "Cannot discover the jwks uri" |
| 239 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): |
| 240 self._authenticator.authenticate(auth_token, self._auth_info, |
| 241 IntegrationTest._SERVICE_NAME) |
| 242 |
| 243 def test_invalid_jwks_uri(self): |
| 244 url = "https://invalid.jwks.uri" |
| 245 self._provider_ids[self._ISSUER] = self._PROVIDER_ID |
| 246 self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig(False, |
| 247 url) |
| 248 message = "Cannot retrieve valid verification keys from the `jwks_uri`" |
| 249 with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): |
| 250 self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, |
| 251 self._auth_info, |
| 252 IntegrationTest._SERVICE_NAME) |
| 253 |
| 254 def _assert_user_info_equals(self, expected, actual): |
| 255 self.assertEqual(expected.audiences, actual.audiences) |
| 256 self.assertEqual(expected.email, actual.email) |
| 257 self.assertEqual(expected.subject_id, actual.subject_id) |
| 258 self.assertEqual(expected.issuer, actual.issuer) |
| 259 |
| 260 |
| 261 class _RestServer(object): |
| 262 |
| 263 def __init__(self): |
| 264 app = flask.Flask("integration-test-server") |
| 265 |
| 266 @app.route("/" + IntegrationTest._JWKS_PATH) |
| 267 def get_json_web_key_set(): # pylint: disable=unused-variable |
| 268 return IntegrationTest._jwks.dump_jwks() |
| 269 |
| 270 @app.route("/" + IntegrationTest._X509_PATH) |
| 271 def get_x509_certificates(): # pylint: disable=unused-variable |
| 272 cert = IntegrationTest._rsa_key.key.publickey().exportKey("PEM") |
| 273 return flask.jsonify({IntegrationTest._rsa_key.kid: cert}) |
| 274 |
| 275 @app.route("/" + IntegrationTest._INVALID_X509_PATH) |
| 276 def get_invalid_x509_certificates(): # pylint: disable=unused-variable |
| 277 return flask.jsonify({IntegrationTest._rsa_key.kid: "invalid cert"}) |
| 278 |
| 279 @app.route("/.well-known/openid-configuration") |
| 280 def get_openid_configuration(): # pylint: disable=unused-variable |
| 281 return flask.jsonify({"jwks_uri": get_url(IntegrationTest._JWKS_PATH)}) |
| 282 |
| 283 self._application = app |
| 284 |
| 285 def start(self): |
| 286 def run_app(): |
| 287 ssl_context = (IntegrationTest._cert_file, IntegrationTest._key_file) |
| 288 self._application.run(port=IntegrationTest._PORT, |
| 289 ssl_context=ssl_context) |
| 290 |
| 291 thread = threading.Thread(target=run_app, args=()) |
| 292 thread.daemon = True |
| 293 thread.start() |
| 294 |
| 295 |
| 296 def get_url(path): |
| 297 return "https://localhost:%d/%s" % (IntegrationTest._PORT, path) |
OLD | NEW |