| OLD | NEW |
| 1 # Protocol Buffers - Google's data interchange format | 1 # Protocol Buffers - Google's data interchange format |
| 2 # Copyright 2008 Google Inc. All rights reserved. | 2 # Copyright 2008 Google Inc. All rights reserved. |
| 3 # http://code.google.com/p/protobuf/ | 3 # https://developers.google.com/protocol-buffers/ |
| 4 # | 4 # |
| 5 # Redistribution and use in source and binary forms, with or without | 5 # Redistribution and use in source and binary forms, with or without |
| 6 # modification, are permitted provided that the following conditions are | 6 # modification, are permitted provided that the following conditions are |
| 7 # met: | 7 # met: |
| 8 # | 8 # |
| 9 # * Redistributions of source code must retain the above copyright | 9 # * Redistributions of source code must retain the above copyright |
| 10 # notice, this list of conditions and the following disclaimer. | 10 # notice, this list of conditions and the following disclaimer. |
| 11 # * Redistributions in binary form must reproduce the above | 11 # * Redistributions in binary form must reproduce the above |
| 12 # copyright notice, this list of conditions and the following disclaimer | 12 # copyright notice, this list of conditions and the following disclaimer |
| 13 # in the documentation and/or other materials provided with the | 13 # in the documentation and/or other materials provided with the |
| (...skipping 29 matching lines...) Expand all Loading... |
| 43 to inject all the useful functionality into the classes | 43 to inject all the useful functionality into the classes |
| 44 output by the protocol compiler at compile-time. | 44 output by the protocol compiler at compile-time. |
| 45 | 45 |
| 46 The upshot of all this is that the real implementation | 46 The upshot of all this is that the real implementation |
| 47 details for ALL pure-Python protocol buffers are *here in | 47 details for ALL pure-Python protocol buffers are *here in |
| 48 this file*. | 48 this file*. |
| 49 """ | 49 """ |
| 50 | 50 |
| 51 __author__ = 'robinson@google.com (Will Robinson)' | 51 __author__ = 'robinson@google.com (Will Robinson)' |
| 52 | 52 |
| 53 try: | 53 from io import BytesIO |
| 54 from cStringIO import StringIO | 54 import sys |
| 55 except ImportError: | |
| 56 from StringIO import StringIO | |
| 57 import copy_reg | |
| 58 import struct | 55 import struct |
| 59 import weakref | 56 import weakref |
| 60 | 57 |
| 58 import six |
| 59 import six.moves.copyreg as copyreg |
| 60 |
| 61 # We use "as" to avoid name collisions with variables. | 61 # We use "as" to avoid name collisions with variables. |
| 62 from google.protobuf.internal import containers | 62 from google.protobuf.internal import containers |
| 63 from google.protobuf.internal import decoder | 63 from google.protobuf.internal import decoder |
| 64 from google.protobuf.internal import encoder | 64 from google.protobuf.internal import encoder |
| 65 from google.protobuf.internal import enum_type_wrapper | 65 from google.protobuf.internal import enum_type_wrapper |
| 66 from google.protobuf.internal import message_listener as message_listener_mod | 66 from google.protobuf.internal import message_listener as message_listener_mod |
| 67 from google.protobuf.internal import type_checkers | 67 from google.protobuf.internal import type_checkers |
| 68 from google.protobuf.internal import well_known_types |
| 68 from google.protobuf.internal import wire_format | 69 from google.protobuf.internal import wire_format |
| 69 from google.protobuf import descriptor as descriptor_mod | 70 from google.protobuf import descriptor as descriptor_mod |
| 70 from google.protobuf import message as message_mod | 71 from google.protobuf import message as message_mod |
| 72 from google.protobuf import symbol_database |
| 71 from google.protobuf import text_format | 73 from google.protobuf import text_format |
| 72 | 74 |
| 73 _FieldDescriptor = descriptor_mod.FieldDescriptor | 75 _FieldDescriptor = descriptor_mod.FieldDescriptor |
| 76 _AnyFullTypeName = 'google.protobuf.Any' |
| 74 | 77 |
| 75 | 78 |
| 76 def NewMessage(bases, descriptor, dictionary): | 79 class GeneratedProtocolMessageType(type): |
| 77 _AddClassAttributesForNestedExtensions(descriptor, dictionary) | |
| 78 _AddSlots(descriptor, dictionary) | |
| 79 return bases | |
| 80 | 80 |
| 81 """Metaclass for protocol message classes created at runtime from Descriptors. |
| 81 | 82 |
| 82 def InitMessage(descriptor, cls): | 83 We add implementations for all methods described in the Message class. We |
| 83 cls._decoders_by_tag = {} | 84 also create properties to allow getting/setting all fields in the protocol |
| 84 cls._extensions_by_name = {} | 85 message. Finally, we create slots to prevent users from accidentally |
| 85 cls._extensions_by_number = {} | 86 "setting" nonexistent fields in the protocol message, which then wouldn't get |
| 86 if (descriptor.has_options and | 87 serialized / deserialized properly. |
| 87 descriptor.GetOptions().message_set_wire_format): | |
| 88 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( | |
| 89 decoder.MessageSetItemDecoder(cls._extensions_by_number)) | |
| 90 | 88 |
| 91 # Attach stuff to each FieldDescriptor for quick lookup later on. | 89 The protocol compiler currently uses this metaclass to create protocol |
| 92 for field in descriptor.fields: | 90 message classes at runtime. Clients can also manually create their own |
| 93 _AttachFieldHelpers(cls, field) | 91 classes at runtime, as in this example: |
| 94 | 92 |
| 95 _AddEnumValues(descriptor, cls) | 93 mydescriptor = Descriptor(.....) |
| 96 _AddInitMethod(descriptor, cls) | 94 class MyProtoClass(Message): |
| 97 _AddPropertiesForFields(descriptor, cls) | 95 __metaclass__ = GeneratedProtocolMessageType |
| 98 _AddPropertiesForExtensions(descriptor, cls) | 96 DESCRIPTOR = mydescriptor |
| 99 _AddStaticMethods(cls) | 97 myproto_instance = MyProtoClass() |
| 100 _AddMessageMethods(descriptor, cls) | 98 myproto.foo_field = 23 |
| 101 _AddPrivateHelperMethods(cls) | 99 ... |
| 102 copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) | 100 |
| 101 The above example will not work for nested types. If you wish to include them, |
| 102 use reflection.MakeClass() instead of manually instantiating the class in |
| 103 order to create the appropriate class structure. |
| 104 """ |
| 105 |
| 106 # Must be consistent with the protocol-compiler code in |
| 107 # proto2/compiler/internal/generator.*. |
| 108 _DESCRIPTOR_KEY = 'DESCRIPTOR' |
| 109 |
| 110 def __new__(cls, name, bases, dictionary): |
| 111 """Custom allocation for runtime-generated class types. |
| 112 |
| 113 We override __new__ because this is apparently the only place |
| 114 where we can meaningfully set __slots__ on the class we're creating(?). |
| 115 (The interplay between metaclasses and slots is not very well-documented). |
| 116 |
| 117 Args: |
| 118 name: Name of the class (ignored, but required by the |
| 119 metaclass protocol). |
| 120 bases: Base classes of the class we're constructing. |
| 121 (Should be message.Message). We ignore this field, but |
| 122 it's required by the metaclass protocol |
| 123 dictionary: The class dictionary of the class we're |
| 124 constructing. dictionary[_DESCRIPTOR_KEY] must contain |
| 125 a Descriptor object describing this protocol message |
| 126 type. |
| 127 |
| 128 Returns: |
| 129 Newly-allocated class. |
| 130 """ |
| 131 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] |
| 132 if descriptor.full_name in well_known_types.WKTBASES: |
| 133 bases += (well_known_types.WKTBASES[descriptor.full_name],) |
| 134 _AddClassAttributesForNestedExtensions(descriptor, dictionary) |
| 135 _AddSlots(descriptor, dictionary) |
| 136 |
| 137 superclass = super(GeneratedProtocolMessageType, cls) |
| 138 new_class = superclass.__new__(cls, name, bases, dictionary) |
| 139 return new_class |
| 140 |
| 141 def __init__(cls, name, bases, dictionary): |
| 142 """Here we perform the majority of our work on the class. |
| 143 We add enum getters, an __init__ method, implementations |
| 144 of all Message methods, and properties for all fields |
| 145 in the protocol type. |
| 146 |
| 147 Args: |
| 148 name: Name of the class (ignored, but required by the |
| 149 metaclass protocol). |
| 150 bases: Base classes of the class we're constructing. |
| 151 (Should be message.Message). We ignore this field, but |
| 152 it's required by the metaclass protocol |
| 153 dictionary: The class dictionary of the class we're |
| 154 constructing. dictionary[_DESCRIPTOR_KEY] must contain |
| 155 a Descriptor object describing this protocol message |
| 156 type. |
| 157 """ |
| 158 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] |
| 159 cls._decoders_by_tag = {} |
| 160 cls._extensions_by_name = {} |
| 161 cls._extensions_by_number = {} |
| 162 if (descriptor.has_options and |
| 163 descriptor.GetOptions().message_set_wire_format): |
| 164 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( |
| 165 decoder.MessageSetItemDecoder(cls._extensions_by_number), None) |
| 166 |
| 167 # Attach stuff to each FieldDescriptor for quick lookup later on. |
| 168 for field in descriptor.fields: |
| 169 _AttachFieldHelpers(cls, field) |
| 170 |
| 171 descriptor._concrete_class = cls # pylint: disable=protected-access |
| 172 _AddEnumValues(descriptor, cls) |
| 173 _AddInitMethod(descriptor, cls) |
| 174 _AddPropertiesForFields(descriptor, cls) |
| 175 _AddPropertiesForExtensions(descriptor, cls) |
| 176 _AddStaticMethods(cls) |
| 177 _AddMessageMethods(descriptor, cls) |
| 178 _AddPrivateHelperMethods(descriptor, cls) |
| 179 copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) |
| 180 |
| 181 superclass = super(GeneratedProtocolMessageType, cls) |
| 182 superclass.__init__(name, bases, dictionary) |
| 103 | 183 |
| 104 | 184 |
| 105 # Stateless helpers for GeneratedProtocolMessageType below. | 185 # Stateless helpers for GeneratedProtocolMessageType below. |
| 106 # Outside clients should not access these directly. | 186 # Outside clients should not access these directly. |
| 107 # | 187 # |
| 108 # I opted not to make any of these methods on the metaclass, to make it more | 188 # I opted not to make any of these methods on the metaclass, to make it more |
| 109 # clear that I'm not really using any state there and to keep clients from | 189 # clear that I'm not really using any state there and to keep clients from |
| 110 # thinking that they have direct access to these construction helpers. | 190 # thinking that they have direct access to these construction helpers. |
| 111 | 191 |
| 112 | 192 |
| (...skipping 56 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 169 message_descriptor: A Descriptor instance describing this message type. | 249 message_descriptor: A Descriptor instance describing this message type. |
| 170 dictionary: Class dictionary to which we'll add a '__slots__' entry. | 250 dictionary: Class dictionary to which we'll add a '__slots__' entry. |
| 171 """ | 251 """ |
| 172 dictionary['__slots__'] = ['_cached_byte_size', | 252 dictionary['__slots__'] = ['_cached_byte_size', |
| 173 '_cached_byte_size_dirty', | 253 '_cached_byte_size_dirty', |
| 174 '_fields', | 254 '_fields', |
| 175 '_unknown_fields', | 255 '_unknown_fields', |
| 176 '_is_present_in_parent', | 256 '_is_present_in_parent', |
| 177 '_listener', | 257 '_listener', |
| 178 '_listener_for_children', | 258 '_listener_for_children', |
| 179 '__weakref__'] | 259 '__weakref__', |
| 260 '_oneofs'] |
| 180 | 261 |
| 181 | 262 |
| 182 def _IsMessageSetExtension(field): | 263 def _IsMessageSetExtension(field): |
| 183 return (field.is_extension and | 264 return (field.is_extension and |
| 184 field.containing_type.has_options and | 265 field.containing_type.has_options and |
| 185 field.containing_type.GetOptions().message_set_wire_format and | 266 field.containing_type.GetOptions().message_set_wire_format and |
| 186 field.type == _FieldDescriptor.TYPE_MESSAGE and | 267 field.type == _FieldDescriptor.TYPE_MESSAGE and |
| 187 field.message_type == field.extension_scope and | |
| 188 field.label == _FieldDescriptor.LABEL_OPTIONAL) | 268 field.label == _FieldDescriptor.LABEL_OPTIONAL) |
| 189 | 269 |
| 190 | 270 |
| 271 def _IsMapField(field): |
| 272 return (field.type == _FieldDescriptor.TYPE_MESSAGE and |
| 273 field.message_type.has_options and |
| 274 field.message_type.GetOptions().map_entry) |
| 275 |
| 276 |
| 277 def _IsMessageMapField(field): |
| 278 value_type = field.message_type.fields_by_name["value"] |
| 279 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE |
| 280 |
| 281 |
| 191 def _AttachFieldHelpers(cls, field_descriptor): | 282 def _AttachFieldHelpers(cls, field_descriptor): |
| 192 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) | 283 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) |
| 193 is_packed = (field_descriptor.has_options and | 284 is_packable = (is_repeated and |
| 194 field_descriptor.GetOptions().packed) | 285 wire_format.IsTypePackable(field_descriptor.type)) |
| 286 if not is_packable: |
| 287 is_packed = False |
| 288 elif field_descriptor.containing_type.syntax == "proto2": |
| 289 is_packed = (field_descriptor.has_options and |
| 290 field_descriptor.GetOptions().packed) |
| 291 else: |
| 292 has_packed_false = (field_descriptor.has_options and |
| 293 field_descriptor.GetOptions().HasField("packed") and |
| 294 field_descriptor.GetOptions().packed == False) |
| 295 is_packed = not has_packed_false |
| 296 is_map_entry = _IsMapField(field_descriptor) |
| 195 | 297 |
| 196 if _IsMessageSetExtension(field_descriptor): | 298 if is_map_entry: |
| 299 field_encoder = encoder.MapEncoder(field_descriptor) |
| 300 sizer = encoder.MapSizer(field_descriptor) |
| 301 elif _IsMessageSetExtension(field_descriptor): |
| 197 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) | 302 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) |
| 198 sizer = encoder.MessageSetItemSizer(field_descriptor.number) | 303 sizer = encoder.MessageSetItemSizer(field_descriptor.number) |
| 199 else: | 304 else: |
| 200 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( | 305 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( |
| 201 field_descriptor.number, is_repeated, is_packed) | 306 field_descriptor.number, is_repeated, is_packed) |
| 202 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( | 307 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( |
| 203 field_descriptor.number, is_repeated, is_packed) | 308 field_descriptor.number, is_repeated, is_packed) |
| 204 | 309 |
| 205 field_descriptor._encoder = field_encoder | 310 field_descriptor._encoder = field_encoder |
| 206 field_descriptor._sizer = sizer | 311 field_descriptor._sizer = sizer |
| 207 field_descriptor._default_constructor = _DefaultValueConstructorForField( | 312 field_descriptor._default_constructor = _DefaultValueConstructorForField( |
| 208 field_descriptor) | 313 field_descriptor) |
| 209 | 314 |
| 210 def AddDecoder(wiretype, is_packed): | 315 def AddDecoder(wiretype, is_packed): |
| 211 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) | 316 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) |
| 212 cls._decoders_by_tag[tag_bytes] = ( | 317 decode_type = field_descriptor.type |
| 213 type_checkers.TYPE_TO_DECODER[field_descriptor.type]( | 318 if (decode_type == _FieldDescriptor.TYPE_ENUM and |
| 214 field_descriptor.number, is_repeated, is_packed, | 319 type_checkers.SupportsOpenEnums(field_descriptor)): |
| 215 field_descriptor, field_descriptor._default_constructor)) | 320 decode_type = _FieldDescriptor.TYPE_INT32 |
| 321 |
| 322 oneof_descriptor = None |
| 323 if field_descriptor.containing_oneof is not None: |
| 324 oneof_descriptor = field_descriptor |
| 325 |
| 326 if is_map_entry: |
| 327 is_message_map = _IsMessageMapField(field_descriptor) |
| 328 |
| 329 field_decoder = decoder.MapDecoder( |
| 330 field_descriptor, _GetInitializeDefaultForMap(field_descriptor), |
| 331 is_message_map) |
| 332 else: |
| 333 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( |
| 334 field_descriptor.number, is_repeated, is_packed, |
| 335 field_descriptor, field_descriptor._default_constructor) |
| 336 |
| 337 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) |
| 216 | 338 |
| 217 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], | 339 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], |
| 218 False) | 340 False) |
| 219 | 341 |
| 220 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): | 342 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): |
| 221 # To support wire compatibility of adding packed = true, add a decoder for | 343 # To support wire compatibility of adding packed = true, add a decoder for |
| 222 # packed values regardless of the field's options. | 344 # packed values regardless of the field's options. |
| 223 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) | 345 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) |
| 224 | 346 |
| 225 | 347 |
| 226 def _AddClassAttributesForNestedExtensions(descriptor, dictionary): | 348 def _AddClassAttributesForNestedExtensions(descriptor, dictionary): |
| 227 extension_dict = descriptor.extensions_by_name | 349 extension_dict = descriptor.extensions_by_name |
| 228 for extension_name, extension_field in extension_dict.iteritems(): | 350 for extension_name, extension_field in extension_dict.items(): |
| 229 assert extension_name not in dictionary | 351 assert extension_name not in dictionary |
| 230 dictionary[extension_name] = extension_field | 352 dictionary[extension_name] = extension_field |
| 231 | 353 |
| 232 | 354 |
| 233 def _AddEnumValues(descriptor, cls): | 355 def _AddEnumValues(descriptor, cls): |
| 234 """Sets class-level attributes for all enum fields defined in this message. | 356 """Sets class-level attributes for all enum fields defined in this message. |
| 235 | 357 |
| 236 Also exporting a class-level object that can name enum values. | 358 Also exporting a class-level object that can name enum values. |
| 237 | 359 |
| 238 Args: | 360 Args: |
| 239 descriptor: Descriptor object for this message type. | 361 descriptor: Descriptor object for this message type. |
| 240 cls: Class we're constructing for this message type. | 362 cls: Class we're constructing for this message type. |
| 241 """ | 363 """ |
| 242 for enum_type in descriptor.enum_types: | 364 for enum_type in descriptor.enum_types: |
| 243 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) | 365 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) |
| 244 for enum_value in enum_type.values: | 366 for enum_value in enum_type.values: |
| 245 setattr(cls, enum_value.name, enum_value.number) | 367 setattr(cls, enum_value.name, enum_value.number) |
| 246 | 368 |
| 247 | 369 |
| 370 def _GetInitializeDefaultForMap(field): |
| 371 if field.label != _FieldDescriptor.LABEL_REPEATED: |
| 372 raise ValueError('map_entry set on non-repeated field %s' % ( |
| 373 field.name)) |
| 374 fields_by_name = field.message_type.fields_by_name |
| 375 key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) |
| 376 |
| 377 value_field = fields_by_name['value'] |
| 378 if _IsMessageMapField(field): |
| 379 def MakeMessageMapDefault(message): |
| 380 return containers.MessageMap( |
| 381 message._listener_for_children, value_field.message_type, key_checker) |
| 382 return MakeMessageMapDefault |
| 383 else: |
| 384 value_checker = type_checkers.GetTypeChecker(value_field) |
| 385 def MakePrimitiveMapDefault(message): |
| 386 return containers.ScalarMap( |
| 387 message._listener_for_children, key_checker, value_checker) |
| 388 return MakePrimitiveMapDefault |
| 389 |
| 248 def _DefaultValueConstructorForField(field): | 390 def _DefaultValueConstructorForField(field): |
| 249 """Returns a function which returns a default value for a field. | 391 """Returns a function which returns a default value for a field. |
| 250 | 392 |
| 251 Args: | 393 Args: |
| 252 field: FieldDescriptor object for this field. | 394 field: FieldDescriptor object for this field. |
| 253 | 395 |
| 254 The returned function has one argument: | 396 The returned function has one argument: |
| 255 message: Message instance containing this field, or a weakref proxy | 397 message: Message instance containing this field, or a weakref proxy |
| 256 of same. | 398 of same. |
| 257 | 399 |
| 258 That function in turn returns a default value for this field. The default | 400 That function in turn returns a default value for this field. The default |
| 259 value may refer back to |message| via a weak reference. | 401 value may refer back to |message| via a weak reference. |
| 260 """ | 402 """ |
| 261 | 403 |
| 404 if _IsMapField(field): |
| 405 return _GetInitializeDefaultForMap(field) |
| 406 |
| 262 if field.label == _FieldDescriptor.LABEL_REPEATED: | 407 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 263 if field.has_default_value and field.default_value != []: | 408 if field.has_default_value and field.default_value != []: |
| 264 raise ValueError('Repeated field default value not empty list: %s' % ( | 409 raise ValueError('Repeated field default value not empty list: %s' % ( |
| 265 field.default_value)) | 410 field.default_value)) |
| 266 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 411 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 267 # We can't look at _concrete_class yet since it might not have | 412 # We can't look at _concrete_class yet since it might not have |
| 268 # been set. (Depends on order in which we initialize the classes). | 413 # been set. (Depends on order in which we initialize the classes). |
| 269 message_type = field.message_type | 414 message_type = field.message_type |
| 270 def MakeRepeatedMessageDefault(message): | 415 def MakeRepeatedMessageDefault(message): |
| 271 return containers.RepeatedCompositeFieldContainer( | 416 return containers.RepeatedCompositeFieldContainer( |
| 272 message._listener_for_children, field.message_type) | 417 message._listener_for_children, field.message_type) |
| 273 return MakeRepeatedMessageDefault | 418 return MakeRepeatedMessageDefault |
| 274 else: | 419 else: |
| 275 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) | 420 type_checker = type_checkers.GetTypeChecker(field) |
| 276 def MakeRepeatedScalarDefault(message): | 421 def MakeRepeatedScalarDefault(message): |
| 277 return containers.RepeatedScalarFieldContainer( | 422 return containers.RepeatedScalarFieldContainer( |
| 278 message._listener_for_children, type_checker) | 423 message._listener_for_children, type_checker) |
| 279 return MakeRepeatedScalarDefault | 424 return MakeRepeatedScalarDefault |
| 280 | 425 |
| 281 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 426 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 282 # _concrete_class may not yet be initialized. | 427 # _concrete_class may not yet be initialized. |
| 283 message_type = field.message_type | 428 message_type = field.message_type |
| 284 def MakeSubMessageDefault(message): | 429 def MakeSubMessageDefault(message): |
| 285 result = message_type._concrete_class() | 430 result = message_type._concrete_class() |
| 286 result._SetListener(message._listener_for_children) | 431 result._SetListener( |
| 432 _OneofListener(message, field) |
| 433 if field.containing_oneof is not None |
| 434 else message._listener_for_children) |
| 287 return result | 435 return result |
| 288 return MakeSubMessageDefault | 436 return MakeSubMessageDefault |
| 289 | 437 |
| 290 def MakeScalarDefault(message): | 438 def MakeScalarDefault(message): |
| 291 # TODO(protobuf-team): This may be broken since there may not be | 439 # TODO(protobuf-team): This may be broken since there may not be |
| 292 # default_value. Combine with has_default_value somehow. | 440 # default_value. Combine with has_default_value somehow. |
| 293 return field.default_value | 441 return field.default_value |
| 294 return MakeScalarDefault | 442 return MakeScalarDefault |
| 295 | 443 |
| 296 | 444 |
| 445 def _ReraiseTypeErrorWithFieldName(message_name, field_name): |
| 446 """Re-raise the currently-handled TypeError with the field name added.""" |
| 447 exc = sys.exc_info()[1] |
| 448 if len(exc.args) == 1 and type(exc) is TypeError: |
| 449 # simple TypeError; add field name to exception message |
| 450 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) |
| 451 |
| 452 # re-raise possibly-amended exception with original traceback: |
| 453 six.reraise(type(exc), exc, sys.exc_info()[2]) |
| 454 |
| 455 |
| 297 def _AddInitMethod(message_descriptor, cls): | 456 def _AddInitMethod(message_descriptor, cls): |
| 298 """Adds an __init__ method to cls.""" | 457 """Adds an __init__ method to cls.""" |
| 299 fields = message_descriptor.fields | 458 |
| 459 def _GetIntegerEnumValue(enum_type, value): |
| 460 """Convert a string or integer enum value to an integer. |
| 461 |
| 462 If the value is a string, it is converted to the enum value in |
| 463 enum_type with the same name. If the value is not a string, it's |
| 464 returned as-is. (No conversion or bounds-checking is done.) |
| 465 """ |
| 466 if isinstance(value, six.string_types): |
| 467 try: |
| 468 return enum_type.values_by_name[value].number |
| 469 except KeyError: |
| 470 raise ValueError('Enum type %s: unknown label "%s"' % ( |
| 471 enum_type.full_name, value)) |
| 472 return value |
| 473 |
| 300 def init(self, **kwargs): | 474 def init(self, **kwargs): |
| 301 self._cached_byte_size = 0 | 475 self._cached_byte_size = 0 |
| 302 self._cached_byte_size_dirty = len(kwargs) > 0 | 476 self._cached_byte_size_dirty = len(kwargs) > 0 |
| 303 self._fields = {} | 477 self._fields = {} |
| 478 # Contains a mapping from oneof field descriptors to the descriptor |
| 479 # of the currently set field in that oneof field. |
| 480 self._oneofs = {} |
| 481 |
| 304 # _unknown_fields is () when empty for efficiency, and will be turned into | 482 # _unknown_fields is () when empty for efficiency, and will be turned into |
| 305 # a list if fields are added. | 483 # a list if fields are added. |
| 306 self._unknown_fields = () | 484 self._unknown_fields = () |
| 307 self._is_present_in_parent = False | 485 self._is_present_in_parent = False |
| 308 self._listener = message_listener_mod.NullMessageListener() | 486 self._listener = message_listener_mod.NullMessageListener() |
| 309 self._listener_for_children = _Listener(self) | 487 self._listener_for_children = _Listener(self) |
| 310 for field_name, field_value in kwargs.iteritems(): | 488 for field_name, field_value in kwargs.items(): |
| 311 field = _GetFieldByName(message_descriptor, field_name) | 489 field = _GetFieldByName(message_descriptor, field_name) |
| 312 if field is None: | 490 if field is None: |
| 313 raise TypeError("%s() got an unexpected keyword argument '%s'" % | 491 raise TypeError("%s() got an unexpected keyword argument '%s'" % |
| 314 (message_descriptor.name, field_name)) | 492 (message_descriptor.name, field_name)) |
| 315 if field.label == _FieldDescriptor.LABEL_REPEATED: | 493 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 316 copy = field._default_constructor(self) | 494 copy = field._default_constructor(self) |
| 317 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite | 495 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite |
| 318 for val in field_value: | 496 if _IsMapField(field): |
| 319 copy.add().MergeFrom(val) | 497 if _IsMessageMapField(field): |
| 498 for key in field_value: |
| 499 copy[key].MergeFrom(field_value[key]) |
| 500 else: |
| 501 copy.update(field_value) |
| 502 else: |
| 503 for val in field_value: |
| 504 if isinstance(val, dict): |
| 505 copy.add(**val) |
| 506 else: |
| 507 copy.add().MergeFrom(val) |
| 320 else: # Scalar | 508 else: # Scalar |
| 509 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: |
| 510 field_value = [_GetIntegerEnumValue(field.enum_type, val) |
| 511 for val in field_value] |
| 321 copy.extend(field_value) | 512 copy.extend(field_value) |
| 322 self._fields[field] = copy | 513 self._fields[field] = copy |
| 323 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 514 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 324 copy = field._default_constructor(self) | 515 copy = field._default_constructor(self) |
| 325 copy.MergeFrom(field_value) | 516 new_val = field_value |
| 517 if isinstance(field_value, dict): |
| 518 new_val = field.message_type._concrete_class(**field_value) |
| 519 try: |
| 520 copy.MergeFrom(new_val) |
| 521 except TypeError: |
| 522 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) |
| 326 self._fields[field] = copy | 523 self._fields[field] = copy |
| 327 else: | 524 else: |
| 328 setattr(self, field_name, field_value) | 525 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: |
| 526 field_value = _GetIntegerEnumValue(field.enum_type, field_value) |
| 527 try: |
| 528 setattr(self, field_name, field_value) |
| 529 except TypeError: |
| 530 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) |
| 329 | 531 |
| 330 init.__module__ = None | 532 init.__module__ = None |
| 331 init.__doc__ = None | 533 init.__doc__ = None |
| 332 cls.__init__ = init | 534 cls.__init__ = init |
| 333 | 535 |
| 334 | 536 |
| 335 def _GetFieldByName(message_descriptor, field_name): | 537 def _GetFieldByName(message_descriptor, field_name): |
| 336 """Returns a field descriptor by field name. | 538 """Returns a field descriptor by field name. |
| 337 | 539 |
| 338 Args: | 540 Args: |
| 339 message_descriptor: A Descriptor describing all fields in message. | 541 message_descriptor: A Descriptor describing all fields in message. |
| 340 field_name: The name of the field to retrieve. | 542 field_name: The name of the field to retrieve. |
| 341 Returns: | 543 Returns: |
| 342 The field descriptor associated with the field name. | 544 The field descriptor associated with the field name. |
| 343 """ | 545 """ |
| 344 try: | 546 try: |
| 345 return message_descriptor.fields_by_name[field_name] | 547 return message_descriptor.fields_by_name[field_name] |
| 346 except KeyError: | 548 except KeyError: |
| 347 raise ValueError('Protocol message has no "%s" field.' % field_name) | 549 raise ValueError('Protocol message %s has no "%s" field.' % |
| 550 (message_descriptor.name, field_name)) |
| 348 | 551 |
| 349 | 552 |
| 350 def _AddPropertiesForFields(descriptor, cls): | 553 def _AddPropertiesForFields(descriptor, cls): |
| 351 """Adds properties for all fields in this protocol message type.""" | 554 """Adds properties for all fields in this protocol message type.""" |
| 352 for field in descriptor.fields: | 555 for field in descriptor.fields: |
| 353 _AddPropertiesForField(field, cls) | 556 _AddPropertiesForField(field, cls) |
| 354 | 557 |
| 355 if descriptor.is_extendable: | 558 if descriptor.is_extendable: |
| 356 # _ExtensionDict is just an adaptor with no state so we allocate a new one | 559 # _ExtensionDict is just an adaptor with no state so we allocate a new one |
| 357 # every time it is accessed. | 560 # every time it is accessed. |
| (...skipping 75 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 433 Note that when the client sets the value of a field by using this property, | 636 Note that when the client sets the value of a field by using this property, |
| 434 all necessary "has" bits are set as a side-effect, and we also perform | 637 all necessary "has" bits are set as a side-effect, and we also perform |
| 435 type-checking. | 638 type-checking. |
| 436 | 639 |
| 437 Args: | 640 Args: |
| 438 field: A FieldDescriptor for this field. | 641 field: A FieldDescriptor for this field. |
| 439 cls: The class we're constructing. | 642 cls: The class we're constructing. |
| 440 """ | 643 """ |
| 441 proto_field_name = field.name | 644 proto_field_name = field.name |
| 442 property_name = _PropertyName(proto_field_name) | 645 property_name = _PropertyName(proto_field_name) |
| 443 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) | 646 type_checker = type_checkers.GetTypeChecker(field) |
| 444 default_value = field.default_value | 647 default_value = field.default_value |
| 445 valid_values = set() | 648 valid_values = set() |
| 649 is_proto3 = field.containing_type.syntax == "proto3" |
| 446 | 650 |
| 447 def getter(self): | 651 def getter(self): |
| 448 # TODO(protobuf-team): This may be broken since there may not be | 652 # TODO(protobuf-team): This may be broken since there may not be |
| 449 # default_value. Combine with has_default_value somehow. | 653 # default_value. Combine with has_default_value somehow. |
| 450 return self._fields.get(field, default_value) | 654 return self._fields.get(field, default_value) |
| 451 getter.__module__ = None | 655 getter.__module__ = None |
| 452 getter.__doc__ = 'Getter for %s.' % proto_field_name | 656 getter.__doc__ = 'Getter for %s.' % proto_field_name |
| 453 def setter(self, new_value): | 657 |
| 454 type_checker.CheckValue(new_value) | 658 clear_when_set_to_default = is_proto3 and not field.containing_oneof |
| 455 self._fields[field] = new_value | 659 |
| 660 def field_setter(self, new_value): |
| 661 # pylint: disable=protected-access |
| 662 # Testing the value for truthiness captures all of the proto3 defaults |
| 663 # (0, 0.0, enum 0, and False). |
| 664 new_value = type_checker.CheckValue(new_value) |
| 665 if clear_when_set_to_default and not new_value: |
| 666 self._fields.pop(field, None) |
| 667 else: |
| 668 self._fields[field] = new_value |
| 456 # Check _cached_byte_size_dirty inline to improve performance, since scalar | 669 # Check _cached_byte_size_dirty inline to improve performance, since scalar |
| 457 # setters are called frequently. | 670 # setters are called frequently. |
| 458 if not self._cached_byte_size_dirty: | 671 if not self._cached_byte_size_dirty: |
| 459 self._Modified() | 672 self._Modified() |
| 460 | 673 |
| 674 if field.containing_oneof: |
| 675 def setter(self, new_value): |
| 676 field_setter(self, new_value) |
| 677 self._UpdateOneofState(field) |
| 678 else: |
| 679 setter = field_setter |
| 680 |
| 461 setter.__module__ = None | 681 setter.__module__ = None |
| 462 setter.__doc__ = 'Setter for %s.' % proto_field_name | 682 setter.__doc__ = 'Setter for %s.' % proto_field_name |
| 463 | 683 |
| 464 # Add a property to encapsulate the getter/setter. | 684 # Add a property to encapsulate the getter/setter. |
| 465 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name | 685 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name |
| 466 setattr(cls, property_name, property(getter, setter, doc=doc)) | 686 setattr(cls, property_name, property(getter, setter, doc=doc)) |
| 467 | 687 |
| 468 | 688 |
| 469 def _AddPropertiesForNonRepeatedCompositeField(field, cls): | 689 def _AddPropertiesForNonRepeatedCompositeField(field, cls): |
| 470 """Adds a public property for a nonrepeated, composite protocol message field. | 690 """Adds a public property for a nonrepeated, composite protocol message field. |
| 471 A composite field is a "group" or "message" field. | 691 A composite field is a "group" or "message" field. |
| 472 | 692 |
| 473 Clients can use this property to get the value of the field, but cannot | 693 Clients can use this property to get the value of the field, but cannot |
| 474 assign to the property directly. | 694 assign to the property directly. |
| 475 | 695 |
| 476 Args: | 696 Args: |
| 477 field: A FieldDescriptor for this field. | 697 field: A FieldDescriptor for this field. |
| 478 cls: The class we're constructing. | 698 cls: The class we're constructing. |
| 479 """ | 699 """ |
| 480 # TODO(robinson): Remove duplication with similar method | 700 # TODO(robinson): Remove duplication with similar method |
| 481 # for non-repeated scalars. | 701 # for non-repeated scalars. |
| 482 proto_field_name = field.name | 702 proto_field_name = field.name |
| 483 property_name = _PropertyName(proto_field_name) | 703 property_name = _PropertyName(proto_field_name) |
| 484 | 704 |
| 485 # TODO(komarek): Can anyone explain to me why we cache the message_type this | |
| 486 # way, instead of referring to field.message_type inside of getter(self)? | |
| 487 # What if someone sets message_type later on (which makes for simpler | |
| 488 # dyanmic proto descriptor and class creation code). | |
| 489 message_type = field.message_type | |
| 490 | |
| 491 def getter(self): | 705 def getter(self): |
| 492 field_value = self._fields.get(field) | 706 field_value = self._fields.get(field) |
| 493 if field_value is None: | 707 if field_value is None: |
| 494 # Construct a new object to represent this field. | 708 # Construct a new object to represent this field. |
| 495 field_value = message_type._concrete_class() # use field.message_type? | 709 field_value = field._default_constructor(self) |
| 496 field_value._SetListener(self._listener_for_children) | |
| 497 | 710 |
| 498 # Atomically check if another thread has preempted us and, if not, swap | 711 # Atomically check if another thread has preempted us and, if not, swap |
| 499 # in the new object we just created. If someone has preempted us, we | 712 # in the new object we just created. If someone has preempted us, we |
| 500 # take that object and discard ours. | 713 # take that object and discard ours. |
| 501 # WARNING: We are relying on setdefault() being atomic. This is true | 714 # WARNING: We are relying on setdefault() being atomic. This is true |
| 502 # in CPython but we haven't investigated others. This warning appears | 715 # in CPython but we haven't investigated others. This warning appears |
| 503 # in several other locations in this file. | 716 # in several other locations in this file. |
| 504 field_value = self._fields.setdefault(field, field_value) | 717 field_value = self._fields.setdefault(field, field_value) |
| 505 return field_value | 718 return field_value |
| 506 getter.__module__ = None | 719 getter.__module__ = None |
| 507 getter.__doc__ = 'Getter for %s.' % proto_field_name | 720 getter.__doc__ = 'Getter for %s.' % proto_field_name |
| 508 | 721 |
| 509 # We define a setter just so we can throw an exception with a more | 722 # We define a setter just so we can throw an exception with a more |
| 510 # helpful error message. | 723 # helpful error message. |
| 511 def setter(self, new_value): | 724 def setter(self, new_value): |
| 512 raise AttributeError('Assignment not allowed to composite field ' | 725 raise AttributeError('Assignment not allowed to composite field ' |
| 513 '"%s" in protocol message object.' % proto_field_name) | 726 '"%s" in protocol message object.' % proto_field_name) |
| 514 | 727 |
| 515 # Add a property to encapsulate the getter. | 728 # Add a property to encapsulate the getter. |
| 516 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name | 729 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name |
| 517 setattr(cls, property_name, property(getter, setter, doc=doc)) | 730 setattr(cls, property_name, property(getter, setter, doc=doc)) |
| 518 | 731 |
| 519 | 732 |
| 520 def _AddPropertiesForExtensions(descriptor, cls): | 733 def _AddPropertiesForExtensions(descriptor, cls): |
| 521 """Adds properties for all fields in this protocol message type.""" | 734 """Adds properties for all fields in this protocol message type.""" |
| 522 extension_dict = descriptor.extensions_by_name | 735 extension_dict = descriptor.extensions_by_name |
| 523 for extension_name, extension_field in extension_dict.iteritems(): | 736 for extension_name, extension_field in extension_dict.items(): |
| 524 constant_name = extension_name.upper() + "_FIELD_NUMBER" | 737 constant_name = extension_name.upper() + "_FIELD_NUMBER" |
| 525 setattr(cls, constant_name, extension_field.number) | 738 setattr(cls, constant_name, extension_field.number) |
| 526 | 739 |
| 527 | 740 |
| 528 def _AddStaticMethods(cls): | 741 def _AddStaticMethods(cls): |
| 529 # TODO(robinson): This probably needs to be thread-safe(?) | 742 # TODO(robinson): This probably needs to be thread-safe(?) |
| 530 def RegisterExtension(extension_handle): | 743 def RegisterExtension(extension_handle): |
| 531 extension_handle.containing_type = cls.DESCRIPTOR | 744 extension_handle.containing_type = cls.DESCRIPTOR |
| 532 _AttachFieldHelpers(cls, extension_handle) | 745 _AttachFieldHelpers(cls, extension_handle) |
| 533 | 746 |
| (...skipping 34 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 568 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 781 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 569 return item[1]._is_present_in_parent | 782 return item[1]._is_present_in_parent |
| 570 else: | 783 else: |
| 571 return True | 784 return True |
| 572 | 785 |
| 573 | 786 |
| 574 def _AddListFieldsMethod(message_descriptor, cls): | 787 def _AddListFieldsMethod(message_descriptor, cls): |
| 575 """Helper for _AddMessageMethods().""" | 788 """Helper for _AddMessageMethods().""" |
| 576 | 789 |
| 577 def ListFields(self): | 790 def ListFields(self): |
| 578 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] | 791 all_fields = [item for item in self._fields.items() if _IsPresent(item)] |
| 579 all_fields.sort(key = lambda item: item[0].number) | 792 all_fields.sort(key = lambda item: item[0].number) |
| 580 return all_fields | 793 return all_fields |
| 581 | 794 |
| 582 cls.ListFields = ListFields | 795 cls.ListFields = ListFields |
| 583 | 796 |
| 797 _Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"' |
| 798 _Proto2HasError = 'Protocol message has no non-repeated field "%s"' |
| 584 | 799 |
| 585 def _AddHasFieldMethod(message_descriptor, cls): | 800 def _AddHasFieldMethod(message_descriptor, cls): |
| 586 """Helper for _AddMessageMethods().""" | 801 """Helper for _AddMessageMethods().""" |
| 587 | 802 |
| 588 singular_fields = {} | 803 is_proto3 = (message_descriptor.syntax == "proto3") |
| 804 error_msg = _Proto3HasError if is_proto3 else _Proto2HasError |
| 805 |
| 806 hassable_fields = {} |
| 589 for field in message_descriptor.fields: | 807 for field in message_descriptor.fields: |
| 590 if field.label != _FieldDescriptor.LABEL_REPEATED: | 808 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 591 singular_fields[field.name] = field | 809 continue |
| 810 # For proto3, only submessages and fields inside a oneof have presence. |
| 811 if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and |
| 812 not field.containing_oneof): |
| 813 continue |
| 814 hassable_fields[field.name] = field |
| 815 |
| 816 if not is_proto3: |
| 817 # Fields inside oneofs are never repeated (enforced by the compiler). |
| 818 for oneof in message_descriptor.oneofs: |
| 819 hassable_fields[oneof.name] = oneof |
| 592 | 820 |
| 593 def HasField(self, field_name): | 821 def HasField(self, field_name): |
| 594 try: | 822 try: |
| 595 field = singular_fields[field_name] | 823 field = hassable_fields[field_name] |
| 596 except KeyError: | 824 except KeyError: |
| 597 raise ValueError( | 825 raise ValueError(error_msg % field_name) |
| 598 'Protocol message has no singular "%s" field.' % field_name) | |
| 599 | 826 |
| 600 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 827 if isinstance(field, descriptor_mod.OneofDescriptor): |
| 601 value = self._fields.get(field) | 828 try: |
| 602 return value is not None and value._is_present_in_parent | 829 return HasField(self, self._oneofs[field].name) |
| 830 except KeyError: |
| 831 return False |
| 603 else: | 832 else: |
| 604 return field in self._fields | 833 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 834 value = self._fields.get(field) |
| 835 return value is not None and value._is_present_in_parent |
| 836 else: |
| 837 return field in self._fields |
| 838 |
| 605 cls.HasField = HasField | 839 cls.HasField = HasField |
| 606 | 840 |
| 607 | 841 |
| 608 def _AddClearFieldMethod(message_descriptor, cls): | 842 def _AddClearFieldMethod(message_descriptor, cls): |
| 609 """Helper for _AddMessageMethods().""" | 843 """Helper for _AddMessageMethods().""" |
| 610 def ClearField(self, field_name): | 844 def ClearField(self, field_name): |
| 611 try: | 845 try: |
| 612 field = message_descriptor.fields_by_name[field_name] | 846 field = message_descriptor.fields_by_name[field_name] |
| 613 except KeyError: | 847 except KeyError: |
| 614 raise ValueError('Protocol message has no "%s" field.' % field_name) | 848 try: |
| 849 field = message_descriptor.oneofs_by_name[field_name] |
| 850 if field in self._oneofs: |
| 851 field = self._oneofs[field] |
| 852 else: |
| 853 return |
| 854 except KeyError: |
| 855 raise ValueError('Protocol message %s() has no "%s" field.' % |
| 856 (message_descriptor.name, field_name)) |
| 615 | 857 |
| 616 if field in self._fields: | 858 if field in self._fields: |
| 859 # To match the C++ implementation, we need to invalidate iterators |
| 860 # for map fields when ClearField() happens. |
| 861 if hasattr(self._fields[field], 'InvalidateIterators'): |
| 862 self._fields[field].InvalidateIterators() |
| 863 |
| 617 # Note: If the field is a sub-message, its listener will still point | 864 # Note: If the field is a sub-message, its listener will still point |
| 618 # at us. That's fine, because the worst than can happen is that it | 865 # at us. That's fine, because the worst than can happen is that it |
| 619 # will call _Modified() and invalidate our byte size. Big deal. | 866 # will call _Modified() and invalidate our byte size. Big deal. |
| 620 del self._fields[field] | 867 del self._fields[field] |
| 621 | 868 |
| 869 if self._oneofs.get(field.containing_oneof, None) is field: |
| 870 del self._oneofs[field.containing_oneof] |
| 871 |
| 622 # Always call _Modified() -- even if nothing was changed, this is | 872 # Always call _Modified() -- even if nothing was changed, this is |
| 623 # a mutating method, and thus calling it should cause the field to become | 873 # a mutating method, and thus calling it should cause the field to become |
| 624 # present in the parent message. | 874 # present in the parent message. |
| 625 self._Modified() | 875 self._Modified() |
| 626 | 876 |
| 627 cls.ClearField = ClearField | 877 cls.ClearField = ClearField |
| 628 | 878 |
| 629 | 879 |
| 630 def _AddClearExtensionMethod(cls): | 880 def _AddClearExtensionMethod(cls): |
| 631 """Helper for _AddMessageMethods().""" | 881 """Helper for _AddMessageMethods().""" |
| 632 def ClearExtension(self, extension_handle): | 882 def ClearExtension(self, extension_handle): |
| 633 _VerifyExtensionHandle(self, extension_handle) | 883 _VerifyExtensionHandle(self, extension_handle) |
| 634 | 884 |
| 635 # Similar to ClearField(), above. | 885 # Similar to ClearField(), above. |
| 636 if extension_handle in self._fields: | 886 if extension_handle in self._fields: |
| 637 del self._fields[extension_handle] | 887 del self._fields[extension_handle] |
| 638 self._Modified() | 888 self._Modified() |
| 639 cls.ClearExtension = ClearExtension | 889 cls.ClearExtension = ClearExtension |
| 640 | 890 |
| 641 | 891 |
| 642 def _AddClearMethod(message_descriptor, cls): | 892 def _AddClearMethod(message_descriptor, cls): |
| 643 """Helper for _AddMessageMethods().""" | 893 """Helper for _AddMessageMethods().""" |
| 644 def Clear(self): | 894 def Clear(self): |
| 645 # Clear fields. | 895 # Clear fields. |
| 646 self._fields = {} | 896 self._fields = {} |
| 647 self._unknown_fields = () | 897 self._unknown_fields = () |
| 898 self._oneofs = {} |
| 648 self._Modified() | 899 self._Modified() |
| 649 cls.Clear = Clear | 900 cls.Clear = Clear |
| 650 | 901 |
| 651 | 902 |
| 652 def _AddHasExtensionMethod(cls): | 903 def _AddHasExtensionMethod(cls): |
| 653 """Helper for _AddMessageMethods().""" | 904 """Helper for _AddMessageMethods().""" |
| 654 def HasExtension(self, extension_handle): | 905 def HasExtension(self, extension_handle): |
| 655 _VerifyExtensionHandle(self, extension_handle) | 906 _VerifyExtensionHandle(self, extension_handle) |
| 656 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: | 907 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: |
| 657 raise KeyError('"%s" is repeated.' % extension_handle.full_name) | 908 raise KeyError('"%s" is repeated.' % extension_handle.full_name) |
| 658 | 909 |
| 659 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 910 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 660 value = self._fields.get(extension_handle) | 911 value = self._fields.get(extension_handle) |
| 661 return value is not None and value._is_present_in_parent | 912 return value is not None and value._is_present_in_parent |
| 662 else: | 913 else: |
| 663 return extension_handle in self._fields | 914 return extension_handle in self._fields |
| 664 cls.HasExtension = HasExtension | 915 cls.HasExtension = HasExtension |
| 665 | 916 |
| 917 def _InternalUnpackAny(msg): |
| 918 """Unpacks Any message and returns the unpacked message. |
| 919 |
| 920 This internal method is differnt from public Any Unpack method which takes |
| 921 the target message as argument. _InternalUnpackAny method does not have |
| 922 target message type and need to find the message type in descriptor pool. |
| 923 |
| 924 Args: |
| 925 msg: An Any message to be unpacked. |
| 926 |
| 927 Returns: |
| 928 The unpacked message. |
| 929 """ |
| 930 type_url = msg.type_url |
| 931 db = symbol_database.Default() |
| 932 |
| 933 if not type_url: |
| 934 return None |
| 935 |
| 936 # TODO(haberman): For now we just strip the hostname. Better logic will be |
| 937 # required. |
| 938 type_name = type_url.split("/")[-1] |
| 939 descriptor = db.pool.FindMessageTypeByName(type_name) |
| 940 |
| 941 if descriptor is None: |
| 942 return None |
| 943 |
| 944 message_class = db.GetPrototype(descriptor) |
| 945 message = message_class() |
| 946 |
| 947 message.ParseFromString(msg.value) |
| 948 return message |
| 666 | 949 |
| 667 def _AddEqualsMethod(message_descriptor, cls): | 950 def _AddEqualsMethod(message_descriptor, cls): |
| 668 """Helper for _AddMessageMethods().""" | 951 """Helper for _AddMessageMethods().""" |
| 669 def __eq__(self, other): | 952 def __eq__(self, other): |
| 670 if (not isinstance(other, message_mod.Message) or | 953 if (not isinstance(other, message_mod.Message) or |
| 671 other.DESCRIPTOR != self.DESCRIPTOR): | 954 other.DESCRIPTOR != self.DESCRIPTOR): |
| 672 return False | 955 return False |
| 673 | 956 |
| 674 if self is other: | 957 if self is other: |
| 675 return True | 958 return True |
| 676 | 959 |
| 960 if self.DESCRIPTOR.full_name == _AnyFullTypeName: |
| 961 any_a = _InternalUnpackAny(self) |
| 962 any_b = _InternalUnpackAny(other) |
| 963 if any_a and any_b: |
| 964 return any_a == any_b |
| 965 |
| 677 if not self.ListFields() == other.ListFields(): | 966 if not self.ListFields() == other.ListFields(): |
| 678 return False | 967 return False |
| 679 | 968 |
| 680 # Sort unknown fields because their order shouldn't affect equality test. | 969 # Sort unknown fields because their order shouldn't affect equality test. |
| 681 unknown_fields = list(self._unknown_fields) | 970 unknown_fields = list(self._unknown_fields) |
| 682 unknown_fields.sort() | 971 unknown_fields.sort() |
| 683 other_unknown_fields = list(other._unknown_fields) | 972 other_unknown_fields = list(other._unknown_fields) |
| 684 other_unknown_fields.sort() | 973 other_unknown_fields.sort() |
| 685 | 974 |
| 686 return unknown_fields == other_unknown_fields | 975 return unknown_fields == other_unknown_fields |
| 687 | 976 |
| 688 cls.__eq__ = __eq__ | 977 cls.__eq__ = __eq__ |
| 689 | 978 |
| 690 | 979 |
| 691 def _AddStrMethod(message_descriptor, cls): | 980 def _AddStrMethod(message_descriptor, cls): |
| 692 """Helper for _AddMessageMethods().""" | 981 """Helper for _AddMessageMethods().""" |
| 693 def __str__(self): | 982 def __str__(self): |
| 694 return text_format.MessageToString(self) | 983 return text_format.MessageToString(self) |
| 695 cls.__str__ = __str__ | 984 cls.__str__ = __str__ |
| 696 | 985 |
| 697 | 986 |
| 987 def _AddReprMethod(message_descriptor, cls): |
| 988 """Helper for _AddMessageMethods().""" |
| 989 def __repr__(self): |
| 990 return text_format.MessageToString(self) |
| 991 cls.__repr__ = __repr__ |
| 992 |
| 993 |
| 698 def _AddUnicodeMethod(unused_message_descriptor, cls): | 994 def _AddUnicodeMethod(unused_message_descriptor, cls): |
| 699 """Helper for _AddMessageMethods().""" | 995 """Helper for _AddMessageMethods().""" |
| 700 | 996 |
| 701 def __unicode__(self): | 997 def __unicode__(self): |
| 702 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') | 998 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') |
| 703 cls.__unicode__ = __unicode__ | 999 cls.__unicode__ = __unicode__ |
| 704 | 1000 |
| 705 | 1001 |
| 706 def _AddSetListenerMethod(cls): | 1002 def _AddSetListenerMethod(cls): |
| 707 """Helper for _AddMessageMethods().""" | 1003 """Helper for _AddMessageMethods().""" |
| (...skipping 58 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 766 'Message %s is missing required fields: %s' % ( | 1062 'Message %s is missing required fields: %s' % ( |
| 767 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) | 1063 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) |
| 768 return self.SerializePartialToString() | 1064 return self.SerializePartialToString() |
| 769 cls.SerializeToString = SerializeToString | 1065 cls.SerializeToString = SerializeToString |
| 770 | 1066 |
| 771 | 1067 |
| 772 def _AddSerializePartialToStringMethod(message_descriptor, cls): | 1068 def _AddSerializePartialToStringMethod(message_descriptor, cls): |
| 773 """Helper for _AddMessageMethods().""" | 1069 """Helper for _AddMessageMethods().""" |
| 774 | 1070 |
| 775 def SerializePartialToString(self): | 1071 def SerializePartialToString(self): |
| 776 out = StringIO() | 1072 out = BytesIO() |
| 777 self._InternalSerialize(out.write) | 1073 self._InternalSerialize(out.write) |
| 778 return out.getvalue() | 1074 return out.getvalue() |
| 779 cls.SerializePartialToString = SerializePartialToString | 1075 cls.SerializePartialToString = SerializePartialToString |
| 780 | 1076 |
| 781 def InternalSerialize(self, write_bytes): | 1077 def InternalSerialize(self, write_bytes): |
| 782 for field_descriptor, field_value in self.ListFields(): | 1078 for field_descriptor, field_value in self.ListFields(): |
| 783 field_descriptor._encoder(write_bytes, field_value) | 1079 field_descriptor._encoder(write_bytes, field_value) |
| 784 for tag_bytes, value_bytes in self._unknown_fields: | 1080 for tag_bytes, value_bytes in self._unknown_fields: |
| 785 write_bytes(tag_bytes) | 1081 write_bytes(tag_bytes) |
| 786 write_bytes(value_bytes) | 1082 write_bytes(value_bytes) |
| 787 cls._InternalSerialize = InternalSerialize | 1083 cls._InternalSerialize = InternalSerialize |
| 788 | 1084 |
| 789 | 1085 |
| 790 def _AddMergeFromStringMethod(message_descriptor, cls): | 1086 def _AddMergeFromStringMethod(message_descriptor, cls): |
| 791 """Helper for _AddMessageMethods().""" | 1087 """Helper for _AddMessageMethods().""" |
| 792 def MergeFromString(self, serialized): | 1088 def MergeFromString(self, serialized): |
| 793 length = len(serialized) | 1089 length = len(serialized) |
| 794 try: | 1090 try: |
| 795 if self._InternalParse(serialized, 0, length) != length: | 1091 if self._InternalParse(serialized, 0, length) != length: |
| 796 # The only reason _InternalParse would return early is if it | 1092 # The only reason _InternalParse would return early is if it |
| 797 # encountered an end-group tag. | 1093 # encountered an end-group tag. |
| 798 raise message_mod.DecodeError('Unexpected end-group tag.') | 1094 raise message_mod.DecodeError('Unexpected end-group tag.') |
| 799 except IndexError: | 1095 except (IndexError, TypeError): |
| 1096 # Now ord(buf[p:p+1]) == ord('') gets TypeError. |
| 800 raise message_mod.DecodeError('Truncated message.') | 1097 raise message_mod.DecodeError('Truncated message.') |
| 801 except struct.error, e: | 1098 except struct.error as e: |
| 802 raise message_mod.DecodeError(e) | 1099 raise message_mod.DecodeError(e) |
| 803 return length # Return this for legacy reasons. | 1100 return length # Return this for legacy reasons. |
| 804 cls.MergeFromString = MergeFromString | 1101 cls.MergeFromString = MergeFromString |
| 805 | 1102 |
| 806 local_ReadTag = decoder.ReadTag | 1103 local_ReadTag = decoder.ReadTag |
| 807 local_SkipField = decoder.SkipField | 1104 local_SkipField = decoder.SkipField |
| 808 decoders_by_tag = cls._decoders_by_tag | 1105 decoders_by_tag = cls._decoders_by_tag |
| 1106 is_proto3 = message_descriptor.syntax == "proto3" |
| 809 | 1107 |
| 810 def InternalParse(self, buffer, pos, end): | 1108 def InternalParse(self, buffer, pos, end): |
| 811 self._Modified() | 1109 self._Modified() |
| 812 field_dict = self._fields | 1110 field_dict = self._fields |
| 813 unknown_field_list = self._unknown_fields | 1111 unknown_field_list = self._unknown_fields |
| 814 while pos != end: | 1112 while pos != end: |
| 815 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) | 1113 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) |
| 816 field_decoder = decoders_by_tag.get(tag_bytes) | 1114 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) |
| 817 if field_decoder is None: | 1115 if field_decoder is None: |
| 818 value_start_pos = new_pos | 1116 value_start_pos = new_pos |
| 819 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) | 1117 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) |
| 820 if new_pos == -1: | 1118 if new_pos == -1: |
| 821 return pos | 1119 return pos |
| 822 if not unknown_field_list: | 1120 if not is_proto3: |
| 823 unknown_field_list = self._unknown_fields = [] | 1121 if not unknown_field_list: |
| 824 unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) | 1122 unknown_field_list = self._unknown_fields = [] |
| 1123 unknown_field_list.append( |
| 1124 (tag_bytes, buffer[value_start_pos:new_pos])) |
| 825 pos = new_pos | 1125 pos = new_pos |
| 826 else: | 1126 else: |
| 827 pos = field_decoder(buffer, new_pos, end, self, field_dict) | 1127 pos = field_decoder(buffer, new_pos, end, self, field_dict) |
| 1128 if field_desc: |
| 1129 self._UpdateOneofState(field_desc) |
| 828 return pos | 1130 return pos |
| 829 cls._InternalParse = InternalParse | 1131 cls._InternalParse = InternalParse |
| 830 | 1132 |
| 831 | 1133 |
| 832 def _AddIsInitializedMethod(message_descriptor, cls): | 1134 def _AddIsInitializedMethod(message_descriptor, cls): |
| 833 """Adds the IsInitialized and FindInitializationError methods to the | 1135 """Adds the IsInitialized and FindInitializationError methods to the |
| 834 protocol message class.""" | 1136 protocol message class.""" |
| 835 | 1137 |
| 836 required_fields = [field for field in message_descriptor.fields | 1138 required_fields = [field for field in message_descriptor.fields |
| 837 if field.label == _FieldDescriptor.LABEL_REQUIRED] | 1139 if field.label == _FieldDescriptor.LABEL_REQUIRED] |
| (...skipping 12 matching lines...) Expand all Loading... |
| 850 # Performance is critical so we avoid HasField() and ListFields(). | 1152 # Performance is critical so we avoid HasField() and ListFields(). |
| 851 | 1153 |
| 852 for field in required_fields: | 1154 for field in required_fields: |
| 853 if (field not in self._fields or | 1155 if (field not in self._fields or |
| 854 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and | 1156 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and |
| 855 not self._fields[field]._is_present_in_parent)): | 1157 not self._fields[field]._is_present_in_parent)): |
| 856 if errors is not None: | 1158 if errors is not None: |
| 857 errors.extend(self.FindInitializationErrors()) | 1159 errors.extend(self.FindInitializationErrors()) |
| 858 return False | 1160 return False |
| 859 | 1161 |
| 860 for field, value in self._fields.iteritems(): | 1162 for field, value in list(self._fields.items()): # dict can change size! |
| 861 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 1163 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 862 if field.label == _FieldDescriptor.LABEL_REPEATED: | 1164 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 1165 if (field.message_type.has_options and |
| 1166 field.message_type.GetOptions().map_entry): |
| 1167 continue |
| 863 for element in value: | 1168 for element in value: |
| 864 if not element.IsInitialized(): | 1169 if not element.IsInitialized(): |
| 865 if errors is not None: | 1170 if errors is not None: |
| 866 errors.extend(self.FindInitializationErrors()) | 1171 errors.extend(self.FindInitializationErrors()) |
| 867 return False | 1172 return False |
| 868 elif value._is_present_in_parent and not value.IsInitialized(): | 1173 elif value._is_present_in_parent and not value.IsInitialized(): |
| 869 if errors is not None: | 1174 if errors is not None: |
| 870 errors.extend(self.FindInitializationErrors()) | 1175 errors.extend(self.FindInitializationErrors()) |
| 871 return False | 1176 return False |
| 872 | 1177 |
| (...skipping 15 matching lines...) Expand all Loading... |
| 888 if not self.HasField(field.name): | 1193 if not self.HasField(field.name): |
| 889 errors.append(field.name) | 1194 errors.append(field.name) |
| 890 | 1195 |
| 891 for field, value in self.ListFields(): | 1196 for field, value in self.ListFields(): |
| 892 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 1197 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 893 if field.is_extension: | 1198 if field.is_extension: |
| 894 name = "(%s)" % field.full_name | 1199 name = "(%s)" % field.full_name |
| 895 else: | 1200 else: |
| 896 name = field.name | 1201 name = field.name |
| 897 | 1202 |
| 898 if field.label == _FieldDescriptor.LABEL_REPEATED: | 1203 if _IsMapField(field): |
| 899 for i in xrange(len(value)): | 1204 if _IsMessageMapField(field): |
| 1205 for key in value: |
| 1206 element = value[key] |
| 1207 prefix = "%s[%s]." % (name, key) |
| 1208 sub_errors = element.FindInitializationErrors() |
| 1209 errors += [prefix + error for error in sub_errors] |
| 1210 else: |
| 1211 # ScalarMaps can't have any initialization errors. |
| 1212 pass |
| 1213 elif field.label == _FieldDescriptor.LABEL_REPEATED: |
| 1214 for i in range(len(value)): |
| 900 element = value[i] | 1215 element = value[i] |
| 901 prefix = "%s[%d]." % (name, i) | 1216 prefix = "%s[%d]." % (name, i) |
| 902 sub_errors = element.FindInitializationErrors() | 1217 sub_errors = element.FindInitializationErrors() |
| 903 errors += [ prefix + error for error in sub_errors ] | 1218 errors += [prefix + error for error in sub_errors] |
| 904 else: | 1219 else: |
| 905 prefix = name + "." | 1220 prefix = name + "." |
| 906 sub_errors = value.FindInitializationErrors() | 1221 sub_errors = value.FindInitializationErrors() |
| 907 errors += [ prefix + error for error in sub_errors ] | 1222 errors += [prefix + error for error in sub_errors] |
| 908 | 1223 |
| 909 return errors | 1224 return errors |
| 910 | 1225 |
| 911 cls.FindInitializationErrors = FindInitializationErrors | 1226 cls.FindInitializationErrors = FindInitializationErrors |
| 912 | 1227 |
| 913 | 1228 |
| 914 def _AddMergeFromMethod(cls): | 1229 def _AddMergeFromMethod(cls): |
| 915 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED | 1230 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED |
| 916 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE | 1231 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE |
| 917 | 1232 |
| 918 def MergeFrom(self, msg): | 1233 def MergeFrom(self, msg): |
| 919 if not isinstance(msg, cls): | 1234 if not isinstance(msg, cls): |
| 920 raise TypeError( | 1235 raise TypeError( |
| 921 "Parameter to MergeFrom() must be instance of same class: " | 1236 "Parameter to MergeFrom() must be instance of same class: " |
| 922 "expected %s got %s." % (cls.__name__, type(msg).__name__)) | 1237 "expected %s got %s." % (cls.__name__, type(msg).__name__)) |
| 923 | 1238 |
| 924 assert msg is not self | 1239 assert msg is not self |
| 925 self._Modified() | 1240 self._Modified() |
| 926 | 1241 |
| 927 fields = self._fields | 1242 fields = self._fields |
| 928 | 1243 |
| 929 for field, value in msg._fields.iteritems(): | 1244 for field, value in msg._fields.items(): |
| 930 if field.label == LABEL_REPEATED: | 1245 if field.label == LABEL_REPEATED: |
| 931 field_value = fields.get(field) | 1246 field_value = fields.get(field) |
| 932 if field_value is None: | 1247 if field_value is None: |
| 933 # Construct a new object to represent this field. | 1248 # Construct a new object to represent this field. |
| 934 field_value = field._default_constructor(self) | 1249 field_value = field._default_constructor(self) |
| 935 fields[field] = field_value | 1250 fields[field] = field_value |
| 936 field_value.MergeFrom(value) | 1251 field_value.MergeFrom(value) |
| 937 elif field.cpp_type == CPPTYPE_MESSAGE: | 1252 elif field.cpp_type == CPPTYPE_MESSAGE: |
| 938 if value._is_present_in_parent: | 1253 if value._is_present_in_parent: |
| 939 field_value = fields.get(field) | 1254 field_value = fields.get(field) |
| 940 if field_value is None: | 1255 if field_value is None: |
| 941 # Construct a new object to represent this field. | 1256 # Construct a new object to represent this field. |
| 942 field_value = field._default_constructor(self) | 1257 field_value = field._default_constructor(self) |
| 943 fields[field] = field_value | 1258 fields[field] = field_value |
| 944 field_value.MergeFrom(value) | 1259 field_value.MergeFrom(value) |
| 945 else: | 1260 else: |
| 946 self._fields[field] = value | 1261 self._fields[field] = value |
| 1262 if field.containing_oneof: |
| 1263 self._UpdateOneofState(field) |
| 947 | 1264 |
| 948 if msg._unknown_fields: | 1265 if msg._unknown_fields: |
| 949 if not self._unknown_fields: | 1266 if not self._unknown_fields: |
| 950 self._unknown_fields = [] | 1267 self._unknown_fields = [] |
| 951 self._unknown_fields.extend(msg._unknown_fields) | 1268 self._unknown_fields.extend(msg._unknown_fields) |
| 952 | 1269 |
| 953 cls.MergeFrom = MergeFrom | 1270 cls.MergeFrom = MergeFrom |
| 954 | 1271 |
| 955 | 1272 |
| 1273 def _AddWhichOneofMethod(message_descriptor, cls): |
| 1274 def WhichOneof(self, oneof_name): |
| 1275 """Returns the name of the currently set field inside a oneof, or None.""" |
| 1276 try: |
| 1277 field = message_descriptor.oneofs_by_name[oneof_name] |
| 1278 except KeyError: |
| 1279 raise ValueError( |
| 1280 'Protocol message has no oneof "%s" field.' % oneof_name) |
| 1281 |
| 1282 nested_field = self._oneofs.get(field, None) |
| 1283 if nested_field is not None and self.HasField(nested_field.name): |
| 1284 return nested_field.name |
| 1285 else: |
| 1286 return None |
| 1287 |
| 1288 cls.WhichOneof = WhichOneof |
| 1289 |
| 1290 |
| 956 def _AddMessageMethods(message_descriptor, cls): | 1291 def _AddMessageMethods(message_descriptor, cls): |
| 957 """Adds implementations of all Message methods to cls.""" | 1292 """Adds implementations of all Message methods to cls.""" |
| 958 _AddListFieldsMethod(message_descriptor, cls) | 1293 _AddListFieldsMethod(message_descriptor, cls) |
| 959 _AddHasFieldMethod(message_descriptor, cls) | 1294 _AddHasFieldMethod(message_descriptor, cls) |
| 960 _AddClearFieldMethod(message_descriptor, cls) | 1295 _AddClearFieldMethod(message_descriptor, cls) |
| 961 if message_descriptor.is_extendable: | 1296 if message_descriptor.is_extendable: |
| 962 _AddClearExtensionMethod(cls) | 1297 _AddClearExtensionMethod(cls) |
| 963 _AddHasExtensionMethod(cls) | 1298 _AddHasExtensionMethod(cls) |
| 964 _AddClearMethod(message_descriptor, cls) | 1299 _AddClearMethod(message_descriptor, cls) |
| 965 _AddEqualsMethod(message_descriptor, cls) | 1300 _AddEqualsMethod(message_descriptor, cls) |
| 966 _AddStrMethod(message_descriptor, cls) | 1301 _AddStrMethod(message_descriptor, cls) |
| 1302 _AddReprMethod(message_descriptor, cls) |
| 967 _AddUnicodeMethod(message_descriptor, cls) | 1303 _AddUnicodeMethod(message_descriptor, cls) |
| 968 _AddSetListenerMethod(cls) | 1304 _AddSetListenerMethod(cls) |
| 969 _AddByteSizeMethod(message_descriptor, cls) | 1305 _AddByteSizeMethod(message_descriptor, cls) |
| 970 _AddSerializeToStringMethod(message_descriptor, cls) | 1306 _AddSerializeToStringMethod(message_descriptor, cls) |
| 971 _AddSerializePartialToStringMethod(message_descriptor, cls) | 1307 _AddSerializePartialToStringMethod(message_descriptor, cls) |
| 972 _AddMergeFromStringMethod(message_descriptor, cls) | 1308 _AddMergeFromStringMethod(message_descriptor, cls) |
| 973 _AddIsInitializedMethod(message_descriptor, cls) | 1309 _AddIsInitializedMethod(message_descriptor, cls) |
| 974 _AddMergeFromMethod(cls) | 1310 _AddMergeFromMethod(cls) |
| 1311 _AddWhichOneofMethod(message_descriptor, cls) |
| 975 | 1312 |
| 976 | 1313 |
| 977 def _AddPrivateHelperMethods(cls): | 1314 def _AddPrivateHelperMethods(message_descriptor, cls): |
| 978 """Adds implementation of private helper methods to cls.""" | 1315 """Adds implementation of private helper methods to cls.""" |
| 979 | 1316 |
| 980 def Modified(self): | 1317 def Modified(self): |
| 981 """Sets the _cached_byte_size_dirty bit to true, | 1318 """Sets the _cached_byte_size_dirty bit to true, |
| 982 and propagates this to our listener iff this was a state change. | 1319 and propagates this to our listener iff this was a state change. |
| 983 """ | 1320 """ |
| 984 | 1321 |
| 985 # Note: Some callers check _cached_byte_size_dirty before calling | 1322 # Note: Some callers check _cached_byte_size_dirty before calling |
| 986 # _Modified() as an extra optimization. So, if this method is ever | 1323 # _Modified() as an extra optimization. So, if this method is ever |
| 987 # changed such that it does stuff even when _cached_byte_size_dirty is | 1324 # changed such that it does stuff even when _cached_byte_size_dirty is |
| 988 # already true, the callers need to be updated. | 1325 # already true, the callers need to be updated. |
| 989 if not self._cached_byte_size_dirty: | 1326 if not self._cached_byte_size_dirty: |
| 990 self._cached_byte_size_dirty = True | 1327 self._cached_byte_size_dirty = True |
| 991 self._listener_for_children.dirty = True | 1328 self._listener_for_children.dirty = True |
| 992 self._is_present_in_parent = True | 1329 self._is_present_in_parent = True |
| 993 self._listener.Modified() | 1330 self._listener.Modified() |
| 994 | 1331 |
| 1332 def _UpdateOneofState(self, field): |
| 1333 """Sets field as the active field in its containing oneof. |
| 1334 |
| 1335 Will also delete currently active field in the oneof, if it is different |
| 1336 from the argument. Does not mark the message as modified. |
| 1337 """ |
| 1338 other_field = self._oneofs.setdefault(field.containing_oneof, field) |
| 1339 if other_field is not field: |
| 1340 del self._fields[other_field] |
| 1341 self._oneofs[field.containing_oneof] = field |
| 1342 |
| 995 cls._Modified = Modified | 1343 cls._Modified = Modified |
| 996 cls.SetInParent = Modified | 1344 cls.SetInParent = Modified |
| 1345 cls._UpdateOneofState = _UpdateOneofState |
| 997 | 1346 |
| 998 | 1347 |
| 999 class _Listener(object): | 1348 class _Listener(object): |
| 1000 | 1349 |
| 1001 """MessageListener implementation that a parent message registers with its | 1350 """MessageListener implementation that a parent message registers with its |
| 1002 child message. | 1351 child message. |
| 1003 | 1352 |
| 1004 In order to support semantics like: | 1353 In order to support semantics like: |
| 1005 | 1354 |
| 1006 foo.bar.baz.qux = 23 | 1355 foo.bar.baz.qux = 23 |
| (...skipping 28 matching lines...) Expand all Loading... |
| 1035 try: | 1384 try: |
| 1036 # Propagate the signal to our parents iff this is the first field set. | 1385 # Propagate the signal to our parents iff this is the first field set. |
| 1037 self._parent_message_weakref._Modified() | 1386 self._parent_message_weakref._Modified() |
| 1038 except ReferenceError: | 1387 except ReferenceError: |
| 1039 # We can get here if a client has kept a reference to a child object, | 1388 # We can get here if a client has kept a reference to a child object, |
| 1040 # and is now setting a field on it, but the child's parent has been | 1389 # and is now setting a field on it, but the child's parent has been |
| 1041 # garbage-collected. This is not an error. | 1390 # garbage-collected. This is not an error. |
| 1042 pass | 1391 pass |
| 1043 | 1392 |
| 1044 | 1393 |
| 1394 class _OneofListener(_Listener): |
| 1395 """Special listener implementation for setting composite oneof fields.""" |
| 1396 |
| 1397 def __init__(self, parent_message, field): |
| 1398 """Args: |
| 1399 parent_message: The message whose _Modified() method we should call when |
| 1400 we receive Modified() messages. |
| 1401 field: The descriptor of the field being set in the parent message. |
| 1402 """ |
| 1403 super(_OneofListener, self).__init__(parent_message) |
| 1404 self._field = field |
| 1405 |
| 1406 def Modified(self): |
| 1407 """Also updates the state of the containing oneof in the parent message.""" |
| 1408 try: |
| 1409 self._parent_message_weakref._UpdateOneofState(self._field) |
| 1410 super(_OneofListener, self).Modified() |
| 1411 except ReferenceError: |
| 1412 pass |
| 1413 |
| 1414 |
| 1045 # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... | 1415 # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... |
| 1046 # TODO(robinson): Unify error handling of "unknown extension" crap. | 1416 # TODO(robinson): Unify error handling of "unknown extension" crap. |
| 1047 # TODO(robinson): Support iteritems()-style iteration over all | 1417 # TODO(robinson): Support iteritems()-style iteration over all |
| 1048 # extensions with the "has" bits turned on? | 1418 # extensions with the "has" bits turned on? |
| 1049 class _ExtensionDict(object): | 1419 class _ExtensionDict(object): |
| 1050 | 1420 |
| 1051 """Dict-like container for supporting an indexable "Extensions" | 1421 """Dict-like container for supporting an indexable "Extensions" |
| 1052 field on proto instances. | 1422 field on proto instances. |
| 1053 | 1423 |
| 1054 Note that in all cases we expect extension handles to be | 1424 Note that in all cases we expect extension handles to be |
| (...skipping 70 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 1125 _VerifyExtensionHandle(self._extended_message, extension_handle) | 1495 _VerifyExtensionHandle(self._extended_message, extension_handle) |
| 1126 | 1496 |
| 1127 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or | 1497 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or |
| 1128 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): | 1498 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): |
| 1129 raise TypeError( | 1499 raise TypeError( |
| 1130 'Cannot assign to extension "%s" because it is a repeated or ' | 1500 'Cannot assign to extension "%s" because it is a repeated or ' |
| 1131 'composite type.' % extension_handle.full_name) | 1501 'composite type.' % extension_handle.full_name) |
| 1132 | 1502 |
| 1133 # It's slightly wasteful to lookup the type checker each time, | 1503 # It's slightly wasteful to lookup the type checker each time, |
| 1134 # but we expect this to be a vanishingly uncommon case anyway. | 1504 # but we expect this to be a vanishingly uncommon case anyway. |
| 1135 type_checker = type_checkers.GetTypeChecker( | 1505 type_checker = type_checkers.GetTypeChecker(extension_handle) |
| 1136 extension_handle.cpp_type, extension_handle.type) | 1506 # pylint: disable=protected-access |
| 1137 type_checker.CheckValue(value) | 1507 self._extended_message._fields[extension_handle] = ( |
| 1138 self._extended_message._fields[extension_handle] = value | 1508 type_checker.CheckValue(value)) |
| 1139 self._extended_message._Modified() | 1509 self._extended_message._Modified() |
| 1140 | 1510 |
| 1141 def _FindExtensionByName(self, name): | 1511 def _FindExtensionByName(self, name): |
| 1142 """Tries to find a known extension with the specified name. | 1512 """Tries to find a known extension with the specified name. |
| 1143 | 1513 |
| 1144 Args: | 1514 Args: |
| 1145 name: Extension full name. | 1515 name: Extension full name. |
| 1146 | 1516 |
| 1147 Returns: | 1517 Returns: |
| 1148 Extension field descriptor. | 1518 Extension field descriptor. |
| 1149 """ | 1519 """ |
| 1150 return self._extended_message._extensions_by_name.get(name, None) | 1520 return self._extended_message._extensions_by_name.get(name, None) |
| OLD | NEW |