Index: gerrit_util.py |
diff --git a/gerrit_util.py b/gerrit_util.py |
index 62fc1b9ab53cac7ddbf60f3f1dc70fae8ba93afd..691daf9182989a93339e8ed4245b4e8f326713e1 100755 |
--- a/gerrit_util.py |
+++ b/gerrit_util.py |
@@ -15,36 +15,19 @@ import logging |
import netrc |
import os |
import re |
+import socket |
import stat |
import sys |
import time |
import urllib |
+import urlparse |
from cStringIO import StringIO |
-_netrc_file = '_netrc' if sys.platform.startswith('win') else '.netrc' |
-_netrc_file = os.path.join(os.environ['HOME'], _netrc_file) |
-try: |
- NETRC = netrc.netrc(_netrc_file) |
-except IOError: |
- print >> sys.stderr, 'WARNING: Could not read netrc file %s' % _netrc_file |
- NETRC = netrc.netrc(os.devnull) |
-except netrc.NetrcParseError as e: |
- _netrc_stat = os.stat(e.filename) |
- if _netrc_stat.st_mode & (stat.S_IRWXG | stat.S_IRWXO): |
- print >> sys.stderr, ( |
- 'WARNING: netrc file %s cannot be used because its file permissions ' |
- 'are insecure. netrc file permissions should be 600.' % _netrc_file) |
- else: |
- print >> sys.stderr, ('ERROR: Cannot use netrc file %s due to a parsing ' |
- 'error.' % _netrc_file) |
- raise |
- del _netrc_stat |
- NETRC = netrc.netrc(os.devnull) |
-del _netrc_file |
LOGGER = logging.getLogger() |
TRY_LIMIT = 5 |
+ |
# Controls the transport protocol used to communicate with gerrit. |
# This is parameterized primarily to enable GerritTestCase. |
GERRIT_PROTOCOL = 'https' |
@@ -84,17 +67,141 @@ def GetConnectionClass(protocol=None): |
"Don't know how to work with protocol '%s'" % protocol) |
+class Authenticator(object): |
+ """Base authenticator class for authenticator implementations to subclass.""" |
+ |
+ def get_auth_header(self, host): |
+ raise NotImplementedError() |
+ |
+ @staticmethod |
+ def get(): |
+ """Returns: (Authenticator) The identified Authenticator to use. |
+ |
+ Probes the local system and its environment and identifies the |
+ Authenticator instance to use. |
+ """ |
+ if GceAuthenticator.is_gce(): |
+ return GceAuthenticator() |
+ return NetrcAuthenticator() |
+ |
+ |
+class NetrcAuthenticator(Authenticator): |
+ """Authenticator implementation that uses ".netrc" for token. |
+ """ |
+ |
+ def __init__(self): |
+ self.netrc = self._get_netrc() |
+ |
+ @staticmethod |
+ def _get_netrc(): |
+ path = '_netrc' if sys.platform.startswith('win') else '.netrc' |
+ path = os.path.join(os.environ['HOME'], path) |
+ try: |
+ return netrc.netrc(path) |
+ except IOError: |
+ print >> sys.stderr, 'WARNING: Could not read netrc file %s' % path |
+ return netrc.netrc(os.devnull) |
+ except netrc.NetrcParseError as e: |
+ st = os.stat(e.path) |
+ if st.st_mode & (stat.S_IRWXG | stat.S_IRWXO): |
+ print >> sys.stderr, ( |
+ 'WARNING: netrc file %s cannot be used because its file ' |
+ 'permissions are insecure. netrc file permissions should be ' |
+ '600.' % path) |
+ else: |
+ print >> sys.stderr, ('ERROR: Cannot use netrc file %s due to a ' |
+ 'parsing error.' % path) |
+ raise |
+ return netrc.netrc(os.devnull) |
+ |
+ def get_auth_header(self, host): |
+ auth = self.netrc.authenticators(host) |
+ if auth: |
+ return 'Basic %s' % (base64.b64encode('%s:%s' % (auth[0], auth[2]))) |
+ return None |
+ |
+ |
+class GceAuthenticator(Authenticator): |
+ """Authenticator implementation that uses GCE metadata service for token. |
+ """ |
+ |
+ _INFO_URL = 'http://metadata.google.internal' |
+ _ACQUIRE_URL = ('http://metadata/computeMetadata/v1/instance/' |
+ 'service-accounts/default/token') |
+ _ACQUIRE_HEADERS = {"Metadata-Flavor": "Google"} |
+ |
+ _cache_is_gce = None |
+ _token_cache = None |
+ _token_expiration = None |
+ |
+ @classmethod |
+ def is_gce(cls): |
+ if cls._cache_is_gce is None: |
+ cls._cache_is_gce = cls._test_is_gce() |
+ return cls._cache_is_gce |
+ |
+ @classmethod |
+ def _test_is_gce(cls): |
+ # Based on https://cloud.google.com/compute/docs/metadata#runninggce |
+ try: |
+ resp = cls._get(cls._INFO_URL) |
+ except socket.error: |
+ # Could not resolve URL. |
+ return False |
+ return resp.getheader('Metadata-Flavor', None) == 'Google' |
+ |
+ @staticmethod |
+ def _get(url, **kwargs): |
+ next_delay_sec = 1 |
+ for i in xrange(TRY_LIMIT): |
+ if i > 0: |
+ # Retry server error status codes. |
+ LOGGER.info('Encountered server error; retrying after %d second(s).', |
+ next_delay_sec) |
+ time.sleep(next_delay_sec) |
+ next_delay_sec *= 2 |
+ |
+ p = urlparse.urlparse(url) |
+ c = GetConnectionClass(protocol=p.scheme)(p.netloc) |
+ c.request('GET', url, **kwargs) |
+ resp = c.getresponse() |
+ LOGGER.debug('GET [%s] #%d/%d (%d)', url, i+1, TRY_LIMIT, resp.status) |
+ if resp.status < httplib.INTERNAL_SERVER_ERROR: |
+ return resp |
+ |
+ |
+ @classmethod |
+ def _get_token_dict(cls): |
+ if cls._token_cache: |
+ # If it expires within 25 seconds, refresh. |
+ if cls._token_expiration < time.time() - 25: |
+ return cls._token_cache |
+ |
+ resp = cls._get(cls._ACQUIRE_URL, headers=cls._ACQUIRE_HEADERS) |
+ if resp.status != httplib.OK: |
+ return None |
+ cls._token_cache = json.load(resp) |
+ cls._token_expiration = cls._token_cache['expires_in'] + time.time() |
+ return cls._token_cache |
+ |
+ def get_auth_header(self, _host): |
+ token_dict = self._get_token_dict() |
+ if not token_dict: |
+ return None |
+ return '%(token_type)s %(access_token)s' % token_dict |
+ |
+ |
+ |
def CreateHttpConn(host, path, reqtype='GET', headers=None, body=None): |
"""Opens an https connection to a gerrit service, and sends a request.""" |
headers = headers or {} |
bare_host = host.partition(':')[0] |
- auth = NETRC.authenticators(bare_host) |
+ auth = Authenticator.get().get_auth_header(bare_host) |
if auth: |
- headers.setdefault('Authorization', 'Basic %s' % ( |
- base64.b64encode('%s:%s' % (auth[0], auth[2])))) |
+ headers.setdefault('Authorization', auth) |
else: |
- LOGGER.debug('No authorization found in netrc for %s.' % bare_host) |
+ LOGGER.debug('No authorization found for %s.' % bare_host) |
if 'Authorization' in headers and not path.startswith('a/'): |
url = '/a/%s' % path |