Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(5)

Side by Side Diff: third_party/protobuf/python/google/protobuf/internal/python_message.py

Issue 1842653006: Update //third_party/protobuf to version 3. (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@master
Patch Set: update README.chromium Created 4 years, 8 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
1 # Protocol Buffers - Google's data interchange format 1 # Protocol Buffers - Google's data interchange format
2 # Copyright 2008 Google Inc. All rights reserved. 2 # Copyright 2008 Google Inc. All rights reserved.
3 # http://code.google.com/p/protobuf/ 3 # https://developers.google.com/protocol-buffers/
4 # 4 #
5 # Redistribution and use in source and binary forms, with or without 5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are 6 # modification, are permitted provided that the following conditions are
7 # met: 7 # met:
8 # 8 #
9 # * Redistributions of source code must retain the above copyright 9 # * Redistributions of source code must retain the above copyright
10 # notice, this list of conditions and the following disclaimer. 10 # notice, this list of conditions and the following disclaimer.
11 # * Redistributions in binary form must reproduce the above 11 # * Redistributions in binary form must reproduce the above
12 # copyright notice, this list of conditions and the following disclaimer 12 # copyright notice, this list of conditions and the following disclaimer
13 # in the documentation and/or other materials provided with the 13 # in the documentation and/or other materials provided with the
(...skipping 29 matching lines...) Expand all
43 to inject all the useful functionality into the classes 43 to inject all the useful functionality into the classes
44 output by the protocol compiler at compile-time. 44 output by the protocol compiler at compile-time.
45 45
46 The upshot of all this is that the real implementation 46 The upshot of all this is that the real implementation
47 details for ALL pure-Python protocol buffers are *here in 47 details for ALL pure-Python protocol buffers are *here in
48 this file*. 48 this file*.
49 """ 49 """
50 50
51 __author__ = 'robinson@google.com (Will Robinson)' 51 __author__ = 'robinson@google.com (Will Robinson)'
52 52
53 try: 53 from io import BytesIO
54 from cStringIO import StringIO 54 import sys
55 except ImportError:
56 from StringIO import StringIO
57 import copy_reg
58 import struct 55 import struct
59 import weakref 56 import weakref
60 57
58 import six
59 import six.moves.copyreg as copyreg
60
61 # We use "as" to avoid name collisions with variables. 61 # We use "as" to avoid name collisions with variables.
62 from google.protobuf.internal import containers 62 from google.protobuf.internal import containers
63 from google.protobuf.internal import decoder 63 from google.protobuf.internal import decoder
64 from google.protobuf.internal import encoder 64 from google.protobuf.internal import encoder
65 from google.protobuf.internal import enum_type_wrapper 65 from google.protobuf.internal import enum_type_wrapper
66 from google.protobuf.internal import message_listener as message_listener_mod 66 from google.protobuf.internal import message_listener as message_listener_mod
67 from google.protobuf.internal import type_checkers 67 from google.protobuf.internal import type_checkers
68 from google.protobuf.internal import well_known_types
68 from google.protobuf.internal import wire_format 69 from google.protobuf.internal import wire_format
69 from google.protobuf import descriptor as descriptor_mod 70 from google.protobuf import descriptor as descriptor_mod
70 from google.protobuf import message as message_mod 71 from google.protobuf import message as message_mod
72 from google.protobuf import symbol_database
71 from google.protobuf import text_format 73 from google.protobuf import text_format
72 74
73 _FieldDescriptor = descriptor_mod.FieldDescriptor 75 _FieldDescriptor = descriptor_mod.FieldDescriptor
76 _AnyFullTypeName = 'google.protobuf.Any'
74 77
75 78
76 def NewMessage(bases, descriptor, dictionary): 79 class GeneratedProtocolMessageType(type):
77 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
78 _AddSlots(descriptor, dictionary)
79 return bases
80 80
81 """Metaclass for protocol message classes created at runtime from Descriptors.
81 82
82 def InitMessage(descriptor, cls): 83 We add implementations for all methods described in the Message class. We
83 cls._decoders_by_tag = {} 84 also create properties to allow getting/setting all fields in the protocol
84 cls._extensions_by_name = {} 85 message. Finally, we create slots to prevent users from accidentally
85 cls._extensions_by_number = {} 86 "setting" nonexistent fields in the protocol message, which then wouldn't get
86 if (descriptor.has_options and 87 serialized / deserialized properly.
87 descriptor.GetOptions().message_set_wire_format):
88 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
89 decoder.MessageSetItemDecoder(cls._extensions_by_number))
90 88
91 # Attach stuff to each FieldDescriptor for quick lookup later on. 89 The protocol compiler currently uses this metaclass to create protocol
92 for field in descriptor.fields: 90 message classes at runtime. Clients can also manually create their own
93 _AttachFieldHelpers(cls, field) 91 classes at runtime, as in this example:
94 92
95 _AddEnumValues(descriptor, cls) 93 mydescriptor = Descriptor(.....)
96 _AddInitMethod(descriptor, cls) 94 class MyProtoClass(Message):
97 _AddPropertiesForFields(descriptor, cls) 95 __metaclass__ = GeneratedProtocolMessageType
98 _AddPropertiesForExtensions(descriptor, cls) 96 DESCRIPTOR = mydescriptor
99 _AddStaticMethods(cls) 97 myproto_instance = MyProtoClass()
100 _AddMessageMethods(descriptor, cls) 98 myproto.foo_field = 23
101 _AddPrivateHelperMethods(cls) 99 ...
102 copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) 100
101 The above example will not work for nested types. If you wish to include them,
102 use reflection.MakeClass() instead of manually instantiating the class in
103 order to create the appropriate class structure.
104 """
105
106 # Must be consistent with the protocol-compiler code in
107 # proto2/compiler/internal/generator.*.
108 _DESCRIPTOR_KEY = 'DESCRIPTOR'
109
110 def __new__(cls, name, bases, dictionary):
111 """Custom allocation for runtime-generated class types.
112
113 We override __new__ because this is apparently the only place
114 where we can meaningfully set __slots__ on the class we're creating(?).
115 (The interplay between metaclasses and slots is not very well-documented).
116
117 Args:
118 name: Name of the class (ignored, but required by the
119 metaclass protocol).
120 bases: Base classes of the class we're constructing.
121 (Should be message.Message). We ignore this field, but
122 it's required by the metaclass protocol
123 dictionary: The class dictionary of the class we're
124 constructing. dictionary[_DESCRIPTOR_KEY] must contain
125 a Descriptor object describing this protocol message
126 type.
127
128 Returns:
129 Newly-allocated class.
130 """
131 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
132 if descriptor.full_name in well_known_types.WKTBASES:
133 bases += (well_known_types.WKTBASES[descriptor.full_name],)
134 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
135 _AddSlots(descriptor, dictionary)
136
137 superclass = super(GeneratedProtocolMessageType, cls)
138 new_class = superclass.__new__(cls, name, bases, dictionary)
139 return new_class
140
141 def __init__(cls, name, bases, dictionary):
142 """Here we perform the majority of our work on the class.
143 We add enum getters, an __init__ method, implementations
144 of all Message methods, and properties for all fields
145 in the protocol type.
146
147 Args:
148 name: Name of the class (ignored, but required by the
149 metaclass protocol).
150 bases: Base classes of the class we're constructing.
151 (Should be message.Message). We ignore this field, but
152 it's required by the metaclass protocol
153 dictionary: The class dictionary of the class we're
154 constructing. dictionary[_DESCRIPTOR_KEY] must contain
155 a Descriptor object describing this protocol message
156 type.
157 """
158 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
159 cls._decoders_by_tag = {}
160 cls._extensions_by_name = {}
161 cls._extensions_by_number = {}
162 if (descriptor.has_options and
163 descriptor.GetOptions().message_set_wire_format):
164 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
165 decoder.MessageSetItemDecoder(cls._extensions_by_number), None)
166
167 # Attach stuff to each FieldDescriptor for quick lookup later on.
168 for field in descriptor.fields:
169 _AttachFieldHelpers(cls, field)
170
171 descriptor._concrete_class = cls # pylint: disable=protected-access
172 _AddEnumValues(descriptor, cls)
173 _AddInitMethod(descriptor, cls)
174 _AddPropertiesForFields(descriptor, cls)
175 _AddPropertiesForExtensions(descriptor, cls)
176 _AddStaticMethods(cls)
177 _AddMessageMethods(descriptor, cls)
178 _AddPrivateHelperMethods(descriptor, cls)
179 copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
180
181 superclass = super(GeneratedProtocolMessageType, cls)
182 superclass.__init__(name, bases, dictionary)
103 183
104 184
105 # Stateless helpers for GeneratedProtocolMessageType below. 185 # Stateless helpers for GeneratedProtocolMessageType below.
106 # Outside clients should not access these directly. 186 # Outside clients should not access these directly.
107 # 187 #
108 # I opted not to make any of these methods on the metaclass, to make it more 188 # I opted not to make any of these methods on the metaclass, to make it more
109 # clear that I'm not really using any state there and to keep clients from 189 # clear that I'm not really using any state there and to keep clients from
110 # thinking that they have direct access to these construction helpers. 190 # thinking that they have direct access to these construction helpers.
111 191
112 192
(...skipping 56 matching lines...) Expand 10 before | Expand all | Expand 10 after
169 message_descriptor: A Descriptor instance describing this message type. 249 message_descriptor: A Descriptor instance describing this message type.
170 dictionary: Class dictionary to which we'll add a '__slots__' entry. 250 dictionary: Class dictionary to which we'll add a '__slots__' entry.
171 """ 251 """
172 dictionary['__slots__'] = ['_cached_byte_size', 252 dictionary['__slots__'] = ['_cached_byte_size',
173 '_cached_byte_size_dirty', 253 '_cached_byte_size_dirty',
174 '_fields', 254 '_fields',
175 '_unknown_fields', 255 '_unknown_fields',
176 '_is_present_in_parent', 256 '_is_present_in_parent',
177 '_listener', 257 '_listener',
178 '_listener_for_children', 258 '_listener_for_children',
179 '__weakref__'] 259 '__weakref__',
260 '_oneofs']
180 261
181 262
182 def _IsMessageSetExtension(field): 263 def _IsMessageSetExtension(field):
183 return (field.is_extension and 264 return (field.is_extension and
184 field.containing_type.has_options and 265 field.containing_type.has_options and
185 field.containing_type.GetOptions().message_set_wire_format and 266 field.containing_type.GetOptions().message_set_wire_format and
186 field.type == _FieldDescriptor.TYPE_MESSAGE and 267 field.type == _FieldDescriptor.TYPE_MESSAGE and
187 field.message_type == field.extension_scope and
188 field.label == _FieldDescriptor.LABEL_OPTIONAL) 268 field.label == _FieldDescriptor.LABEL_OPTIONAL)
189 269
190 270
271 def _IsMapField(field):
272 return (field.type == _FieldDescriptor.TYPE_MESSAGE and
273 field.message_type.has_options and
274 field.message_type.GetOptions().map_entry)
275
276
277 def _IsMessageMapField(field):
278 value_type = field.message_type.fields_by_name["value"]
279 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
280
281
191 def _AttachFieldHelpers(cls, field_descriptor): 282 def _AttachFieldHelpers(cls, field_descriptor):
192 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) 283 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
193 is_packed = (field_descriptor.has_options and 284 is_packable = (is_repeated and
194 field_descriptor.GetOptions().packed) 285 wire_format.IsTypePackable(field_descriptor.type))
286 if not is_packable:
287 is_packed = False
288 elif field_descriptor.containing_type.syntax == "proto2":
289 is_packed = (field_descriptor.has_options and
290 field_descriptor.GetOptions().packed)
291 else:
292 has_packed_false = (field_descriptor.has_options and
293 field_descriptor.GetOptions().HasField("packed") and
294 field_descriptor.GetOptions().packed == False)
295 is_packed = not has_packed_false
296 is_map_entry = _IsMapField(field_descriptor)
195 297
196 if _IsMessageSetExtension(field_descriptor): 298 if is_map_entry:
299 field_encoder = encoder.MapEncoder(field_descriptor)
300 sizer = encoder.MapSizer(field_descriptor)
301 elif _IsMessageSetExtension(field_descriptor):
197 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) 302 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
198 sizer = encoder.MessageSetItemSizer(field_descriptor.number) 303 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
199 else: 304 else:
200 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( 305 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
201 field_descriptor.number, is_repeated, is_packed) 306 field_descriptor.number, is_repeated, is_packed)
202 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( 307 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
203 field_descriptor.number, is_repeated, is_packed) 308 field_descriptor.number, is_repeated, is_packed)
204 309
205 field_descriptor._encoder = field_encoder 310 field_descriptor._encoder = field_encoder
206 field_descriptor._sizer = sizer 311 field_descriptor._sizer = sizer
207 field_descriptor._default_constructor = _DefaultValueConstructorForField( 312 field_descriptor._default_constructor = _DefaultValueConstructorForField(
208 field_descriptor) 313 field_descriptor)
209 314
210 def AddDecoder(wiretype, is_packed): 315 def AddDecoder(wiretype, is_packed):
211 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) 316 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
212 cls._decoders_by_tag[tag_bytes] = ( 317 decode_type = field_descriptor.type
213 type_checkers.TYPE_TO_DECODER[field_descriptor.type]( 318 if (decode_type == _FieldDescriptor.TYPE_ENUM and
214 field_descriptor.number, is_repeated, is_packed, 319 type_checkers.SupportsOpenEnums(field_descriptor)):
215 field_descriptor, field_descriptor._default_constructor)) 320 decode_type = _FieldDescriptor.TYPE_INT32
321
322 oneof_descriptor = None
323 if field_descriptor.containing_oneof is not None:
324 oneof_descriptor = field_descriptor
325
326 if is_map_entry:
327 is_message_map = _IsMessageMapField(field_descriptor)
328
329 field_decoder = decoder.MapDecoder(
330 field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
331 is_message_map)
332 else:
333 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
334 field_descriptor.number, is_repeated, is_packed,
335 field_descriptor, field_descriptor._default_constructor)
336
337 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
216 338
217 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], 339 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
218 False) 340 False)
219 341
220 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 342 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
221 # To support wire compatibility of adding packed = true, add a decoder for 343 # To support wire compatibility of adding packed = true, add a decoder for
222 # packed values regardless of the field's options. 344 # packed values regardless of the field's options.
223 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) 345 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
224 346
225 347
226 def _AddClassAttributesForNestedExtensions(descriptor, dictionary): 348 def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
227 extension_dict = descriptor.extensions_by_name 349 extension_dict = descriptor.extensions_by_name
228 for extension_name, extension_field in extension_dict.iteritems(): 350 for extension_name, extension_field in extension_dict.items():
229 assert extension_name not in dictionary 351 assert extension_name not in dictionary
230 dictionary[extension_name] = extension_field 352 dictionary[extension_name] = extension_field
231 353
232 354
233 def _AddEnumValues(descriptor, cls): 355 def _AddEnumValues(descriptor, cls):
234 """Sets class-level attributes for all enum fields defined in this message. 356 """Sets class-level attributes for all enum fields defined in this message.
235 357
236 Also exporting a class-level object that can name enum values. 358 Also exporting a class-level object that can name enum values.
237 359
238 Args: 360 Args:
239 descriptor: Descriptor object for this message type. 361 descriptor: Descriptor object for this message type.
240 cls: Class we're constructing for this message type. 362 cls: Class we're constructing for this message type.
241 """ 363 """
242 for enum_type in descriptor.enum_types: 364 for enum_type in descriptor.enum_types:
243 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) 365 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
244 for enum_value in enum_type.values: 366 for enum_value in enum_type.values:
245 setattr(cls, enum_value.name, enum_value.number) 367 setattr(cls, enum_value.name, enum_value.number)
246 368
247 369
370 def _GetInitializeDefaultForMap(field):
371 if field.label != _FieldDescriptor.LABEL_REPEATED:
372 raise ValueError('map_entry set on non-repeated field %s' % (
373 field.name))
374 fields_by_name = field.message_type.fields_by_name
375 key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
376
377 value_field = fields_by_name['value']
378 if _IsMessageMapField(field):
379 def MakeMessageMapDefault(message):
380 return containers.MessageMap(
381 message._listener_for_children, value_field.message_type, key_checker)
382 return MakeMessageMapDefault
383 else:
384 value_checker = type_checkers.GetTypeChecker(value_field)
385 def MakePrimitiveMapDefault(message):
386 return containers.ScalarMap(
387 message._listener_for_children, key_checker, value_checker)
388 return MakePrimitiveMapDefault
389
248 def _DefaultValueConstructorForField(field): 390 def _DefaultValueConstructorForField(field):
249 """Returns a function which returns a default value for a field. 391 """Returns a function which returns a default value for a field.
250 392
251 Args: 393 Args:
252 field: FieldDescriptor object for this field. 394 field: FieldDescriptor object for this field.
253 395
254 The returned function has one argument: 396 The returned function has one argument:
255 message: Message instance containing this field, or a weakref proxy 397 message: Message instance containing this field, or a weakref proxy
256 of same. 398 of same.
257 399
258 That function in turn returns a default value for this field. The default 400 That function in turn returns a default value for this field. The default
259 value may refer back to |message| via a weak reference. 401 value may refer back to |message| via a weak reference.
260 """ 402 """
261 403
404 if _IsMapField(field):
405 return _GetInitializeDefaultForMap(field)
406
262 if field.label == _FieldDescriptor.LABEL_REPEATED: 407 if field.label == _FieldDescriptor.LABEL_REPEATED:
263 if field.has_default_value and field.default_value != []: 408 if field.has_default_value and field.default_value != []:
264 raise ValueError('Repeated field default value not empty list: %s' % ( 409 raise ValueError('Repeated field default value not empty list: %s' % (
265 field.default_value)) 410 field.default_value))
266 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 411 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
267 # We can't look at _concrete_class yet since it might not have 412 # We can't look at _concrete_class yet since it might not have
268 # been set. (Depends on order in which we initialize the classes). 413 # been set. (Depends on order in which we initialize the classes).
269 message_type = field.message_type 414 message_type = field.message_type
270 def MakeRepeatedMessageDefault(message): 415 def MakeRepeatedMessageDefault(message):
271 return containers.RepeatedCompositeFieldContainer( 416 return containers.RepeatedCompositeFieldContainer(
272 message._listener_for_children, field.message_type) 417 message._listener_for_children, field.message_type)
273 return MakeRepeatedMessageDefault 418 return MakeRepeatedMessageDefault
274 else: 419 else:
275 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) 420 type_checker = type_checkers.GetTypeChecker(field)
276 def MakeRepeatedScalarDefault(message): 421 def MakeRepeatedScalarDefault(message):
277 return containers.RepeatedScalarFieldContainer( 422 return containers.RepeatedScalarFieldContainer(
278 message._listener_for_children, type_checker) 423 message._listener_for_children, type_checker)
279 return MakeRepeatedScalarDefault 424 return MakeRepeatedScalarDefault
280 425
281 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 426 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
282 # _concrete_class may not yet be initialized. 427 # _concrete_class may not yet be initialized.
283 message_type = field.message_type 428 message_type = field.message_type
284 def MakeSubMessageDefault(message): 429 def MakeSubMessageDefault(message):
285 result = message_type._concrete_class() 430 result = message_type._concrete_class()
286 result._SetListener(message._listener_for_children) 431 result._SetListener(
432 _OneofListener(message, field)
433 if field.containing_oneof is not None
434 else message._listener_for_children)
287 return result 435 return result
288 return MakeSubMessageDefault 436 return MakeSubMessageDefault
289 437
290 def MakeScalarDefault(message): 438 def MakeScalarDefault(message):
291 # TODO(protobuf-team): This may be broken since there may not be 439 # TODO(protobuf-team): This may be broken since there may not be
292 # default_value. Combine with has_default_value somehow. 440 # default_value. Combine with has_default_value somehow.
293 return field.default_value 441 return field.default_value
294 return MakeScalarDefault 442 return MakeScalarDefault
295 443
296 444
445 def _ReraiseTypeErrorWithFieldName(message_name, field_name):
446 """Re-raise the currently-handled TypeError with the field name added."""
447 exc = sys.exc_info()[1]
448 if len(exc.args) == 1 and type(exc) is TypeError:
449 # simple TypeError; add field name to exception message
450 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
451
452 # re-raise possibly-amended exception with original traceback:
453 six.reraise(type(exc), exc, sys.exc_info()[2])
454
455
297 def _AddInitMethod(message_descriptor, cls): 456 def _AddInitMethod(message_descriptor, cls):
298 """Adds an __init__ method to cls.""" 457 """Adds an __init__ method to cls."""
299 fields = message_descriptor.fields 458
459 def _GetIntegerEnumValue(enum_type, value):
460 """Convert a string or integer enum value to an integer.
461
462 If the value is a string, it is converted to the enum value in
463 enum_type with the same name. If the value is not a string, it's
464 returned as-is. (No conversion or bounds-checking is done.)
465 """
466 if isinstance(value, six.string_types):
467 try:
468 return enum_type.values_by_name[value].number
469 except KeyError:
470 raise ValueError('Enum type %s: unknown label "%s"' % (
471 enum_type.full_name, value))
472 return value
473
300 def init(self, **kwargs): 474 def init(self, **kwargs):
301 self._cached_byte_size = 0 475 self._cached_byte_size = 0
302 self._cached_byte_size_dirty = len(kwargs) > 0 476 self._cached_byte_size_dirty = len(kwargs) > 0
303 self._fields = {} 477 self._fields = {}
478 # Contains a mapping from oneof field descriptors to the descriptor
479 # of the currently set field in that oneof field.
480 self._oneofs = {}
481
304 # _unknown_fields is () when empty for efficiency, and will be turned into 482 # _unknown_fields is () when empty for efficiency, and will be turned into
305 # a list if fields are added. 483 # a list if fields are added.
306 self._unknown_fields = () 484 self._unknown_fields = ()
307 self._is_present_in_parent = False 485 self._is_present_in_parent = False
308 self._listener = message_listener_mod.NullMessageListener() 486 self._listener = message_listener_mod.NullMessageListener()
309 self._listener_for_children = _Listener(self) 487 self._listener_for_children = _Listener(self)
310 for field_name, field_value in kwargs.iteritems(): 488 for field_name, field_value in kwargs.items():
311 field = _GetFieldByName(message_descriptor, field_name) 489 field = _GetFieldByName(message_descriptor, field_name)
312 if field is None: 490 if field is None:
313 raise TypeError("%s() got an unexpected keyword argument '%s'" % 491 raise TypeError("%s() got an unexpected keyword argument '%s'" %
314 (message_descriptor.name, field_name)) 492 (message_descriptor.name, field_name))
315 if field.label == _FieldDescriptor.LABEL_REPEATED: 493 if field.label == _FieldDescriptor.LABEL_REPEATED:
316 copy = field._default_constructor(self) 494 copy = field._default_constructor(self)
317 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite 495 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
318 for val in field_value: 496 if _IsMapField(field):
319 copy.add().MergeFrom(val) 497 if _IsMessageMapField(field):
498 for key in field_value:
499 copy[key].MergeFrom(field_value[key])
500 else:
501 copy.update(field_value)
502 else:
503 for val in field_value:
504 if isinstance(val, dict):
505 copy.add(**val)
506 else:
507 copy.add().MergeFrom(val)
320 else: # Scalar 508 else: # Scalar
509 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
510 field_value = [_GetIntegerEnumValue(field.enum_type, val)
511 for val in field_value]
321 copy.extend(field_value) 512 copy.extend(field_value)
322 self._fields[field] = copy 513 self._fields[field] = copy
323 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 514 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
324 copy = field._default_constructor(self) 515 copy = field._default_constructor(self)
325 copy.MergeFrom(field_value) 516 new_val = field_value
517 if isinstance(field_value, dict):
518 new_val = field.message_type._concrete_class(**field_value)
519 try:
520 copy.MergeFrom(new_val)
521 except TypeError:
522 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
326 self._fields[field] = copy 523 self._fields[field] = copy
327 else: 524 else:
328 setattr(self, field_name, field_value) 525 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
526 field_value = _GetIntegerEnumValue(field.enum_type, field_value)
527 try:
528 setattr(self, field_name, field_value)
529 except TypeError:
530 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
329 531
330 init.__module__ = None 532 init.__module__ = None
331 init.__doc__ = None 533 init.__doc__ = None
332 cls.__init__ = init 534 cls.__init__ = init
333 535
334 536
335 def _GetFieldByName(message_descriptor, field_name): 537 def _GetFieldByName(message_descriptor, field_name):
336 """Returns a field descriptor by field name. 538 """Returns a field descriptor by field name.
337 539
338 Args: 540 Args:
339 message_descriptor: A Descriptor describing all fields in message. 541 message_descriptor: A Descriptor describing all fields in message.
340 field_name: The name of the field to retrieve. 542 field_name: The name of the field to retrieve.
341 Returns: 543 Returns:
342 The field descriptor associated with the field name. 544 The field descriptor associated with the field name.
343 """ 545 """
344 try: 546 try:
345 return message_descriptor.fields_by_name[field_name] 547 return message_descriptor.fields_by_name[field_name]
346 except KeyError: 548 except KeyError:
347 raise ValueError('Protocol message has no "%s" field.' % field_name) 549 raise ValueError('Protocol message %s has no "%s" field.' %
550 (message_descriptor.name, field_name))
348 551
349 552
350 def _AddPropertiesForFields(descriptor, cls): 553 def _AddPropertiesForFields(descriptor, cls):
351 """Adds properties for all fields in this protocol message type.""" 554 """Adds properties for all fields in this protocol message type."""
352 for field in descriptor.fields: 555 for field in descriptor.fields:
353 _AddPropertiesForField(field, cls) 556 _AddPropertiesForField(field, cls)
354 557
355 if descriptor.is_extendable: 558 if descriptor.is_extendable:
356 # _ExtensionDict is just an adaptor with no state so we allocate a new one 559 # _ExtensionDict is just an adaptor with no state so we allocate a new one
357 # every time it is accessed. 560 # every time it is accessed.
(...skipping 75 matching lines...) Expand 10 before | Expand all | Expand 10 after
433 Note that when the client sets the value of a field by using this property, 636 Note that when the client sets the value of a field by using this property,
434 all necessary "has" bits are set as a side-effect, and we also perform 637 all necessary "has" bits are set as a side-effect, and we also perform
435 type-checking. 638 type-checking.
436 639
437 Args: 640 Args:
438 field: A FieldDescriptor for this field. 641 field: A FieldDescriptor for this field.
439 cls: The class we're constructing. 642 cls: The class we're constructing.
440 """ 643 """
441 proto_field_name = field.name 644 proto_field_name = field.name
442 property_name = _PropertyName(proto_field_name) 645 property_name = _PropertyName(proto_field_name)
443 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) 646 type_checker = type_checkers.GetTypeChecker(field)
444 default_value = field.default_value 647 default_value = field.default_value
445 valid_values = set() 648 valid_values = set()
649 is_proto3 = field.containing_type.syntax == "proto3"
446 650
447 def getter(self): 651 def getter(self):
448 # TODO(protobuf-team): This may be broken since there may not be 652 # TODO(protobuf-team): This may be broken since there may not be
449 # default_value. Combine with has_default_value somehow. 653 # default_value. Combine with has_default_value somehow.
450 return self._fields.get(field, default_value) 654 return self._fields.get(field, default_value)
451 getter.__module__ = None 655 getter.__module__ = None
452 getter.__doc__ = 'Getter for %s.' % proto_field_name 656 getter.__doc__ = 'Getter for %s.' % proto_field_name
453 def setter(self, new_value): 657
454 type_checker.CheckValue(new_value) 658 clear_when_set_to_default = is_proto3 and not field.containing_oneof
455 self._fields[field] = new_value 659
660 def field_setter(self, new_value):
661 # pylint: disable=protected-access
662 # Testing the value for truthiness captures all of the proto3 defaults
663 # (0, 0.0, enum 0, and False).
664 new_value = type_checker.CheckValue(new_value)
665 if clear_when_set_to_default and not new_value:
666 self._fields.pop(field, None)
667 else:
668 self._fields[field] = new_value
456 # Check _cached_byte_size_dirty inline to improve performance, since scalar 669 # Check _cached_byte_size_dirty inline to improve performance, since scalar
457 # setters are called frequently. 670 # setters are called frequently.
458 if not self._cached_byte_size_dirty: 671 if not self._cached_byte_size_dirty:
459 self._Modified() 672 self._Modified()
460 673
674 if field.containing_oneof:
675 def setter(self, new_value):
676 field_setter(self, new_value)
677 self._UpdateOneofState(field)
678 else:
679 setter = field_setter
680
461 setter.__module__ = None 681 setter.__module__ = None
462 setter.__doc__ = 'Setter for %s.' % proto_field_name 682 setter.__doc__ = 'Setter for %s.' % proto_field_name
463 683
464 # Add a property to encapsulate the getter/setter. 684 # Add a property to encapsulate the getter/setter.
465 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 685 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
466 setattr(cls, property_name, property(getter, setter, doc=doc)) 686 setattr(cls, property_name, property(getter, setter, doc=doc))
467 687
468 688
469 def _AddPropertiesForNonRepeatedCompositeField(field, cls): 689 def _AddPropertiesForNonRepeatedCompositeField(field, cls):
470 """Adds a public property for a nonrepeated, composite protocol message field. 690 """Adds a public property for a nonrepeated, composite protocol message field.
471 A composite field is a "group" or "message" field. 691 A composite field is a "group" or "message" field.
472 692
473 Clients can use this property to get the value of the field, but cannot 693 Clients can use this property to get the value of the field, but cannot
474 assign to the property directly. 694 assign to the property directly.
475 695
476 Args: 696 Args:
477 field: A FieldDescriptor for this field. 697 field: A FieldDescriptor for this field.
478 cls: The class we're constructing. 698 cls: The class we're constructing.
479 """ 699 """
480 # TODO(robinson): Remove duplication with similar method 700 # TODO(robinson): Remove duplication with similar method
481 # for non-repeated scalars. 701 # for non-repeated scalars.
482 proto_field_name = field.name 702 proto_field_name = field.name
483 property_name = _PropertyName(proto_field_name) 703 property_name = _PropertyName(proto_field_name)
484 704
485 # TODO(komarek): Can anyone explain to me why we cache the message_type this
486 # way, instead of referring to field.message_type inside of getter(self)?
487 # What if someone sets message_type later on (which makes for simpler
488 # dyanmic proto descriptor and class creation code).
489 message_type = field.message_type
490
491 def getter(self): 705 def getter(self):
492 field_value = self._fields.get(field) 706 field_value = self._fields.get(field)
493 if field_value is None: 707 if field_value is None:
494 # Construct a new object to represent this field. 708 # Construct a new object to represent this field.
495 field_value = message_type._concrete_class() # use field.message_type? 709 field_value = field._default_constructor(self)
496 field_value._SetListener(self._listener_for_children)
497 710
498 # Atomically check if another thread has preempted us and, if not, swap 711 # Atomically check if another thread has preempted us and, if not, swap
499 # in the new object we just created. If someone has preempted us, we 712 # in the new object we just created. If someone has preempted us, we
500 # take that object and discard ours. 713 # take that object and discard ours.
501 # WARNING: We are relying on setdefault() being atomic. This is true 714 # WARNING: We are relying on setdefault() being atomic. This is true
502 # in CPython but we haven't investigated others. This warning appears 715 # in CPython but we haven't investigated others. This warning appears
503 # in several other locations in this file. 716 # in several other locations in this file.
504 field_value = self._fields.setdefault(field, field_value) 717 field_value = self._fields.setdefault(field, field_value)
505 return field_value 718 return field_value
506 getter.__module__ = None 719 getter.__module__ = None
507 getter.__doc__ = 'Getter for %s.' % proto_field_name 720 getter.__doc__ = 'Getter for %s.' % proto_field_name
508 721
509 # We define a setter just so we can throw an exception with a more 722 # We define a setter just so we can throw an exception with a more
510 # helpful error message. 723 # helpful error message.
511 def setter(self, new_value): 724 def setter(self, new_value):
512 raise AttributeError('Assignment not allowed to composite field ' 725 raise AttributeError('Assignment not allowed to composite field '
513 '"%s" in protocol message object.' % proto_field_name) 726 '"%s" in protocol message object.' % proto_field_name)
514 727
515 # Add a property to encapsulate the getter. 728 # Add a property to encapsulate the getter.
516 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 729 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
517 setattr(cls, property_name, property(getter, setter, doc=doc)) 730 setattr(cls, property_name, property(getter, setter, doc=doc))
518 731
519 732
520 def _AddPropertiesForExtensions(descriptor, cls): 733 def _AddPropertiesForExtensions(descriptor, cls):
521 """Adds properties for all fields in this protocol message type.""" 734 """Adds properties for all fields in this protocol message type."""
522 extension_dict = descriptor.extensions_by_name 735 extension_dict = descriptor.extensions_by_name
523 for extension_name, extension_field in extension_dict.iteritems(): 736 for extension_name, extension_field in extension_dict.items():
524 constant_name = extension_name.upper() + "_FIELD_NUMBER" 737 constant_name = extension_name.upper() + "_FIELD_NUMBER"
525 setattr(cls, constant_name, extension_field.number) 738 setattr(cls, constant_name, extension_field.number)
526 739
527 740
528 def _AddStaticMethods(cls): 741 def _AddStaticMethods(cls):
529 # TODO(robinson): This probably needs to be thread-safe(?) 742 # TODO(robinson): This probably needs to be thread-safe(?)
530 def RegisterExtension(extension_handle): 743 def RegisterExtension(extension_handle):
531 extension_handle.containing_type = cls.DESCRIPTOR 744 extension_handle.containing_type = cls.DESCRIPTOR
532 _AttachFieldHelpers(cls, extension_handle) 745 _AttachFieldHelpers(cls, extension_handle)
533 746
(...skipping 34 matching lines...) Expand 10 before | Expand all | Expand 10 after
568 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 781 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
569 return item[1]._is_present_in_parent 782 return item[1]._is_present_in_parent
570 else: 783 else:
571 return True 784 return True
572 785
573 786
574 def _AddListFieldsMethod(message_descriptor, cls): 787 def _AddListFieldsMethod(message_descriptor, cls):
575 """Helper for _AddMessageMethods().""" 788 """Helper for _AddMessageMethods()."""
576 789
577 def ListFields(self): 790 def ListFields(self):
578 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] 791 all_fields = [item for item in self._fields.items() if _IsPresent(item)]
579 all_fields.sort(key = lambda item: item[0].number) 792 all_fields.sort(key = lambda item: item[0].number)
580 return all_fields 793 return all_fields
581 794
582 cls.ListFields = ListFields 795 cls.ListFields = ListFields
583 796
797 _Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"'
798 _Proto2HasError = 'Protocol message has no non-repeated field "%s"'
584 799
585 def _AddHasFieldMethod(message_descriptor, cls): 800 def _AddHasFieldMethod(message_descriptor, cls):
586 """Helper for _AddMessageMethods().""" 801 """Helper for _AddMessageMethods()."""
587 802
588 singular_fields = {} 803 is_proto3 = (message_descriptor.syntax == "proto3")
804 error_msg = _Proto3HasError if is_proto3 else _Proto2HasError
805
806 hassable_fields = {}
589 for field in message_descriptor.fields: 807 for field in message_descriptor.fields:
590 if field.label != _FieldDescriptor.LABEL_REPEATED: 808 if field.label == _FieldDescriptor.LABEL_REPEATED:
591 singular_fields[field.name] = field 809 continue
810 # For proto3, only submessages and fields inside a oneof have presence.
811 if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
812 not field.containing_oneof):
813 continue
814 hassable_fields[field.name] = field
815
816 if not is_proto3:
817 # Fields inside oneofs are never repeated (enforced by the compiler).
818 for oneof in message_descriptor.oneofs:
819 hassable_fields[oneof.name] = oneof
592 820
593 def HasField(self, field_name): 821 def HasField(self, field_name):
594 try: 822 try:
595 field = singular_fields[field_name] 823 field = hassable_fields[field_name]
596 except KeyError: 824 except KeyError:
597 raise ValueError( 825 raise ValueError(error_msg % field_name)
598 'Protocol message has no singular "%s" field.' % field_name)
599 826
600 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 827 if isinstance(field, descriptor_mod.OneofDescriptor):
601 value = self._fields.get(field) 828 try:
602 return value is not None and value._is_present_in_parent 829 return HasField(self, self._oneofs[field].name)
830 except KeyError:
831 return False
603 else: 832 else:
604 return field in self._fields 833 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
834 value = self._fields.get(field)
835 return value is not None and value._is_present_in_parent
836 else:
837 return field in self._fields
838
605 cls.HasField = HasField 839 cls.HasField = HasField
606 840
607 841
608 def _AddClearFieldMethod(message_descriptor, cls): 842 def _AddClearFieldMethod(message_descriptor, cls):
609 """Helper for _AddMessageMethods().""" 843 """Helper for _AddMessageMethods()."""
610 def ClearField(self, field_name): 844 def ClearField(self, field_name):
611 try: 845 try:
612 field = message_descriptor.fields_by_name[field_name] 846 field = message_descriptor.fields_by_name[field_name]
613 except KeyError: 847 except KeyError:
614 raise ValueError('Protocol message has no "%s" field.' % field_name) 848 try:
849 field = message_descriptor.oneofs_by_name[field_name]
850 if field in self._oneofs:
851 field = self._oneofs[field]
852 else:
853 return
854 except KeyError:
855 raise ValueError('Protocol message %s() has no "%s" field.' %
856 (message_descriptor.name, field_name))
615 857
616 if field in self._fields: 858 if field in self._fields:
859 # To match the C++ implementation, we need to invalidate iterators
860 # for map fields when ClearField() happens.
861 if hasattr(self._fields[field], 'InvalidateIterators'):
862 self._fields[field].InvalidateIterators()
863
617 # Note: If the field is a sub-message, its listener will still point 864 # Note: If the field is a sub-message, its listener will still point
618 # at us. That's fine, because the worst than can happen is that it 865 # at us. That's fine, because the worst than can happen is that it
619 # will call _Modified() and invalidate our byte size. Big deal. 866 # will call _Modified() and invalidate our byte size. Big deal.
620 del self._fields[field] 867 del self._fields[field]
621 868
869 if self._oneofs.get(field.containing_oneof, None) is field:
870 del self._oneofs[field.containing_oneof]
871
622 # Always call _Modified() -- even if nothing was changed, this is 872 # Always call _Modified() -- even if nothing was changed, this is
623 # a mutating method, and thus calling it should cause the field to become 873 # a mutating method, and thus calling it should cause the field to become
624 # present in the parent message. 874 # present in the parent message.
625 self._Modified() 875 self._Modified()
626 876
627 cls.ClearField = ClearField 877 cls.ClearField = ClearField
628 878
629 879
630 def _AddClearExtensionMethod(cls): 880 def _AddClearExtensionMethod(cls):
631 """Helper for _AddMessageMethods().""" 881 """Helper for _AddMessageMethods()."""
632 def ClearExtension(self, extension_handle): 882 def ClearExtension(self, extension_handle):
633 _VerifyExtensionHandle(self, extension_handle) 883 _VerifyExtensionHandle(self, extension_handle)
634 884
635 # Similar to ClearField(), above. 885 # Similar to ClearField(), above.
636 if extension_handle in self._fields: 886 if extension_handle in self._fields:
637 del self._fields[extension_handle] 887 del self._fields[extension_handle]
638 self._Modified() 888 self._Modified()
639 cls.ClearExtension = ClearExtension 889 cls.ClearExtension = ClearExtension
640 890
641 891
642 def _AddClearMethod(message_descriptor, cls): 892 def _AddClearMethod(message_descriptor, cls):
643 """Helper for _AddMessageMethods().""" 893 """Helper for _AddMessageMethods()."""
644 def Clear(self): 894 def Clear(self):
645 # Clear fields. 895 # Clear fields.
646 self._fields = {} 896 self._fields = {}
647 self._unknown_fields = () 897 self._unknown_fields = ()
898 self._oneofs = {}
648 self._Modified() 899 self._Modified()
649 cls.Clear = Clear 900 cls.Clear = Clear
650 901
651 902
652 def _AddHasExtensionMethod(cls): 903 def _AddHasExtensionMethod(cls):
653 """Helper for _AddMessageMethods().""" 904 """Helper for _AddMessageMethods()."""
654 def HasExtension(self, extension_handle): 905 def HasExtension(self, extension_handle):
655 _VerifyExtensionHandle(self, extension_handle) 906 _VerifyExtensionHandle(self, extension_handle)
656 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 907 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
657 raise KeyError('"%s" is repeated.' % extension_handle.full_name) 908 raise KeyError('"%s" is repeated.' % extension_handle.full_name)
658 909
659 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 910 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
660 value = self._fields.get(extension_handle) 911 value = self._fields.get(extension_handle)
661 return value is not None and value._is_present_in_parent 912 return value is not None and value._is_present_in_parent
662 else: 913 else:
663 return extension_handle in self._fields 914 return extension_handle in self._fields
664 cls.HasExtension = HasExtension 915 cls.HasExtension = HasExtension
665 916
917 def _InternalUnpackAny(msg):
918 """Unpacks Any message and returns the unpacked message.
919
920 This internal method is differnt from public Any Unpack method which takes
921 the target message as argument. _InternalUnpackAny method does not have
922 target message type and need to find the message type in descriptor pool.
923
924 Args:
925 msg: An Any message to be unpacked.
926
927 Returns:
928 The unpacked message.
929 """
930 type_url = msg.type_url
931 db = symbol_database.Default()
932
933 if not type_url:
934 return None
935
936 # TODO(haberman): For now we just strip the hostname. Better logic will be
937 # required.
938 type_name = type_url.split("/")[-1]
939 descriptor = db.pool.FindMessageTypeByName(type_name)
940
941 if descriptor is None:
942 return None
943
944 message_class = db.GetPrototype(descriptor)
945 message = message_class()
946
947 message.ParseFromString(msg.value)
948 return message
666 949
667 def _AddEqualsMethod(message_descriptor, cls): 950 def _AddEqualsMethod(message_descriptor, cls):
668 """Helper for _AddMessageMethods().""" 951 """Helper for _AddMessageMethods()."""
669 def __eq__(self, other): 952 def __eq__(self, other):
670 if (not isinstance(other, message_mod.Message) or 953 if (not isinstance(other, message_mod.Message) or
671 other.DESCRIPTOR != self.DESCRIPTOR): 954 other.DESCRIPTOR != self.DESCRIPTOR):
672 return False 955 return False
673 956
674 if self is other: 957 if self is other:
675 return True 958 return True
676 959
960 if self.DESCRIPTOR.full_name == _AnyFullTypeName:
961 any_a = _InternalUnpackAny(self)
962 any_b = _InternalUnpackAny(other)
963 if any_a and any_b:
964 return any_a == any_b
965
677 if not self.ListFields() == other.ListFields(): 966 if not self.ListFields() == other.ListFields():
678 return False 967 return False
679 968
680 # Sort unknown fields because their order shouldn't affect equality test. 969 # Sort unknown fields because their order shouldn't affect equality test.
681 unknown_fields = list(self._unknown_fields) 970 unknown_fields = list(self._unknown_fields)
682 unknown_fields.sort() 971 unknown_fields.sort()
683 other_unknown_fields = list(other._unknown_fields) 972 other_unknown_fields = list(other._unknown_fields)
684 other_unknown_fields.sort() 973 other_unknown_fields.sort()
685 974
686 return unknown_fields == other_unknown_fields 975 return unknown_fields == other_unknown_fields
687 976
688 cls.__eq__ = __eq__ 977 cls.__eq__ = __eq__
689 978
690 979
691 def _AddStrMethod(message_descriptor, cls): 980 def _AddStrMethod(message_descriptor, cls):
692 """Helper for _AddMessageMethods().""" 981 """Helper for _AddMessageMethods()."""
693 def __str__(self): 982 def __str__(self):
694 return text_format.MessageToString(self) 983 return text_format.MessageToString(self)
695 cls.__str__ = __str__ 984 cls.__str__ = __str__
696 985
697 986
987 def _AddReprMethod(message_descriptor, cls):
988 """Helper for _AddMessageMethods()."""
989 def __repr__(self):
990 return text_format.MessageToString(self)
991 cls.__repr__ = __repr__
992
993
698 def _AddUnicodeMethod(unused_message_descriptor, cls): 994 def _AddUnicodeMethod(unused_message_descriptor, cls):
699 """Helper for _AddMessageMethods().""" 995 """Helper for _AddMessageMethods()."""
700 996
701 def __unicode__(self): 997 def __unicode__(self):
702 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') 998 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
703 cls.__unicode__ = __unicode__ 999 cls.__unicode__ = __unicode__
704 1000
705 1001
706 def _AddSetListenerMethod(cls): 1002 def _AddSetListenerMethod(cls):
707 """Helper for _AddMessageMethods().""" 1003 """Helper for _AddMessageMethods()."""
(...skipping 58 matching lines...) Expand 10 before | Expand all | Expand 10 after
766 'Message %s is missing required fields: %s' % ( 1062 'Message %s is missing required fields: %s' % (
767 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) 1063 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
768 return self.SerializePartialToString() 1064 return self.SerializePartialToString()
769 cls.SerializeToString = SerializeToString 1065 cls.SerializeToString = SerializeToString
770 1066
771 1067
772 def _AddSerializePartialToStringMethod(message_descriptor, cls): 1068 def _AddSerializePartialToStringMethod(message_descriptor, cls):
773 """Helper for _AddMessageMethods().""" 1069 """Helper for _AddMessageMethods()."""
774 1070
775 def SerializePartialToString(self): 1071 def SerializePartialToString(self):
776 out = StringIO() 1072 out = BytesIO()
777 self._InternalSerialize(out.write) 1073 self._InternalSerialize(out.write)
778 return out.getvalue() 1074 return out.getvalue()
779 cls.SerializePartialToString = SerializePartialToString 1075 cls.SerializePartialToString = SerializePartialToString
780 1076
781 def InternalSerialize(self, write_bytes): 1077 def InternalSerialize(self, write_bytes):
782 for field_descriptor, field_value in self.ListFields(): 1078 for field_descriptor, field_value in self.ListFields():
783 field_descriptor._encoder(write_bytes, field_value) 1079 field_descriptor._encoder(write_bytes, field_value)
784 for tag_bytes, value_bytes in self._unknown_fields: 1080 for tag_bytes, value_bytes in self._unknown_fields:
785 write_bytes(tag_bytes) 1081 write_bytes(tag_bytes)
786 write_bytes(value_bytes) 1082 write_bytes(value_bytes)
787 cls._InternalSerialize = InternalSerialize 1083 cls._InternalSerialize = InternalSerialize
788 1084
789 1085
790 def _AddMergeFromStringMethod(message_descriptor, cls): 1086 def _AddMergeFromStringMethod(message_descriptor, cls):
791 """Helper for _AddMessageMethods().""" 1087 """Helper for _AddMessageMethods()."""
792 def MergeFromString(self, serialized): 1088 def MergeFromString(self, serialized):
793 length = len(serialized) 1089 length = len(serialized)
794 try: 1090 try:
795 if self._InternalParse(serialized, 0, length) != length: 1091 if self._InternalParse(serialized, 0, length) != length:
796 # The only reason _InternalParse would return early is if it 1092 # The only reason _InternalParse would return early is if it
797 # encountered an end-group tag. 1093 # encountered an end-group tag.
798 raise message_mod.DecodeError('Unexpected end-group tag.') 1094 raise message_mod.DecodeError('Unexpected end-group tag.')
799 except IndexError: 1095 except (IndexError, TypeError):
1096 # Now ord(buf[p:p+1]) == ord('') gets TypeError.
800 raise message_mod.DecodeError('Truncated message.') 1097 raise message_mod.DecodeError('Truncated message.')
801 except struct.error, e: 1098 except struct.error as e:
802 raise message_mod.DecodeError(e) 1099 raise message_mod.DecodeError(e)
803 return length # Return this for legacy reasons. 1100 return length # Return this for legacy reasons.
804 cls.MergeFromString = MergeFromString 1101 cls.MergeFromString = MergeFromString
805 1102
806 local_ReadTag = decoder.ReadTag 1103 local_ReadTag = decoder.ReadTag
807 local_SkipField = decoder.SkipField 1104 local_SkipField = decoder.SkipField
808 decoders_by_tag = cls._decoders_by_tag 1105 decoders_by_tag = cls._decoders_by_tag
1106 is_proto3 = message_descriptor.syntax == "proto3"
809 1107
810 def InternalParse(self, buffer, pos, end): 1108 def InternalParse(self, buffer, pos, end):
811 self._Modified() 1109 self._Modified()
812 field_dict = self._fields 1110 field_dict = self._fields
813 unknown_field_list = self._unknown_fields 1111 unknown_field_list = self._unknown_fields
814 while pos != end: 1112 while pos != end:
815 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) 1113 (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
816 field_decoder = decoders_by_tag.get(tag_bytes) 1114 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
817 if field_decoder is None: 1115 if field_decoder is None:
818 value_start_pos = new_pos 1116 value_start_pos = new_pos
819 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) 1117 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
820 if new_pos == -1: 1118 if new_pos == -1:
821 return pos 1119 return pos
822 if not unknown_field_list: 1120 if not is_proto3:
823 unknown_field_list = self._unknown_fields = [] 1121 if not unknown_field_list:
824 unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) 1122 unknown_field_list = self._unknown_fields = []
1123 unknown_field_list.append(
1124 (tag_bytes, buffer[value_start_pos:new_pos]))
825 pos = new_pos 1125 pos = new_pos
826 else: 1126 else:
827 pos = field_decoder(buffer, new_pos, end, self, field_dict) 1127 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1128 if field_desc:
1129 self._UpdateOneofState(field_desc)
828 return pos 1130 return pos
829 cls._InternalParse = InternalParse 1131 cls._InternalParse = InternalParse
830 1132
831 1133
832 def _AddIsInitializedMethod(message_descriptor, cls): 1134 def _AddIsInitializedMethod(message_descriptor, cls):
833 """Adds the IsInitialized and FindInitializationError methods to the 1135 """Adds the IsInitialized and FindInitializationError methods to the
834 protocol message class.""" 1136 protocol message class."""
835 1137
836 required_fields = [field for field in message_descriptor.fields 1138 required_fields = [field for field in message_descriptor.fields
837 if field.label == _FieldDescriptor.LABEL_REQUIRED] 1139 if field.label == _FieldDescriptor.LABEL_REQUIRED]
(...skipping 12 matching lines...) Expand all
850 # Performance is critical so we avoid HasField() and ListFields(). 1152 # Performance is critical so we avoid HasField() and ListFields().
851 1153
852 for field in required_fields: 1154 for field in required_fields:
853 if (field not in self._fields or 1155 if (field not in self._fields or
854 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and 1156 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
855 not self._fields[field]._is_present_in_parent)): 1157 not self._fields[field]._is_present_in_parent)):
856 if errors is not None: 1158 if errors is not None:
857 errors.extend(self.FindInitializationErrors()) 1159 errors.extend(self.FindInitializationErrors())
858 return False 1160 return False
859 1161
860 for field, value in self._fields.iteritems(): 1162 for field, value in list(self._fields.items()): # dict can change size!
861 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1163 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
862 if field.label == _FieldDescriptor.LABEL_REPEATED: 1164 if field.label == _FieldDescriptor.LABEL_REPEATED:
1165 if (field.message_type.has_options and
1166 field.message_type.GetOptions().map_entry):
1167 continue
863 for element in value: 1168 for element in value:
864 if not element.IsInitialized(): 1169 if not element.IsInitialized():
865 if errors is not None: 1170 if errors is not None:
866 errors.extend(self.FindInitializationErrors()) 1171 errors.extend(self.FindInitializationErrors())
867 return False 1172 return False
868 elif value._is_present_in_parent and not value.IsInitialized(): 1173 elif value._is_present_in_parent and not value.IsInitialized():
869 if errors is not None: 1174 if errors is not None:
870 errors.extend(self.FindInitializationErrors()) 1175 errors.extend(self.FindInitializationErrors())
871 return False 1176 return False
872 1177
(...skipping 15 matching lines...) Expand all
888 if not self.HasField(field.name): 1193 if not self.HasField(field.name):
889 errors.append(field.name) 1194 errors.append(field.name)
890 1195
891 for field, value in self.ListFields(): 1196 for field, value in self.ListFields():
892 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1197 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
893 if field.is_extension: 1198 if field.is_extension:
894 name = "(%s)" % field.full_name 1199 name = "(%s)" % field.full_name
895 else: 1200 else:
896 name = field.name 1201 name = field.name
897 1202
898 if field.label == _FieldDescriptor.LABEL_REPEATED: 1203 if _IsMapField(field):
899 for i in xrange(len(value)): 1204 if _IsMessageMapField(field):
1205 for key in value:
1206 element = value[key]
1207 prefix = "%s[%s]." % (name, key)
1208 sub_errors = element.FindInitializationErrors()
1209 errors += [prefix + error for error in sub_errors]
1210 else:
1211 # ScalarMaps can't have any initialization errors.
1212 pass
1213 elif field.label == _FieldDescriptor.LABEL_REPEATED:
1214 for i in range(len(value)):
900 element = value[i] 1215 element = value[i]
901 prefix = "%s[%d]." % (name, i) 1216 prefix = "%s[%d]." % (name, i)
902 sub_errors = element.FindInitializationErrors() 1217 sub_errors = element.FindInitializationErrors()
903 errors += [ prefix + error for error in sub_errors ] 1218 errors += [prefix + error for error in sub_errors]
904 else: 1219 else:
905 prefix = name + "." 1220 prefix = name + "."
906 sub_errors = value.FindInitializationErrors() 1221 sub_errors = value.FindInitializationErrors()
907 errors += [ prefix + error for error in sub_errors ] 1222 errors += [prefix + error for error in sub_errors]
908 1223
909 return errors 1224 return errors
910 1225
911 cls.FindInitializationErrors = FindInitializationErrors 1226 cls.FindInitializationErrors = FindInitializationErrors
912 1227
913 1228
914 def _AddMergeFromMethod(cls): 1229 def _AddMergeFromMethod(cls):
915 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED 1230 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
916 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE 1231 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
917 1232
918 def MergeFrom(self, msg): 1233 def MergeFrom(self, msg):
919 if not isinstance(msg, cls): 1234 if not isinstance(msg, cls):
920 raise TypeError( 1235 raise TypeError(
921 "Parameter to MergeFrom() must be instance of same class: " 1236 "Parameter to MergeFrom() must be instance of same class: "
922 "expected %s got %s." % (cls.__name__, type(msg).__name__)) 1237 "expected %s got %s." % (cls.__name__, type(msg).__name__))
923 1238
924 assert msg is not self 1239 assert msg is not self
925 self._Modified() 1240 self._Modified()
926 1241
927 fields = self._fields 1242 fields = self._fields
928 1243
929 for field, value in msg._fields.iteritems(): 1244 for field, value in msg._fields.items():
930 if field.label == LABEL_REPEATED: 1245 if field.label == LABEL_REPEATED:
931 field_value = fields.get(field) 1246 field_value = fields.get(field)
932 if field_value is None: 1247 if field_value is None:
933 # Construct a new object to represent this field. 1248 # Construct a new object to represent this field.
934 field_value = field._default_constructor(self) 1249 field_value = field._default_constructor(self)
935 fields[field] = field_value 1250 fields[field] = field_value
936 field_value.MergeFrom(value) 1251 field_value.MergeFrom(value)
937 elif field.cpp_type == CPPTYPE_MESSAGE: 1252 elif field.cpp_type == CPPTYPE_MESSAGE:
938 if value._is_present_in_parent: 1253 if value._is_present_in_parent:
939 field_value = fields.get(field) 1254 field_value = fields.get(field)
940 if field_value is None: 1255 if field_value is None:
941 # Construct a new object to represent this field. 1256 # Construct a new object to represent this field.
942 field_value = field._default_constructor(self) 1257 field_value = field._default_constructor(self)
943 fields[field] = field_value 1258 fields[field] = field_value
944 field_value.MergeFrom(value) 1259 field_value.MergeFrom(value)
945 else: 1260 else:
946 self._fields[field] = value 1261 self._fields[field] = value
1262 if field.containing_oneof:
1263 self._UpdateOneofState(field)
947 1264
948 if msg._unknown_fields: 1265 if msg._unknown_fields:
949 if not self._unknown_fields: 1266 if not self._unknown_fields:
950 self._unknown_fields = [] 1267 self._unknown_fields = []
951 self._unknown_fields.extend(msg._unknown_fields) 1268 self._unknown_fields.extend(msg._unknown_fields)
952 1269
953 cls.MergeFrom = MergeFrom 1270 cls.MergeFrom = MergeFrom
954 1271
955 1272
1273 def _AddWhichOneofMethod(message_descriptor, cls):
1274 def WhichOneof(self, oneof_name):
1275 """Returns the name of the currently set field inside a oneof, or None."""
1276 try:
1277 field = message_descriptor.oneofs_by_name[oneof_name]
1278 except KeyError:
1279 raise ValueError(
1280 'Protocol message has no oneof "%s" field.' % oneof_name)
1281
1282 nested_field = self._oneofs.get(field, None)
1283 if nested_field is not None and self.HasField(nested_field.name):
1284 return nested_field.name
1285 else:
1286 return None
1287
1288 cls.WhichOneof = WhichOneof
1289
1290
956 def _AddMessageMethods(message_descriptor, cls): 1291 def _AddMessageMethods(message_descriptor, cls):
957 """Adds implementations of all Message methods to cls.""" 1292 """Adds implementations of all Message methods to cls."""
958 _AddListFieldsMethod(message_descriptor, cls) 1293 _AddListFieldsMethod(message_descriptor, cls)
959 _AddHasFieldMethod(message_descriptor, cls) 1294 _AddHasFieldMethod(message_descriptor, cls)
960 _AddClearFieldMethod(message_descriptor, cls) 1295 _AddClearFieldMethod(message_descriptor, cls)
961 if message_descriptor.is_extendable: 1296 if message_descriptor.is_extendable:
962 _AddClearExtensionMethod(cls) 1297 _AddClearExtensionMethod(cls)
963 _AddHasExtensionMethod(cls) 1298 _AddHasExtensionMethod(cls)
964 _AddClearMethod(message_descriptor, cls) 1299 _AddClearMethod(message_descriptor, cls)
965 _AddEqualsMethod(message_descriptor, cls) 1300 _AddEqualsMethod(message_descriptor, cls)
966 _AddStrMethod(message_descriptor, cls) 1301 _AddStrMethod(message_descriptor, cls)
1302 _AddReprMethod(message_descriptor, cls)
967 _AddUnicodeMethod(message_descriptor, cls) 1303 _AddUnicodeMethod(message_descriptor, cls)
968 _AddSetListenerMethod(cls) 1304 _AddSetListenerMethod(cls)
969 _AddByteSizeMethod(message_descriptor, cls) 1305 _AddByteSizeMethod(message_descriptor, cls)
970 _AddSerializeToStringMethod(message_descriptor, cls) 1306 _AddSerializeToStringMethod(message_descriptor, cls)
971 _AddSerializePartialToStringMethod(message_descriptor, cls) 1307 _AddSerializePartialToStringMethod(message_descriptor, cls)
972 _AddMergeFromStringMethod(message_descriptor, cls) 1308 _AddMergeFromStringMethod(message_descriptor, cls)
973 _AddIsInitializedMethod(message_descriptor, cls) 1309 _AddIsInitializedMethod(message_descriptor, cls)
974 _AddMergeFromMethod(cls) 1310 _AddMergeFromMethod(cls)
1311 _AddWhichOneofMethod(message_descriptor, cls)
975 1312
976 1313
977 def _AddPrivateHelperMethods(cls): 1314 def _AddPrivateHelperMethods(message_descriptor, cls):
978 """Adds implementation of private helper methods to cls.""" 1315 """Adds implementation of private helper methods to cls."""
979 1316
980 def Modified(self): 1317 def Modified(self):
981 """Sets the _cached_byte_size_dirty bit to true, 1318 """Sets the _cached_byte_size_dirty bit to true,
982 and propagates this to our listener iff this was a state change. 1319 and propagates this to our listener iff this was a state change.
983 """ 1320 """
984 1321
985 # Note: Some callers check _cached_byte_size_dirty before calling 1322 # Note: Some callers check _cached_byte_size_dirty before calling
986 # _Modified() as an extra optimization. So, if this method is ever 1323 # _Modified() as an extra optimization. So, if this method is ever
987 # changed such that it does stuff even when _cached_byte_size_dirty is 1324 # changed such that it does stuff even when _cached_byte_size_dirty is
988 # already true, the callers need to be updated. 1325 # already true, the callers need to be updated.
989 if not self._cached_byte_size_dirty: 1326 if not self._cached_byte_size_dirty:
990 self._cached_byte_size_dirty = True 1327 self._cached_byte_size_dirty = True
991 self._listener_for_children.dirty = True 1328 self._listener_for_children.dirty = True
992 self._is_present_in_parent = True 1329 self._is_present_in_parent = True
993 self._listener.Modified() 1330 self._listener.Modified()
994 1331
1332 def _UpdateOneofState(self, field):
1333 """Sets field as the active field in its containing oneof.
1334
1335 Will also delete currently active field in the oneof, if it is different
1336 from the argument. Does not mark the message as modified.
1337 """
1338 other_field = self._oneofs.setdefault(field.containing_oneof, field)
1339 if other_field is not field:
1340 del self._fields[other_field]
1341 self._oneofs[field.containing_oneof] = field
1342
995 cls._Modified = Modified 1343 cls._Modified = Modified
996 cls.SetInParent = Modified 1344 cls.SetInParent = Modified
1345 cls._UpdateOneofState = _UpdateOneofState
997 1346
998 1347
999 class _Listener(object): 1348 class _Listener(object):
1000 1349
1001 """MessageListener implementation that a parent message registers with its 1350 """MessageListener implementation that a parent message registers with its
1002 child message. 1351 child message.
1003 1352
1004 In order to support semantics like: 1353 In order to support semantics like:
1005 1354
1006 foo.bar.baz.qux = 23 1355 foo.bar.baz.qux = 23
(...skipping 28 matching lines...) Expand all
1035 try: 1384 try:
1036 # Propagate the signal to our parents iff this is the first field set. 1385 # Propagate the signal to our parents iff this is the first field set.
1037 self._parent_message_weakref._Modified() 1386 self._parent_message_weakref._Modified()
1038 except ReferenceError: 1387 except ReferenceError:
1039 # We can get here if a client has kept a reference to a child object, 1388 # We can get here if a client has kept a reference to a child object,
1040 # and is now setting a field on it, but the child's parent has been 1389 # and is now setting a field on it, but the child's parent has been
1041 # garbage-collected. This is not an error. 1390 # garbage-collected. This is not an error.
1042 pass 1391 pass
1043 1392
1044 1393
1394 class _OneofListener(_Listener):
1395 """Special listener implementation for setting composite oneof fields."""
1396
1397 def __init__(self, parent_message, field):
1398 """Args:
1399 parent_message: The message whose _Modified() method we should call when
1400 we receive Modified() messages.
1401 field: The descriptor of the field being set in the parent message.
1402 """
1403 super(_OneofListener, self).__init__(parent_message)
1404 self._field = field
1405
1406 def Modified(self):
1407 """Also updates the state of the containing oneof in the parent message."""
1408 try:
1409 self._parent_message_weakref._UpdateOneofState(self._field)
1410 super(_OneofListener, self).Modified()
1411 except ReferenceError:
1412 pass
1413
1414
1045 # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... 1415 # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
1046 # TODO(robinson): Unify error handling of "unknown extension" crap. 1416 # TODO(robinson): Unify error handling of "unknown extension" crap.
1047 # TODO(robinson): Support iteritems()-style iteration over all 1417 # TODO(robinson): Support iteritems()-style iteration over all
1048 # extensions with the "has" bits turned on? 1418 # extensions with the "has" bits turned on?
1049 class _ExtensionDict(object): 1419 class _ExtensionDict(object):
1050 1420
1051 """Dict-like container for supporting an indexable "Extensions" 1421 """Dict-like container for supporting an indexable "Extensions"
1052 field on proto instances. 1422 field on proto instances.
1053 1423
1054 Note that in all cases we expect extension handles to be 1424 Note that in all cases we expect extension handles to be
(...skipping 70 matching lines...) Expand 10 before | Expand all | Expand 10 after
1125 _VerifyExtensionHandle(self._extended_message, extension_handle) 1495 _VerifyExtensionHandle(self._extended_message, extension_handle)
1126 1496
1127 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or 1497 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1128 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): 1498 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1129 raise TypeError( 1499 raise TypeError(
1130 'Cannot assign to extension "%s" because it is a repeated or ' 1500 'Cannot assign to extension "%s" because it is a repeated or '
1131 'composite type.' % extension_handle.full_name) 1501 'composite type.' % extension_handle.full_name)
1132 1502
1133 # It's slightly wasteful to lookup the type checker each time, 1503 # It's slightly wasteful to lookup the type checker each time,
1134 # but we expect this to be a vanishingly uncommon case anyway. 1504 # but we expect this to be a vanishingly uncommon case anyway.
1135 type_checker = type_checkers.GetTypeChecker( 1505 type_checker = type_checkers.GetTypeChecker(extension_handle)
1136 extension_handle.cpp_type, extension_handle.type) 1506 # pylint: disable=protected-access
1137 type_checker.CheckValue(value) 1507 self._extended_message._fields[extension_handle] = (
1138 self._extended_message._fields[extension_handle] = value 1508 type_checker.CheckValue(value))
1139 self._extended_message._Modified() 1509 self._extended_message._Modified()
1140 1510
1141 def _FindExtensionByName(self, name): 1511 def _FindExtensionByName(self, name):
1142 """Tries to find a known extension with the specified name. 1512 """Tries to find a known extension with the specified name.
1143 1513
1144 Args: 1514 Args:
1145 name: Extension full name. 1515 name: Extension full name.
1146 1516
1147 Returns: 1517 Returns:
1148 Extension field descriptor. 1518 Extension field descriptor.
1149 """ 1519 """
1150 return self._extended_message._extensions_by_name.get(name, None) 1520 return self._extended_message._extensions_by_name.get(name, None)
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698