OLD | NEW |
| (Empty) |
1 # Copyright 2010 Google Inc. All Rights Reserved. | |
2 # | |
3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 # you may not use this file except in compliance with the License. | |
5 # You may obtain a copy of the License at | |
6 # | |
7 # http://www.apache.org/licenses/LICENSE-2.0 | |
8 # | |
9 # Unless required by applicable law or agreed to in writing, software | |
10 # distributed under the License is distributed on an "AS IS" BASIS, | |
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 # See the License for the specific language governing permissions and | |
13 # limitations under the License. | |
14 | |
15 """Unit tests for oauth2_client.""" | |
16 | |
17 import datetime | |
18 import logging | |
19 import os | |
20 import sys | |
21 import unittest | |
22 import urllib2 | |
23 import urlparse | |
24 from stat import S_IMODE | |
25 from StringIO import StringIO | |
26 | |
27 test_bin_dir = os.path.dirname(os.path.realpath(sys.argv[0])) | |
28 | |
29 lib_dir = os.path.join(test_bin_dir, '..') | |
30 sys.path.insert(0, lib_dir) | |
31 | |
32 # Needed for boto.cacerts | |
33 boto_lib_dir = os.path.join(test_bin_dir, '..', 'boto') | |
34 sys.path.insert(0, boto_lib_dir) | |
35 | |
36 import oauth2_client | |
37 | |
38 LOG = logging.getLogger('oauth2_client_test') | |
39 | |
40 class MockOpener: | |
41 def __init__(self): | |
42 self.reset() | |
43 | |
44 def reset(self): | |
45 self.open_error = None | |
46 self.open_result = None | |
47 self.open_capture_url = None | |
48 self.open_capture_data = None | |
49 | |
50 def open(self, req, data=None): | |
51 self.open_capture_url = req.get_full_url() | |
52 self.open_capture_data = req.get_data() | |
53 if self.open_error is not None: | |
54 raise self.open_error | |
55 else: | |
56 return StringIO(self.open_result) | |
57 | |
58 | |
59 class MockDateTime: | |
60 def __init__(self): | |
61 self.mock_now = None | |
62 | |
63 def utcnow(self): | |
64 return self.mock_now | |
65 | |
66 | |
67 class OAuth2ClientTest(unittest.TestCase): | |
68 def setUp(self): | |
69 self.opener = MockOpener() | |
70 self.mock_datetime = MockDateTime() | |
71 self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) | |
72 self.mock_datetime.mock_now = self.start_time | |
73 self.client = oauth2_client.OAuth2Client( | |
74 oauth2_client.OAuth2Provider( | |
75 'Sample OAuth Provider', | |
76 'https://provider.example.com/oauth/provider?mode=authorize', | |
77 'https://provider.example.com/oauth/provider?mode=token'), | |
78 'clid', 'clsecret', | |
79 url_opener=self.opener, datetime_strategy=self.mock_datetime) | |
80 | |
81 def testFetchAccessToken(self): | |
82 refresh_token = '1/ZaBrxdPl77Bi4jbsO7x-NmATiaQZnWPB51nTvo8n9Sw' | |
83 access_token = '1/aalskfja-asjwerwj' | |
84 self.opener.open_result = ( | |
85 '{"access_token":"%s","expires_in":3600}' % access_token) | |
86 cred = oauth2_client.RefreshToken(self.client, refresh_token) | |
87 token = self.client.FetchAccessToken(cred) | |
88 | |
89 self.assertEquals( | |
90 self.opener.open_capture_url, | |
91 'https://provider.example.com/oauth/provider?mode=token') | |
92 self.assertEquals({ | |
93 'grant_type': ['refresh_token'], | |
94 'client_id': ['clid'], | |
95 'client_secret': ['clsecret'], | |
96 'refresh_token': [refresh_token]}, | |
97 urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True, | |
98 strict_parsing=True)) | |
99 self.assertEquals(access_token, token.token) | |
100 self.assertEquals( | |
101 datetime.datetime(2011, 3, 1, 11, 25, 13, 300826), | |
102 token.expiry) | |
103 | |
104 def testFetchAccessTokenFailsForBadJsonResponse(self): | |
105 self.opener.open_result = 'blah' | |
106 cred = oauth2_client.RefreshToken(self.client, 'abc123') | |
107 self.assertRaises( | |
108 oauth2_client.AccessTokenRefreshError, self.client.FetchAccessToken, cre
d) | |
109 | |
110 def testFetchAccessTokenFailsForErrorResponse(self): | |
111 self.opener.open_error = urllib2.HTTPError( | |
112 None, 400, 'Bad Request', None, StringIO('{"error": "invalid token"}')) | |
113 cred = oauth2_client.RefreshToken(self.client, 'abc123') | |
114 self.assertRaises( | |
115 oauth2_client.AccessTokenRefreshError, self.client.FetchAccessToken, cre
d) | |
116 | |
117 def testFetchAccessTokenFailsForHttpError(self): | |
118 self.opener.open_result = urllib2.HTTPError( | |
119 'foo', 400, 'Bad Request', None, None) | |
120 cred = oauth2_client.RefreshToken(self.client, 'abc123') | |
121 self.assertRaises( | |
122 oauth2_client.AccessTokenRefreshError, self.client.FetchAccessToken, cre
d) | |
123 | |
124 def testGetAccessToken(self): | |
125 refresh_token = 'ref_token' | |
126 access_token_1 = 'abc123' | |
127 self.opener.open_result = ( | |
128 '{"access_token":"%s",' '"expires_in":3600}' % access_token_1) | |
129 cred = oauth2_client.RefreshToken(self.client, refresh_token) | |
130 | |
131 token_1 = self.client.GetAccessToken(cred) | |
132 | |
133 # There's no access token in the cache; verify that we fetched a fresh | |
134 # token. | |
135 self.assertEquals({ | |
136 'grant_type': ['refresh_token'], | |
137 'client_id': ['clid'], | |
138 'client_secret': ['clsecret'], | |
139 'refresh_token': [refresh_token]}, | |
140 urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True, | |
141 strict_parsing=True)) | |
142 self.assertEquals(access_token_1, token_1.token) | |
143 self.assertEquals(self.start_time + datetime.timedelta(minutes=60), | |
144 token_1.expiry) | |
145 | |
146 # Advance time by less than expiry time, and fetch another token. | |
147 self.opener.reset() | |
148 self.mock_datetime.mock_now = ( | |
149 self.start_time + datetime.timedelta(minutes=55)) | |
150 token_2 = self.client.GetAccessToken(cred) | |
151 | |
152 # Since the access token wasn't expired, we get the cache token, and there | |
153 # was no refresh request. | |
154 self.assertEquals(token_1, token_2) | |
155 self.assertEquals(access_token_1, token_2.token) | |
156 self.assertEquals(None, self.opener.open_capture_url) | |
157 self.assertEquals(None, self.opener.open_capture_data) | |
158 | |
159 # Advance time past expiry time, and fetch another token. | |
160 self.opener.reset() | |
161 self.mock_datetime.mock_now = ( | |
162 self.start_time + datetime.timedelta(minutes=55, seconds=1)) | |
163 access_token_2 = 'zyx456' | |
164 self.opener.open_result = ( | |
165 '{"access_token":"%s",' '"expires_in":3600}' % access_token_2) | |
166 token_3 = self.client.GetAccessToken(cred) | |
167 | |
168 # This should have resulted in a refresh request and a fresh access token. | |
169 self.assertEquals({ | |
170 'grant_type': ['refresh_token'], | |
171 'client_id': ['clid'], | |
172 'client_secret': ['clsecret'], | |
173 'refresh_token': [refresh_token]}, | |
174 urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True, | |
175 strict_parsing=True)) | |
176 self.assertEquals(access_token_2, token_3.token) | |
177 self.assertEquals(self.mock_datetime.mock_now + datetime.timedelta(minutes=6
0), | |
178 token_3.expiry) | |
179 | |
180 def testGetAuthorizationUri(self): | |
181 authn_uri = self.client.GetAuthorizationUri( | |
182 'https://www.example.com/oauth/redir?mode=approve%20me', | |
183 ('scope_foo', 'scope_bar'), | |
184 {'state': 'this and that & sundry'}) | |
185 | |
186 uri_parts = urlparse.urlsplit(authn_uri) | |
187 self.assertEquals(('https', 'provider.example.com', '/oauth/provider'), | |
188 uri_parts[:3]) | |
189 | |
190 self.assertEquals({ | |
191 'response_type': ['code'], | |
192 'client_id': ['clid'], | |
193 'redirect_uri': | |
194 ['https://www.example.com/oauth/redir?mode=approve%20me'], | |
195 'scope': ['scope_foo scope_bar'], | |
196 'state': ['this and that & sundry'], | |
197 'mode': ['authorize']}, | |
198 urlparse.parse_qs(uri_parts[3])) | |
199 | |
200 def testExchangeAuthorizationCode(self): | |
201 code = 'codeABQ1234' | |
202 exp_refresh_token = 'ref_token42' | |
203 exp_access_token = 'access_tokenXY123' | |
204 self.opener.open_result = ( | |
205 '{"access_token":"%s","expires_in":3600,"refresh_token":"%s"}' | |
206 % (exp_access_token, exp_refresh_token)) | |
207 | |
208 refresh_token, access_token = self.client.ExchangeAuthorizationCode( | |
209 code, 'urn:ietf:wg:oauth:2.0:oob', ('scope1', 'scope2')) | |
210 | |
211 self.assertEquals({ | |
212 'grant_type': ['authorization_code'], | |
213 'client_id': ['clid'], | |
214 'client_secret': ['clsecret'], | |
215 'code': [code], | |
216 'redirect_uri': ['urn:ietf:wg:oauth:2.0:oob'], | |
217 'scope': ['scope1 scope2'] }, | |
218 urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True, | |
219 strict_parsing=True)) | |
220 self.assertEquals(exp_access_token, access_token.token) | |
221 self.assertEquals(self.start_time + datetime.timedelta(minutes=60), | |
222 access_token.expiry) | |
223 | |
224 self.assertEquals(self.client, refresh_token.oauth2_client) | |
225 self.assertEquals(exp_refresh_token, refresh_token.refresh_token) | |
226 | |
227 # Check that the access token was put in the cache. | |
228 cached_token = self.client.access_token_cache.GetToken( | |
229 refresh_token.CacheKey()) | |
230 self.assertEquals(access_token, cached_token) | |
231 | |
232 | |
233 class AccessTokenTest(unittest.TestCase): | |
234 | |
235 def testShouldRefresh(self): | |
236 mock_datetime = MockDateTime() | |
237 start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) | |
238 expiry = start + datetime.timedelta(minutes=60) | |
239 token = oauth2_client.AccessToken( | |
240 'foo', expiry, datetime_strategy=mock_datetime) | |
241 | |
242 mock_datetime.mock_now = start | |
243 self.assertFalse(token.ShouldRefresh()) | |
244 | |
245 mock_datetime.mock_now = start + datetime.timedelta(minutes=54) | |
246 self.assertFalse(token.ShouldRefresh()) | |
247 | |
248 mock_datetime.mock_now = start + datetime.timedelta(minutes=55) | |
249 self.assertFalse(token.ShouldRefresh()) | |
250 | |
251 mock_datetime.mock_now = start + datetime.timedelta( | |
252 minutes=55, seconds=1) | |
253 self.assertTrue(token.ShouldRefresh()) | |
254 | |
255 mock_datetime.mock_now = start + datetime.timedelta( | |
256 minutes=61) | |
257 self.assertTrue(token.ShouldRefresh()) | |
258 | |
259 mock_datetime.mock_now = start + datetime.timedelta(minutes=58) | |
260 self.assertFalse(token.ShouldRefresh(time_delta=120)) | |
261 | |
262 mock_datetime.mock_now = start + datetime.timedelta( | |
263 minutes=58, seconds=1) | |
264 self.assertTrue(token.ShouldRefresh(time_delta=120)) | |
265 | |
266 def testShouldRefreshNoExpiry(self): | |
267 mock_datetime = MockDateTime() | |
268 start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) | |
269 token = oauth2_client.AccessToken( | |
270 'foo', None, datetime_strategy=mock_datetime) | |
271 | |
272 mock_datetime.mock_now = start | |
273 self.assertFalse(token.ShouldRefresh()) | |
274 | |
275 mock_datetime.mock_now = start + datetime.timedelta( | |
276 minutes=472) | |
277 self.assertFalse(token.ShouldRefresh()) | |
278 | |
279 def testSerialization(self): | |
280 expiry = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) | |
281 token = oauth2_client.AccessToken('foo', expiry) | |
282 serialized_token = token.Serialize() | |
283 LOG.debug('testSerialization: serialized_token=%s' % serialized_token) | |
284 | |
285 token2 = oauth2_client.AccessToken.UnSerialize(serialized_token) | |
286 self.assertEquals(token, token2) | |
287 | |
288 | |
289 class RefreshTokenTest(unittest.TestCase): | |
290 def setUp(self): | |
291 self.opener = MockOpener() | |
292 self.mock_datetime = MockDateTime() | |
293 self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) | |
294 self.mock_datetime.mock_now = self.start_time | |
295 self.client = oauth2_client.OAuth2Client( | |
296 oauth2_client.OAuth2Provider( | |
297 'Sample OAuth Provider', | |
298 'https://provider.example.com/oauth/provider?mode=authorize', | |
299 'https://provider.example.com/oauth/provider?mode=token'), | |
300 'clid', 'clsecret', | |
301 url_opener=self.opener, datetime_strategy=self.mock_datetime) | |
302 | |
303 self.cred = oauth2_client.RefreshToken(self.client, 'ref_token_abc123') | |
304 | |
305 def testUniqeId(self): | |
306 cred_id = self.cred.CacheKey() | |
307 self.assertEquals('0720afed6871f12761fbea3271f451e6ba184bf5', cred_id) | |
308 | |
309 def testGetAuthorizationHeader(self): | |
310 access_token = 'access_123' | |
311 self.opener.open_result = ( | |
312 '{"access_token":"%s","expires_in":3600}' % access_token) | |
313 | |
314 self.assertEquals('Bearer %s' % access_token, | |
315 self.cred.GetAuthorizationHeader()) | |
316 | |
317 | |
318 class FileSystemTokenCacheTest(unittest.TestCase): | |
319 | |
320 def setUp(self): | |
321 self.cache = oauth2_client.FileSystemTokenCache() | |
322 self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) | |
323 self.token_1 = oauth2_client.AccessToken('token1', self.start_time) | |
324 self.token_2 = oauth2_client.AccessToken( | |
325 'token2', self.start_time + datetime.timedelta(seconds=492)) | |
326 self.key = 'token1key' | |
327 | |
328 def tearDown(self): | |
329 try: | |
330 os.unlink(self.cache.CacheFileName(self.key)) | |
331 except: | |
332 pass | |
333 | |
334 def testPut(self): | |
335 self.cache.PutToken(self.key, self.token_1) | |
336 # Assert that the cache file exists and has correct permissions. | |
337 self.assertEquals( | |
338 0600, S_IMODE(os.stat(self.cache.CacheFileName(self.key)).st_mode)) | |
339 | |
340 def testPutGet(self): | |
341 # No cache file present. | |
342 self.assertEquals(None, self.cache.GetToken(self.key)) | |
343 | |
344 # Put a token | |
345 self.cache.PutToken(self.key, self.token_1) | |
346 cached_token = self.cache.GetToken(self.key) | |
347 self.assertEquals(self.token_1, cached_token) | |
348 | |
349 # Put a different token | |
350 self.cache.PutToken(self.key, self.token_2) | |
351 cached_token = self.cache.GetToken(self.key) | |
352 self.assertEquals(self.token_2, cached_token) | |
353 | |
354 def testGetBadFile(self): | |
355 f = open(self.cache.CacheFileName(self.key), 'w') | |
356 f.write('blah') | |
357 f.close() | |
358 self.assertEquals(None, self.cache.GetToken(self.key)) | |
359 | |
360 def testCacheFileName(self): | |
361 cache = oauth2_client.FileSystemTokenCache( | |
362 path_pattern='/var/run/ccache/token.%(uid)s.%(key)s') | |
363 self.assertEquals('/var/run/ccache/token.%d.abc123' % os.getuid(), | |
364 cache.CacheFileName('abc123')) | |
365 | |
366 cache = oauth2_client.FileSystemTokenCache( | |
367 path_pattern='/var/run/ccache/token.%(key)s') | |
368 self.assertEquals('/var/run/ccache/token.abc123', | |
369 cache.CacheFileName('abc123')) | |
370 | |
371 | |
372 if __name__ == '__main__': | |
373 logging.basicConfig(level=logging.DEBUG) | |
374 unittest.main() | |
OLD | NEW |