Index: third_party/protobuf/python/google/protobuf/internal/python_message.py |
diff --git a/third_party/protobuf/python/google/protobuf/internal/python_message.py b/third_party/protobuf/python/google/protobuf/internal/python_message.py |
index c1bd1f9c3bde805d926f70a40e4fc70f6d8cb217..f8f73dd20486fb83c5576aea8e8ab7528b2eb25a 100755 |
--- a/third_party/protobuf/python/google/protobuf/internal/python_message.py |
+++ b/third_party/protobuf/python/google/protobuf/internal/python_message.py |
@@ -51,8 +51,8 @@ this file*. |
__author__ = 'robinson@google.com (Will Robinson)' |
from io import BytesIO |
-import struct |
import sys |
+import struct |
import weakref |
import six |
@@ -63,10 +63,7 @@ except ImportError: |
# nothing like hermetic Python. This means lesser control on the system and |
# the six.moves package may be missing (is missing on 20150321 on gMac). Be |
# extra conservative and try to load the old replacement if it fails. |
- try: |
- import copy_reg as copyreg #PY26 |
- except ImportError: |
- import copyreg |
+ import copy_reg as copyreg |
# We use "as" to avoid name collisions with variables. |
from google.protobuf.internal import containers |
@@ -79,6 +76,7 @@ from google.protobuf.internal import well_known_types |
from google.protobuf.internal import wire_format |
from google.protobuf import descriptor as descriptor_mod |
from google.protobuf import message as message_mod |
+from google.protobuf import symbol_database |
from google.protobuf import text_format |
_FieldDescriptor = descriptor_mod.FieldDescriptor |
@@ -100,12 +98,16 @@ class GeneratedProtocolMessageType(type): |
classes at runtime, as in this example: |
mydescriptor = Descriptor(.....) |
- factory = symbol_database.Default() |
- factory.pool.AddDescriptor(mydescriptor) |
- MyProtoClass = factory.GetPrototype(mydescriptor) |
+ class MyProtoClass(Message): |
+ __metaclass__ = GeneratedProtocolMessageType |
+ DESCRIPTOR = mydescriptor |
myproto_instance = MyProtoClass() |
myproto.foo_field = 23 |
... |
+ |
+ The above example will not work for nested types. If you wish to include them, |
+ use reflection.MakeClass() instead of manually instantiating the class in |
+ order to create the appropriate class structure. |
""" |
# Must be consistent with the protocol-compiler code in |
@@ -162,10 +164,12 @@ class GeneratedProtocolMessageType(type): |
""" |
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] |
cls._decoders_by_tag = {} |
+ cls._extensions_by_name = {} |
+ cls._extensions_by_number = {} |
if (descriptor.has_options and |
descriptor.GetOptions().message_set_wire_format): |
cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( |
- decoder.MessageSetItemDecoder(descriptor), None) |
+ decoder.MessageSetItemDecoder(cls._extensions_by_number), None) |
# Attach stuff to each FieldDescriptor for quick lookup later on. |
for field in descriptor.fields: |
@@ -381,15 +385,13 @@ def _GetInitializeDefaultForMap(field): |
if _IsMessageMapField(field): |
def MakeMessageMapDefault(message): |
return containers.MessageMap( |
- message._listener_for_children, value_field.message_type, key_checker, |
- field.message_type) |
+ message._listener_for_children, value_field.message_type, key_checker) |
return MakeMessageMapDefault |
else: |
value_checker = type_checkers.GetTypeChecker(value_field) |
def MakePrimitiveMapDefault(message): |
return containers.ScalarMap( |
- message._listener_for_children, key_checker, value_checker, |
- field.message_type) |
+ message._listener_for_children, key_checker, value_checker) |
return MakePrimitiveMapDefault |
def _DefaultValueConstructorForField(field): |
@@ -745,21 +747,32 @@ def _AddPropertiesForExtensions(descriptor, cls): |
constant_name = extension_name.upper() + "_FIELD_NUMBER" |
setattr(cls, constant_name, extension_field.number) |
- # TODO(amauryfa): Migrate all users of these attributes to functions like |
- # pool.FindExtensionByNumber(descriptor). |
- if descriptor.file is not None: |
- # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. |
- pool = descriptor.file.pool |
- cls._extensions_by_number = pool._extensions_by_number[descriptor] |
- cls._extensions_by_name = pool._extensions_by_name[descriptor] |
def _AddStaticMethods(cls): |
# TODO(robinson): This probably needs to be thread-safe(?) |
def RegisterExtension(extension_handle): |
extension_handle.containing_type = cls.DESCRIPTOR |
- # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. |
- cls.DESCRIPTOR.file.pool.AddExtensionDescriptor(extension_handle) |
_AttachFieldHelpers(cls, extension_handle) |
+ |
+ # Try to insert our extension, failing if an extension with the same number |
+ # already exists. |
+ actual_handle = cls._extensions_by_number.setdefault( |
+ extension_handle.number, extension_handle) |
+ if actual_handle is not extension_handle: |
+ raise AssertionError( |
+ 'Extensions "%s" and "%s" both try to extend message type "%s" with ' |
+ 'field number %d.' % |
+ (extension_handle.full_name, actual_handle.full_name, |
+ cls.DESCRIPTOR.full_name, extension_handle.number)) |
+ |
+ cls._extensions_by_name[extension_handle.full_name] = extension_handle |
+ |
+ handle = extension_handle # avoid line wrapping |
+ if _IsMessageSetExtension(handle): |
+ # MessageSet extension. Also register under type name. |
+ cls._extensions_by_name[ |
+ extension_handle.message_type.full_name] = extension_handle |
+ |
cls.RegisterExtension = staticmethod(RegisterExtension) |
def FromString(s): |
@@ -913,33 +926,26 @@ def _InternalUnpackAny(msg): |
Returns: |
The unpacked message. |
""" |
- # TODO(amauryfa): Don't use the factory of generated messages. |
- # To make Any work with custom factories, use the message factory of the |
- # parent message. |
- # pylint: disable=g-import-not-at-top |
- from google.protobuf import symbol_database |
- factory = symbol_database.Default() |
- |
type_url = msg.type_url |
+ db = symbol_database.Default() |
if not type_url: |
return None |
# TODO(haberman): For now we just strip the hostname. Better logic will be |
# required. |
- type_name = type_url.split('/')[-1] |
- descriptor = factory.pool.FindMessageTypeByName(type_name) |
+ type_name = type_url.split("/")[-1] |
+ descriptor = db.pool.FindMessageTypeByName(type_name) |
if descriptor is None: |
return None |
- message_class = factory.GetPrototype(descriptor) |
+ message_class = db.GetPrototype(descriptor) |
message = message_class() |
message.ParseFromString(msg.value) |
return message |
- |
def _AddEqualsMethod(message_descriptor, cls): |
"""Helper for _AddMessageMethods().""" |
def __eq__(self, other): |
@@ -1217,7 +1223,7 @@ def _AddMergeFromMethod(cls): |
if not isinstance(msg, cls): |
raise TypeError( |
"Parameter to MergeFrom() must be instance of same class: " |
- 'expected %s got %s.' % (cls.__name__, msg.__class__.__name__)) |
+ "expected %s got %s." % (cls.__name__, type(msg).__name__)) |
assert msg is not self |
self._Modified() |