OLD | NEW |
(Empty) | |
| 1 # Copyright 2010 Google Inc. All Rights Reserved. |
| 2 # |
| 3 # Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 # you may not use this file except in compliance with the License. |
| 5 # You may obtain a copy of the License at |
| 6 # |
| 7 # http://www.apache.org/licenses/LICENSE-2.0 |
| 8 # |
| 9 # Unless required by applicable law or agreed to in writing, software |
| 10 # distributed under the License is distributed on an "AS IS" BASIS, |
| 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 # See the License for the specific language governing permissions and |
| 13 # limitations under the License. |
| 14 |
| 15 """Unit tests for oauth2_client.""" |
| 16 |
| 17 import datetime |
| 18 import logging |
| 19 import os |
| 20 import sys |
| 21 import unittest |
| 22 import urllib2 |
| 23 import urlparse |
| 24 from stat import S_IMODE |
| 25 from StringIO import StringIO |
| 26 |
| 27 test_bin_dir = os.path.dirname(os.path.realpath(sys.argv[0])) |
| 28 |
| 29 lib_dir = os.path.join(test_bin_dir, '..') |
| 30 sys.path.insert(0, lib_dir) |
| 31 |
| 32 # Needed for boto.cacerts |
| 33 boto_lib_dir = os.path.join(test_bin_dir, '..', 'boto') |
| 34 sys.path.insert(0, boto_lib_dir) |
| 35 |
| 36 import oauth2_client |
| 37 |
| 38 LOG = logging.getLogger('oauth2_client_test') |
| 39 |
| 40 class MockOpener: |
| 41 def __init__(self): |
| 42 self.reset() |
| 43 |
| 44 def reset(self): |
| 45 self.open_error = None |
| 46 self.open_result = None |
| 47 self.open_capture_url = None |
| 48 self.open_capture_data = None |
| 49 |
| 50 def open(self, req, data=None): |
| 51 self.open_capture_url = req.get_full_url() |
| 52 self.open_capture_data = req.get_data() |
| 53 if self.open_error is not None: |
| 54 raise self.open_error |
| 55 else: |
| 56 return StringIO(self.open_result) |
| 57 |
| 58 |
| 59 class MockDateTime: |
| 60 def __init__(self): |
| 61 self.mock_now = None |
| 62 |
| 63 def utcnow(self): |
| 64 return self.mock_now |
| 65 |
| 66 |
| 67 class OAuth2ClientTest(unittest.TestCase): |
| 68 def setUp(self): |
| 69 self.opener = MockOpener() |
| 70 self.mock_datetime = MockDateTime() |
| 71 self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) |
| 72 self.mock_datetime.mock_now = self.start_time |
| 73 self.client = oauth2_client.OAuth2Client( |
| 74 oauth2_client.OAuth2Provider( |
| 75 'Sample OAuth Provider', |
| 76 'https://provider.example.com/oauth/provider?mode=authorize', |
| 77 'https://provider.example.com/oauth/provider?mode=token'), |
| 78 'clid', 'clsecret', |
| 79 url_opener=self.opener, datetime_strategy=self.mock_datetime) |
| 80 |
| 81 def testFetchAccessToken(self): |
| 82 refresh_token = '1/ZaBrxdPl77Bi4jbsO7x-NmATiaQZnWPB51nTvo8n9Sw' |
| 83 access_token = '1/aalskfja-asjwerwj' |
| 84 self.opener.open_result = ( |
| 85 '{"access_token":"%s","expires_in":3600}' % access_token) |
| 86 cred = oauth2_client.RefreshToken(self.client, refresh_token) |
| 87 token = self.client.FetchAccessToken(cred) |
| 88 |
| 89 self.assertEquals( |
| 90 self.opener.open_capture_url, |
| 91 'https://provider.example.com/oauth/provider?mode=token') |
| 92 self.assertEquals({ |
| 93 'grant_type': ['refresh_token'], |
| 94 'client_id': ['clid'], |
| 95 'client_secret': ['clsecret'], |
| 96 'refresh_token': [refresh_token]}, |
| 97 urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True, |
| 98 strict_parsing=True)) |
| 99 self.assertEquals(access_token, token.token) |
| 100 self.assertEquals( |
| 101 datetime.datetime(2011, 3, 1, 11, 25, 13, 300826), |
| 102 token.expiry) |
| 103 |
| 104 def testFetchAccessTokenFailsForBadJsonResponse(self): |
| 105 self.opener.open_result = 'blah' |
| 106 cred = oauth2_client.RefreshToken(self.client, 'abc123') |
| 107 self.assertRaises( |
| 108 oauth2_client.AccessTokenRefreshError, self.client.FetchAccessToken, cre
d) |
| 109 |
| 110 def testFetchAccessTokenFailsForErrorResponse(self): |
| 111 self.opener.open_error = urllib2.HTTPError( |
| 112 None, 400, 'Bad Request', None, StringIO('{"error": "invalid token"}')) |
| 113 cred = oauth2_client.RefreshToken(self.client, 'abc123') |
| 114 self.assertRaises( |
| 115 oauth2_client.AccessTokenRefreshError, self.client.FetchAccessToken, cre
d) |
| 116 |
| 117 def testFetchAccessTokenFailsForHttpError(self): |
| 118 self.opener.open_result = urllib2.HTTPError( |
| 119 'foo', 400, 'Bad Request', None, None) |
| 120 cred = oauth2_client.RefreshToken(self.client, 'abc123') |
| 121 self.assertRaises( |
| 122 oauth2_client.AccessTokenRefreshError, self.client.FetchAccessToken, cre
d) |
| 123 |
| 124 def testGetAccessToken(self): |
| 125 refresh_token = 'ref_token' |
| 126 access_token_1 = 'abc123' |
| 127 self.opener.open_result = ( |
| 128 '{"access_token":"%s",' '"expires_in":3600}' % access_token_1) |
| 129 cred = oauth2_client.RefreshToken(self.client, refresh_token) |
| 130 |
| 131 token_1 = self.client.GetAccessToken(cred) |
| 132 |
| 133 # There's no access token in the cache; verify that we fetched a fresh |
| 134 # token. |
| 135 self.assertEquals({ |
| 136 'grant_type': ['refresh_token'], |
| 137 'client_id': ['clid'], |
| 138 'client_secret': ['clsecret'], |
| 139 'refresh_token': [refresh_token]}, |
| 140 urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True, |
| 141 strict_parsing=True)) |
| 142 self.assertEquals(access_token_1, token_1.token) |
| 143 self.assertEquals(self.start_time + datetime.timedelta(minutes=60), |
| 144 token_1.expiry) |
| 145 |
| 146 # Advance time by less than expiry time, and fetch another token. |
| 147 self.opener.reset() |
| 148 self.mock_datetime.mock_now = ( |
| 149 self.start_time + datetime.timedelta(minutes=55)) |
| 150 token_2 = self.client.GetAccessToken(cred) |
| 151 |
| 152 # Since the access token wasn't expired, we get the cache token, and there |
| 153 # was no refresh request. |
| 154 self.assertEquals(token_1, token_2) |
| 155 self.assertEquals(access_token_1, token_2.token) |
| 156 self.assertEquals(None, self.opener.open_capture_url) |
| 157 self.assertEquals(None, self.opener.open_capture_data) |
| 158 |
| 159 # Advance time past expiry time, and fetch another token. |
| 160 self.opener.reset() |
| 161 self.mock_datetime.mock_now = ( |
| 162 self.start_time + datetime.timedelta(minutes=55, seconds=1)) |
| 163 access_token_2 = 'zyx456' |
| 164 self.opener.open_result = ( |
| 165 '{"access_token":"%s",' '"expires_in":3600}' % access_token_2) |
| 166 token_3 = self.client.GetAccessToken(cred) |
| 167 |
| 168 # This should have resulted in a refresh request and a fresh access token. |
| 169 self.assertEquals({ |
| 170 'grant_type': ['refresh_token'], |
| 171 'client_id': ['clid'], |
| 172 'client_secret': ['clsecret'], |
| 173 'refresh_token': [refresh_token]}, |
| 174 urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True, |
| 175 strict_parsing=True)) |
| 176 self.assertEquals(access_token_2, token_3.token) |
| 177 self.assertEquals(self.mock_datetime.mock_now + datetime.timedelta(minutes=6
0), |
| 178 token_3.expiry) |
| 179 |
| 180 def testGetAuthorizationUri(self): |
| 181 authn_uri = self.client.GetAuthorizationUri( |
| 182 'https://www.example.com/oauth/redir?mode=approve%20me', |
| 183 ('scope_foo', 'scope_bar'), |
| 184 {'state': 'this and that & sundry'}) |
| 185 |
| 186 uri_parts = urlparse.urlsplit(authn_uri) |
| 187 self.assertEquals(('https', 'provider.example.com', '/oauth/provider'), |
| 188 uri_parts[:3]) |
| 189 |
| 190 self.assertEquals({ |
| 191 'response_type': ['code'], |
| 192 'client_id': ['clid'], |
| 193 'redirect_uri': |
| 194 ['https://www.example.com/oauth/redir?mode=approve%20me'], |
| 195 'scope': ['scope_foo scope_bar'], |
| 196 'state': ['this and that & sundry'], |
| 197 'mode': ['authorize']}, |
| 198 urlparse.parse_qs(uri_parts[3])) |
| 199 |
| 200 def testExchangeAuthorizationCode(self): |
| 201 code = 'codeABQ1234' |
| 202 exp_refresh_token = 'ref_token42' |
| 203 exp_access_token = 'access_tokenXY123' |
| 204 self.opener.open_result = ( |
| 205 '{"access_token":"%s","expires_in":3600,"refresh_token":"%s"}' |
| 206 % (exp_access_token, exp_refresh_token)) |
| 207 |
| 208 refresh_token, access_token = self.client.ExchangeAuthorizationCode( |
| 209 code, 'urn:ietf:wg:oauth:2.0:oob', ('scope1', 'scope2')) |
| 210 |
| 211 self.assertEquals({ |
| 212 'grant_type': ['authorization_code'], |
| 213 'client_id': ['clid'], |
| 214 'client_secret': ['clsecret'], |
| 215 'code': [code], |
| 216 'redirect_uri': ['urn:ietf:wg:oauth:2.0:oob'], |
| 217 'scope': ['scope1 scope2'] }, |
| 218 urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True, |
| 219 strict_parsing=True)) |
| 220 self.assertEquals(exp_access_token, access_token.token) |
| 221 self.assertEquals(self.start_time + datetime.timedelta(minutes=60), |
| 222 access_token.expiry) |
| 223 |
| 224 self.assertEquals(self.client, refresh_token.oauth2_client) |
| 225 self.assertEquals(exp_refresh_token, refresh_token.refresh_token) |
| 226 |
| 227 # Check that the access token was put in the cache. |
| 228 cached_token = self.client.access_token_cache.GetToken( |
| 229 refresh_token.CacheKey()) |
| 230 self.assertEquals(access_token, cached_token) |
| 231 |
| 232 |
| 233 class AccessTokenTest(unittest.TestCase): |
| 234 |
| 235 def testShouldRefresh(self): |
| 236 mock_datetime = MockDateTime() |
| 237 start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) |
| 238 expiry = start + datetime.timedelta(minutes=60) |
| 239 token = oauth2_client.AccessToken( |
| 240 'foo', expiry, datetime_strategy=mock_datetime) |
| 241 |
| 242 mock_datetime.mock_now = start |
| 243 self.assertFalse(token.ShouldRefresh()) |
| 244 |
| 245 mock_datetime.mock_now = start + datetime.timedelta(minutes=54) |
| 246 self.assertFalse(token.ShouldRefresh()) |
| 247 |
| 248 mock_datetime.mock_now = start + datetime.timedelta(minutes=55) |
| 249 self.assertFalse(token.ShouldRefresh()) |
| 250 |
| 251 mock_datetime.mock_now = start + datetime.timedelta( |
| 252 minutes=55, seconds=1) |
| 253 self.assertTrue(token.ShouldRefresh()) |
| 254 |
| 255 mock_datetime.mock_now = start + datetime.timedelta( |
| 256 minutes=61) |
| 257 self.assertTrue(token.ShouldRefresh()) |
| 258 |
| 259 mock_datetime.mock_now = start + datetime.timedelta(minutes=58) |
| 260 self.assertFalse(token.ShouldRefresh(time_delta=120)) |
| 261 |
| 262 mock_datetime.mock_now = start + datetime.timedelta( |
| 263 minutes=58, seconds=1) |
| 264 self.assertTrue(token.ShouldRefresh(time_delta=120)) |
| 265 |
| 266 def testShouldRefreshNoExpiry(self): |
| 267 mock_datetime = MockDateTime() |
| 268 start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) |
| 269 token = oauth2_client.AccessToken( |
| 270 'foo', None, datetime_strategy=mock_datetime) |
| 271 |
| 272 mock_datetime.mock_now = start |
| 273 self.assertFalse(token.ShouldRefresh()) |
| 274 |
| 275 mock_datetime.mock_now = start + datetime.timedelta( |
| 276 minutes=472) |
| 277 self.assertFalse(token.ShouldRefresh()) |
| 278 |
| 279 def testSerialization(self): |
| 280 expiry = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) |
| 281 token = oauth2_client.AccessToken('foo', expiry) |
| 282 serialized_token = token.Serialize() |
| 283 LOG.debug('testSerialization: serialized_token=%s' % serialized_token) |
| 284 |
| 285 token2 = oauth2_client.AccessToken.UnSerialize(serialized_token) |
| 286 self.assertEquals(token, token2) |
| 287 |
| 288 |
| 289 class RefreshTokenTest(unittest.TestCase): |
| 290 def setUp(self): |
| 291 self.opener = MockOpener() |
| 292 self.mock_datetime = MockDateTime() |
| 293 self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) |
| 294 self.mock_datetime.mock_now = self.start_time |
| 295 self.client = oauth2_client.OAuth2Client( |
| 296 oauth2_client.OAuth2Provider( |
| 297 'Sample OAuth Provider', |
| 298 'https://provider.example.com/oauth/provider?mode=authorize', |
| 299 'https://provider.example.com/oauth/provider?mode=token'), |
| 300 'clid', 'clsecret', |
| 301 url_opener=self.opener, datetime_strategy=self.mock_datetime) |
| 302 |
| 303 self.cred = oauth2_client.RefreshToken(self.client, 'ref_token_abc123') |
| 304 |
| 305 def testUniqeId(self): |
| 306 cred_id = self.cred.CacheKey() |
| 307 self.assertEquals('0720afed6871f12761fbea3271f451e6ba184bf5', cred_id) |
| 308 |
| 309 def testGetAuthorizationHeader(self): |
| 310 access_token = 'access_123' |
| 311 self.opener.open_result = ( |
| 312 '{"access_token":"%s","expires_in":3600}' % access_token) |
| 313 |
| 314 self.assertEquals('Bearer %s' % access_token, |
| 315 self.cred.GetAuthorizationHeader()) |
| 316 |
| 317 |
| 318 class FileSystemTokenCacheTest(unittest.TestCase): |
| 319 |
| 320 def setUp(self): |
| 321 self.cache = oauth2_client.FileSystemTokenCache() |
| 322 self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) |
| 323 self.token_1 = oauth2_client.AccessToken('token1', self.start_time) |
| 324 self.token_2 = oauth2_client.AccessToken( |
| 325 'token2', self.start_time + datetime.timedelta(seconds=492)) |
| 326 self.key = 'token1key' |
| 327 |
| 328 def tearDown(self): |
| 329 try: |
| 330 os.unlink(self.cache.CacheFileName(self.key)) |
| 331 except: |
| 332 pass |
| 333 |
| 334 def testPut(self): |
| 335 self.cache.PutToken(self.key, self.token_1) |
| 336 # Assert that the cache file exists and has correct permissions. |
| 337 self.assertEquals( |
| 338 0600, S_IMODE(os.stat(self.cache.CacheFileName(self.key)).st_mode)) |
| 339 |
| 340 def testPutGet(self): |
| 341 # No cache file present. |
| 342 self.assertEquals(None, self.cache.GetToken(self.key)) |
| 343 |
| 344 # Put a token |
| 345 self.cache.PutToken(self.key, self.token_1) |
| 346 cached_token = self.cache.GetToken(self.key) |
| 347 self.assertEquals(self.token_1, cached_token) |
| 348 |
| 349 # Put a different token |
| 350 self.cache.PutToken(self.key, self.token_2) |
| 351 cached_token = self.cache.GetToken(self.key) |
| 352 self.assertEquals(self.token_2, cached_token) |
| 353 |
| 354 def testGetBadFile(self): |
| 355 f = open(self.cache.CacheFileName(self.key), 'w') |
| 356 f.write('blah') |
| 357 f.close() |
| 358 self.assertEquals(None, self.cache.GetToken(self.key)) |
| 359 |
| 360 def testCacheFileName(self): |
| 361 cache = oauth2_client.FileSystemTokenCache( |
| 362 path_pattern='/var/run/ccache/token.%(uid)s.%(key)s') |
| 363 self.assertEquals('/var/run/ccache/token.%d.abc123' % os.getuid(), |
| 364 cache.CacheFileName('abc123')) |
| 365 |
| 366 cache = oauth2_client.FileSystemTokenCache( |
| 367 path_pattern='/var/run/ccache/token.%(key)s') |
| 368 self.assertEquals('/var/run/ccache/token.abc123', |
| 369 cache.CacheFileName('abc123')) |
| 370 |
| 371 |
| 372 if __name__ == '__main__': |
| 373 logging.basicConfig(level=logging.DEBUG) |
| 374 unittest.main() |
OLD | NEW |