OLD | NEW |
(Empty) | |
| 1 #!/usr/bin/env python |
| 2 # |
| 3 # Copyright 2015 Google Inc. |
| 4 # |
| 5 # Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 # you may not use this file except in compliance with the License. |
| 7 # You may obtain a copy of the License at |
| 8 # |
| 9 # http://www.apache.org/licenses/LICENSE-2.0 |
| 10 # |
| 11 # Unless required by applicable law or agreed to in writing, software |
| 12 # distributed under the License is distributed on an "AS IS" BASIS, |
| 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 # See the License for the specific language governing permissions and |
| 15 # limitations under the License. |
| 16 |
| 17 """Message registry for apitools.""" |
| 18 |
| 19 import collections |
| 20 import contextlib |
| 21 import json |
| 22 |
| 23 import six |
| 24 |
| 25 from apitools.base.protorpclite import descriptor |
| 26 from apitools.base.protorpclite import messages |
| 27 from apitools.gen import extended_descriptor |
| 28 from apitools.gen import util |
| 29 |
| 30 TypeInfo = collections.namedtuple('TypeInfo', ('type_name', 'variant')) |
| 31 |
| 32 |
| 33 class MessageRegistry(object): |
| 34 |
| 35 """Registry for message types. |
| 36 |
| 37 This closely mirrors a messages.FileDescriptor, but adds additional |
| 38 attributes (such as message and field descriptions) and some extra |
| 39 code for validation and cycle detection. |
| 40 """ |
| 41 |
| 42 # Type information from these two maps comes from here: |
| 43 # https://developers.google.com/discovery/v1/type-format |
| 44 PRIMITIVE_TYPE_INFO_MAP = { |
| 45 'string': TypeInfo(type_name='string', |
| 46 variant=messages.StringField.DEFAULT_VARIANT), |
| 47 'integer': TypeInfo(type_name='integer', |
| 48 variant=messages.IntegerField.DEFAULT_VARIANT), |
| 49 'boolean': TypeInfo(type_name='boolean', |
| 50 variant=messages.BooleanField.DEFAULT_VARIANT), |
| 51 'number': TypeInfo(type_name='number', |
| 52 variant=messages.FloatField.DEFAULT_VARIANT), |
| 53 'any': TypeInfo(type_name='extra_types.JsonValue', |
| 54 variant=messages.Variant.MESSAGE), |
| 55 } |
| 56 |
| 57 PRIMITIVE_FORMAT_MAP = { |
| 58 'int32': TypeInfo(type_name='integer', |
| 59 variant=messages.Variant.INT32), |
| 60 'uint32': TypeInfo(type_name='integer', |
| 61 variant=messages.Variant.UINT32), |
| 62 'int64': TypeInfo(type_name='string', |
| 63 variant=messages.Variant.INT64), |
| 64 'uint64': TypeInfo(type_name='string', |
| 65 variant=messages.Variant.UINT64), |
| 66 'double': TypeInfo(type_name='number', |
| 67 variant=messages.Variant.DOUBLE), |
| 68 'float': TypeInfo(type_name='number', |
| 69 variant=messages.Variant.FLOAT), |
| 70 'byte': TypeInfo(type_name='byte', |
| 71 variant=messages.BytesField.DEFAULT_VARIANT), |
| 72 'date': TypeInfo(type_name='extra_types.DateField', |
| 73 variant=messages.Variant.STRING), |
| 74 'date-time': TypeInfo( |
| 75 type_name=('apitools.base.protorpclite.message_types.' |
| 76 'DateTimeMessage'), |
| 77 variant=messages.Variant.MESSAGE), |
| 78 } |
| 79 |
| 80 def __init__(self, client_info, names, description, root_package_dir, |
| 81 base_files_package, protorpc_package): |
| 82 self.__names = names |
| 83 self.__client_info = client_info |
| 84 self.__package = client_info.package |
| 85 self.__description = util.CleanDescription(description) |
| 86 self.__root_package_dir = root_package_dir |
| 87 self.__base_files_package = base_files_package |
| 88 self.__protorpc_package = protorpc_package |
| 89 self.__file_descriptor = extended_descriptor.ExtendedFileDescriptor( |
| 90 package=self.__package, description=self.__description) |
| 91 # Add required imports |
| 92 self.__file_descriptor.additional_imports = [ |
| 93 'from %s import messages as _messages' % self.__protorpc_package, |
| 94 ] |
| 95 # Map from scoped names (i.e. Foo.Bar) to MessageDescriptors. |
| 96 self.__message_registry = collections.OrderedDict() |
| 97 # A set of types that we're currently adding (for cycle detection). |
| 98 self.__nascent_types = set() |
| 99 # A set of types for which we've seen a reference but no |
| 100 # definition; if this set is nonempty, validation fails. |
| 101 self.__unknown_types = set() |
| 102 # Used for tracking paths during message creation |
| 103 self.__current_path = [] |
| 104 # Where to register created messages |
| 105 self.__current_env = self.__file_descriptor |
| 106 # TODO(craigcitro): Add a `Finalize` method. |
| 107 |
| 108 @property |
| 109 def file_descriptor(self): |
| 110 self.Validate() |
| 111 return self.__file_descriptor |
| 112 |
| 113 def WriteProtoFile(self, printer): |
| 114 """Write the messages file to out as proto.""" |
| 115 self.Validate() |
| 116 extended_descriptor.WriteMessagesFile( |
| 117 self.__file_descriptor, self.__package, self.__client_info.version, |
| 118 printer) |
| 119 |
| 120 def WriteFile(self, printer): |
| 121 """Write the messages file to out.""" |
| 122 self.Validate() |
| 123 extended_descriptor.WritePythonFile( |
| 124 self.__file_descriptor, self.__package, self.__client_info.version, |
| 125 printer) |
| 126 |
| 127 def Validate(self): |
| 128 mysteries = self.__nascent_types or self.__unknown_types |
| 129 if mysteries: |
| 130 raise ValueError('Malformed MessageRegistry: %s' % mysteries) |
| 131 |
| 132 def __ComputeFullName(self, name): |
| 133 return '.'.join(map(six.text_type, self.__current_path[:] + [name])) |
| 134 |
| 135 def __AddImport(self, new_import): |
| 136 if new_import not in self.__file_descriptor.additional_imports: |
| 137 self.__file_descriptor.additional_imports.append(new_import) |
| 138 |
| 139 def __DeclareDescriptor(self, name): |
| 140 self.__nascent_types.add(self.__ComputeFullName(name)) |
| 141 |
| 142 def __RegisterDescriptor(self, new_descriptor): |
| 143 """Register the given descriptor in this registry.""" |
| 144 if not isinstance(new_descriptor, ( |
| 145 extended_descriptor.ExtendedMessageDescriptor, |
| 146 extended_descriptor.ExtendedEnumDescriptor)): |
| 147 raise ValueError('Cannot add descriptor of type %s' % ( |
| 148 type(new_descriptor),)) |
| 149 full_name = self.__ComputeFullName(new_descriptor.name) |
| 150 if full_name in self.__message_registry: |
| 151 raise ValueError( |
| 152 'Attempt to re-register descriptor %s' % full_name) |
| 153 if full_name not in self.__nascent_types: |
| 154 raise ValueError('Directly adding types is not supported') |
| 155 new_descriptor.full_name = full_name |
| 156 self.__message_registry[full_name] = new_descriptor |
| 157 if isinstance(new_descriptor, |
| 158 extended_descriptor.ExtendedMessageDescriptor): |
| 159 self.__current_env.message_types.append(new_descriptor) |
| 160 elif isinstance(new_descriptor, |
| 161 extended_descriptor.ExtendedEnumDescriptor): |
| 162 self.__current_env.enum_types.append(new_descriptor) |
| 163 self.__unknown_types.discard(full_name) |
| 164 self.__nascent_types.remove(full_name) |
| 165 |
| 166 def LookupDescriptor(self, name): |
| 167 return self.__GetDescriptorByName(name) |
| 168 |
| 169 def LookupDescriptorOrDie(self, name): |
| 170 message_descriptor = self.LookupDescriptor(name) |
| 171 if message_descriptor is None: |
| 172 raise ValueError('No message descriptor named "%s"', name) |
| 173 return message_descriptor |
| 174 |
| 175 def __GetDescriptor(self, name): |
| 176 return self.__GetDescriptorByName(self.__ComputeFullName(name)) |
| 177 |
| 178 def __GetDescriptorByName(self, name): |
| 179 if name in self.__message_registry: |
| 180 return self.__message_registry[name] |
| 181 if name in self.__nascent_types: |
| 182 raise ValueError( |
| 183 'Cannot retrieve type currently being created: %s' % name) |
| 184 return None |
| 185 |
| 186 @contextlib.contextmanager |
| 187 def __DescriptorEnv(self, message_descriptor): |
| 188 # TODO(craigcitro): Typecheck? |
| 189 previous_env = self.__current_env |
| 190 self.__current_path.append(message_descriptor.name) |
| 191 self.__current_env = message_descriptor |
| 192 yield |
| 193 self.__current_path.pop() |
| 194 self.__current_env = previous_env |
| 195 |
| 196 def AddEnumDescriptor(self, name, description, |
| 197 enum_values, enum_descriptions): |
| 198 """Add a new EnumDescriptor named name with the given enum values.""" |
| 199 message = extended_descriptor.ExtendedEnumDescriptor() |
| 200 message.name = self.__names.ClassName(name) |
| 201 message.description = util.CleanDescription(description) |
| 202 self.__DeclareDescriptor(message.name) |
| 203 for index, (enum_name, enum_description) in enumerate( |
| 204 zip(enum_values, enum_descriptions)): |
| 205 enum_value = extended_descriptor.ExtendedEnumValueDescriptor() |
| 206 enum_value.name = self.__names.NormalizeEnumName(enum_name) |
| 207 if enum_value.name != enum_name: |
| 208 message.enum_mappings.append( |
| 209 extended_descriptor.ExtendedEnumDescriptor.JsonEnumMapping( |
| 210 python_name=enum_value.name, json_name=enum_name)) |
| 211 self.__AddImport('from %s import encoding' % |
| 212 self.__base_files_package) |
| 213 enum_value.number = index |
| 214 enum_value.description = util.CleanDescription( |
| 215 enum_description or '<no description>') |
| 216 message.values.append(enum_value) |
| 217 self.__RegisterDescriptor(message) |
| 218 |
| 219 def __DeclareMessageAlias(self, schema, alias_for): |
| 220 """Declare schema as an alias for alias_for.""" |
| 221 # TODO(craigcitro): This is a hack. Remove it. |
| 222 message = extended_descriptor.ExtendedMessageDescriptor() |
| 223 message.name = self.__names.ClassName(schema['id']) |
| 224 message.alias_for = alias_for |
| 225 self.__DeclareDescriptor(message.name) |
| 226 self.__AddImport('from %s import extra_types' % |
| 227 self.__base_files_package) |
| 228 self.__RegisterDescriptor(message) |
| 229 |
| 230 def __AddAdditionalProperties(self, message, schema, properties): |
| 231 """Add an additionalProperties field to message.""" |
| 232 additional_properties_info = schema['additionalProperties'] |
| 233 entries_type_name = self.__AddAdditionalPropertyType( |
| 234 message.name, additional_properties_info) |
| 235 description = util.CleanDescription( |
| 236 additional_properties_info.get('description')) |
| 237 if description is None: |
| 238 description = 'Additional properties of type %s' % message.name |
| 239 attrs = { |
| 240 'items': { |
| 241 '$ref': entries_type_name, |
| 242 }, |
| 243 'description': description, |
| 244 'type': 'array', |
| 245 } |
| 246 field_name = 'additionalProperties' |
| 247 message.fields.append(self.__FieldDescriptorFromProperties( |
| 248 field_name, len(properties) + 1, attrs)) |
| 249 self.__AddImport('from %s import encoding' % self.__base_files_package) |
| 250 message.decorators.append( |
| 251 'encoding.MapUnrecognizedFields(%r)' % field_name) |
| 252 |
| 253 def AddDescriptorFromSchema(self, schema_name, schema): |
| 254 """Add a new MessageDescriptor named schema_name based on schema.""" |
| 255 # TODO(craigcitro): Is schema_name redundant? |
| 256 if self.__GetDescriptor(schema_name): |
| 257 return |
| 258 if schema.get('enum'): |
| 259 self.__DeclareEnum(schema_name, schema) |
| 260 return |
| 261 if schema.get('type') == 'any': |
| 262 self.__DeclareMessageAlias(schema, 'extra_types.JsonValue') |
| 263 return |
| 264 if schema.get('type') != 'object': |
| 265 raise ValueError('Cannot create message descriptors for type %s', |
| 266 schema.get('type')) |
| 267 message = extended_descriptor.ExtendedMessageDescriptor() |
| 268 message.name = self.__names.ClassName(schema['id']) |
| 269 message.description = util.CleanDescription(schema.get( |
| 270 'description', 'A %s object.' % message.name)) |
| 271 self.__DeclareDescriptor(message.name) |
| 272 with self.__DescriptorEnv(message): |
| 273 properties = schema.get('properties', {}) |
| 274 for index, (name, attrs) in enumerate(sorted(properties.items())): |
| 275 field = self.__FieldDescriptorFromProperties( |
| 276 name, index + 1, attrs) |
| 277 message.fields.append(field) |
| 278 if field.name != name: |
| 279 message.field_mappings.append( |
| 280 type(message).JsonFieldMapping( |
| 281 python_name=field.name, json_name=name)) |
| 282 self.__AddImport( |
| 283 'from %s import encoding' % self.__base_files_package) |
| 284 if 'additionalProperties' in schema: |
| 285 self.__AddAdditionalProperties(message, schema, properties) |
| 286 self.__RegisterDescriptor(message) |
| 287 |
| 288 def __AddAdditionalPropertyType(self, name, property_schema): |
| 289 """Add a new nested AdditionalProperty message.""" |
| 290 new_type_name = 'AdditionalProperty' |
| 291 property_schema = dict(property_schema) |
| 292 # We drop the description here on purpose, so the resulting |
| 293 # messages are less repetitive. |
| 294 property_schema.pop('description', None) |
| 295 description = 'An additional property for a %s object.' % name |
| 296 schema = { |
| 297 'id': new_type_name, |
| 298 'type': 'object', |
| 299 'description': description, |
| 300 'properties': { |
| 301 'key': { |
| 302 'type': 'string', |
| 303 'description': 'Name of the additional property.', |
| 304 }, |
| 305 'value': property_schema, |
| 306 }, |
| 307 } |
| 308 self.AddDescriptorFromSchema(new_type_name, schema) |
| 309 return new_type_name |
| 310 |
| 311 def __AddEntryType(self, entry_type_name, entry_schema, parent_name): |
| 312 """Add a type for a list entry.""" |
| 313 entry_schema.pop('description', None) |
| 314 description = 'Single entry in a %s.' % parent_name |
| 315 schema = { |
| 316 'id': entry_type_name, |
| 317 'type': 'object', |
| 318 'description': description, |
| 319 'properties': { |
| 320 'entry': { |
| 321 'type': 'array', |
| 322 'items': entry_schema, |
| 323 }, |
| 324 }, |
| 325 } |
| 326 self.AddDescriptorFromSchema(entry_type_name, schema) |
| 327 return entry_type_name |
| 328 |
| 329 def __FieldDescriptorFromProperties(self, name, index, attrs): |
| 330 """Create a field descriptor for these attrs.""" |
| 331 field = descriptor.FieldDescriptor() |
| 332 field.name = self.__names.CleanName(name) |
| 333 field.number = index |
| 334 field.label = self.__ComputeLabel(attrs) |
| 335 new_type_name_hint = self.__names.ClassName( |
| 336 '%sValue' % self.__names.ClassName(name)) |
| 337 type_info = self.__GetTypeInfo(attrs, new_type_name_hint) |
| 338 field.type_name = type_info.type_name |
| 339 field.variant = type_info.variant |
| 340 if 'default' in attrs: |
| 341 # TODO(craigcitro): Correctly handle non-primitive default values. |
| 342 default = attrs['default'] |
| 343 if not (field.type_name == 'string' or |
| 344 field.variant == messages.Variant.ENUM): |
| 345 default = str(json.loads(default)) |
| 346 if field.variant == messages.Variant.ENUM: |
| 347 default = self.__names.NormalizeEnumName(default) |
| 348 field.default_value = default |
| 349 extended_field = extended_descriptor.ExtendedFieldDescriptor() |
| 350 extended_field.name = field.name |
| 351 extended_field.description = util.CleanDescription( |
| 352 attrs.get('description', 'A %s attribute.' % field.type_name)) |
| 353 extended_field.field_descriptor = field |
| 354 return extended_field |
| 355 |
| 356 @staticmethod |
| 357 def __ComputeLabel(attrs): |
| 358 if attrs.get('required', False): |
| 359 return descriptor.FieldDescriptor.Label.REQUIRED |
| 360 elif attrs.get('type') == 'array': |
| 361 return descriptor.FieldDescriptor.Label.REPEATED |
| 362 elif attrs.get('repeated'): |
| 363 return descriptor.FieldDescriptor.Label.REPEATED |
| 364 return descriptor.FieldDescriptor.Label.OPTIONAL |
| 365 |
| 366 def __DeclareEnum(self, enum_name, attrs): |
| 367 description = util.CleanDescription(attrs.get('description', '')) |
| 368 enum_values = attrs['enum'] |
| 369 enum_descriptions = attrs.get( |
| 370 'enumDescriptions', [''] * len(enum_values)) |
| 371 self.AddEnumDescriptor(enum_name, description, |
| 372 enum_values, enum_descriptions) |
| 373 self.__AddIfUnknown(enum_name) |
| 374 return TypeInfo(type_name=enum_name, variant=messages.Variant.ENUM) |
| 375 |
| 376 def __AddIfUnknown(self, type_name): |
| 377 type_name = self.__names.ClassName(type_name) |
| 378 full_type_name = self.__ComputeFullName(type_name) |
| 379 if (full_type_name not in self.__message_registry.keys() and |
| 380 type_name not in self.__message_registry.keys()): |
| 381 self.__unknown_types.add(type_name) |
| 382 |
| 383 def __GetTypeInfo(self, attrs, name_hint): |
| 384 """Return a TypeInfo object for attrs, creating one if needed.""" |
| 385 |
| 386 type_ref = self.__names.ClassName(attrs.get('$ref')) |
| 387 type_name = attrs.get('type') |
| 388 if not (type_ref or type_name): |
| 389 raise ValueError('No type found for %s' % attrs) |
| 390 |
| 391 if type_ref: |
| 392 self.__AddIfUnknown(type_ref) |
| 393 # We don't actually know this is a message -- it might be an |
| 394 # enum. However, we can't check that until we've created all the |
| 395 # types, so we come back and fix this up later. |
| 396 return TypeInfo( |
| 397 type_name=type_ref, variant=messages.Variant.MESSAGE) |
| 398 |
| 399 if 'enum' in attrs: |
| 400 enum_name = '%sValuesEnum' % name_hint |
| 401 return self.__DeclareEnum(enum_name, attrs) |
| 402 |
| 403 if 'format' in attrs: |
| 404 type_info = self.PRIMITIVE_FORMAT_MAP.get(attrs['format']) |
| 405 if type_info is None: |
| 406 # If we don't recognize the format, the spec says we fall back |
| 407 # to just using the type name. |
| 408 if type_name in self.PRIMITIVE_TYPE_INFO_MAP: |
| 409 return self.PRIMITIVE_TYPE_INFO_MAP[type_name] |
| 410 raise ValueError('Unknown type/format "%s"/"%s"' % ( |
| 411 attrs['format'], type_name)) |
| 412 if type_info.type_name.startswith(( |
| 413 'apitools.base.protorpclite.message_types.', |
| 414 'message_types.')): |
| 415 self.__AddImport( |
| 416 'from %s import message_types as _message_types' % |
| 417 self.__protorpc_package) |
| 418 if type_info.type_name.startswith('extra_types.'): |
| 419 self.__AddImport( |
| 420 'from %s import extra_types' % self.__base_files_package) |
| 421 return type_info |
| 422 |
| 423 if type_name in self.PRIMITIVE_TYPE_INFO_MAP: |
| 424 type_info = self.PRIMITIVE_TYPE_INFO_MAP[type_name] |
| 425 return type_info |
| 426 |
| 427 if type_name == 'array': |
| 428 items = attrs.get('items') |
| 429 if not items: |
| 430 raise ValueError('Array type with no item type: %s' % attrs) |
| 431 entry_name_hint = self.__names.ClassName( |
| 432 items.get('title') or '%sListEntry' % name_hint) |
| 433 entry_label = self.__ComputeLabel(items) |
| 434 if entry_label == descriptor.FieldDescriptor.Label.REPEATED: |
| 435 parent_name = self.__names.ClassName( |
| 436 items.get('title') or name_hint) |
| 437 entry_type_name = self.__AddEntryType( |
| 438 entry_name_hint, items.get('items'), parent_name) |
| 439 return TypeInfo(type_name=entry_type_name, |
| 440 variant=messages.Variant.MESSAGE) |
| 441 else: |
| 442 return self.__GetTypeInfo(items, entry_name_hint) |
| 443 elif type_name == 'any': |
| 444 self.__AddImport('from %s import extra_types' % |
| 445 self.__base_files_package) |
| 446 return self.PRIMITIVE_TYPE_INFO_MAP['any'] |
| 447 elif type_name == 'object': |
| 448 # TODO(craigcitro): Think of a better way to come up with names. |
| 449 if not name_hint: |
| 450 raise ValueError( |
| 451 'Cannot create subtype without some name hint') |
| 452 schema = dict(attrs) |
| 453 schema['id'] = name_hint |
| 454 self.AddDescriptorFromSchema(name_hint, schema) |
| 455 self.__AddIfUnknown(name_hint) |
| 456 return TypeInfo( |
| 457 type_name=name_hint, variant=messages.Variant.MESSAGE) |
| 458 |
| 459 raise ValueError('Unknown type: %s' % type_name) |
| 460 |
| 461 def FixupMessageFields(self): |
| 462 for message_type in self.file_descriptor.message_types: |
| 463 self._FixupMessage(message_type) |
| 464 |
| 465 def _FixupMessage(self, message_type): |
| 466 with self.__DescriptorEnv(message_type): |
| 467 for field in message_type.fields: |
| 468 if field.field_descriptor.variant == messages.Variant.MESSAGE: |
| 469 field_type_name = field.field_descriptor.type_name |
| 470 field_type = self.LookupDescriptor(field_type_name) |
| 471 if isinstance(field_type, |
| 472 extended_descriptor.ExtendedEnumDescriptor): |
| 473 field.field_descriptor.variant = messages.Variant.ENUM |
| 474 for submessage_type in message_type.message_types: |
| 475 self._FixupMessage(submessage_type) |
OLD | NEW |