| Index: appengine/auth_service/importer.py
|
| diff --git a/appengine/auth_service/importer.py b/appengine/auth_service/importer.py
|
| index 8e7ec52ba8ef210f288aa65274b1195ff7122f86..ff2a1037ff6706fde3131963c3b0d02fee9bf2fd 100644
|
| --- a/appengine/auth_service/importer.py
|
| +++ b/appengine/auth_service/importer.py
|
| @@ -35,17 +35,22 @@ service too.
|
|
|
| import collections
|
| import contextlib
|
| +import json
|
| import logging
|
| import StringIO
|
| import tarfile
|
|
|
| from google.appengine.ext import ndb
|
|
|
| +from google import protobuf
|
| +
|
| from components import auth
|
| from components import net
|
| from components import utils
|
| from components.auth import model
|
|
|
| +from proto import config_pb2
|
| +
|
|
|
| class BundleImportError(Exception):
|
| """Base class for errors while fetching external bundle."""
|
| @@ -94,92 +99,131 @@ def config_key():
|
|
|
| class GroupImporterConfig(ndb.Model):
|
| """Singleton entity with group importer configuration JSON."""
|
| - config = ndb.JsonProperty()
|
| + config = ndb.TextProperty() # legacy field with JSON config
|
| + config_proto = ndb.TextProperty()
|
| modified_by = auth.IdentityProperty(indexed=False)
|
| modified_ts = ndb.DateTimeProperty(auto_now=True, indexed=False)
|
|
|
|
|
| -def is_valid_config(config):
|
| - """Checks config for correctness."""
|
| - if not isinstance(config, list):
|
| - return False
|
| +def legacy_json_config_to_proto(config_json):
|
| + """Converts legacy JSON config to config_pb2.GroupImporterConfig message.
|
|
|
| - seen_systems = set(['external'])
|
| - seen_groups = set()
|
| + TODO(vadimsh): Remove once all instances of auth service use protobuf configs.
|
| + """
|
| + try:
|
| + config = json.loads(config_json)
|
| + except ValueError as ex:
|
| + logging.error('Invalid JSON: %s', ex)
|
| + return None
|
| + msg = config_pb2.GroupImporterConfig()
|
| for item in config:
|
| - if not isinstance(item, dict):
|
| - return False
|
| -
|
| - # 'format' is an optional string describing the format of the imported
|
| - # source. The default format is 'tarball'.
|
| fmt = item.get('format', 'tarball')
|
| - if fmt not in ['tarball', 'plainlist']:
|
| - return False
|
| -
|
| - # 'url' is a required string: where to fetch groups from.
|
| - url = item.get('url')
|
| - if not url or not isinstance(url, basestring):
|
| - return False
|
| -
|
| - # 'oauth_scopes' is an optional list of strings: used when generating OAuth
|
| - # access_token to put in Authorization header.
|
| - oauth_scopes = item.get('oauth_scopes')
|
| - if oauth_scopes is not None:
|
| - if not all(isinstance(x, basestring) for x in oauth_scopes):
|
| - return False
|
| -
|
| - # 'domain' is an optional string: will be used when constructing emails from
|
| - # naked usernames found in imported groups.
|
| - domain = item.get('domain')
|
| - if domain and not isinstance(domain, basestring):
|
| - return False
|
| -
|
| - # 'tarball' format uses 'systems' and 'groups' fields.
|
| if fmt == 'tarball':
|
| - # 'systems' is a required list of strings: group systems expected to be
|
| - # found in the archive (they act as prefixes to group names, e.g 'ldap').
|
| - systems = item.get('systems')
|
| - if not systems or not isinstance(systems, list):
|
| - return False
|
| - if not all(isinstance(x, basestring) for x in systems):
|
| - return False
|
| -
|
| - # There should be no overlap in systems between different bundles.
|
| - if set(systems) & seen_systems:
|
| - return False
|
| - seen_systems.update(systems)
|
| -
|
| - # 'groups' is an optional list of strings: if given, filters imported
|
| - # groups only to this list.
|
| - groups = item.get('groups')
|
| - if groups and not all(isinstance(x, basestring) for x in groups):
|
| - return False
|
| + entry = msg.tarball.add()
|
| + elif fmt == 'plainlist':
|
| + entry = msg.plainlist.add()
|
| + else:
|
| + logging.error('Unrecognized format: %s', fmt)
|
| + continue
|
| + entry.url = item.get('url') or ''
|
| + entry.oauth_scopes.extend(item.get('oauth_scopes') or [])
|
| + if 'domain' in item:
|
| + entry.domain = item['domain']
|
| + if fmt == 'tarball':
|
| + entry.systems.extend(item.get('systems') or [])
|
| + entry.groups.extend(item.get('groups') or [])
|
| elif fmt == 'plainlist':
|
| - # 'group' is a required name of imported group. The full group name will
|
| - # be 'external/<group>'.
|
| - group = item.get('group')
|
| - if not group or not isinstance(group, basestring) or group in seen_groups:
|
| - return False
|
| - seen_groups.add(group)
|
| + entry.group = item.get('group') or ''
|
| else:
|
| - assert False, 'Unreachable'
|
| + assert False, 'Not reachable'
|
| + return msg
|
|
|
| - return True
|
|
|
| +def validate_config(config):
|
| + """Checks config_pb2.GroupImporterConfig for correctness.
|
|
|
| -def read_config():
|
| - """Returns currently stored config or [] if not set."""
|
| + Raises:
|
| + ValueError if config has invalid structure.
|
| + """
|
| + if not isinstance(config, config_pb2.GroupImporterConfig):
|
| + raise ValueError('Not GroupImporterConfig proto message')
|
| +
|
| + # TODO(vadimsh): Can be made stricter.
|
| +
|
| + # Validate fields common to Tarball and Plainlist.
|
| + for entry in list(config.tarball) + list(config.plainlist):
|
| + if not entry.url:
|
| + raise ValueError(
|
| + '"url" field is required in %s' % entry.__class__.__name__)
|
| +
|
| + # Validate tarball fields.
|
| + seen_systems = set(['external'])
|
| + for tarball in config.tarball:
|
| + if not tarball.systems:
|
| + raise ValueError(
|
| + '"tarball" entry "%s" needs "systems" field' % tarball.url)
|
| + # There should be no overlap in systems between different bundles.
|
| + twice = set(tarball.systems) & seen_systems
|
| + if twice:
|
| + raise ValueError(
|
| + 'A system is imported twice by "%s": %s' %
|
| + (tarball.url, sorted(twice)))
|
| + seen_systems.update(tarball.systems)
|
| +
|
| + # Validate plainlist fields.
|
| + seen_groups = set()
|
| + for plainlist in config.plainlist:
|
| + if not plainlist.group:
|
| + raise ValueError(
|
| + '"plainlist" entry "%s" needs "group" field' % plainlist.url)
|
| + if plainlist.group in seen_groups:
|
| + raise ValueError(
|
| + 'In "%s" the group is imported twice: %s' %
|
| + (plainlist.url, plainlist.group))
|
| + seen_groups.add(plainlist.group)
|
| +
|
| +
|
| +def read_config_text():
|
| + """Returns importer config as a text blob (or '' if not set)."""
|
| + e = config_key().get()
|
| + if not e:
|
| + return ''
|
| + if e.config_proto:
|
| + return e.config_proto
|
| + if e.config:
|
| + msg = legacy_json_config_to_proto(e.config)
|
| + if not msg:
|
| + return ''
|
| + return protobuf.text_format.MessageToString(msg)
|
| + return ''
|
| +
|
| +
|
| +def read_legacy_config():
|
| + """Returns legacy JSON config stored in GroupImporterConfig entity.
|
| +
|
| + TODO(vadimsh): Remove once all instance of auth service use protobuf configs.
|
| + """
|
| + # Note: we do not care to do it in transaction.
|
| e = config_key().get()
|
| - return (e.config if e else []) or []
|
| + return e.config if e else None
|
|
|
|
|
| -def write_config(config):
|
| - """Updates stored configuration."""
|
| - if not is_valid_config(config):
|
| - raise ValueError('Invalid config')
|
| +def write_config_text(text):
|
| + """Validates config text blobs and puts it into the datastore.
|
| +
|
| + Raises:
|
| + ValueError on invalid format.
|
| + """
|
| + msg = config_pb2.GroupImporterConfig()
|
| + try:
|
| + protobuf.text_format.Merge(text, msg)
|
| + except protobuf.text_format.ParseError as ex:
|
| + raise ValueError('Config is badly formated: %s' % ex)
|
| + validate_config(msg)
|
| e = GroupImporterConfig(
|
| key=config_key(),
|
| - config=config,
|
| + config=read_legacy_config(),
|
| + config_proto=text,
|
| modified_by=auth.get_current_identity())
|
| e.put()
|
|
|
| @@ -190,34 +234,40 @@ def import_external_groups():
|
| Runs as a cron task. Raises BundleImportError in case of import errors.
|
| """
|
| # Missing config is not a error.
|
| - config = read_config()
|
| - if not config:
|
| + config_text = read_config_text()
|
| + if not config_text:
|
| logging.info('Not configured')
|
| return
|
| - if not is_valid_config(config):
|
| - raise BundleImportError('Bad config')
|
| + config = config_pb2.GroupImporterConfig()
|
| + try:
|
| + protobuf.text_format.Merge(config_text, config)
|
| + except protobuf.text_format.ParseError as ex:
|
| + raise BundleImportError('Bad config format: %s' % ex)
|
| + try:
|
| + validate_config(config)
|
| + except ValueError as ex:
|
| + raise BundleImportError('Bad config structure: %s' % ex)
|
|
|
| # Fetch all files specified in config in parallel.
|
| - futures = [fetch_file_async(p['url'], p.get('oauth_scopes')) for p in config]
|
| + entries = list(config.tarball) + list(config.plainlist)
|
| + futures = [fetch_file_async(e.url, e.oauth_scopes) for e in entries]
|
|
|
| # {system name -> group name -> list of identities}
|
| bundles = {}
|
| - for p, future in zip(config, futures):
|
| - fmt = p.get('format', 'tarball')
|
| -
|
| + for e, future in zip(entries, futures):
|
| # Unpack tarball into {system name -> group name -> list of identities}.
|
| - if fmt == 'tarball':
|
| + if isinstance(e, config_pb2.GroupImporterConfig.TarballEntry):
|
| fetched = load_tarball(
|
| - future.get_result(), p['systems'], p.get('groups'), p.get('domain'))
|
| + future.get_result(), e.systems, e.groups, e.domain)
|
| assert not (
|
| set(fetched) & set(bundles)), (fetched.keys(), bundles.keys())
|
| bundles.update(fetched)
|
| continue
|
|
|
| # Add plainlist group to 'external/*' bundle.
|
| - if fmt == 'plainlist':
|
| - group = load_group_file(future.get_result(), p.get('domain'))
|
| - name = 'external/%s' % p['group']
|
| + if isinstance(e, config_pb2.GroupImporterConfig.PlainlistEntry):
|
| + group = load_group_file(future.get_result(), e.domain)
|
| + name = 'external/%s' % e.group
|
| if 'external' not in bundles:
|
| bundles['external'] = {}
|
| assert name not in bundles['external'], name
|
|
|