| Index: third_party/protobuf/python/google/protobuf/internal/cpp_message.py
|
| ===================================================================
|
| --- third_party/protobuf/python/google/protobuf/internal/cpp_message.py (revision 216642)
|
| +++ third_party/protobuf/python/google/protobuf/internal/cpp_message.py (working copy)
|
| @@ -34,8 +34,10 @@
|
|
|
| __author__ = 'petar@google.com (Petar Petrov)'
|
|
|
| +import copy_reg
|
| import operator
|
| from google.protobuf.internal import _net_proto2___python
|
| +from google.protobuf.internal import enum_type_wrapper
|
| from google.protobuf import message
|
|
|
|
|
| @@ -156,10 +158,12 @@
|
| def __hash__(self):
|
| raise TypeError('unhashable object')
|
|
|
| - def sort(self, sort_function=cmp):
|
| - values = self[slice(None, None, None)]
|
| - values.sort(sort_function)
|
| - self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
|
| + def sort(self, *args, **kwargs):
|
| + # Maintain compatibility with the previous interface.
|
| + if 'sort_function' in kwargs:
|
| + kwargs['cmp'] = kwargs.pop('sort_function')
|
| + self._cmsg.AssignRepeatedScalar(self._cfield_descriptor,
|
| + sorted(self, *args, **kwargs))
|
|
|
|
|
| def RepeatedScalarProperty(cdescriptor):
|
| @@ -202,6 +206,12 @@
|
| for message in elem_seq:
|
| self.add().MergeFrom(message)
|
|
|
| + def remove(self, value):
|
| + # TODO(protocol-devel): This is inefficient as it needs to generate a
|
| + # message pointer for each message only to do index(). Move this to a C++
|
| + # extension function.
|
| + self.__delitem__(self[slice(None, None, None)].index(value))
|
| +
|
| def MergeFrom(self, other):
|
| for message in other[:]:
|
| self.add().MergeFrom(message)
|
| @@ -236,27 +246,29 @@
|
| def __hash__(self):
|
| raise TypeError('unhashable object')
|
|
|
| - def sort(self, sort_function=cmp):
|
| - messages = []
|
| - for index in range(len(self)):
|
| - # messages[i][0] is where the i-th element of the new array has to come
|
| - # from.
|
| - # messages[i][1] is where the i-th element of the old array has to go.
|
| - messages.append([index, 0, self[index]])
|
| - messages.sort(lambda x,y: sort_function(x[2], y[2]))
|
| + def sort(self, cmp=None, key=None, reverse=False, **kwargs):
|
| + # Maintain compatibility with the old interface.
|
| + if cmp is None and 'sort_function' in kwargs:
|
| + cmp = kwargs.pop('sort_function')
|
|
|
| - # Remember which position each elements has to move to.
|
| - for i in range(len(messages)):
|
| - messages[messages[i][0]][1] = i
|
| + # The cmp function, if provided, is passed the results of the key function,
|
| + # so we only need to wrap one of them.
|
| + if key is None:
|
| + index_key = self.__getitem__
|
| + else:
|
| + index_key = lambda i: key(self[i])
|
|
|
| + # Sort the list of current indexes by the underlying object.
|
| + indexes = range(len(self))
|
| + indexes.sort(cmp=cmp, key=index_key, reverse=reverse)
|
| +
|
| # Apply the transposition.
|
| - for i in range(len(messages)):
|
| - from_position = messages[i][0]
|
| - if i == from_position:
|
| + for dest, src in enumerate(indexes):
|
| + if dest == src:
|
| continue
|
| - self._cmsg.SwapRepeatedFieldElements(
|
| - self._cfield_descriptor, i, from_position)
|
| - messages[messages[i][1]][0] = from_position
|
| + self._cmsg.SwapRepeatedFieldElements(self._cfield_descriptor, dest, src)
|
| + # Don't swap the same value twice.
|
| + indexes[src] = src
|
|
|
|
|
| def RepeatedCompositeProperty(cdescriptor, message_type):
|
| @@ -359,11 +371,12 @@
|
| return None
|
|
|
|
|
| -def NewMessage(message_descriptor, dictionary):
|
| +def NewMessage(bases, message_descriptor, dictionary):
|
| """Creates a new protocol message *class*."""
|
| _AddClassAttributesForNestedExtensions(message_descriptor, dictionary)
|
| _AddEnumValues(message_descriptor, dictionary)
|
| _AddDescriptors(message_descriptor, dictionary)
|
| + return bases
|
|
|
|
|
| def InitMessage(message_descriptor, cls):
|
| @@ -372,6 +385,7 @@
|
| _AddInitMethod(message_descriptor, cls)
|
| _AddMessageMethods(message_descriptor, cls)
|
| _AddPropertiesForExtensions(message_descriptor, cls)
|
| + copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
|
|
|
|
|
| def _AddDescriptors(message_descriptor, dictionary):
|
| @@ -387,7 +401,7 @@
|
| field.full_name)
|
|
|
| dictionary['__slots__'] = list(dictionary['__descriptors'].iterkeys()) + [
|
| - '_cmsg', '_owner', '_composite_fields', 'Extensions']
|
| + '_cmsg', '_owner', '_composite_fields', 'Extensions', '_HACK_REFCOUNTS']
|
|
|
|
|
| def _AddEnumValues(message_descriptor, dictionary):
|
| @@ -398,6 +412,7 @@
|
| dictionary: Class dictionary that should be populated.
|
| """
|
| for enum_type in message_descriptor.enum_types:
|
| + dictionary[enum_type.name] = enum_type_wrapper.EnumTypeWrapper(enum_type)
|
| for enum_value in enum_type.values:
|
| dictionary[enum_value.name] = enum_value.number
|
|
|
| @@ -439,28 +454,35 @@
|
| def Init(self, **kwargs):
|
| """Message constructor."""
|
| cmessage = kwargs.pop('__cmessage', None)
|
| - if cmessage is None:
|
| + if cmessage:
|
| + self._cmsg = cmessage
|
| + else:
|
| self._cmsg = NewCMessage(message_descriptor.full_name)
|
| - else:
|
| - self._cmsg = cmessage
|
|
|
| # Keep a reference to the owner, as the owner keeps a reference to the
|
| # underlying protocol buffer message.
|
| owner = kwargs.pop('__owner', None)
|
| - if owner is not None:
|
| + if owner:
|
| self._owner = owner
|
|
|
| - self.Extensions = ExtensionDict(self)
|
| + if message_descriptor.is_extendable:
|
| + self.Extensions = ExtensionDict(self)
|
| + else:
|
| + # Reference counting in the C++ code is broken and depends on
|
| + # the Extensions reference to keep this object alive during unit
|
| + # tests (see b/4856052). Remove this once b/4945904 is fixed.
|
| + self._HACK_REFCOUNTS = self
|
| self._composite_fields = {}
|
|
|
| for field_name, field_value in kwargs.iteritems():
|
| field_cdescriptor = self.__descriptors.get(field_name, None)
|
| - if field_cdescriptor is None:
|
| + if not field_cdescriptor:
|
| raise ValueError('Protocol message has no "%s" field.' % field_name)
|
| if field_cdescriptor.label == _LABEL_REPEATED:
|
| if field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
|
| + field_name = getattr(self, field_name)
|
| for val in field_value:
|
| - getattr(self, field_name).add().MergeFrom(val)
|
| + field_name.add().MergeFrom(val)
|
| else:
|
| getattr(self, field_name).extend(field_value)
|
| elif field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
|
| @@ -497,12 +519,34 @@
|
| return self._cmsg.HasField(field_name)
|
|
|
| def ClearField(self, field_name):
|
| + child_cmessage = None
|
| if field_name in self._composite_fields:
|
| + child_field = self._composite_fields[field_name]
|
| del self._composite_fields[field_name]
|
| - self._cmsg.ClearField(field_name)
|
|
|
| + child_cdescriptor = self.__descriptors[field_name]
|
| + # TODO(anuraag): Support clearing repeated message fields as well.
|
| + if (child_cdescriptor.label != _LABEL_REPEATED and
|
| + child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
|
| + child_field._owner = None
|
| + child_cmessage = child_field._cmsg
|
| +
|
| + if child_cmessage is not None:
|
| + self._cmsg.ClearField(field_name, child_cmessage)
|
| + else:
|
| + self._cmsg.ClearField(field_name)
|
| +
|
| def Clear(self):
|
| - return self._cmsg.Clear()
|
| + cmessages_to_release = []
|
| + for field_name, child_field in self._composite_fields.iteritems():
|
| + child_cdescriptor = self.__descriptors[field_name]
|
| + # TODO(anuraag): Support clearing repeated message fields as well.
|
| + if (child_cdescriptor.label != _LABEL_REPEATED and
|
| + child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
|
| + child_field._owner = None
|
| + cmessages_to_release.append((child_cdescriptor, child_field._cmsg))
|
| + self._composite_fields.clear()
|
| + self._cmsg.Clear(cmessages_to_release)
|
|
|
| def IsInitialized(self, errors=None):
|
| if self._cmsg.IsInitialized():
|
| @@ -514,8 +558,8 @@
|
| def SerializeToString(self):
|
| if not self.IsInitialized():
|
| raise message.EncodeError(
|
| - 'Message is missing required fields: ' +
|
| - ','.join(self.FindInitializationErrors()))
|
| + 'Message %s is missing required fields: %s' % (
|
| + self._cmsg.full_name, ','.join(self.FindInitializationErrors())))
|
| return self._cmsg.SerializeToString()
|
|
|
| def SerializePartialToString(self):
|
| @@ -534,7 +578,8 @@
|
| def MergeFrom(self, msg):
|
| if not isinstance(msg, cls):
|
| raise TypeError(
|
| - "Parameter to MergeFrom() must be instance of same class.")
|
| + "Parameter to MergeFrom() must be instance of same class: "
|
| + "expected %s got %s." % (cls.__name__, type(msg).__name__))
|
| self._cmsg.MergeFrom(msg._cmsg)
|
|
|
| def CopyFrom(self, msg):
|
| @@ -581,6 +626,8 @@
|
| raise TypeError('unhashable object')
|
|
|
| def __unicode__(self):
|
| + # Lazy import to prevent circular import when text_format imports this file.
|
| + from google.protobuf import text_format
|
| return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
|
|
|
| # Attach the local methods to the message class.
|
|
|