Index: third_party/protobuf/python/google/protobuf/internal/cpp_message.py |
=================================================================== |
--- third_party/protobuf/python/google/protobuf/internal/cpp_message.py (revision 216642) |
+++ third_party/protobuf/python/google/protobuf/internal/cpp_message.py (working copy) |
@@ -34,8 +34,10 @@ |
__author__ = 'petar@google.com (Petar Petrov)' |
+import copy_reg |
import operator |
from google.protobuf.internal import _net_proto2___python |
+from google.protobuf.internal import enum_type_wrapper |
from google.protobuf import message |
@@ -156,10 +158,12 @@ |
def __hash__(self): |
raise TypeError('unhashable object') |
- def sort(self, sort_function=cmp): |
- values = self[slice(None, None, None)] |
- values.sort(sort_function) |
- self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values) |
+ def sort(self, *args, **kwargs): |
+ # Maintain compatibility with the previous interface. |
+ if 'sort_function' in kwargs: |
+ kwargs['cmp'] = kwargs.pop('sort_function') |
+ self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, |
+ sorted(self, *args, **kwargs)) |
def RepeatedScalarProperty(cdescriptor): |
@@ -202,6 +206,12 @@ |
for message in elem_seq: |
self.add().MergeFrom(message) |
+ def remove(self, value): |
+ # TODO(protocol-devel): This is inefficient as it needs to generate a |
+ # message pointer for each message only to do index(). Move this to a C++ |
+ # extension function. |
+ self.__delitem__(self[slice(None, None, None)].index(value)) |
+ |
def MergeFrom(self, other): |
for message in other[:]: |
self.add().MergeFrom(message) |
@@ -236,27 +246,29 @@ |
def __hash__(self): |
raise TypeError('unhashable object') |
- def sort(self, sort_function=cmp): |
- messages = [] |
- for index in range(len(self)): |
- # messages[i][0] is where the i-th element of the new array has to come |
- # from. |
- # messages[i][1] is where the i-th element of the old array has to go. |
- messages.append([index, 0, self[index]]) |
- messages.sort(lambda x,y: sort_function(x[2], y[2])) |
+ def sort(self, cmp=None, key=None, reverse=False, **kwargs): |
+ # Maintain compatibility with the old interface. |
+ if cmp is None and 'sort_function' in kwargs: |
+ cmp = kwargs.pop('sort_function') |
- # Remember which position each elements has to move to. |
- for i in range(len(messages)): |
- messages[messages[i][0]][1] = i |
+ # The cmp function, if provided, is passed the results of the key function, |
+ # so we only need to wrap one of them. |
+ if key is None: |
+ index_key = self.__getitem__ |
+ else: |
+ index_key = lambda i: key(self[i]) |
+ # Sort the list of current indexes by the underlying object. |
+ indexes = range(len(self)) |
+ indexes.sort(cmp=cmp, key=index_key, reverse=reverse) |
+ |
# Apply the transposition. |
- for i in range(len(messages)): |
- from_position = messages[i][0] |
- if i == from_position: |
+ for dest, src in enumerate(indexes): |
+ if dest == src: |
continue |
- self._cmsg.SwapRepeatedFieldElements( |
- self._cfield_descriptor, i, from_position) |
- messages[messages[i][1]][0] = from_position |
+ self._cmsg.SwapRepeatedFieldElements(self._cfield_descriptor, dest, src) |
+ # Don't swap the same value twice. |
+ indexes[src] = src |
def RepeatedCompositeProperty(cdescriptor, message_type): |
@@ -359,11 +371,12 @@ |
return None |
-def NewMessage(message_descriptor, dictionary): |
+def NewMessage(bases, message_descriptor, dictionary): |
"""Creates a new protocol message *class*.""" |
_AddClassAttributesForNestedExtensions(message_descriptor, dictionary) |
_AddEnumValues(message_descriptor, dictionary) |
_AddDescriptors(message_descriptor, dictionary) |
+ return bases |
def InitMessage(message_descriptor, cls): |
@@ -372,6 +385,7 @@ |
_AddInitMethod(message_descriptor, cls) |
_AddMessageMethods(message_descriptor, cls) |
_AddPropertiesForExtensions(message_descriptor, cls) |
+ copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) |
def _AddDescriptors(message_descriptor, dictionary): |
@@ -387,7 +401,7 @@ |
field.full_name) |
dictionary['__slots__'] = list(dictionary['__descriptors'].iterkeys()) + [ |
- '_cmsg', '_owner', '_composite_fields', 'Extensions'] |
+ '_cmsg', '_owner', '_composite_fields', 'Extensions', '_HACK_REFCOUNTS'] |
def _AddEnumValues(message_descriptor, dictionary): |
@@ -398,6 +412,7 @@ |
dictionary: Class dictionary that should be populated. |
""" |
for enum_type in message_descriptor.enum_types: |
+ dictionary[enum_type.name] = enum_type_wrapper.EnumTypeWrapper(enum_type) |
for enum_value in enum_type.values: |
dictionary[enum_value.name] = enum_value.number |
@@ -439,28 +454,35 @@ |
def Init(self, **kwargs): |
"""Message constructor.""" |
cmessage = kwargs.pop('__cmessage', None) |
- if cmessage is None: |
+ if cmessage: |
+ self._cmsg = cmessage |
+ else: |
self._cmsg = NewCMessage(message_descriptor.full_name) |
- else: |
- self._cmsg = cmessage |
# Keep a reference to the owner, as the owner keeps a reference to the |
# underlying protocol buffer message. |
owner = kwargs.pop('__owner', None) |
- if owner is not None: |
+ if owner: |
self._owner = owner |
- self.Extensions = ExtensionDict(self) |
+ if message_descriptor.is_extendable: |
+ self.Extensions = ExtensionDict(self) |
+ else: |
+ # Reference counting in the C++ code is broken and depends on |
+ # the Extensions reference to keep this object alive during unit |
+ # tests (see b/4856052). Remove this once b/4945904 is fixed. |
+ self._HACK_REFCOUNTS = self |
self._composite_fields = {} |
for field_name, field_value in kwargs.iteritems(): |
field_cdescriptor = self.__descriptors.get(field_name, None) |
- if field_cdescriptor is None: |
+ if not field_cdescriptor: |
raise ValueError('Protocol message has no "%s" field.' % field_name) |
if field_cdescriptor.label == _LABEL_REPEATED: |
if field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE: |
+ field_name = getattr(self, field_name) |
for val in field_value: |
- getattr(self, field_name).add().MergeFrom(val) |
+ field_name.add().MergeFrom(val) |
else: |
getattr(self, field_name).extend(field_value) |
elif field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE: |
@@ -497,12 +519,34 @@ |
return self._cmsg.HasField(field_name) |
def ClearField(self, field_name): |
+ child_cmessage = None |
if field_name in self._composite_fields: |
+ child_field = self._composite_fields[field_name] |
del self._composite_fields[field_name] |
- self._cmsg.ClearField(field_name) |
+ child_cdescriptor = self.__descriptors[field_name] |
+ # TODO(anuraag): Support clearing repeated message fields as well. |
+ if (child_cdescriptor.label != _LABEL_REPEATED and |
+ child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE): |
+ child_field._owner = None |
+ child_cmessage = child_field._cmsg |
+ |
+ if child_cmessage is not None: |
+ self._cmsg.ClearField(field_name, child_cmessage) |
+ else: |
+ self._cmsg.ClearField(field_name) |
+ |
def Clear(self): |
- return self._cmsg.Clear() |
+ cmessages_to_release = [] |
+ for field_name, child_field in self._composite_fields.iteritems(): |
+ child_cdescriptor = self.__descriptors[field_name] |
+ # TODO(anuraag): Support clearing repeated message fields as well. |
+ if (child_cdescriptor.label != _LABEL_REPEATED and |
+ child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE): |
+ child_field._owner = None |
+ cmessages_to_release.append((child_cdescriptor, child_field._cmsg)) |
+ self._composite_fields.clear() |
+ self._cmsg.Clear(cmessages_to_release) |
def IsInitialized(self, errors=None): |
if self._cmsg.IsInitialized(): |
@@ -514,8 +558,8 @@ |
def SerializeToString(self): |
if not self.IsInitialized(): |
raise message.EncodeError( |
- 'Message is missing required fields: ' + |
- ','.join(self.FindInitializationErrors())) |
+ 'Message %s is missing required fields: %s' % ( |
+ self._cmsg.full_name, ','.join(self.FindInitializationErrors()))) |
return self._cmsg.SerializeToString() |
def SerializePartialToString(self): |
@@ -534,7 +578,8 @@ |
def MergeFrom(self, msg): |
if not isinstance(msg, cls): |
raise TypeError( |
- "Parameter to MergeFrom() must be instance of same class.") |
+ "Parameter to MergeFrom() must be instance of same class: " |
+ "expected %s got %s." % (cls.__name__, type(msg).__name__)) |
self._cmsg.MergeFrom(msg._cmsg) |
def CopyFrom(self, msg): |
@@ -581,6 +626,8 @@ |
raise TypeError('unhashable object') |
def __unicode__(self): |
+ # Lazy import to prevent circular import when text_format imports this file. |
+ from google.protobuf import text_format |
return text_format.MessageToString(self, as_utf8=True).decode('utf-8') |
# Attach the local methods to the message class. |