Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(10)

Unified Diff: third_party/protobuf/python/google/protobuf/internal/python_message.py

Issue 1291903002: Pull new version of protobuf sources. (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@master
Patch Set: Build fix attempts Created 5 years, 4 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
Index: third_party/protobuf/python/google/protobuf/internal/python_message.py
diff --git a/third_party/protobuf/python/google/protobuf/internal/python_message.py b/third_party/protobuf/python/google/protobuf/internal/python_message.py
index 4bea57ac6c33616454a02efd845cbc50752b5db2..ca9f76753b72562fed5bbbbfd2a810fbfd68d151 100644
--- a/third_party/protobuf/python/google/protobuf/internal/python_message.py
+++ b/third_party/protobuf/python/google/protobuf/internal/python_message.py
@@ -1,6 +1,6 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
+# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
@@ -28,6 +28,10 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# Keep it Python2.5 compatible for GAE.
+#
+# Copyright 2007 Google Inc. All Rights Reserved.
+#
# This code is meant to work on Python 2.4 and above only.
#
# TODO(robinson): Helpers for verbose, common checks like seeing if a
@@ -50,11 +54,18 @@ this file*.
__author__ = 'robinson@google.com (Will Robinson)'
-try:
- from cStringIO import StringIO
-except ImportError:
- from StringIO import StringIO
-import copy_reg
+import sys
+if sys.version_info[0] < 3:
+ try:
+ from cStringIO import StringIO as BytesIO
+ except ImportError:
+ from StringIO import StringIO as BytesIO
+ import copy_reg as copyreg
+ _basestring = basestring
+else:
+ from io import BytesIO
+ import copyreg
+ _basestring = str
import struct
import weakref
@@ -68,6 +79,7 @@ from google.protobuf.internal import type_checkers
from google.protobuf.internal import wire_format
from google.protobuf import descriptor as descriptor_mod
from google.protobuf import message as message_mod
+from google.protobuf import symbol_database
from google.protobuf import text_format
_FieldDescriptor = descriptor_mod.FieldDescriptor
@@ -86,20 +98,21 @@ def InitMessage(descriptor, cls):
if (descriptor.has_options and
descriptor.GetOptions().message_set_wire_format):
cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
- decoder.MessageSetItemDecoder(cls._extensions_by_number))
+ decoder.MessageSetItemDecoder(cls._extensions_by_number), None)
# Attach stuff to each FieldDescriptor for quick lookup later on.
for field in descriptor.fields:
_AttachFieldHelpers(cls, field)
+ descriptor._concrete_class = cls # pylint: disable=protected-access
_AddEnumValues(descriptor, cls)
_AddInitMethod(descriptor, cls)
_AddPropertiesForFields(descriptor, cls)
_AddPropertiesForExtensions(descriptor, cls)
_AddStaticMethods(cls)
_AddMessageMethods(descriptor, cls)
- _AddPrivateHelperMethods(cls)
- copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
+ _AddPrivateHelperMethods(descriptor, cls)
+ copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
# Stateless helpers for GeneratedProtocolMessageType below.
@@ -176,7 +189,8 @@ def _AddSlots(message_descriptor, dictionary):
'_is_present_in_parent',
'_listener',
'_listener_for_children',
- '__weakref__']
+ '__weakref__',
+ '_oneofs']
def _IsMessageSetExtension(field):
@@ -188,12 +202,37 @@ def _IsMessageSetExtension(field):
field.label == _FieldDescriptor.LABEL_OPTIONAL)
+def _IsMapField(field):
+ return (field.type == _FieldDescriptor.TYPE_MESSAGE and
+ field.message_type.has_options and
+ field.message_type.GetOptions().map_entry)
+
+
+def _IsMessageMapField(field):
+ value_type = field.message_type.fields_by_name["value"]
+ return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
+
+
def _AttachFieldHelpers(cls, field_descriptor):
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
- is_packed = (field_descriptor.has_options and
- field_descriptor.GetOptions().packed)
-
- if _IsMessageSetExtension(field_descriptor):
+ is_packable = (is_repeated and
+ wire_format.IsTypePackable(field_descriptor.type))
+ if not is_packable:
+ is_packed = False
+ elif field_descriptor.containing_type.syntax == "proto2":
+ is_packed = (field_descriptor.has_options and
+ field_descriptor.GetOptions().packed)
+ else:
+ has_packed_false = (field_descriptor.has_options and
+ field_descriptor.GetOptions().HasField("packed") and
+ field_descriptor.GetOptions().packed == False)
+ is_packed = not has_packed_false
+ is_map_entry = _IsMapField(field_descriptor)
+
+ if is_map_entry:
+ field_encoder = encoder.MapEncoder(field_descriptor)
+ sizer = encoder.MapSizer(field_descriptor)
+ elif _IsMessageSetExtension(field_descriptor):
field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
sizer = encoder.MessageSetItemSizer(field_descriptor.number)
else:
@@ -209,10 +248,27 @@ def _AttachFieldHelpers(cls, field_descriptor):
def AddDecoder(wiretype, is_packed):
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
- cls._decoders_by_tag[tag_bytes] = (
- type_checkers.TYPE_TO_DECODER[field_descriptor.type](
- field_descriptor.number, is_repeated, is_packed,
- field_descriptor, field_descriptor._default_constructor))
+ decode_type = field_descriptor.type
+ if (decode_type == _FieldDescriptor.TYPE_ENUM and
+ type_checkers.SupportsOpenEnums(field_descriptor)):
+ decode_type = _FieldDescriptor.TYPE_INT32
+
+ oneof_descriptor = None
+ if field_descriptor.containing_oneof is not None:
+ oneof_descriptor = field_descriptor
+
+ if is_map_entry:
+ is_message_map = _IsMessageMapField(field_descriptor)
+
+ field_decoder = decoder.MapDecoder(
+ field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
+ is_message_map)
+ else:
+ field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor)
+
+ cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
False)
@@ -245,6 +301,26 @@ def _AddEnumValues(descriptor, cls):
setattr(cls, enum_value.name, enum_value.number)
+def _GetInitializeDefaultForMap(field):
+ if field.label != _FieldDescriptor.LABEL_REPEATED:
+ raise ValueError('map_entry set on non-repeated field %s' % (
+ field.name))
+ fields_by_name = field.message_type.fields_by_name
+ key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
+
+ value_field = fields_by_name['value']
+ if _IsMessageMapField(field):
+ def MakeMessageMapDefault(message):
+ return containers.MessageMap(
+ message._listener_for_children, value_field.message_type, key_checker)
+ return MakeMessageMapDefault
+ else:
+ value_checker = type_checkers.GetTypeChecker(value_field)
+ def MakePrimitiveMapDefault(message):
+ return containers.ScalarMap(
+ message._listener_for_children, key_checker, value_checker)
+ return MakePrimitiveMapDefault
+
def _DefaultValueConstructorForField(field):
"""Returns a function which returns a default value for a field.
@@ -259,6 +335,9 @@ def _DefaultValueConstructorForField(field):
value may refer back to |message| via a weak reference.
"""
+ if _IsMapField(field):
+ return _GetInitializeDefaultForMap(field)
+
if field.label == _FieldDescriptor.LABEL_REPEATED:
if field.has_default_value and field.default_value != []:
raise ValueError('Repeated field default value not empty list: %s' % (
@@ -272,7 +351,7 @@ def _DefaultValueConstructorForField(field):
message._listener_for_children, field.message_type)
return MakeRepeatedMessageDefault
else:
- type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
+ type_checker = type_checkers.GetTypeChecker(field)
def MakeRepeatedScalarDefault(message):
return containers.RepeatedScalarFieldContainer(
message._listener_for_children, type_checker)
@@ -284,6 +363,8 @@ def _DefaultValueConstructorForField(field):
def MakeSubMessageDefault(message):
result = message_type._concrete_class()
result._SetListener(message._listener_for_children)
+ if field.containing_oneof:
+ message._UpdateOneofState(field)
return result
return MakeSubMessageDefault
@@ -294,13 +375,43 @@ def _DefaultValueConstructorForField(field):
return MakeScalarDefault
+def _ReraiseTypeErrorWithFieldName(message_name, field_name):
+ """Re-raise the currently-handled TypeError with the field name added."""
+ exc = sys.exc_info()[1]
+ if len(exc.args) == 1 and type(exc) is TypeError:
+ # simple TypeError; add field name to exception message
+ exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
+
+ # re-raise possibly-amended exception with original traceback:
+ raise type(exc)(exc, sys.exc_info()[2])
+
+
def _AddInitMethod(message_descriptor, cls):
"""Adds an __init__ method to cls."""
- fields = message_descriptor.fields
+
+ def _GetIntegerEnumValue(enum_type, value):
+ """Convert a string or integer enum value to an integer.
+
+ If the value is a string, it is converted to the enum value in
+ enum_type with the same name. If the value is not a string, it's
+ returned as-is. (No conversion or bounds-checking is done.)
+ """
+ if isinstance(value, _basestring):
+ try:
+ return enum_type.values_by_name[value].number
+ except KeyError:
+ raise ValueError('Enum type %s: unknown label "%s"' % (
+ enum_type.full_name, value))
+ return value
+
def init(self, **kwargs):
self._cached_byte_size = 0
self._cached_byte_size_dirty = len(kwargs) > 0
self._fields = {}
+ # Contains a mapping from oneof field descriptors to the descriptor
+ # of the currently set field in that oneof field.
+ self._oneofs = {}
+
# _unknown_fields is () when empty for efficiency, and will be turned into
# a list if fields are added.
self._unknown_fields = ()
@@ -315,17 +426,41 @@ def _AddInitMethod(message_descriptor, cls):
if field.label == _FieldDescriptor.LABEL_REPEATED:
copy = field._default_constructor(self)
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
- for val in field_value:
- copy.add().MergeFrom(val)
+ if _IsMapField(field):
+ if _IsMessageMapField(field):
+ for key in field_value:
+ copy[key].MergeFrom(field_value[key])
+ else:
+ copy.update(field_value)
+ else:
+ for val in field_value:
+ if isinstance(val, dict):
+ copy.add(**val)
+ else:
+ copy.add().MergeFrom(val)
else: # Scalar
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
+ field_value = [_GetIntegerEnumValue(field.enum_type, val)
+ for val in field_value]
copy.extend(field_value)
self._fields[field] = copy
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
copy = field._default_constructor(self)
- copy.MergeFrom(field_value)
+ new_val = field_value
+ if isinstance(field_value, dict):
+ new_val = field.message_type._concrete_class(**field_value)
+ try:
+ copy.MergeFrom(new_val)
+ except TypeError:
+ _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
self._fields[field] = copy
else:
- setattr(self, field_name, field_value)
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
+ field_value = _GetIntegerEnumValue(field.enum_type, field_value)
+ try:
+ setattr(self, field_name, field_value)
+ except TypeError:
+ _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
init.__module__ = None
init.__doc__ = None
@@ -440,9 +575,10 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
"""
proto_field_name = field.name
property_name = _PropertyName(proto_field_name)
- type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
+ type_checker = type_checkers.GetTypeChecker(field)
default_value = field.default_value
valid_values = set()
+ is_proto3 = field.containing_type.syntax == "proto3"
def getter(self):
# TODO(protobuf-team): This may be broken since there may not be
@@ -450,14 +586,30 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
return self._fields.get(field, default_value)
getter.__module__ = None
getter.__doc__ = 'Getter for %s.' % proto_field_name
- def setter(self, new_value):
- type_checker.CheckValue(new_value)
- self._fields[field] = new_value
+
+ clear_when_set_to_default = is_proto3 and not field.containing_oneof
+
+ def field_setter(self, new_value):
+ # pylint: disable=protected-access
+ # Testing the value for truthiness captures all of the proto3 defaults
+ # (0, 0.0, enum 0, and False).
+ new_value = type_checker.CheckValue(new_value)
+ if clear_when_set_to_default and not new_value:
+ self._fields.pop(field, None)
+ else:
+ self._fields[field] = new_value
# Check _cached_byte_size_dirty inline to improve performance, since scalar
# setters are called frequently.
if not self._cached_byte_size_dirty:
self._Modified()
+ if field.containing_oneof:
+ def setter(self, new_value):
+ field_setter(self, new_value)
+ self._UpdateOneofState(field)
+ else:
+ setter = field_setter
+
setter.__module__ = None
setter.__doc__ = 'Setter for %s.' % proto_field_name
@@ -493,7 +645,10 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
if field_value is None:
# Construct a new object to represent this field.
field_value = message_type._concrete_class() # use field.message_type?
- field_value._SetListener(self._listener_for_children)
+ field_value._SetListener(
+ _OneofListener(self, field)
+ if field.containing_oneof is not None
+ else self._listener_for_children)
# Atomically check if another thread has preempted us and, if not, swap
# in the new object we just created. If someone has preempted us, we
@@ -581,27 +736,48 @@ def _AddListFieldsMethod(message_descriptor, cls):
cls.ListFields = ListFields
+_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"'
+_Proto2HasError = 'Protocol message has no non-repeated field "%s"'
def _AddHasFieldMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- singular_fields = {}
+ is_proto3 = (message_descriptor.syntax == "proto3")
+ error_msg = _Proto3HasError if is_proto3 else _Proto2HasError
+
+ hassable_fields = {}
for field in message_descriptor.fields:
- if field.label != _FieldDescriptor.LABEL_REPEATED:
- singular_fields[field.name] = field
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ continue
+ # For proto3, only submessages and fields inside a oneof have presence.
+ if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
+ not field.containing_oneof):
+ continue
+ hassable_fields[field.name] = field
+
+ if not is_proto3:
+ # Fields inside oneofs are never repeated (enforced by the compiler).
+ for oneof in message_descriptor.oneofs:
+ hassable_fields[oneof.name] = oneof
def HasField(self, field_name):
try:
- field = singular_fields[field_name]
+ field = hassable_fields[field_name]
except KeyError:
- raise ValueError(
- 'Protocol message has no singular "%s" field.' % field_name)
+ raise ValueError(error_msg % field_name)
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- value = self._fields.get(field)
- return value is not None and value._is_present_in_parent
+ if isinstance(field, descriptor_mod.OneofDescriptor):
+ try:
+ return HasField(self, self._oneofs[field].name)
+ except KeyError:
+ return False
else:
- return field in self._fields
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ value = self._fields.get(field)
+ return value is not None and value._is_present_in_parent
+ else:
+ return field in self._fields
+
cls.HasField = HasField
@@ -611,7 +787,14 @@ def _AddClearFieldMethod(message_descriptor, cls):
try:
field = message_descriptor.fields_by_name[field_name]
except KeyError:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
+ try:
+ field = message_descriptor.oneofs_by_name[field_name]
+ if field in self._oneofs:
+ field = self._oneofs[field]
+ else:
+ return
+ except KeyError:
+ raise ValueError('Protocol message has no "%s" field.' % field_name)
if field in self._fields:
# Note: If the field is a sub-message, its listener will still point
@@ -619,6 +802,9 @@ def _AddClearFieldMethod(message_descriptor, cls):
# will call _Modified() and invalidate our byte size. Big deal.
del self._fields[field]
+ if self._oneofs.get(field.containing_oneof, None) is field:
+ del self._oneofs[field.containing_oneof]
+
# Always call _Modified() -- even if nothing was changed, this is
# a mutating method, and thus calling it should cause the field to become
# present in the parent message.
@@ -645,6 +831,7 @@ def _AddClearMethod(message_descriptor, cls):
# Clear fields.
self._fields = {}
self._unknown_fields = ()
+ self._oneofs = {}
self._Modified()
cls.Clear = Clear
@@ -663,6 +850,26 @@ def _AddHasExtensionMethod(cls):
return extension_handle in self._fields
cls.HasExtension = HasExtension
+def _UnpackAny(msg):
+ type_url = msg.type_url
+ db = symbol_database.Default()
+
+ if not type_url:
+ return None
+
+ # TODO(haberman): For now we just strip the hostname. Better logic will be
+ # required.
+ type_name = type_url.split("/")[-1]
+ descriptor = db.pool.FindMessageTypeByName(type_name)
+
+ if descriptor is None:
+ return None
+
+ message_class = db.GetPrototype(descriptor)
+ message = message_class()
+
+ message.ParseFromString(msg.value)
+ return message
def _AddEqualsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
@@ -674,6 +881,12 @@ def _AddEqualsMethod(message_descriptor, cls):
if self is other:
return True
+ if self.DESCRIPTOR.full_name == "google.protobuf.Any":
+ any_a = _UnpackAny(self)
+ any_b = _UnpackAny(other)
+ if any_a and any_b:
+ return any_a == any_b
+
if not self.ListFields() == other.ListFields():
return False
@@ -773,7 +986,7 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def SerializePartialToString(self):
- out = StringIO()
+ out = BytesIO()
self._InternalSerialize(out.write)
return out.getvalue()
cls.SerializePartialToString = SerializePartialToString
@@ -796,9 +1009,10 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
# The only reason _InternalParse would return early is if it
# encountered an end-group tag.
raise message_mod.DecodeError('Unexpected end-group tag.')
- except IndexError:
+ except (IndexError, TypeError):
+ # Now ord(buf[p:p+1]) == ord('') gets TypeError.
raise message_mod.DecodeError('Truncated message.')
- except struct.error, e:
+ except struct.error as e:
raise message_mod.DecodeError(e)
return length # Return this for legacy reasons.
cls.MergeFromString = MergeFromString
@@ -806,6 +1020,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField
decoders_by_tag = cls._decoders_by_tag
+ is_proto3 = message_descriptor.syntax == "proto3"
def InternalParse(self, buffer, pos, end):
self._Modified()
@@ -813,18 +1028,22 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
unknown_field_list = self._unknown_fields
while pos != end:
(tag_bytes, new_pos) = local_ReadTag(buffer, pos)
- field_decoder = decoders_by_tag.get(tag_bytes)
+ field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
if field_decoder is None:
value_start_pos = new_pos
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
if new_pos == -1:
return pos
- if not unknown_field_list:
- unknown_field_list = self._unknown_fields = []
- unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
+ if not is_proto3:
+ if not unknown_field_list:
+ unknown_field_list = self._unknown_fields = []
+ unknown_field_list.append(
+ (tag_bytes, buffer[value_start_pos:new_pos]))
pos = new_pos
else:
pos = field_decoder(buffer, new_pos, end, self, field_dict)
+ if field_desc:
+ self._UpdateOneofState(field_desc)
return pos
cls._InternalParse = InternalParse
@@ -857,9 +1076,12 @@ def _AddIsInitializedMethod(message_descriptor, cls):
errors.extend(self.FindInitializationErrors())
return False
- for field, value in self._fields.iteritems():
+ for field, value in list(self._fields.items()): # dict can change size!
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
if field.label == _FieldDescriptor.LABEL_REPEATED:
+ if (field.message_type.has_options and
+ field.message_type.GetOptions().map_entry):
+ continue
for element in value:
if not element.IsInitialized():
if errors is not None:
@@ -895,16 +1117,26 @@ def _AddIsInitializedMethod(message_descriptor, cls):
else:
name = field.name
- if field.label == _FieldDescriptor.LABEL_REPEATED:
+ if _IsMapField(field):
+ if _IsMessageMapField(field):
+ for key in value:
+ element = value[key]
+ prefix = "%s[%d]." % (name, key)
+ sub_errors = element.FindInitializationErrors()
+ errors += [prefix + error for error in sub_errors]
+ else:
+ # ScalarMaps can't have any initialization errors.
+ pass
+ elif field.label == _FieldDescriptor.LABEL_REPEATED:
for i in xrange(len(value)):
element = value[i]
prefix = "%s[%d]." % (name, i)
sub_errors = element.FindInitializationErrors()
- errors += [ prefix + error for error in sub_errors ]
+ errors += [prefix + error for error in sub_errors]
else:
prefix = name + "."
sub_errors = value.FindInitializationErrors()
- errors += [ prefix + error for error in sub_errors ]
+ errors += [prefix + error for error in sub_errors]
return errors
@@ -941,9 +1173,13 @@ def _AddMergeFromMethod(cls):
# Construct a new object to represent this field.
field_value = field._default_constructor(self)
fields[field] = field_value
+ if field.containing_oneof:
+ self._UpdateOneofState(field)
field_value.MergeFrom(value)
else:
self._fields[field] = value
+ if field.containing_oneof:
+ self._UpdateOneofState(field)
if msg._unknown_fields:
if not self._unknown_fields:
@@ -953,6 +1189,24 @@ def _AddMergeFromMethod(cls):
cls.MergeFrom = MergeFrom
+def _AddWhichOneofMethod(message_descriptor, cls):
+ def WhichOneof(self, oneof_name):
+ """Returns the name of the currently set field inside a oneof, or None."""
+ try:
+ field = message_descriptor.oneofs_by_name[oneof_name]
+ except KeyError:
+ raise ValueError(
+ 'Protocol message has no oneof "%s" field.' % oneof_name)
+
+ nested_field = self._oneofs.get(field, None)
+ if nested_field is not None and self.HasField(nested_field.name):
+ return nested_field.name
+ else:
+ return None
+
+ cls.WhichOneof = WhichOneof
+
+
def _AddMessageMethods(message_descriptor, cls):
"""Adds implementations of all Message methods to cls."""
_AddListFieldsMethod(message_descriptor, cls)
@@ -972,9 +1226,9 @@ def _AddMessageMethods(message_descriptor, cls):
_AddMergeFromStringMethod(message_descriptor, cls)
_AddIsInitializedMethod(message_descriptor, cls)
_AddMergeFromMethod(cls)
+ _AddWhichOneofMethod(message_descriptor, cls)
-
-def _AddPrivateHelperMethods(cls):
+def _AddPrivateHelperMethods(message_descriptor, cls):
"""Adds implementation of private helper methods to cls."""
def Modified(self):
@@ -992,8 +1246,20 @@ def _AddPrivateHelperMethods(cls):
self._is_present_in_parent = True
self._listener.Modified()
+ def _UpdateOneofState(self, field):
+ """Sets field as the active field in its containing oneof.
+
+ Will also delete currently active field in the oneof, if it is different
+ from the argument. Does not mark the message as modified.
+ """
+ other_field = self._oneofs.setdefault(field.containing_oneof, field)
+ if other_field is not field:
+ del self._fields[other_field]
+ self._oneofs[field.containing_oneof] = field
+
cls._Modified = Modified
cls.SetInParent = Modified
+ cls._UpdateOneofState = _UpdateOneofState
class _Listener(object):
@@ -1042,6 +1308,27 @@ class _Listener(object):
pass
+class _OneofListener(_Listener):
+ """Special listener implementation for setting composite oneof fields."""
+
+ def __init__(self, parent_message, field):
+ """Args:
+ parent_message: The message whose _Modified() method we should call when
+ we receive Modified() messages.
+ field: The descriptor of the field being set in the parent message.
+ """
+ super(_OneofListener, self).__init__(parent_message)
+ self._field = field
+
+ def Modified(self):
+ """Also updates the state of the containing oneof in the parent message."""
+ try:
+ self._parent_message_weakref._UpdateOneofState(self._field)
+ super(_OneofListener, self).Modified()
+ except ReferenceError:
+ pass
+
+
# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
# TODO(robinson): Unify error handling of "unknown extension" crap.
# TODO(robinson): Support iteritems()-style iteration over all
@@ -1132,10 +1419,10 @@ class _ExtensionDict(object):
# It's slightly wasteful to lookup the type checker each time,
# but we expect this to be a vanishingly uncommon case anyway.
- type_checker = type_checkers.GetTypeChecker(
- extension_handle.cpp_type, extension_handle.type)
- type_checker.CheckValue(value)
- self._extended_message._fields[extension_handle] = value
+ type_checker = type_checkers.GetTypeChecker(extension_handle)
+ # pylint: disable=protected-access
+ self._extended_message._fields[extension_handle] = (
+ type_checker.CheckValue(value))
self._extended_message._Modified()
def _FindExtensionByName(self, name):

Powered by Google App Engine
This is Rietveld 408576698