Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(152)

Side by Side Diff: third_party/google-endpoints/apitools/gen/command_registry.py

Issue 2666783008: Add google-endpoints to third_party/. (Closed)
Patch Set: Created 3 years, 10 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
(Empty)
1 #!/usr/bin/env python
2 #
3 # Copyright 2015 Google Inc.
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 """Command registry for apitools."""
18
19 import logging
20 import textwrap
21
22 from apitools.base.protorpclite import descriptor
23 from apitools.base.protorpclite import messages
24 from apitools.gen import extended_descriptor
25
26 # This is a code generator; we're purposely verbose.
27 # pylint:disable=too-many-statements
28
29 _VARIANT_TO_FLAG_TYPE_MAP = {
30 messages.Variant.DOUBLE: 'float',
31 messages.Variant.FLOAT: 'float',
32 messages.Variant.INT64: 'string',
33 messages.Variant.UINT64: 'string',
34 messages.Variant.INT32: 'integer',
35 messages.Variant.BOOL: 'boolean',
36 messages.Variant.STRING: 'string',
37 messages.Variant.MESSAGE: 'string',
38 messages.Variant.BYTES: 'string',
39 messages.Variant.UINT32: 'integer',
40 messages.Variant.ENUM: 'enum',
41 messages.Variant.SINT32: 'integer',
42 messages.Variant.SINT64: 'integer',
43 }
44
45
46 class FlagInfo(messages.Message):
47
48 """Information about a flag and conversion to a message.
49
50 Fields:
51 name: name of this flag.
52 type: type of the flag.
53 description: description of the flag.
54 default: default value for this flag.
55 enum_values: if this flag is an enum, the list of possible
56 values.
57 required: whether or not this flag is required.
58 fv: name of the flag_values object where this flag should
59 be registered.
60 conversion: template for type conversion.
61 special: (boolean, default: False) If True, this flag doesn't
62 correspond to an attribute on the request.
63 """
64 name = messages.StringField(1)
65 type = messages.StringField(2)
66 description = messages.StringField(3)
67 default = messages.StringField(4)
68 enum_values = messages.StringField(5, repeated=True)
69 required = messages.BooleanField(6, default=False)
70 fv = messages.StringField(7)
71 conversion = messages.StringField(8)
72 special = messages.BooleanField(9, default=False)
73
74
75 class ArgInfo(messages.Message):
76
77 """Information about a single positional command argument.
78
79 Fields:
80 name: argument name.
81 description: description of this argument.
82 conversion: template for type conversion.
83 """
84 name = messages.StringField(1)
85 description = messages.StringField(2)
86 conversion = messages.StringField(3)
87
88
89 class CommandInfo(messages.Message):
90
91 """Information about a single command.
92
93 Fields:
94 name: name of this command.
95 class_name: name of the apitools_base.NewCmd class for this command.
96 description: description of this command.
97 flags: list of FlagInfo messages for the command-specific flags.
98 args: list of ArgInfo messages for the positional args.
99 request_type: name of the request type for this command.
100 client_method_path: path from the client object to the method
101 this command is wrapping.
102 """
103 name = messages.StringField(1)
104 class_name = messages.StringField(2)
105 description = messages.StringField(3)
106 flags = messages.MessageField(FlagInfo, 4, repeated=True)
107 args = messages.MessageField(ArgInfo, 5, repeated=True)
108 request_type = messages.StringField(6)
109 client_method_path = messages.StringField(7)
110 has_upload = messages.BooleanField(8, default=False)
111 has_download = messages.BooleanField(9, default=False)
112
113
114 class CommandRegistry(object):
115
116 """Registry for CLI commands."""
117
118 def __init__(self, package, version, client_info, message_registry,
119 root_package, base_files_package, protorpc_package,
120 base_url, names):
121 self.__package = package
122 self.__version = version
123 self.__client_info = client_info
124 self.__names = names
125 self.__message_registry = message_registry
126 self.__root_package = root_package
127 self.__base_files_package = base_files_package
128 self.__protorpc_package = protorpc_package
129 self.__base_url = base_url
130 self.__command_list = []
131 self.__global_flags = []
132
133 def Validate(self):
134 self.__message_registry.Validate()
135
136 def AddGlobalParameters(self, schema):
137 for field in schema.fields:
138 self.__global_flags.append(self.__FlagInfoFromField(field, schema))
139
140 def AddCommandForMethod(self, service_name, method_name, method_info,
141 request, _):
142 """Add the given method as a command."""
143 command_name = self.__GetCommandName(method_info.method_id)
144 calling_path = '%s.%s' % (service_name, method_name)
145 request_type = self.__message_registry.LookupDescriptor(request)
146 description = method_info.description
147 if not description:
148 description = 'Call the %s method.' % method_info.method_id
149 field_map = dict((f.name, f) for f in request_type.fields)
150 args = []
151 arg_names = []
152 for field_name in method_info.ordered_params:
153 extended_field = field_map[field_name]
154 name = extended_field.name
155 args.append(ArgInfo(
156 name=name,
157 description=extended_field.description,
158 conversion=self.__GetConversion(extended_field, request_type),
159 ))
160 arg_names.append(name)
161 flags = []
162 for extended_field in sorted(request_type.fields,
163 key=lambda x: x.name):
164 field = extended_field.field_descriptor
165 if extended_field.name in arg_names:
166 continue
167 if self.__FieldIsRequired(field):
168 logging.warning(
169 'Required field %s not in ordered_params for command %s',
170 extended_field.name, command_name)
171 flags.append(self.__FlagInfoFromField(
172 extended_field, request_type, fv='fv'))
173 if method_info.upload_config:
174 # TODO(craigcitro): Consider adding additional flags to allow
175 # determining the filename from the object metadata.
176 upload_flag_info = FlagInfo(
177 name='upload_filename', type='string', default='',
178 description='Filename to use for upload.', fv='fv',
179 special=True)
180 flags.append(upload_flag_info)
181 mime_description = (
182 'MIME type to use for the upload. Only needed if '
183 'the extension on --upload_filename does not determine '
184 'the correct (or any) MIME type.')
185 mime_type_flag_info = FlagInfo(
186 name='upload_mime_type', type='string', default='',
187 description=mime_description, fv='fv', special=True)
188 flags.append(mime_type_flag_info)
189 if method_info.supports_download:
190 download_flag_info = FlagInfo(
191 name='download_filename', type='string', default='',
192 description='Filename to use for download.', fv='fv',
193 special=True)
194 flags.append(download_flag_info)
195 overwrite_description = (
196 'If True, overwrite the existing file when downloading.')
197 overwrite_flag_info = FlagInfo(
198 name='overwrite', type='boolean', default='False',
199 description=overwrite_description, fv='fv', special=True)
200 flags.append(overwrite_flag_info)
201 command_info = CommandInfo(
202 name=command_name,
203 class_name=self.__names.ClassName(command_name),
204 description=description,
205 flags=flags,
206 args=args,
207 request_type=request_type.full_name,
208 client_method_path=calling_path,
209 has_upload=bool(method_info.upload_config),
210 has_download=bool(method_info.supports_download)
211 )
212 self.__command_list.append(command_info)
213
214 def __LookupMessage(self, message, field):
215 message_type = self.__message_registry.LookupDescriptor(
216 '%s.%s' % (message.name, field.type_name))
217 if message_type is None:
218 message_type = self.__message_registry.LookupDescriptor(
219 field.type_name)
220 return message_type
221
222 def __GetCommandName(self, method_id):
223 command_name = method_id
224 prefix = '%s.' % self.__package
225 if command_name.startswith(prefix):
226 command_name = command_name[len(prefix):]
227 command_name = command_name.replace('.', '_')
228 return command_name
229
230 def __GetConversion(self, extended_field, extended_message):
231 field = extended_field.field_descriptor
232
233 type_name = ''
234 if field.variant in (messages.Variant.MESSAGE, messages.Variant.ENUM):
235 if field.type_name.startswith('apitools.base.protorpclite.'):
236 type_name = field.type_name
237 else:
238 field_message = self.__LookupMessage(extended_message, field)
239 if field_message is None:
240 raise ValueError(
241 'Could not find type for field %s' % field.name)
242 type_name = 'messages.%s' % field_message.full_name
243
244 template = ''
245 if field.variant in (messages.Variant.INT64, messages.Variant.UINT64):
246 template = 'int(%s)'
247 elif field.variant == messages.Variant.MESSAGE:
248 template = 'apitools_base.JsonToMessage(%s, %%s)' % type_name
249 elif field.variant == messages.Variant.ENUM:
250 template = '%s(%%s)' % type_name
251 elif field.variant == messages.Variant.STRING:
252 template = "%s.decode('utf8')"
253
254 if self.__FieldIsRepeated(extended_field.field_descriptor):
255 if template:
256 template = '[%s for x in %%s]' % (template % 'x')
257
258 return template
259
260 def __FieldIsRequired(self, field):
261 return field.label == descriptor.FieldDescriptor.Label.REQUIRED
262
263 def __FieldIsRepeated(self, field):
264 return field.label == descriptor.FieldDescriptor.Label.REPEATED
265
266 def __FlagInfoFromField(self, extended_field, extended_message, fv=''):
267 field = extended_field.field_descriptor
268 flag_info = FlagInfo()
269 flag_info.name = str(field.name)
270 # TODO(craigcitro): We should key by variant.
271 flag_info.type = _VARIANT_TO_FLAG_TYPE_MAP[field.variant]
272 flag_info.description = extended_field.description
273 if field.default_value:
274 # TODO(craigcitro): Formatting?
275 flag_info.default = field.default_value
276 if flag_info.type == 'enum':
277 # TODO(craigcitro): Does protorpc do this for us?
278 enum_type = self.__LookupMessage(extended_message, field)
279 if enum_type is None:
280 raise ValueError('Cannot find enum type %s', field.type_name)
281 flag_info.enum_values = [x.name for x in enum_type.values]
282 # Note that this choice is completely arbitrary -- but we only
283 # push the value through if the user specifies it, so this
284 # doesn't hurt anything.
285 if flag_info.default is None:
286 flag_info.default = flag_info.enum_values[0]
287 if self.__FieldIsRequired(field):
288 flag_info.required = True
289 flag_info.fv = fv
290 flag_info.conversion = self.__GetConversion(
291 extended_field, extended_message)
292 return flag_info
293
294 def __PrintFlagDeclarations(self, printer):
295 package = self.__client_info.package
296 function_name = '_Declare%sFlags' % (package[0].upper() + package[1:])
297 printer()
298 printer()
299 printer('def %s():', function_name)
300 with printer.Indent():
301 printer('"""Declare global flags in an idempotent way."""')
302 printer("if 'api_endpoint' in flags.FLAGS:")
303 with printer.Indent():
304 printer('return')
305 printer('flags.DEFINE_string(')
306 with printer.Indent(' '):
307 printer("'api_endpoint',")
308 printer('%r,', self.__base_url)
309 printer("'URL of the API endpoint to use.',")
310 printer("short_name='%s_url')", self.__package)
311 printer('flags.DEFINE_string(')
312 with printer.Indent(' '):
313 printer("'history_file',")
314 printer('%r,', '~/.%s.%s.history' %
315 (self.__package, self.__version))
316 printer("'File with interactive shell history.')")
317 printer('flags.DEFINE_multistring(')
318 with printer.Indent(' '):
319 printer("'add_header', [],")
320 printer("'Additional http headers (as key=value strings). '")
321 printer("'Can be specified multiple times.')")
322 printer('flags.DEFINE_string(')
323 with printer.Indent(' '):
324 printer("'service_account_json_keyfile', '',")
325 printer("'Filename for a JSON service account key downloaded'")
326 printer("' from the Developer Console.')")
327 for flag_info in self.__global_flags:
328 self.__PrintFlag(printer, flag_info)
329 printer()
330 printer()
331 printer('FLAGS = flags.FLAGS')
332 printer('apitools_base_cli.DeclareBaseFlags()')
333 printer('%s()', function_name)
334
335 def __PrintGetGlobalParams(self, printer):
336 printer('def GetGlobalParamsFromFlags():')
337 with printer.Indent():
338 printer('"""Return a StandardQueryParameters based on flags."""')
339 printer('result = messages.StandardQueryParameters()')
340
341 for flag_info in self.__global_flags:
342 rhs = 'FLAGS.%s' % flag_info.name
343 if flag_info.conversion:
344 rhs = flag_info.conversion % rhs
345 printer('if FLAGS[%r].present:', flag_info.name)
346 with printer.Indent():
347 printer('result.%s = %s', flag_info.name, rhs)
348 printer('return result')
349 printer()
350 printer()
351
352 def __PrintGetClient(self, printer):
353 printer('def GetClientFromFlags():')
354 with printer.Indent():
355 printer('"""Return a client object, configured from flags."""')
356 printer('log_request = FLAGS.log_request or '
357 'FLAGS.log_request_response')
358 printer('log_response = FLAGS.log_response or '
359 'FLAGS.log_request_response')
360 printer('api_endpoint = apitools_base.NormalizeApiEndpoint('
361 'FLAGS.api_endpoint)')
362 printer("additional_http_headers = dict(x.split('=', 1) for x in "
363 "FLAGS.add_header)")
364 printer('credentials_args = {')
365 with printer.Indent(' '):
366 printer("'service_account_json_keyfile': os.path.expanduser("
367 'FLAGS.service_account_json_keyfile)')
368 printer('}')
369 printer('try:')
370 with printer.Indent():
371 printer('client = client_lib.%s(',
372 self.__client_info.client_class_name)
373 with printer.Indent(indent=' '):
374 printer('api_endpoint, log_request=log_request,')
375 printer('log_response=log_response,')
376 printer('credentials_args=credentials_args,')
377 printer('additional_http_headers=additional_http_headers)')
378 printer('except apitools_base.CredentialsError as e:')
379 with printer.Indent():
380 printer("print 'Error creating credentials: %%s' %% e")
381 printer('sys.exit(1)')
382 printer('return client')
383 printer()
384 printer()
385
386 def __PrintCommandDocstring(self, printer, command_info):
387 with printer.CommentContext():
388 for line in textwrap.wrap('"""%s' % command_info.description,
389 printer.CalculateWidth()):
390 printer(line)
391 extended_descriptor.PrintIndentedDescriptions(
392 printer, command_info.args, 'Args')
393 extended_descriptor.PrintIndentedDescriptions(
394 printer, command_info.flags, 'Flags')
395 printer('"""')
396
397 def __PrintFlag(self, printer, flag_info):
398 printer('flags.DEFINE_%s(', flag_info.type)
399 with printer.Indent(indent=' '):
400 printer('%r,', flag_info.name)
401 printer('%r,', flag_info.default)
402 if flag_info.type == 'enum':
403 printer('%r,', flag_info.enum_values)
404
405 # TODO(craigcitro): Consider using 'drop_whitespace' elsewhere.
406 description_lines = textwrap.wrap(
407 flag_info.description, 75 - len(printer.indent),
408 drop_whitespace=False)
409 for line in description_lines[:-1]:
410 printer('%r', line)
411 last_line = description_lines[-1] if description_lines else ''
412 printer('%r%s', last_line, ',' if flag_info.fv else ')')
413 if flag_info.fv:
414 printer('flag_values=%s)', flag_info.fv)
415 if flag_info.required:
416 printer('flags.MarkFlagAsRequired(%r)', flag_info.name)
417
418 def __PrintPyShell(self, printer):
419 printer('class PyShell(appcommands.Cmd):')
420 printer()
421 with printer.Indent():
422 printer('def Run(self, _):')
423 with printer.Indent():
424 printer(
425 '"""Run an interactive python shell with the client."""')
426 printer('client = GetClientFromFlags()')
427 printer('params = GetGlobalParamsFromFlags()')
428 printer('for field in params.all_fields():')
429 with printer.Indent():
430 printer('value = params.get_assigned_value(field.name)')
431 printer('if value != field.default:')
432 with printer.Indent():
433 printer('client.AddGlobalParam(field.name, value)')
434 printer('banner = """')
435 printer(' == %s interactive console ==' % (
436 self.__client_info.package))
437 printer(' client: a %s client' %
438 self.__client_info.package)
439 printer(' apitools_base: base apitools module')
440 printer(' messages: the generated messages module')
441 printer('"""')
442 printer('local_vars = {')
443 with printer.Indent(indent=' '):
444 printer("'apitools_base': apitools_base,")
445 printer("'client': client,")
446 printer("'client_lib': client_lib,")
447 printer("'messages': messages,")
448 printer('}')
449 printer("if platform.system() == 'Linux':")
450 with printer.Indent():
451 printer('console = apitools_base_cli.ConsoleWithReadline(')
452 with printer.Indent(indent=' '):
453 printer('local_vars, histfile=FLAGS.history_file)')
454 printer('else:')
455 with printer.Indent():
456 printer('console = code.InteractiveConsole(local_vars)')
457 printer('try:')
458 with printer.Indent():
459 printer('console.interact(banner)')
460 printer('except SystemExit as e:')
461 with printer.Indent():
462 printer('return e.code')
463 printer()
464 printer()
465
466 def WriteFile(self, printer):
467 """Write a simple CLI (currently just a stub)."""
468 printer('#!/usr/bin/env python')
469 printer('"""CLI for %s, version %s."""',
470 self.__package, self.__version)
471 printer('# NOTE: This file is autogenerated and should not be edited '
472 'by hand.')
473 # TODO(craigcitro): Add a build stamp, along with some other
474 # information.
475 printer()
476 printer('import code')
477 printer('import os')
478 printer('import platform')
479 printer('import sys')
480 printer()
481 printer('from %s import message_types', self.__protorpc_package)
482 printer('from %s import messages', self.__protorpc_package)
483 printer()
484 appcommands_import = 'from google.apputils import appcommands'
485 printer(appcommands_import)
486
487 flags_import = 'import gflags as flags'
488 printer(flags_import)
489 printer()
490 printer('import %s as apitools_base', self.__base_files_package)
491 printer('from %s import cli as apitools_base_cli',
492 self.__base_files_package)
493 import_prefix = ''
494 printer('%simport %s as client_lib',
495 import_prefix, self.__client_info.client_rule_name)
496 printer('%simport %s as messages',
497 import_prefix, self.__client_info.messages_rule_name)
498 self.__PrintFlagDeclarations(printer)
499 printer()
500 printer()
501 self.__PrintGetGlobalParams(printer)
502 self.__PrintGetClient(printer)
503 self.__PrintPyShell(printer)
504 self.__PrintCommands(printer)
505 printer('def main(_):')
506 with printer.Indent():
507 printer("appcommands.AddCmd('pyshell', PyShell)")
508 for command_info in self.__command_list:
509 printer("appcommands.AddCmd('%s', %s)",
510 command_info.name, command_info.class_name)
511 printer()
512 printer('apitools_base_cli.SetupLogger()')
513 # TODO(craigcitro): Just call SetDefaultCommand as soon as
514 # another appcommands release happens and this exists
515 # externally.
516 printer("if hasattr(appcommands, 'SetDefaultCommand'):")
517 with printer.Indent():
518 printer("appcommands.SetDefaultCommand('pyshell')")
519 printer()
520 printer()
521 printer('run_main = apitools_base_cli.run_main')
522 printer()
523 printer("if __name__ == '__main__':")
524 with printer.Indent():
525 printer('appcommands.Run()')
526
527 def __PrintCommands(self, printer):
528 """Print all commands in this registry using printer."""
529 for command_info in self.__command_list:
530 arg_list = [arg_info.name for arg_info in command_info.args]
531 printer(
532 'class %s(apitools_base_cli.NewCmd):', command_info.class_name)
533 with printer.Indent():
534 printer('"""Command wrapping %s."""',
535 command_info.client_method_path)
536 printer()
537 printer('usage = """%s%s%s"""',
538 command_info.name,
539 ' ' if arg_list else '',
540 ' '.join('<%s>' % argname for argname in arg_list))
541 printer()
542 printer('def __init__(self, name, fv):')
543 with printer.Indent():
544 printer('super(%s, self).__init__(name, fv)',
545 command_info.class_name)
546 for flag in command_info.flags:
547 self.__PrintFlag(printer, flag)
548 printer()
549 printer('def RunWithArgs(%s):', ', '.join(['self'] + arg_list))
550 with printer.Indent():
551 self.__PrintCommandDocstring(printer, command_info)
552 printer('client = GetClientFromFlags()')
553 printer('global_params = GetGlobalParamsFromFlags()')
554 printer(
555 'request = messages.%s(', command_info.request_type)
556 with printer.Indent(indent=' '):
557 for arg in command_info.args:
558 rhs = arg.name
559 if arg.conversion:
560 rhs = arg.conversion % arg.name
561 printer('%s=%s,', arg.name, rhs)
562 printer(')')
563 for flag_info in command_info.flags:
564 if flag_info.special:
565 continue
566 rhs = 'FLAGS.%s' % flag_info.name
567 if flag_info.conversion:
568 rhs = flag_info.conversion % rhs
569 printer('if FLAGS[%r].present:', flag_info.name)
570 with printer.Indent():
571 printer('request.%s = %s', flag_info.name, rhs)
572 call_args = ['request', 'global_params=global_params']
573 if command_info.has_upload:
574 call_args.append('upload=upload')
575 printer('upload = None')
576 printer('if FLAGS.upload_filename:')
577 with printer.Indent():
578 printer('upload = apitools_base.Upload.FromFile(')
579 printer(' FLAGS.upload_filename, '
580 'FLAGS.upload_mime_type,')
581 printer(' progress_callback='
582 'apitools_base.UploadProgressPrinter,')
583 printer(' finish_callback='
584 'apitools_base.UploadCompletePrinter)')
585 if command_info.has_download:
586 call_args.append('download=download')
587 printer('download = None')
588 printer('if FLAGS.download_filename:')
589 with printer.Indent():
590 printer('download = apitools_base.Download.'
591 'FromFile(FLAGS.download_filename, '
592 'overwrite=FLAGS.overwrite,')
593 printer(' progress_callback='
594 'apitools_base.DownloadProgressPrinter,')
595 printer(' finish_callback='
596 'apitools_base.DownloadCompletePrinter)')
597 printer(
598 'result = client.%s(', command_info.client_method_path)
599 with printer.Indent(indent=' '):
600 printer('%s)', ', '.join(call_args))
601 printer('print apitools_base_cli.FormatOutput(result)')
602 printer()
603 printer()
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698