OLD | NEW |
1 # Protocol Buffers - Google's data interchange format | 1 # Protocol Buffers - Google's data interchange format |
2 # Copyright 2008 Google Inc. All rights reserved. | 2 # Copyright 2008 Google Inc. All rights reserved. |
3 # http://code.google.com/p/protobuf/ | 3 # https://developers.google.com/protocol-buffers/ |
4 # | 4 # |
5 # Redistribution and use in source and binary forms, with or without | 5 # Redistribution and use in source and binary forms, with or without |
6 # modification, are permitted provided that the following conditions are | 6 # modification, are permitted provided that the following conditions are |
7 # met: | 7 # met: |
8 # | 8 # |
9 # * Redistributions of source code must retain the above copyright | 9 # * Redistributions of source code must retain the above copyright |
10 # notice, this list of conditions and the following disclaimer. | 10 # notice, this list of conditions and the following disclaimer. |
11 # * Redistributions in binary form must reproduce the above | 11 # * Redistributions in binary form must reproduce the above |
12 # copyright notice, this list of conditions and the following disclaimer | 12 # copyright notice, this list of conditions and the following disclaimer |
13 # in the documentation and/or other materials provided with the | 13 # in the documentation and/or other materials provided with the |
(...skipping 29 matching lines...) Expand all Loading... |
43 to inject all the useful functionality into the classes | 43 to inject all the useful functionality into the classes |
44 output by the protocol compiler at compile-time. | 44 output by the protocol compiler at compile-time. |
45 | 45 |
46 The upshot of all this is that the real implementation | 46 The upshot of all this is that the real implementation |
47 details for ALL pure-Python protocol buffers are *here in | 47 details for ALL pure-Python protocol buffers are *here in |
48 this file*. | 48 this file*. |
49 """ | 49 """ |
50 | 50 |
51 __author__ = 'robinson@google.com (Will Robinson)' | 51 __author__ = 'robinson@google.com (Will Robinson)' |
52 | 52 |
53 try: | 53 from io import BytesIO |
54 from cStringIO import StringIO | 54 import sys |
55 except ImportError: | |
56 from StringIO import StringIO | |
57 import copy_reg | |
58 import struct | 55 import struct |
59 import weakref | 56 import weakref |
60 | 57 |
| 58 import six |
| 59 import six.moves.copyreg as copyreg |
| 60 |
61 # We use "as" to avoid name collisions with variables. | 61 # We use "as" to avoid name collisions with variables. |
62 from google.protobuf.internal import containers | 62 from google.protobuf.internal import containers |
63 from google.protobuf.internal import decoder | 63 from google.protobuf.internal import decoder |
64 from google.protobuf.internal import encoder | 64 from google.protobuf.internal import encoder |
65 from google.protobuf.internal import enum_type_wrapper | 65 from google.protobuf.internal import enum_type_wrapper |
66 from google.protobuf.internal import message_listener as message_listener_mod | 66 from google.protobuf.internal import message_listener as message_listener_mod |
67 from google.protobuf.internal import type_checkers | 67 from google.protobuf.internal import type_checkers |
| 68 from google.protobuf.internal import well_known_types |
68 from google.protobuf.internal import wire_format | 69 from google.protobuf.internal import wire_format |
69 from google.protobuf import descriptor as descriptor_mod | 70 from google.protobuf import descriptor as descriptor_mod |
70 from google.protobuf import message as message_mod | 71 from google.protobuf import message as message_mod |
| 72 from google.protobuf import symbol_database |
71 from google.protobuf import text_format | 73 from google.protobuf import text_format |
72 | 74 |
73 _FieldDescriptor = descriptor_mod.FieldDescriptor | 75 _FieldDescriptor = descriptor_mod.FieldDescriptor |
| 76 _AnyFullTypeName = 'google.protobuf.Any' |
74 | 77 |
75 | 78 |
76 def NewMessage(bases, descriptor, dictionary): | 79 class GeneratedProtocolMessageType(type): |
77 _AddClassAttributesForNestedExtensions(descriptor, dictionary) | |
78 _AddSlots(descriptor, dictionary) | |
79 return bases | |
80 | 80 |
| 81 """Metaclass for protocol message classes created at runtime from Descriptors. |
81 | 82 |
82 def InitMessage(descriptor, cls): | 83 We add implementations for all methods described in the Message class. We |
83 cls._decoders_by_tag = {} | 84 also create properties to allow getting/setting all fields in the protocol |
84 cls._extensions_by_name = {} | 85 message. Finally, we create slots to prevent users from accidentally |
85 cls._extensions_by_number = {} | 86 "setting" nonexistent fields in the protocol message, which then wouldn't get |
86 if (descriptor.has_options and | 87 serialized / deserialized properly. |
87 descriptor.GetOptions().message_set_wire_format): | |
88 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( | |
89 decoder.MessageSetItemDecoder(cls._extensions_by_number)) | |
90 | 88 |
91 # Attach stuff to each FieldDescriptor for quick lookup later on. | 89 The protocol compiler currently uses this metaclass to create protocol |
92 for field in descriptor.fields: | 90 message classes at runtime. Clients can also manually create their own |
93 _AttachFieldHelpers(cls, field) | 91 classes at runtime, as in this example: |
94 | 92 |
95 _AddEnumValues(descriptor, cls) | 93 mydescriptor = Descriptor(.....) |
96 _AddInitMethod(descriptor, cls) | 94 class MyProtoClass(Message): |
97 _AddPropertiesForFields(descriptor, cls) | 95 __metaclass__ = GeneratedProtocolMessageType |
98 _AddPropertiesForExtensions(descriptor, cls) | 96 DESCRIPTOR = mydescriptor |
99 _AddStaticMethods(cls) | 97 myproto_instance = MyProtoClass() |
100 _AddMessageMethods(descriptor, cls) | 98 myproto.foo_field = 23 |
101 _AddPrivateHelperMethods(cls) | 99 ... |
102 copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) | 100 |
| 101 The above example will not work for nested types. If you wish to include them, |
| 102 use reflection.MakeClass() instead of manually instantiating the class in |
| 103 order to create the appropriate class structure. |
| 104 """ |
| 105 |
| 106 # Must be consistent with the protocol-compiler code in |
| 107 # proto2/compiler/internal/generator.*. |
| 108 _DESCRIPTOR_KEY = 'DESCRIPTOR' |
| 109 |
| 110 def __new__(cls, name, bases, dictionary): |
| 111 """Custom allocation for runtime-generated class types. |
| 112 |
| 113 We override __new__ because this is apparently the only place |
| 114 where we can meaningfully set __slots__ on the class we're creating(?). |
| 115 (The interplay between metaclasses and slots is not very well-documented). |
| 116 |
| 117 Args: |
| 118 name: Name of the class (ignored, but required by the |
| 119 metaclass protocol). |
| 120 bases: Base classes of the class we're constructing. |
| 121 (Should be message.Message). We ignore this field, but |
| 122 it's required by the metaclass protocol |
| 123 dictionary: The class dictionary of the class we're |
| 124 constructing. dictionary[_DESCRIPTOR_KEY] must contain |
| 125 a Descriptor object describing this protocol message |
| 126 type. |
| 127 |
| 128 Returns: |
| 129 Newly-allocated class. |
| 130 """ |
| 131 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] |
| 132 if descriptor.full_name in well_known_types.WKTBASES: |
| 133 bases += (well_known_types.WKTBASES[descriptor.full_name],) |
| 134 _AddClassAttributesForNestedExtensions(descriptor, dictionary) |
| 135 _AddSlots(descriptor, dictionary) |
| 136 |
| 137 superclass = super(GeneratedProtocolMessageType, cls) |
| 138 new_class = superclass.__new__(cls, name, bases, dictionary) |
| 139 return new_class |
| 140 |
| 141 def __init__(cls, name, bases, dictionary): |
| 142 """Here we perform the majority of our work on the class. |
| 143 We add enum getters, an __init__ method, implementations |
| 144 of all Message methods, and properties for all fields |
| 145 in the protocol type. |
| 146 |
| 147 Args: |
| 148 name: Name of the class (ignored, but required by the |
| 149 metaclass protocol). |
| 150 bases: Base classes of the class we're constructing. |
| 151 (Should be message.Message). We ignore this field, but |
| 152 it's required by the metaclass protocol |
| 153 dictionary: The class dictionary of the class we're |
| 154 constructing. dictionary[_DESCRIPTOR_KEY] must contain |
| 155 a Descriptor object describing this protocol message |
| 156 type. |
| 157 """ |
| 158 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] |
| 159 cls._decoders_by_tag = {} |
| 160 cls._extensions_by_name = {} |
| 161 cls._extensions_by_number = {} |
| 162 if (descriptor.has_options and |
| 163 descriptor.GetOptions().message_set_wire_format): |
| 164 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( |
| 165 decoder.MessageSetItemDecoder(cls._extensions_by_number), None) |
| 166 |
| 167 # Attach stuff to each FieldDescriptor for quick lookup later on. |
| 168 for field in descriptor.fields: |
| 169 _AttachFieldHelpers(cls, field) |
| 170 |
| 171 descriptor._concrete_class = cls # pylint: disable=protected-access |
| 172 _AddEnumValues(descriptor, cls) |
| 173 _AddInitMethod(descriptor, cls) |
| 174 _AddPropertiesForFields(descriptor, cls) |
| 175 _AddPropertiesForExtensions(descriptor, cls) |
| 176 _AddStaticMethods(cls) |
| 177 _AddMessageMethods(descriptor, cls) |
| 178 _AddPrivateHelperMethods(descriptor, cls) |
| 179 copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) |
| 180 |
| 181 superclass = super(GeneratedProtocolMessageType, cls) |
| 182 superclass.__init__(name, bases, dictionary) |
103 | 183 |
104 | 184 |
105 # Stateless helpers for GeneratedProtocolMessageType below. | 185 # Stateless helpers for GeneratedProtocolMessageType below. |
106 # Outside clients should not access these directly. | 186 # Outside clients should not access these directly. |
107 # | 187 # |
108 # I opted not to make any of these methods on the metaclass, to make it more | 188 # I opted not to make any of these methods on the metaclass, to make it more |
109 # clear that I'm not really using any state there and to keep clients from | 189 # clear that I'm not really using any state there and to keep clients from |
110 # thinking that they have direct access to these construction helpers. | 190 # thinking that they have direct access to these construction helpers. |
111 | 191 |
112 | 192 |
(...skipping 56 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
169 message_descriptor: A Descriptor instance describing this message type. | 249 message_descriptor: A Descriptor instance describing this message type. |
170 dictionary: Class dictionary to which we'll add a '__slots__' entry. | 250 dictionary: Class dictionary to which we'll add a '__slots__' entry. |
171 """ | 251 """ |
172 dictionary['__slots__'] = ['_cached_byte_size', | 252 dictionary['__slots__'] = ['_cached_byte_size', |
173 '_cached_byte_size_dirty', | 253 '_cached_byte_size_dirty', |
174 '_fields', | 254 '_fields', |
175 '_unknown_fields', | 255 '_unknown_fields', |
176 '_is_present_in_parent', | 256 '_is_present_in_parent', |
177 '_listener', | 257 '_listener', |
178 '_listener_for_children', | 258 '_listener_for_children', |
179 '__weakref__'] | 259 '__weakref__', |
| 260 '_oneofs'] |
180 | 261 |
181 | 262 |
182 def _IsMessageSetExtension(field): | 263 def _IsMessageSetExtension(field): |
183 return (field.is_extension and | 264 return (field.is_extension and |
184 field.containing_type.has_options and | 265 field.containing_type.has_options and |
185 field.containing_type.GetOptions().message_set_wire_format and | 266 field.containing_type.GetOptions().message_set_wire_format and |
186 field.type == _FieldDescriptor.TYPE_MESSAGE and | 267 field.type == _FieldDescriptor.TYPE_MESSAGE and |
187 field.message_type == field.extension_scope and | |
188 field.label == _FieldDescriptor.LABEL_OPTIONAL) | 268 field.label == _FieldDescriptor.LABEL_OPTIONAL) |
189 | 269 |
190 | 270 |
| 271 def _IsMapField(field): |
| 272 return (field.type == _FieldDescriptor.TYPE_MESSAGE and |
| 273 field.message_type.has_options and |
| 274 field.message_type.GetOptions().map_entry) |
| 275 |
| 276 |
| 277 def _IsMessageMapField(field): |
| 278 value_type = field.message_type.fields_by_name["value"] |
| 279 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE |
| 280 |
| 281 |
191 def _AttachFieldHelpers(cls, field_descriptor): | 282 def _AttachFieldHelpers(cls, field_descriptor): |
192 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) | 283 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) |
193 is_packed = (field_descriptor.has_options and | 284 is_packable = (is_repeated and |
194 field_descriptor.GetOptions().packed) | 285 wire_format.IsTypePackable(field_descriptor.type)) |
| 286 if not is_packable: |
| 287 is_packed = False |
| 288 elif field_descriptor.containing_type.syntax == "proto2": |
| 289 is_packed = (field_descriptor.has_options and |
| 290 field_descriptor.GetOptions().packed) |
| 291 else: |
| 292 has_packed_false = (field_descriptor.has_options and |
| 293 field_descriptor.GetOptions().HasField("packed") and |
| 294 field_descriptor.GetOptions().packed == False) |
| 295 is_packed = not has_packed_false |
| 296 is_map_entry = _IsMapField(field_descriptor) |
195 | 297 |
196 if _IsMessageSetExtension(field_descriptor): | 298 if is_map_entry: |
| 299 field_encoder = encoder.MapEncoder(field_descriptor) |
| 300 sizer = encoder.MapSizer(field_descriptor) |
| 301 elif _IsMessageSetExtension(field_descriptor): |
197 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) | 302 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) |
198 sizer = encoder.MessageSetItemSizer(field_descriptor.number) | 303 sizer = encoder.MessageSetItemSizer(field_descriptor.number) |
199 else: | 304 else: |
200 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( | 305 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( |
201 field_descriptor.number, is_repeated, is_packed) | 306 field_descriptor.number, is_repeated, is_packed) |
202 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( | 307 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( |
203 field_descriptor.number, is_repeated, is_packed) | 308 field_descriptor.number, is_repeated, is_packed) |
204 | 309 |
205 field_descriptor._encoder = field_encoder | 310 field_descriptor._encoder = field_encoder |
206 field_descriptor._sizer = sizer | 311 field_descriptor._sizer = sizer |
207 field_descriptor._default_constructor = _DefaultValueConstructorForField( | 312 field_descriptor._default_constructor = _DefaultValueConstructorForField( |
208 field_descriptor) | 313 field_descriptor) |
209 | 314 |
210 def AddDecoder(wiretype, is_packed): | 315 def AddDecoder(wiretype, is_packed): |
211 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) | 316 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) |
212 cls._decoders_by_tag[tag_bytes] = ( | 317 decode_type = field_descriptor.type |
213 type_checkers.TYPE_TO_DECODER[field_descriptor.type]( | 318 if (decode_type == _FieldDescriptor.TYPE_ENUM and |
214 field_descriptor.number, is_repeated, is_packed, | 319 type_checkers.SupportsOpenEnums(field_descriptor)): |
215 field_descriptor, field_descriptor._default_constructor)) | 320 decode_type = _FieldDescriptor.TYPE_INT32 |
| 321 |
| 322 oneof_descriptor = None |
| 323 if field_descriptor.containing_oneof is not None: |
| 324 oneof_descriptor = field_descriptor |
| 325 |
| 326 if is_map_entry: |
| 327 is_message_map = _IsMessageMapField(field_descriptor) |
| 328 |
| 329 field_decoder = decoder.MapDecoder( |
| 330 field_descriptor, _GetInitializeDefaultForMap(field_descriptor), |
| 331 is_message_map) |
| 332 else: |
| 333 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( |
| 334 field_descriptor.number, is_repeated, is_packed, |
| 335 field_descriptor, field_descriptor._default_constructor) |
| 336 |
| 337 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) |
216 | 338 |
217 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], | 339 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], |
218 False) | 340 False) |
219 | 341 |
220 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): | 342 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): |
221 # To support wire compatibility of adding packed = true, add a decoder for | 343 # To support wire compatibility of adding packed = true, add a decoder for |
222 # packed values regardless of the field's options. | 344 # packed values regardless of the field's options. |
223 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) | 345 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) |
224 | 346 |
225 | 347 |
226 def _AddClassAttributesForNestedExtensions(descriptor, dictionary): | 348 def _AddClassAttributesForNestedExtensions(descriptor, dictionary): |
227 extension_dict = descriptor.extensions_by_name | 349 extension_dict = descriptor.extensions_by_name |
228 for extension_name, extension_field in extension_dict.iteritems(): | 350 for extension_name, extension_field in extension_dict.items(): |
229 assert extension_name not in dictionary | 351 assert extension_name not in dictionary |
230 dictionary[extension_name] = extension_field | 352 dictionary[extension_name] = extension_field |
231 | 353 |
232 | 354 |
233 def _AddEnumValues(descriptor, cls): | 355 def _AddEnumValues(descriptor, cls): |
234 """Sets class-level attributes for all enum fields defined in this message. | 356 """Sets class-level attributes for all enum fields defined in this message. |
235 | 357 |
236 Also exporting a class-level object that can name enum values. | 358 Also exporting a class-level object that can name enum values. |
237 | 359 |
238 Args: | 360 Args: |
239 descriptor: Descriptor object for this message type. | 361 descriptor: Descriptor object for this message type. |
240 cls: Class we're constructing for this message type. | 362 cls: Class we're constructing for this message type. |
241 """ | 363 """ |
242 for enum_type in descriptor.enum_types: | 364 for enum_type in descriptor.enum_types: |
243 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) | 365 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) |
244 for enum_value in enum_type.values: | 366 for enum_value in enum_type.values: |
245 setattr(cls, enum_value.name, enum_value.number) | 367 setattr(cls, enum_value.name, enum_value.number) |
246 | 368 |
247 | 369 |
| 370 def _GetInitializeDefaultForMap(field): |
| 371 if field.label != _FieldDescriptor.LABEL_REPEATED: |
| 372 raise ValueError('map_entry set on non-repeated field %s' % ( |
| 373 field.name)) |
| 374 fields_by_name = field.message_type.fields_by_name |
| 375 key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) |
| 376 |
| 377 value_field = fields_by_name['value'] |
| 378 if _IsMessageMapField(field): |
| 379 def MakeMessageMapDefault(message): |
| 380 return containers.MessageMap( |
| 381 message._listener_for_children, value_field.message_type, key_checker) |
| 382 return MakeMessageMapDefault |
| 383 else: |
| 384 value_checker = type_checkers.GetTypeChecker(value_field) |
| 385 def MakePrimitiveMapDefault(message): |
| 386 return containers.ScalarMap( |
| 387 message._listener_for_children, key_checker, value_checker) |
| 388 return MakePrimitiveMapDefault |
| 389 |
248 def _DefaultValueConstructorForField(field): | 390 def _DefaultValueConstructorForField(field): |
249 """Returns a function which returns a default value for a field. | 391 """Returns a function which returns a default value for a field. |
250 | 392 |
251 Args: | 393 Args: |
252 field: FieldDescriptor object for this field. | 394 field: FieldDescriptor object for this field. |
253 | 395 |
254 The returned function has one argument: | 396 The returned function has one argument: |
255 message: Message instance containing this field, or a weakref proxy | 397 message: Message instance containing this field, or a weakref proxy |
256 of same. | 398 of same. |
257 | 399 |
258 That function in turn returns a default value for this field. The default | 400 That function in turn returns a default value for this field. The default |
259 value may refer back to |message| via a weak reference. | 401 value may refer back to |message| via a weak reference. |
260 """ | 402 """ |
261 | 403 |
| 404 if _IsMapField(field): |
| 405 return _GetInitializeDefaultForMap(field) |
| 406 |
262 if field.label == _FieldDescriptor.LABEL_REPEATED: | 407 if field.label == _FieldDescriptor.LABEL_REPEATED: |
263 if field.has_default_value and field.default_value != []: | 408 if field.has_default_value and field.default_value != []: |
264 raise ValueError('Repeated field default value not empty list: %s' % ( | 409 raise ValueError('Repeated field default value not empty list: %s' % ( |
265 field.default_value)) | 410 field.default_value)) |
266 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 411 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
267 # We can't look at _concrete_class yet since it might not have | 412 # We can't look at _concrete_class yet since it might not have |
268 # been set. (Depends on order in which we initialize the classes). | 413 # been set. (Depends on order in which we initialize the classes). |
269 message_type = field.message_type | 414 message_type = field.message_type |
270 def MakeRepeatedMessageDefault(message): | 415 def MakeRepeatedMessageDefault(message): |
271 return containers.RepeatedCompositeFieldContainer( | 416 return containers.RepeatedCompositeFieldContainer( |
272 message._listener_for_children, field.message_type) | 417 message._listener_for_children, field.message_type) |
273 return MakeRepeatedMessageDefault | 418 return MakeRepeatedMessageDefault |
274 else: | 419 else: |
275 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) | 420 type_checker = type_checkers.GetTypeChecker(field) |
276 def MakeRepeatedScalarDefault(message): | 421 def MakeRepeatedScalarDefault(message): |
277 return containers.RepeatedScalarFieldContainer( | 422 return containers.RepeatedScalarFieldContainer( |
278 message._listener_for_children, type_checker) | 423 message._listener_for_children, type_checker) |
279 return MakeRepeatedScalarDefault | 424 return MakeRepeatedScalarDefault |
280 | 425 |
281 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 426 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
282 # _concrete_class may not yet be initialized. | 427 # _concrete_class may not yet be initialized. |
283 message_type = field.message_type | 428 message_type = field.message_type |
284 def MakeSubMessageDefault(message): | 429 def MakeSubMessageDefault(message): |
285 result = message_type._concrete_class() | 430 result = message_type._concrete_class() |
286 result._SetListener(message._listener_for_children) | 431 result._SetListener( |
| 432 _OneofListener(message, field) |
| 433 if field.containing_oneof is not None |
| 434 else message._listener_for_children) |
287 return result | 435 return result |
288 return MakeSubMessageDefault | 436 return MakeSubMessageDefault |
289 | 437 |
290 def MakeScalarDefault(message): | 438 def MakeScalarDefault(message): |
291 # TODO(protobuf-team): This may be broken since there may not be | 439 # TODO(protobuf-team): This may be broken since there may not be |
292 # default_value. Combine with has_default_value somehow. | 440 # default_value. Combine with has_default_value somehow. |
293 return field.default_value | 441 return field.default_value |
294 return MakeScalarDefault | 442 return MakeScalarDefault |
295 | 443 |
296 | 444 |
| 445 def _ReraiseTypeErrorWithFieldName(message_name, field_name): |
| 446 """Re-raise the currently-handled TypeError with the field name added.""" |
| 447 exc = sys.exc_info()[1] |
| 448 if len(exc.args) == 1 and type(exc) is TypeError: |
| 449 # simple TypeError; add field name to exception message |
| 450 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) |
| 451 |
| 452 # re-raise possibly-amended exception with original traceback: |
| 453 six.reraise(type(exc), exc, sys.exc_info()[2]) |
| 454 |
| 455 |
297 def _AddInitMethod(message_descriptor, cls): | 456 def _AddInitMethod(message_descriptor, cls): |
298 """Adds an __init__ method to cls.""" | 457 """Adds an __init__ method to cls.""" |
299 fields = message_descriptor.fields | 458 |
| 459 def _GetIntegerEnumValue(enum_type, value): |
| 460 """Convert a string or integer enum value to an integer. |
| 461 |
| 462 If the value is a string, it is converted to the enum value in |
| 463 enum_type with the same name. If the value is not a string, it's |
| 464 returned as-is. (No conversion or bounds-checking is done.) |
| 465 """ |
| 466 if isinstance(value, six.string_types): |
| 467 try: |
| 468 return enum_type.values_by_name[value].number |
| 469 except KeyError: |
| 470 raise ValueError('Enum type %s: unknown label "%s"' % ( |
| 471 enum_type.full_name, value)) |
| 472 return value |
| 473 |
300 def init(self, **kwargs): | 474 def init(self, **kwargs): |
301 self._cached_byte_size = 0 | 475 self._cached_byte_size = 0 |
302 self._cached_byte_size_dirty = len(kwargs) > 0 | 476 self._cached_byte_size_dirty = len(kwargs) > 0 |
303 self._fields = {} | 477 self._fields = {} |
| 478 # Contains a mapping from oneof field descriptors to the descriptor |
| 479 # of the currently set field in that oneof field. |
| 480 self._oneofs = {} |
| 481 |
304 # _unknown_fields is () when empty for efficiency, and will be turned into | 482 # _unknown_fields is () when empty for efficiency, and will be turned into |
305 # a list if fields are added. | 483 # a list if fields are added. |
306 self._unknown_fields = () | 484 self._unknown_fields = () |
307 self._is_present_in_parent = False | 485 self._is_present_in_parent = False |
308 self._listener = message_listener_mod.NullMessageListener() | 486 self._listener = message_listener_mod.NullMessageListener() |
309 self._listener_for_children = _Listener(self) | 487 self._listener_for_children = _Listener(self) |
310 for field_name, field_value in kwargs.iteritems(): | 488 for field_name, field_value in kwargs.items(): |
311 field = _GetFieldByName(message_descriptor, field_name) | 489 field = _GetFieldByName(message_descriptor, field_name) |
312 if field is None: | 490 if field is None: |
313 raise TypeError("%s() got an unexpected keyword argument '%s'" % | 491 raise TypeError("%s() got an unexpected keyword argument '%s'" % |
314 (message_descriptor.name, field_name)) | 492 (message_descriptor.name, field_name)) |
315 if field.label == _FieldDescriptor.LABEL_REPEATED: | 493 if field.label == _FieldDescriptor.LABEL_REPEATED: |
316 copy = field._default_constructor(self) | 494 copy = field._default_constructor(self) |
317 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite | 495 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite |
318 for val in field_value: | 496 if _IsMapField(field): |
319 copy.add().MergeFrom(val) | 497 if _IsMessageMapField(field): |
| 498 for key in field_value: |
| 499 copy[key].MergeFrom(field_value[key]) |
| 500 else: |
| 501 copy.update(field_value) |
| 502 else: |
| 503 for val in field_value: |
| 504 if isinstance(val, dict): |
| 505 copy.add(**val) |
| 506 else: |
| 507 copy.add().MergeFrom(val) |
320 else: # Scalar | 508 else: # Scalar |
| 509 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: |
| 510 field_value = [_GetIntegerEnumValue(field.enum_type, val) |
| 511 for val in field_value] |
321 copy.extend(field_value) | 512 copy.extend(field_value) |
322 self._fields[field] = copy | 513 self._fields[field] = copy |
323 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 514 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
324 copy = field._default_constructor(self) | 515 copy = field._default_constructor(self) |
325 copy.MergeFrom(field_value) | 516 new_val = field_value |
| 517 if isinstance(field_value, dict): |
| 518 new_val = field.message_type._concrete_class(**field_value) |
| 519 try: |
| 520 copy.MergeFrom(new_val) |
| 521 except TypeError: |
| 522 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) |
326 self._fields[field] = copy | 523 self._fields[field] = copy |
327 else: | 524 else: |
328 setattr(self, field_name, field_value) | 525 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: |
| 526 field_value = _GetIntegerEnumValue(field.enum_type, field_value) |
| 527 try: |
| 528 setattr(self, field_name, field_value) |
| 529 except TypeError: |
| 530 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) |
329 | 531 |
330 init.__module__ = None | 532 init.__module__ = None |
331 init.__doc__ = None | 533 init.__doc__ = None |
332 cls.__init__ = init | 534 cls.__init__ = init |
333 | 535 |
334 | 536 |
335 def _GetFieldByName(message_descriptor, field_name): | 537 def _GetFieldByName(message_descriptor, field_name): |
336 """Returns a field descriptor by field name. | 538 """Returns a field descriptor by field name. |
337 | 539 |
338 Args: | 540 Args: |
339 message_descriptor: A Descriptor describing all fields in message. | 541 message_descriptor: A Descriptor describing all fields in message. |
340 field_name: The name of the field to retrieve. | 542 field_name: The name of the field to retrieve. |
341 Returns: | 543 Returns: |
342 The field descriptor associated with the field name. | 544 The field descriptor associated with the field name. |
343 """ | 545 """ |
344 try: | 546 try: |
345 return message_descriptor.fields_by_name[field_name] | 547 return message_descriptor.fields_by_name[field_name] |
346 except KeyError: | 548 except KeyError: |
347 raise ValueError('Protocol message has no "%s" field.' % field_name) | 549 raise ValueError('Protocol message %s has no "%s" field.' % |
| 550 (message_descriptor.name, field_name)) |
348 | 551 |
349 | 552 |
350 def _AddPropertiesForFields(descriptor, cls): | 553 def _AddPropertiesForFields(descriptor, cls): |
351 """Adds properties for all fields in this protocol message type.""" | 554 """Adds properties for all fields in this protocol message type.""" |
352 for field in descriptor.fields: | 555 for field in descriptor.fields: |
353 _AddPropertiesForField(field, cls) | 556 _AddPropertiesForField(field, cls) |
354 | 557 |
355 if descriptor.is_extendable: | 558 if descriptor.is_extendable: |
356 # _ExtensionDict is just an adaptor with no state so we allocate a new one | 559 # _ExtensionDict is just an adaptor with no state so we allocate a new one |
357 # every time it is accessed. | 560 # every time it is accessed. |
(...skipping 75 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
433 Note that when the client sets the value of a field by using this property, | 636 Note that when the client sets the value of a field by using this property, |
434 all necessary "has" bits are set as a side-effect, and we also perform | 637 all necessary "has" bits are set as a side-effect, and we also perform |
435 type-checking. | 638 type-checking. |
436 | 639 |
437 Args: | 640 Args: |
438 field: A FieldDescriptor for this field. | 641 field: A FieldDescriptor for this field. |
439 cls: The class we're constructing. | 642 cls: The class we're constructing. |
440 """ | 643 """ |
441 proto_field_name = field.name | 644 proto_field_name = field.name |
442 property_name = _PropertyName(proto_field_name) | 645 property_name = _PropertyName(proto_field_name) |
443 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) | 646 type_checker = type_checkers.GetTypeChecker(field) |
444 default_value = field.default_value | 647 default_value = field.default_value |
445 valid_values = set() | 648 valid_values = set() |
| 649 is_proto3 = field.containing_type.syntax == "proto3" |
446 | 650 |
447 def getter(self): | 651 def getter(self): |
448 # TODO(protobuf-team): This may be broken since there may not be | 652 # TODO(protobuf-team): This may be broken since there may not be |
449 # default_value. Combine with has_default_value somehow. | 653 # default_value. Combine with has_default_value somehow. |
450 return self._fields.get(field, default_value) | 654 return self._fields.get(field, default_value) |
451 getter.__module__ = None | 655 getter.__module__ = None |
452 getter.__doc__ = 'Getter for %s.' % proto_field_name | 656 getter.__doc__ = 'Getter for %s.' % proto_field_name |
453 def setter(self, new_value): | 657 |
454 type_checker.CheckValue(new_value) | 658 clear_when_set_to_default = is_proto3 and not field.containing_oneof |
455 self._fields[field] = new_value | 659 |
| 660 def field_setter(self, new_value): |
| 661 # pylint: disable=protected-access |
| 662 # Testing the value for truthiness captures all of the proto3 defaults |
| 663 # (0, 0.0, enum 0, and False). |
| 664 new_value = type_checker.CheckValue(new_value) |
| 665 if clear_when_set_to_default and not new_value: |
| 666 self._fields.pop(field, None) |
| 667 else: |
| 668 self._fields[field] = new_value |
456 # Check _cached_byte_size_dirty inline to improve performance, since scalar | 669 # Check _cached_byte_size_dirty inline to improve performance, since scalar |
457 # setters are called frequently. | 670 # setters are called frequently. |
458 if not self._cached_byte_size_dirty: | 671 if not self._cached_byte_size_dirty: |
459 self._Modified() | 672 self._Modified() |
460 | 673 |
| 674 if field.containing_oneof: |
| 675 def setter(self, new_value): |
| 676 field_setter(self, new_value) |
| 677 self._UpdateOneofState(field) |
| 678 else: |
| 679 setter = field_setter |
| 680 |
461 setter.__module__ = None | 681 setter.__module__ = None |
462 setter.__doc__ = 'Setter for %s.' % proto_field_name | 682 setter.__doc__ = 'Setter for %s.' % proto_field_name |
463 | 683 |
464 # Add a property to encapsulate the getter/setter. | 684 # Add a property to encapsulate the getter/setter. |
465 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name | 685 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name |
466 setattr(cls, property_name, property(getter, setter, doc=doc)) | 686 setattr(cls, property_name, property(getter, setter, doc=doc)) |
467 | 687 |
468 | 688 |
469 def _AddPropertiesForNonRepeatedCompositeField(field, cls): | 689 def _AddPropertiesForNonRepeatedCompositeField(field, cls): |
470 """Adds a public property for a nonrepeated, composite protocol message field. | 690 """Adds a public property for a nonrepeated, composite protocol message field. |
471 A composite field is a "group" or "message" field. | 691 A composite field is a "group" or "message" field. |
472 | 692 |
473 Clients can use this property to get the value of the field, but cannot | 693 Clients can use this property to get the value of the field, but cannot |
474 assign to the property directly. | 694 assign to the property directly. |
475 | 695 |
476 Args: | 696 Args: |
477 field: A FieldDescriptor for this field. | 697 field: A FieldDescriptor for this field. |
478 cls: The class we're constructing. | 698 cls: The class we're constructing. |
479 """ | 699 """ |
480 # TODO(robinson): Remove duplication with similar method | 700 # TODO(robinson): Remove duplication with similar method |
481 # for non-repeated scalars. | 701 # for non-repeated scalars. |
482 proto_field_name = field.name | 702 proto_field_name = field.name |
483 property_name = _PropertyName(proto_field_name) | 703 property_name = _PropertyName(proto_field_name) |
484 | 704 |
485 # TODO(komarek): Can anyone explain to me why we cache the message_type this | |
486 # way, instead of referring to field.message_type inside of getter(self)? | |
487 # What if someone sets message_type later on (which makes for simpler | |
488 # dyanmic proto descriptor and class creation code). | |
489 message_type = field.message_type | |
490 | |
491 def getter(self): | 705 def getter(self): |
492 field_value = self._fields.get(field) | 706 field_value = self._fields.get(field) |
493 if field_value is None: | 707 if field_value is None: |
494 # Construct a new object to represent this field. | 708 # Construct a new object to represent this field. |
495 field_value = message_type._concrete_class() # use field.message_type? | 709 field_value = field._default_constructor(self) |
496 field_value._SetListener(self._listener_for_children) | |
497 | 710 |
498 # Atomically check if another thread has preempted us and, if not, swap | 711 # Atomically check if another thread has preempted us and, if not, swap |
499 # in the new object we just created. If someone has preempted us, we | 712 # in the new object we just created. If someone has preempted us, we |
500 # take that object and discard ours. | 713 # take that object and discard ours. |
501 # WARNING: We are relying on setdefault() being atomic. This is true | 714 # WARNING: We are relying on setdefault() being atomic. This is true |
502 # in CPython but we haven't investigated others. This warning appears | 715 # in CPython but we haven't investigated others. This warning appears |
503 # in several other locations in this file. | 716 # in several other locations in this file. |
504 field_value = self._fields.setdefault(field, field_value) | 717 field_value = self._fields.setdefault(field, field_value) |
505 return field_value | 718 return field_value |
506 getter.__module__ = None | 719 getter.__module__ = None |
507 getter.__doc__ = 'Getter for %s.' % proto_field_name | 720 getter.__doc__ = 'Getter for %s.' % proto_field_name |
508 | 721 |
509 # We define a setter just so we can throw an exception with a more | 722 # We define a setter just so we can throw an exception with a more |
510 # helpful error message. | 723 # helpful error message. |
511 def setter(self, new_value): | 724 def setter(self, new_value): |
512 raise AttributeError('Assignment not allowed to composite field ' | 725 raise AttributeError('Assignment not allowed to composite field ' |
513 '"%s" in protocol message object.' % proto_field_name) | 726 '"%s" in protocol message object.' % proto_field_name) |
514 | 727 |
515 # Add a property to encapsulate the getter. | 728 # Add a property to encapsulate the getter. |
516 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name | 729 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name |
517 setattr(cls, property_name, property(getter, setter, doc=doc)) | 730 setattr(cls, property_name, property(getter, setter, doc=doc)) |
518 | 731 |
519 | 732 |
520 def _AddPropertiesForExtensions(descriptor, cls): | 733 def _AddPropertiesForExtensions(descriptor, cls): |
521 """Adds properties for all fields in this protocol message type.""" | 734 """Adds properties for all fields in this protocol message type.""" |
522 extension_dict = descriptor.extensions_by_name | 735 extension_dict = descriptor.extensions_by_name |
523 for extension_name, extension_field in extension_dict.iteritems(): | 736 for extension_name, extension_field in extension_dict.items(): |
524 constant_name = extension_name.upper() + "_FIELD_NUMBER" | 737 constant_name = extension_name.upper() + "_FIELD_NUMBER" |
525 setattr(cls, constant_name, extension_field.number) | 738 setattr(cls, constant_name, extension_field.number) |
526 | 739 |
527 | 740 |
528 def _AddStaticMethods(cls): | 741 def _AddStaticMethods(cls): |
529 # TODO(robinson): This probably needs to be thread-safe(?) | 742 # TODO(robinson): This probably needs to be thread-safe(?) |
530 def RegisterExtension(extension_handle): | 743 def RegisterExtension(extension_handle): |
531 extension_handle.containing_type = cls.DESCRIPTOR | 744 extension_handle.containing_type = cls.DESCRIPTOR |
532 _AttachFieldHelpers(cls, extension_handle) | 745 _AttachFieldHelpers(cls, extension_handle) |
533 | 746 |
(...skipping 34 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
568 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 781 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
569 return item[1]._is_present_in_parent | 782 return item[1]._is_present_in_parent |
570 else: | 783 else: |
571 return True | 784 return True |
572 | 785 |
573 | 786 |
574 def _AddListFieldsMethod(message_descriptor, cls): | 787 def _AddListFieldsMethod(message_descriptor, cls): |
575 """Helper for _AddMessageMethods().""" | 788 """Helper for _AddMessageMethods().""" |
576 | 789 |
577 def ListFields(self): | 790 def ListFields(self): |
578 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] | 791 all_fields = [item for item in self._fields.items() if _IsPresent(item)] |
579 all_fields.sort(key = lambda item: item[0].number) | 792 all_fields.sort(key = lambda item: item[0].number) |
580 return all_fields | 793 return all_fields |
581 | 794 |
582 cls.ListFields = ListFields | 795 cls.ListFields = ListFields |
583 | 796 |
| 797 _Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"' |
| 798 _Proto2HasError = 'Protocol message has no non-repeated field "%s"' |
584 | 799 |
585 def _AddHasFieldMethod(message_descriptor, cls): | 800 def _AddHasFieldMethod(message_descriptor, cls): |
586 """Helper for _AddMessageMethods().""" | 801 """Helper for _AddMessageMethods().""" |
587 | 802 |
588 singular_fields = {} | 803 is_proto3 = (message_descriptor.syntax == "proto3") |
| 804 error_msg = _Proto3HasError if is_proto3 else _Proto2HasError |
| 805 |
| 806 hassable_fields = {} |
589 for field in message_descriptor.fields: | 807 for field in message_descriptor.fields: |
590 if field.label != _FieldDescriptor.LABEL_REPEATED: | 808 if field.label == _FieldDescriptor.LABEL_REPEATED: |
591 singular_fields[field.name] = field | 809 continue |
| 810 # For proto3, only submessages and fields inside a oneof have presence. |
| 811 if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and |
| 812 not field.containing_oneof): |
| 813 continue |
| 814 hassable_fields[field.name] = field |
| 815 |
| 816 if not is_proto3: |
| 817 # Fields inside oneofs are never repeated (enforced by the compiler). |
| 818 for oneof in message_descriptor.oneofs: |
| 819 hassable_fields[oneof.name] = oneof |
592 | 820 |
593 def HasField(self, field_name): | 821 def HasField(self, field_name): |
594 try: | 822 try: |
595 field = singular_fields[field_name] | 823 field = hassable_fields[field_name] |
596 except KeyError: | 824 except KeyError: |
597 raise ValueError( | 825 raise ValueError(error_msg % field_name) |
598 'Protocol message has no singular "%s" field.' % field_name) | |
599 | 826 |
600 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 827 if isinstance(field, descriptor_mod.OneofDescriptor): |
601 value = self._fields.get(field) | 828 try: |
602 return value is not None and value._is_present_in_parent | 829 return HasField(self, self._oneofs[field].name) |
| 830 except KeyError: |
| 831 return False |
603 else: | 832 else: |
604 return field in self._fields | 833 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 834 value = self._fields.get(field) |
| 835 return value is not None and value._is_present_in_parent |
| 836 else: |
| 837 return field in self._fields |
| 838 |
605 cls.HasField = HasField | 839 cls.HasField = HasField |
606 | 840 |
607 | 841 |
608 def _AddClearFieldMethod(message_descriptor, cls): | 842 def _AddClearFieldMethod(message_descriptor, cls): |
609 """Helper for _AddMessageMethods().""" | 843 """Helper for _AddMessageMethods().""" |
610 def ClearField(self, field_name): | 844 def ClearField(self, field_name): |
611 try: | 845 try: |
612 field = message_descriptor.fields_by_name[field_name] | 846 field = message_descriptor.fields_by_name[field_name] |
613 except KeyError: | 847 except KeyError: |
614 raise ValueError('Protocol message has no "%s" field.' % field_name) | 848 try: |
| 849 field = message_descriptor.oneofs_by_name[field_name] |
| 850 if field in self._oneofs: |
| 851 field = self._oneofs[field] |
| 852 else: |
| 853 return |
| 854 except KeyError: |
| 855 raise ValueError('Protocol message %s() has no "%s" field.' % |
| 856 (message_descriptor.name, field_name)) |
615 | 857 |
616 if field in self._fields: | 858 if field in self._fields: |
| 859 # To match the C++ implementation, we need to invalidate iterators |
| 860 # for map fields when ClearField() happens. |
| 861 if hasattr(self._fields[field], 'InvalidateIterators'): |
| 862 self._fields[field].InvalidateIterators() |
| 863 |
617 # Note: If the field is a sub-message, its listener will still point | 864 # Note: If the field is a sub-message, its listener will still point |
618 # at us. That's fine, because the worst than can happen is that it | 865 # at us. That's fine, because the worst than can happen is that it |
619 # will call _Modified() and invalidate our byte size. Big deal. | 866 # will call _Modified() and invalidate our byte size. Big deal. |
620 del self._fields[field] | 867 del self._fields[field] |
621 | 868 |
| 869 if self._oneofs.get(field.containing_oneof, None) is field: |
| 870 del self._oneofs[field.containing_oneof] |
| 871 |
622 # Always call _Modified() -- even if nothing was changed, this is | 872 # Always call _Modified() -- even if nothing was changed, this is |
623 # a mutating method, and thus calling it should cause the field to become | 873 # a mutating method, and thus calling it should cause the field to become |
624 # present in the parent message. | 874 # present in the parent message. |
625 self._Modified() | 875 self._Modified() |
626 | 876 |
627 cls.ClearField = ClearField | 877 cls.ClearField = ClearField |
628 | 878 |
629 | 879 |
630 def _AddClearExtensionMethod(cls): | 880 def _AddClearExtensionMethod(cls): |
631 """Helper for _AddMessageMethods().""" | 881 """Helper for _AddMessageMethods().""" |
632 def ClearExtension(self, extension_handle): | 882 def ClearExtension(self, extension_handle): |
633 _VerifyExtensionHandle(self, extension_handle) | 883 _VerifyExtensionHandle(self, extension_handle) |
634 | 884 |
635 # Similar to ClearField(), above. | 885 # Similar to ClearField(), above. |
636 if extension_handle in self._fields: | 886 if extension_handle in self._fields: |
637 del self._fields[extension_handle] | 887 del self._fields[extension_handle] |
638 self._Modified() | 888 self._Modified() |
639 cls.ClearExtension = ClearExtension | 889 cls.ClearExtension = ClearExtension |
640 | 890 |
641 | 891 |
642 def _AddClearMethod(message_descriptor, cls): | 892 def _AddClearMethod(message_descriptor, cls): |
643 """Helper for _AddMessageMethods().""" | 893 """Helper for _AddMessageMethods().""" |
644 def Clear(self): | 894 def Clear(self): |
645 # Clear fields. | 895 # Clear fields. |
646 self._fields = {} | 896 self._fields = {} |
647 self._unknown_fields = () | 897 self._unknown_fields = () |
| 898 self._oneofs = {} |
648 self._Modified() | 899 self._Modified() |
649 cls.Clear = Clear | 900 cls.Clear = Clear |
650 | 901 |
651 | 902 |
652 def _AddHasExtensionMethod(cls): | 903 def _AddHasExtensionMethod(cls): |
653 """Helper for _AddMessageMethods().""" | 904 """Helper for _AddMessageMethods().""" |
654 def HasExtension(self, extension_handle): | 905 def HasExtension(self, extension_handle): |
655 _VerifyExtensionHandle(self, extension_handle) | 906 _VerifyExtensionHandle(self, extension_handle) |
656 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: | 907 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: |
657 raise KeyError('"%s" is repeated.' % extension_handle.full_name) | 908 raise KeyError('"%s" is repeated.' % extension_handle.full_name) |
658 | 909 |
659 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 910 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
660 value = self._fields.get(extension_handle) | 911 value = self._fields.get(extension_handle) |
661 return value is not None and value._is_present_in_parent | 912 return value is not None and value._is_present_in_parent |
662 else: | 913 else: |
663 return extension_handle in self._fields | 914 return extension_handle in self._fields |
664 cls.HasExtension = HasExtension | 915 cls.HasExtension = HasExtension |
665 | 916 |
| 917 def _InternalUnpackAny(msg): |
| 918 """Unpacks Any message and returns the unpacked message. |
| 919 |
| 920 This internal method is differnt from public Any Unpack method which takes |
| 921 the target message as argument. _InternalUnpackAny method does not have |
| 922 target message type and need to find the message type in descriptor pool. |
| 923 |
| 924 Args: |
| 925 msg: An Any message to be unpacked. |
| 926 |
| 927 Returns: |
| 928 The unpacked message. |
| 929 """ |
| 930 type_url = msg.type_url |
| 931 db = symbol_database.Default() |
| 932 |
| 933 if not type_url: |
| 934 return None |
| 935 |
| 936 # TODO(haberman): For now we just strip the hostname. Better logic will be |
| 937 # required. |
| 938 type_name = type_url.split("/")[-1] |
| 939 descriptor = db.pool.FindMessageTypeByName(type_name) |
| 940 |
| 941 if descriptor is None: |
| 942 return None |
| 943 |
| 944 message_class = db.GetPrototype(descriptor) |
| 945 message = message_class() |
| 946 |
| 947 message.ParseFromString(msg.value) |
| 948 return message |
666 | 949 |
667 def _AddEqualsMethod(message_descriptor, cls): | 950 def _AddEqualsMethod(message_descriptor, cls): |
668 """Helper for _AddMessageMethods().""" | 951 """Helper for _AddMessageMethods().""" |
669 def __eq__(self, other): | 952 def __eq__(self, other): |
670 if (not isinstance(other, message_mod.Message) or | 953 if (not isinstance(other, message_mod.Message) or |
671 other.DESCRIPTOR != self.DESCRIPTOR): | 954 other.DESCRIPTOR != self.DESCRIPTOR): |
672 return False | 955 return False |
673 | 956 |
674 if self is other: | 957 if self is other: |
675 return True | 958 return True |
676 | 959 |
| 960 if self.DESCRIPTOR.full_name == _AnyFullTypeName: |
| 961 any_a = _InternalUnpackAny(self) |
| 962 any_b = _InternalUnpackAny(other) |
| 963 if any_a and any_b: |
| 964 return any_a == any_b |
| 965 |
677 if not self.ListFields() == other.ListFields(): | 966 if not self.ListFields() == other.ListFields(): |
678 return False | 967 return False |
679 | 968 |
680 # Sort unknown fields because their order shouldn't affect equality test. | 969 # Sort unknown fields because their order shouldn't affect equality test. |
681 unknown_fields = list(self._unknown_fields) | 970 unknown_fields = list(self._unknown_fields) |
682 unknown_fields.sort() | 971 unknown_fields.sort() |
683 other_unknown_fields = list(other._unknown_fields) | 972 other_unknown_fields = list(other._unknown_fields) |
684 other_unknown_fields.sort() | 973 other_unknown_fields.sort() |
685 | 974 |
686 return unknown_fields == other_unknown_fields | 975 return unknown_fields == other_unknown_fields |
687 | 976 |
688 cls.__eq__ = __eq__ | 977 cls.__eq__ = __eq__ |
689 | 978 |
690 | 979 |
691 def _AddStrMethod(message_descriptor, cls): | 980 def _AddStrMethod(message_descriptor, cls): |
692 """Helper for _AddMessageMethods().""" | 981 """Helper for _AddMessageMethods().""" |
693 def __str__(self): | 982 def __str__(self): |
694 return text_format.MessageToString(self) | 983 return text_format.MessageToString(self) |
695 cls.__str__ = __str__ | 984 cls.__str__ = __str__ |
696 | 985 |
697 | 986 |
| 987 def _AddReprMethod(message_descriptor, cls): |
| 988 """Helper for _AddMessageMethods().""" |
| 989 def __repr__(self): |
| 990 return text_format.MessageToString(self) |
| 991 cls.__repr__ = __repr__ |
| 992 |
| 993 |
698 def _AddUnicodeMethod(unused_message_descriptor, cls): | 994 def _AddUnicodeMethod(unused_message_descriptor, cls): |
699 """Helper for _AddMessageMethods().""" | 995 """Helper for _AddMessageMethods().""" |
700 | 996 |
701 def __unicode__(self): | 997 def __unicode__(self): |
702 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') | 998 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') |
703 cls.__unicode__ = __unicode__ | 999 cls.__unicode__ = __unicode__ |
704 | 1000 |
705 | 1001 |
706 def _AddSetListenerMethod(cls): | 1002 def _AddSetListenerMethod(cls): |
707 """Helper for _AddMessageMethods().""" | 1003 """Helper for _AddMessageMethods().""" |
(...skipping 58 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
766 'Message %s is missing required fields: %s' % ( | 1062 'Message %s is missing required fields: %s' % ( |
767 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) | 1063 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) |
768 return self.SerializePartialToString() | 1064 return self.SerializePartialToString() |
769 cls.SerializeToString = SerializeToString | 1065 cls.SerializeToString = SerializeToString |
770 | 1066 |
771 | 1067 |
772 def _AddSerializePartialToStringMethod(message_descriptor, cls): | 1068 def _AddSerializePartialToStringMethod(message_descriptor, cls): |
773 """Helper for _AddMessageMethods().""" | 1069 """Helper for _AddMessageMethods().""" |
774 | 1070 |
775 def SerializePartialToString(self): | 1071 def SerializePartialToString(self): |
776 out = StringIO() | 1072 out = BytesIO() |
777 self._InternalSerialize(out.write) | 1073 self._InternalSerialize(out.write) |
778 return out.getvalue() | 1074 return out.getvalue() |
779 cls.SerializePartialToString = SerializePartialToString | 1075 cls.SerializePartialToString = SerializePartialToString |
780 | 1076 |
781 def InternalSerialize(self, write_bytes): | 1077 def InternalSerialize(self, write_bytes): |
782 for field_descriptor, field_value in self.ListFields(): | 1078 for field_descriptor, field_value in self.ListFields(): |
783 field_descriptor._encoder(write_bytes, field_value) | 1079 field_descriptor._encoder(write_bytes, field_value) |
784 for tag_bytes, value_bytes in self._unknown_fields: | 1080 for tag_bytes, value_bytes in self._unknown_fields: |
785 write_bytes(tag_bytes) | 1081 write_bytes(tag_bytes) |
786 write_bytes(value_bytes) | 1082 write_bytes(value_bytes) |
787 cls._InternalSerialize = InternalSerialize | 1083 cls._InternalSerialize = InternalSerialize |
788 | 1084 |
789 | 1085 |
790 def _AddMergeFromStringMethod(message_descriptor, cls): | 1086 def _AddMergeFromStringMethod(message_descriptor, cls): |
791 """Helper for _AddMessageMethods().""" | 1087 """Helper for _AddMessageMethods().""" |
792 def MergeFromString(self, serialized): | 1088 def MergeFromString(self, serialized): |
793 length = len(serialized) | 1089 length = len(serialized) |
794 try: | 1090 try: |
795 if self._InternalParse(serialized, 0, length) != length: | 1091 if self._InternalParse(serialized, 0, length) != length: |
796 # The only reason _InternalParse would return early is if it | 1092 # The only reason _InternalParse would return early is if it |
797 # encountered an end-group tag. | 1093 # encountered an end-group tag. |
798 raise message_mod.DecodeError('Unexpected end-group tag.') | 1094 raise message_mod.DecodeError('Unexpected end-group tag.') |
799 except IndexError: | 1095 except (IndexError, TypeError): |
| 1096 # Now ord(buf[p:p+1]) == ord('') gets TypeError. |
800 raise message_mod.DecodeError('Truncated message.') | 1097 raise message_mod.DecodeError('Truncated message.') |
801 except struct.error, e: | 1098 except struct.error as e: |
802 raise message_mod.DecodeError(e) | 1099 raise message_mod.DecodeError(e) |
803 return length # Return this for legacy reasons. | 1100 return length # Return this for legacy reasons. |
804 cls.MergeFromString = MergeFromString | 1101 cls.MergeFromString = MergeFromString |
805 | 1102 |
806 local_ReadTag = decoder.ReadTag | 1103 local_ReadTag = decoder.ReadTag |
807 local_SkipField = decoder.SkipField | 1104 local_SkipField = decoder.SkipField |
808 decoders_by_tag = cls._decoders_by_tag | 1105 decoders_by_tag = cls._decoders_by_tag |
| 1106 is_proto3 = message_descriptor.syntax == "proto3" |
809 | 1107 |
810 def InternalParse(self, buffer, pos, end): | 1108 def InternalParse(self, buffer, pos, end): |
811 self._Modified() | 1109 self._Modified() |
812 field_dict = self._fields | 1110 field_dict = self._fields |
813 unknown_field_list = self._unknown_fields | 1111 unknown_field_list = self._unknown_fields |
814 while pos != end: | 1112 while pos != end: |
815 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) | 1113 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) |
816 field_decoder = decoders_by_tag.get(tag_bytes) | 1114 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) |
817 if field_decoder is None: | 1115 if field_decoder is None: |
818 value_start_pos = new_pos | 1116 value_start_pos = new_pos |
819 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) | 1117 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) |
820 if new_pos == -1: | 1118 if new_pos == -1: |
821 return pos | 1119 return pos |
822 if not unknown_field_list: | 1120 if not is_proto3: |
823 unknown_field_list = self._unknown_fields = [] | 1121 if not unknown_field_list: |
824 unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) | 1122 unknown_field_list = self._unknown_fields = [] |
| 1123 unknown_field_list.append( |
| 1124 (tag_bytes, buffer[value_start_pos:new_pos])) |
825 pos = new_pos | 1125 pos = new_pos |
826 else: | 1126 else: |
827 pos = field_decoder(buffer, new_pos, end, self, field_dict) | 1127 pos = field_decoder(buffer, new_pos, end, self, field_dict) |
| 1128 if field_desc: |
| 1129 self._UpdateOneofState(field_desc) |
828 return pos | 1130 return pos |
829 cls._InternalParse = InternalParse | 1131 cls._InternalParse = InternalParse |
830 | 1132 |
831 | 1133 |
832 def _AddIsInitializedMethod(message_descriptor, cls): | 1134 def _AddIsInitializedMethod(message_descriptor, cls): |
833 """Adds the IsInitialized and FindInitializationError methods to the | 1135 """Adds the IsInitialized and FindInitializationError methods to the |
834 protocol message class.""" | 1136 protocol message class.""" |
835 | 1137 |
836 required_fields = [field for field in message_descriptor.fields | 1138 required_fields = [field for field in message_descriptor.fields |
837 if field.label == _FieldDescriptor.LABEL_REQUIRED] | 1139 if field.label == _FieldDescriptor.LABEL_REQUIRED] |
(...skipping 12 matching lines...) Expand all Loading... |
850 # Performance is critical so we avoid HasField() and ListFields(). | 1152 # Performance is critical so we avoid HasField() and ListFields(). |
851 | 1153 |
852 for field in required_fields: | 1154 for field in required_fields: |
853 if (field not in self._fields or | 1155 if (field not in self._fields or |
854 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and | 1156 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and |
855 not self._fields[field]._is_present_in_parent)): | 1157 not self._fields[field]._is_present_in_parent)): |
856 if errors is not None: | 1158 if errors is not None: |
857 errors.extend(self.FindInitializationErrors()) | 1159 errors.extend(self.FindInitializationErrors()) |
858 return False | 1160 return False |
859 | 1161 |
860 for field, value in self._fields.iteritems(): | 1162 for field, value in list(self._fields.items()): # dict can change size! |
861 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 1163 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
862 if field.label == _FieldDescriptor.LABEL_REPEATED: | 1164 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 1165 if (field.message_type.has_options and |
| 1166 field.message_type.GetOptions().map_entry): |
| 1167 continue |
863 for element in value: | 1168 for element in value: |
864 if not element.IsInitialized(): | 1169 if not element.IsInitialized(): |
865 if errors is not None: | 1170 if errors is not None: |
866 errors.extend(self.FindInitializationErrors()) | 1171 errors.extend(self.FindInitializationErrors()) |
867 return False | 1172 return False |
868 elif value._is_present_in_parent and not value.IsInitialized(): | 1173 elif value._is_present_in_parent and not value.IsInitialized(): |
869 if errors is not None: | 1174 if errors is not None: |
870 errors.extend(self.FindInitializationErrors()) | 1175 errors.extend(self.FindInitializationErrors()) |
871 return False | 1176 return False |
872 | 1177 |
(...skipping 15 matching lines...) Expand all Loading... |
888 if not self.HasField(field.name): | 1193 if not self.HasField(field.name): |
889 errors.append(field.name) | 1194 errors.append(field.name) |
890 | 1195 |
891 for field, value in self.ListFields(): | 1196 for field, value in self.ListFields(): |
892 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | 1197 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
893 if field.is_extension: | 1198 if field.is_extension: |
894 name = "(%s)" % field.full_name | 1199 name = "(%s)" % field.full_name |
895 else: | 1200 else: |
896 name = field.name | 1201 name = field.name |
897 | 1202 |
898 if field.label == _FieldDescriptor.LABEL_REPEATED: | 1203 if _IsMapField(field): |
899 for i in xrange(len(value)): | 1204 if _IsMessageMapField(field): |
| 1205 for key in value: |
| 1206 element = value[key] |
| 1207 prefix = "%s[%s]." % (name, key) |
| 1208 sub_errors = element.FindInitializationErrors() |
| 1209 errors += [prefix + error for error in sub_errors] |
| 1210 else: |
| 1211 # ScalarMaps can't have any initialization errors. |
| 1212 pass |
| 1213 elif field.label == _FieldDescriptor.LABEL_REPEATED: |
| 1214 for i in range(len(value)): |
900 element = value[i] | 1215 element = value[i] |
901 prefix = "%s[%d]." % (name, i) | 1216 prefix = "%s[%d]." % (name, i) |
902 sub_errors = element.FindInitializationErrors() | 1217 sub_errors = element.FindInitializationErrors() |
903 errors += [ prefix + error for error in sub_errors ] | 1218 errors += [prefix + error for error in sub_errors] |
904 else: | 1219 else: |
905 prefix = name + "." | 1220 prefix = name + "." |
906 sub_errors = value.FindInitializationErrors() | 1221 sub_errors = value.FindInitializationErrors() |
907 errors += [ prefix + error for error in sub_errors ] | 1222 errors += [prefix + error for error in sub_errors] |
908 | 1223 |
909 return errors | 1224 return errors |
910 | 1225 |
911 cls.FindInitializationErrors = FindInitializationErrors | 1226 cls.FindInitializationErrors = FindInitializationErrors |
912 | 1227 |
913 | 1228 |
914 def _AddMergeFromMethod(cls): | 1229 def _AddMergeFromMethod(cls): |
915 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED | 1230 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED |
916 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE | 1231 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE |
917 | 1232 |
918 def MergeFrom(self, msg): | 1233 def MergeFrom(self, msg): |
919 if not isinstance(msg, cls): | 1234 if not isinstance(msg, cls): |
920 raise TypeError( | 1235 raise TypeError( |
921 "Parameter to MergeFrom() must be instance of same class: " | 1236 "Parameter to MergeFrom() must be instance of same class: " |
922 "expected %s got %s." % (cls.__name__, type(msg).__name__)) | 1237 "expected %s got %s." % (cls.__name__, type(msg).__name__)) |
923 | 1238 |
924 assert msg is not self | 1239 assert msg is not self |
925 self._Modified() | 1240 self._Modified() |
926 | 1241 |
927 fields = self._fields | 1242 fields = self._fields |
928 | 1243 |
929 for field, value in msg._fields.iteritems(): | 1244 for field, value in msg._fields.items(): |
930 if field.label == LABEL_REPEATED: | 1245 if field.label == LABEL_REPEATED: |
931 field_value = fields.get(field) | 1246 field_value = fields.get(field) |
932 if field_value is None: | 1247 if field_value is None: |
933 # Construct a new object to represent this field. | 1248 # Construct a new object to represent this field. |
934 field_value = field._default_constructor(self) | 1249 field_value = field._default_constructor(self) |
935 fields[field] = field_value | 1250 fields[field] = field_value |
936 field_value.MergeFrom(value) | 1251 field_value.MergeFrom(value) |
937 elif field.cpp_type == CPPTYPE_MESSAGE: | 1252 elif field.cpp_type == CPPTYPE_MESSAGE: |
938 if value._is_present_in_parent: | 1253 if value._is_present_in_parent: |
939 field_value = fields.get(field) | 1254 field_value = fields.get(field) |
940 if field_value is None: | 1255 if field_value is None: |
941 # Construct a new object to represent this field. | 1256 # Construct a new object to represent this field. |
942 field_value = field._default_constructor(self) | 1257 field_value = field._default_constructor(self) |
943 fields[field] = field_value | 1258 fields[field] = field_value |
944 field_value.MergeFrom(value) | 1259 field_value.MergeFrom(value) |
945 else: | 1260 else: |
946 self._fields[field] = value | 1261 self._fields[field] = value |
| 1262 if field.containing_oneof: |
| 1263 self._UpdateOneofState(field) |
947 | 1264 |
948 if msg._unknown_fields: | 1265 if msg._unknown_fields: |
949 if not self._unknown_fields: | 1266 if not self._unknown_fields: |
950 self._unknown_fields = [] | 1267 self._unknown_fields = [] |
951 self._unknown_fields.extend(msg._unknown_fields) | 1268 self._unknown_fields.extend(msg._unknown_fields) |
952 | 1269 |
953 cls.MergeFrom = MergeFrom | 1270 cls.MergeFrom = MergeFrom |
954 | 1271 |
955 | 1272 |
| 1273 def _AddWhichOneofMethod(message_descriptor, cls): |
| 1274 def WhichOneof(self, oneof_name): |
| 1275 """Returns the name of the currently set field inside a oneof, or None.""" |
| 1276 try: |
| 1277 field = message_descriptor.oneofs_by_name[oneof_name] |
| 1278 except KeyError: |
| 1279 raise ValueError( |
| 1280 'Protocol message has no oneof "%s" field.' % oneof_name) |
| 1281 |
| 1282 nested_field = self._oneofs.get(field, None) |
| 1283 if nested_field is not None and self.HasField(nested_field.name): |
| 1284 return nested_field.name |
| 1285 else: |
| 1286 return None |
| 1287 |
| 1288 cls.WhichOneof = WhichOneof |
| 1289 |
| 1290 |
956 def _AddMessageMethods(message_descriptor, cls): | 1291 def _AddMessageMethods(message_descriptor, cls): |
957 """Adds implementations of all Message methods to cls.""" | 1292 """Adds implementations of all Message methods to cls.""" |
958 _AddListFieldsMethod(message_descriptor, cls) | 1293 _AddListFieldsMethod(message_descriptor, cls) |
959 _AddHasFieldMethod(message_descriptor, cls) | 1294 _AddHasFieldMethod(message_descriptor, cls) |
960 _AddClearFieldMethod(message_descriptor, cls) | 1295 _AddClearFieldMethod(message_descriptor, cls) |
961 if message_descriptor.is_extendable: | 1296 if message_descriptor.is_extendable: |
962 _AddClearExtensionMethod(cls) | 1297 _AddClearExtensionMethod(cls) |
963 _AddHasExtensionMethod(cls) | 1298 _AddHasExtensionMethod(cls) |
964 _AddClearMethod(message_descriptor, cls) | 1299 _AddClearMethod(message_descriptor, cls) |
965 _AddEqualsMethod(message_descriptor, cls) | 1300 _AddEqualsMethod(message_descriptor, cls) |
966 _AddStrMethod(message_descriptor, cls) | 1301 _AddStrMethod(message_descriptor, cls) |
| 1302 _AddReprMethod(message_descriptor, cls) |
967 _AddUnicodeMethod(message_descriptor, cls) | 1303 _AddUnicodeMethod(message_descriptor, cls) |
968 _AddSetListenerMethod(cls) | 1304 _AddSetListenerMethod(cls) |
969 _AddByteSizeMethod(message_descriptor, cls) | 1305 _AddByteSizeMethod(message_descriptor, cls) |
970 _AddSerializeToStringMethod(message_descriptor, cls) | 1306 _AddSerializeToStringMethod(message_descriptor, cls) |
971 _AddSerializePartialToStringMethod(message_descriptor, cls) | 1307 _AddSerializePartialToStringMethod(message_descriptor, cls) |
972 _AddMergeFromStringMethod(message_descriptor, cls) | 1308 _AddMergeFromStringMethod(message_descriptor, cls) |
973 _AddIsInitializedMethod(message_descriptor, cls) | 1309 _AddIsInitializedMethod(message_descriptor, cls) |
974 _AddMergeFromMethod(cls) | 1310 _AddMergeFromMethod(cls) |
| 1311 _AddWhichOneofMethod(message_descriptor, cls) |
975 | 1312 |
976 | 1313 |
977 def _AddPrivateHelperMethods(cls): | 1314 def _AddPrivateHelperMethods(message_descriptor, cls): |
978 """Adds implementation of private helper methods to cls.""" | 1315 """Adds implementation of private helper methods to cls.""" |
979 | 1316 |
980 def Modified(self): | 1317 def Modified(self): |
981 """Sets the _cached_byte_size_dirty bit to true, | 1318 """Sets the _cached_byte_size_dirty bit to true, |
982 and propagates this to our listener iff this was a state change. | 1319 and propagates this to our listener iff this was a state change. |
983 """ | 1320 """ |
984 | 1321 |
985 # Note: Some callers check _cached_byte_size_dirty before calling | 1322 # Note: Some callers check _cached_byte_size_dirty before calling |
986 # _Modified() as an extra optimization. So, if this method is ever | 1323 # _Modified() as an extra optimization. So, if this method is ever |
987 # changed such that it does stuff even when _cached_byte_size_dirty is | 1324 # changed such that it does stuff even when _cached_byte_size_dirty is |
988 # already true, the callers need to be updated. | 1325 # already true, the callers need to be updated. |
989 if not self._cached_byte_size_dirty: | 1326 if not self._cached_byte_size_dirty: |
990 self._cached_byte_size_dirty = True | 1327 self._cached_byte_size_dirty = True |
991 self._listener_for_children.dirty = True | 1328 self._listener_for_children.dirty = True |
992 self._is_present_in_parent = True | 1329 self._is_present_in_parent = True |
993 self._listener.Modified() | 1330 self._listener.Modified() |
994 | 1331 |
| 1332 def _UpdateOneofState(self, field): |
| 1333 """Sets field as the active field in its containing oneof. |
| 1334 |
| 1335 Will also delete currently active field in the oneof, if it is different |
| 1336 from the argument. Does not mark the message as modified. |
| 1337 """ |
| 1338 other_field = self._oneofs.setdefault(field.containing_oneof, field) |
| 1339 if other_field is not field: |
| 1340 del self._fields[other_field] |
| 1341 self._oneofs[field.containing_oneof] = field |
| 1342 |
995 cls._Modified = Modified | 1343 cls._Modified = Modified |
996 cls.SetInParent = Modified | 1344 cls.SetInParent = Modified |
| 1345 cls._UpdateOneofState = _UpdateOneofState |
997 | 1346 |
998 | 1347 |
999 class _Listener(object): | 1348 class _Listener(object): |
1000 | 1349 |
1001 """MessageListener implementation that a parent message registers with its | 1350 """MessageListener implementation that a parent message registers with its |
1002 child message. | 1351 child message. |
1003 | 1352 |
1004 In order to support semantics like: | 1353 In order to support semantics like: |
1005 | 1354 |
1006 foo.bar.baz.qux = 23 | 1355 foo.bar.baz.qux = 23 |
(...skipping 28 matching lines...) Expand all Loading... |
1035 try: | 1384 try: |
1036 # Propagate the signal to our parents iff this is the first field set. | 1385 # Propagate the signal to our parents iff this is the first field set. |
1037 self._parent_message_weakref._Modified() | 1386 self._parent_message_weakref._Modified() |
1038 except ReferenceError: | 1387 except ReferenceError: |
1039 # We can get here if a client has kept a reference to a child object, | 1388 # We can get here if a client has kept a reference to a child object, |
1040 # and is now setting a field on it, but the child's parent has been | 1389 # and is now setting a field on it, but the child's parent has been |
1041 # garbage-collected. This is not an error. | 1390 # garbage-collected. This is not an error. |
1042 pass | 1391 pass |
1043 | 1392 |
1044 | 1393 |
| 1394 class _OneofListener(_Listener): |
| 1395 """Special listener implementation for setting composite oneof fields.""" |
| 1396 |
| 1397 def __init__(self, parent_message, field): |
| 1398 """Args: |
| 1399 parent_message: The message whose _Modified() method we should call when |
| 1400 we receive Modified() messages. |
| 1401 field: The descriptor of the field being set in the parent message. |
| 1402 """ |
| 1403 super(_OneofListener, self).__init__(parent_message) |
| 1404 self._field = field |
| 1405 |
| 1406 def Modified(self): |
| 1407 """Also updates the state of the containing oneof in the parent message.""" |
| 1408 try: |
| 1409 self._parent_message_weakref._UpdateOneofState(self._field) |
| 1410 super(_OneofListener, self).Modified() |
| 1411 except ReferenceError: |
| 1412 pass |
| 1413 |
| 1414 |
1045 # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... | 1415 # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... |
1046 # TODO(robinson): Unify error handling of "unknown extension" crap. | 1416 # TODO(robinson): Unify error handling of "unknown extension" crap. |
1047 # TODO(robinson): Support iteritems()-style iteration over all | 1417 # TODO(robinson): Support iteritems()-style iteration over all |
1048 # extensions with the "has" bits turned on? | 1418 # extensions with the "has" bits turned on? |
1049 class _ExtensionDict(object): | 1419 class _ExtensionDict(object): |
1050 | 1420 |
1051 """Dict-like container for supporting an indexable "Extensions" | 1421 """Dict-like container for supporting an indexable "Extensions" |
1052 field on proto instances. | 1422 field on proto instances. |
1053 | 1423 |
1054 Note that in all cases we expect extension handles to be | 1424 Note that in all cases we expect extension handles to be |
(...skipping 70 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
1125 _VerifyExtensionHandle(self._extended_message, extension_handle) | 1495 _VerifyExtensionHandle(self._extended_message, extension_handle) |
1126 | 1496 |
1127 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or | 1497 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or |
1128 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): | 1498 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): |
1129 raise TypeError( | 1499 raise TypeError( |
1130 'Cannot assign to extension "%s" because it is a repeated or ' | 1500 'Cannot assign to extension "%s" because it is a repeated or ' |
1131 'composite type.' % extension_handle.full_name) | 1501 'composite type.' % extension_handle.full_name) |
1132 | 1502 |
1133 # It's slightly wasteful to lookup the type checker each time, | 1503 # It's slightly wasteful to lookup the type checker each time, |
1134 # but we expect this to be a vanishingly uncommon case anyway. | 1504 # but we expect this to be a vanishingly uncommon case anyway. |
1135 type_checker = type_checkers.GetTypeChecker( | 1505 type_checker = type_checkers.GetTypeChecker(extension_handle) |
1136 extension_handle.cpp_type, extension_handle.type) | 1506 # pylint: disable=protected-access |
1137 type_checker.CheckValue(value) | 1507 self._extended_message._fields[extension_handle] = ( |
1138 self._extended_message._fields[extension_handle] = value | 1508 type_checker.CheckValue(value)) |
1139 self._extended_message._Modified() | 1509 self._extended_message._Modified() |
1140 | 1510 |
1141 def _FindExtensionByName(self, name): | 1511 def _FindExtensionByName(self, name): |
1142 """Tries to find a known extension with the specified name. | 1512 """Tries to find a known extension with the specified name. |
1143 | 1513 |
1144 Args: | 1514 Args: |
1145 name: Extension full name. | 1515 name: Extension full name. |
1146 | 1516 |
1147 Returns: | 1517 Returns: |
1148 Extension field descriptor. | 1518 Extension field descriptor. |
1149 """ | 1519 """ |
1150 return self._extended_message._extensions_by_name.get(name, None) | 1520 return self._extended_message._extensions_by_name.get(name, None) |
OLD | NEW |