Index: third_party/protobuf/python/google/protobuf/internal/python_message.py |
diff --git a/third_party/protobuf/python/google/protobuf/internal/python_message.py b/third_party/protobuf/python/google/protobuf/internal/python_message.py |
old mode 100644 |
new mode 100755 |
index 4bea57ac6c33616454a02efd845cbc50752b5db2..87f60666ab64012b8aeb51f7f1a7c76a916d16c2 |
--- a/third_party/protobuf/python/google/protobuf/internal/python_message.py |
+++ b/third_party/protobuf/python/google/protobuf/internal/python_message.py |
@@ -1,6 +1,6 @@ |
# Protocol Buffers - Google's data interchange format |
# Copyright 2008 Google Inc. All rights reserved. |
-# http://code.google.com/p/protobuf/ |
+# https://developers.google.com/protocol-buffers/ |
# |
# Redistribution and use in source and binary forms, with or without |
# modification, are permitted provided that the following conditions are |
@@ -50,14 +50,14 @@ this file*. |
__author__ = 'robinson@google.com (Will Robinson)' |
-try: |
- from cStringIO import StringIO |
-except ImportError: |
- from StringIO import StringIO |
-import copy_reg |
+from io import BytesIO |
+import sys |
import struct |
import weakref |
+import six |
+import six.moves.copyreg as copyreg |
+ |
# We use "as" to avoid name collisions with variables. |
from google.protobuf.internal import containers |
from google.protobuf.internal import decoder |
@@ -65,41 +65,121 @@ from google.protobuf.internal import encoder |
from google.protobuf.internal import enum_type_wrapper |
from google.protobuf.internal import message_listener as message_listener_mod |
from google.protobuf.internal import type_checkers |
+from google.protobuf.internal import well_known_types |
from google.protobuf.internal import wire_format |
from google.protobuf import descriptor as descriptor_mod |
from google.protobuf import message as message_mod |
+from google.protobuf import symbol_database |
from google.protobuf import text_format |
_FieldDescriptor = descriptor_mod.FieldDescriptor |
+_AnyFullTypeName = 'google.protobuf.Any' |
-def NewMessage(bases, descriptor, dictionary): |
- _AddClassAttributesForNestedExtensions(descriptor, dictionary) |
- _AddSlots(descriptor, dictionary) |
- return bases |
+class GeneratedProtocolMessageType(type): |
+ """Metaclass for protocol message classes created at runtime from Descriptors. |
-def InitMessage(descriptor, cls): |
- cls._decoders_by_tag = {} |
- cls._extensions_by_name = {} |
- cls._extensions_by_number = {} |
- if (descriptor.has_options and |
- descriptor.GetOptions().message_set_wire_format): |
- cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( |
- decoder.MessageSetItemDecoder(cls._extensions_by_number)) |
+ We add implementations for all methods described in the Message class. We |
+ also create properties to allow getting/setting all fields in the protocol |
+ message. Finally, we create slots to prevent users from accidentally |
+ "setting" nonexistent fields in the protocol message, which then wouldn't get |
+ serialized / deserialized properly. |
- # Attach stuff to each FieldDescriptor for quick lookup later on. |
- for field in descriptor.fields: |
- _AttachFieldHelpers(cls, field) |
+ The protocol compiler currently uses this metaclass to create protocol |
+ message classes at runtime. Clients can also manually create their own |
+ classes at runtime, as in this example: |
- _AddEnumValues(descriptor, cls) |
- _AddInitMethod(descriptor, cls) |
- _AddPropertiesForFields(descriptor, cls) |
- _AddPropertiesForExtensions(descriptor, cls) |
- _AddStaticMethods(cls) |
- _AddMessageMethods(descriptor, cls) |
- _AddPrivateHelperMethods(cls) |
- copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) |
+ mydescriptor = Descriptor(.....) |
+ class MyProtoClass(Message): |
+ __metaclass__ = GeneratedProtocolMessageType |
+ DESCRIPTOR = mydescriptor |
+ myproto_instance = MyProtoClass() |
+ myproto.foo_field = 23 |
+ ... |
+ |
+ The above example will not work for nested types. If you wish to include them, |
+ use reflection.MakeClass() instead of manually instantiating the class in |
+ order to create the appropriate class structure. |
+ """ |
+ |
+ # Must be consistent with the protocol-compiler code in |
+ # proto2/compiler/internal/generator.*. |
+ _DESCRIPTOR_KEY = 'DESCRIPTOR' |
+ |
+ def __new__(cls, name, bases, dictionary): |
+ """Custom allocation for runtime-generated class types. |
+ |
+ We override __new__ because this is apparently the only place |
+ where we can meaningfully set __slots__ on the class we're creating(?). |
+ (The interplay between metaclasses and slots is not very well-documented). |
+ |
+ Args: |
+ name: Name of the class (ignored, but required by the |
+ metaclass protocol). |
+ bases: Base classes of the class we're constructing. |
+ (Should be message.Message). We ignore this field, but |
+ it's required by the metaclass protocol |
+ dictionary: The class dictionary of the class we're |
+ constructing. dictionary[_DESCRIPTOR_KEY] must contain |
+ a Descriptor object describing this protocol message |
+ type. |
+ |
+ Returns: |
+ Newly-allocated class. |
+ """ |
+ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] |
+ if descriptor.full_name in well_known_types.WKTBASES: |
+ bases += (well_known_types.WKTBASES[descriptor.full_name],) |
+ _AddClassAttributesForNestedExtensions(descriptor, dictionary) |
+ _AddSlots(descriptor, dictionary) |
+ |
+ superclass = super(GeneratedProtocolMessageType, cls) |
+ new_class = superclass.__new__(cls, name, bases, dictionary) |
+ return new_class |
+ |
+ def __init__(cls, name, bases, dictionary): |
+ """Here we perform the majority of our work on the class. |
+ We add enum getters, an __init__ method, implementations |
+ of all Message methods, and properties for all fields |
+ in the protocol type. |
+ |
+ Args: |
+ name: Name of the class (ignored, but required by the |
+ metaclass protocol). |
+ bases: Base classes of the class we're constructing. |
+ (Should be message.Message). We ignore this field, but |
+ it's required by the metaclass protocol |
+ dictionary: The class dictionary of the class we're |
+ constructing. dictionary[_DESCRIPTOR_KEY] must contain |
+ a Descriptor object describing this protocol message |
+ type. |
+ """ |
+ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] |
+ cls._decoders_by_tag = {} |
+ cls._extensions_by_name = {} |
+ cls._extensions_by_number = {} |
+ if (descriptor.has_options and |
+ descriptor.GetOptions().message_set_wire_format): |
+ cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( |
+ decoder.MessageSetItemDecoder(cls._extensions_by_number), None) |
+ |
+ # Attach stuff to each FieldDescriptor for quick lookup later on. |
+ for field in descriptor.fields: |
+ _AttachFieldHelpers(cls, field) |
+ |
+ descriptor._concrete_class = cls # pylint: disable=protected-access |
+ _AddEnumValues(descriptor, cls) |
+ _AddInitMethod(descriptor, cls) |
+ _AddPropertiesForFields(descriptor, cls) |
+ _AddPropertiesForExtensions(descriptor, cls) |
+ _AddStaticMethods(cls) |
+ _AddMessageMethods(descriptor, cls) |
+ _AddPrivateHelperMethods(descriptor, cls) |
+ copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) |
+ |
+ superclass = super(GeneratedProtocolMessageType, cls) |
+ superclass.__init__(name, bases, dictionary) |
# Stateless helpers for GeneratedProtocolMessageType below. |
@@ -176,7 +256,8 @@ def _AddSlots(message_descriptor, dictionary): |
'_is_present_in_parent', |
'_listener', |
'_listener_for_children', |
- '__weakref__'] |
+ '__weakref__', |
+ '_oneofs'] |
def _IsMessageSetExtension(field): |
@@ -184,16 +265,40 @@ def _IsMessageSetExtension(field): |
field.containing_type.has_options and |
field.containing_type.GetOptions().message_set_wire_format and |
field.type == _FieldDescriptor.TYPE_MESSAGE and |
- field.message_type == field.extension_scope and |
field.label == _FieldDescriptor.LABEL_OPTIONAL) |
+def _IsMapField(field): |
+ return (field.type == _FieldDescriptor.TYPE_MESSAGE and |
+ field.message_type.has_options and |
+ field.message_type.GetOptions().map_entry) |
+ |
+ |
+def _IsMessageMapField(field): |
+ value_type = field.message_type.fields_by_name["value"] |
+ return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE |
+ |
+ |
def _AttachFieldHelpers(cls, field_descriptor): |
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) |
- is_packed = (field_descriptor.has_options and |
- field_descriptor.GetOptions().packed) |
- |
- if _IsMessageSetExtension(field_descriptor): |
+ is_packable = (is_repeated and |
+ wire_format.IsTypePackable(field_descriptor.type)) |
+ if not is_packable: |
+ is_packed = False |
+ elif field_descriptor.containing_type.syntax == "proto2": |
+ is_packed = (field_descriptor.has_options and |
+ field_descriptor.GetOptions().packed) |
+ else: |
+ has_packed_false = (field_descriptor.has_options and |
+ field_descriptor.GetOptions().HasField("packed") and |
+ field_descriptor.GetOptions().packed == False) |
+ is_packed = not has_packed_false |
+ is_map_entry = _IsMapField(field_descriptor) |
+ |
+ if is_map_entry: |
+ field_encoder = encoder.MapEncoder(field_descriptor) |
+ sizer = encoder.MapSizer(field_descriptor) |
+ elif _IsMessageSetExtension(field_descriptor): |
field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) |
sizer = encoder.MessageSetItemSizer(field_descriptor.number) |
else: |
@@ -209,10 +314,27 @@ def _AttachFieldHelpers(cls, field_descriptor): |
def AddDecoder(wiretype, is_packed): |
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) |
- cls._decoders_by_tag[tag_bytes] = ( |
- type_checkers.TYPE_TO_DECODER[field_descriptor.type]( |
- field_descriptor.number, is_repeated, is_packed, |
- field_descriptor, field_descriptor._default_constructor)) |
+ decode_type = field_descriptor.type |
+ if (decode_type == _FieldDescriptor.TYPE_ENUM and |
+ type_checkers.SupportsOpenEnums(field_descriptor)): |
+ decode_type = _FieldDescriptor.TYPE_INT32 |
+ |
+ oneof_descriptor = None |
+ if field_descriptor.containing_oneof is not None: |
+ oneof_descriptor = field_descriptor |
+ |
+ if is_map_entry: |
+ is_message_map = _IsMessageMapField(field_descriptor) |
+ |
+ field_decoder = decoder.MapDecoder( |
+ field_descriptor, _GetInitializeDefaultForMap(field_descriptor), |
+ is_message_map) |
+ else: |
+ field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( |
+ field_descriptor.number, is_repeated, is_packed, |
+ field_descriptor, field_descriptor._default_constructor) |
+ |
+ cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) |
AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], |
False) |
@@ -225,7 +347,7 @@ def _AttachFieldHelpers(cls, field_descriptor): |
def _AddClassAttributesForNestedExtensions(descriptor, dictionary): |
extension_dict = descriptor.extensions_by_name |
- for extension_name, extension_field in extension_dict.iteritems(): |
+ for extension_name, extension_field in extension_dict.items(): |
assert extension_name not in dictionary |
dictionary[extension_name] = extension_field |
@@ -245,6 +367,26 @@ def _AddEnumValues(descriptor, cls): |
setattr(cls, enum_value.name, enum_value.number) |
+def _GetInitializeDefaultForMap(field): |
+ if field.label != _FieldDescriptor.LABEL_REPEATED: |
+ raise ValueError('map_entry set on non-repeated field %s' % ( |
+ field.name)) |
+ fields_by_name = field.message_type.fields_by_name |
+ key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) |
+ |
+ value_field = fields_by_name['value'] |
+ if _IsMessageMapField(field): |
+ def MakeMessageMapDefault(message): |
+ return containers.MessageMap( |
+ message._listener_for_children, value_field.message_type, key_checker) |
+ return MakeMessageMapDefault |
+ else: |
+ value_checker = type_checkers.GetTypeChecker(value_field) |
+ def MakePrimitiveMapDefault(message): |
+ return containers.ScalarMap( |
+ message._listener_for_children, key_checker, value_checker) |
+ return MakePrimitiveMapDefault |
+ |
def _DefaultValueConstructorForField(field): |
"""Returns a function which returns a default value for a field. |
@@ -259,6 +401,9 @@ def _DefaultValueConstructorForField(field): |
value may refer back to |message| via a weak reference. |
""" |
+ if _IsMapField(field): |
+ return _GetInitializeDefaultForMap(field) |
+ |
if field.label == _FieldDescriptor.LABEL_REPEATED: |
if field.has_default_value and field.default_value != []: |
raise ValueError('Repeated field default value not empty list: %s' % ( |
@@ -272,7 +417,7 @@ def _DefaultValueConstructorForField(field): |
message._listener_for_children, field.message_type) |
return MakeRepeatedMessageDefault |
else: |
- type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) |
+ type_checker = type_checkers.GetTypeChecker(field) |
def MakeRepeatedScalarDefault(message): |
return containers.RepeatedScalarFieldContainer( |
message._listener_for_children, type_checker) |
@@ -283,7 +428,10 @@ def _DefaultValueConstructorForField(field): |
message_type = field.message_type |
def MakeSubMessageDefault(message): |
result = message_type._concrete_class() |
- result._SetListener(message._listener_for_children) |
+ result._SetListener( |
+ _OneofListener(message, field) |
+ if field.containing_oneof is not None |
+ else message._listener_for_children) |
return result |
return MakeSubMessageDefault |
@@ -294,20 +442,50 @@ def _DefaultValueConstructorForField(field): |
return MakeScalarDefault |
+def _ReraiseTypeErrorWithFieldName(message_name, field_name): |
+ """Re-raise the currently-handled TypeError with the field name added.""" |
+ exc = sys.exc_info()[1] |
+ if len(exc.args) == 1 and type(exc) is TypeError: |
+ # simple TypeError; add field name to exception message |
+ exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) |
+ |
+ # re-raise possibly-amended exception with original traceback: |
+ six.reraise(type(exc), exc, sys.exc_info()[2]) |
+ |
+ |
def _AddInitMethod(message_descriptor, cls): |
"""Adds an __init__ method to cls.""" |
- fields = message_descriptor.fields |
+ |
+ def _GetIntegerEnumValue(enum_type, value): |
+ """Convert a string or integer enum value to an integer. |
+ |
+ If the value is a string, it is converted to the enum value in |
+ enum_type with the same name. If the value is not a string, it's |
+ returned as-is. (No conversion or bounds-checking is done.) |
+ """ |
+ if isinstance(value, six.string_types): |
+ try: |
+ return enum_type.values_by_name[value].number |
+ except KeyError: |
+ raise ValueError('Enum type %s: unknown label "%s"' % ( |
+ enum_type.full_name, value)) |
+ return value |
+ |
def init(self, **kwargs): |
self._cached_byte_size = 0 |
self._cached_byte_size_dirty = len(kwargs) > 0 |
self._fields = {} |
+ # Contains a mapping from oneof field descriptors to the descriptor |
+ # of the currently set field in that oneof field. |
+ self._oneofs = {} |
+ |
# _unknown_fields is () when empty for efficiency, and will be turned into |
# a list if fields are added. |
self._unknown_fields = () |
self._is_present_in_parent = False |
self._listener = message_listener_mod.NullMessageListener() |
self._listener_for_children = _Listener(self) |
- for field_name, field_value in kwargs.iteritems(): |
+ for field_name, field_value in kwargs.items(): |
field = _GetFieldByName(message_descriptor, field_name) |
if field is None: |
raise TypeError("%s() got an unexpected keyword argument '%s'" % |
@@ -315,17 +493,41 @@ def _AddInitMethod(message_descriptor, cls): |
if field.label == _FieldDescriptor.LABEL_REPEATED: |
copy = field._default_constructor(self) |
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite |
- for val in field_value: |
- copy.add().MergeFrom(val) |
+ if _IsMapField(field): |
+ if _IsMessageMapField(field): |
+ for key in field_value: |
+ copy[key].MergeFrom(field_value[key]) |
+ else: |
+ copy.update(field_value) |
+ else: |
+ for val in field_value: |
+ if isinstance(val, dict): |
+ copy.add(**val) |
+ else: |
+ copy.add().MergeFrom(val) |
else: # Scalar |
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: |
+ field_value = [_GetIntegerEnumValue(field.enum_type, val) |
+ for val in field_value] |
copy.extend(field_value) |
self._fields[field] = copy |
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
copy = field._default_constructor(self) |
- copy.MergeFrom(field_value) |
+ new_val = field_value |
+ if isinstance(field_value, dict): |
+ new_val = field.message_type._concrete_class(**field_value) |
+ try: |
+ copy.MergeFrom(new_val) |
+ except TypeError: |
+ _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) |
self._fields[field] = copy |
else: |
- setattr(self, field_name, field_value) |
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: |
+ field_value = _GetIntegerEnumValue(field.enum_type, field_value) |
+ try: |
+ setattr(self, field_name, field_value) |
+ except TypeError: |
+ _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) |
init.__module__ = None |
init.__doc__ = None |
@@ -344,7 +546,8 @@ def _GetFieldByName(message_descriptor, field_name): |
try: |
return message_descriptor.fields_by_name[field_name] |
except KeyError: |
- raise ValueError('Protocol message has no "%s" field.' % field_name) |
+ raise ValueError('Protocol message %s has no "%s" field.' % |
+ (message_descriptor.name, field_name)) |
def _AddPropertiesForFields(descriptor, cls): |
@@ -440,9 +643,10 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): |
""" |
proto_field_name = field.name |
property_name = _PropertyName(proto_field_name) |
- type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) |
+ type_checker = type_checkers.GetTypeChecker(field) |
default_value = field.default_value |
valid_values = set() |
+ is_proto3 = field.containing_type.syntax == "proto3" |
def getter(self): |
# TODO(protobuf-team): This may be broken since there may not be |
@@ -450,14 +654,30 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): |
return self._fields.get(field, default_value) |
getter.__module__ = None |
getter.__doc__ = 'Getter for %s.' % proto_field_name |
- def setter(self, new_value): |
- type_checker.CheckValue(new_value) |
- self._fields[field] = new_value |
+ |
+ clear_when_set_to_default = is_proto3 and not field.containing_oneof |
+ |
+ def field_setter(self, new_value): |
+ # pylint: disable=protected-access |
+ # Testing the value for truthiness captures all of the proto3 defaults |
+ # (0, 0.0, enum 0, and False). |
+ new_value = type_checker.CheckValue(new_value) |
+ if clear_when_set_to_default and not new_value: |
+ self._fields.pop(field, None) |
+ else: |
+ self._fields[field] = new_value |
# Check _cached_byte_size_dirty inline to improve performance, since scalar |
# setters are called frequently. |
if not self._cached_byte_size_dirty: |
self._Modified() |
+ if field.containing_oneof: |
+ def setter(self, new_value): |
+ field_setter(self, new_value) |
+ self._UpdateOneofState(field) |
+ else: |
+ setter = field_setter |
+ |
setter.__module__ = None |
setter.__doc__ = 'Setter for %s.' % proto_field_name |
@@ -482,18 +702,11 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): |
proto_field_name = field.name |
property_name = _PropertyName(proto_field_name) |
- # TODO(komarek): Can anyone explain to me why we cache the message_type this |
- # way, instead of referring to field.message_type inside of getter(self)? |
- # What if someone sets message_type later on (which makes for simpler |
- # dyanmic proto descriptor and class creation code). |
- message_type = field.message_type |
- |
def getter(self): |
field_value = self._fields.get(field) |
if field_value is None: |
# Construct a new object to represent this field. |
- field_value = message_type._concrete_class() # use field.message_type? |
- field_value._SetListener(self._listener_for_children) |
+ field_value = field._default_constructor(self) |
# Atomically check if another thread has preempted us and, if not, swap |
# in the new object we just created. If someone has preempted us, we |
@@ -520,7 +733,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): |
def _AddPropertiesForExtensions(descriptor, cls): |
"""Adds properties for all fields in this protocol message type.""" |
extension_dict = descriptor.extensions_by_name |
- for extension_name, extension_field in extension_dict.iteritems(): |
+ for extension_name, extension_field in extension_dict.items(): |
constant_name = extension_name.upper() + "_FIELD_NUMBER" |
setattr(cls, constant_name, extension_field.number) |
@@ -575,33 +788,54 @@ def _AddListFieldsMethod(message_descriptor, cls): |
"""Helper for _AddMessageMethods().""" |
def ListFields(self): |
- all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] |
+ all_fields = [item for item in self._fields.items() if _IsPresent(item)] |
all_fields.sort(key = lambda item: item[0].number) |
return all_fields |
cls.ListFields = ListFields |
+_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"' |
+_Proto2HasError = 'Protocol message has no non-repeated field "%s"' |
def _AddHasFieldMethod(message_descriptor, cls): |
"""Helper for _AddMessageMethods().""" |
- singular_fields = {} |
+ is_proto3 = (message_descriptor.syntax == "proto3") |
+ error_msg = _Proto3HasError if is_proto3 else _Proto2HasError |
+ |
+ hassable_fields = {} |
for field in message_descriptor.fields: |
- if field.label != _FieldDescriptor.LABEL_REPEATED: |
- singular_fields[field.name] = field |
+ if field.label == _FieldDescriptor.LABEL_REPEATED: |
+ continue |
+ # For proto3, only submessages and fields inside a oneof have presence. |
+ if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and |
+ not field.containing_oneof): |
+ continue |
+ hassable_fields[field.name] = field |
+ |
+ if not is_proto3: |
+ # Fields inside oneofs are never repeated (enforced by the compiler). |
+ for oneof in message_descriptor.oneofs: |
+ hassable_fields[oneof.name] = oneof |
def HasField(self, field_name): |
try: |
- field = singular_fields[field_name] |
+ field = hassable_fields[field_name] |
except KeyError: |
- raise ValueError( |
- 'Protocol message has no singular "%s" field.' % field_name) |
+ raise ValueError(error_msg % field_name) |
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
- value = self._fields.get(field) |
- return value is not None and value._is_present_in_parent |
+ if isinstance(field, descriptor_mod.OneofDescriptor): |
+ try: |
+ return HasField(self, self._oneofs[field].name) |
+ except KeyError: |
+ return False |
else: |
- return field in self._fields |
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
+ value = self._fields.get(field) |
+ return value is not None and value._is_present_in_parent |
+ else: |
+ return field in self._fields |
+ |
cls.HasField = HasField |
@@ -611,14 +845,30 @@ def _AddClearFieldMethod(message_descriptor, cls): |
try: |
field = message_descriptor.fields_by_name[field_name] |
except KeyError: |
- raise ValueError('Protocol message has no "%s" field.' % field_name) |
+ try: |
+ field = message_descriptor.oneofs_by_name[field_name] |
+ if field in self._oneofs: |
+ field = self._oneofs[field] |
+ else: |
+ return |
+ except KeyError: |
+ raise ValueError('Protocol message %s() has no "%s" field.' % |
+ (message_descriptor.name, field_name)) |
if field in self._fields: |
+ # To match the C++ implementation, we need to invalidate iterators |
+ # for map fields when ClearField() happens. |
+ if hasattr(self._fields[field], 'InvalidateIterators'): |
+ self._fields[field].InvalidateIterators() |
+ |
# Note: If the field is a sub-message, its listener will still point |
# at us. That's fine, because the worst than can happen is that it |
# will call _Modified() and invalidate our byte size. Big deal. |
del self._fields[field] |
+ if self._oneofs.get(field.containing_oneof, None) is field: |
+ del self._oneofs[field.containing_oneof] |
+ |
# Always call _Modified() -- even if nothing was changed, this is |
# a mutating method, and thus calling it should cause the field to become |
# present in the parent message. |
@@ -645,6 +895,7 @@ def _AddClearMethod(message_descriptor, cls): |
# Clear fields. |
self._fields = {} |
self._unknown_fields = () |
+ self._oneofs = {} |
self._Modified() |
cls.Clear = Clear |
@@ -663,6 +914,38 @@ def _AddHasExtensionMethod(cls): |
return extension_handle in self._fields |
cls.HasExtension = HasExtension |
+def _InternalUnpackAny(msg): |
+ """Unpacks Any message and returns the unpacked message. |
+ |
+ This internal method is differnt from public Any Unpack method which takes |
+ the target message as argument. _InternalUnpackAny method does not have |
+ target message type and need to find the message type in descriptor pool. |
+ |
+ Args: |
+ msg: An Any message to be unpacked. |
+ |
+ Returns: |
+ The unpacked message. |
+ """ |
+ type_url = msg.type_url |
+ db = symbol_database.Default() |
+ |
+ if not type_url: |
+ return None |
+ |
+ # TODO(haberman): For now we just strip the hostname. Better logic will be |
+ # required. |
+ type_name = type_url.split("/")[-1] |
+ descriptor = db.pool.FindMessageTypeByName(type_name) |
+ |
+ if descriptor is None: |
+ return None |
+ |
+ message_class = db.GetPrototype(descriptor) |
+ message = message_class() |
+ |
+ message.ParseFromString(msg.value) |
+ return message |
def _AddEqualsMethod(message_descriptor, cls): |
"""Helper for _AddMessageMethods().""" |
@@ -674,6 +957,12 @@ def _AddEqualsMethod(message_descriptor, cls): |
if self is other: |
return True |
+ if self.DESCRIPTOR.full_name == _AnyFullTypeName: |
+ any_a = _InternalUnpackAny(self) |
+ any_b = _InternalUnpackAny(other) |
+ if any_a and any_b: |
+ return any_a == any_b |
+ |
if not self.ListFields() == other.ListFields(): |
return False |
@@ -695,6 +984,13 @@ def _AddStrMethod(message_descriptor, cls): |
cls.__str__ = __str__ |
+def _AddReprMethod(message_descriptor, cls): |
+ """Helper for _AddMessageMethods().""" |
+ def __repr__(self): |
+ return text_format.MessageToString(self) |
+ cls.__repr__ = __repr__ |
+ |
+ |
def _AddUnicodeMethod(unused_message_descriptor, cls): |
"""Helper for _AddMessageMethods().""" |
@@ -773,7 +1069,7 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): |
"""Helper for _AddMessageMethods().""" |
def SerializePartialToString(self): |
- out = StringIO() |
+ out = BytesIO() |
self._InternalSerialize(out.write) |
return out.getvalue() |
cls.SerializePartialToString = SerializePartialToString |
@@ -796,9 +1092,10 @@ def _AddMergeFromStringMethod(message_descriptor, cls): |
# The only reason _InternalParse would return early is if it |
# encountered an end-group tag. |
raise message_mod.DecodeError('Unexpected end-group tag.') |
- except IndexError: |
+ except (IndexError, TypeError): |
+ # Now ord(buf[p:p+1]) == ord('') gets TypeError. |
raise message_mod.DecodeError('Truncated message.') |
- except struct.error, e: |
+ except struct.error as e: |
raise message_mod.DecodeError(e) |
return length # Return this for legacy reasons. |
cls.MergeFromString = MergeFromString |
@@ -806,6 +1103,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls): |
local_ReadTag = decoder.ReadTag |
local_SkipField = decoder.SkipField |
decoders_by_tag = cls._decoders_by_tag |
+ is_proto3 = message_descriptor.syntax == "proto3" |
def InternalParse(self, buffer, pos, end): |
self._Modified() |
@@ -813,18 +1111,22 @@ def _AddMergeFromStringMethod(message_descriptor, cls): |
unknown_field_list = self._unknown_fields |
while pos != end: |
(tag_bytes, new_pos) = local_ReadTag(buffer, pos) |
- field_decoder = decoders_by_tag.get(tag_bytes) |
+ field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) |
if field_decoder is None: |
value_start_pos = new_pos |
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) |
if new_pos == -1: |
return pos |
- if not unknown_field_list: |
- unknown_field_list = self._unknown_fields = [] |
- unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) |
+ if not is_proto3: |
+ if not unknown_field_list: |
+ unknown_field_list = self._unknown_fields = [] |
+ unknown_field_list.append( |
+ (tag_bytes, buffer[value_start_pos:new_pos])) |
pos = new_pos |
else: |
pos = field_decoder(buffer, new_pos, end, self, field_dict) |
+ if field_desc: |
+ self._UpdateOneofState(field_desc) |
return pos |
cls._InternalParse = InternalParse |
@@ -857,9 +1159,12 @@ def _AddIsInitializedMethod(message_descriptor, cls): |
errors.extend(self.FindInitializationErrors()) |
return False |
- for field, value in self._fields.iteritems(): |
+ for field, value in list(self._fields.items()): # dict can change size! |
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
if field.label == _FieldDescriptor.LABEL_REPEATED: |
+ if (field.message_type.has_options and |
+ field.message_type.GetOptions().map_entry): |
+ continue |
for element in value: |
if not element.IsInitialized(): |
if errors is not None: |
@@ -895,16 +1200,26 @@ def _AddIsInitializedMethod(message_descriptor, cls): |
else: |
name = field.name |
- if field.label == _FieldDescriptor.LABEL_REPEATED: |
- for i in xrange(len(value)): |
+ if _IsMapField(field): |
+ if _IsMessageMapField(field): |
+ for key in value: |
+ element = value[key] |
+ prefix = "%s[%s]." % (name, key) |
+ sub_errors = element.FindInitializationErrors() |
+ errors += [prefix + error for error in sub_errors] |
+ else: |
+ # ScalarMaps can't have any initialization errors. |
+ pass |
+ elif field.label == _FieldDescriptor.LABEL_REPEATED: |
+ for i in range(len(value)): |
element = value[i] |
prefix = "%s[%d]." % (name, i) |
sub_errors = element.FindInitializationErrors() |
- errors += [ prefix + error for error in sub_errors ] |
+ errors += [prefix + error for error in sub_errors] |
else: |
prefix = name + "." |
sub_errors = value.FindInitializationErrors() |
- errors += [ prefix + error for error in sub_errors ] |
+ errors += [prefix + error for error in sub_errors] |
return errors |
@@ -926,7 +1241,7 @@ def _AddMergeFromMethod(cls): |
fields = self._fields |
- for field, value in msg._fields.iteritems(): |
+ for field, value in msg._fields.items(): |
if field.label == LABEL_REPEATED: |
field_value = fields.get(field) |
if field_value is None: |
@@ -944,6 +1259,8 @@ def _AddMergeFromMethod(cls): |
field_value.MergeFrom(value) |
else: |
self._fields[field] = value |
+ if field.containing_oneof: |
+ self._UpdateOneofState(field) |
if msg._unknown_fields: |
if not self._unknown_fields: |
@@ -953,6 +1270,24 @@ def _AddMergeFromMethod(cls): |
cls.MergeFrom = MergeFrom |
+def _AddWhichOneofMethod(message_descriptor, cls): |
+ def WhichOneof(self, oneof_name): |
+ """Returns the name of the currently set field inside a oneof, or None.""" |
+ try: |
+ field = message_descriptor.oneofs_by_name[oneof_name] |
+ except KeyError: |
+ raise ValueError( |
+ 'Protocol message has no oneof "%s" field.' % oneof_name) |
+ |
+ nested_field = self._oneofs.get(field, None) |
+ if nested_field is not None and self.HasField(nested_field.name): |
+ return nested_field.name |
+ else: |
+ return None |
+ |
+ cls.WhichOneof = WhichOneof |
+ |
+ |
def _AddMessageMethods(message_descriptor, cls): |
"""Adds implementations of all Message methods to cls.""" |
_AddListFieldsMethod(message_descriptor, cls) |
@@ -964,6 +1299,7 @@ def _AddMessageMethods(message_descriptor, cls): |
_AddClearMethod(message_descriptor, cls) |
_AddEqualsMethod(message_descriptor, cls) |
_AddStrMethod(message_descriptor, cls) |
+ _AddReprMethod(message_descriptor, cls) |
_AddUnicodeMethod(message_descriptor, cls) |
_AddSetListenerMethod(cls) |
_AddByteSizeMethod(message_descriptor, cls) |
@@ -972,9 +1308,10 @@ def _AddMessageMethods(message_descriptor, cls): |
_AddMergeFromStringMethod(message_descriptor, cls) |
_AddIsInitializedMethod(message_descriptor, cls) |
_AddMergeFromMethod(cls) |
+ _AddWhichOneofMethod(message_descriptor, cls) |
-def _AddPrivateHelperMethods(cls): |
+def _AddPrivateHelperMethods(message_descriptor, cls): |
"""Adds implementation of private helper methods to cls.""" |
def Modified(self): |
@@ -992,8 +1329,20 @@ def _AddPrivateHelperMethods(cls): |
self._is_present_in_parent = True |
self._listener.Modified() |
+ def _UpdateOneofState(self, field): |
+ """Sets field as the active field in its containing oneof. |
+ |
+ Will also delete currently active field in the oneof, if it is different |
+ from the argument. Does not mark the message as modified. |
+ """ |
+ other_field = self._oneofs.setdefault(field.containing_oneof, field) |
+ if other_field is not field: |
+ del self._fields[other_field] |
+ self._oneofs[field.containing_oneof] = field |
+ |
cls._Modified = Modified |
cls.SetInParent = Modified |
+ cls._UpdateOneofState = _UpdateOneofState |
class _Listener(object): |
@@ -1042,6 +1391,27 @@ class _Listener(object): |
pass |
+class _OneofListener(_Listener): |
+ """Special listener implementation for setting composite oneof fields.""" |
+ |
+ def __init__(self, parent_message, field): |
+ """Args: |
+ parent_message: The message whose _Modified() method we should call when |
+ we receive Modified() messages. |
+ field: The descriptor of the field being set in the parent message. |
+ """ |
+ super(_OneofListener, self).__init__(parent_message) |
+ self._field = field |
+ |
+ def Modified(self): |
+ """Also updates the state of the containing oneof in the parent message.""" |
+ try: |
+ self._parent_message_weakref._UpdateOneofState(self._field) |
+ super(_OneofListener, self).Modified() |
+ except ReferenceError: |
+ pass |
+ |
+ |
# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... |
# TODO(robinson): Unify error handling of "unknown extension" crap. |
# TODO(robinson): Support iteritems()-style iteration over all |
@@ -1132,10 +1502,10 @@ class _ExtensionDict(object): |
# It's slightly wasteful to lookup the type checker each time, |
# but we expect this to be a vanishingly uncommon case anyway. |
- type_checker = type_checkers.GetTypeChecker( |
- extension_handle.cpp_type, extension_handle.type) |
- type_checker.CheckValue(value) |
- self._extended_message._fields[extension_handle] = value |
+ type_checker = type_checkers.GetTypeChecker(extension_handle) |
+ # pylint: disable=protected-access |
+ self._extended_message._fields[extension_handle] = ( |
+ type_checker.CheckValue(value)) |
self._extended_message._Modified() |
def _FindExtensionByName(self, name): |