OLD | NEW |
(Empty) | |
| 1 #!/usr/bin/env python |
| 2 # |
| 3 # Copyright 2015 Google Inc. |
| 4 # |
| 5 # Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 # you may not use this file except in compliance with the License. |
| 7 # You may obtain a copy of the License at |
| 8 # |
| 9 # http://www.apache.org/licenses/LICENSE-2.0 |
| 10 # |
| 11 # Unless required by applicable law or agreed to in writing, software |
| 12 # distributed under the License is distributed on an "AS IS" BASIS, |
| 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 # See the License for the specific language governing permissions and |
| 15 # limitations under the License. |
| 16 |
| 17 """Service registry for apitools.""" |
| 18 |
| 19 import collections |
| 20 import logging |
| 21 import re |
| 22 import textwrap |
| 23 |
| 24 from apitools.base.py import base_api |
| 25 from apitools.gen import util |
| 26 |
| 27 # We're a code generator. I don't care. |
| 28 # pylint:disable=too-many-statements |
| 29 |
| 30 _MIME_PATTERN_RE = re.compile(r'(?i)[a-z0-9_*-]+/[a-z0-9_*-]+') |
| 31 |
| 32 |
| 33 class ServiceRegistry(object): |
| 34 |
| 35 """Registry for service types.""" |
| 36 |
| 37 def __init__(self, client_info, message_registry, command_registry, |
| 38 base_url, base_path, names, |
| 39 root_package, base_files_package, |
| 40 unelidable_request_methods): |
| 41 self.__client_info = client_info |
| 42 self.__package = client_info.package |
| 43 self.__names = names |
| 44 self.__service_method_info_map = collections.OrderedDict() |
| 45 self.__message_registry = message_registry |
| 46 self.__command_registry = command_registry |
| 47 self.__base_url = base_url |
| 48 self.__base_path = base_path |
| 49 self.__root_package = root_package |
| 50 self.__base_files_package = base_files_package |
| 51 self.__unelidable_request_methods = unelidable_request_methods |
| 52 self.__all_scopes = set(self.__client_info.scopes) |
| 53 |
| 54 def Validate(self): |
| 55 self.__message_registry.Validate() |
| 56 |
| 57 @property |
| 58 def scopes(self): |
| 59 return sorted(list(self.__all_scopes)) |
| 60 |
| 61 def __GetServiceClassName(self, service_name): |
| 62 return self.__names.ClassName( |
| 63 '%sService' % self.__names.ClassName(service_name)) |
| 64 |
| 65 def __PrintDocstring(self, printer, method_info, method_name, name): |
| 66 """Print a docstring for a service method.""" |
| 67 if method_info.description: |
| 68 description = util.CleanDescription(method_info.description) |
| 69 first_line, newline, remaining = method_info.description.partition( |
| 70 '\n') |
| 71 if not first_line.endswith('.'): |
| 72 first_line = '%s.' % first_line |
| 73 description = '%s%s%s' % (first_line, newline, remaining) |
| 74 else: |
| 75 description = '%s method for the %s service.' % (method_name, name) |
| 76 with printer.CommentContext(): |
| 77 printer('"""%s' % description) |
| 78 printer() |
| 79 printer('Args:') |
| 80 printer(' request: (%s) input message', method_info.request_type_name) |
| 81 printer(' global_params: (StandardQueryParameters, default: None) ' |
| 82 'global arguments') |
| 83 if method_info.upload_config: |
| 84 printer(' upload: (Upload, default: None) If present, upload') |
| 85 printer(' this stream with the request.') |
| 86 if method_info.supports_download: |
| 87 printer( |
| 88 ' download: (Download, default: None) If present, download') |
| 89 printer(' data from the request via this stream.') |
| 90 printer('Returns:') |
| 91 printer(' (%s) The response message.', method_info.response_type_name) |
| 92 printer('"""') |
| 93 |
| 94 def __WriteSingleService( |
| 95 self, printer, name, method_info_map, client_class_name): |
| 96 printer() |
| 97 class_name = self.__GetServiceClassName(name) |
| 98 printer('class %s(base_api.BaseApiService):', class_name) |
| 99 with printer.Indent(): |
| 100 printer('"""Service class for the %s resource."""', name) |
| 101 printer() |
| 102 printer('_NAME = %s', repr(name)) |
| 103 |
| 104 # Print the configs for the methods first. |
| 105 printer() |
| 106 printer('def __init__(self, client):') |
| 107 with printer.Indent(): |
| 108 printer('super(%s.%s, self).__init__(client)', |
| 109 client_class_name, class_name) |
| 110 printer('self._method_configs = {') |
| 111 with printer.Indent(indent=' '): |
| 112 for method_name, method_info in method_info_map.items(): |
| 113 printer("'%s': base_api.ApiMethodInfo(", method_name) |
| 114 with printer.Indent(indent=' '): |
| 115 attrs = sorted( |
| 116 x.name for x in method_info.all_fields()) |
| 117 for attr in attrs: |
| 118 if attr in ('upload_config', 'description'): |
| 119 continue |
| 120 printer( |
| 121 '%s=%r,', attr, getattr(method_info, attr)) |
| 122 printer('),') |
| 123 printer('}') |
| 124 printer() |
| 125 printer('self._upload_configs = {') |
| 126 with printer.Indent(indent=' '): |
| 127 for method_name, method_info in method_info_map.items(): |
| 128 upload_config = method_info.upload_config |
| 129 if upload_config is not None: |
| 130 printer( |
| 131 "'%s': base_api.ApiUploadInfo(", method_name) |
| 132 with printer.Indent(indent=' '): |
| 133 attrs = sorted( |
| 134 x.name for x in upload_config.all_fields()) |
| 135 for attr in attrs: |
| 136 printer('%s=%r,', |
| 137 attr, getattr(upload_config, attr)) |
| 138 printer('),') |
| 139 printer('}') |
| 140 |
| 141 # Now write each method in turn. |
| 142 for method_name, method_info in method_info_map.items(): |
| 143 printer() |
| 144 params = ['self', 'request', 'global_params=None'] |
| 145 if method_info.upload_config: |
| 146 params.append('upload=None') |
| 147 if method_info.supports_download: |
| 148 params.append('download=None') |
| 149 printer('def %s(%s):', method_name, ', '.join(params)) |
| 150 with printer.Indent(): |
| 151 self.__PrintDocstring( |
| 152 printer, method_info, method_name, name) |
| 153 printer("config = self.GetMethodConfig('%s')", method_name) |
| 154 upload_config = method_info.upload_config |
| 155 if upload_config is not None: |
| 156 printer("upload_config = self.GetUploadConfig('%s')", |
| 157 method_name) |
| 158 arg_lines = [ |
| 159 'config, request, global_params=global_params'] |
| 160 if method_info.upload_config: |
| 161 arg_lines.append( |
| 162 'upload=upload, upload_config=upload_config') |
| 163 if method_info.supports_download: |
| 164 arg_lines.append('download=download') |
| 165 printer('return self._RunMethod(') |
| 166 with printer.Indent(indent=' '): |
| 167 for line in arg_lines[:-1]: |
| 168 printer('%s,', line) |
| 169 printer('%s)', arg_lines[-1]) |
| 170 |
| 171 def __WriteProtoServiceDeclaration(self, printer, name, method_info_map): |
| 172 """Write a single service declaration to a proto file.""" |
| 173 printer() |
| 174 printer('service %s {', self.__GetServiceClassName(name)) |
| 175 with printer.Indent(): |
| 176 for method_name, method_info in method_info_map.items(): |
| 177 for line in textwrap.wrap(method_info.description, |
| 178 printer.CalculateWidth() - 3): |
| 179 printer('// %s', line) |
| 180 printer('rpc %s (%s) returns (%s);', |
| 181 method_name, |
| 182 method_info.request_type_name, |
| 183 method_info.response_type_name) |
| 184 printer('}') |
| 185 |
| 186 def WriteProtoFile(self, printer): |
| 187 """Write the services in this registry to out as proto.""" |
| 188 self.Validate() |
| 189 client_info = self.__client_info |
| 190 printer('// Generated services for %s version %s.', |
| 191 client_info.package, client_info.version) |
| 192 printer() |
| 193 printer('syntax = "proto2";') |
| 194 printer('package %s;', self.__package) |
| 195 printer('import "%s";', client_info.messages_proto_file_name) |
| 196 printer() |
| 197 for name, method_info_map in self.__service_method_info_map.items(): |
| 198 self.__WriteProtoServiceDeclaration(printer, name, method_info_map) |
| 199 |
| 200 def WriteFile(self, printer): |
| 201 """Write the services in this registry to out.""" |
| 202 self.Validate() |
| 203 client_info = self.__client_info |
| 204 printer('"""Generated client library for %s version %s."""', |
| 205 client_info.package, client_info.version) |
| 206 printer('# NOTE: This file is autogenerated and should not be edited ' |
| 207 'by hand.') |
| 208 printer('from %s import base_api', self.__base_files_package) |
| 209 if self.__root_package: |
| 210 import_prefix = 'from {0} '.format(self.__root_package) |
| 211 else: |
| 212 import_prefix = '' |
| 213 printer('%simport %s as messages', import_prefix, |
| 214 client_info.messages_rule_name) |
| 215 printer() |
| 216 printer() |
| 217 printer('class %s(base_api.BaseApiClient):', |
| 218 client_info.client_class_name) |
| 219 with printer.Indent(): |
| 220 printer( |
| 221 '"""Generated client library for service %s version %s."""', |
| 222 client_info.package, client_info.version) |
| 223 printer() |
| 224 printer('MESSAGES_MODULE = messages') |
| 225 printer() |
| 226 # pylint: disable=protected-access |
| 227 client_info_items = client_info._asdict().items() |
| 228 for attr, val in client_info_items: |
| 229 if attr == 'scopes' and not val: |
| 230 val = ['https://www.googleapis.com/auth/userinfo.email'] |
| 231 printer('_%s = %r' % (attr.upper(), val)) |
| 232 printer() |
| 233 printer("def __init__(self, url='', credentials=None,") |
| 234 with printer.Indent(indent=' '): |
| 235 printer('get_credentials=True, http=None, model=None,') |
| 236 printer('log_request=False, log_response=False,') |
| 237 printer('credentials_args=None, default_global_params=None,') |
| 238 printer('additional_http_headers=None):') |
| 239 with printer.Indent(): |
| 240 printer('"""Create a new %s handle."""', client_info.package) |
| 241 printer('url = url or %r', self.__base_url) |
| 242 printer( |
| 243 'super(%s, self).__init__(', client_info.client_class_name) |
| 244 printer(' url, credentials=credentials,') |
| 245 printer(' get_credentials=get_credentials, http=http, ' |
| 246 'model=model,') |
| 247 printer(' log_request=log_request, ' |
| 248 'log_response=log_response,') |
| 249 printer(' credentials_args=credentials_args,') |
| 250 printer(' default_global_params=default_global_params,') |
| 251 printer(' additional_http_headers=additional_http_headers)') |
| 252 for name in self.__service_method_info_map.keys(): |
| 253 printer('self.%s = self.%s(self)', |
| 254 name, self.__GetServiceClassName(name)) |
| 255 for name, method_info in self.__service_method_info_map.items(): |
| 256 self.__WriteSingleService( |
| 257 printer, name, method_info, client_info.client_class_name) |
| 258 |
| 259 def __RegisterService(self, service_name, method_info_map): |
| 260 if service_name in self.__service_method_info_map: |
| 261 raise ValueError( |
| 262 'Attempt to re-register descriptor %s' % service_name) |
| 263 self.__service_method_info_map[service_name] = method_info_map |
| 264 |
| 265 def __CreateRequestType(self, method_description, body_type=None): |
| 266 """Create a request type for this method.""" |
| 267 schema = {} |
| 268 schema['id'] = self.__names.ClassName('%sRequest' % ( |
| 269 self.__names.ClassName(method_description['id'], separator='.'),)) |
| 270 schema['type'] = 'object' |
| 271 schema['properties'] = collections.OrderedDict() |
| 272 if 'parameterOrder' not in method_description: |
| 273 ordered_parameters = list(method_description.get('parameters', [])) |
| 274 else: |
| 275 ordered_parameters = method_description['parameterOrder'][:] |
| 276 for k in method_description['parameters']: |
| 277 if k not in ordered_parameters: |
| 278 ordered_parameters.append(k) |
| 279 for parameter_name in ordered_parameters: |
| 280 field_name = self.__names.CleanName(parameter_name) |
| 281 field = dict(method_description['parameters'][parameter_name]) |
| 282 if 'type' not in field: |
| 283 raise ValueError('No type found in parameter %s' % field) |
| 284 schema['properties'][field_name] = field |
| 285 if body_type is not None: |
| 286 body_field_name = self.__GetRequestField( |
| 287 method_description, body_type) |
| 288 if body_field_name in schema['properties']: |
| 289 raise ValueError('Failed to normalize request resource name') |
| 290 if 'description' not in body_type: |
| 291 body_type['description'] = ( |
| 292 'A %s resource to be passed as the request body.' % ( |
| 293 self.__GetRequestType(body_type),)) |
| 294 schema['properties'][body_field_name] = body_type |
| 295 self.__message_registry.AddDescriptorFromSchema(schema['id'], schema) |
| 296 return schema['id'] |
| 297 |
| 298 def __CreateVoidResponseType(self, method_description): |
| 299 """Create an empty response type.""" |
| 300 schema = {} |
| 301 method_name = self.__names.ClassName( |
| 302 method_description['id'], separator='.') |
| 303 schema['id'] = self.__names.ClassName('%sResponse' % method_name) |
| 304 schema['type'] = 'object' |
| 305 schema['description'] = 'An empty %s response.' % method_name |
| 306 self.__message_registry.AddDescriptorFromSchema(schema['id'], schema) |
| 307 return schema['id'] |
| 308 |
| 309 def __NeedRequestType(self, method_description, request_type): |
| 310 """Determine if this method needs a new request type created.""" |
| 311 if not request_type: |
| 312 return True |
| 313 method_id = method_description.get('id', '') |
| 314 if method_id in self.__unelidable_request_methods: |
| 315 return True |
| 316 message = self.__message_registry.LookupDescriptorOrDie(request_type) |
| 317 if message is None: |
| 318 return True |
| 319 field_names = [x.name for x in message.fields] |
| 320 parameters = method_description.get('parameters', {}) |
| 321 for param_name, param_info in parameters.items(): |
| 322 if (param_info.get('location') != 'path' or |
| 323 self.__names.CleanName(param_name) not in field_names): |
| 324 break |
| 325 else: |
| 326 return False |
| 327 return True |
| 328 |
| 329 def __MaxSizeToInt(self, max_size): |
| 330 """Convert max_size to an int.""" |
| 331 size_groups = re.match(r'(?P<size>\d+)(?P<unit>.B)?$', max_size) |
| 332 if size_groups is None: |
| 333 raise ValueError('Could not parse maxSize') |
| 334 size, unit = size_groups.group('size', 'unit') |
| 335 shift = 0 |
| 336 if unit is not None: |
| 337 unit_dict = {'KB': 10, 'MB': 20, 'GB': 30, 'TB': 40} |
| 338 shift = unit_dict.get(unit.upper()) |
| 339 if shift is None: |
| 340 raise ValueError('Unknown unit %s' % unit) |
| 341 return int(size) * (1 << shift) |
| 342 |
| 343 def __ComputeUploadConfig(self, media_upload_config, method_id): |
| 344 """Fill out the upload config for this method.""" |
| 345 config = base_api.ApiUploadInfo() |
| 346 if 'maxSize' in media_upload_config: |
| 347 config.max_size = self.__MaxSizeToInt( |
| 348 media_upload_config['maxSize']) |
| 349 if 'accept' not in media_upload_config: |
| 350 logging.warn( |
| 351 'No accept types found for upload configuration in ' |
| 352 'method %s, using */*', method_id) |
| 353 config.accept.extend([ |
| 354 str(a) for a in media_upload_config.get('accept', '*/*')]) |
| 355 |
| 356 for accept_pattern in config.accept: |
| 357 if not _MIME_PATTERN_RE.match(accept_pattern): |
| 358 logging.warn('Unexpected MIME type: %s', accept_pattern) |
| 359 protocols = media_upload_config.get('protocols', {}) |
| 360 for protocol in ('simple', 'resumable'): |
| 361 media = protocols.get(protocol, {}) |
| 362 for attr in ('multipart', 'path'): |
| 363 if attr in media: |
| 364 setattr(config, '%s_%s' % (protocol, attr), media[attr]) |
| 365 return config |
| 366 |
| 367 def __ComputeMethodInfo(self, method_description, request, response, |
| 368 request_field): |
| 369 """Compute the base_api.ApiMethodInfo for this method.""" |
| 370 relative_path = self.__names.NormalizeRelativePath( |
| 371 ''.join((self.__base_path, method_description['path']))) |
| 372 method_id = method_description['id'] |
| 373 ordered_params = [] |
| 374 for param_name in method_description.get('parameterOrder', []): |
| 375 param_info = method_description['parameters'][param_name] |
| 376 if param_info.get('required', False): |
| 377 ordered_params.append(param_name) |
| 378 method_info = base_api.ApiMethodInfo( |
| 379 relative_path=relative_path, |
| 380 method_id=method_id, |
| 381 http_method=method_description['httpMethod'], |
| 382 description=util.CleanDescription( |
| 383 method_description.get('description', '')), |
| 384 query_params=[], |
| 385 path_params=[], |
| 386 ordered_params=ordered_params, |
| 387 request_type_name=self.__names.ClassName(request), |
| 388 response_type_name=self.__names.ClassName(response), |
| 389 request_field=request_field, |
| 390 ) |
| 391 if method_description.get('supportsMediaUpload', False): |
| 392 method_info.upload_config = self.__ComputeUploadConfig( |
| 393 method_description.get('mediaUpload'), method_id) |
| 394 method_info.supports_download = method_description.get( |
| 395 'supportsMediaDownload', False) |
| 396 self.__all_scopes.update(method_description.get('scopes', ())) |
| 397 for param, desc in method_description.get('parameters', {}).items(): |
| 398 param = self.__names.CleanName(param) |
| 399 location = desc['location'] |
| 400 if location == 'query': |
| 401 method_info.query_params.append(param) |
| 402 elif location == 'path': |
| 403 method_info.path_params.append(param) |
| 404 else: |
| 405 raise ValueError( |
| 406 'Unknown parameter location %s for parameter %s' % ( |
| 407 location, param)) |
| 408 method_info.path_params.sort() |
| 409 method_info.query_params.sort() |
| 410 return method_info |
| 411 |
| 412 def __BodyFieldName(self, body_type): |
| 413 if body_type is None: |
| 414 return '' |
| 415 return self.__names.FieldName(body_type['$ref']) |
| 416 |
| 417 def __GetRequestType(self, body_type): |
| 418 return self.__names.ClassName(body_type.get('$ref')) |
| 419 |
| 420 def __GetRequestField(self, method_description, body_type): |
| 421 """Determine the request field for this method.""" |
| 422 body_field_name = self.__BodyFieldName(body_type) |
| 423 if body_field_name in method_description.get('parameters', {}): |
| 424 body_field_name = self.__names.FieldName( |
| 425 '%s_resource' % body_field_name) |
| 426 # It's exceedingly unlikely that we'd get two name collisions, which |
| 427 # means it's bound to happen at some point. |
| 428 while body_field_name in method_description.get('parameters', {}): |
| 429 body_field_name = self.__names.FieldName( |
| 430 '%s_body' % body_field_name) |
| 431 return body_field_name |
| 432 |
| 433 def AddServiceFromResource(self, service_name, methods): |
| 434 """Add a new service named service_name with the given methods.""" |
| 435 method_descriptions = methods.get('methods', {}) |
| 436 method_info_map = collections.OrderedDict() |
| 437 items = sorted(method_descriptions.items()) |
| 438 for method_name, method_description in items: |
| 439 method_name = self.__names.MethodName(method_name) |
| 440 |
| 441 # NOTE: According to the discovery document, if the request or |
| 442 # response is present, it will simply contain a `$ref`. |
| 443 body_type = method_description.get('request') |
| 444 if body_type is None: |
| 445 request_type = None |
| 446 else: |
| 447 request_type = self.__GetRequestType(body_type) |
| 448 if self.__NeedRequestType(method_description, request_type): |
| 449 request = self.__CreateRequestType( |
| 450 method_description, body_type=body_type) |
| 451 request_field = self.__GetRequestField( |
| 452 method_description, body_type) |
| 453 else: |
| 454 request = request_type |
| 455 request_field = base_api.REQUEST_IS_BODY |
| 456 |
| 457 if 'response' in method_description: |
| 458 response = method_description['response']['$ref'] |
| 459 else: |
| 460 response = self.__CreateVoidResponseType(method_description) |
| 461 |
| 462 method_info_map[method_name] = self.__ComputeMethodInfo( |
| 463 method_description, request, response, request_field) |
| 464 self.__command_registry.AddCommandForMethod( |
| 465 service_name, method_name, method_info_map[method_name], |
| 466 request, response) |
| 467 |
| 468 nested_services = methods.get('resources', {}) |
| 469 services = sorted(nested_services.items()) |
| 470 for subservice_name, submethods in services: |
| 471 new_service_name = '%s_%s' % (service_name, subservice_name) |
| 472 self.AddServiceFromResource(new_service_name, submethods) |
| 473 |
| 474 self.__RegisterService(service_name, method_info_map) |
OLD | NEW |