Index: third_party/protobuf/python/google/protobuf/internal/python_message.py |
diff --git a/third_party/protobuf/python/google/protobuf/internal/python_message.py b/third_party/protobuf/python/google/protobuf/internal/python_message.py |
index 4bea57ac6c33616454a02efd845cbc50752b5db2..ca9f76753b72562fed5bbbbfd2a810fbfd68d151 100644 |
--- a/third_party/protobuf/python/google/protobuf/internal/python_message.py |
+++ b/third_party/protobuf/python/google/protobuf/internal/python_message.py |
@@ -1,6 +1,6 @@ |
# Protocol Buffers - Google's data interchange format |
# Copyright 2008 Google Inc. All rights reserved. |
-# http://code.google.com/p/protobuf/ |
+# https://developers.google.com/protocol-buffers/ |
# |
# Redistribution and use in source and binary forms, with or without |
# modification, are permitted provided that the following conditions are |
@@ -28,6 +28,10 @@ |
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
+# Keep it Python2.5 compatible for GAE. |
+# |
+# Copyright 2007 Google Inc. All Rights Reserved. |
+# |
# This code is meant to work on Python 2.4 and above only. |
# |
# TODO(robinson): Helpers for verbose, common checks like seeing if a |
@@ -50,11 +54,18 @@ this file*. |
__author__ = 'robinson@google.com (Will Robinson)' |
-try: |
- from cStringIO import StringIO |
-except ImportError: |
- from StringIO import StringIO |
-import copy_reg |
+import sys |
+if sys.version_info[0] < 3: |
+ try: |
+ from cStringIO import StringIO as BytesIO |
+ except ImportError: |
+ from StringIO import StringIO as BytesIO |
+ import copy_reg as copyreg |
+ _basestring = basestring |
+else: |
+ from io import BytesIO |
+ import copyreg |
+ _basestring = str |
import struct |
import weakref |
@@ -68,6 +79,7 @@ from google.protobuf.internal import type_checkers |
from google.protobuf.internal import wire_format |
from google.protobuf import descriptor as descriptor_mod |
from google.protobuf import message as message_mod |
+from google.protobuf import symbol_database |
from google.protobuf import text_format |
_FieldDescriptor = descriptor_mod.FieldDescriptor |
@@ -86,20 +98,21 @@ def InitMessage(descriptor, cls): |
if (descriptor.has_options and |
descriptor.GetOptions().message_set_wire_format): |
cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( |
- decoder.MessageSetItemDecoder(cls._extensions_by_number)) |
+ decoder.MessageSetItemDecoder(cls._extensions_by_number), None) |
# Attach stuff to each FieldDescriptor for quick lookup later on. |
for field in descriptor.fields: |
_AttachFieldHelpers(cls, field) |
+ descriptor._concrete_class = cls # pylint: disable=protected-access |
_AddEnumValues(descriptor, cls) |
_AddInitMethod(descriptor, cls) |
_AddPropertiesForFields(descriptor, cls) |
_AddPropertiesForExtensions(descriptor, cls) |
_AddStaticMethods(cls) |
_AddMessageMethods(descriptor, cls) |
- _AddPrivateHelperMethods(cls) |
- copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) |
+ _AddPrivateHelperMethods(descriptor, cls) |
+ copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) |
# Stateless helpers for GeneratedProtocolMessageType below. |
@@ -176,7 +189,8 @@ def _AddSlots(message_descriptor, dictionary): |
'_is_present_in_parent', |
'_listener', |
'_listener_for_children', |
- '__weakref__'] |
+ '__weakref__', |
+ '_oneofs'] |
def _IsMessageSetExtension(field): |
@@ -188,12 +202,37 @@ def _IsMessageSetExtension(field): |
field.label == _FieldDescriptor.LABEL_OPTIONAL) |
+def _IsMapField(field): |
+ return (field.type == _FieldDescriptor.TYPE_MESSAGE and |
+ field.message_type.has_options and |
+ field.message_type.GetOptions().map_entry) |
+ |
+ |
+def _IsMessageMapField(field): |
+ value_type = field.message_type.fields_by_name["value"] |
+ return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE |
+ |
+ |
def _AttachFieldHelpers(cls, field_descriptor): |
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) |
- is_packed = (field_descriptor.has_options and |
- field_descriptor.GetOptions().packed) |
- |
- if _IsMessageSetExtension(field_descriptor): |
+ is_packable = (is_repeated and |
+ wire_format.IsTypePackable(field_descriptor.type)) |
+ if not is_packable: |
+ is_packed = False |
+ elif field_descriptor.containing_type.syntax == "proto2": |
+ is_packed = (field_descriptor.has_options and |
+ field_descriptor.GetOptions().packed) |
+ else: |
+ has_packed_false = (field_descriptor.has_options and |
+ field_descriptor.GetOptions().HasField("packed") and |
+ field_descriptor.GetOptions().packed == False) |
+ is_packed = not has_packed_false |
+ is_map_entry = _IsMapField(field_descriptor) |
+ |
+ if is_map_entry: |
+ field_encoder = encoder.MapEncoder(field_descriptor) |
+ sizer = encoder.MapSizer(field_descriptor) |
+ elif _IsMessageSetExtension(field_descriptor): |
field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) |
sizer = encoder.MessageSetItemSizer(field_descriptor.number) |
else: |
@@ -209,10 +248,27 @@ def _AttachFieldHelpers(cls, field_descriptor): |
def AddDecoder(wiretype, is_packed): |
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) |
- cls._decoders_by_tag[tag_bytes] = ( |
- type_checkers.TYPE_TO_DECODER[field_descriptor.type]( |
- field_descriptor.number, is_repeated, is_packed, |
- field_descriptor, field_descriptor._default_constructor)) |
+ decode_type = field_descriptor.type |
+ if (decode_type == _FieldDescriptor.TYPE_ENUM and |
+ type_checkers.SupportsOpenEnums(field_descriptor)): |
+ decode_type = _FieldDescriptor.TYPE_INT32 |
+ |
+ oneof_descriptor = None |
+ if field_descriptor.containing_oneof is not None: |
+ oneof_descriptor = field_descriptor |
+ |
+ if is_map_entry: |
+ is_message_map = _IsMessageMapField(field_descriptor) |
+ |
+ field_decoder = decoder.MapDecoder( |
+ field_descriptor, _GetInitializeDefaultForMap(field_descriptor), |
+ is_message_map) |
+ else: |
+ field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( |
+ field_descriptor.number, is_repeated, is_packed, |
+ field_descriptor, field_descriptor._default_constructor) |
+ |
+ cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) |
AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], |
False) |
@@ -245,6 +301,26 @@ def _AddEnumValues(descriptor, cls): |
setattr(cls, enum_value.name, enum_value.number) |
+def _GetInitializeDefaultForMap(field): |
+ if field.label != _FieldDescriptor.LABEL_REPEATED: |
+ raise ValueError('map_entry set on non-repeated field %s' % ( |
+ field.name)) |
+ fields_by_name = field.message_type.fields_by_name |
+ key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) |
+ |
+ value_field = fields_by_name['value'] |
+ if _IsMessageMapField(field): |
+ def MakeMessageMapDefault(message): |
+ return containers.MessageMap( |
+ message._listener_for_children, value_field.message_type, key_checker) |
+ return MakeMessageMapDefault |
+ else: |
+ value_checker = type_checkers.GetTypeChecker(value_field) |
+ def MakePrimitiveMapDefault(message): |
+ return containers.ScalarMap( |
+ message._listener_for_children, key_checker, value_checker) |
+ return MakePrimitiveMapDefault |
+ |
def _DefaultValueConstructorForField(field): |
"""Returns a function which returns a default value for a field. |
@@ -259,6 +335,9 @@ def _DefaultValueConstructorForField(field): |
value may refer back to |message| via a weak reference. |
""" |
+ if _IsMapField(field): |
+ return _GetInitializeDefaultForMap(field) |
+ |
if field.label == _FieldDescriptor.LABEL_REPEATED: |
if field.has_default_value and field.default_value != []: |
raise ValueError('Repeated field default value not empty list: %s' % ( |
@@ -272,7 +351,7 @@ def _DefaultValueConstructorForField(field): |
message._listener_for_children, field.message_type) |
return MakeRepeatedMessageDefault |
else: |
- type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) |
+ type_checker = type_checkers.GetTypeChecker(field) |
def MakeRepeatedScalarDefault(message): |
return containers.RepeatedScalarFieldContainer( |
message._listener_for_children, type_checker) |
@@ -284,6 +363,8 @@ def _DefaultValueConstructorForField(field): |
def MakeSubMessageDefault(message): |
result = message_type._concrete_class() |
result._SetListener(message._listener_for_children) |
+ if field.containing_oneof: |
+ message._UpdateOneofState(field) |
return result |
return MakeSubMessageDefault |
@@ -294,13 +375,43 @@ def _DefaultValueConstructorForField(field): |
return MakeScalarDefault |
+def _ReraiseTypeErrorWithFieldName(message_name, field_name): |
+ """Re-raise the currently-handled TypeError with the field name added.""" |
+ exc = sys.exc_info()[1] |
+ if len(exc.args) == 1 and type(exc) is TypeError: |
+ # simple TypeError; add field name to exception message |
+ exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) |
+ |
+ # re-raise possibly-amended exception with original traceback: |
+ raise type(exc)(exc, sys.exc_info()[2]) |
+ |
+ |
def _AddInitMethod(message_descriptor, cls): |
"""Adds an __init__ method to cls.""" |
- fields = message_descriptor.fields |
+ |
+ def _GetIntegerEnumValue(enum_type, value): |
+ """Convert a string or integer enum value to an integer. |
+ |
+ If the value is a string, it is converted to the enum value in |
+ enum_type with the same name. If the value is not a string, it's |
+ returned as-is. (No conversion or bounds-checking is done.) |
+ """ |
+ if isinstance(value, _basestring): |
+ try: |
+ return enum_type.values_by_name[value].number |
+ except KeyError: |
+ raise ValueError('Enum type %s: unknown label "%s"' % ( |
+ enum_type.full_name, value)) |
+ return value |
+ |
def init(self, **kwargs): |
self._cached_byte_size = 0 |
self._cached_byte_size_dirty = len(kwargs) > 0 |
self._fields = {} |
+ # Contains a mapping from oneof field descriptors to the descriptor |
+ # of the currently set field in that oneof field. |
+ self._oneofs = {} |
+ |
# _unknown_fields is () when empty for efficiency, and will be turned into |
# a list if fields are added. |
self._unknown_fields = () |
@@ -315,17 +426,41 @@ def _AddInitMethod(message_descriptor, cls): |
if field.label == _FieldDescriptor.LABEL_REPEATED: |
copy = field._default_constructor(self) |
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite |
- for val in field_value: |
- copy.add().MergeFrom(val) |
+ if _IsMapField(field): |
+ if _IsMessageMapField(field): |
+ for key in field_value: |
+ copy[key].MergeFrom(field_value[key]) |
+ else: |
+ copy.update(field_value) |
+ else: |
+ for val in field_value: |
+ if isinstance(val, dict): |
+ copy.add(**val) |
+ else: |
+ copy.add().MergeFrom(val) |
else: # Scalar |
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: |
+ field_value = [_GetIntegerEnumValue(field.enum_type, val) |
+ for val in field_value] |
copy.extend(field_value) |
self._fields[field] = copy |
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
copy = field._default_constructor(self) |
- copy.MergeFrom(field_value) |
+ new_val = field_value |
+ if isinstance(field_value, dict): |
+ new_val = field.message_type._concrete_class(**field_value) |
+ try: |
+ copy.MergeFrom(new_val) |
+ except TypeError: |
+ _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) |
self._fields[field] = copy |
else: |
- setattr(self, field_name, field_value) |
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: |
+ field_value = _GetIntegerEnumValue(field.enum_type, field_value) |
+ try: |
+ setattr(self, field_name, field_value) |
+ except TypeError: |
+ _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) |
init.__module__ = None |
init.__doc__ = None |
@@ -440,9 +575,10 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): |
""" |
proto_field_name = field.name |
property_name = _PropertyName(proto_field_name) |
- type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) |
+ type_checker = type_checkers.GetTypeChecker(field) |
default_value = field.default_value |
valid_values = set() |
+ is_proto3 = field.containing_type.syntax == "proto3" |
def getter(self): |
# TODO(protobuf-team): This may be broken since there may not be |
@@ -450,14 +586,30 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): |
return self._fields.get(field, default_value) |
getter.__module__ = None |
getter.__doc__ = 'Getter for %s.' % proto_field_name |
- def setter(self, new_value): |
- type_checker.CheckValue(new_value) |
- self._fields[field] = new_value |
+ |
+ clear_when_set_to_default = is_proto3 and not field.containing_oneof |
+ |
+ def field_setter(self, new_value): |
+ # pylint: disable=protected-access |
+ # Testing the value for truthiness captures all of the proto3 defaults |
+ # (0, 0.0, enum 0, and False). |
+ new_value = type_checker.CheckValue(new_value) |
+ if clear_when_set_to_default and not new_value: |
+ self._fields.pop(field, None) |
+ else: |
+ self._fields[field] = new_value |
# Check _cached_byte_size_dirty inline to improve performance, since scalar |
# setters are called frequently. |
if not self._cached_byte_size_dirty: |
self._Modified() |
+ if field.containing_oneof: |
+ def setter(self, new_value): |
+ field_setter(self, new_value) |
+ self._UpdateOneofState(field) |
+ else: |
+ setter = field_setter |
+ |
setter.__module__ = None |
setter.__doc__ = 'Setter for %s.' % proto_field_name |
@@ -493,7 +645,10 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): |
if field_value is None: |
# Construct a new object to represent this field. |
field_value = message_type._concrete_class() # use field.message_type? |
- field_value._SetListener(self._listener_for_children) |
+ field_value._SetListener( |
+ _OneofListener(self, field) |
+ if field.containing_oneof is not None |
+ else self._listener_for_children) |
# Atomically check if another thread has preempted us and, if not, swap |
# in the new object we just created. If someone has preempted us, we |
@@ -581,27 +736,48 @@ def _AddListFieldsMethod(message_descriptor, cls): |
cls.ListFields = ListFields |
+_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"' |
+_Proto2HasError = 'Protocol message has no non-repeated field "%s"' |
def _AddHasFieldMethod(message_descriptor, cls): |
"""Helper for _AddMessageMethods().""" |
- singular_fields = {} |
+ is_proto3 = (message_descriptor.syntax == "proto3") |
+ error_msg = _Proto3HasError if is_proto3 else _Proto2HasError |
+ |
+ hassable_fields = {} |
for field in message_descriptor.fields: |
- if field.label != _FieldDescriptor.LABEL_REPEATED: |
- singular_fields[field.name] = field |
+ if field.label == _FieldDescriptor.LABEL_REPEATED: |
+ continue |
+ # For proto3, only submessages and fields inside a oneof have presence. |
+ if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and |
+ not field.containing_oneof): |
+ continue |
+ hassable_fields[field.name] = field |
+ |
+ if not is_proto3: |
+ # Fields inside oneofs are never repeated (enforced by the compiler). |
+ for oneof in message_descriptor.oneofs: |
+ hassable_fields[oneof.name] = oneof |
def HasField(self, field_name): |
try: |
- field = singular_fields[field_name] |
+ field = hassable_fields[field_name] |
except KeyError: |
- raise ValueError( |
- 'Protocol message has no singular "%s" field.' % field_name) |
+ raise ValueError(error_msg % field_name) |
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
- value = self._fields.get(field) |
- return value is not None and value._is_present_in_parent |
+ if isinstance(field, descriptor_mod.OneofDescriptor): |
+ try: |
+ return HasField(self, self._oneofs[field].name) |
+ except KeyError: |
+ return False |
else: |
- return field in self._fields |
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
+ value = self._fields.get(field) |
+ return value is not None and value._is_present_in_parent |
+ else: |
+ return field in self._fields |
+ |
cls.HasField = HasField |
@@ -611,7 +787,14 @@ def _AddClearFieldMethod(message_descriptor, cls): |
try: |
field = message_descriptor.fields_by_name[field_name] |
except KeyError: |
- raise ValueError('Protocol message has no "%s" field.' % field_name) |
+ try: |
+ field = message_descriptor.oneofs_by_name[field_name] |
+ if field in self._oneofs: |
+ field = self._oneofs[field] |
+ else: |
+ return |
+ except KeyError: |
+ raise ValueError('Protocol message has no "%s" field.' % field_name) |
if field in self._fields: |
# Note: If the field is a sub-message, its listener will still point |
@@ -619,6 +802,9 @@ def _AddClearFieldMethod(message_descriptor, cls): |
# will call _Modified() and invalidate our byte size. Big deal. |
del self._fields[field] |
+ if self._oneofs.get(field.containing_oneof, None) is field: |
+ del self._oneofs[field.containing_oneof] |
+ |
# Always call _Modified() -- even if nothing was changed, this is |
# a mutating method, and thus calling it should cause the field to become |
# present in the parent message. |
@@ -645,6 +831,7 @@ def _AddClearMethod(message_descriptor, cls): |
# Clear fields. |
self._fields = {} |
self._unknown_fields = () |
+ self._oneofs = {} |
self._Modified() |
cls.Clear = Clear |
@@ -663,6 +850,26 @@ def _AddHasExtensionMethod(cls): |
return extension_handle in self._fields |
cls.HasExtension = HasExtension |
+def _UnpackAny(msg): |
+ type_url = msg.type_url |
+ db = symbol_database.Default() |
+ |
+ if not type_url: |
+ return None |
+ |
+ # TODO(haberman): For now we just strip the hostname. Better logic will be |
+ # required. |
+ type_name = type_url.split("/")[-1] |
+ descriptor = db.pool.FindMessageTypeByName(type_name) |
+ |
+ if descriptor is None: |
+ return None |
+ |
+ message_class = db.GetPrototype(descriptor) |
+ message = message_class() |
+ |
+ message.ParseFromString(msg.value) |
+ return message |
def _AddEqualsMethod(message_descriptor, cls): |
"""Helper for _AddMessageMethods().""" |
@@ -674,6 +881,12 @@ def _AddEqualsMethod(message_descriptor, cls): |
if self is other: |
return True |
+ if self.DESCRIPTOR.full_name == "google.protobuf.Any": |
+ any_a = _UnpackAny(self) |
+ any_b = _UnpackAny(other) |
+ if any_a and any_b: |
+ return any_a == any_b |
+ |
if not self.ListFields() == other.ListFields(): |
return False |
@@ -773,7 +986,7 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): |
"""Helper for _AddMessageMethods().""" |
def SerializePartialToString(self): |
- out = StringIO() |
+ out = BytesIO() |
self._InternalSerialize(out.write) |
return out.getvalue() |
cls.SerializePartialToString = SerializePartialToString |
@@ -796,9 +1009,10 @@ def _AddMergeFromStringMethod(message_descriptor, cls): |
# The only reason _InternalParse would return early is if it |
# encountered an end-group tag. |
raise message_mod.DecodeError('Unexpected end-group tag.') |
- except IndexError: |
+ except (IndexError, TypeError): |
+ # Now ord(buf[p:p+1]) == ord('') gets TypeError. |
raise message_mod.DecodeError('Truncated message.') |
- except struct.error, e: |
+ except struct.error as e: |
raise message_mod.DecodeError(e) |
return length # Return this for legacy reasons. |
cls.MergeFromString = MergeFromString |
@@ -806,6 +1020,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls): |
local_ReadTag = decoder.ReadTag |
local_SkipField = decoder.SkipField |
decoders_by_tag = cls._decoders_by_tag |
+ is_proto3 = message_descriptor.syntax == "proto3" |
def InternalParse(self, buffer, pos, end): |
self._Modified() |
@@ -813,18 +1028,22 @@ def _AddMergeFromStringMethod(message_descriptor, cls): |
unknown_field_list = self._unknown_fields |
while pos != end: |
(tag_bytes, new_pos) = local_ReadTag(buffer, pos) |
- field_decoder = decoders_by_tag.get(tag_bytes) |
+ field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) |
if field_decoder is None: |
value_start_pos = new_pos |
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) |
if new_pos == -1: |
return pos |
- if not unknown_field_list: |
- unknown_field_list = self._unknown_fields = [] |
- unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) |
+ if not is_proto3: |
+ if not unknown_field_list: |
+ unknown_field_list = self._unknown_fields = [] |
+ unknown_field_list.append( |
+ (tag_bytes, buffer[value_start_pos:new_pos])) |
pos = new_pos |
else: |
pos = field_decoder(buffer, new_pos, end, self, field_dict) |
+ if field_desc: |
+ self._UpdateOneofState(field_desc) |
return pos |
cls._InternalParse = InternalParse |
@@ -857,9 +1076,12 @@ def _AddIsInitializedMethod(message_descriptor, cls): |
errors.extend(self.FindInitializationErrors()) |
return False |
- for field, value in self._fields.iteritems(): |
+ for field, value in list(self._fields.items()): # dict can change size! |
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
if field.label == _FieldDescriptor.LABEL_REPEATED: |
+ if (field.message_type.has_options and |
+ field.message_type.GetOptions().map_entry): |
+ continue |
for element in value: |
if not element.IsInitialized(): |
if errors is not None: |
@@ -895,16 +1117,26 @@ def _AddIsInitializedMethod(message_descriptor, cls): |
else: |
name = field.name |
- if field.label == _FieldDescriptor.LABEL_REPEATED: |
+ if _IsMapField(field): |
+ if _IsMessageMapField(field): |
+ for key in value: |
+ element = value[key] |
+ prefix = "%s[%d]." % (name, key) |
+ sub_errors = element.FindInitializationErrors() |
+ errors += [prefix + error for error in sub_errors] |
+ else: |
+ # ScalarMaps can't have any initialization errors. |
+ pass |
+ elif field.label == _FieldDescriptor.LABEL_REPEATED: |
for i in xrange(len(value)): |
element = value[i] |
prefix = "%s[%d]." % (name, i) |
sub_errors = element.FindInitializationErrors() |
- errors += [ prefix + error for error in sub_errors ] |
+ errors += [prefix + error for error in sub_errors] |
else: |
prefix = name + "." |
sub_errors = value.FindInitializationErrors() |
- errors += [ prefix + error for error in sub_errors ] |
+ errors += [prefix + error for error in sub_errors] |
return errors |
@@ -941,9 +1173,13 @@ def _AddMergeFromMethod(cls): |
# Construct a new object to represent this field. |
field_value = field._default_constructor(self) |
fields[field] = field_value |
+ if field.containing_oneof: |
+ self._UpdateOneofState(field) |
field_value.MergeFrom(value) |
else: |
self._fields[field] = value |
+ if field.containing_oneof: |
+ self._UpdateOneofState(field) |
if msg._unknown_fields: |
if not self._unknown_fields: |
@@ -953,6 +1189,24 @@ def _AddMergeFromMethod(cls): |
cls.MergeFrom = MergeFrom |
+def _AddWhichOneofMethod(message_descriptor, cls): |
+ def WhichOneof(self, oneof_name): |
+ """Returns the name of the currently set field inside a oneof, or None.""" |
+ try: |
+ field = message_descriptor.oneofs_by_name[oneof_name] |
+ except KeyError: |
+ raise ValueError( |
+ 'Protocol message has no oneof "%s" field.' % oneof_name) |
+ |
+ nested_field = self._oneofs.get(field, None) |
+ if nested_field is not None and self.HasField(nested_field.name): |
+ return nested_field.name |
+ else: |
+ return None |
+ |
+ cls.WhichOneof = WhichOneof |
+ |
+ |
def _AddMessageMethods(message_descriptor, cls): |
"""Adds implementations of all Message methods to cls.""" |
_AddListFieldsMethod(message_descriptor, cls) |
@@ -972,9 +1226,9 @@ def _AddMessageMethods(message_descriptor, cls): |
_AddMergeFromStringMethod(message_descriptor, cls) |
_AddIsInitializedMethod(message_descriptor, cls) |
_AddMergeFromMethod(cls) |
+ _AddWhichOneofMethod(message_descriptor, cls) |
- |
-def _AddPrivateHelperMethods(cls): |
+def _AddPrivateHelperMethods(message_descriptor, cls): |
"""Adds implementation of private helper methods to cls.""" |
def Modified(self): |
@@ -992,8 +1246,20 @@ def _AddPrivateHelperMethods(cls): |
self._is_present_in_parent = True |
self._listener.Modified() |
+ def _UpdateOneofState(self, field): |
+ """Sets field as the active field in its containing oneof. |
+ |
+ Will also delete currently active field in the oneof, if it is different |
+ from the argument. Does not mark the message as modified. |
+ """ |
+ other_field = self._oneofs.setdefault(field.containing_oneof, field) |
+ if other_field is not field: |
+ del self._fields[other_field] |
+ self._oneofs[field.containing_oneof] = field |
+ |
cls._Modified = Modified |
cls.SetInParent = Modified |
+ cls._UpdateOneofState = _UpdateOneofState |
class _Listener(object): |
@@ -1042,6 +1308,27 @@ class _Listener(object): |
pass |
+class _OneofListener(_Listener): |
+ """Special listener implementation for setting composite oneof fields.""" |
+ |
+ def __init__(self, parent_message, field): |
+ """Args: |
+ parent_message: The message whose _Modified() method we should call when |
+ we receive Modified() messages. |
+ field: The descriptor of the field being set in the parent message. |
+ """ |
+ super(_OneofListener, self).__init__(parent_message) |
+ self._field = field |
+ |
+ def Modified(self): |
+ """Also updates the state of the containing oneof in the parent message.""" |
+ try: |
+ self._parent_message_weakref._UpdateOneofState(self._field) |
+ super(_OneofListener, self).Modified() |
+ except ReferenceError: |
+ pass |
+ |
+ |
# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... |
# TODO(robinson): Unify error handling of "unknown extension" crap. |
# TODO(robinson): Support iteritems()-style iteration over all |
@@ -1132,10 +1419,10 @@ class _ExtensionDict(object): |
# It's slightly wasteful to lookup the type checker each time, |
# but we expect this to be a vanishingly uncommon case anyway. |
- type_checker = type_checkers.GetTypeChecker( |
- extension_handle.cpp_type, extension_handle.type) |
- type_checker.CheckValue(value) |
- self._extended_message._fields[extension_handle] = value |
+ type_checker = type_checkers.GetTypeChecker(extension_handle) |
+ # pylint: disable=protected-access |
+ self._extended_message._fields[extension_handle] = ( |
+ type_checker.CheckValue(value)) |
self._extended_message._Modified() |
def _FindExtensionByName(self, name): |