| Index: third_party/protobuf/python/google/protobuf/descriptor_pool.py
|
| diff --git a/third_party/protobuf/python/google/protobuf/descriptor_pool.py b/third_party/protobuf/python/google/protobuf/descriptor_pool.py
|
| index 20a33701720ee1ed4102a5b8d9906e792f9915cf..fc3a7f4404442e46921ba50bca834237605e77c1 100644
|
| --- a/third_party/protobuf/python/google/protobuf/descriptor_pool.py
|
| +++ b/third_party/protobuf/python/google/protobuf/descriptor_pool.py
|
| @@ -57,12 +57,14 @@ directly instead of this class.
|
|
|
| __author__ = 'matthewtoia@google.com (Matt Toia)'
|
|
|
| +import collections
|
| +
|
| from google.protobuf import descriptor
|
| from google.protobuf import descriptor_database
|
| from google.protobuf import text_encoding
|
|
|
|
|
| -_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS
|
| +_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access
|
|
|
|
|
| def _NormalizeFullyQualifiedName(name):
|
| @@ -80,6 +82,22 @@ def _NormalizeFullyQualifiedName(name):
|
| return name.lstrip('.')
|
|
|
|
|
| +def _OptionsOrNone(descriptor_proto):
|
| + """Returns the value of the field `options`, or None if it is not set."""
|
| + if descriptor_proto.HasField('options'):
|
| + return descriptor_proto.options
|
| + else:
|
| + return None
|
| +
|
| +
|
| +def _IsMessageSetExtension(field):
|
| + return (field.is_extension and
|
| + field.containing_type.has_options and
|
| + field.containing_type.GetOptions().message_set_wire_format and
|
| + field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
|
| + field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL)
|
| +
|
| +
|
| class DescriptorPool(object):
|
| """A collection of protobufs dynamically constructed by descriptor protos."""
|
|
|
| @@ -107,6 +125,12 @@ class DescriptorPool(object):
|
| self._descriptors = {}
|
| self._enum_descriptors = {}
|
| self._file_descriptors = {}
|
| + self._toplevel_extensions = {}
|
| + # We store extensions in two two-level mappings: The first key is the
|
| + # descriptor of the message being extended, the second key is the extension
|
| + # full name or its tag number.
|
| + self._extensions_by_name = collections.defaultdict(dict)
|
| + self._extensions_by_number = collections.defaultdict(dict)
|
|
|
| def Add(self, file_desc_proto):
|
| """Adds the FileDescriptorProto and its types to this pool.
|
| @@ -162,6 +186,48 @@ class DescriptorPool(object):
|
| self._enum_descriptors[enum_desc.full_name] = enum_desc
|
| self.AddFileDescriptor(enum_desc.file)
|
|
|
| + def AddExtensionDescriptor(self, extension):
|
| + """Adds a FieldDescriptor describing an extension to the pool.
|
| +
|
| + Args:
|
| + extension: A FieldDescriptor.
|
| +
|
| + Raises:
|
| + AssertionError: when another extension with the same number extends the
|
| + same message.
|
| + TypeError: when the specified extension is not a
|
| + descriptor.FieldDescriptor.
|
| + """
|
| + if not (isinstance(extension, descriptor.FieldDescriptor) and
|
| + extension.is_extension):
|
| + raise TypeError('Expected an extension descriptor.')
|
| +
|
| + if extension.extension_scope is None:
|
| + self._toplevel_extensions[extension.full_name] = extension
|
| +
|
| + try:
|
| + existing_desc = self._extensions_by_number[
|
| + extension.containing_type][extension.number]
|
| + except KeyError:
|
| + pass
|
| + else:
|
| + if extension is not existing_desc:
|
| + raise AssertionError(
|
| + 'Extensions "%s" and "%s" both try to extend message type "%s" '
|
| + 'with field number %d.' %
|
| + (extension.full_name, existing_desc.full_name,
|
| + extension.containing_type.full_name, extension.number))
|
| +
|
| + self._extensions_by_number[extension.containing_type][
|
| + extension.number] = extension
|
| + self._extensions_by_name[extension.containing_type][
|
| + extension.full_name] = extension
|
| +
|
| + # Also register MessageSet extensions with the type name.
|
| + if _IsMessageSetExtension(extension):
|
| + self._extensions_by_name[extension.containing_type][
|
| + extension.message_type.full_name] = extension
|
| +
|
| def AddFileDescriptor(self, file_desc):
|
| """Adds a FileDescriptor to the pool, non-recursively.
|
|
|
| @@ -294,6 +360,14 @@ class DescriptorPool(object):
|
| A FieldDescriptor, describing the named extension.
|
| """
|
| full_name = _NormalizeFullyQualifiedName(full_name)
|
| + try:
|
| + # The proto compiler does not give any link between the FileDescriptor
|
| + # and top-level extensions unless the FileDescriptorProto is added to
|
| + # the DescriptorDatabase, but this can impact memory usage.
|
| + # So we registered these extensions by name explicitly.
|
| + return self._toplevel_extensions[full_name]
|
| + except KeyError:
|
| + pass
|
| message_name, _, extension_name = full_name.rpartition('.')
|
| try:
|
| # Most extensions are nested inside a message.
|
| @@ -303,6 +377,39 @@ class DescriptorPool(object):
|
| scope = self.FindFileContainingSymbol(full_name)
|
| return scope.extensions_by_name[extension_name]
|
|
|
| + def FindExtensionByNumber(self, message_descriptor, number):
|
| + """Gets the extension of the specified message with the specified number.
|
| +
|
| + Extensions have to be registered to this pool by calling
|
| + AddExtensionDescriptor.
|
| +
|
| + Args:
|
| + message_descriptor: descriptor of the extended message.
|
| + number: integer, number of the extension field.
|
| +
|
| + Returns:
|
| + A FieldDescriptor describing the extension.
|
| +
|
| + Raise:
|
| + KeyError: when no extension with the given number is known for the
|
| + specified message.
|
| + """
|
| + return self._extensions_by_number[message_descriptor][number]
|
| +
|
| + def FindAllExtensions(self, message_descriptor):
|
| + """Gets all the known extension of a given message.
|
| +
|
| + Extensions have to be registered to this pool by calling
|
| + AddExtensionDescriptor.
|
| +
|
| + Args:
|
| + message_descriptor: descriptor of the extended message.
|
| +
|
| + Returns:
|
| + A list of FieldDescriptor describing the extensions.
|
| + """
|
| + return list(self._extensions_by_number[message_descriptor].values())
|
| +
|
| def _ConvertFileProtoToFileDescriptor(self, file_proto):
|
| """Creates a FileDescriptor from a proto or returns a cached copy.
|
|
|
| @@ -326,73 +433,61 @@ class DescriptorPool(object):
|
| name=file_proto.name,
|
| package=file_proto.package,
|
| syntax=file_proto.syntax,
|
| - options=file_proto.options,
|
| + options=_OptionsOrNone(file_proto),
|
| serialized_pb=file_proto.SerializeToString(),
|
| dependencies=direct_deps,
|
| public_dependencies=public_deps)
|
| - if _USE_C_DESCRIPTORS:
|
| - # When using C++ descriptors, all objects defined in the file were added
|
| - # to the C++ database when the FileDescriptor was built above.
|
| - # Just add them to this descriptor pool.
|
| - def _AddMessageDescriptor(message_desc):
|
| - self._descriptors[message_desc.full_name] = message_desc
|
| - for nested in message_desc.nested_types:
|
| - _AddMessageDescriptor(nested)
|
| - for enum_type in message_desc.enum_types:
|
| - _AddEnumDescriptor(enum_type)
|
| - def _AddEnumDescriptor(enum_desc):
|
| - self._enum_descriptors[enum_desc.full_name] = enum_desc
|
| - for message_type in file_descriptor.message_types_by_name.values():
|
| - _AddMessageDescriptor(message_type)
|
| - for enum_type in file_descriptor.enum_types_by_name.values():
|
| - _AddEnumDescriptor(enum_type)
|
| + scope = {}
|
| +
|
| + # This loop extracts all the message and enum types from all the
|
| + # dependencies of the file_proto. This is necessary to create the
|
| + # scope of available message types when defining the passed in
|
| + # file proto.
|
| + for dependency in built_deps:
|
| + scope.update(self._ExtractSymbols(
|
| + dependency.message_types_by_name.values()))
|
| + scope.update((_PrefixWithDot(enum.full_name), enum)
|
| + for enum in dependency.enum_types_by_name.values())
|
| +
|
| + for message_type in file_proto.message_type:
|
| + message_desc = self._ConvertMessageDescriptor(
|
| + message_type, file_proto.package, file_descriptor, scope,
|
| + file_proto.syntax)
|
| + file_descriptor.message_types_by_name[message_desc.name] = (
|
| + message_desc)
|
| +
|
| + for enum_type in file_proto.enum_type:
|
| + file_descriptor.enum_types_by_name[enum_type.name] = (
|
| + self._ConvertEnumDescriptor(enum_type, file_proto.package,
|
| + file_descriptor, None, scope))
|
| +
|
| + for index, extension_proto in enumerate(file_proto.extension):
|
| + extension_desc = self._MakeFieldDescriptor(
|
| + extension_proto, file_proto.package, index, is_extension=True)
|
| + extension_desc.containing_type = self._GetTypeFromScope(
|
| + file_descriptor.package, extension_proto.extendee, scope)
|
| + self._SetFieldType(extension_proto, extension_desc,
|
| + file_descriptor.package, scope)
|
| + file_descriptor.extensions_by_name[extension_desc.name] = (
|
| + extension_desc)
|
| +
|
| + for desc_proto in file_proto.message_type:
|
| + self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
|
| +
|
| + if file_proto.package:
|
| + desc_proto_prefix = _PrefixWithDot(file_proto.package)
|
| else:
|
| - scope = {}
|
| -
|
| - # This loop extracts all the message and enum types from all the
|
| - # dependencies of the file_proto. This is necessary to create the
|
| - # scope of available message types when defining the passed in
|
| - # file proto.
|
| - for dependency in built_deps:
|
| - scope.update(self._ExtractSymbols(
|
| - dependency.message_types_by_name.values()))
|
| - scope.update((_PrefixWithDot(enum.full_name), enum)
|
| - for enum in dependency.enum_types_by_name.values())
|
| -
|
| - for message_type in file_proto.message_type:
|
| - message_desc = self._ConvertMessageDescriptor(
|
| - message_type, file_proto.package, file_descriptor, scope,
|
| - file_proto.syntax)
|
| - file_descriptor.message_types_by_name[message_desc.name] = (
|
| - message_desc)
|
| -
|
| - for enum_type in file_proto.enum_type:
|
| - file_descriptor.enum_types_by_name[enum_type.name] = (
|
| - self._ConvertEnumDescriptor(enum_type, file_proto.package,
|
| - file_descriptor, None, scope))
|
| -
|
| - for index, extension_proto in enumerate(file_proto.extension):
|
| - extension_desc = self._MakeFieldDescriptor(
|
| - extension_proto, file_proto.package, index, is_extension=True)
|
| - extension_desc.containing_type = self._GetTypeFromScope(
|
| - file_descriptor.package, extension_proto.extendee, scope)
|
| - self._SetFieldType(extension_proto, extension_desc,
|
| - file_descriptor.package, scope)
|
| - file_descriptor.extensions_by_name[extension_desc.name] = (
|
| - extension_desc)
|
| -
|
| - for desc_proto in file_proto.message_type:
|
| - self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
|
| -
|
| - if file_proto.package:
|
| - desc_proto_prefix = _PrefixWithDot(file_proto.package)
|
| - else:
|
| - desc_proto_prefix = ''
|
| + desc_proto_prefix = ''
|
|
|
| - for desc_proto in file_proto.message_type:
|
| - desc = self._GetTypeFromScope(
|
| - desc_proto_prefix, desc_proto.name, scope)
|
| - file_descriptor.message_types_by_name[desc_proto.name] = desc
|
| + for desc_proto in file_proto.message_type:
|
| + desc = self._GetTypeFromScope(
|
| + desc_proto_prefix, desc_proto.name, scope)
|
| + file_descriptor.message_types_by_name[desc_proto.name] = desc
|
| +
|
| + for index, service_proto in enumerate(file_proto.service):
|
| + file_descriptor.services_by_name[service_proto.name] = (
|
| + self._MakeServiceDescriptor(service_proto, index, scope,
|
| + file_proto.package, file_descriptor))
|
|
|
| self.Add(file_proto)
|
| self._file_descriptors[file_proto.name] = file_descriptor
|
| @@ -408,6 +503,7 @@ class DescriptorPool(object):
|
| package: The package the proto should be located in.
|
| file_desc: The file containing this message.
|
| scope: Dict mapping short and full symbols to message and enum types.
|
| + syntax: string indicating syntax of the file ("proto2" or "proto3")
|
|
|
| Returns:
|
| The added descriptor.
|
| @@ -441,7 +537,7 @@ class DescriptorPool(object):
|
| for index, extension in enumerate(desc_proto.extension)]
|
| oneofs = [
|
| descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)),
|
| - index, None, [])
|
| + index, None, [], desc.options)
|
| for index, desc in enumerate(desc_proto.oneof_decl)]
|
| extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
|
| if extension_ranges:
|
| @@ -458,7 +554,7 @@ class DescriptorPool(object):
|
| nested_types=nested,
|
| enum_types=enums,
|
| extensions=extensions,
|
| - options=desc_proto.options,
|
| + options=_OptionsOrNone(desc_proto),
|
| is_extendable=is_extendable,
|
| extension_ranges=extension_ranges,
|
| file=file_desc,
|
| @@ -512,7 +608,7 @@ class DescriptorPool(object):
|
| file=file_desc,
|
| values=values,
|
| containing_type=containing_type,
|
| - options=enum_proto.options)
|
| + options=_OptionsOrNone(enum_proto))
|
| scope['.%s' % enum_name] = desc
|
| self._enum_descriptors[enum_name] = desc
|
| return desc
|
| @@ -557,7 +653,7 @@ class DescriptorPool(object):
|
| default_value=None,
|
| is_extension=is_extension,
|
| extension_scope=None,
|
| - options=field_proto.options)
|
| + options=_OptionsOrNone(field_proto))
|
|
|
| def _SetAllFieldTypes(self, package, desc_proto, scope):
|
| """Sets all the descriptor's fields's types.
|
| @@ -676,9 +772,67 @@ class DescriptorPool(object):
|
| name=value_proto.name,
|
| index=index,
|
| number=value_proto.number,
|
| - options=value_proto.options,
|
| + options=_OptionsOrNone(value_proto),
|
| type=None)
|
|
|
| + def _MakeServiceDescriptor(self, service_proto, service_index, scope,
|
| + package, file_desc):
|
| + """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
|
| +
|
| + Args:
|
| + service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
|
| + service_index: The index of the service in the File.
|
| + scope: Dict mapping short and full symbols to message and enum types.
|
| + package: Optional package name for the new message EnumDescriptor.
|
| + file_desc: The file containing the service descriptor.
|
| +
|
| + Returns:
|
| + The added descriptor.
|
| + """
|
| +
|
| + if package:
|
| + service_name = '.'.join((package, service_proto.name))
|
| + else:
|
| + service_name = service_proto.name
|
| +
|
| + methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
|
| + scope, index)
|
| + for index, method_proto in enumerate(service_proto.method)]
|
| + desc = descriptor.ServiceDescriptor(name=service_proto.name,
|
| + full_name=service_name,
|
| + index=service_index,
|
| + methods=methods,
|
| + options=_OptionsOrNone(service_proto),
|
| + file=file_desc)
|
| + return desc
|
| +
|
| + def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
|
| + index):
|
| + """Creates a method descriptor from a MethodDescriptorProto.
|
| +
|
| + Args:
|
| + method_proto: The proto describing the method.
|
| + service_name: The name of the containing service.
|
| + package: Optional package name to look up for types.
|
| + scope: Scope containing available types.
|
| + index: Index of the method in the service.
|
| +
|
| + Returns:
|
| + An initialized MethodDescriptor object.
|
| + """
|
| + full_name = '.'.join((service_name, method_proto.name))
|
| + input_type = self._GetTypeFromScope(
|
| + package, method_proto.input_type, scope)
|
| + output_type = self._GetTypeFromScope(
|
| + package, method_proto.output_type, scope)
|
| + return descriptor.MethodDescriptor(name=method_proto.name,
|
| + full_name=full_name,
|
| + index=index,
|
| + containing_service=None,
|
| + input_type=input_type,
|
| + output_type=output_type,
|
| + options=_OptionsOrNone(method_proto))
|
| +
|
| def _ExtractSymbols(self, descriptors):
|
| """Pulls out all the symbols from descriptor protos.
|
|
|
|
|