Index: third_party/protobuf/python/google/protobuf/descriptor_pool.py |
diff --git a/third_party/protobuf/python/google/protobuf/descriptor_pool.py b/third_party/protobuf/python/google/protobuf/descriptor_pool.py |
index 20a33701720ee1ed4102a5b8d9906e792f9915cf..fc3a7f4404442e46921ba50bca834237605e77c1 100644 |
--- a/third_party/protobuf/python/google/protobuf/descriptor_pool.py |
+++ b/third_party/protobuf/python/google/protobuf/descriptor_pool.py |
@@ -57,12 +57,14 @@ directly instead of this class. |
__author__ = 'matthewtoia@google.com (Matt Toia)' |
+import collections |
+ |
from google.protobuf import descriptor |
from google.protobuf import descriptor_database |
from google.protobuf import text_encoding |
-_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS |
+_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access |
def _NormalizeFullyQualifiedName(name): |
@@ -80,6 +82,22 @@ def _NormalizeFullyQualifiedName(name): |
return name.lstrip('.') |
+def _OptionsOrNone(descriptor_proto): |
+ """Returns the value of the field `options`, or None if it is not set.""" |
+ if descriptor_proto.HasField('options'): |
+ return descriptor_proto.options |
+ else: |
+ return None |
+ |
+ |
+def _IsMessageSetExtension(field): |
+ return (field.is_extension and |
+ field.containing_type.has_options and |
+ field.containing_type.GetOptions().message_set_wire_format and |
+ field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and |
+ field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL) |
+ |
+ |
class DescriptorPool(object): |
"""A collection of protobufs dynamically constructed by descriptor protos.""" |
@@ -107,6 +125,12 @@ class DescriptorPool(object): |
self._descriptors = {} |
self._enum_descriptors = {} |
self._file_descriptors = {} |
+ self._toplevel_extensions = {} |
+ # We store extensions in two two-level mappings: The first key is the |
+ # descriptor of the message being extended, the second key is the extension |
+ # full name or its tag number. |
+ self._extensions_by_name = collections.defaultdict(dict) |
+ self._extensions_by_number = collections.defaultdict(dict) |
def Add(self, file_desc_proto): |
"""Adds the FileDescriptorProto and its types to this pool. |
@@ -162,6 +186,48 @@ class DescriptorPool(object): |
self._enum_descriptors[enum_desc.full_name] = enum_desc |
self.AddFileDescriptor(enum_desc.file) |
+ def AddExtensionDescriptor(self, extension): |
+ """Adds a FieldDescriptor describing an extension to the pool. |
+ |
+ Args: |
+ extension: A FieldDescriptor. |
+ |
+ Raises: |
+ AssertionError: when another extension with the same number extends the |
+ same message. |
+ TypeError: when the specified extension is not a |
+ descriptor.FieldDescriptor. |
+ """ |
+ if not (isinstance(extension, descriptor.FieldDescriptor) and |
+ extension.is_extension): |
+ raise TypeError('Expected an extension descriptor.') |
+ |
+ if extension.extension_scope is None: |
+ self._toplevel_extensions[extension.full_name] = extension |
+ |
+ try: |
+ existing_desc = self._extensions_by_number[ |
+ extension.containing_type][extension.number] |
+ except KeyError: |
+ pass |
+ else: |
+ if extension is not existing_desc: |
+ raise AssertionError( |
+ 'Extensions "%s" and "%s" both try to extend message type "%s" ' |
+ 'with field number %d.' % |
+ (extension.full_name, existing_desc.full_name, |
+ extension.containing_type.full_name, extension.number)) |
+ |
+ self._extensions_by_number[extension.containing_type][ |
+ extension.number] = extension |
+ self._extensions_by_name[extension.containing_type][ |
+ extension.full_name] = extension |
+ |
+ # Also register MessageSet extensions with the type name. |
+ if _IsMessageSetExtension(extension): |
+ self._extensions_by_name[extension.containing_type][ |
+ extension.message_type.full_name] = extension |
+ |
def AddFileDescriptor(self, file_desc): |
"""Adds a FileDescriptor to the pool, non-recursively. |
@@ -294,6 +360,14 @@ class DescriptorPool(object): |
A FieldDescriptor, describing the named extension. |
""" |
full_name = _NormalizeFullyQualifiedName(full_name) |
+ try: |
+ # The proto compiler does not give any link between the FileDescriptor |
+ # and top-level extensions unless the FileDescriptorProto is added to |
+ # the DescriptorDatabase, but this can impact memory usage. |
+ # So we registered these extensions by name explicitly. |
+ return self._toplevel_extensions[full_name] |
+ except KeyError: |
+ pass |
message_name, _, extension_name = full_name.rpartition('.') |
try: |
# Most extensions are nested inside a message. |
@@ -303,6 +377,39 @@ class DescriptorPool(object): |
scope = self.FindFileContainingSymbol(full_name) |
return scope.extensions_by_name[extension_name] |
+ def FindExtensionByNumber(self, message_descriptor, number): |
+ """Gets the extension of the specified message with the specified number. |
+ |
+ Extensions have to be registered to this pool by calling |
+ AddExtensionDescriptor. |
+ |
+ Args: |
+ message_descriptor: descriptor of the extended message. |
+ number: integer, number of the extension field. |
+ |
+ Returns: |
+ A FieldDescriptor describing the extension. |
+ |
+ Raise: |
+ KeyError: when no extension with the given number is known for the |
+ specified message. |
+ """ |
+ return self._extensions_by_number[message_descriptor][number] |
+ |
+ def FindAllExtensions(self, message_descriptor): |
+ """Gets all the known extension of a given message. |
+ |
+ Extensions have to be registered to this pool by calling |
+ AddExtensionDescriptor. |
+ |
+ Args: |
+ message_descriptor: descriptor of the extended message. |
+ |
+ Returns: |
+ A list of FieldDescriptor describing the extensions. |
+ """ |
+ return list(self._extensions_by_number[message_descriptor].values()) |
+ |
def _ConvertFileProtoToFileDescriptor(self, file_proto): |
"""Creates a FileDescriptor from a proto or returns a cached copy. |
@@ -326,73 +433,61 @@ class DescriptorPool(object): |
name=file_proto.name, |
package=file_proto.package, |
syntax=file_proto.syntax, |
- options=file_proto.options, |
+ options=_OptionsOrNone(file_proto), |
serialized_pb=file_proto.SerializeToString(), |
dependencies=direct_deps, |
public_dependencies=public_deps) |
- if _USE_C_DESCRIPTORS: |
- # When using C++ descriptors, all objects defined in the file were added |
- # to the C++ database when the FileDescriptor was built above. |
- # Just add them to this descriptor pool. |
- def _AddMessageDescriptor(message_desc): |
- self._descriptors[message_desc.full_name] = message_desc |
- for nested in message_desc.nested_types: |
- _AddMessageDescriptor(nested) |
- for enum_type in message_desc.enum_types: |
- _AddEnumDescriptor(enum_type) |
- def _AddEnumDescriptor(enum_desc): |
- self._enum_descriptors[enum_desc.full_name] = enum_desc |
- for message_type in file_descriptor.message_types_by_name.values(): |
- _AddMessageDescriptor(message_type) |
- for enum_type in file_descriptor.enum_types_by_name.values(): |
- _AddEnumDescriptor(enum_type) |
+ scope = {} |
+ |
+ # This loop extracts all the message and enum types from all the |
+ # dependencies of the file_proto. This is necessary to create the |
+ # scope of available message types when defining the passed in |
+ # file proto. |
+ for dependency in built_deps: |
+ scope.update(self._ExtractSymbols( |
+ dependency.message_types_by_name.values())) |
+ scope.update((_PrefixWithDot(enum.full_name), enum) |
+ for enum in dependency.enum_types_by_name.values()) |
+ |
+ for message_type in file_proto.message_type: |
+ message_desc = self._ConvertMessageDescriptor( |
+ message_type, file_proto.package, file_descriptor, scope, |
+ file_proto.syntax) |
+ file_descriptor.message_types_by_name[message_desc.name] = ( |
+ message_desc) |
+ |
+ for enum_type in file_proto.enum_type: |
+ file_descriptor.enum_types_by_name[enum_type.name] = ( |
+ self._ConvertEnumDescriptor(enum_type, file_proto.package, |
+ file_descriptor, None, scope)) |
+ |
+ for index, extension_proto in enumerate(file_proto.extension): |
+ extension_desc = self._MakeFieldDescriptor( |
+ extension_proto, file_proto.package, index, is_extension=True) |
+ extension_desc.containing_type = self._GetTypeFromScope( |
+ file_descriptor.package, extension_proto.extendee, scope) |
+ self._SetFieldType(extension_proto, extension_desc, |
+ file_descriptor.package, scope) |
+ file_descriptor.extensions_by_name[extension_desc.name] = ( |
+ extension_desc) |
+ |
+ for desc_proto in file_proto.message_type: |
+ self._SetAllFieldTypes(file_proto.package, desc_proto, scope) |
+ |
+ if file_proto.package: |
+ desc_proto_prefix = _PrefixWithDot(file_proto.package) |
else: |
- scope = {} |
- |
- # This loop extracts all the message and enum types from all the |
- # dependencies of the file_proto. This is necessary to create the |
- # scope of available message types when defining the passed in |
- # file proto. |
- for dependency in built_deps: |
- scope.update(self._ExtractSymbols( |
- dependency.message_types_by_name.values())) |
- scope.update((_PrefixWithDot(enum.full_name), enum) |
- for enum in dependency.enum_types_by_name.values()) |
- |
- for message_type in file_proto.message_type: |
- message_desc = self._ConvertMessageDescriptor( |
- message_type, file_proto.package, file_descriptor, scope, |
- file_proto.syntax) |
- file_descriptor.message_types_by_name[message_desc.name] = ( |
- message_desc) |
- |
- for enum_type in file_proto.enum_type: |
- file_descriptor.enum_types_by_name[enum_type.name] = ( |
- self._ConvertEnumDescriptor(enum_type, file_proto.package, |
- file_descriptor, None, scope)) |
- |
- for index, extension_proto in enumerate(file_proto.extension): |
- extension_desc = self._MakeFieldDescriptor( |
- extension_proto, file_proto.package, index, is_extension=True) |
- extension_desc.containing_type = self._GetTypeFromScope( |
- file_descriptor.package, extension_proto.extendee, scope) |
- self._SetFieldType(extension_proto, extension_desc, |
- file_descriptor.package, scope) |
- file_descriptor.extensions_by_name[extension_desc.name] = ( |
- extension_desc) |
- |
- for desc_proto in file_proto.message_type: |
- self._SetAllFieldTypes(file_proto.package, desc_proto, scope) |
- |
- if file_proto.package: |
- desc_proto_prefix = _PrefixWithDot(file_proto.package) |
- else: |
- desc_proto_prefix = '' |
+ desc_proto_prefix = '' |
- for desc_proto in file_proto.message_type: |
- desc = self._GetTypeFromScope( |
- desc_proto_prefix, desc_proto.name, scope) |
- file_descriptor.message_types_by_name[desc_proto.name] = desc |
+ for desc_proto in file_proto.message_type: |
+ desc = self._GetTypeFromScope( |
+ desc_proto_prefix, desc_proto.name, scope) |
+ file_descriptor.message_types_by_name[desc_proto.name] = desc |
+ |
+ for index, service_proto in enumerate(file_proto.service): |
+ file_descriptor.services_by_name[service_proto.name] = ( |
+ self._MakeServiceDescriptor(service_proto, index, scope, |
+ file_proto.package, file_descriptor)) |
self.Add(file_proto) |
self._file_descriptors[file_proto.name] = file_descriptor |
@@ -408,6 +503,7 @@ class DescriptorPool(object): |
package: The package the proto should be located in. |
file_desc: The file containing this message. |
scope: Dict mapping short and full symbols to message and enum types. |
+ syntax: string indicating syntax of the file ("proto2" or "proto3") |
Returns: |
The added descriptor. |
@@ -441,7 +537,7 @@ class DescriptorPool(object): |
for index, extension in enumerate(desc_proto.extension)] |
oneofs = [ |
descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)), |
- index, None, []) |
+ index, None, [], desc.options) |
for index, desc in enumerate(desc_proto.oneof_decl)] |
extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range] |
if extension_ranges: |
@@ -458,7 +554,7 @@ class DescriptorPool(object): |
nested_types=nested, |
enum_types=enums, |
extensions=extensions, |
- options=desc_proto.options, |
+ options=_OptionsOrNone(desc_proto), |
is_extendable=is_extendable, |
extension_ranges=extension_ranges, |
file=file_desc, |
@@ -512,7 +608,7 @@ class DescriptorPool(object): |
file=file_desc, |
values=values, |
containing_type=containing_type, |
- options=enum_proto.options) |
+ options=_OptionsOrNone(enum_proto)) |
scope['.%s' % enum_name] = desc |
self._enum_descriptors[enum_name] = desc |
return desc |
@@ -557,7 +653,7 @@ class DescriptorPool(object): |
default_value=None, |
is_extension=is_extension, |
extension_scope=None, |
- options=field_proto.options) |
+ options=_OptionsOrNone(field_proto)) |
def _SetAllFieldTypes(self, package, desc_proto, scope): |
"""Sets all the descriptor's fields's types. |
@@ -676,9 +772,67 @@ class DescriptorPool(object): |
name=value_proto.name, |
index=index, |
number=value_proto.number, |
- options=value_proto.options, |
+ options=_OptionsOrNone(value_proto), |
type=None) |
+ def _MakeServiceDescriptor(self, service_proto, service_index, scope, |
+ package, file_desc): |
+ """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto. |
+ |
+ Args: |
+ service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message. |
+ service_index: The index of the service in the File. |
+ scope: Dict mapping short and full symbols to message and enum types. |
+ package: Optional package name for the new message EnumDescriptor. |
+ file_desc: The file containing the service descriptor. |
+ |
+ Returns: |
+ The added descriptor. |
+ """ |
+ |
+ if package: |
+ service_name = '.'.join((package, service_proto.name)) |
+ else: |
+ service_name = service_proto.name |
+ |
+ methods = [self._MakeMethodDescriptor(method_proto, service_name, package, |
+ scope, index) |
+ for index, method_proto in enumerate(service_proto.method)] |
+ desc = descriptor.ServiceDescriptor(name=service_proto.name, |
+ full_name=service_name, |
+ index=service_index, |
+ methods=methods, |
+ options=_OptionsOrNone(service_proto), |
+ file=file_desc) |
+ return desc |
+ |
+ def _MakeMethodDescriptor(self, method_proto, service_name, package, scope, |
+ index): |
+ """Creates a method descriptor from a MethodDescriptorProto. |
+ |
+ Args: |
+ method_proto: The proto describing the method. |
+ service_name: The name of the containing service. |
+ package: Optional package name to look up for types. |
+ scope: Scope containing available types. |
+ index: Index of the method in the service. |
+ |
+ Returns: |
+ An initialized MethodDescriptor object. |
+ """ |
+ full_name = '.'.join((service_name, method_proto.name)) |
+ input_type = self._GetTypeFromScope( |
+ package, method_proto.input_type, scope) |
+ output_type = self._GetTypeFromScope( |
+ package, method_proto.output_type, scope) |
+ return descriptor.MethodDescriptor(name=method_proto.name, |
+ full_name=full_name, |
+ index=index, |
+ containing_service=None, |
+ input_type=input_type, |
+ output_type=output_type, |
+ options=_OptionsOrNone(method_proto)) |
+ |
def _ExtractSymbols(self, descriptors): |
"""Pulls out all the symbols from descriptor protos. |