Index: third_party/protobuf/python/google/protobuf/internal/python_message.py |
=================================================================== |
--- third_party/protobuf/python/google/protobuf/internal/python_message.py (revision 216642) |
+++ third_party/protobuf/python/google/protobuf/internal/python_message.py (working copy) |
@@ -54,6 +54,7 @@ |
from cStringIO import StringIO |
except ImportError: |
from StringIO import StringIO |
+import copy_reg |
import struct |
import weakref |
@@ -61,6 +62,7 @@ |
from google.protobuf.internal import containers |
from google.protobuf.internal import decoder |
from google.protobuf.internal import encoder |
+from google.protobuf.internal import enum_type_wrapper |
from google.protobuf.internal import message_listener as message_listener_mod |
from google.protobuf.internal import type_checkers |
from google.protobuf.internal import wire_format |
@@ -71,9 +73,10 @@ |
_FieldDescriptor = descriptor_mod.FieldDescriptor |
-def NewMessage(descriptor, dictionary): |
+def NewMessage(bases, descriptor, dictionary): |
_AddClassAttributesForNestedExtensions(descriptor, dictionary) |
_AddSlots(descriptor, dictionary) |
+ return bases |
def InitMessage(descriptor, cls): |
@@ -96,6 +99,7 @@ |
_AddStaticMethods(cls) |
_AddMessageMethods(descriptor, cls) |
_AddPrivateHelperMethods(cls) |
+ copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) |
# Stateless helpers for GeneratedProtocolMessageType below. |
@@ -145,6 +149,10 @@ |
if not extension_handle.is_extension: |
raise KeyError('"%s" is not an extension.' % extension_handle.full_name) |
+ if not extension_handle.containing_type: |
+ raise KeyError('"%s" is missing a containing_type.' |
+ % extension_handle.full_name) |
+ |
if extension_handle.containing_type is not message.DESCRIPTOR: |
raise KeyError('Extension "%s" extends message type "%s", but this ' |
'message is of type "%s".' % |
@@ -164,6 +172,7 @@ |
dictionary['__slots__'] = ['_cached_byte_size', |
'_cached_byte_size_dirty', |
'_fields', |
+ '_unknown_fields', |
'_is_present_in_parent', |
'_listener', |
'_listener_for_children', |
@@ -224,11 +233,14 @@ |
def _AddEnumValues(descriptor, cls): |
"""Sets class-level attributes for all enum fields defined in this message. |
+ Also exporting a class-level object that can name enum values. |
+ |
Args: |
descriptor: Descriptor object for this message type. |
cls: Class we're constructing for this message type. |
""" |
for enum_type in descriptor.enum_types: |
+ setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) |
for enum_value in enum_type.values: |
setattr(cls, enum_value.name, enum_value.number) |
@@ -248,7 +260,7 @@ |
""" |
if field.label == _FieldDescriptor.LABEL_REPEATED: |
- if field.default_value != []: |
+ if field.has_default_value and field.default_value != []: |
raise ValueError('Repeated field default value not empty list: %s' % ( |
field.default_value)) |
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
@@ -276,6 +288,8 @@ |
return MakeSubMessageDefault |
def MakeScalarDefault(message): |
+ # TODO(protobuf-team): This may be broken since there may not be |
+ # default_value. Combine with has_default_value somehow. |
return field.default_value |
return MakeScalarDefault |
@@ -287,6 +301,9 @@ |
self._cached_byte_size = 0 |
self._cached_byte_size_dirty = len(kwargs) > 0 |
self._fields = {} |
+ # _unknown_fields is () when empty for efficiency, and will be turned into |
+ # a list if fields are added. |
+ self._unknown_fields = () |
self._is_present_in_parent = False |
self._listener = message_listener_mod.NullMessageListener() |
self._listener_for_children = _Listener(self) |
@@ -428,6 +445,8 @@ |
valid_values = set() |
def getter(self): |
+ # TODO(protobuf-team): This may be broken since there may not be |
+ # default_value. Combine with has_default_value somehow. |
return self._fields.get(field, default_value) |
getter.__module__ = None |
getter.__doc__ = 'Getter for %s.' % proto_field_name |
@@ -462,13 +481,18 @@ |
# for non-repeated scalars. |
proto_field_name = field.name |
property_name = _PropertyName(proto_field_name) |
+ |
+ # TODO(komarek): Can anyone explain to me why we cache the message_type this |
+ # way, instead of referring to field.message_type inside of getter(self)? |
+ # What if someone sets message_type later on (which makes for simpler |
+ # dyanmic proto descriptor and class creation code). |
message_type = field.message_type |
def getter(self): |
field_value = self._fields.get(field) |
if field_value is None: |
# Construct a new object to represent this field. |
- field_value = message_type._concrete_class() |
+ field_value = message_type._concrete_class() # use field.message_type? |
field_value._SetListener(self._listener_for_children) |
# Atomically check if another thread has preempted us and, if not, swap |
@@ -620,6 +644,7 @@ |
def Clear(self): |
# Clear fields. |
self._fields = {} |
+ self._unknown_fields = () |
self._Modified() |
cls.Clear = Clear |
@@ -649,8 +674,17 @@ |
if self is other: |
return True |
- return self.ListFields() == other.ListFields() |
+ if not self.ListFields() == other.ListFields(): |
+ return False |
+ # Sort unknown fields because their order shouldn't affect equality test. |
+ unknown_fields = list(self._unknown_fields) |
+ unknown_fields.sort() |
+ other_unknown_fields = list(other._unknown_fields) |
+ other_unknown_fields.sort() |
+ |
+ return unknown_fields == other_unknown_fields |
+ |
cls.__eq__ = __eq__ |
@@ -710,6 +744,9 @@ |
for field_descriptor, field_value in self.ListFields(): |
size += field_descriptor._sizer(field_value) |
+ for tag_bytes, value_bytes in self._unknown_fields: |
+ size += len(tag_bytes) + len(value_bytes) |
+ |
self._cached_byte_size = size |
self._cached_byte_size_dirty = False |
self._listener_for_children.dirty = False |
@@ -726,8 +763,8 @@ |
errors = [] |
if not self.IsInitialized(): |
raise message_mod.EncodeError( |
- 'Message is missing required fields: ' + |
- ','.join(self.FindInitializationErrors())) |
+ 'Message %s is missing required fields: %s' % ( |
+ self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) |
return self.SerializePartialToString() |
cls.SerializeToString = SerializeToString |
@@ -744,6 +781,9 @@ |
def InternalSerialize(self, write_bytes): |
for field_descriptor, field_value in self.ListFields(): |
field_descriptor._encoder(write_bytes, field_value) |
+ for tag_bytes, value_bytes in self._unknown_fields: |
+ write_bytes(tag_bytes) |
+ write_bytes(value_bytes) |
cls._InternalSerialize = InternalSerialize |
@@ -770,13 +810,18 @@ |
def InternalParse(self, buffer, pos, end): |
self._Modified() |
field_dict = self._fields |
+ unknown_field_list = self._unknown_fields |
while pos != end: |
(tag_bytes, new_pos) = local_ReadTag(buffer, pos) |
field_decoder = decoders_by_tag.get(tag_bytes) |
if field_decoder is None: |
+ value_start_pos = new_pos |
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) |
if new_pos == -1: |
return pos |
+ if not unknown_field_list: |
+ unknown_field_list = self._unknown_fields = [] |
+ unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) |
pos = new_pos |
else: |
pos = field_decoder(buffer, new_pos, end, self, field_dict) |
@@ -873,7 +918,8 @@ |
def MergeFrom(self, msg): |
if not isinstance(msg, cls): |
raise TypeError( |
- "Parameter to MergeFrom() must be instance of same class.") |
+ "Parameter to MergeFrom() must be instance of same class: " |
+ "expected %s got %s." % (cls.__name__, type(msg).__name__)) |
assert msg is not self |
self._Modified() |
@@ -898,6 +944,12 @@ |
field_value.MergeFrom(value) |
else: |
self._fields[field] = value |
+ |
+ if msg._unknown_fields: |
+ if not self._unknown_fields: |
+ self._unknown_fields = [] |
+ self._unknown_fields.extend(msg._unknown_fields) |
+ |
cls.MergeFrom = MergeFrom |