Index: third_party/gsutil/third_party/apitools/apitools/gen/service_registry.py |
diff --git a/third_party/gsutil/third_party/apitools/apitools/gen/service_registry.py b/third_party/gsutil/third_party/apitools/apitools/gen/service_registry.py |
new file mode 100755 |
index 0000000000000000000000000000000000000000..d4826a6c27c755dcfbcf3b4d3f5ae05bc45d77ea |
--- /dev/null |
+++ b/third_party/gsutil/third_party/apitools/apitools/gen/service_registry.py |
@@ -0,0 +1,456 @@ |
+#!/usr/bin/env python |
+"""Service registry for apitools.""" |
+ |
+import collections |
+import logging |
+import re |
+import textwrap |
+ |
+from apitools.base.py import base_api |
+from apitools.gen import util |
+ |
+# We're a code generator. I don't care. |
+# pylint:disable=too-many-statements |
+ |
+_MIME_PATTERN_RE = re.compile(r'(?i)[a-z0-9_*-]+/[a-z0-9_*-]+') |
+ |
+ |
+class ServiceRegistry(object): |
+ |
+ """Registry for service types.""" |
+ |
+ def __init__(self, client_info, message_registry, command_registry, |
+ base_url, base_path, names, |
+ root_package_dir, base_files_package, |
+ unelidable_request_methods): |
+ self.__client_info = client_info |
+ self.__package = client_info.package |
+ self.__names = names |
+ self.__service_method_info_map = collections.OrderedDict() |
+ self.__message_registry = message_registry |
+ self.__command_registry = command_registry |
+ self.__base_url = base_url |
+ self.__base_path = base_path |
+ self.__root_package_dir = root_package_dir |
+ self.__base_files_package = base_files_package |
+ self.__unelidable_request_methods = unelidable_request_methods |
+ self.__all_scopes = set(self.__client_info.scopes) |
+ |
+ def Validate(self): |
+ self.__message_registry.Validate() |
+ |
+ @property |
+ def scopes(self): |
+ return sorted(list(self.__all_scopes)) |
+ |
+ def __GetServiceClassName(self, service_name): |
+ return self.__names.ClassName( |
+ '%sService' % self.__names.ClassName(service_name)) |
+ |
+ def __PrintDocstring(self, printer, method_info, method_name, name): |
+ """Print a docstring for a service method.""" |
+ if method_info.description: |
+ description = util.CleanDescription(method_info.description) |
+ first_line, newline, remaining = method_info.description.partition( |
+ '\n') |
+ if not first_line.endswith('.'): |
+ first_line = '%s.' % first_line |
+ description = '%s%s%s' % (first_line, newline, remaining) |
+ else: |
+ description = '%s method for the %s service.' % (method_name, name) |
+ with printer.CommentContext(): |
+ printer('"""%s' % description) |
+ printer() |
+ printer('Args:') |
+ printer(' request: (%s) input message', method_info.request_type_name) |
+ printer(' global_params: (StandardQueryParameters, default: None) ' |
+ 'global arguments') |
+ if method_info.upload_config: |
+ printer(' upload: (Upload, default: None) If present, upload') |
+ printer(' this stream with the request.') |
+ if method_info.supports_download: |
+ printer( |
+ ' download: (Download, default: None) If present, download') |
+ printer(' data from the request via this stream.') |
+ printer('Returns:') |
+ printer(' (%s) The response message.', method_info.response_type_name) |
+ printer('"""') |
+ |
+ def __WriteSingleService( |
+ self, printer, name, method_info_map, client_class_name): |
+ printer() |
+ class_name = self.__GetServiceClassName(name) |
+ printer('class %s(base_api.BaseApiService):', class_name) |
+ with printer.Indent(): |
+ printer('"""Service class for the %s resource."""', name) |
+ printer() |
+ printer('_NAME = %s', repr(name)) |
+ |
+ # Print the configs for the methods first. |
+ printer() |
+ printer('def __init__(self, client):') |
+ with printer.Indent(): |
+ printer('super(%s.%s, self).__init__(client)', |
+ client_class_name, class_name) |
+ printer('self._method_configs = {') |
+ with printer.Indent(indent=' '): |
+ for method_name, method_info in method_info_map.items(): |
+ printer("'%s': base_api.ApiMethodInfo(", method_name) |
+ with printer.Indent(indent=' '): |
+ attrs = sorted( |
+ x.name for x in method_info.all_fields()) |
+ for attr in attrs: |
+ if attr in ('upload_config', 'description'): |
+ continue |
+ printer( |
+ '%s=%r,', attr, getattr(method_info, attr)) |
+ printer('),') |
+ printer('}') |
+ printer() |
+ printer('self._upload_configs = {') |
+ with printer.Indent(indent=' '): |
+ for method_name, method_info in method_info_map.items(): |
+ upload_config = method_info.upload_config |
+ if upload_config is not None: |
+ printer( |
+ "'%s': base_api.ApiUploadInfo(", method_name) |
+ with printer.Indent(indent=' '): |
+ attrs = sorted( |
+ x.name for x in upload_config.all_fields()) |
+ for attr in attrs: |
+ printer('%s=%r,', |
+ attr, getattr(upload_config, attr)) |
+ printer('),') |
+ printer('}') |
+ |
+ # Now write each method in turn. |
+ for method_name, method_info in method_info_map.items(): |
+ printer() |
+ params = ['self', 'request', 'global_params=None'] |
+ if method_info.upload_config: |
+ params.append('upload=None') |
+ if method_info.supports_download: |
+ params.append('download=None') |
+ printer('def %s(%s):', method_name, ', '.join(params)) |
+ with printer.Indent(): |
+ self.__PrintDocstring( |
+ printer, method_info, method_name, name) |
+ printer("config = self.GetMethodConfig('%s')", method_name) |
+ upload_config = method_info.upload_config |
+ if upload_config is not None: |
+ printer("upload_config = self.GetUploadConfig('%s')", |
+ method_name) |
+ arg_lines = [ |
+ 'config, request, global_params=global_params'] |
+ if method_info.upload_config: |
+ arg_lines.append( |
+ 'upload=upload, upload_config=upload_config') |
+ if method_info.supports_download: |
+ arg_lines.append('download=download') |
+ printer('return self._RunMethod(') |
+ with printer.Indent(indent=' '): |
+ for line in arg_lines[:-1]: |
+ printer('%s,', line) |
+ printer('%s)', arg_lines[-1]) |
+ |
+ def __WriteProtoServiceDeclaration(self, printer, name, method_info_map): |
+ """Write a single service declaration to a proto file.""" |
+ printer() |
+ printer('service %s {', self.__GetServiceClassName(name)) |
+ with printer.Indent(): |
+ for method_name, method_info in method_info_map.items(): |
+ for line in textwrap.wrap(method_info.description, |
+ printer.CalculateWidth() - 3): |
+ printer('// %s', line) |
+ printer('rpc %s (%s) returns (%s);', |
+ method_name, |
+ method_info.request_type_name, |
+ method_info.response_type_name) |
+ printer('}') |
+ |
+ def WriteProtoFile(self, printer): |
+ """Write the services in this registry to out as proto.""" |
+ self.Validate() |
+ client_info = self.__client_info |
+ printer('// Generated services for %s version %s.', |
+ client_info.package, client_info.version) |
+ printer() |
+ printer('syntax = "proto2";') |
+ printer('package %s;', self.__package) |
+ printer('import "%s";', client_info.messages_proto_file_name) |
+ printer() |
+ for name, method_info_map in self.__service_method_info_map.items(): |
+ self.__WriteProtoServiceDeclaration(printer, name, method_info_map) |
+ |
+ def WriteFile(self, printer): |
+ """Write the services in this registry to out.""" |
+ self.Validate() |
+ client_info = self.__client_info |
+ printer('"""Generated client library for %s version %s."""', |
+ client_info.package, client_info.version) |
+ printer('# NOTE: This file is autogenerated and should not be edited ' |
+ 'by hand.') |
+ printer('from %s import base_api', self.__base_files_package) |
+ import_prefix = '' |
+ printer('%simport %s as messages', import_prefix, |
+ client_info.messages_rule_name) |
+ printer() |
+ printer() |
+ printer('class %s(base_api.BaseApiClient):', |
+ client_info.client_class_name) |
+ with printer.Indent(): |
+ printer( |
+ '"""Generated client library for service %s version %s."""', |
+ client_info.package, client_info.version) |
+ printer() |
+ printer('MESSAGES_MODULE = messages') |
+ printer() |
+ client_info_items = client_info._asdict( |
+ ).items() # pylint:disable=protected-access |
+ for attr, val in client_info_items: |
+ if attr == 'scopes' and not val: |
+ val = ['https://www.googleapis.com/auth/userinfo.email'] |
+ printer('_%s = %r' % (attr.upper(), val)) |
+ printer() |
+ printer("def __init__(self, url='', credentials=None,") |
+ with printer.Indent(indent=' '): |
+ printer('get_credentials=True, http=None, model=None,') |
+ printer('log_request=False, log_response=False,') |
+ printer('credentials_args=None, default_global_params=None,') |
+ printer('additional_http_headers=None):') |
+ with printer.Indent(): |
+ printer('"""Create a new %s handle."""', client_info.package) |
+ printer('url = url or %r', self.__base_url) |
+ printer( |
+ 'super(%s, self).__init__(', client_info.client_class_name) |
+ printer(' url, credentials=credentials,') |
+ printer(' get_credentials=get_credentials, http=http, ' |
+ 'model=model,') |
+ printer(' log_request=log_request, ' |
+ 'log_response=log_response,') |
+ printer(' credentials_args=credentials_args,') |
+ printer(' default_global_params=default_global_params,') |
+ printer(' additional_http_headers=additional_http_headers)') |
+ for name in self.__service_method_info_map.keys(): |
+ printer('self.%s = self.%s(self)', |
+ name, self.__GetServiceClassName(name)) |
+ for name, method_info in self.__service_method_info_map.items(): |
+ self.__WriteSingleService( |
+ printer, name, method_info, client_info.client_class_name) |
+ |
+ def __RegisterService(self, service_name, method_info_map): |
+ if service_name in self.__service_method_info_map: |
+ raise ValueError( |
+ 'Attempt to re-register descriptor %s' % service_name) |
+ self.__service_method_info_map[service_name] = method_info_map |
+ |
+ def __CreateRequestType(self, method_description, body_type=None): |
+ """Create a request type for this method.""" |
+ schema = {} |
+ schema['id'] = self.__names.ClassName('%sRequest' % ( |
+ self.__names.ClassName(method_description['id'], separator='.'),)) |
+ schema['type'] = 'object' |
+ schema['properties'] = collections.OrderedDict() |
+ if 'parameterOrder' not in method_description: |
+ ordered_parameters = list(method_description.get('parameters', [])) |
+ else: |
+ ordered_parameters = method_description['parameterOrder'][:] |
+ for k in method_description['parameters']: |
+ if k not in ordered_parameters: |
+ ordered_parameters.append(k) |
+ for parameter_name in ordered_parameters: |
+ field_name = self.__names.CleanName(parameter_name) |
+ field = dict(method_description['parameters'][parameter_name]) |
+ if 'type' not in field: |
+ raise ValueError('No type found in parameter %s' % field) |
+ schema['properties'][field_name] = field |
+ if body_type is not None: |
+ body_field_name = self.__GetRequestField( |
+ method_description, body_type) |
+ if body_field_name in schema['properties']: |
+ raise ValueError('Failed to normalize request resource name') |
+ if 'description' not in body_type: |
+ body_type['description'] = ( |
+ 'A %s resource to be passed as the request body.' % ( |
+ self.__GetRequestType(body_type),)) |
+ schema['properties'][body_field_name] = body_type |
+ self.__message_registry.AddDescriptorFromSchema(schema['id'], schema) |
+ return schema['id'] |
+ |
+ def __CreateVoidResponseType(self, method_description): |
+ """Create an empty response type.""" |
+ schema = {} |
+ method_name = self.__names.ClassName( |
+ method_description['id'], separator='.') |
+ schema['id'] = self.__names.ClassName('%sResponse' % method_name) |
+ schema['type'] = 'object' |
+ schema['description'] = 'An empty %s response.' % method_name |
+ self.__message_registry.AddDescriptorFromSchema(schema['id'], schema) |
+ return schema['id'] |
+ |
+ def __NeedRequestType(self, method_description, request_type): |
+ """Determine if this method needs a new request type created.""" |
+ if not request_type: |
+ return True |
+ method_id = method_description.get('id', '') |
+ if method_id in self.__unelidable_request_methods: |
+ return True |
+ message = self.__message_registry.LookupDescriptorOrDie(request_type) |
+ if message is None: |
+ return True |
+ field_names = [x.name for x in message.fields] |
+ parameters = method_description.get('parameters', {}) |
+ for param_name, param_info in parameters.items(): |
+ if (param_info.get('location') != 'path' or |
+ self.__names.CleanName(param_name) not in field_names): |
+ break |
+ else: |
+ return False |
+ return True |
+ |
+ def __MaxSizeToInt(self, max_size): |
+ """Convert max_size to an int.""" |
+ size_groups = re.match(r'(?P<size>\d+)(?P<unit>.B)?$', max_size) |
+ if size_groups is None: |
+ raise ValueError('Could not parse maxSize') |
+ size, unit = size_groups.group('size', 'unit') |
+ shift = 0 |
+ if unit is not None: |
+ unit_dict = {'KB': 10, 'MB': 20, 'GB': 30, 'TB': 40} |
+ shift = unit_dict.get(unit.upper()) |
+ if shift is None: |
+ raise ValueError('Unknown unit %s' % unit) |
+ return int(size) * (1 << shift) |
+ |
+ def __ComputeUploadConfig(self, media_upload_config, method_id): |
+ """Fill out the upload config for this method.""" |
+ config = base_api.ApiUploadInfo() |
+ if 'maxSize' in media_upload_config: |
+ config.max_size = self.__MaxSizeToInt( |
+ media_upload_config['maxSize']) |
+ if 'accept' not in media_upload_config: |
+ logging.warn( |
+ 'No accept types found for upload configuration in ' |
+ 'method %s, using */*', method_id) |
+ config.accept.extend([ |
+ str(a) for a in media_upload_config.get('accept', '*/*')]) |
+ |
+ for accept_pattern in config.accept: |
+ if not _MIME_PATTERN_RE.match(accept_pattern): |
+ logging.warn('Unexpected MIME type: %s', accept_pattern) |
+ protocols = media_upload_config.get('protocols', {}) |
+ for protocol in ('simple', 'resumable'): |
+ media = protocols.get(protocol, {}) |
+ for attr in ('multipart', 'path'): |
+ if attr in media: |
+ setattr(config, '%s_%s' % (protocol, attr), media[attr]) |
+ return config |
+ |
+ def __ComputeMethodInfo(self, method_description, request, response, |
+ request_field): |
+ """Compute the base_api.ApiMethodInfo for this method.""" |
+ relative_path = self.__names.NormalizeRelativePath( |
+ ''.join((self.__base_path, method_description['path']))) |
+ method_id = method_description['id'] |
+ ordered_params = [] |
+ for param_name in method_description.get('parameterOrder', []): |
+ param_info = method_description['parameters'][param_name] |
+ if param_info.get('required', False): |
+ ordered_params.append(param_name) |
+ method_info = base_api.ApiMethodInfo( |
+ relative_path=relative_path, |
+ method_id=method_id, |
+ http_method=method_description['httpMethod'], |
+ description=util.CleanDescription( |
+ method_description.get('description', '')), |
+ query_params=[], |
+ path_params=[], |
+ ordered_params=ordered_params, |
+ request_type_name=self.__names.ClassName(request), |
+ response_type_name=self.__names.ClassName(response), |
+ request_field=request_field, |
+ ) |
+ if method_description.get('supportsMediaUpload', False): |
+ method_info.upload_config = self.__ComputeUploadConfig( |
+ method_description.get('mediaUpload'), method_id) |
+ method_info.supports_download = method_description.get( |
+ 'supportsMediaDownload', False) |
+ self.__all_scopes.update(method_description.get('scopes', ())) |
+ for param, desc in method_description.get('parameters', {}).items(): |
+ param = self.__names.CleanName(param) |
+ location = desc['location'] |
+ if location == 'query': |
+ method_info.query_params.append(param) |
+ elif location == 'path': |
+ method_info.path_params.append(param) |
+ else: |
+ raise ValueError( |
+ 'Unknown parameter location %s for parameter %s' % ( |
+ location, param)) |
+ method_info.path_params.sort() |
+ method_info.query_params.sort() |
+ return method_info |
+ |
+ def __BodyFieldName(self, body_type): |
+ if body_type is None: |
+ return '' |
+ return self.__names.FieldName(body_type['$ref']) |
+ |
+ def __GetRequestType(self, body_type): |
+ return self.__names.ClassName(body_type.get('$ref')) |
+ |
+ def __GetRequestField(self, method_description, body_type): |
+ """Determine the request field for this method.""" |
+ body_field_name = self.__BodyFieldName(body_type) |
+ if body_field_name in method_description.get('parameters', {}): |
+ body_field_name = self.__names.FieldName( |
+ '%s_resource' % body_field_name) |
+ # It's exceedingly unlikely that we'd get two name collisions, which |
+ # means it's bound to happen at some point. |
+ while body_field_name in method_description.get('parameters', {}): |
+ body_field_name = self.__names.FieldName( |
+ '%s_body' % body_field_name) |
+ return body_field_name |
+ |
+ def AddServiceFromResource(self, service_name, methods): |
+ """Add a new service named service_name with the given methods.""" |
+ method_descriptions = methods.get('methods', {}) |
+ method_info_map = collections.OrderedDict() |
+ items = sorted(method_descriptions.items()) |
+ for method_name, method_description in items: |
+ method_name = self.__names.MethodName(method_name) |
+ |
+ # NOTE: According to the discovery document, if the request or |
+ # response is present, it will simply contain a `$ref`. |
+ body_type = method_description.get('request') |
+ if body_type is None: |
+ request_type = None |
+ else: |
+ request_type = self.__GetRequestType(body_type) |
+ if self.__NeedRequestType(method_description, request_type): |
+ request = self.__CreateRequestType( |
+ method_description, body_type=body_type) |
+ request_field = self.__GetRequestField( |
+ method_description, body_type) |
+ else: |
+ request = request_type |
+ request_field = base_api.REQUEST_IS_BODY |
+ |
+ if 'response' in method_description: |
+ response = method_description['response']['$ref'] |
+ else: |
+ response = self.__CreateVoidResponseType(method_description) |
+ |
+ method_info_map[method_name] = self.__ComputeMethodInfo( |
+ method_description, request, response, request_field) |
+ self.__command_registry.AddCommandForMethod( |
+ service_name, method_name, method_info_map[method_name], |
+ request, response) |
+ |
+ nested_services = methods.get('resources', {}) |
+ services = sorted(nested_services.items()) |
+ for subservice_name, submethods in services: |
+ new_service_name = '%s_%s' % (service_name, subservice_name) |
+ self.AddServiceFromResource(new_service_name, submethods) |
+ |
+ self.__RegisterService(service_name, method_info_map) |