Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(23)

Side by Side Diff: recipe_engine/third_party/google/protobuf/internal/python_message.py

Issue 1344583003: Recipe package system. (Closed) Base URL: git@github.com:luci/recipes-py.git@master
Patch Set: Recompiled proto Created 5 years, 3 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
(Empty)
1 # Protocol Buffers - Google's data interchange format
2 # Copyright 2008 Google Inc. All rights reserved.
3 # http://code.google.com/p/protobuf/
4 #
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are
7 # met:
8 #
9 # * Redistributions of source code must retain the above copyright
10 # notice, this list of conditions and the following disclaimer.
11 # * Redistributions in binary form must reproduce the above
12 # copyright notice, this list of conditions and the following disclaimer
13 # in the documentation and/or other materials provided with the
14 # distribution.
15 # * Neither the name of Google Inc. nor the names of its
16 # contributors may be used to endorse or promote products derived from
17 # this software without specific prior written permission.
18 #
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 # This code is meant to work on Python 2.4 and above only.
32 #
33 # TODO(robinson): Helpers for verbose, common checks like seeing if a
34 # descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36 """Contains a metaclass and helper functions used to create
37 protocol message classes from Descriptor objects at runtime.
38
39 Recall that a metaclass is the "type" of a class.
40 (A class is to a metaclass what an instance is to a class.)
41
42 In this case, we use the GeneratedProtocolMessageType metaclass
43 to inject all the useful functionality into the classes
44 output by the protocol compiler at compile-time.
45
46 The upshot of all this is that the real implementation
47 details for ALL pure-Python protocol buffers are *here in
48 this file*.
49 """
50
51 __author__ = 'robinson@google.com (Will Robinson)'
52
53 try:
54 from cStringIO import StringIO
55 except ImportError:
56 from StringIO import StringIO
57 import copy_reg
58 import struct
59 import weakref
60
61 # We use "as" to avoid name collisions with variables.
62 from google.protobuf.internal import containers
63 from google.protobuf.internal import decoder
64 from google.protobuf.internal import encoder
65 from google.protobuf.internal import enum_type_wrapper
66 from google.protobuf.internal import message_listener as message_listener_mod
67 from google.protobuf.internal import type_checkers
68 from google.protobuf.internal import wire_format
69 from google.protobuf import descriptor as descriptor_mod
70 from google.protobuf import message as message_mod
71 from google.protobuf import text_format
72
73 _FieldDescriptor = descriptor_mod.FieldDescriptor
74
75
76 def NewMessage(bases, descriptor, dictionary):
77 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
78 _AddSlots(descriptor, dictionary)
79 return bases
80
81
82 def InitMessage(descriptor, cls):
83 cls._decoders_by_tag = {}
84 cls._extensions_by_name = {}
85 cls._extensions_by_number = {}
86 if (descriptor.has_options and
87 descriptor.GetOptions().message_set_wire_format):
88 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
89 decoder.MessageSetItemDecoder(cls._extensions_by_number))
90
91 # Attach stuff to each FieldDescriptor for quick lookup later on.
92 for field in descriptor.fields:
93 _AttachFieldHelpers(cls, field)
94
95 _AddEnumValues(descriptor, cls)
96 _AddInitMethod(descriptor, cls)
97 _AddPropertiesForFields(descriptor, cls)
98 _AddPropertiesForExtensions(descriptor, cls)
99 _AddStaticMethods(cls)
100 _AddMessageMethods(descriptor, cls)
101 _AddPrivateHelperMethods(cls)
102 copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
103
104
105 # Stateless helpers for GeneratedProtocolMessageType below.
106 # Outside clients should not access these directly.
107 #
108 # I opted not to make any of these methods on the metaclass, to make it more
109 # clear that I'm not really using any state there and to keep clients from
110 # thinking that they have direct access to these construction helpers.
111
112
113 def _PropertyName(proto_field_name):
114 """Returns the name of the public property attribute which
115 clients can use to get and (in some cases) set the value
116 of a protocol message field.
117
118 Args:
119 proto_field_name: The protocol message field name, exactly
120 as it appears (or would appear) in a .proto file.
121 """
122 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
123 # nnorwitz makes my day by writing:
124 # """
125 # FYI. See the keyword module in the stdlib. This could be as simple as:
126 #
127 # if keyword.iskeyword(proto_field_name):
128 # return proto_field_name + "_"
129 # return proto_field_name
130 # """
131 # Kenton says: The above is a BAD IDEA. People rely on being able to use
132 # getattr() and setattr() to reflectively manipulate field values. If we
133 # rename the properties, then every such user has to also make sure to apply
134 # the same transformation. Note that currently if you name a field "yield",
135 # you can still access it just fine using getattr/setattr -- it's not even
136 # that cumbersome to do so.
137 # TODO(kenton): Remove this method entirely if/when everyone agrees with my
138 # position.
139 return proto_field_name
140
141
142 def _VerifyExtensionHandle(message, extension_handle):
143 """Verify that the given extension handle is valid."""
144
145 if not isinstance(extension_handle, _FieldDescriptor):
146 raise KeyError('HasExtension() expects an extension handle, got: %s' %
147 extension_handle)
148
149 if not extension_handle.is_extension:
150 raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
151
152 if not extension_handle.containing_type:
153 raise KeyError('"%s" is missing a containing_type.'
154 % extension_handle.full_name)
155
156 if extension_handle.containing_type is not message.DESCRIPTOR:
157 raise KeyError('Extension "%s" extends message type "%s", but this '
158 'message is of type "%s".' %
159 (extension_handle.full_name,
160 extension_handle.containing_type.full_name,
161 message.DESCRIPTOR.full_name))
162
163
164 def _AddSlots(message_descriptor, dictionary):
165 """Adds a __slots__ entry to dictionary, containing the names of all valid
166 attributes for this message type.
167
168 Args:
169 message_descriptor: A Descriptor instance describing this message type.
170 dictionary: Class dictionary to which we'll add a '__slots__' entry.
171 """
172 dictionary['__slots__'] = ['_cached_byte_size',
173 '_cached_byte_size_dirty',
174 '_fields',
175 '_unknown_fields',
176 '_is_present_in_parent',
177 '_listener',
178 '_listener_for_children',
179 '__weakref__']
180
181
182 def _IsMessageSetExtension(field):
183 return (field.is_extension and
184 field.containing_type.has_options and
185 field.containing_type.GetOptions().message_set_wire_format and
186 field.type == _FieldDescriptor.TYPE_MESSAGE and
187 field.message_type == field.extension_scope and
188 field.label == _FieldDescriptor.LABEL_OPTIONAL)
189
190
191 def _AttachFieldHelpers(cls, field_descriptor):
192 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
193 is_packed = (field_descriptor.has_options and
194 field_descriptor.GetOptions().packed)
195
196 if _IsMessageSetExtension(field_descriptor):
197 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
198 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
199 else:
200 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
201 field_descriptor.number, is_repeated, is_packed)
202 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
203 field_descriptor.number, is_repeated, is_packed)
204
205 field_descriptor._encoder = field_encoder
206 field_descriptor._sizer = sizer
207 field_descriptor._default_constructor = _DefaultValueConstructorForField(
208 field_descriptor)
209
210 def AddDecoder(wiretype, is_packed):
211 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
212 cls._decoders_by_tag[tag_bytes] = (
213 type_checkers.TYPE_TO_DECODER[field_descriptor.type](
214 field_descriptor.number, is_repeated, is_packed,
215 field_descriptor, field_descriptor._default_constructor))
216
217 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
218 False)
219
220 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
221 # To support wire compatibility of adding packed = true, add a decoder for
222 # packed values regardless of the field's options.
223 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
224
225
226 def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
227 extension_dict = descriptor.extensions_by_name
228 for extension_name, extension_field in extension_dict.iteritems():
229 assert extension_name not in dictionary
230 dictionary[extension_name] = extension_field
231
232
233 def _AddEnumValues(descriptor, cls):
234 """Sets class-level attributes for all enum fields defined in this message.
235
236 Also exporting a class-level object that can name enum values.
237
238 Args:
239 descriptor: Descriptor object for this message type.
240 cls: Class we're constructing for this message type.
241 """
242 for enum_type in descriptor.enum_types:
243 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
244 for enum_value in enum_type.values:
245 setattr(cls, enum_value.name, enum_value.number)
246
247
248 def _DefaultValueConstructorForField(field):
249 """Returns a function which returns a default value for a field.
250
251 Args:
252 field: FieldDescriptor object for this field.
253
254 The returned function has one argument:
255 message: Message instance containing this field, or a weakref proxy
256 of same.
257
258 That function in turn returns a default value for this field. The default
259 value may refer back to |message| via a weak reference.
260 """
261
262 if field.label == _FieldDescriptor.LABEL_REPEATED:
263 if field.has_default_value and field.default_value != []:
264 raise ValueError('Repeated field default value not empty list: %s' % (
265 field.default_value))
266 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
267 # We can't look at _concrete_class yet since it might not have
268 # been set. (Depends on order in which we initialize the classes).
269 message_type = field.message_type
270 def MakeRepeatedMessageDefault(message):
271 return containers.RepeatedCompositeFieldContainer(
272 message._listener_for_children, field.message_type)
273 return MakeRepeatedMessageDefault
274 else:
275 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
276 def MakeRepeatedScalarDefault(message):
277 return containers.RepeatedScalarFieldContainer(
278 message._listener_for_children, type_checker)
279 return MakeRepeatedScalarDefault
280
281 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
282 # _concrete_class may not yet be initialized.
283 message_type = field.message_type
284 def MakeSubMessageDefault(message):
285 result = message_type._concrete_class()
286 result._SetListener(message._listener_for_children)
287 return result
288 return MakeSubMessageDefault
289
290 def MakeScalarDefault(message):
291 # TODO(protobuf-team): This may be broken since there may not be
292 # default_value. Combine with has_default_value somehow.
293 return field.default_value
294 return MakeScalarDefault
295
296
297 def _AddInitMethod(message_descriptor, cls):
298 """Adds an __init__ method to cls."""
299 fields = message_descriptor.fields
300 def init(self, **kwargs):
301 self._cached_byte_size = 0
302 self._cached_byte_size_dirty = len(kwargs) > 0
303 self._fields = {}
304 # _unknown_fields is () when empty for efficiency, and will be turned into
305 # a list if fields are added.
306 self._unknown_fields = ()
307 self._is_present_in_parent = False
308 self._listener = message_listener_mod.NullMessageListener()
309 self._listener_for_children = _Listener(self)
310 for field_name, field_value in kwargs.iteritems():
311 field = _GetFieldByName(message_descriptor, field_name)
312 if field is None:
313 raise TypeError("%s() got an unexpected keyword argument '%s'" %
314 (message_descriptor.name, field_name))
315 if field.label == _FieldDescriptor.LABEL_REPEATED:
316 copy = field._default_constructor(self)
317 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
318 for val in field_value:
319 copy.add().MergeFrom(val)
320 else: # Scalar
321 copy.extend(field_value)
322 self._fields[field] = copy
323 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
324 copy = field._default_constructor(self)
325 copy.MergeFrom(field_value)
326 self._fields[field] = copy
327 else:
328 setattr(self, field_name, field_value)
329
330 init.__module__ = None
331 init.__doc__ = None
332 cls.__init__ = init
333
334
335 def _GetFieldByName(message_descriptor, field_name):
336 """Returns a field descriptor by field name.
337
338 Args:
339 message_descriptor: A Descriptor describing all fields in message.
340 field_name: The name of the field to retrieve.
341 Returns:
342 The field descriptor associated with the field name.
343 """
344 try:
345 return message_descriptor.fields_by_name[field_name]
346 except KeyError:
347 raise ValueError('Protocol message has no "%s" field.' % field_name)
348
349
350 def _AddPropertiesForFields(descriptor, cls):
351 """Adds properties for all fields in this protocol message type."""
352 for field in descriptor.fields:
353 _AddPropertiesForField(field, cls)
354
355 if descriptor.is_extendable:
356 # _ExtensionDict is just an adaptor with no state so we allocate a new one
357 # every time it is accessed.
358 cls.Extensions = property(lambda self: _ExtensionDict(self))
359
360
361 def _AddPropertiesForField(field, cls):
362 """Adds a public property for a protocol message field.
363 Clients can use this property to get and (in the case
364 of non-repeated scalar fields) directly set the value
365 of a protocol message field.
366
367 Args:
368 field: A FieldDescriptor for this field.
369 cls: The class we're constructing.
370 """
371 # Catch it if we add other types that we should
372 # handle specially here.
373 assert _FieldDescriptor.MAX_CPPTYPE == 10
374
375 constant_name = field.name.upper() + "_FIELD_NUMBER"
376 setattr(cls, constant_name, field.number)
377
378 if field.label == _FieldDescriptor.LABEL_REPEATED:
379 _AddPropertiesForRepeatedField(field, cls)
380 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
381 _AddPropertiesForNonRepeatedCompositeField(field, cls)
382 else:
383 _AddPropertiesForNonRepeatedScalarField(field, cls)
384
385
386 def _AddPropertiesForRepeatedField(field, cls):
387 """Adds a public property for a "repeated" protocol message field. Clients
388 can use this property to get the value of the field, which will be either a
389 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
390 below).
391
392 Note that when clients add values to these containers, we perform
393 type-checking in the case of repeated scalar fields, and we also set any
394 necessary "has" bits as a side-effect.
395
396 Args:
397 field: A FieldDescriptor for this field.
398 cls: The class we're constructing.
399 """
400 proto_field_name = field.name
401 property_name = _PropertyName(proto_field_name)
402
403 def getter(self):
404 field_value = self._fields.get(field)
405 if field_value is None:
406 # Construct a new object to represent this field.
407 field_value = field._default_constructor(self)
408
409 # Atomically check if another thread has preempted us and, if not, swap
410 # in the new object we just created. If someone has preempted us, we
411 # take that object and discard ours.
412 # WARNING: We are relying on setdefault() being atomic. This is true
413 # in CPython but we haven't investigated others. This warning appears
414 # in several other locations in this file.
415 field_value = self._fields.setdefault(field, field_value)
416 return field_value
417 getter.__module__ = None
418 getter.__doc__ = 'Getter for %s.' % proto_field_name
419
420 # We define a setter just so we can throw an exception with a more
421 # helpful error message.
422 def setter(self, new_value):
423 raise AttributeError('Assignment not allowed to repeated field '
424 '"%s" in protocol message object.' % proto_field_name)
425
426 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
427 setattr(cls, property_name, property(getter, setter, doc=doc))
428
429
430 def _AddPropertiesForNonRepeatedScalarField(field, cls):
431 """Adds a public property for a nonrepeated, scalar protocol message field.
432 Clients can use this property to get and directly set the value of the field.
433 Note that when the client sets the value of a field by using this property,
434 all necessary "has" bits are set as a side-effect, and we also perform
435 type-checking.
436
437 Args:
438 field: A FieldDescriptor for this field.
439 cls: The class we're constructing.
440 """
441 proto_field_name = field.name
442 property_name = _PropertyName(proto_field_name)
443 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
444 default_value = field.default_value
445 valid_values = set()
446
447 def getter(self):
448 # TODO(protobuf-team): This may be broken since there may not be
449 # default_value. Combine with has_default_value somehow.
450 return self._fields.get(field, default_value)
451 getter.__module__ = None
452 getter.__doc__ = 'Getter for %s.' % proto_field_name
453 def setter(self, new_value):
454 type_checker.CheckValue(new_value)
455 self._fields[field] = new_value
456 # Check _cached_byte_size_dirty inline to improve performance, since scalar
457 # setters are called frequently.
458 if not self._cached_byte_size_dirty:
459 self._Modified()
460
461 setter.__module__ = None
462 setter.__doc__ = 'Setter for %s.' % proto_field_name
463
464 # Add a property to encapsulate the getter/setter.
465 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
466 setattr(cls, property_name, property(getter, setter, doc=doc))
467
468
469 def _AddPropertiesForNonRepeatedCompositeField(field, cls):
470 """Adds a public property for a nonrepeated, composite protocol message field.
471 A composite field is a "group" or "message" field.
472
473 Clients can use this property to get the value of the field, but cannot
474 assign to the property directly.
475
476 Args:
477 field: A FieldDescriptor for this field.
478 cls: The class we're constructing.
479 """
480 # TODO(robinson): Remove duplication with similar method
481 # for non-repeated scalars.
482 proto_field_name = field.name
483 property_name = _PropertyName(proto_field_name)
484
485 # TODO(komarek): Can anyone explain to me why we cache the message_type this
486 # way, instead of referring to field.message_type inside of getter(self)?
487 # What if someone sets message_type later on (which makes for simpler
488 # dyanmic proto descriptor and class creation code).
489 message_type = field.message_type
490
491 def getter(self):
492 field_value = self._fields.get(field)
493 if field_value is None:
494 # Construct a new object to represent this field.
495 field_value = message_type._concrete_class() # use field.message_type?
496 field_value._SetListener(self._listener_for_children)
497
498 # Atomically check if another thread has preempted us and, if not, swap
499 # in the new object we just created. If someone has preempted us, we
500 # take that object and discard ours.
501 # WARNING: We are relying on setdefault() being atomic. This is true
502 # in CPython but we haven't investigated others. This warning appears
503 # in several other locations in this file.
504 field_value = self._fields.setdefault(field, field_value)
505 return field_value
506 getter.__module__ = None
507 getter.__doc__ = 'Getter for %s.' % proto_field_name
508
509 # We define a setter just so we can throw an exception with a more
510 # helpful error message.
511 def setter(self, new_value):
512 raise AttributeError('Assignment not allowed to composite field '
513 '"%s" in protocol message object.' % proto_field_name)
514
515 # Add a property to encapsulate the getter.
516 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
517 setattr(cls, property_name, property(getter, setter, doc=doc))
518
519
520 def _AddPropertiesForExtensions(descriptor, cls):
521 """Adds properties for all fields in this protocol message type."""
522 extension_dict = descriptor.extensions_by_name
523 for extension_name, extension_field in extension_dict.iteritems():
524 constant_name = extension_name.upper() + "_FIELD_NUMBER"
525 setattr(cls, constant_name, extension_field.number)
526
527
528 def _AddStaticMethods(cls):
529 # TODO(robinson): This probably needs to be thread-safe(?)
530 def RegisterExtension(extension_handle):
531 extension_handle.containing_type = cls.DESCRIPTOR
532 _AttachFieldHelpers(cls, extension_handle)
533
534 # Try to insert our extension, failing if an extension with the same number
535 # already exists.
536 actual_handle = cls._extensions_by_number.setdefault(
537 extension_handle.number, extension_handle)
538 if actual_handle is not extension_handle:
539 raise AssertionError(
540 'Extensions "%s" and "%s" both try to extend message type "%s" with '
541 'field number %d.' %
542 (extension_handle.full_name, actual_handle.full_name,
543 cls.DESCRIPTOR.full_name, extension_handle.number))
544
545 cls._extensions_by_name[extension_handle.full_name] = extension_handle
546
547 handle = extension_handle # avoid line wrapping
548 if _IsMessageSetExtension(handle):
549 # MessageSet extension. Also register under type name.
550 cls._extensions_by_name[
551 extension_handle.message_type.full_name] = extension_handle
552
553 cls.RegisterExtension = staticmethod(RegisterExtension)
554
555 def FromString(s):
556 message = cls()
557 message.MergeFromString(s)
558 return message
559 cls.FromString = staticmethod(FromString)
560
561
562 def _IsPresent(item):
563 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
564 value should be included in the list returned by ListFields()."""
565
566 if item[0].label == _FieldDescriptor.LABEL_REPEATED:
567 return bool(item[1])
568 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
569 return item[1]._is_present_in_parent
570 else:
571 return True
572
573
574 def _AddListFieldsMethod(message_descriptor, cls):
575 """Helper for _AddMessageMethods()."""
576
577 def ListFields(self):
578 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
579 all_fields.sort(key = lambda item: item[0].number)
580 return all_fields
581
582 cls.ListFields = ListFields
583
584
585 def _AddHasFieldMethod(message_descriptor, cls):
586 """Helper for _AddMessageMethods()."""
587
588 singular_fields = {}
589 for field in message_descriptor.fields:
590 if field.label != _FieldDescriptor.LABEL_REPEATED:
591 singular_fields[field.name] = field
592
593 def HasField(self, field_name):
594 try:
595 field = singular_fields[field_name]
596 except KeyError:
597 raise ValueError(
598 'Protocol message has no singular "%s" field.' % field_name)
599
600 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
601 value = self._fields.get(field)
602 return value is not None and value._is_present_in_parent
603 else:
604 return field in self._fields
605 cls.HasField = HasField
606
607
608 def _AddClearFieldMethod(message_descriptor, cls):
609 """Helper for _AddMessageMethods()."""
610 def ClearField(self, field_name):
611 try:
612 field = message_descriptor.fields_by_name[field_name]
613 except KeyError:
614 raise ValueError('Protocol message has no "%s" field.' % field_name)
615
616 if field in self._fields:
617 # Note: If the field is a sub-message, its listener will still point
618 # at us. That's fine, because the worst than can happen is that it
619 # will call _Modified() and invalidate our byte size. Big deal.
620 del self._fields[field]
621
622 # Always call _Modified() -- even if nothing was changed, this is
623 # a mutating method, and thus calling it should cause the field to become
624 # present in the parent message.
625 self._Modified()
626
627 cls.ClearField = ClearField
628
629
630 def _AddClearExtensionMethod(cls):
631 """Helper for _AddMessageMethods()."""
632 def ClearExtension(self, extension_handle):
633 _VerifyExtensionHandle(self, extension_handle)
634
635 # Similar to ClearField(), above.
636 if extension_handle in self._fields:
637 del self._fields[extension_handle]
638 self._Modified()
639 cls.ClearExtension = ClearExtension
640
641
642 def _AddClearMethod(message_descriptor, cls):
643 """Helper for _AddMessageMethods()."""
644 def Clear(self):
645 # Clear fields.
646 self._fields = {}
647 self._unknown_fields = ()
648 self._Modified()
649 cls.Clear = Clear
650
651
652 def _AddHasExtensionMethod(cls):
653 """Helper for _AddMessageMethods()."""
654 def HasExtension(self, extension_handle):
655 _VerifyExtensionHandle(self, extension_handle)
656 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
657 raise KeyError('"%s" is repeated.' % extension_handle.full_name)
658
659 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
660 value = self._fields.get(extension_handle)
661 return value is not None and value._is_present_in_parent
662 else:
663 return extension_handle in self._fields
664 cls.HasExtension = HasExtension
665
666
667 def _AddEqualsMethod(message_descriptor, cls):
668 """Helper for _AddMessageMethods()."""
669 def __eq__(self, other):
670 if (not isinstance(other, message_mod.Message) or
671 other.DESCRIPTOR != self.DESCRIPTOR):
672 return False
673
674 if self is other:
675 return True
676
677 if not self.ListFields() == other.ListFields():
678 return False
679
680 # Sort unknown fields because their order shouldn't affect equality test.
681 unknown_fields = list(self._unknown_fields)
682 unknown_fields.sort()
683 other_unknown_fields = list(other._unknown_fields)
684 other_unknown_fields.sort()
685
686 return unknown_fields == other_unknown_fields
687
688 cls.__eq__ = __eq__
689
690
691 def _AddStrMethod(message_descriptor, cls):
692 """Helper for _AddMessageMethods()."""
693 def __str__(self):
694 return text_format.MessageToString(self)
695 cls.__str__ = __str__
696
697
698 def _AddUnicodeMethod(unused_message_descriptor, cls):
699 """Helper for _AddMessageMethods()."""
700
701 def __unicode__(self):
702 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
703 cls.__unicode__ = __unicode__
704
705
706 def _AddSetListenerMethod(cls):
707 """Helper for _AddMessageMethods()."""
708 def SetListener(self, listener):
709 if listener is None:
710 self._listener = message_listener_mod.NullMessageListener()
711 else:
712 self._listener = listener
713 cls._SetListener = SetListener
714
715
716 def _BytesForNonRepeatedElement(value, field_number, field_type):
717 """Returns the number of bytes needed to serialize a non-repeated element.
718 The returned byte count includes space for tag information and any
719 other additional space associated with serializing value.
720
721 Args:
722 value: Value we're serializing.
723 field_number: Field number of this value. (Since the field number
724 is stored as part of a varint-encoded tag, this has an impact
725 on the total bytes required to serialize the value).
726 field_type: The type of the field. One of the TYPE_* constants
727 within FieldDescriptor.
728 """
729 try:
730 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
731 return fn(field_number, value)
732 except KeyError:
733 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
734
735
736 def _AddByteSizeMethod(message_descriptor, cls):
737 """Helper for _AddMessageMethods()."""
738
739 def ByteSize(self):
740 if not self._cached_byte_size_dirty:
741 return self._cached_byte_size
742
743 size = 0
744 for field_descriptor, field_value in self.ListFields():
745 size += field_descriptor._sizer(field_value)
746
747 for tag_bytes, value_bytes in self._unknown_fields:
748 size += len(tag_bytes) + len(value_bytes)
749
750 self._cached_byte_size = size
751 self._cached_byte_size_dirty = False
752 self._listener_for_children.dirty = False
753 return size
754
755 cls.ByteSize = ByteSize
756
757
758 def _AddSerializeToStringMethod(message_descriptor, cls):
759 """Helper for _AddMessageMethods()."""
760
761 def SerializeToString(self):
762 # Check if the message has all of its required fields set.
763 errors = []
764 if not self.IsInitialized():
765 raise message_mod.EncodeError(
766 'Message %s is missing required fields: %s' % (
767 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
768 return self.SerializePartialToString()
769 cls.SerializeToString = SerializeToString
770
771
772 def _AddSerializePartialToStringMethod(message_descriptor, cls):
773 """Helper for _AddMessageMethods()."""
774
775 def SerializePartialToString(self):
776 out = StringIO()
777 self._InternalSerialize(out.write)
778 return out.getvalue()
779 cls.SerializePartialToString = SerializePartialToString
780
781 def InternalSerialize(self, write_bytes):
782 for field_descriptor, field_value in self.ListFields():
783 field_descriptor._encoder(write_bytes, field_value)
784 for tag_bytes, value_bytes in self._unknown_fields:
785 write_bytes(tag_bytes)
786 write_bytes(value_bytes)
787 cls._InternalSerialize = InternalSerialize
788
789
790 def _AddMergeFromStringMethod(message_descriptor, cls):
791 """Helper for _AddMessageMethods()."""
792 def MergeFromString(self, serialized):
793 length = len(serialized)
794 try:
795 if self._InternalParse(serialized, 0, length) != length:
796 # The only reason _InternalParse would return early is if it
797 # encountered an end-group tag.
798 raise message_mod.DecodeError('Unexpected end-group tag.')
799 except IndexError:
800 raise message_mod.DecodeError('Truncated message.')
801 except struct.error, e:
802 raise message_mod.DecodeError(e)
803 return length # Return this for legacy reasons.
804 cls.MergeFromString = MergeFromString
805
806 local_ReadTag = decoder.ReadTag
807 local_SkipField = decoder.SkipField
808 decoders_by_tag = cls._decoders_by_tag
809
810 def InternalParse(self, buffer, pos, end):
811 self._Modified()
812 field_dict = self._fields
813 unknown_field_list = self._unknown_fields
814 while pos != end:
815 (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
816 field_decoder = decoders_by_tag.get(tag_bytes)
817 if field_decoder is None:
818 value_start_pos = new_pos
819 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
820 if new_pos == -1:
821 return pos
822 if not unknown_field_list:
823 unknown_field_list = self._unknown_fields = []
824 unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
825 pos = new_pos
826 else:
827 pos = field_decoder(buffer, new_pos, end, self, field_dict)
828 return pos
829 cls._InternalParse = InternalParse
830
831
832 def _AddIsInitializedMethod(message_descriptor, cls):
833 """Adds the IsInitialized and FindInitializationError methods to the
834 protocol message class."""
835
836 required_fields = [field for field in message_descriptor.fields
837 if field.label == _FieldDescriptor.LABEL_REQUIRED]
838
839 def IsInitialized(self, errors=None):
840 """Checks if all required fields of a message are set.
841
842 Args:
843 errors: A list which, if provided, will be populated with the field
844 paths of all missing required fields.
845
846 Returns:
847 True iff the specified message has all required fields set.
848 """
849
850 # Performance is critical so we avoid HasField() and ListFields().
851
852 for field in required_fields:
853 if (field not in self._fields or
854 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
855 not self._fields[field]._is_present_in_parent)):
856 if errors is not None:
857 errors.extend(self.FindInitializationErrors())
858 return False
859
860 for field, value in self._fields.iteritems():
861 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
862 if field.label == _FieldDescriptor.LABEL_REPEATED:
863 for element in value:
864 if not element.IsInitialized():
865 if errors is not None:
866 errors.extend(self.FindInitializationErrors())
867 return False
868 elif value._is_present_in_parent and not value.IsInitialized():
869 if errors is not None:
870 errors.extend(self.FindInitializationErrors())
871 return False
872
873 return True
874
875 cls.IsInitialized = IsInitialized
876
877 def FindInitializationErrors(self):
878 """Finds required fields which are not initialized.
879
880 Returns:
881 A list of strings. Each string is a path to an uninitialized field from
882 the top-level message, e.g. "foo.bar[5].baz".
883 """
884
885 errors = [] # simplify things
886
887 for field in required_fields:
888 if not self.HasField(field.name):
889 errors.append(field.name)
890
891 for field, value in self.ListFields():
892 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
893 if field.is_extension:
894 name = "(%s)" % field.full_name
895 else:
896 name = field.name
897
898 if field.label == _FieldDescriptor.LABEL_REPEATED:
899 for i in xrange(len(value)):
900 element = value[i]
901 prefix = "%s[%d]." % (name, i)
902 sub_errors = element.FindInitializationErrors()
903 errors += [ prefix + error for error in sub_errors ]
904 else:
905 prefix = name + "."
906 sub_errors = value.FindInitializationErrors()
907 errors += [ prefix + error for error in sub_errors ]
908
909 return errors
910
911 cls.FindInitializationErrors = FindInitializationErrors
912
913
914 def _AddMergeFromMethod(cls):
915 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
916 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
917
918 def MergeFrom(self, msg):
919 if not isinstance(msg, cls):
920 raise TypeError(
921 "Parameter to MergeFrom() must be instance of same class: "
922 "expected %s got %s." % (cls.__name__, type(msg).__name__))
923
924 assert msg is not self
925 self._Modified()
926
927 fields = self._fields
928
929 for field, value in msg._fields.iteritems():
930 if field.label == LABEL_REPEATED:
931 field_value = fields.get(field)
932 if field_value is None:
933 # Construct a new object to represent this field.
934 field_value = field._default_constructor(self)
935 fields[field] = field_value
936 field_value.MergeFrom(value)
937 elif field.cpp_type == CPPTYPE_MESSAGE:
938 if value._is_present_in_parent:
939 field_value = fields.get(field)
940 if field_value is None:
941 # Construct a new object to represent this field.
942 field_value = field._default_constructor(self)
943 fields[field] = field_value
944 field_value.MergeFrom(value)
945 else:
946 self._fields[field] = value
947
948 if msg._unknown_fields:
949 if not self._unknown_fields:
950 self._unknown_fields = []
951 self._unknown_fields.extend(msg._unknown_fields)
952
953 cls.MergeFrom = MergeFrom
954
955
956 def _AddMessageMethods(message_descriptor, cls):
957 """Adds implementations of all Message methods to cls."""
958 _AddListFieldsMethod(message_descriptor, cls)
959 _AddHasFieldMethod(message_descriptor, cls)
960 _AddClearFieldMethod(message_descriptor, cls)
961 if message_descriptor.is_extendable:
962 _AddClearExtensionMethod(cls)
963 _AddHasExtensionMethod(cls)
964 _AddClearMethod(message_descriptor, cls)
965 _AddEqualsMethod(message_descriptor, cls)
966 _AddStrMethod(message_descriptor, cls)
967 _AddUnicodeMethod(message_descriptor, cls)
968 _AddSetListenerMethod(cls)
969 _AddByteSizeMethod(message_descriptor, cls)
970 _AddSerializeToStringMethod(message_descriptor, cls)
971 _AddSerializePartialToStringMethod(message_descriptor, cls)
972 _AddMergeFromStringMethod(message_descriptor, cls)
973 _AddIsInitializedMethod(message_descriptor, cls)
974 _AddMergeFromMethod(cls)
975
976
977 def _AddPrivateHelperMethods(cls):
978 """Adds implementation of private helper methods to cls."""
979
980 def Modified(self):
981 """Sets the _cached_byte_size_dirty bit to true,
982 and propagates this to our listener iff this was a state change.
983 """
984
985 # Note: Some callers check _cached_byte_size_dirty before calling
986 # _Modified() as an extra optimization. So, if this method is ever
987 # changed such that it does stuff even when _cached_byte_size_dirty is
988 # already true, the callers need to be updated.
989 if not self._cached_byte_size_dirty:
990 self._cached_byte_size_dirty = True
991 self._listener_for_children.dirty = True
992 self._is_present_in_parent = True
993 self._listener.Modified()
994
995 cls._Modified = Modified
996 cls.SetInParent = Modified
997
998
999 class _Listener(object):
1000
1001 """MessageListener implementation that a parent message registers with its
1002 child message.
1003
1004 In order to support semantics like:
1005
1006 foo.bar.baz.qux = 23
1007 assert foo.HasField('bar')
1008
1009 ...child objects must have back references to their parents.
1010 This helper class is at the heart of this support.
1011 """
1012
1013 def __init__(self, parent_message):
1014 """Args:
1015 parent_message: The message whose _Modified() method we should call when
1016 we receive Modified() messages.
1017 """
1018 # This listener establishes a back reference from a child (contained) object
1019 # to its parent (containing) object. We make this a weak reference to avoid
1020 # creating cyclic garbage when the client finishes with the 'parent' object
1021 # in the tree.
1022 if isinstance(parent_message, weakref.ProxyType):
1023 self._parent_message_weakref = parent_message
1024 else:
1025 self._parent_message_weakref = weakref.proxy(parent_message)
1026
1027 # As an optimization, we also indicate directly on the listener whether
1028 # or not the parent message is dirty. This way we can avoid traversing
1029 # up the tree in the common case.
1030 self.dirty = False
1031
1032 def Modified(self):
1033 if self.dirty:
1034 return
1035 try:
1036 # Propagate the signal to our parents iff this is the first field set.
1037 self._parent_message_weakref._Modified()
1038 except ReferenceError:
1039 # We can get here if a client has kept a reference to a child object,
1040 # and is now setting a field on it, but the child's parent has been
1041 # garbage-collected. This is not an error.
1042 pass
1043
1044
1045 # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
1046 # TODO(robinson): Unify error handling of "unknown extension" crap.
1047 # TODO(robinson): Support iteritems()-style iteration over all
1048 # extensions with the "has" bits turned on?
1049 class _ExtensionDict(object):
1050
1051 """Dict-like container for supporting an indexable "Extensions"
1052 field on proto instances.
1053
1054 Note that in all cases we expect extension handles to be
1055 FieldDescriptors.
1056 """
1057
1058 def __init__(self, extended_message):
1059 """extended_message: Message instance for which we are the Extensions dict.
1060 """
1061
1062 self._extended_message = extended_message
1063
1064 def __getitem__(self, extension_handle):
1065 """Returns the current value of the given extension handle."""
1066
1067 _VerifyExtensionHandle(self._extended_message, extension_handle)
1068
1069 result = self._extended_message._fields.get(extension_handle)
1070 if result is not None:
1071 return result
1072
1073 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1074 result = extension_handle._default_constructor(self._extended_message)
1075 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1076 result = extension_handle.message_type._concrete_class()
1077 try:
1078 result._SetListener(self._extended_message._listener_for_children)
1079 except ReferenceError:
1080 pass
1081 else:
1082 # Singular scalar -- just return the default without inserting into the
1083 # dict.
1084 return extension_handle.default_value
1085
1086 # Atomically check if another thread has preempted us and, if not, swap
1087 # in the new object we just created. If someone has preempted us, we
1088 # take that object and discard ours.
1089 # WARNING: We are relying on setdefault() being atomic. This is true
1090 # in CPython but we haven't investigated others. This warning appears
1091 # in several other locations in this file.
1092 result = self._extended_message._fields.setdefault(
1093 extension_handle, result)
1094
1095 return result
1096
1097 def __eq__(self, other):
1098 if not isinstance(other, self.__class__):
1099 return False
1100
1101 my_fields = self._extended_message.ListFields()
1102 other_fields = other._extended_message.ListFields()
1103
1104 # Get rid of non-extension fields.
1105 my_fields = [ field for field in my_fields if field.is_extension ]
1106 other_fields = [ field for field in other_fields if field.is_extension ]
1107
1108 return my_fields == other_fields
1109
1110 def __ne__(self, other):
1111 return not self == other
1112
1113 def __hash__(self):
1114 raise TypeError('unhashable object')
1115
1116 # Note that this is only meaningful for non-repeated, scalar extension
1117 # fields. Note also that we may have to call _Modified() when we do
1118 # successfully set a field this way, to set any necssary "has" bits in the
1119 # ancestors of the extended message.
1120 def __setitem__(self, extension_handle, value):
1121 """If extension_handle specifies a non-repeated, scalar extension
1122 field, sets the value of that field.
1123 """
1124
1125 _VerifyExtensionHandle(self._extended_message, extension_handle)
1126
1127 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1128 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1129 raise TypeError(
1130 'Cannot assign to extension "%s" because it is a repeated or '
1131 'composite type.' % extension_handle.full_name)
1132
1133 # It's slightly wasteful to lookup the type checker each time,
1134 # but we expect this to be a vanishingly uncommon case anyway.
1135 type_checker = type_checkers.GetTypeChecker(
1136 extension_handle.cpp_type, extension_handle.type)
1137 type_checker.CheckValue(value)
1138 self._extended_message._fields[extension_handle] = value
1139 self._extended_message._Modified()
1140
1141 def _FindExtensionByName(self, name):
1142 """Tries to find a known extension with the specified name.
1143
1144 Args:
1145 name: Extension full name.
1146
1147 Returns:
1148 Extension field descriptor.
1149 """
1150 return self._extended_message._extensions_by_name.get(name, None)
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698