| Index: third_party/protobuf/python/google/protobuf/internal/python_message.py
|
| diff --git a/third_party/protobuf/python/google/protobuf/internal/python_message.py b/third_party/protobuf/python/google/protobuf/internal/python_message.py
|
| old mode 100644
|
| new mode 100755
|
| index 4bea57ac6c33616454a02efd845cbc50752b5db2..87f60666ab64012b8aeb51f7f1a7c76a916d16c2
|
| --- a/third_party/protobuf/python/google/protobuf/internal/python_message.py
|
| +++ b/third_party/protobuf/python/google/protobuf/internal/python_message.py
|
| @@ -1,6 +1,6 @@
|
| # Protocol Buffers - Google's data interchange format
|
| # Copyright 2008 Google Inc. All rights reserved.
|
| -# http://code.google.com/p/protobuf/
|
| +# https://developers.google.com/protocol-buffers/
|
| #
|
| # Redistribution and use in source and binary forms, with or without
|
| # modification, are permitted provided that the following conditions are
|
| @@ -50,14 +50,14 @@ this file*.
|
|
|
| __author__ = 'robinson@google.com (Will Robinson)'
|
|
|
| -try:
|
| - from cStringIO import StringIO
|
| -except ImportError:
|
| - from StringIO import StringIO
|
| -import copy_reg
|
| +from io import BytesIO
|
| +import sys
|
| import struct
|
| import weakref
|
|
|
| +import six
|
| +import six.moves.copyreg as copyreg
|
| +
|
| # We use "as" to avoid name collisions with variables.
|
| from google.protobuf.internal import containers
|
| from google.protobuf.internal import decoder
|
| @@ -65,41 +65,121 @@ from google.protobuf.internal import encoder
|
| from google.protobuf.internal import enum_type_wrapper
|
| from google.protobuf.internal import message_listener as message_listener_mod
|
| from google.protobuf.internal import type_checkers
|
| +from google.protobuf.internal import well_known_types
|
| from google.protobuf.internal import wire_format
|
| from google.protobuf import descriptor as descriptor_mod
|
| from google.protobuf import message as message_mod
|
| +from google.protobuf import symbol_database
|
| from google.protobuf import text_format
|
|
|
| _FieldDescriptor = descriptor_mod.FieldDescriptor
|
| +_AnyFullTypeName = 'google.protobuf.Any'
|
|
|
|
|
| -def NewMessage(bases, descriptor, dictionary):
|
| - _AddClassAttributesForNestedExtensions(descriptor, dictionary)
|
| - _AddSlots(descriptor, dictionary)
|
| - return bases
|
| +class GeneratedProtocolMessageType(type):
|
|
|
| + """Metaclass for protocol message classes created at runtime from Descriptors.
|
|
|
| -def InitMessage(descriptor, cls):
|
| - cls._decoders_by_tag = {}
|
| - cls._extensions_by_name = {}
|
| - cls._extensions_by_number = {}
|
| - if (descriptor.has_options and
|
| - descriptor.GetOptions().message_set_wire_format):
|
| - cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
|
| - decoder.MessageSetItemDecoder(cls._extensions_by_number))
|
| + We add implementations for all methods described in the Message class. We
|
| + also create properties to allow getting/setting all fields in the protocol
|
| + message. Finally, we create slots to prevent users from accidentally
|
| + "setting" nonexistent fields in the protocol message, which then wouldn't get
|
| + serialized / deserialized properly.
|
|
|
| - # Attach stuff to each FieldDescriptor for quick lookup later on.
|
| - for field in descriptor.fields:
|
| - _AttachFieldHelpers(cls, field)
|
| + The protocol compiler currently uses this metaclass to create protocol
|
| + message classes at runtime. Clients can also manually create their own
|
| + classes at runtime, as in this example:
|
|
|
| - _AddEnumValues(descriptor, cls)
|
| - _AddInitMethod(descriptor, cls)
|
| - _AddPropertiesForFields(descriptor, cls)
|
| - _AddPropertiesForExtensions(descriptor, cls)
|
| - _AddStaticMethods(cls)
|
| - _AddMessageMethods(descriptor, cls)
|
| - _AddPrivateHelperMethods(cls)
|
| - copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
|
| + mydescriptor = Descriptor(.....)
|
| + class MyProtoClass(Message):
|
| + __metaclass__ = GeneratedProtocolMessageType
|
| + DESCRIPTOR = mydescriptor
|
| + myproto_instance = MyProtoClass()
|
| + myproto.foo_field = 23
|
| + ...
|
| +
|
| + The above example will not work for nested types. If you wish to include them,
|
| + use reflection.MakeClass() instead of manually instantiating the class in
|
| + order to create the appropriate class structure.
|
| + """
|
| +
|
| + # Must be consistent with the protocol-compiler code in
|
| + # proto2/compiler/internal/generator.*.
|
| + _DESCRIPTOR_KEY = 'DESCRIPTOR'
|
| +
|
| + def __new__(cls, name, bases, dictionary):
|
| + """Custom allocation for runtime-generated class types.
|
| +
|
| + We override __new__ because this is apparently the only place
|
| + where we can meaningfully set __slots__ on the class we're creating(?).
|
| + (The interplay between metaclasses and slots is not very well-documented).
|
| +
|
| + Args:
|
| + name: Name of the class (ignored, but required by the
|
| + metaclass protocol).
|
| + bases: Base classes of the class we're constructing.
|
| + (Should be message.Message). We ignore this field, but
|
| + it's required by the metaclass protocol
|
| + dictionary: The class dictionary of the class we're
|
| + constructing. dictionary[_DESCRIPTOR_KEY] must contain
|
| + a Descriptor object describing this protocol message
|
| + type.
|
| +
|
| + Returns:
|
| + Newly-allocated class.
|
| + """
|
| + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
|
| + if descriptor.full_name in well_known_types.WKTBASES:
|
| + bases += (well_known_types.WKTBASES[descriptor.full_name],)
|
| + _AddClassAttributesForNestedExtensions(descriptor, dictionary)
|
| + _AddSlots(descriptor, dictionary)
|
| +
|
| + superclass = super(GeneratedProtocolMessageType, cls)
|
| + new_class = superclass.__new__(cls, name, bases, dictionary)
|
| + return new_class
|
| +
|
| + def __init__(cls, name, bases, dictionary):
|
| + """Here we perform the majority of our work on the class.
|
| + We add enum getters, an __init__ method, implementations
|
| + of all Message methods, and properties for all fields
|
| + in the protocol type.
|
| +
|
| + Args:
|
| + name: Name of the class (ignored, but required by the
|
| + metaclass protocol).
|
| + bases: Base classes of the class we're constructing.
|
| + (Should be message.Message). We ignore this field, but
|
| + it's required by the metaclass protocol
|
| + dictionary: The class dictionary of the class we're
|
| + constructing. dictionary[_DESCRIPTOR_KEY] must contain
|
| + a Descriptor object describing this protocol message
|
| + type.
|
| + """
|
| + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
|
| + cls._decoders_by_tag = {}
|
| + cls._extensions_by_name = {}
|
| + cls._extensions_by_number = {}
|
| + if (descriptor.has_options and
|
| + descriptor.GetOptions().message_set_wire_format):
|
| + cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
|
| + decoder.MessageSetItemDecoder(cls._extensions_by_number), None)
|
| +
|
| + # Attach stuff to each FieldDescriptor for quick lookup later on.
|
| + for field in descriptor.fields:
|
| + _AttachFieldHelpers(cls, field)
|
| +
|
| + descriptor._concrete_class = cls # pylint: disable=protected-access
|
| + _AddEnumValues(descriptor, cls)
|
| + _AddInitMethod(descriptor, cls)
|
| + _AddPropertiesForFields(descriptor, cls)
|
| + _AddPropertiesForExtensions(descriptor, cls)
|
| + _AddStaticMethods(cls)
|
| + _AddMessageMethods(descriptor, cls)
|
| + _AddPrivateHelperMethods(descriptor, cls)
|
| + copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
|
| +
|
| + superclass = super(GeneratedProtocolMessageType, cls)
|
| + superclass.__init__(name, bases, dictionary)
|
|
|
|
|
| # Stateless helpers for GeneratedProtocolMessageType below.
|
| @@ -176,7 +256,8 @@ def _AddSlots(message_descriptor, dictionary):
|
| '_is_present_in_parent',
|
| '_listener',
|
| '_listener_for_children',
|
| - '__weakref__']
|
| + '__weakref__',
|
| + '_oneofs']
|
|
|
|
|
| def _IsMessageSetExtension(field):
|
| @@ -184,16 +265,40 @@ def _IsMessageSetExtension(field):
|
| field.containing_type.has_options and
|
| field.containing_type.GetOptions().message_set_wire_format and
|
| field.type == _FieldDescriptor.TYPE_MESSAGE and
|
| - field.message_type == field.extension_scope and
|
| field.label == _FieldDescriptor.LABEL_OPTIONAL)
|
|
|
|
|
| +def _IsMapField(field):
|
| + return (field.type == _FieldDescriptor.TYPE_MESSAGE and
|
| + field.message_type.has_options and
|
| + field.message_type.GetOptions().map_entry)
|
| +
|
| +
|
| +def _IsMessageMapField(field):
|
| + value_type = field.message_type.fields_by_name["value"]
|
| + return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
|
| +
|
| +
|
| def _AttachFieldHelpers(cls, field_descriptor):
|
| is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
|
| - is_packed = (field_descriptor.has_options and
|
| - field_descriptor.GetOptions().packed)
|
| -
|
| - if _IsMessageSetExtension(field_descriptor):
|
| + is_packable = (is_repeated and
|
| + wire_format.IsTypePackable(field_descriptor.type))
|
| + if not is_packable:
|
| + is_packed = False
|
| + elif field_descriptor.containing_type.syntax == "proto2":
|
| + is_packed = (field_descriptor.has_options and
|
| + field_descriptor.GetOptions().packed)
|
| + else:
|
| + has_packed_false = (field_descriptor.has_options and
|
| + field_descriptor.GetOptions().HasField("packed") and
|
| + field_descriptor.GetOptions().packed == False)
|
| + is_packed = not has_packed_false
|
| + is_map_entry = _IsMapField(field_descriptor)
|
| +
|
| + if is_map_entry:
|
| + field_encoder = encoder.MapEncoder(field_descriptor)
|
| + sizer = encoder.MapSizer(field_descriptor)
|
| + elif _IsMessageSetExtension(field_descriptor):
|
| field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
|
| sizer = encoder.MessageSetItemSizer(field_descriptor.number)
|
| else:
|
| @@ -209,10 +314,27 @@ def _AttachFieldHelpers(cls, field_descriptor):
|
|
|
| def AddDecoder(wiretype, is_packed):
|
| tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
|
| - cls._decoders_by_tag[tag_bytes] = (
|
| - type_checkers.TYPE_TO_DECODER[field_descriptor.type](
|
| - field_descriptor.number, is_repeated, is_packed,
|
| - field_descriptor, field_descriptor._default_constructor))
|
| + decode_type = field_descriptor.type
|
| + if (decode_type == _FieldDescriptor.TYPE_ENUM and
|
| + type_checkers.SupportsOpenEnums(field_descriptor)):
|
| + decode_type = _FieldDescriptor.TYPE_INT32
|
| +
|
| + oneof_descriptor = None
|
| + if field_descriptor.containing_oneof is not None:
|
| + oneof_descriptor = field_descriptor
|
| +
|
| + if is_map_entry:
|
| + is_message_map = _IsMessageMapField(field_descriptor)
|
| +
|
| + field_decoder = decoder.MapDecoder(
|
| + field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
|
| + is_message_map)
|
| + else:
|
| + field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
|
| + field_descriptor.number, is_repeated, is_packed,
|
| + field_descriptor, field_descriptor._default_constructor)
|
| +
|
| + cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
|
|
|
| AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
|
| False)
|
| @@ -225,7 +347,7 @@ def _AttachFieldHelpers(cls, field_descriptor):
|
|
|
| def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
|
| extension_dict = descriptor.extensions_by_name
|
| - for extension_name, extension_field in extension_dict.iteritems():
|
| + for extension_name, extension_field in extension_dict.items():
|
| assert extension_name not in dictionary
|
| dictionary[extension_name] = extension_field
|
|
|
| @@ -245,6 +367,26 @@ def _AddEnumValues(descriptor, cls):
|
| setattr(cls, enum_value.name, enum_value.number)
|
|
|
|
|
| +def _GetInitializeDefaultForMap(field):
|
| + if field.label != _FieldDescriptor.LABEL_REPEATED:
|
| + raise ValueError('map_entry set on non-repeated field %s' % (
|
| + field.name))
|
| + fields_by_name = field.message_type.fields_by_name
|
| + key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
|
| +
|
| + value_field = fields_by_name['value']
|
| + if _IsMessageMapField(field):
|
| + def MakeMessageMapDefault(message):
|
| + return containers.MessageMap(
|
| + message._listener_for_children, value_field.message_type, key_checker)
|
| + return MakeMessageMapDefault
|
| + else:
|
| + value_checker = type_checkers.GetTypeChecker(value_field)
|
| + def MakePrimitiveMapDefault(message):
|
| + return containers.ScalarMap(
|
| + message._listener_for_children, key_checker, value_checker)
|
| + return MakePrimitiveMapDefault
|
| +
|
| def _DefaultValueConstructorForField(field):
|
| """Returns a function which returns a default value for a field.
|
|
|
| @@ -259,6 +401,9 @@ def _DefaultValueConstructorForField(field):
|
| value may refer back to |message| via a weak reference.
|
| """
|
|
|
| + if _IsMapField(field):
|
| + return _GetInitializeDefaultForMap(field)
|
| +
|
| if field.label == _FieldDescriptor.LABEL_REPEATED:
|
| if field.has_default_value and field.default_value != []:
|
| raise ValueError('Repeated field default value not empty list: %s' % (
|
| @@ -272,7 +417,7 @@ def _DefaultValueConstructorForField(field):
|
| message._listener_for_children, field.message_type)
|
| return MakeRepeatedMessageDefault
|
| else:
|
| - type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
|
| + type_checker = type_checkers.GetTypeChecker(field)
|
| def MakeRepeatedScalarDefault(message):
|
| return containers.RepeatedScalarFieldContainer(
|
| message._listener_for_children, type_checker)
|
| @@ -283,7 +428,10 @@ def _DefaultValueConstructorForField(field):
|
| message_type = field.message_type
|
| def MakeSubMessageDefault(message):
|
| result = message_type._concrete_class()
|
| - result._SetListener(message._listener_for_children)
|
| + result._SetListener(
|
| + _OneofListener(message, field)
|
| + if field.containing_oneof is not None
|
| + else message._listener_for_children)
|
| return result
|
| return MakeSubMessageDefault
|
|
|
| @@ -294,20 +442,50 @@ def _DefaultValueConstructorForField(field):
|
| return MakeScalarDefault
|
|
|
|
|
| +def _ReraiseTypeErrorWithFieldName(message_name, field_name):
|
| + """Re-raise the currently-handled TypeError with the field name added."""
|
| + exc = sys.exc_info()[1]
|
| + if len(exc.args) == 1 and type(exc) is TypeError:
|
| + # simple TypeError; add field name to exception message
|
| + exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
|
| +
|
| + # re-raise possibly-amended exception with original traceback:
|
| + six.reraise(type(exc), exc, sys.exc_info()[2])
|
| +
|
| +
|
| def _AddInitMethod(message_descriptor, cls):
|
| """Adds an __init__ method to cls."""
|
| - fields = message_descriptor.fields
|
| +
|
| + def _GetIntegerEnumValue(enum_type, value):
|
| + """Convert a string or integer enum value to an integer.
|
| +
|
| + If the value is a string, it is converted to the enum value in
|
| + enum_type with the same name. If the value is not a string, it's
|
| + returned as-is. (No conversion or bounds-checking is done.)
|
| + """
|
| + if isinstance(value, six.string_types):
|
| + try:
|
| + return enum_type.values_by_name[value].number
|
| + except KeyError:
|
| + raise ValueError('Enum type %s: unknown label "%s"' % (
|
| + enum_type.full_name, value))
|
| + return value
|
| +
|
| def init(self, **kwargs):
|
| self._cached_byte_size = 0
|
| self._cached_byte_size_dirty = len(kwargs) > 0
|
| self._fields = {}
|
| + # Contains a mapping from oneof field descriptors to the descriptor
|
| + # of the currently set field in that oneof field.
|
| + self._oneofs = {}
|
| +
|
| # _unknown_fields is () when empty for efficiency, and will be turned into
|
| # a list if fields are added.
|
| self._unknown_fields = ()
|
| self._is_present_in_parent = False
|
| self._listener = message_listener_mod.NullMessageListener()
|
| self._listener_for_children = _Listener(self)
|
| - for field_name, field_value in kwargs.iteritems():
|
| + for field_name, field_value in kwargs.items():
|
| field = _GetFieldByName(message_descriptor, field_name)
|
| if field is None:
|
| raise TypeError("%s() got an unexpected keyword argument '%s'" %
|
| @@ -315,17 +493,41 @@ def _AddInitMethod(message_descriptor, cls):
|
| if field.label == _FieldDescriptor.LABEL_REPEATED:
|
| copy = field._default_constructor(self)
|
| if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
|
| - for val in field_value:
|
| - copy.add().MergeFrom(val)
|
| + if _IsMapField(field):
|
| + if _IsMessageMapField(field):
|
| + for key in field_value:
|
| + copy[key].MergeFrom(field_value[key])
|
| + else:
|
| + copy.update(field_value)
|
| + else:
|
| + for val in field_value:
|
| + if isinstance(val, dict):
|
| + copy.add(**val)
|
| + else:
|
| + copy.add().MergeFrom(val)
|
| else: # Scalar
|
| + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
|
| + field_value = [_GetIntegerEnumValue(field.enum_type, val)
|
| + for val in field_value]
|
| copy.extend(field_value)
|
| self._fields[field] = copy
|
| elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
|
| copy = field._default_constructor(self)
|
| - copy.MergeFrom(field_value)
|
| + new_val = field_value
|
| + if isinstance(field_value, dict):
|
| + new_val = field.message_type._concrete_class(**field_value)
|
| + try:
|
| + copy.MergeFrom(new_val)
|
| + except TypeError:
|
| + _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
|
| self._fields[field] = copy
|
| else:
|
| - setattr(self, field_name, field_value)
|
| + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
|
| + field_value = _GetIntegerEnumValue(field.enum_type, field_value)
|
| + try:
|
| + setattr(self, field_name, field_value)
|
| + except TypeError:
|
| + _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
|
|
|
| init.__module__ = None
|
| init.__doc__ = None
|
| @@ -344,7 +546,8 @@ def _GetFieldByName(message_descriptor, field_name):
|
| try:
|
| return message_descriptor.fields_by_name[field_name]
|
| except KeyError:
|
| - raise ValueError('Protocol message has no "%s" field.' % field_name)
|
| + raise ValueError('Protocol message %s has no "%s" field.' %
|
| + (message_descriptor.name, field_name))
|
|
|
|
|
| def _AddPropertiesForFields(descriptor, cls):
|
| @@ -440,9 +643,10 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
|
| """
|
| proto_field_name = field.name
|
| property_name = _PropertyName(proto_field_name)
|
| - type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
|
| + type_checker = type_checkers.GetTypeChecker(field)
|
| default_value = field.default_value
|
| valid_values = set()
|
| + is_proto3 = field.containing_type.syntax == "proto3"
|
|
|
| def getter(self):
|
| # TODO(protobuf-team): This may be broken since there may not be
|
| @@ -450,14 +654,30 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
|
| return self._fields.get(field, default_value)
|
| getter.__module__ = None
|
| getter.__doc__ = 'Getter for %s.' % proto_field_name
|
| - def setter(self, new_value):
|
| - type_checker.CheckValue(new_value)
|
| - self._fields[field] = new_value
|
| +
|
| + clear_when_set_to_default = is_proto3 and not field.containing_oneof
|
| +
|
| + def field_setter(self, new_value):
|
| + # pylint: disable=protected-access
|
| + # Testing the value for truthiness captures all of the proto3 defaults
|
| + # (0, 0.0, enum 0, and False).
|
| + new_value = type_checker.CheckValue(new_value)
|
| + if clear_when_set_to_default and not new_value:
|
| + self._fields.pop(field, None)
|
| + else:
|
| + self._fields[field] = new_value
|
| # Check _cached_byte_size_dirty inline to improve performance, since scalar
|
| # setters are called frequently.
|
| if not self._cached_byte_size_dirty:
|
| self._Modified()
|
|
|
| + if field.containing_oneof:
|
| + def setter(self, new_value):
|
| + field_setter(self, new_value)
|
| + self._UpdateOneofState(field)
|
| + else:
|
| + setter = field_setter
|
| +
|
| setter.__module__ = None
|
| setter.__doc__ = 'Setter for %s.' % proto_field_name
|
|
|
| @@ -482,18 +702,11 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
|
| proto_field_name = field.name
|
| property_name = _PropertyName(proto_field_name)
|
|
|
| - # TODO(komarek): Can anyone explain to me why we cache the message_type this
|
| - # way, instead of referring to field.message_type inside of getter(self)?
|
| - # What if someone sets message_type later on (which makes for simpler
|
| - # dyanmic proto descriptor and class creation code).
|
| - message_type = field.message_type
|
| -
|
| def getter(self):
|
| field_value = self._fields.get(field)
|
| if field_value is None:
|
| # Construct a new object to represent this field.
|
| - field_value = message_type._concrete_class() # use field.message_type?
|
| - field_value._SetListener(self._listener_for_children)
|
| + field_value = field._default_constructor(self)
|
|
|
| # Atomically check if another thread has preempted us and, if not, swap
|
| # in the new object we just created. If someone has preempted us, we
|
| @@ -520,7 +733,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
|
| def _AddPropertiesForExtensions(descriptor, cls):
|
| """Adds properties for all fields in this protocol message type."""
|
| extension_dict = descriptor.extensions_by_name
|
| - for extension_name, extension_field in extension_dict.iteritems():
|
| + for extension_name, extension_field in extension_dict.items():
|
| constant_name = extension_name.upper() + "_FIELD_NUMBER"
|
| setattr(cls, constant_name, extension_field.number)
|
|
|
| @@ -575,33 +788,54 @@ def _AddListFieldsMethod(message_descriptor, cls):
|
| """Helper for _AddMessageMethods()."""
|
|
|
| def ListFields(self):
|
| - all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
|
| + all_fields = [item for item in self._fields.items() if _IsPresent(item)]
|
| all_fields.sort(key = lambda item: item[0].number)
|
| return all_fields
|
|
|
| cls.ListFields = ListFields
|
|
|
| +_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"'
|
| +_Proto2HasError = 'Protocol message has no non-repeated field "%s"'
|
|
|
| def _AddHasFieldMethod(message_descriptor, cls):
|
| """Helper for _AddMessageMethods()."""
|
|
|
| - singular_fields = {}
|
| + is_proto3 = (message_descriptor.syntax == "proto3")
|
| + error_msg = _Proto3HasError if is_proto3 else _Proto2HasError
|
| +
|
| + hassable_fields = {}
|
| for field in message_descriptor.fields:
|
| - if field.label != _FieldDescriptor.LABEL_REPEATED:
|
| - singular_fields[field.name] = field
|
| + if field.label == _FieldDescriptor.LABEL_REPEATED:
|
| + continue
|
| + # For proto3, only submessages and fields inside a oneof have presence.
|
| + if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
|
| + not field.containing_oneof):
|
| + continue
|
| + hassable_fields[field.name] = field
|
| +
|
| + if not is_proto3:
|
| + # Fields inside oneofs are never repeated (enforced by the compiler).
|
| + for oneof in message_descriptor.oneofs:
|
| + hassable_fields[oneof.name] = oneof
|
|
|
| def HasField(self, field_name):
|
| try:
|
| - field = singular_fields[field_name]
|
| + field = hassable_fields[field_name]
|
| except KeyError:
|
| - raise ValueError(
|
| - 'Protocol message has no singular "%s" field.' % field_name)
|
| + raise ValueError(error_msg % field_name)
|
|
|
| - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
|
| - value = self._fields.get(field)
|
| - return value is not None and value._is_present_in_parent
|
| + if isinstance(field, descriptor_mod.OneofDescriptor):
|
| + try:
|
| + return HasField(self, self._oneofs[field].name)
|
| + except KeyError:
|
| + return False
|
| else:
|
| - return field in self._fields
|
| + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
|
| + value = self._fields.get(field)
|
| + return value is not None and value._is_present_in_parent
|
| + else:
|
| + return field in self._fields
|
| +
|
| cls.HasField = HasField
|
|
|
|
|
| @@ -611,14 +845,30 @@ def _AddClearFieldMethod(message_descriptor, cls):
|
| try:
|
| field = message_descriptor.fields_by_name[field_name]
|
| except KeyError:
|
| - raise ValueError('Protocol message has no "%s" field.' % field_name)
|
| + try:
|
| + field = message_descriptor.oneofs_by_name[field_name]
|
| + if field in self._oneofs:
|
| + field = self._oneofs[field]
|
| + else:
|
| + return
|
| + except KeyError:
|
| + raise ValueError('Protocol message %s() has no "%s" field.' %
|
| + (message_descriptor.name, field_name))
|
|
|
| if field in self._fields:
|
| + # To match the C++ implementation, we need to invalidate iterators
|
| + # for map fields when ClearField() happens.
|
| + if hasattr(self._fields[field], 'InvalidateIterators'):
|
| + self._fields[field].InvalidateIterators()
|
| +
|
| # Note: If the field is a sub-message, its listener will still point
|
| # at us. That's fine, because the worst than can happen is that it
|
| # will call _Modified() and invalidate our byte size. Big deal.
|
| del self._fields[field]
|
|
|
| + if self._oneofs.get(field.containing_oneof, None) is field:
|
| + del self._oneofs[field.containing_oneof]
|
| +
|
| # Always call _Modified() -- even if nothing was changed, this is
|
| # a mutating method, and thus calling it should cause the field to become
|
| # present in the parent message.
|
| @@ -645,6 +895,7 @@ def _AddClearMethod(message_descriptor, cls):
|
| # Clear fields.
|
| self._fields = {}
|
| self._unknown_fields = ()
|
| + self._oneofs = {}
|
| self._Modified()
|
| cls.Clear = Clear
|
|
|
| @@ -663,6 +914,38 @@ def _AddHasExtensionMethod(cls):
|
| return extension_handle in self._fields
|
| cls.HasExtension = HasExtension
|
|
|
| +def _InternalUnpackAny(msg):
|
| + """Unpacks Any message and returns the unpacked message.
|
| +
|
| + This internal method is differnt from public Any Unpack method which takes
|
| + the target message as argument. _InternalUnpackAny method does not have
|
| + target message type and need to find the message type in descriptor pool.
|
| +
|
| + Args:
|
| + msg: An Any message to be unpacked.
|
| +
|
| + Returns:
|
| + The unpacked message.
|
| + """
|
| + type_url = msg.type_url
|
| + db = symbol_database.Default()
|
| +
|
| + if not type_url:
|
| + return None
|
| +
|
| + # TODO(haberman): For now we just strip the hostname. Better logic will be
|
| + # required.
|
| + type_name = type_url.split("/")[-1]
|
| + descriptor = db.pool.FindMessageTypeByName(type_name)
|
| +
|
| + if descriptor is None:
|
| + return None
|
| +
|
| + message_class = db.GetPrototype(descriptor)
|
| + message = message_class()
|
| +
|
| + message.ParseFromString(msg.value)
|
| + return message
|
|
|
| def _AddEqualsMethod(message_descriptor, cls):
|
| """Helper for _AddMessageMethods()."""
|
| @@ -674,6 +957,12 @@ def _AddEqualsMethod(message_descriptor, cls):
|
| if self is other:
|
| return True
|
|
|
| + if self.DESCRIPTOR.full_name == _AnyFullTypeName:
|
| + any_a = _InternalUnpackAny(self)
|
| + any_b = _InternalUnpackAny(other)
|
| + if any_a and any_b:
|
| + return any_a == any_b
|
| +
|
| if not self.ListFields() == other.ListFields():
|
| return False
|
|
|
| @@ -695,6 +984,13 @@ def _AddStrMethod(message_descriptor, cls):
|
| cls.__str__ = __str__
|
|
|
|
|
| +def _AddReprMethod(message_descriptor, cls):
|
| + """Helper for _AddMessageMethods()."""
|
| + def __repr__(self):
|
| + return text_format.MessageToString(self)
|
| + cls.__repr__ = __repr__
|
| +
|
| +
|
| def _AddUnicodeMethod(unused_message_descriptor, cls):
|
| """Helper for _AddMessageMethods()."""
|
|
|
| @@ -773,7 +1069,7 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls):
|
| """Helper for _AddMessageMethods()."""
|
|
|
| def SerializePartialToString(self):
|
| - out = StringIO()
|
| + out = BytesIO()
|
| self._InternalSerialize(out.write)
|
| return out.getvalue()
|
| cls.SerializePartialToString = SerializePartialToString
|
| @@ -796,9 +1092,10 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
|
| # The only reason _InternalParse would return early is if it
|
| # encountered an end-group tag.
|
| raise message_mod.DecodeError('Unexpected end-group tag.')
|
| - except IndexError:
|
| + except (IndexError, TypeError):
|
| + # Now ord(buf[p:p+1]) == ord('') gets TypeError.
|
| raise message_mod.DecodeError('Truncated message.')
|
| - except struct.error, e:
|
| + except struct.error as e:
|
| raise message_mod.DecodeError(e)
|
| return length # Return this for legacy reasons.
|
| cls.MergeFromString = MergeFromString
|
| @@ -806,6 +1103,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
|
| local_ReadTag = decoder.ReadTag
|
| local_SkipField = decoder.SkipField
|
| decoders_by_tag = cls._decoders_by_tag
|
| + is_proto3 = message_descriptor.syntax == "proto3"
|
|
|
| def InternalParse(self, buffer, pos, end):
|
| self._Modified()
|
| @@ -813,18 +1111,22 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
|
| unknown_field_list = self._unknown_fields
|
| while pos != end:
|
| (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
|
| - field_decoder = decoders_by_tag.get(tag_bytes)
|
| + field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
|
| if field_decoder is None:
|
| value_start_pos = new_pos
|
| new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
|
| if new_pos == -1:
|
| return pos
|
| - if not unknown_field_list:
|
| - unknown_field_list = self._unknown_fields = []
|
| - unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
|
| + if not is_proto3:
|
| + if not unknown_field_list:
|
| + unknown_field_list = self._unknown_fields = []
|
| + unknown_field_list.append(
|
| + (tag_bytes, buffer[value_start_pos:new_pos]))
|
| pos = new_pos
|
| else:
|
| pos = field_decoder(buffer, new_pos, end, self, field_dict)
|
| + if field_desc:
|
| + self._UpdateOneofState(field_desc)
|
| return pos
|
| cls._InternalParse = InternalParse
|
|
|
| @@ -857,9 +1159,12 @@ def _AddIsInitializedMethod(message_descriptor, cls):
|
| errors.extend(self.FindInitializationErrors())
|
| return False
|
|
|
| - for field, value in self._fields.iteritems():
|
| + for field, value in list(self._fields.items()): # dict can change size!
|
| if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
|
| if field.label == _FieldDescriptor.LABEL_REPEATED:
|
| + if (field.message_type.has_options and
|
| + field.message_type.GetOptions().map_entry):
|
| + continue
|
| for element in value:
|
| if not element.IsInitialized():
|
| if errors is not None:
|
| @@ -895,16 +1200,26 @@ def _AddIsInitializedMethod(message_descriptor, cls):
|
| else:
|
| name = field.name
|
|
|
| - if field.label == _FieldDescriptor.LABEL_REPEATED:
|
| - for i in xrange(len(value)):
|
| + if _IsMapField(field):
|
| + if _IsMessageMapField(field):
|
| + for key in value:
|
| + element = value[key]
|
| + prefix = "%s[%s]." % (name, key)
|
| + sub_errors = element.FindInitializationErrors()
|
| + errors += [prefix + error for error in sub_errors]
|
| + else:
|
| + # ScalarMaps can't have any initialization errors.
|
| + pass
|
| + elif field.label == _FieldDescriptor.LABEL_REPEATED:
|
| + for i in range(len(value)):
|
| element = value[i]
|
| prefix = "%s[%d]." % (name, i)
|
| sub_errors = element.FindInitializationErrors()
|
| - errors += [ prefix + error for error in sub_errors ]
|
| + errors += [prefix + error for error in sub_errors]
|
| else:
|
| prefix = name + "."
|
| sub_errors = value.FindInitializationErrors()
|
| - errors += [ prefix + error for error in sub_errors ]
|
| + errors += [prefix + error for error in sub_errors]
|
|
|
| return errors
|
|
|
| @@ -926,7 +1241,7 @@ def _AddMergeFromMethod(cls):
|
|
|
| fields = self._fields
|
|
|
| - for field, value in msg._fields.iteritems():
|
| + for field, value in msg._fields.items():
|
| if field.label == LABEL_REPEATED:
|
| field_value = fields.get(field)
|
| if field_value is None:
|
| @@ -944,6 +1259,8 @@ def _AddMergeFromMethod(cls):
|
| field_value.MergeFrom(value)
|
| else:
|
| self._fields[field] = value
|
| + if field.containing_oneof:
|
| + self._UpdateOneofState(field)
|
|
|
| if msg._unknown_fields:
|
| if not self._unknown_fields:
|
| @@ -953,6 +1270,24 @@ def _AddMergeFromMethod(cls):
|
| cls.MergeFrom = MergeFrom
|
|
|
|
|
| +def _AddWhichOneofMethod(message_descriptor, cls):
|
| + def WhichOneof(self, oneof_name):
|
| + """Returns the name of the currently set field inside a oneof, or None."""
|
| + try:
|
| + field = message_descriptor.oneofs_by_name[oneof_name]
|
| + except KeyError:
|
| + raise ValueError(
|
| + 'Protocol message has no oneof "%s" field.' % oneof_name)
|
| +
|
| + nested_field = self._oneofs.get(field, None)
|
| + if nested_field is not None and self.HasField(nested_field.name):
|
| + return nested_field.name
|
| + else:
|
| + return None
|
| +
|
| + cls.WhichOneof = WhichOneof
|
| +
|
| +
|
| def _AddMessageMethods(message_descriptor, cls):
|
| """Adds implementations of all Message methods to cls."""
|
| _AddListFieldsMethod(message_descriptor, cls)
|
| @@ -964,6 +1299,7 @@ def _AddMessageMethods(message_descriptor, cls):
|
| _AddClearMethod(message_descriptor, cls)
|
| _AddEqualsMethod(message_descriptor, cls)
|
| _AddStrMethod(message_descriptor, cls)
|
| + _AddReprMethod(message_descriptor, cls)
|
| _AddUnicodeMethod(message_descriptor, cls)
|
| _AddSetListenerMethod(cls)
|
| _AddByteSizeMethod(message_descriptor, cls)
|
| @@ -972,9 +1308,10 @@ def _AddMessageMethods(message_descriptor, cls):
|
| _AddMergeFromStringMethod(message_descriptor, cls)
|
| _AddIsInitializedMethod(message_descriptor, cls)
|
| _AddMergeFromMethod(cls)
|
| + _AddWhichOneofMethod(message_descriptor, cls)
|
|
|
|
|
| -def _AddPrivateHelperMethods(cls):
|
| +def _AddPrivateHelperMethods(message_descriptor, cls):
|
| """Adds implementation of private helper methods to cls."""
|
|
|
| def Modified(self):
|
| @@ -992,8 +1329,20 @@ def _AddPrivateHelperMethods(cls):
|
| self._is_present_in_parent = True
|
| self._listener.Modified()
|
|
|
| + def _UpdateOneofState(self, field):
|
| + """Sets field as the active field in its containing oneof.
|
| +
|
| + Will also delete currently active field in the oneof, if it is different
|
| + from the argument. Does not mark the message as modified.
|
| + """
|
| + other_field = self._oneofs.setdefault(field.containing_oneof, field)
|
| + if other_field is not field:
|
| + del self._fields[other_field]
|
| + self._oneofs[field.containing_oneof] = field
|
| +
|
| cls._Modified = Modified
|
| cls.SetInParent = Modified
|
| + cls._UpdateOneofState = _UpdateOneofState
|
|
|
|
|
| class _Listener(object):
|
| @@ -1042,6 +1391,27 @@ class _Listener(object):
|
| pass
|
|
|
|
|
| +class _OneofListener(_Listener):
|
| + """Special listener implementation for setting composite oneof fields."""
|
| +
|
| + def __init__(self, parent_message, field):
|
| + """Args:
|
| + parent_message: The message whose _Modified() method we should call when
|
| + we receive Modified() messages.
|
| + field: The descriptor of the field being set in the parent message.
|
| + """
|
| + super(_OneofListener, self).__init__(parent_message)
|
| + self._field = field
|
| +
|
| + def Modified(self):
|
| + """Also updates the state of the containing oneof in the parent message."""
|
| + try:
|
| + self._parent_message_weakref._UpdateOneofState(self._field)
|
| + super(_OneofListener, self).Modified()
|
| + except ReferenceError:
|
| + pass
|
| +
|
| +
|
| # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
|
| # TODO(robinson): Unify error handling of "unknown extension" crap.
|
| # TODO(robinson): Support iteritems()-style iteration over all
|
| @@ -1132,10 +1502,10 @@ class _ExtensionDict(object):
|
|
|
| # It's slightly wasteful to lookup the type checker each time,
|
| # but we expect this to be a vanishingly uncommon case anyway.
|
| - type_checker = type_checkers.GetTypeChecker(
|
| - extension_handle.cpp_type, extension_handle.type)
|
| - type_checker.CheckValue(value)
|
| - self._extended_message._fields[extension_handle] = value
|
| + type_checker = type_checkers.GetTypeChecker(extension_handle)
|
| + # pylint: disable=protected-access
|
| + self._extended_message._fields[extension_handle] = (
|
| + type_checker.CheckValue(value))
|
| self._extended_message._Modified()
|
|
|
| def _FindExtensionByName(self, name):
|
|
|