Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(287)

Side by Side Diff: third_party/protobuf/python/google/protobuf/internal/python_message.py

Issue 6737030: third_party/protobuf: update to upstream r371 (Closed) Base URL: svn://svn.chromium.org/chrome/trunk/src
Patch Set: Created 9 years, 8 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch | Annotate | Revision Log
OLDNEW
(Empty)
1 # Protocol Buffers - Google's data interchange format
2 # Copyright 2008 Google Inc. All rights reserved.
3 # http://code.google.com/p/protobuf/
4 #
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are
7 # met:
8 #
9 # * Redistributions of source code must retain the above copyright
10 # notice, this list of conditions and the following disclaimer.
11 # * Redistributions in binary form must reproduce the above
12 # copyright notice, this list of conditions and the following disclaimer
13 # in the documentation and/or other materials provided with the
14 # distribution.
15 # * Neither the name of Google Inc. nor the names of its
16 # contributors may be used to endorse or promote products derived from
17 # this software without specific prior written permission.
18 #
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 # This code is meant to work on Python 2.4 and above only.
32 #
33 # TODO(robinson): Helpers for verbose, common checks like seeing if a
34 # descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36 """Contains a metaclass and helper functions used to create
37 protocol message classes from Descriptor objects at runtime.
38
39 Recall that a metaclass is the "type" of a class.
40 (A class is to a metaclass what an instance is to a class.)
41
42 In this case, we use the GeneratedProtocolMessageType metaclass
43 to inject all the useful functionality into the classes
44 output by the protocol compiler at compile-time.
45
46 The upshot of all this is that the real implementation
47 details for ALL pure-Python protocol buffers are *here in
48 this file*.
49 """
50
51 __author__ = 'robinson@google.com (Will Robinson)'
52
53 try:
54 from cStringIO import StringIO
55 except ImportError:
56 from StringIO import StringIO
57 import struct
58 import weakref
59
60 # We use "as" to avoid name collisions with variables.
61 from google.protobuf.internal import containers
62 from google.protobuf.internal import decoder
63 from google.protobuf.internal import encoder
64 from google.protobuf.internal import message_listener as message_listener_mod
65 from google.protobuf.internal import type_checkers
66 from google.protobuf.internal import wire_format
67 from google.protobuf import descriptor as descriptor_mod
68 from google.protobuf import message as message_mod
69 from google.protobuf import text_format
70
71 _FieldDescriptor = descriptor_mod.FieldDescriptor
72
73
74 def NewMessage(descriptor, dictionary):
75 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
76 _AddSlots(descriptor, dictionary)
77
78
79 def InitMessage(descriptor, cls):
80 cls._decoders_by_tag = {}
81 cls._extensions_by_name = {}
82 cls._extensions_by_number = {}
83 if (descriptor.has_options and
84 descriptor.GetOptions().message_set_wire_format):
85 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
86 decoder.MessageSetItemDecoder(cls._extensions_by_number))
87
88 # Attach stuff to each FieldDescriptor for quick lookup later on.
89 for field in descriptor.fields:
90 _AttachFieldHelpers(cls, field)
91
92 _AddEnumValues(descriptor, cls)
93 _AddInitMethod(descriptor, cls)
94 _AddPropertiesForFields(descriptor, cls)
95 _AddPropertiesForExtensions(descriptor, cls)
96 _AddStaticMethods(cls)
97 _AddMessageMethods(descriptor, cls)
98 _AddPrivateHelperMethods(cls)
99
100
101 # Stateless helpers for GeneratedProtocolMessageType below.
102 # Outside clients should not access these directly.
103 #
104 # I opted not to make any of these methods on the metaclass, to make it more
105 # clear that I'm not really using any state there and to keep clients from
106 # thinking that they have direct access to these construction helpers.
107
108
109 def _PropertyName(proto_field_name):
110 """Returns the name of the public property attribute which
111 clients can use to get and (in some cases) set the value
112 of a protocol message field.
113
114 Args:
115 proto_field_name: The protocol message field name, exactly
116 as it appears (or would appear) in a .proto file.
117 """
118 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
119 # nnorwitz makes my day by writing:
120 # """
121 # FYI. See the keyword module in the stdlib. This could be as simple as:
122 #
123 # if keyword.iskeyword(proto_field_name):
124 # return proto_field_name + "_"
125 # return proto_field_name
126 # """
127 # Kenton says: The above is a BAD IDEA. People rely on being able to use
128 # getattr() and setattr() to reflectively manipulate field values. If we
129 # rename the properties, then every such user has to also make sure to apply
130 # the same transformation. Note that currently if you name a field "yield",
131 # you can still access it just fine using getattr/setattr -- it's not even
132 # that cumbersome to do so.
133 # TODO(kenton): Remove this method entirely if/when everyone agrees with my
134 # position.
135 return proto_field_name
136
137
138 def _VerifyExtensionHandle(message, extension_handle):
139 """Verify that the given extension handle is valid."""
140
141 if not isinstance(extension_handle, _FieldDescriptor):
142 raise KeyError('HasExtension() expects an extension handle, got: %s' %
143 extension_handle)
144
145 if not extension_handle.is_extension:
146 raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
147
148 if extension_handle.containing_type is not message.DESCRIPTOR:
149 raise KeyError('Extension "%s" extends message type "%s", but this '
150 'message is of type "%s".' %
151 (extension_handle.full_name,
152 extension_handle.containing_type.full_name,
153 message.DESCRIPTOR.full_name))
154
155
156 def _AddSlots(message_descriptor, dictionary):
157 """Adds a __slots__ entry to dictionary, containing the names of all valid
158 attributes for this message type.
159
160 Args:
161 message_descriptor: A Descriptor instance describing this message type.
162 dictionary: Class dictionary to which we'll add a '__slots__' entry.
163 """
164 dictionary['__slots__'] = ['_cached_byte_size',
165 '_cached_byte_size_dirty',
166 '_fields',
167 '_is_present_in_parent',
168 '_listener',
169 '_listener_for_children',
170 '__weakref__']
171
172
173 def _IsMessageSetExtension(field):
174 return (field.is_extension and
175 field.containing_type.has_options and
176 field.containing_type.GetOptions().message_set_wire_format and
177 field.type == _FieldDescriptor.TYPE_MESSAGE and
178 field.message_type == field.extension_scope and
179 field.label == _FieldDescriptor.LABEL_OPTIONAL)
180
181
182 def _AttachFieldHelpers(cls, field_descriptor):
183 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
184 is_packed = (field_descriptor.has_options and
185 field_descriptor.GetOptions().packed)
186
187 if _IsMessageSetExtension(field_descriptor):
188 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
189 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
190 else:
191 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
192 field_descriptor.number, is_repeated, is_packed)
193 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
194 field_descriptor.number, is_repeated, is_packed)
195
196 field_descriptor._encoder = field_encoder
197 field_descriptor._sizer = sizer
198 field_descriptor._default_constructor = _DefaultValueConstructorForField(
199 field_descriptor)
200
201 def AddDecoder(wiretype, is_packed):
202 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
203 cls._decoders_by_tag[tag_bytes] = (
204 type_checkers.TYPE_TO_DECODER[field_descriptor.type](
205 field_descriptor.number, is_repeated, is_packed,
206 field_descriptor, field_descriptor._default_constructor))
207
208 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
209 False)
210
211 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
212 # To support wire compatibility of adding packed = true, add a decoder for
213 # packed values regardless of the field's options.
214 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
215
216
217 def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
218 extension_dict = descriptor.extensions_by_name
219 for extension_name, extension_field in extension_dict.iteritems():
220 assert extension_name not in dictionary
221 dictionary[extension_name] = extension_field
222
223
224 def _AddEnumValues(descriptor, cls):
225 """Sets class-level attributes for all enum fields defined in this message.
226
227 Args:
228 descriptor: Descriptor object for this message type.
229 cls: Class we're constructing for this message type.
230 """
231 for enum_type in descriptor.enum_types:
232 for enum_value in enum_type.values:
233 setattr(cls, enum_value.name, enum_value.number)
234
235
236 def _DefaultValueConstructorForField(field):
237 """Returns a function which returns a default value for a field.
238
239 Args:
240 field: FieldDescriptor object for this field.
241
242 The returned function has one argument:
243 message: Message instance containing this field, or a weakref proxy
244 of same.
245
246 That function in turn returns a default value for this field. The default
247 value may refer back to |message| via a weak reference.
248 """
249
250 if field.label == _FieldDescriptor.LABEL_REPEATED:
251 if field.default_value != []:
252 raise ValueError('Repeated field default value not empty list: %s' % (
253 field.default_value))
254 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
255 # We can't look at _concrete_class yet since it might not have
256 # been set. (Depends on order in which we initialize the classes).
257 message_type = field.message_type
258 def MakeRepeatedMessageDefault(message):
259 return containers.RepeatedCompositeFieldContainer(
260 message._listener_for_children, field.message_type)
261 return MakeRepeatedMessageDefault
262 else:
263 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
264 def MakeRepeatedScalarDefault(message):
265 return containers.RepeatedScalarFieldContainer(
266 message._listener_for_children, type_checker)
267 return MakeRepeatedScalarDefault
268
269 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
270 # _concrete_class may not yet be initialized.
271 message_type = field.message_type
272 def MakeSubMessageDefault(message):
273 result = message_type._concrete_class()
274 result._SetListener(message._listener_for_children)
275 return result
276 return MakeSubMessageDefault
277
278 def MakeScalarDefault(message):
279 return field.default_value
280 return MakeScalarDefault
281
282
283 def _AddInitMethod(message_descriptor, cls):
284 """Adds an __init__ method to cls."""
285 fields = message_descriptor.fields
286 def init(self, **kwargs):
287 self._cached_byte_size = 0
288 self._cached_byte_size_dirty = len(kwargs) > 0
289 self._fields = {}
290 self._is_present_in_parent = False
291 self._listener = message_listener_mod.NullMessageListener()
292 self._listener_for_children = _Listener(self)
293 for field_name, field_value in kwargs.iteritems():
294 field = _GetFieldByName(message_descriptor, field_name)
295 if field is None:
296 raise TypeError("%s() got an unexpected keyword argument '%s'" %
297 (message_descriptor.name, field_name))
298 if field.label == _FieldDescriptor.LABEL_REPEATED:
299 copy = field._default_constructor(self)
300 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
301 for val in field_value:
302 copy.add().MergeFrom(val)
303 else: # Scalar
304 copy.extend(field_value)
305 self._fields[field] = copy
306 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
307 copy = field._default_constructor(self)
308 copy.MergeFrom(field_value)
309 self._fields[field] = copy
310 else:
311 setattr(self, field_name, field_value)
312
313 init.__module__ = None
314 init.__doc__ = None
315 cls.__init__ = init
316
317
318 def _GetFieldByName(message_descriptor, field_name):
319 """Returns a field descriptor by field name.
320
321 Args:
322 message_descriptor: A Descriptor describing all fields in message.
323 field_name: The name of the field to retrieve.
324 Returns:
325 The field descriptor associated with the field name.
326 """
327 try:
328 return message_descriptor.fields_by_name[field_name]
329 except KeyError:
330 raise ValueError('Protocol message has no "%s" field.' % field_name)
331
332
333 def _AddPropertiesForFields(descriptor, cls):
334 """Adds properties for all fields in this protocol message type."""
335 for field in descriptor.fields:
336 _AddPropertiesForField(field, cls)
337
338 if descriptor.is_extendable:
339 # _ExtensionDict is just an adaptor with no state so we allocate a new one
340 # every time it is accessed.
341 cls.Extensions = property(lambda self: _ExtensionDict(self))
342
343
344 def _AddPropertiesForField(field, cls):
345 """Adds a public property for a protocol message field.
346 Clients can use this property to get and (in the case
347 of non-repeated scalar fields) directly set the value
348 of a protocol message field.
349
350 Args:
351 field: A FieldDescriptor for this field.
352 cls: The class we're constructing.
353 """
354 # Catch it if we add other types that we should
355 # handle specially here.
356 assert _FieldDescriptor.MAX_CPPTYPE == 10
357
358 constant_name = field.name.upper() + "_FIELD_NUMBER"
359 setattr(cls, constant_name, field.number)
360
361 if field.label == _FieldDescriptor.LABEL_REPEATED:
362 _AddPropertiesForRepeatedField(field, cls)
363 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
364 _AddPropertiesForNonRepeatedCompositeField(field, cls)
365 else:
366 _AddPropertiesForNonRepeatedScalarField(field, cls)
367
368
369 def _AddPropertiesForRepeatedField(field, cls):
370 """Adds a public property for a "repeated" protocol message field. Clients
371 can use this property to get the value of the field, which will be either a
372 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
373 below).
374
375 Note that when clients add values to these containers, we perform
376 type-checking in the case of repeated scalar fields, and we also set any
377 necessary "has" bits as a side-effect.
378
379 Args:
380 field: A FieldDescriptor for this field.
381 cls: The class we're constructing.
382 """
383 proto_field_name = field.name
384 property_name = _PropertyName(proto_field_name)
385
386 def getter(self):
387 field_value = self._fields.get(field)
388 if field_value is None:
389 # Construct a new object to represent this field.
390 field_value = field._default_constructor(self)
391
392 # Atomically check if another thread has preempted us and, if not, swap
393 # in the new object we just created. If someone has preempted us, we
394 # take that object and discard ours.
395 # WARNING: We are relying on setdefault() being atomic. This is true
396 # in CPython but we haven't investigated others. This warning appears
397 # in several other locations in this file.
398 field_value = self._fields.setdefault(field, field_value)
399 return field_value
400 getter.__module__ = None
401 getter.__doc__ = 'Getter for %s.' % proto_field_name
402
403 # We define a setter just so we can throw an exception with a more
404 # helpful error message.
405 def setter(self, new_value):
406 raise AttributeError('Assignment not allowed to repeated field '
407 '"%s" in protocol message object.' % proto_field_name)
408
409 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
410 setattr(cls, property_name, property(getter, setter, doc=doc))
411
412
413 def _AddPropertiesForNonRepeatedScalarField(field, cls):
414 """Adds a public property for a nonrepeated, scalar protocol message field.
415 Clients can use this property to get and directly set the value of the field.
416 Note that when the client sets the value of a field by using this property,
417 all necessary "has" bits are set as a side-effect, and we also perform
418 type-checking.
419
420 Args:
421 field: A FieldDescriptor for this field.
422 cls: The class we're constructing.
423 """
424 proto_field_name = field.name
425 property_name = _PropertyName(proto_field_name)
426 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
427 default_value = field.default_value
428 valid_values = set()
429
430 def getter(self):
431 return self._fields.get(field, default_value)
432 getter.__module__ = None
433 getter.__doc__ = 'Getter for %s.' % proto_field_name
434 def setter(self, new_value):
435 type_checker.CheckValue(new_value)
436 self._fields[field] = new_value
437 # Check _cached_byte_size_dirty inline to improve performance, since scalar
438 # setters are called frequently.
439 if not self._cached_byte_size_dirty:
440 self._Modified()
441
442 setter.__module__ = None
443 setter.__doc__ = 'Setter for %s.' % proto_field_name
444
445 # Add a property to encapsulate the getter/setter.
446 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
447 setattr(cls, property_name, property(getter, setter, doc=doc))
448
449
450 def _AddPropertiesForNonRepeatedCompositeField(field, cls):
451 """Adds a public property for a nonrepeated, composite protocol message field.
452 A composite field is a "group" or "message" field.
453
454 Clients can use this property to get the value of the field, but cannot
455 assign to the property directly.
456
457 Args:
458 field: A FieldDescriptor for this field.
459 cls: The class we're constructing.
460 """
461 # TODO(robinson): Remove duplication with similar method
462 # for non-repeated scalars.
463 proto_field_name = field.name
464 property_name = _PropertyName(proto_field_name)
465 message_type = field.message_type
466
467 def getter(self):
468 field_value = self._fields.get(field)
469 if field_value is None:
470 # Construct a new object to represent this field.
471 field_value = message_type._concrete_class()
472 field_value._SetListener(self._listener_for_children)
473
474 # Atomically check if another thread has preempted us and, if not, swap
475 # in the new object we just created. If someone has preempted us, we
476 # take that object and discard ours.
477 # WARNING: We are relying on setdefault() being atomic. This is true
478 # in CPython but we haven't investigated others. This warning appears
479 # in several other locations in this file.
480 field_value = self._fields.setdefault(field, field_value)
481 return field_value
482 getter.__module__ = None
483 getter.__doc__ = 'Getter for %s.' % proto_field_name
484
485 # We define a setter just so we can throw an exception with a more
486 # helpful error message.
487 def setter(self, new_value):
488 raise AttributeError('Assignment not allowed to composite field '
489 '"%s" in protocol message object.' % proto_field_name)
490
491 # Add a property to encapsulate the getter.
492 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
493 setattr(cls, property_name, property(getter, setter, doc=doc))
494
495
496 def _AddPropertiesForExtensions(descriptor, cls):
497 """Adds properties for all fields in this protocol message type."""
498 extension_dict = descriptor.extensions_by_name
499 for extension_name, extension_field in extension_dict.iteritems():
500 constant_name = extension_name.upper() + "_FIELD_NUMBER"
501 setattr(cls, constant_name, extension_field.number)
502
503
504 def _AddStaticMethods(cls):
505 # TODO(robinson): This probably needs to be thread-safe(?)
506 def RegisterExtension(extension_handle):
507 extension_handle.containing_type = cls.DESCRIPTOR
508 _AttachFieldHelpers(cls, extension_handle)
509
510 # Try to insert our extension, failing if an extension with the same number
511 # already exists.
512 actual_handle = cls._extensions_by_number.setdefault(
513 extension_handle.number, extension_handle)
514 if actual_handle is not extension_handle:
515 raise AssertionError(
516 'Extensions "%s" and "%s" both try to extend message type "%s" with '
517 'field number %d.' %
518 (extension_handle.full_name, actual_handle.full_name,
519 cls.DESCRIPTOR.full_name, extension_handle.number))
520
521 cls._extensions_by_name[extension_handle.full_name] = extension_handle
522
523 handle = extension_handle # avoid line wrapping
524 if _IsMessageSetExtension(handle):
525 # MessageSet extension. Also register under type name.
526 cls._extensions_by_name[
527 extension_handle.message_type.full_name] = extension_handle
528
529 cls.RegisterExtension = staticmethod(RegisterExtension)
530
531 def FromString(s):
532 message = cls()
533 message.MergeFromString(s)
534 return message
535 cls.FromString = staticmethod(FromString)
536
537
538 def _IsPresent(item):
539 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
540 value should be included in the list returned by ListFields()."""
541
542 if item[0].label == _FieldDescriptor.LABEL_REPEATED:
543 return bool(item[1])
544 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
545 return item[1]._is_present_in_parent
546 else:
547 return True
548
549
550 def _AddListFieldsMethod(message_descriptor, cls):
551 """Helper for _AddMessageMethods()."""
552
553 def ListFields(self):
554 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
555 all_fields.sort(key = lambda item: item[0].number)
556 return all_fields
557
558 cls.ListFields = ListFields
559
560
561 def _AddHasFieldMethod(message_descriptor, cls):
562 """Helper for _AddMessageMethods()."""
563
564 singular_fields = {}
565 for field in message_descriptor.fields:
566 if field.label != _FieldDescriptor.LABEL_REPEATED:
567 singular_fields[field.name] = field
568
569 def HasField(self, field_name):
570 try:
571 field = singular_fields[field_name]
572 except KeyError:
573 raise ValueError(
574 'Protocol message has no singular "%s" field.' % field_name)
575
576 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
577 value = self._fields.get(field)
578 return value is not None and value._is_present_in_parent
579 else:
580 return field in self._fields
581 cls.HasField = HasField
582
583
584 def _AddClearFieldMethod(message_descriptor, cls):
585 """Helper for _AddMessageMethods()."""
586 def ClearField(self, field_name):
587 try:
588 field = message_descriptor.fields_by_name[field_name]
589 except KeyError:
590 raise ValueError('Protocol message has no "%s" field.' % field_name)
591
592 if field in self._fields:
593 # Note: If the field is a sub-message, its listener will still point
594 # at us. That's fine, because the worst than can happen is that it
595 # will call _Modified() and invalidate our byte size. Big deal.
596 del self._fields[field]
597
598 # Always call _Modified() -- even if nothing was changed, this is
599 # a mutating method, and thus calling it should cause the field to become
600 # present in the parent message.
601 self._Modified()
602
603 cls.ClearField = ClearField
604
605
606 def _AddClearExtensionMethod(cls):
607 """Helper for _AddMessageMethods()."""
608 def ClearExtension(self, extension_handle):
609 _VerifyExtensionHandle(self, extension_handle)
610
611 # Similar to ClearField(), above.
612 if extension_handle in self._fields:
613 del self._fields[extension_handle]
614 self._Modified()
615 cls.ClearExtension = ClearExtension
616
617
618 def _AddClearMethod(message_descriptor, cls):
619 """Helper for _AddMessageMethods()."""
620 def Clear(self):
621 # Clear fields.
622 self._fields = {}
623 self._Modified()
624 cls.Clear = Clear
625
626
627 def _AddHasExtensionMethod(cls):
628 """Helper for _AddMessageMethods()."""
629 def HasExtension(self, extension_handle):
630 _VerifyExtensionHandle(self, extension_handle)
631 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
632 raise KeyError('"%s" is repeated.' % extension_handle.full_name)
633
634 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
635 value = self._fields.get(extension_handle)
636 return value is not None and value._is_present_in_parent
637 else:
638 return extension_handle in self._fields
639 cls.HasExtension = HasExtension
640
641
642 def _AddEqualsMethod(message_descriptor, cls):
643 """Helper for _AddMessageMethods()."""
644 def __eq__(self, other):
645 if (not isinstance(other, message_mod.Message) or
646 other.DESCRIPTOR != self.DESCRIPTOR):
647 return False
648
649 if self is other:
650 return True
651
652 return self.ListFields() == other.ListFields()
653
654 cls.__eq__ = __eq__
655
656
657 def _AddStrMethod(message_descriptor, cls):
658 """Helper for _AddMessageMethods()."""
659 def __str__(self):
660 return text_format.MessageToString(self)
661 cls.__str__ = __str__
662
663
664 def _AddUnicodeMethod(unused_message_descriptor, cls):
665 """Helper for _AddMessageMethods()."""
666
667 def __unicode__(self):
668 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
669 cls.__unicode__ = __unicode__
670
671
672 def _AddSetListenerMethod(cls):
673 """Helper for _AddMessageMethods()."""
674 def SetListener(self, listener):
675 if listener is None:
676 self._listener = message_listener_mod.NullMessageListener()
677 else:
678 self._listener = listener
679 cls._SetListener = SetListener
680
681
682 def _BytesForNonRepeatedElement(value, field_number, field_type):
683 """Returns the number of bytes needed to serialize a non-repeated element.
684 The returned byte count includes space for tag information and any
685 other additional space associated with serializing value.
686
687 Args:
688 value: Value we're serializing.
689 field_number: Field number of this value. (Since the field number
690 is stored as part of a varint-encoded tag, this has an impact
691 on the total bytes required to serialize the value).
692 field_type: The type of the field. One of the TYPE_* constants
693 within FieldDescriptor.
694 """
695 try:
696 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
697 return fn(field_number, value)
698 except KeyError:
699 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
700
701
702 def _AddByteSizeMethod(message_descriptor, cls):
703 """Helper for _AddMessageMethods()."""
704
705 def ByteSize(self):
706 if not self._cached_byte_size_dirty:
707 return self._cached_byte_size
708
709 size = 0
710 for field_descriptor, field_value in self.ListFields():
711 size += field_descriptor._sizer(field_value)
712
713 self._cached_byte_size = size
714 self._cached_byte_size_dirty = False
715 self._listener_for_children.dirty = False
716 return size
717
718 cls.ByteSize = ByteSize
719
720
721 def _AddSerializeToStringMethod(message_descriptor, cls):
722 """Helper for _AddMessageMethods()."""
723
724 def SerializeToString(self):
725 # Check if the message has all of its required fields set.
726 errors = []
727 if not self.IsInitialized():
728 raise message_mod.EncodeError(
729 'Message is missing required fields: ' +
730 ','.join(self.FindInitializationErrors()))
731 return self.SerializePartialToString()
732 cls.SerializeToString = SerializeToString
733
734
735 def _AddSerializePartialToStringMethod(message_descriptor, cls):
736 """Helper for _AddMessageMethods()."""
737
738 def SerializePartialToString(self):
739 out = StringIO()
740 self._InternalSerialize(out.write)
741 return out.getvalue()
742 cls.SerializePartialToString = SerializePartialToString
743
744 def InternalSerialize(self, write_bytes):
745 for field_descriptor, field_value in self.ListFields():
746 field_descriptor._encoder(write_bytes, field_value)
747 cls._InternalSerialize = InternalSerialize
748
749
750 def _AddMergeFromStringMethod(message_descriptor, cls):
751 """Helper for _AddMessageMethods()."""
752 def MergeFromString(self, serialized):
753 length = len(serialized)
754 try:
755 if self._InternalParse(serialized, 0, length) != length:
756 # The only reason _InternalParse would return early is if it
757 # encountered an end-group tag.
758 raise message_mod.DecodeError('Unexpected end-group tag.')
759 except IndexError:
760 raise message_mod.DecodeError('Truncated message.')
761 except struct.error, e:
762 raise message_mod.DecodeError(e)
763 return length # Return this for legacy reasons.
764 cls.MergeFromString = MergeFromString
765
766 local_ReadTag = decoder.ReadTag
767 local_SkipField = decoder.SkipField
768 decoders_by_tag = cls._decoders_by_tag
769
770 def InternalParse(self, buffer, pos, end):
771 self._Modified()
772 field_dict = self._fields
773 while pos != end:
774 (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
775 field_decoder = decoders_by_tag.get(tag_bytes)
776 if field_decoder is None:
777 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
778 if new_pos == -1:
779 return pos
780 pos = new_pos
781 else:
782 pos = field_decoder(buffer, new_pos, end, self, field_dict)
783 return pos
784 cls._InternalParse = InternalParse
785
786
787 def _AddIsInitializedMethod(message_descriptor, cls):
788 """Adds the IsInitialized and FindInitializationError methods to the
789 protocol message class."""
790
791 required_fields = [field for field in message_descriptor.fields
792 if field.label == _FieldDescriptor.LABEL_REQUIRED]
793
794 def IsInitialized(self, errors=None):
795 """Checks if all required fields of a message are set.
796
797 Args:
798 errors: A list which, if provided, will be populated with the field
799 paths of all missing required fields.
800
801 Returns:
802 True iff the specified message has all required fields set.
803 """
804
805 # Performance is critical so we avoid HasField() and ListFields().
806
807 for field in required_fields:
808 if (field not in self._fields or
809 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
810 not self._fields[field]._is_present_in_parent)):
811 if errors is not None:
812 errors.extend(self.FindInitializationErrors())
813 return False
814
815 for field, value in self._fields.iteritems():
816 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
817 if field.label == _FieldDescriptor.LABEL_REPEATED:
818 for element in value:
819 if not element.IsInitialized():
820 if errors is not None:
821 errors.extend(self.FindInitializationErrors())
822 return False
823 elif value._is_present_in_parent and not value.IsInitialized():
824 if errors is not None:
825 errors.extend(self.FindInitializationErrors())
826 return False
827
828 return True
829
830 cls.IsInitialized = IsInitialized
831
832 def FindInitializationErrors(self):
833 """Finds required fields which are not initialized.
834
835 Returns:
836 A list of strings. Each string is a path to an uninitialized field from
837 the top-level message, e.g. "foo.bar[5].baz".
838 """
839
840 errors = [] # simplify things
841
842 for field in required_fields:
843 if not self.HasField(field.name):
844 errors.append(field.name)
845
846 for field, value in self.ListFields():
847 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
848 if field.is_extension:
849 name = "(%s)" % field.full_name
850 else:
851 name = field.name
852
853 if field.label == _FieldDescriptor.LABEL_REPEATED:
854 for i in xrange(len(value)):
855 element = value[i]
856 prefix = "%s[%d]." % (name, i)
857 sub_errors = element.FindInitializationErrors()
858 errors += [ prefix + error for error in sub_errors ]
859 else:
860 prefix = name + "."
861 sub_errors = value.FindInitializationErrors()
862 errors += [ prefix + error for error in sub_errors ]
863
864 return errors
865
866 cls.FindInitializationErrors = FindInitializationErrors
867
868
869 def _AddMergeFromMethod(cls):
870 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
871 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
872
873 def MergeFrom(self, msg):
874 if not isinstance(msg, cls):
875 raise TypeError(
876 "Parameter to MergeFrom() must be instance of same class.")
877
878 assert msg is not self
879 self._Modified()
880
881 fields = self._fields
882
883 for field, value in msg._fields.iteritems():
884 if field.label == LABEL_REPEATED:
885 field_value = fields.get(field)
886 if field_value is None:
887 # Construct a new object to represent this field.
888 field_value = field._default_constructor(self)
889 fields[field] = field_value
890 field_value.MergeFrom(value)
891 elif field.cpp_type == CPPTYPE_MESSAGE:
892 if value._is_present_in_parent:
893 field_value = fields.get(field)
894 if field_value is None:
895 # Construct a new object to represent this field.
896 field_value = field._default_constructor(self)
897 fields[field] = field_value
898 field_value.MergeFrom(value)
899 else:
900 self._fields[field] = value
901 cls.MergeFrom = MergeFrom
902
903
904 def _AddMessageMethods(message_descriptor, cls):
905 """Adds implementations of all Message methods to cls."""
906 _AddListFieldsMethod(message_descriptor, cls)
907 _AddHasFieldMethod(message_descriptor, cls)
908 _AddClearFieldMethod(message_descriptor, cls)
909 if message_descriptor.is_extendable:
910 _AddClearExtensionMethod(cls)
911 _AddHasExtensionMethod(cls)
912 _AddClearMethod(message_descriptor, cls)
913 _AddEqualsMethod(message_descriptor, cls)
914 _AddStrMethod(message_descriptor, cls)
915 _AddUnicodeMethod(message_descriptor, cls)
916 _AddSetListenerMethod(cls)
917 _AddByteSizeMethod(message_descriptor, cls)
918 _AddSerializeToStringMethod(message_descriptor, cls)
919 _AddSerializePartialToStringMethod(message_descriptor, cls)
920 _AddMergeFromStringMethod(message_descriptor, cls)
921 _AddIsInitializedMethod(message_descriptor, cls)
922 _AddMergeFromMethod(cls)
923
924
925 def _AddPrivateHelperMethods(cls):
926 """Adds implementation of private helper methods to cls."""
927
928 def Modified(self):
929 """Sets the _cached_byte_size_dirty bit to true,
930 and propagates this to our listener iff this was a state change.
931 """
932
933 # Note: Some callers check _cached_byte_size_dirty before calling
934 # _Modified() as an extra optimization. So, if this method is ever
935 # changed such that it does stuff even when _cached_byte_size_dirty is
936 # already true, the callers need to be updated.
937 if not self._cached_byte_size_dirty:
938 self._cached_byte_size_dirty = True
939 self._listener_for_children.dirty = True
940 self._is_present_in_parent = True
941 self._listener.Modified()
942
943 cls._Modified = Modified
944 cls.SetInParent = Modified
945
946
947 class _Listener(object):
948
949 """MessageListener implementation that a parent message registers with its
950 child message.
951
952 In order to support semantics like:
953
954 foo.bar.baz.qux = 23
955 assert foo.HasField('bar')
956
957 ...child objects must have back references to their parents.
958 This helper class is at the heart of this support.
959 """
960
961 def __init__(self, parent_message):
962 """Args:
963 parent_message: The message whose _Modified() method we should call when
964 we receive Modified() messages.
965 """
966 # This listener establishes a back reference from a child (contained) object
967 # to its parent (containing) object. We make this a weak reference to avoid
968 # creating cyclic garbage when the client finishes with the 'parent' object
969 # in the tree.
970 if isinstance(parent_message, weakref.ProxyType):
971 self._parent_message_weakref = parent_message
972 else:
973 self._parent_message_weakref = weakref.proxy(parent_message)
974
975 # As an optimization, we also indicate directly on the listener whether
976 # or not the parent message is dirty. This way we can avoid traversing
977 # up the tree in the common case.
978 self.dirty = False
979
980 def Modified(self):
981 if self.dirty:
982 return
983 try:
984 # Propagate the signal to our parents iff this is the first field set.
985 self._parent_message_weakref._Modified()
986 except ReferenceError:
987 # We can get here if a client has kept a reference to a child object,
988 # and is now setting a field on it, but the child's parent has been
989 # garbage-collected. This is not an error.
990 pass
991
992
993 # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
994 # TODO(robinson): Unify error handling of "unknown extension" crap.
995 # TODO(robinson): Support iteritems()-style iteration over all
996 # extensions with the "has" bits turned on?
997 class _ExtensionDict(object):
998
999 """Dict-like container for supporting an indexable "Extensions"
1000 field on proto instances.
1001
1002 Note that in all cases we expect extension handles to be
1003 FieldDescriptors.
1004 """
1005
1006 def __init__(self, extended_message):
1007 """extended_message: Message instance for which we are the Extensions dict.
1008 """
1009
1010 self._extended_message = extended_message
1011
1012 def __getitem__(self, extension_handle):
1013 """Returns the current value of the given extension handle."""
1014
1015 _VerifyExtensionHandle(self._extended_message, extension_handle)
1016
1017 result = self._extended_message._fields.get(extension_handle)
1018 if result is not None:
1019 return result
1020
1021 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1022 result = extension_handle._default_constructor(self._extended_message)
1023 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1024 result = extension_handle.message_type._concrete_class()
1025 try:
1026 result._SetListener(self._extended_message._listener_for_children)
1027 except ReferenceError:
1028 pass
1029 else:
1030 # Singular scalar -- just return the default without inserting into the
1031 # dict.
1032 return extension_handle.default_value
1033
1034 # Atomically check if another thread has preempted us and, if not, swap
1035 # in the new object we just created. If someone has preempted us, we
1036 # take that object and discard ours.
1037 # WARNING: We are relying on setdefault() being atomic. This is true
1038 # in CPython but we haven't investigated others. This warning appears
1039 # in several other locations in this file.
1040 result = self._extended_message._fields.setdefault(
1041 extension_handle, result)
1042
1043 return result
1044
1045 def __eq__(self, other):
1046 if not isinstance(other, self.__class__):
1047 return False
1048
1049 my_fields = self._extended_message.ListFields()
1050 other_fields = other._extended_message.ListFields()
1051
1052 # Get rid of non-extension fields.
1053 my_fields = [ field for field in my_fields if field.is_extension ]
1054 other_fields = [ field for field in other_fields if field.is_extension ]
1055
1056 return my_fields == other_fields
1057
1058 def __ne__(self, other):
1059 return not self == other
1060
1061 def __hash__(self):
1062 raise TypeError('unhashable object')
1063
1064 # Note that this is only meaningful for non-repeated, scalar extension
1065 # fields. Note also that we may have to call _Modified() when we do
1066 # successfully set a field this way, to set any necssary "has" bits in the
1067 # ancestors of the extended message.
1068 def __setitem__(self, extension_handle, value):
1069 """If extension_handle specifies a non-repeated, scalar extension
1070 field, sets the value of that field.
1071 """
1072
1073 _VerifyExtensionHandle(self._extended_message, extension_handle)
1074
1075 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1076 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1077 raise TypeError(
1078 'Cannot assign to extension "%s" because it is a repeated or '
1079 'composite type.' % extension_handle.full_name)
1080
1081 # It's slightly wasteful to lookup the type checker each time,
1082 # but we expect this to be a vanishingly uncommon case anyway.
1083 type_checker = type_checkers.GetTypeChecker(
1084 extension_handle.cpp_type, extension_handle.type)
1085 type_checker.CheckValue(value)
1086 self._extended_message._fields[extension_handle] = value
1087 self._extended_message._Modified()
1088
1089 def _FindExtensionByName(self, name):
1090 """Tries to find a known extension with the specified name.
1091
1092 Args:
1093 name: Extension full name.
1094
1095 Returns:
1096 Extension field descriptor.
1097 """
1098 return self._extended_message._extensions_by_name.get(name, None)
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698