Index: third_party/protobuf/python/google/protobuf/internal/python_message.py |
diff --git a/third_party/protobuf/python/google/protobuf/internal/python_message.py b/third_party/protobuf/python/google/protobuf/internal/python_message.py |
index f8f73dd20486fb83c5576aea8e8ab7528b2eb25a..c1bd1f9c3bde805d926f70a40e4fc70f6d8cb217 100755 |
--- a/third_party/protobuf/python/google/protobuf/internal/python_message.py |
+++ b/third_party/protobuf/python/google/protobuf/internal/python_message.py |
@@ -51,8 +51,8 @@ this file*. |
__author__ = 'robinson@google.com (Will Robinson)' |
from io import BytesIO |
-import sys |
import struct |
+import sys |
import weakref |
import six |
@@ -63,7 +63,10 @@ except ImportError: |
# nothing like hermetic Python. This means lesser control on the system and |
# the six.moves package may be missing (is missing on 20150321 on gMac). Be |
# extra conservative and try to load the old replacement if it fails. |
- import copy_reg as copyreg |
+ try: |
+ import copy_reg as copyreg #PY26 |
+ except ImportError: |
+ import copyreg |
# We use "as" to avoid name collisions with variables. |
from google.protobuf.internal import containers |
@@ -76,7 +79,6 @@ from google.protobuf.internal import well_known_types |
from google.protobuf.internal import wire_format |
from google.protobuf import descriptor as descriptor_mod |
from google.protobuf import message as message_mod |
-from google.protobuf import symbol_database |
from google.protobuf import text_format |
_FieldDescriptor = descriptor_mod.FieldDescriptor |
@@ -98,16 +100,12 @@ class GeneratedProtocolMessageType(type): |
classes at runtime, as in this example: |
mydescriptor = Descriptor(.....) |
- class MyProtoClass(Message): |
- __metaclass__ = GeneratedProtocolMessageType |
- DESCRIPTOR = mydescriptor |
+ factory = symbol_database.Default() |
+ factory.pool.AddDescriptor(mydescriptor) |
+ MyProtoClass = factory.GetPrototype(mydescriptor) |
myproto_instance = MyProtoClass() |
myproto.foo_field = 23 |
... |
- |
- The above example will not work for nested types. If you wish to include them, |
- use reflection.MakeClass() instead of manually instantiating the class in |
- order to create the appropriate class structure. |
""" |
# Must be consistent with the protocol-compiler code in |
@@ -164,12 +162,10 @@ class GeneratedProtocolMessageType(type): |
""" |
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] |
cls._decoders_by_tag = {} |
- cls._extensions_by_name = {} |
- cls._extensions_by_number = {} |
if (descriptor.has_options and |
descriptor.GetOptions().message_set_wire_format): |
cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( |
- decoder.MessageSetItemDecoder(cls._extensions_by_number), None) |
+ decoder.MessageSetItemDecoder(descriptor), None) |
# Attach stuff to each FieldDescriptor for quick lookup later on. |
for field in descriptor.fields: |
@@ -385,13 +381,15 @@ def _GetInitializeDefaultForMap(field): |
if _IsMessageMapField(field): |
def MakeMessageMapDefault(message): |
return containers.MessageMap( |
- message._listener_for_children, value_field.message_type, key_checker) |
+ message._listener_for_children, value_field.message_type, key_checker, |
+ field.message_type) |
return MakeMessageMapDefault |
else: |
value_checker = type_checkers.GetTypeChecker(value_field) |
def MakePrimitiveMapDefault(message): |
return containers.ScalarMap( |
- message._listener_for_children, key_checker, value_checker) |
+ message._listener_for_children, key_checker, value_checker, |
+ field.message_type) |
return MakePrimitiveMapDefault |
def _DefaultValueConstructorForField(field): |
@@ -747,32 +745,21 @@ def _AddPropertiesForExtensions(descriptor, cls): |
constant_name = extension_name.upper() + "_FIELD_NUMBER" |
setattr(cls, constant_name, extension_field.number) |
+ # TODO(amauryfa): Migrate all users of these attributes to functions like |
+ # pool.FindExtensionByNumber(descriptor). |
+ if descriptor.file is not None: |
+ # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. |
+ pool = descriptor.file.pool |
+ cls._extensions_by_number = pool._extensions_by_number[descriptor] |
+ cls._extensions_by_name = pool._extensions_by_name[descriptor] |
def _AddStaticMethods(cls): |
# TODO(robinson): This probably needs to be thread-safe(?) |
def RegisterExtension(extension_handle): |
extension_handle.containing_type = cls.DESCRIPTOR |
+ # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. |
+ cls.DESCRIPTOR.file.pool.AddExtensionDescriptor(extension_handle) |
_AttachFieldHelpers(cls, extension_handle) |
- |
- # Try to insert our extension, failing if an extension with the same number |
- # already exists. |
- actual_handle = cls._extensions_by_number.setdefault( |
- extension_handle.number, extension_handle) |
- if actual_handle is not extension_handle: |
- raise AssertionError( |
- 'Extensions "%s" and "%s" both try to extend message type "%s" with ' |
- 'field number %d.' % |
- (extension_handle.full_name, actual_handle.full_name, |
- cls.DESCRIPTOR.full_name, extension_handle.number)) |
- |
- cls._extensions_by_name[extension_handle.full_name] = extension_handle |
- |
- handle = extension_handle # avoid line wrapping |
- if _IsMessageSetExtension(handle): |
- # MessageSet extension. Also register under type name. |
- cls._extensions_by_name[ |
- extension_handle.message_type.full_name] = extension_handle |
- |
cls.RegisterExtension = staticmethod(RegisterExtension) |
def FromString(s): |
@@ -926,26 +913,33 @@ def _InternalUnpackAny(msg): |
Returns: |
The unpacked message. |
""" |
+ # TODO(amauryfa): Don't use the factory of generated messages. |
+ # To make Any work with custom factories, use the message factory of the |
+ # parent message. |
+ # pylint: disable=g-import-not-at-top |
+ from google.protobuf import symbol_database |
+ factory = symbol_database.Default() |
+ |
type_url = msg.type_url |
- db = symbol_database.Default() |
if not type_url: |
return None |
# TODO(haberman): For now we just strip the hostname. Better logic will be |
# required. |
- type_name = type_url.split("/")[-1] |
- descriptor = db.pool.FindMessageTypeByName(type_name) |
+ type_name = type_url.split('/')[-1] |
+ descriptor = factory.pool.FindMessageTypeByName(type_name) |
if descriptor is None: |
return None |
- message_class = db.GetPrototype(descriptor) |
+ message_class = factory.GetPrototype(descriptor) |
message = message_class() |
message.ParseFromString(msg.value) |
return message |
+ |
def _AddEqualsMethod(message_descriptor, cls): |
"""Helper for _AddMessageMethods().""" |
def __eq__(self, other): |
@@ -1223,7 +1217,7 @@ def _AddMergeFromMethod(cls): |
if not isinstance(msg, cls): |
raise TypeError( |
"Parameter to MergeFrom() must be instance of same class: " |
- "expected %s got %s." % (cls.__name__, type(msg).__name__)) |
+ 'expected %s got %s.' % (cls.__name__, msg.__class__.__name__)) |
assert msg is not self |
self._Modified() |