Index: third_party/protobuf/python/google/protobuf/internal/python_message.py |
diff --git a/third_party/protobuf/python/google/protobuf/internal/python_message.py b/third_party/protobuf/python/google/protobuf/internal/python_message.py |
index ca9f76753b72562fed5bbbbfd2a810fbfd68d151..4bea57ac6c33616454a02efd845cbc50752b5db2 100644 |
--- a/third_party/protobuf/python/google/protobuf/internal/python_message.py |
+++ b/third_party/protobuf/python/google/protobuf/internal/python_message.py |
@@ -1,6 +1,6 @@ |
# Protocol Buffers - Google's data interchange format |
# Copyright 2008 Google Inc. All rights reserved. |
-# https://developers.google.com/protocol-buffers/ |
+# http://code.google.com/p/protobuf/ |
# |
# Redistribution and use in source and binary forms, with or without |
# modification, are permitted provided that the following conditions are |
@@ -28,10 +28,6 @@ |
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
-# Keep it Python2.5 compatible for GAE. |
-# |
-# Copyright 2007 Google Inc. All Rights Reserved. |
-# |
# This code is meant to work on Python 2.4 and above only. |
# |
# TODO(robinson): Helpers for verbose, common checks like seeing if a |
@@ -54,18 +50,11 @@ this file*. |
__author__ = 'robinson@google.com (Will Robinson)' |
-import sys |
-if sys.version_info[0] < 3: |
- try: |
- from cStringIO import StringIO as BytesIO |
- except ImportError: |
- from StringIO import StringIO as BytesIO |
- import copy_reg as copyreg |
- _basestring = basestring |
-else: |
- from io import BytesIO |
- import copyreg |
- _basestring = str |
+try: |
+ from cStringIO import StringIO |
+except ImportError: |
+ from StringIO import StringIO |
+import copy_reg |
import struct |
import weakref |
@@ -79,7 +68,6 @@ from google.protobuf.internal import type_checkers |
from google.protobuf.internal import wire_format |
from google.protobuf import descriptor as descriptor_mod |
from google.protobuf import message as message_mod |
-from google.protobuf import symbol_database |
from google.protobuf import text_format |
_FieldDescriptor = descriptor_mod.FieldDescriptor |
@@ -98,21 +86,20 @@ def InitMessage(descriptor, cls): |
if (descriptor.has_options and |
descriptor.GetOptions().message_set_wire_format): |
cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( |
- decoder.MessageSetItemDecoder(cls._extensions_by_number), None) |
+ decoder.MessageSetItemDecoder(cls._extensions_by_number)) |
# Attach stuff to each FieldDescriptor for quick lookup later on. |
for field in descriptor.fields: |
_AttachFieldHelpers(cls, field) |
- descriptor._concrete_class = cls # pylint: disable=protected-access |
_AddEnumValues(descriptor, cls) |
_AddInitMethod(descriptor, cls) |
_AddPropertiesForFields(descriptor, cls) |
_AddPropertiesForExtensions(descriptor, cls) |
_AddStaticMethods(cls) |
_AddMessageMethods(descriptor, cls) |
- _AddPrivateHelperMethods(descriptor, cls) |
- copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) |
+ _AddPrivateHelperMethods(cls) |
+ copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) |
# Stateless helpers for GeneratedProtocolMessageType below. |
@@ -189,8 +176,7 @@ def _AddSlots(message_descriptor, dictionary): |
'_is_present_in_parent', |
'_listener', |
'_listener_for_children', |
- '__weakref__', |
- '_oneofs'] |
+ '__weakref__'] |
def _IsMessageSetExtension(field): |
@@ -202,37 +188,12 @@ def _IsMessageSetExtension(field): |
field.label == _FieldDescriptor.LABEL_OPTIONAL) |
-def _IsMapField(field): |
- return (field.type == _FieldDescriptor.TYPE_MESSAGE and |
- field.message_type.has_options and |
- field.message_type.GetOptions().map_entry) |
- |
- |
-def _IsMessageMapField(field): |
- value_type = field.message_type.fields_by_name["value"] |
- return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE |
- |
- |
def _AttachFieldHelpers(cls, field_descriptor): |
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) |
- is_packable = (is_repeated and |
- wire_format.IsTypePackable(field_descriptor.type)) |
- if not is_packable: |
- is_packed = False |
- elif field_descriptor.containing_type.syntax == "proto2": |
- is_packed = (field_descriptor.has_options and |
- field_descriptor.GetOptions().packed) |
- else: |
- has_packed_false = (field_descriptor.has_options and |
- field_descriptor.GetOptions().HasField("packed") and |
- field_descriptor.GetOptions().packed == False) |
- is_packed = not has_packed_false |
- is_map_entry = _IsMapField(field_descriptor) |
- |
- if is_map_entry: |
- field_encoder = encoder.MapEncoder(field_descriptor) |
- sizer = encoder.MapSizer(field_descriptor) |
- elif _IsMessageSetExtension(field_descriptor): |
+ is_packed = (field_descriptor.has_options and |
+ field_descriptor.GetOptions().packed) |
+ |
+ if _IsMessageSetExtension(field_descriptor): |
field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) |
sizer = encoder.MessageSetItemSizer(field_descriptor.number) |
else: |
@@ -248,27 +209,10 @@ def _AttachFieldHelpers(cls, field_descriptor): |
def AddDecoder(wiretype, is_packed): |
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) |
- decode_type = field_descriptor.type |
- if (decode_type == _FieldDescriptor.TYPE_ENUM and |
- type_checkers.SupportsOpenEnums(field_descriptor)): |
- decode_type = _FieldDescriptor.TYPE_INT32 |
- |
- oneof_descriptor = None |
- if field_descriptor.containing_oneof is not None: |
- oneof_descriptor = field_descriptor |
- |
- if is_map_entry: |
- is_message_map = _IsMessageMapField(field_descriptor) |
- |
- field_decoder = decoder.MapDecoder( |
- field_descriptor, _GetInitializeDefaultForMap(field_descriptor), |
- is_message_map) |
- else: |
- field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( |
- field_descriptor.number, is_repeated, is_packed, |
- field_descriptor, field_descriptor._default_constructor) |
- |
- cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) |
+ cls._decoders_by_tag[tag_bytes] = ( |
+ type_checkers.TYPE_TO_DECODER[field_descriptor.type]( |
+ field_descriptor.number, is_repeated, is_packed, |
+ field_descriptor, field_descriptor._default_constructor)) |
AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], |
False) |
@@ -301,26 +245,6 @@ def _AddEnumValues(descriptor, cls): |
setattr(cls, enum_value.name, enum_value.number) |
-def _GetInitializeDefaultForMap(field): |
- if field.label != _FieldDescriptor.LABEL_REPEATED: |
- raise ValueError('map_entry set on non-repeated field %s' % ( |
- field.name)) |
- fields_by_name = field.message_type.fields_by_name |
- key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) |
- |
- value_field = fields_by_name['value'] |
- if _IsMessageMapField(field): |
- def MakeMessageMapDefault(message): |
- return containers.MessageMap( |
- message._listener_for_children, value_field.message_type, key_checker) |
- return MakeMessageMapDefault |
- else: |
- value_checker = type_checkers.GetTypeChecker(value_field) |
- def MakePrimitiveMapDefault(message): |
- return containers.ScalarMap( |
- message._listener_for_children, key_checker, value_checker) |
- return MakePrimitiveMapDefault |
- |
def _DefaultValueConstructorForField(field): |
"""Returns a function which returns a default value for a field. |
@@ -335,9 +259,6 @@ def _DefaultValueConstructorForField(field): |
value may refer back to |message| via a weak reference. |
""" |
- if _IsMapField(field): |
- return _GetInitializeDefaultForMap(field) |
- |
if field.label == _FieldDescriptor.LABEL_REPEATED: |
if field.has_default_value and field.default_value != []: |
raise ValueError('Repeated field default value not empty list: %s' % ( |
@@ -351,7 +272,7 @@ def _DefaultValueConstructorForField(field): |
message._listener_for_children, field.message_type) |
return MakeRepeatedMessageDefault |
else: |
- type_checker = type_checkers.GetTypeChecker(field) |
+ type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) |
def MakeRepeatedScalarDefault(message): |
return containers.RepeatedScalarFieldContainer( |
message._listener_for_children, type_checker) |
@@ -363,8 +284,6 @@ def _DefaultValueConstructorForField(field): |
def MakeSubMessageDefault(message): |
result = message_type._concrete_class() |
result._SetListener(message._listener_for_children) |
- if field.containing_oneof: |
- message._UpdateOneofState(field) |
return result |
return MakeSubMessageDefault |
@@ -375,43 +294,13 @@ def _DefaultValueConstructorForField(field): |
return MakeScalarDefault |
-def _ReraiseTypeErrorWithFieldName(message_name, field_name): |
- """Re-raise the currently-handled TypeError with the field name added.""" |
- exc = sys.exc_info()[1] |
- if len(exc.args) == 1 and type(exc) is TypeError: |
- # simple TypeError; add field name to exception message |
- exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) |
- |
- # re-raise possibly-amended exception with original traceback: |
- raise type(exc)(exc, sys.exc_info()[2]) |
- |
- |
def _AddInitMethod(message_descriptor, cls): |
"""Adds an __init__ method to cls.""" |
- |
- def _GetIntegerEnumValue(enum_type, value): |
- """Convert a string or integer enum value to an integer. |
- |
- If the value is a string, it is converted to the enum value in |
- enum_type with the same name. If the value is not a string, it's |
- returned as-is. (No conversion or bounds-checking is done.) |
- """ |
- if isinstance(value, _basestring): |
- try: |
- return enum_type.values_by_name[value].number |
- except KeyError: |
- raise ValueError('Enum type %s: unknown label "%s"' % ( |
- enum_type.full_name, value)) |
- return value |
- |
+ fields = message_descriptor.fields |
def init(self, **kwargs): |
self._cached_byte_size = 0 |
self._cached_byte_size_dirty = len(kwargs) > 0 |
self._fields = {} |
- # Contains a mapping from oneof field descriptors to the descriptor |
- # of the currently set field in that oneof field. |
- self._oneofs = {} |
- |
# _unknown_fields is () when empty for efficiency, and will be turned into |
# a list if fields are added. |
self._unknown_fields = () |
@@ -426,41 +315,17 @@ def _AddInitMethod(message_descriptor, cls): |
if field.label == _FieldDescriptor.LABEL_REPEATED: |
copy = field._default_constructor(self) |
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite |
- if _IsMapField(field): |
- if _IsMessageMapField(field): |
- for key in field_value: |
- copy[key].MergeFrom(field_value[key]) |
- else: |
- copy.update(field_value) |
- else: |
- for val in field_value: |
- if isinstance(val, dict): |
- copy.add(**val) |
- else: |
- copy.add().MergeFrom(val) |
+ for val in field_value: |
+ copy.add().MergeFrom(val) |
else: # Scalar |
- if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: |
- field_value = [_GetIntegerEnumValue(field.enum_type, val) |
- for val in field_value] |
copy.extend(field_value) |
self._fields[field] = copy |
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
copy = field._default_constructor(self) |
- new_val = field_value |
- if isinstance(field_value, dict): |
- new_val = field.message_type._concrete_class(**field_value) |
- try: |
- copy.MergeFrom(new_val) |
- except TypeError: |
- _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) |
+ copy.MergeFrom(field_value) |
self._fields[field] = copy |
else: |
- if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: |
- field_value = _GetIntegerEnumValue(field.enum_type, field_value) |
- try: |
- setattr(self, field_name, field_value) |
- except TypeError: |
- _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) |
+ setattr(self, field_name, field_value) |
init.__module__ = None |
init.__doc__ = None |
@@ -575,10 +440,9 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): |
""" |
proto_field_name = field.name |
property_name = _PropertyName(proto_field_name) |
- type_checker = type_checkers.GetTypeChecker(field) |
+ type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) |
default_value = field.default_value |
valid_values = set() |
- is_proto3 = field.containing_type.syntax == "proto3" |
def getter(self): |
# TODO(protobuf-team): This may be broken since there may not be |
@@ -586,30 +450,14 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): |
return self._fields.get(field, default_value) |
getter.__module__ = None |
getter.__doc__ = 'Getter for %s.' % proto_field_name |
- |
- clear_when_set_to_default = is_proto3 and not field.containing_oneof |
- |
- def field_setter(self, new_value): |
- # pylint: disable=protected-access |
- # Testing the value for truthiness captures all of the proto3 defaults |
- # (0, 0.0, enum 0, and False). |
- new_value = type_checker.CheckValue(new_value) |
- if clear_when_set_to_default and not new_value: |
- self._fields.pop(field, None) |
- else: |
- self._fields[field] = new_value |
+ def setter(self, new_value): |
+ type_checker.CheckValue(new_value) |
+ self._fields[field] = new_value |
# Check _cached_byte_size_dirty inline to improve performance, since scalar |
# setters are called frequently. |
if not self._cached_byte_size_dirty: |
self._Modified() |
- if field.containing_oneof: |
- def setter(self, new_value): |
- field_setter(self, new_value) |
- self._UpdateOneofState(field) |
- else: |
- setter = field_setter |
- |
setter.__module__ = None |
setter.__doc__ = 'Setter for %s.' % proto_field_name |
@@ -645,10 +493,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): |
if field_value is None: |
# Construct a new object to represent this field. |
field_value = message_type._concrete_class() # use field.message_type? |
- field_value._SetListener( |
- _OneofListener(self, field) |
- if field.containing_oneof is not None |
- else self._listener_for_children) |
+ field_value._SetListener(self._listener_for_children) |
# Atomically check if another thread has preempted us and, if not, swap |
# in the new object we just created. If someone has preempted us, we |
@@ -736,48 +581,27 @@ def _AddListFieldsMethod(message_descriptor, cls): |
cls.ListFields = ListFields |
-_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"' |
-_Proto2HasError = 'Protocol message has no non-repeated field "%s"' |
def _AddHasFieldMethod(message_descriptor, cls): |
"""Helper for _AddMessageMethods().""" |
- is_proto3 = (message_descriptor.syntax == "proto3") |
- error_msg = _Proto3HasError if is_proto3 else _Proto2HasError |
- |
- hassable_fields = {} |
+ singular_fields = {} |
for field in message_descriptor.fields: |
- if field.label == _FieldDescriptor.LABEL_REPEATED: |
- continue |
- # For proto3, only submessages and fields inside a oneof have presence. |
- if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and |
- not field.containing_oneof): |
- continue |
- hassable_fields[field.name] = field |
- |
- if not is_proto3: |
- # Fields inside oneofs are never repeated (enforced by the compiler). |
- for oneof in message_descriptor.oneofs: |
- hassable_fields[oneof.name] = oneof |
+ if field.label != _FieldDescriptor.LABEL_REPEATED: |
+ singular_fields[field.name] = field |
def HasField(self, field_name): |
try: |
- field = hassable_fields[field_name] |
+ field = singular_fields[field_name] |
except KeyError: |
- raise ValueError(error_msg % field_name) |
+ raise ValueError( |
+ 'Protocol message has no singular "%s" field.' % field_name) |
- if isinstance(field, descriptor_mod.OneofDescriptor): |
- try: |
- return HasField(self, self._oneofs[field].name) |
- except KeyError: |
- return False |
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
+ value = self._fields.get(field) |
+ return value is not None and value._is_present_in_parent |
else: |
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
- value = self._fields.get(field) |
- return value is not None and value._is_present_in_parent |
- else: |
- return field in self._fields |
- |
+ return field in self._fields |
cls.HasField = HasField |
@@ -787,14 +611,7 @@ def _AddClearFieldMethod(message_descriptor, cls): |
try: |
field = message_descriptor.fields_by_name[field_name] |
except KeyError: |
- try: |
- field = message_descriptor.oneofs_by_name[field_name] |
- if field in self._oneofs: |
- field = self._oneofs[field] |
- else: |
- return |
- except KeyError: |
- raise ValueError('Protocol message has no "%s" field.' % field_name) |
+ raise ValueError('Protocol message has no "%s" field.' % field_name) |
if field in self._fields: |
# Note: If the field is a sub-message, its listener will still point |
@@ -802,9 +619,6 @@ def _AddClearFieldMethod(message_descriptor, cls): |
# will call _Modified() and invalidate our byte size. Big deal. |
del self._fields[field] |
- if self._oneofs.get(field.containing_oneof, None) is field: |
- del self._oneofs[field.containing_oneof] |
- |
# Always call _Modified() -- even if nothing was changed, this is |
# a mutating method, and thus calling it should cause the field to become |
# present in the parent message. |
@@ -831,7 +645,6 @@ def _AddClearMethod(message_descriptor, cls): |
# Clear fields. |
self._fields = {} |
self._unknown_fields = () |
- self._oneofs = {} |
self._Modified() |
cls.Clear = Clear |
@@ -850,26 +663,6 @@ def _AddHasExtensionMethod(cls): |
return extension_handle in self._fields |
cls.HasExtension = HasExtension |
-def _UnpackAny(msg): |
- type_url = msg.type_url |
- db = symbol_database.Default() |
- |
- if not type_url: |
- return None |
- |
- # TODO(haberman): For now we just strip the hostname. Better logic will be |
- # required. |
- type_name = type_url.split("/")[-1] |
- descriptor = db.pool.FindMessageTypeByName(type_name) |
- |
- if descriptor is None: |
- return None |
- |
- message_class = db.GetPrototype(descriptor) |
- message = message_class() |
- |
- message.ParseFromString(msg.value) |
- return message |
def _AddEqualsMethod(message_descriptor, cls): |
"""Helper for _AddMessageMethods().""" |
@@ -881,12 +674,6 @@ def _AddEqualsMethod(message_descriptor, cls): |
if self is other: |
return True |
- if self.DESCRIPTOR.full_name == "google.protobuf.Any": |
- any_a = _UnpackAny(self) |
- any_b = _UnpackAny(other) |
- if any_a and any_b: |
- return any_a == any_b |
- |
if not self.ListFields() == other.ListFields(): |
return False |
@@ -986,7 +773,7 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): |
"""Helper for _AddMessageMethods().""" |
def SerializePartialToString(self): |
- out = BytesIO() |
+ out = StringIO() |
self._InternalSerialize(out.write) |
return out.getvalue() |
cls.SerializePartialToString = SerializePartialToString |
@@ -1009,10 +796,9 @@ def _AddMergeFromStringMethod(message_descriptor, cls): |
# The only reason _InternalParse would return early is if it |
# encountered an end-group tag. |
raise message_mod.DecodeError('Unexpected end-group tag.') |
- except (IndexError, TypeError): |
- # Now ord(buf[p:p+1]) == ord('') gets TypeError. |
+ except IndexError: |
raise message_mod.DecodeError('Truncated message.') |
- except struct.error as e: |
+ except struct.error, e: |
raise message_mod.DecodeError(e) |
return length # Return this for legacy reasons. |
cls.MergeFromString = MergeFromString |
@@ -1020,7 +806,6 @@ def _AddMergeFromStringMethod(message_descriptor, cls): |
local_ReadTag = decoder.ReadTag |
local_SkipField = decoder.SkipField |
decoders_by_tag = cls._decoders_by_tag |
- is_proto3 = message_descriptor.syntax == "proto3" |
def InternalParse(self, buffer, pos, end): |
self._Modified() |
@@ -1028,22 +813,18 @@ def _AddMergeFromStringMethod(message_descriptor, cls): |
unknown_field_list = self._unknown_fields |
while pos != end: |
(tag_bytes, new_pos) = local_ReadTag(buffer, pos) |
- field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) |
+ field_decoder = decoders_by_tag.get(tag_bytes) |
if field_decoder is None: |
value_start_pos = new_pos |
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) |
if new_pos == -1: |
return pos |
- if not is_proto3: |
- if not unknown_field_list: |
- unknown_field_list = self._unknown_fields = [] |
- unknown_field_list.append( |
- (tag_bytes, buffer[value_start_pos:new_pos])) |
+ if not unknown_field_list: |
+ unknown_field_list = self._unknown_fields = [] |
+ unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) |
pos = new_pos |
else: |
pos = field_decoder(buffer, new_pos, end, self, field_dict) |
- if field_desc: |
- self._UpdateOneofState(field_desc) |
return pos |
cls._InternalParse = InternalParse |
@@ -1076,12 +857,9 @@ def _AddIsInitializedMethod(message_descriptor, cls): |
errors.extend(self.FindInitializationErrors()) |
return False |
- for field, value in list(self._fields.items()): # dict can change size! |
+ for field, value in self._fields.iteritems(): |
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
if field.label == _FieldDescriptor.LABEL_REPEATED: |
- if (field.message_type.has_options and |
- field.message_type.GetOptions().map_entry): |
- continue |
for element in value: |
if not element.IsInitialized(): |
if errors is not None: |
@@ -1117,26 +895,16 @@ def _AddIsInitializedMethod(message_descriptor, cls): |
else: |
name = field.name |
- if _IsMapField(field): |
- if _IsMessageMapField(field): |
- for key in value: |
- element = value[key] |
- prefix = "%s[%d]." % (name, key) |
- sub_errors = element.FindInitializationErrors() |
- errors += [prefix + error for error in sub_errors] |
- else: |
- # ScalarMaps can't have any initialization errors. |
- pass |
- elif field.label == _FieldDescriptor.LABEL_REPEATED: |
+ if field.label == _FieldDescriptor.LABEL_REPEATED: |
for i in xrange(len(value)): |
element = value[i] |
prefix = "%s[%d]." % (name, i) |
sub_errors = element.FindInitializationErrors() |
- errors += [prefix + error for error in sub_errors] |
+ errors += [ prefix + error for error in sub_errors ] |
else: |
prefix = name + "." |
sub_errors = value.FindInitializationErrors() |
- errors += [prefix + error for error in sub_errors] |
+ errors += [ prefix + error for error in sub_errors ] |
return errors |
@@ -1173,13 +941,9 @@ def _AddMergeFromMethod(cls): |
# Construct a new object to represent this field. |
field_value = field._default_constructor(self) |
fields[field] = field_value |
- if field.containing_oneof: |
- self._UpdateOneofState(field) |
field_value.MergeFrom(value) |
else: |
self._fields[field] = value |
- if field.containing_oneof: |
- self._UpdateOneofState(field) |
if msg._unknown_fields: |
if not self._unknown_fields: |
@@ -1189,24 +953,6 @@ def _AddMergeFromMethod(cls): |
cls.MergeFrom = MergeFrom |
-def _AddWhichOneofMethod(message_descriptor, cls): |
- def WhichOneof(self, oneof_name): |
- """Returns the name of the currently set field inside a oneof, or None.""" |
- try: |
- field = message_descriptor.oneofs_by_name[oneof_name] |
- except KeyError: |
- raise ValueError( |
- 'Protocol message has no oneof "%s" field.' % oneof_name) |
- |
- nested_field = self._oneofs.get(field, None) |
- if nested_field is not None and self.HasField(nested_field.name): |
- return nested_field.name |
- else: |
- return None |
- |
- cls.WhichOneof = WhichOneof |
- |
- |
def _AddMessageMethods(message_descriptor, cls): |
"""Adds implementations of all Message methods to cls.""" |
_AddListFieldsMethod(message_descriptor, cls) |
@@ -1226,9 +972,9 @@ def _AddMessageMethods(message_descriptor, cls): |
_AddMergeFromStringMethod(message_descriptor, cls) |
_AddIsInitializedMethod(message_descriptor, cls) |
_AddMergeFromMethod(cls) |
- _AddWhichOneofMethod(message_descriptor, cls) |
-def _AddPrivateHelperMethods(message_descriptor, cls): |
+ |
+def _AddPrivateHelperMethods(cls): |
"""Adds implementation of private helper methods to cls.""" |
def Modified(self): |
@@ -1246,20 +992,8 @@ def _AddPrivateHelperMethods(message_descriptor, cls): |
self._is_present_in_parent = True |
self._listener.Modified() |
- def _UpdateOneofState(self, field): |
- """Sets field as the active field in its containing oneof. |
- |
- Will also delete currently active field in the oneof, if it is different |
- from the argument. Does not mark the message as modified. |
- """ |
- other_field = self._oneofs.setdefault(field.containing_oneof, field) |
- if other_field is not field: |
- del self._fields[other_field] |
- self._oneofs[field.containing_oneof] = field |
- |
cls._Modified = Modified |
cls.SetInParent = Modified |
- cls._UpdateOneofState = _UpdateOneofState |
class _Listener(object): |
@@ -1308,27 +1042,6 @@ class _Listener(object): |
pass |
-class _OneofListener(_Listener): |
- """Special listener implementation for setting composite oneof fields.""" |
- |
- def __init__(self, parent_message, field): |
- """Args: |
- parent_message: The message whose _Modified() method we should call when |
- we receive Modified() messages. |
- field: The descriptor of the field being set in the parent message. |
- """ |
- super(_OneofListener, self).__init__(parent_message) |
- self._field = field |
- |
- def Modified(self): |
- """Also updates the state of the containing oneof in the parent message.""" |
- try: |
- self._parent_message_weakref._UpdateOneofState(self._field) |
- super(_OneofListener, self).Modified() |
- except ReferenceError: |
- pass |
- |
- |
# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... |
# TODO(robinson): Unify error handling of "unknown extension" crap. |
# TODO(robinson): Support iteritems()-style iteration over all |
@@ -1419,10 +1132,10 @@ class _ExtensionDict(object): |
# It's slightly wasteful to lookup the type checker each time, |
# but we expect this to be a vanishingly uncommon case anyway. |
- type_checker = type_checkers.GetTypeChecker(extension_handle) |
- # pylint: disable=protected-access |
- self._extended_message._fields[extension_handle] = ( |
- type_checker.CheckValue(value)) |
+ type_checker = type_checkers.GetTypeChecker( |
+ extension_handle.cpp_type, extension_handle.type) |
+ type_checker.CheckValue(value) |
+ self._extended_message._fields[extension_handle] = value |
self._extended_message._Modified() |
def _FindExtensionByName(self, name): |