| OLD | NEW |
| (Empty) |
| 1 import datetime | |
| 2 import logging | |
| 3 import os | |
| 4 import stat | |
| 5 | |
| 6 import gslib.tests.testcase as testcase | |
| 7 from gslib.tests.util import unittest | |
| 8 import gslib.util | |
| 9 import gslib.third_party.oauth2_plugin.oauth2_client as oauth2_client | |
| 10 | |
| 11 LOG = logging.getLogger('test_oauth2_client') | |
| 12 | |
| 13 ACCESS_TOKEN = 'abc123' | |
| 14 TOKEN_URI = 'https://provider.example.com/oauth/provider?mode=token' | |
| 15 AUTH_URI = 'https://provider.example.com/oauth/provider?mode=authorize' | |
| 16 DEFAULT_CA_CERTS_FILE = os.path.abspath( | |
| 17 os.path.join('gslib', 'data', 'cacerts.txt')) | |
| 18 | |
| 19 class MockDateTime: | |
| 20 def __init__(self): | |
| 21 self.mock_now = None | |
| 22 | |
| 23 def utcnow(self): | |
| 24 return self.mock_now | |
| 25 | |
| 26 class MockOAuth2ServiceAccountClient(oauth2_client.OAuth2ServiceAccountClient): | |
| 27 def __init__(self, client_id, private_key, password, auth_uri, token_uri, | |
| 28 datetime_strategy): | |
| 29 super(MockOAuth2ServiceAccountClient, self).__init__( | |
| 30 client_id, private_key, password, auth_uri=auth_uri, | |
| 31 token_uri=token_uri, datetime_strategy=datetime_strategy, | |
| 32 ca_certs_file=DEFAULT_CA_CERTS_FILE) | |
| 33 self.Reset() | |
| 34 | |
| 35 def Reset(self): | |
| 36 self.fetched_token = False | |
| 37 | |
| 38 def FetchAccessToken(self): | |
| 39 self.fetched_token = True | |
| 40 return oauth2_client.AccessToken( | |
| 41 ACCESS_TOKEN, | |
| 42 GetExpiry(self.datetime_strategy, 3600), | |
| 43 datetime_strategy=self.datetime_strategy) | |
| 44 | |
| 45 | |
| 46 class MockOAuth2UserAccountClient(oauth2_client.OAuth2UserAccountClient): | |
| 47 def __init__(self, token_uri, client_id, client_secret, refresh_token, | |
| 48 auth_uri, datetime_strategy): | |
| 49 super(MockOAuth2UserAccountClient, self).__init__( | |
| 50 token_uri, client_id, client_secret, refresh_token, auth_uri=auth_uri, | |
| 51 datetime_strategy=datetime_strategy, | |
| 52 ca_certs_file=DEFAULT_CA_CERTS_FILE) | |
| 53 self.Reset() | |
| 54 | |
| 55 def Reset(self): | |
| 56 self.fetched_token = False | |
| 57 | |
| 58 def FetchAccessToken(self): | |
| 59 self.fetched_token = True | |
| 60 return oauth2_client.AccessToken( | |
| 61 ACCESS_TOKEN, | |
| 62 GetExpiry(self.datetime_strategy, 3600), | |
| 63 datetime_strategy=self.datetime_strategy) | |
| 64 | |
| 65 def GetExpiry(datetime_strategy, lengthInSeconds): | |
| 66 token_expiry = (datetime_strategy.utcnow() | |
| 67 + datetime.timedelta(seconds=lengthInSeconds)) | |
| 68 return token_expiry | |
| 69 | |
| 70 def CreateMockUserAccountClient(start_time, mock_datetime): | |
| 71 return MockOAuth2UserAccountClient( | |
| 72 TOKEN_URI, 'clid', 'clsecret', 'ref_token_abc123', AUTH_URI, | |
| 73 mock_datetime) | |
| 74 | |
| 75 def CreateMockServiceAccountClient(start_time, mock_datetime): | |
| 76 return MockOAuth2ServiceAccountClient( | |
| 77 'clid', 'private_key', 'password', AUTH_URI, TOKEN_URI, | |
| 78 mock_datetime) | |
| 79 | |
| 80 | |
| 81 class OAuth2UserAccountClientTest(testcase.GsUtilUnitTestCase): | |
| 82 | |
| 83 def setUp(self): | |
| 84 self.tempdirs = [] | |
| 85 self.mock_datetime = MockDateTime() | |
| 86 self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) | |
| 87 self.mock_datetime.mock_now = self.start_time | |
| 88 | |
| 89 | |
| 90 def testGetAccessTokenUserAccount(self): | |
| 91 self.client = CreateMockUserAccountClient(self.start_time, | |
| 92 self.mock_datetime) | |
| 93 self._RunGetAccessTokenTest() | |
| 94 | |
| 95 | |
| 96 def testGetAccessTokenServiceAccount(self): | |
| 97 self.client = CreateMockServiceAccountClient(self.start_time, | |
| 98 self.mock_datetime) | |
| 99 self._RunGetAccessTokenTest() | |
| 100 | |
| 101 | |
| 102 def _RunGetAccessTokenTest(self): | |
| 103 refresh_token = 'ref_token' | |
| 104 access_token_1 = 'abc123' | |
| 105 | |
| 106 self.assertFalse(self.client.fetched_token) | |
| 107 token_1 = self.client.GetAccessToken() | |
| 108 | |
| 109 # There's no access token in the cache; verify that we fetched a fresh | |
| 110 # token. | |
| 111 self.assertTrue(self.client.fetched_token) | |
| 112 self.assertEquals(access_token_1, token_1.token) | |
| 113 self.assertEquals(self.start_time + datetime.timedelta(minutes=60), | |
| 114 token_1.expiry) | |
| 115 | |
| 116 # Advance time by less than expiry time, and fetch another token. | |
| 117 self.client.Reset() | |
| 118 self.mock_datetime.mock_now = ( | |
| 119 self.start_time + datetime.timedelta(minutes=55)) | |
| 120 token_2 = self.client.GetAccessToken() | |
| 121 | |
| 122 # Since the access token wasn't expired, we get the cache token, and there | |
| 123 # was no refresh request. | |
| 124 self.assertEquals(token_1, token_2) | |
| 125 self.assertEquals(access_token_1, token_2.token) | |
| 126 self.assertFalse(self.client.fetched_token) | |
| 127 | |
| 128 # Advance time past expiry time, and fetch another token. | |
| 129 self.client.Reset() | |
| 130 self.mock_datetime.mock_now = ( | |
| 131 self.start_time + datetime.timedelta(minutes=55, seconds=1)) | |
| 132 self.client.datetime_strategy = self.mock_datetime | |
| 133 access_token_2 = 'zyx456' | |
| 134 token_3 = self.client.GetAccessToken() | |
| 135 | |
| 136 # This should have resulted in a refresh request and a fresh access token. | |
| 137 self.assertTrue(self.client.fetched_token) | |
| 138 self.assertEquals( | |
| 139 self.mock_datetime.mock_now + datetime.timedelta(minutes=60), | |
| 140 token_3.expiry) | |
| 141 | |
| 142 | |
| 143 class AccessTokenTest(unittest.TestCase): | |
| 144 | |
| 145 def testShouldRefresh(self): | |
| 146 mock_datetime = MockDateTime() | |
| 147 start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) | |
| 148 expiry = start + datetime.timedelta(minutes=60) | |
| 149 token = oauth2_client.AccessToken( | |
| 150 'foo', expiry, datetime_strategy=mock_datetime) | |
| 151 | |
| 152 mock_datetime.mock_now = start | |
| 153 self.assertFalse(token.ShouldRefresh()) | |
| 154 | |
| 155 mock_datetime.mock_now = start + datetime.timedelta(minutes=54) | |
| 156 self.assertFalse(token.ShouldRefresh()) | |
| 157 | |
| 158 mock_datetime.mock_now = start + datetime.timedelta(minutes=55) | |
| 159 self.assertFalse(token.ShouldRefresh()) | |
| 160 | |
| 161 mock_datetime.mock_now = start + datetime.timedelta( | |
| 162 minutes=55, seconds=1) | |
| 163 self.assertTrue(token.ShouldRefresh()) | |
| 164 | |
| 165 mock_datetime.mock_now = start + datetime.timedelta( | |
| 166 minutes=61) | |
| 167 self.assertTrue(token.ShouldRefresh()) | |
| 168 | |
| 169 mock_datetime.mock_now = start + datetime.timedelta(minutes=58) | |
| 170 self.assertFalse(token.ShouldRefresh(time_delta=120)) | |
| 171 | |
| 172 mock_datetime.mock_now = start + datetime.timedelta( | |
| 173 minutes=58, seconds=1) | |
| 174 self.assertTrue(token.ShouldRefresh(time_delta=120)) | |
| 175 | |
| 176 def testShouldRefreshNoExpiry(self): | |
| 177 mock_datetime = MockDateTime() | |
| 178 start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) | |
| 179 token = oauth2_client.AccessToken( | |
| 180 'foo', None, datetime_strategy=mock_datetime) | |
| 181 | |
| 182 mock_datetime.mock_now = start | |
| 183 self.assertFalse(token.ShouldRefresh()) | |
| 184 | |
| 185 mock_datetime.mock_now = start + datetime.timedelta( | |
| 186 minutes=472) | |
| 187 self.assertFalse(token.ShouldRefresh()) | |
| 188 | |
| 189 def testSerialization(self): | |
| 190 expiry = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) | |
| 191 token = oauth2_client.AccessToken('foo', expiry) | |
| 192 serialized_token = token.Serialize() | |
| 193 LOG.debug('testSerialization: serialized_token=%s' % serialized_token) | |
| 194 | |
| 195 token2 = oauth2_client.AccessToken.UnSerialize(serialized_token) | |
| 196 self.assertEquals(token, token2) | |
| 197 | |
| 198 class FileSystemTokenCacheTest(unittest.TestCase): | |
| 199 | |
| 200 def setUp(self): | |
| 201 self.cache = oauth2_client.FileSystemTokenCache() | |
| 202 self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) | |
| 203 self.token_1 = oauth2_client.AccessToken('token1', self.start_time) | |
| 204 self.token_2 = oauth2_client.AccessToken( | |
| 205 'token2', self.start_time + datetime.timedelta(seconds=492)) | |
| 206 self.key = 'token1key' | |
| 207 | |
| 208 def tearDown(self): | |
| 209 try: | |
| 210 os.unlink(self.cache.CacheFileName(self.key)) | |
| 211 except: | |
| 212 pass | |
| 213 | |
| 214 def testPut(self): | |
| 215 self.cache.PutToken(self.key, self.token_1) | |
| 216 # Assert that the cache file exists and has correct permissions. | |
| 217 if not gslib.util.IS_WINDOWS: | |
| 218 self.assertEquals( | |
| 219 0600, | |
| 220 stat.S_IMODE(os.stat(self.cache.CacheFileName(self.key)).st_mode)) | |
| 221 | |
| 222 def testPutGet(self): | |
| 223 # No cache file present. | |
| 224 self.assertEquals(None, self.cache.GetToken(self.key)) | |
| 225 | |
| 226 # Put a token | |
| 227 self.cache.PutToken(self.key, self.token_1) | |
| 228 cached_token = self.cache.GetToken(self.key) | |
| 229 self.assertEquals(self.token_1, cached_token) | |
| 230 | |
| 231 # Put a different token | |
| 232 self.cache.PutToken(self.key, self.token_2) | |
| 233 cached_token = self.cache.GetToken(self.key) | |
| 234 self.assertEquals(self.token_2, cached_token) | |
| 235 | |
| 236 def testGetBadFile(self): | |
| 237 f = open(self.cache.CacheFileName(self.key), 'w') | |
| 238 f.write('blah') | |
| 239 f.close() | |
| 240 self.assertEquals(None, self.cache.GetToken(self.key)) | |
| 241 | |
| 242 def testCacheFileName(self): | |
| 243 cache = oauth2_client.FileSystemTokenCache( | |
| 244 path_pattern='/var/run/ccache/token.%(uid)s.%(key)s') | |
| 245 if gslib.util.IS_WINDOWS: | |
| 246 uid = '_' | |
| 247 else: | |
| 248 uid = os.getuid() | |
| 249 self.assertEquals('/var/run/ccache/token.%s.abc123' % uid, | |
| 250 cache.CacheFileName('abc123')) | |
| 251 | |
| 252 cache = oauth2_client.FileSystemTokenCache( | |
| 253 path_pattern='/var/run/ccache/token.%(key)s') | |
| 254 self.assertEquals('/var/run/ccache/token.abc123', | |
| 255 cache.CacheFileName('abc123')) | |
| 256 | |
| 257 | |
| 258 class RefreshTokenTest(unittest.TestCase): | |
| 259 def setUp(self): | |
| 260 self.mock_datetime = MockDateTime() | |
| 261 self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) | |
| 262 self.mock_datetime.mock_now = self.start_time | |
| 263 self.client = CreateMockUserAccountClient(self.start_time, | |
| 264 self.mock_datetime) | |
| 265 | |
| 266 | |
| 267 def testUniqeId(self): | |
| 268 cred_id = self.client.CacheKey() | |
| 269 self.assertEquals('0720afed6871f12761fbea3271f451e6ba184bf5', cred_id) | |
| 270 | |
| 271 def testGetAuthorizationHeader(self): | |
| 272 self.assertEquals('Bearer %s' % ACCESS_TOKEN, | |
| 273 self.client.GetAuthorizationHeader()) | |
| OLD | NEW |