Index: appengine/auth_service/importer.py |
diff --git a/appengine/auth_service/importer.py b/appengine/auth_service/importer.py |
index 8e7ec52ba8ef210f288aa65274b1195ff7122f86..ff2a1037ff6706fde3131963c3b0d02fee9bf2fd 100644 |
--- a/appengine/auth_service/importer.py |
+++ b/appengine/auth_service/importer.py |
@@ -35,17 +35,22 @@ service too. |
import collections |
import contextlib |
+import json |
import logging |
import StringIO |
import tarfile |
from google.appengine.ext import ndb |
+from google import protobuf |
+ |
from components import auth |
from components import net |
from components import utils |
from components.auth import model |
+from proto import config_pb2 |
+ |
class BundleImportError(Exception): |
"""Base class for errors while fetching external bundle.""" |
@@ -94,92 +99,131 @@ def config_key(): |
class GroupImporterConfig(ndb.Model): |
"""Singleton entity with group importer configuration JSON.""" |
- config = ndb.JsonProperty() |
+ config = ndb.TextProperty() # legacy field with JSON config |
+ config_proto = ndb.TextProperty() |
modified_by = auth.IdentityProperty(indexed=False) |
modified_ts = ndb.DateTimeProperty(auto_now=True, indexed=False) |
-def is_valid_config(config): |
- """Checks config for correctness.""" |
- if not isinstance(config, list): |
- return False |
+def legacy_json_config_to_proto(config_json): |
+ """Converts legacy JSON config to config_pb2.GroupImporterConfig message. |
- seen_systems = set(['external']) |
- seen_groups = set() |
+ TODO(vadimsh): Remove once all instances of auth service use protobuf configs. |
+ """ |
+ try: |
+ config = json.loads(config_json) |
+ except ValueError as ex: |
+ logging.error('Invalid JSON: %s', ex) |
+ return None |
+ msg = config_pb2.GroupImporterConfig() |
for item in config: |
- if not isinstance(item, dict): |
- return False |
- |
- # 'format' is an optional string describing the format of the imported |
- # source. The default format is 'tarball'. |
fmt = item.get('format', 'tarball') |
- if fmt not in ['tarball', 'plainlist']: |
- return False |
- |
- # 'url' is a required string: where to fetch groups from. |
- url = item.get('url') |
- if not url or not isinstance(url, basestring): |
- return False |
- |
- # 'oauth_scopes' is an optional list of strings: used when generating OAuth |
- # access_token to put in Authorization header. |
- oauth_scopes = item.get('oauth_scopes') |
- if oauth_scopes is not None: |
- if not all(isinstance(x, basestring) for x in oauth_scopes): |
- return False |
- |
- # 'domain' is an optional string: will be used when constructing emails from |
- # naked usernames found in imported groups. |
- domain = item.get('domain') |
- if domain and not isinstance(domain, basestring): |
- return False |
- |
- # 'tarball' format uses 'systems' and 'groups' fields. |
if fmt == 'tarball': |
- # 'systems' is a required list of strings: group systems expected to be |
- # found in the archive (they act as prefixes to group names, e.g 'ldap'). |
- systems = item.get('systems') |
- if not systems or not isinstance(systems, list): |
- return False |
- if not all(isinstance(x, basestring) for x in systems): |
- return False |
- |
- # There should be no overlap in systems between different bundles. |
- if set(systems) & seen_systems: |
- return False |
- seen_systems.update(systems) |
- |
- # 'groups' is an optional list of strings: if given, filters imported |
- # groups only to this list. |
- groups = item.get('groups') |
- if groups and not all(isinstance(x, basestring) for x in groups): |
- return False |
+ entry = msg.tarball.add() |
+ elif fmt == 'plainlist': |
+ entry = msg.plainlist.add() |
+ else: |
+ logging.error('Unrecognized format: %s', fmt) |
+ continue |
+ entry.url = item.get('url') or '' |
+ entry.oauth_scopes.extend(item.get('oauth_scopes') or []) |
+ if 'domain' in item: |
+ entry.domain = item['domain'] |
+ if fmt == 'tarball': |
+ entry.systems.extend(item.get('systems') or []) |
+ entry.groups.extend(item.get('groups') or []) |
elif fmt == 'plainlist': |
- # 'group' is a required name of imported group. The full group name will |
- # be 'external/<group>'. |
- group = item.get('group') |
- if not group or not isinstance(group, basestring) or group in seen_groups: |
- return False |
- seen_groups.add(group) |
+ entry.group = item.get('group') or '' |
else: |
- assert False, 'Unreachable' |
+ assert False, 'Not reachable' |
+ return msg |
- return True |
+def validate_config(config): |
+ """Checks config_pb2.GroupImporterConfig for correctness. |
-def read_config(): |
- """Returns currently stored config or [] if not set.""" |
+ Raises: |
+ ValueError if config has invalid structure. |
+ """ |
+ if not isinstance(config, config_pb2.GroupImporterConfig): |
+ raise ValueError('Not GroupImporterConfig proto message') |
+ |
+ # TODO(vadimsh): Can be made stricter. |
+ |
+ # Validate fields common to Tarball and Plainlist. |
+ for entry in list(config.tarball) + list(config.plainlist): |
+ if not entry.url: |
+ raise ValueError( |
+ '"url" field is required in %s' % entry.__class__.__name__) |
+ |
+ # Validate tarball fields. |
+ seen_systems = set(['external']) |
+ for tarball in config.tarball: |
+ if not tarball.systems: |
+ raise ValueError( |
+ '"tarball" entry "%s" needs "systems" field' % tarball.url) |
+ # There should be no overlap in systems between different bundles. |
+ twice = set(tarball.systems) & seen_systems |
+ if twice: |
+ raise ValueError( |
+ 'A system is imported twice by "%s": %s' % |
+ (tarball.url, sorted(twice))) |
+ seen_systems.update(tarball.systems) |
+ |
+ # Validate plainlist fields. |
+ seen_groups = set() |
+ for plainlist in config.plainlist: |
+ if not plainlist.group: |
+ raise ValueError( |
+ '"plainlist" entry "%s" needs "group" field' % plainlist.url) |
+ if plainlist.group in seen_groups: |
+ raise ValueError( |
+ 'In "%s" the group is imported twice: %s' % |
+ (plainlist.url, plainlist.group)) |
+ seen_groups.add(plainlist.group) |
+ |
+ |
+def read_config_text(): |
+ """Returns importer config as a text blob (or '' if not set).""" |
+ e = config_key().get() |
+ if not e: |
+ return '' |
+ if e.config_proto: |
+ return e.config_proto |
+ if e.config: |
+ msg = legacy_json_config_to_proto(e.config) |
+ if not msg: |
+ return '' |
+ return protobuf.text_format.MessageToString(msg) |
+ return '' |
+ |
+ |
+def read_legacy_config(): |
+ """Returns legacy JSON config stored in GroupImporterConfig entity. |
+ |
+ TODO(vadimsh): Remove once all instance of auth service use protobuf configs. |
+ """ |
+ # Note: we do not care to do it in transaction. |
e = config_key().get() |
- return (e.config if e else []) or [] |
+ return e.config if e else None |
-def write_config(config): |
- """Updates stored configuration.""" |
- if not is_valid_config(config): |
- raise ValueError('Invalid config') |
+def write_config_text(text): |
+ """Validates config text blobs and puts it into the datastore. |
+ |
+ Raises: |
+ ValueError on invalid format. |
+ """ |
+ msg = config_pb2.GroupImporterConfig() |
+ try: |
+ protobuf.text_format.Merge(text, msg) |
+ except protobuf.text_format.ParseError as ex: |
+ raise ValueError('Config is badly formated: %s' % ex) |
+ validate_config(msg) |
e = GroupImporterConfig( |
key=config_key(), |
- config=config, |
+ config=read_legacy_config(), |
+ config_proto=text, |
modified_by=auth.get_current_identity()) |
e.put() |
@@ -190,34 +234,40 @@ def import_external_groups(): |
Runs as a cron task. Raises BundleImportError in case of import errors. |
""" |
# Missing config is not a error. |
- config = read_config() |
- if not config: |
+ config_text = read_config_text() |
+ if not config_text: |
logging.info('Not configured') |
return |
- if not is_valid_config(config): |
- raise BundleImportError('Bad config') |
+ config = config_pb2.GroupImporterConfig() |
+ try: |
+ protobuf.text_format.Merge(config_text, config) |
+ except protobuf.text_format.ParseError as ex: |
+ raise BundleImportError('Bad config format: %s' % ex) |
+ try: |
+ validate_config(config) |
+ except ValueError as ex: |
+ raise BundleImportError('Bad config structure: %s' % ex) |
# Fetch all files specified in config in parallel. |
- futures = [fetch_file_async(p['url'], p.get('oauth_scopes')) for p in config] |
+ entries = list(config.tarball) + list(config.plainlist) |
+ futures = [fetch_file_async(e.url, e.oauth_scopes) for e in entries] |
# {system name -> group name -> list of identities} |
bundles = {} |
- for p, future in zip(config, futures): |
- fmt = p.get('format', 'tarball') |
- |
+ for e, future in zip(entries, futures): |
# Unpack tarball into {system name -> group name -> list of identities}. |
- if fmt == 'tarball': |
+ if isinstance(e, config_pb2.GroupImporterConfig.TarballEntry): |
fetched = load_tarball( |
- future.get_result(), p['systems'], p.get('groups'), p.get('domain')) |
+ future.get_result(), e.systems, e.groups, e.domain) |
assert not ( |
set(fetched) & set(bundles)), (fetched.keys(), bundles.keys()) |
bundles.update(fetched) |
continue |
# Add plainlist group to 'external/*' bundle. |
- if fmt == 'plainlist': |
- group = load_group_file(future.get_result(), p.get('domain')) |
- name = 'external/%s' % p['group'] |
+ if isinstance(e, config_pb2.GroupImporterConfig.PlainlistEntry): |
+ group = load_group_file(future.get_result(), e.domain) |
+ name = 'external/%s' % e.group |
if 'external' not in bundles: |
bundles['external'] = {} |
assert name not in bundles['external'], name |