Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(707)

Side by Side Diff: third_party/google/protobuf/internal/python_message.py

Issue 1153333003: Added tools to retrieve CQ builders from a CQ config (Closed) Base URL: https://chromium.googlesource.com/chromium/tools/depot_tools.git@master
Patch Set: Addressed comments Created 5 years, 6 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
(Empty)
1 # Protocol Buffers - Google's data interchange format
2 # Copyright 2008 Google Inc. All rights reserved.
3 # http://code.google.com/p/protobuf/
4 #
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are
7 # met:
8 #
9 # * Redistributions of source code must retain the above copyright
10 # notice, this list of conditions and the following disclaimer.
11 # * Redistributions in binary form must reproduce the above
12 # copyright notice, this list of conditions and the following disclaimer
13 # in the documentation and/or other materials provided with the
14 # distribution.
15 # * Neither the name of Google Inc. nor the names of its
16 # contributors may be used to endorse or promote products derived from
17 # this software without specific prior written permission.
18 #
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 # Keep it Python2.5 compatible for GAE.
32 #
33 # Copyright 2007 Google Inc. All Rights Reserved.
34 #
35 # This code is meant to work on Python 2.4 and above only.
36 #
37 # TODO(robinson): Helpers for verbose, common checks like seeing if a
38 # descriptor's cpp_type is CPPTYPE_MESSAGE.
39
40 """Contains a metaclass and helper functions used to create
41 protocol message classes from Descriptor objects at runtime.
42
43 Recall that a metaclass is the "type" of a class.
44 (A class is to a metaclass what an instance is to a class.)
45
46 In this case, we use the GeneratedProtocolMessageType metaclass
47 to inject all the useful functionality into the classes
48 output by the protocol compiler at compile-time.
49
50 The upshot of all this is that the real implementation
51 details for ALL pure-Python protocol buffers are *here in
52 this file*.
53 """
54
55 __author__ = 'robinson@google.com (Will Robinson)'
56
57 import sys
58 if sys.version_info[0] < 3:
59 try:
60 from cStringIO import StringIO as BytesIO
61 except ImportError:
62 from StringIO import StringIO as BytesIO
63 import copy_reg as copyreg
64 else:
65 from io import BytesIO
66 import copyreg
67 import struct
68 import weakref
69
70 # We use "as" to avoid name collisions with variables.
71 from google.protobuf.internal import containers
72 from google.protobuf.internal import decoder
73 from google.protobuf.internal import encoder
74 from google.protobuf.internal import enum_type_wrapper
75 from google.protobuf.internal import message_listener as message_listener_mod
76 from google.protobuf.internal import type_checkers
77 from google.protobuf.internal import wire_format
78 from google.protobuf import descriptor as descriptor_mod
79 from google.protobuf import message as message_mod
80 from google.protobuf import text_format
81
82 _FieldDescriptor = descriptor_mod.FieldDescriptor
83
84
85 def NewMessage(bases, descriptor, dictionary):
86 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
87 _AddSlots(descriptor, dictionary)
88 return bases
89
90
91 def InitMessage(descriptor, cls):
92 cls._decoders_by_tag = {}
93 cls._extensions_by_name = {}
94 cls._extensions_by_number = {}
95 if (descriptor.has_options and
96 descriptor.GetOptions().message_set_wire_format):
97 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
98 decoder.MessageSetItemDecoder(cls._extensions_by_number))
99
100 # Attach stuff to each FieldDescriptor for quick lookup later on.
101 for field in descriptor.fields:
102 _AttachFieldHelpers(cls, field)
103
104 _AddEnumValues(descriptor, cls)
105 _AddInitMethod(descriptor, cls)
106 _AddPropertiesForFields(descriptor, cls)
107 _AddPropertiesForExtensions(descriptor, cls)
108 _AddStaticMethods(cls)
109 _AddMessageMethods(descriptor, cls)
110 _AddPrivateHelperMethods(descriptor, cls)
111 copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
112
113
114 # Stateless helpers for GeneratedProtocolMessageType below.
115 # Outside clients should not access these directly.
116 #
117 # I opted not to make any of these methods on the metaclass, to make it more
118 # clear that I'm not really using any state there and to keep clients from
119 # thinking that they have direct access to these construction helpers.
120
121
122 def _PropertyName(proto_field_name):
123 """Returns the name of the public property attribute which
124 clients can use to get and (in some cases) set the value
125 of a protocol message field.
126
127 Args:
128 proto_field_name: The protocol message field name, exactly
129 as it appears (or would appear) in a .proto file.
130 """
131 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
132 # nnorwitz makes my day by writing:
133 # """
134 # FYI. See the keyword module in the stdlib. This could be as simple as:
135 #
136 # if keyword.iskeyword(proto_field_name):
137 # return proto_field_name + "_"
138 # return proto_field_name
139 # """
140 # Kenton says: The above is a BAD IDEA. People rely on being able to use
141 # getattr() and setattr() to reflectively manipulate field values. If we
142 # rename the properties, then every such user has to also make sure to apply
143 # the same transformation. Note that currently if you name a field "yield",
144 # you can still access it just fine using getattr/setattr -- it's not even
145 # that cumbersome to do so.
146 # TODO(kenton): Remove this method entirely if/when everyone agrees with my
147 # position.
148 return proto_field_name
149
150
151 def _VerifyExtensionHandle(message, extension_handle):
152 """Verify that the given extension handle is valid."""
153
154 if not isinstance(extension_handle, _FieldDescriptor):
155 raise KeyError('HasExtension() expects an extension handle, got: %s' %
156 extension_handle)
157
158 if not extension_handle.is_extension:
159 raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
160
161 if not extension_handle.containing_type:
162 raise KeyError('"%s" is missing a containing_type.'
163 % extension_handle.full_name)
164
165 if extension_handle.containing_type is not message.DESCRIPTOR:
166 raise KeyError('Extension "%s" extends message type "%s", but this '
167 'message is of type "%s".' %
168 (extension_handle.full_name,
169 extension_handle.containing_type.full_name,
170 message.DESCRIPTOR.full_name))
171
172
173 def _AddSlots(message_descriptor, dictionary):
174 """Adds a __slots__ entry to dictionary, containing the names of all valid
175 attributes for this message type.
176
177 Args:
178 message_descriptor: A Descriptor instance describing this message type.
179 dictionary: Class dictionary to which we'll add a '__slots__' entry.
180 """
181 dictionary['__slots__'] = ['_cached_byte_size',
182 '_cached_byte_size_dirty',
183 '_fields',
184 '_unknown_fields',
185 '_is_present_in_parent',
186 '_listener',
187 '_listener_for_children',
188 '__weakref__',
189 '_oneofs']
190
191
192 def _IsMessageSetExtension(field):
193 return (field.is_extension and
194 field.containing_type.has_options and
195 field.containing_type.GetOptions().message_set_wire_format and
196 field.type == _FieldDescriptor.TYPE_MESSAGE and
197 field.message_type == field.extension_scope and
198 field.label == _FieldDescriptor.LABEL_OPTIONAL)
199
200
201 def _AttachFieldHelpers(cls, field_descriptor):
202 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
203 is_packed = (field_descriptor.has_options and
204 field_descriptor.GetOptions().packed)
205
206 if _IsMessageSetExtension(field_descriptor):
207 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
208 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
209 else:
210 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
211 field_descriptor.number, is_repeated, is_packed)
212 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
213 field_descriptor.number, is_repeated, is_packed)
214
215 field_descriptor._encoder = field_encoder
216 field_descriptor._sizer = sizer
217 field_descriptor._default_constructor = _DefaultValueConstructorForField(
218 field_descriptor)
219
220 def AddDecoder(wiretype, is_packed):
221 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
222 cls._decoders_by_tag[tag_bytes] = (
223 type_checkers.TYPE_TO_DECODER[field_descriptor.type](
224 field_descriptor.number, is_repeated, is_packed,
225 field_descriptor, field_descriptor._default_constructor))
226
227 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
228 False)
229
230 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
231 # To support wire compatibility of adding packed = true, add a decoder for
232 # packed values regardless of the field's options.
233 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
234
235
236 def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
237 extension_dict = descriptor.extensions_by_name
238 for extension_name, extension_field in extension_dict.iteritems():
239 assert extension_name not in dictionary
240 dictionary[extension_name] = extension_field
241
242
243 def _AddEnumValues(descriptor, cls):
244 """Sets class-level attributes for all enum fields defined in this message.
245
246 Also exporting a class-level object that can name enum values.
247
248 Args:
249 descriptor: Descriptor object for this message type.
250 cls: Class we're constructing for this message type.
251 """
252 for enum_type in descriptor.enum_types:
253 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
254 for enum_value in enum_type.values:
255 setattr(cls, enum_value.name, enum_value.number)
256
257
258 def _DefaultValueConstructorForField(field):
259 """Returns a function which returns a default value for a field.
260
261 Args:
262 field: FieldDescriptor object for this field.
263
264 The returned function has one argument:
265 message: Message instance containing this field, or a weakref proxy
266 of same.
267
268 That function in turn returns a default value for this field. The default
269 value may refer back to |message| via a weak reference.
270 """
271
272 if field.label == _FieldDescriptor.LABEL_REPEATED:
273 if field.has_default_value and field.default_value != []:
274 raise ValueError('Repeated field default value not empty list: %s' % (
275 field.default_value))
276 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
277 # We can't look at _concrete_class yet since it might not have
278 # been set. (Depends on order in which we initialize the classes).
279 message_type = field.message_type
280 def MakeRepeatedMessageDefault(message):
281 return containers.RepeatedCompositeFieldContainer(
282 message._listener_for_children, field.message_type)
283 return MakeRepeatedMessageDefault
284 else:
285 type_checker = type_checkers.GetTypeChecker(field)
286 def MakeRepeatedScalarDefault(message):
287 return containers.RepeatedScalarFieldContainer(
288 message._listener_for_children, type_checker)
289 return MakeRepeatedScalarDefault
290
291 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
292 # _concrete_class may not yet be initialized.
293 message_type = field.message_type
294 def MakeSubMessageDefault(message):
295 result = message_type._concrete_class()
296 result._SetListener(message._listener_for_children)
297 return result
298 return MakeSubMessageDefault
299
300 def MakeScalarDefault(message):
301 # TODO(protobuf-team): This may be broken since there may not be
302 # default_value. Combine with has_default_value somehow.
303 return field.default_value
304 return MakeScalarDefault
305
306
307 def _AddInitMethod(message_descriptor, cls):
308 """Adds an __init__ method to cls."""
309 fields = message_descriptor.fields
310 def init(self, **kwargs):
311 self._cached_byte_size = 0
312 self._cached_byte_size_dirty = len(kwargs) > 0
313 self._fields = {}
314 # Contains a mapping from oneof field descriptors to the descriptor
315 # of the currently set field in that oneof field.
316 self._oneofs = {}
317
318 # _unknown_fields is () when empty for efficiency, and will be turned into
319 # a list if fields are added.
320 self._unknown_fields = ()
321 self._is_present_in_parent = False
322 self._listener = message_listener_mod.NullMessageListener()
323 self._listener_for_children = _Listener(self)
324 for field_name, field_value in kwargs.iteritems():
325 field = _GetFieldByName(message_descriptor, field_name)
326 if field is None:
327 raise TypeError("%s() got an unexpected keyword argument '%s'" %
328 (message_descriptor.name, field_name))
329 if field.label == _FieldDescriptor.LABEL_REPEATED:
330 copy = field._default_constructor(self)
331 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
332 for val in field_value:
333 copy.add().MergeFrom(val)
334 else: # Scalar
335 copy.extend(field_value)
336 self._fields[field] = copy
337 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
338 copy = field._default_constructor(self)
339 copy.MergeFrom(field_value)
340 self._fields[field] = copy
341 else:
342 setattr(self, field_name, field_value)
343
344 init.__module__ = None
345 init.__doc__ = None
346 cls.__init__ = init
347
348
349 def _GetFieldByName(message_descriptor, field_name):
350 """Returns a field descriptor by field name.
351
352 Args:
353 message_descriptor: A Descriptor describing all fields in message.
354 field_name: The name of the field to retrieve.
355 Returns:
356 The field descriptor associated with the field name.
357 """
358 try:
359 return message_descriptor.fields_by_name[field_name]
360 except KeyError:
361 raise ValueError('Protocol message has no "%s" field.' % field_name)
362
363
364 def _AddPropertiesForFields(descriptor, cls):
365 """Adds properties for all fields in this protocol message type."""
366 for field in descriptor.fields:
367 _AddPropertiesForField(field, cls)
368
369 if descriptor.is_extendable:
370 # _ExtensionDict is just an adaptor with no state so we allocate a new one
371 # every time it is accessed.
372 cls.Extensions = property(lambda self: _ExtensionDict(self))
373
374
375 def _AddPropertiesForField(field, cls):
376 """Adds a public property for a protocol message field.
377 Clients can use this property to get and (in the case
378 of non-repeated scalar fields) directly set the value
379 of a protocol message field.
380
381 Args:
382 field: A FieldDescriptor for this field.
383 cls: The class we're constructing.
384 """
385 # Catch it if we add other types that we should
386 # handle specially here.
387 assert _FieldDescriptor.MAX_CPPTYPE == 10
388
389 constant_name = field.name.upper() + "_FIELD_NUMBER"
390 setattr(cls, constant_name, field.number)
391
392 if field.label == _FieldDescriptor.LABEL_REPEATED:
393 _AddPropertiesForRepeatedField(field, cls)
394 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
395 _AddPropertiesForNonRepeatedCompositeField(field, cls)
396 else:
397 _AddPropertiesForNonRepeatedScalarField(field, cls)
398
399
400 def _AddPropertiesForRepeatedField(field, cls):
401 """Adds a public property for a "repeated" protocol message field. Clients
402 can use this property to get the value of the field, which will be either a
403 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
404 below).
405
406 Note that when clients add values to these containers, we perform
407 type-checking in the case of repeated scalar fields, and we also set any
408 necessary "has" bits as a side-effect.
409
410 Args:
411 field: A FieldDescriptor for this field.
412 cls: The class we're constructing.
413 """
414 proto_field_name = field.name
415 property_name = _PropertyName(proto_field_name)
416
417 def getter(self):
418 field_value = self._fields.get(field)
419 if field_value is None:
420 # Construct a new object to represent this field.
421 field_value = field._default_constructor(self)
422
423 # Atomically check if another thread has preempted us and, if not, swap
424 # in the new object we just created. If someone has preempted us, we
425 # take that object and discard ours.
426 # WARNING: We are relying on setdefault() being atomic. This is true
427 # in CPython but we haven't investigated others. This warning appears
428 # in several other locations in this file.
429 field_value = self._fields.setdefault(field, field_value)
430 return field_value
431 getter.__module__ = None
432 getter.__doc__ = 'Getter for %s.' % proto_field_name
433
434 # We define a setter just so we can throw an exception with a more
435 # helpful error message.
436 def setter(self, new_value):
437 raise AttributeError('Assignment not allowed to repeated field '
438 '"%s" in protocol message object.' % proto_field_name)
439
440 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
441 setattr(cls, property_name, property(getter, setter, doc=doc))
442
443
444 def _AddPropertiesForNonRepeatedScalarField(field, cls):
445 """Adds a public property for a nonrepeated, scalar protocol message field.
446 Clients can use this property to get and directly set the value of the field.
447 Note that when the client sets the value of a field by using this property,
448 all necessary "has" bits are set as a side-effect, and we also perform
449 type-checking.
450
451 Args:
452 field: A FieldDescriptor for this field.
453 cls: The class we're constructing.
454 """
455 proto_field_name = field.name
456 property_name = _PropertyName(proto_field_name)
457 type_checker = type_checkers.GetTypeChecker(field)
458 default_value = field.default_value
459 valid_values = set()
460
461 def getter(self):
462 # TODO(protobuf-team): This may be broken since there may not be
463 # default_value. Combine with has_default_value somehow.
464 return self._fields.get(field, default_value)
465 getter.__module__ = None
466 getter.__doc__ = 'Getter for %s.' % proto_field_name
467 def field_setter(self, new_value):
468 # pylint: disable=protected-access
469 self._fields[field] = type_checker.CheckValue(new_value)
470 # Check _cached_byte_size_dirty inline to improve performance, since scalar
471 # setters are called frequently.
472 if not self._cached_byte_size_dirty:
473 self._Modified()
474
475 if field.containing_oneof is not None:
476 def setter(self, new_value):
477 field_setter(self, new_value)
478 self._UpdateOneofState(field)
479 else:
480 setter = field_setter
481
482 setter.__module__ = None
483 setter.__doc__ = 'Setter for %s.' % proto_field_name
484
485 # Add a property to encapsulate the getter/setter.
486 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
487 setattr(cls, property_name, property(getter, setter, doc=doc))
488
489
490 def _AddPropertiesForNonRepeatedCompositeField(field, cls):
491 """Adds a public property for a nonrepeated, composite protocol message field.
492 A composite field is a "group" or "message" field.
493
494 Clients can use this property to get the value of the field, but cannot
495 assign to the property directly.
496
497 Args:
498 field: A FieldDescriptor for this field.
499 cls: The class we're constructing.
500 """
501 # TODO(robinson): Remove duplication with similar method
502 # for non-repeated scalars.
503 proto_field_name = field.name
504 property_name = _PropertyName(proto_field_name)
505
506 # TODO(komarek): Can anyone explain to me why we cache the message_type this
507 # way, instead of referring to field.message_type inside of getter(self)?
508 # What if someone sets message_type later on (which makes for simpler
509 # dyanmic proto descriptor and class creation code).
510 message_type = field.message_type
511
512 def getter(self):
513 field_value = self._fields.get(field)
514 if field_value is None:
515 # Construct a new object to represent this field.
516 field_value = message_type._concrete_class() # use field.message_type?
517 field_value._SetListener(
518 _OneofListener(self, field)
519 if field.containing_oneof is not None
520 else self._listener_for_children)
521
522 # Atomically check if another thread has preempted us and, if not, swap
523 # in the new object we just created. If someone has preempted us, we
524 # take that object and discard ours.
525 # WARNING: We are relying on setdefault() being atomic. This is true
526 # in CPython but we haven't investigated others. This warning appears
527 # in several other locations in this file.
528 field_value = self._fields.setdefault(field, field_value)
529 return field_value
530 getter.__module__ = None
531 getter.__doc__ = 'Getter for %s.' % proto_field_name
532
533 # We define a setter just so we can throw an exception with a more
534 # helpful error message.
535 def setter(self, new_value):
536 raise AttributeError('Assignment not allowed to composite field '
537 '"%s" in protocol message object.' % proto_field_name)
538
539 # Add a property to encapsulate the getter.
540 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
541 setattr(cls, property_name, property(getter, setter, doc=doc))
542
543
544 def _AddPropertiesForExtensions(descriptor, cls):
545 """Adds properties for all fields in this protocol message type."""
546 extension_dict = descriptor.extensions_by_name
547 for extension_name, extension_field in extension_dict.iteritems():
548 constant_name = extension_name.upper() + "_FIELD_NUMBER"
549 setattr(cls, constant_name, extension_field.number)
550
551
552 def _AddStaticMethods(cls):
553 # TODO(robinson): This probably needs to be thread-safe(?)
554 def RegisterExtension(extension_handle):
555 extension_handle.containing_type = cls.DESCRIPTOR
556 _AttachFieldHelpers(cls, extension_handle)
557
558 # Try to insert our extension, failing if an extension with the same number
559 # already exists.
560 actual_handle = cls._extensions_by_number.setdefault(
561 extension_handle.number, extension_handle)
562 if actual_handle is not extension_handle:
563 raise AssertionError(
564 'Extensions "%s" and "%s" both try to extend message type "%s" with '
565 'field number %d.' %
566 (extension_handle.full_name, actual_handle.full_name,
567 cls.DESCRIPTOR.full_name, extension_handle.number))
568
569 cls._extensions_by_name[extension_handle.full_name] = extension_handle
570
571 handle = extension_handle # avoid line wrapping
572 if _IsMessageSetExtension(handle):
573 # MessageSet extension. Also register under type name.
574 cls._extensions_by_name[
575 extension_handle.message_type.full_name] = extension_handle
576
577 cls.RegisterExtension = staticmethod(RegisterExtension)
578
579 def FromString(s):
580 message = cls()
581 message.MergeFromString(s)
582 return message
583 cls.FromString = staticmethod(FromString)
584
585
586 def _IsPresent(item):
587 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
588 value should be included in the list returned by ListFields()."""
589
590 if item[0].label == _FieldDescriptor.LABEL_REPEATED:
591 return bool(item[1])
592 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
593 return item[1]._is_present_in_parent
594 else:
595 return True
596
597
598 def _AddListFieldsMethod(message_descriptor, cls):
599 """Helper for _AddMessageMethods()."""
600
601 def ListFields(self):
602 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
603 all_fields.sort(key = lambda item: item[0].number)
604 return all_fields
605
606 cls.ListFields = ListFields
607
608
609 def _AddHasFieldMethod(message_descriptor, cls):
610 """Helper for _AddMessageMethods()."""
611
612 singular_fields = {}
613 for field in message_descriptor.fields:
614 if field.label != _FieldDescriptor.LABEL_REPEATED:
615 singular_fields[field.name] = field
616 # Fields inside oneofs are never repeated (enforced by the compiler).
617 for field in message_descriptor.oneofs:
618 singular_fields[field.name] = field
619
620 def HasField(self, field_name):
621 try:
622 field = singular_fields[field_name]
623 except KeyError:
624 raise ValueError(
625 'Protocol message has no singular "%s" field.' % field_name)
626
627 if isinstance(field, descriptor_mod.OneofDescriptor):
628 try:
629 return HasField(self, self._oneofs[field].name)
630 except KeyError:
631 return False
632 else:
633 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
634 value = self._fields.get(field)
635 return value is not None and value._is_present_in_parent
636 else:
637 return field in self._fields
638
639 cls.HasField = HasField
640
641
642 def _AddClearFieldMethod(message_descriptor, cls):
643 """Helper for _AddMessageMethods()."""
644 def ClearField(self, field_name):
645 try:
646 field = message_descriptor.fields_by_name[field_name]
647 except KeyError:
648 try:
649 field = message_descriptor.oneofs_by_name[field_name]
650 if field in self._oneofs:
651 field = self._oneofs[field]
652 else:
653 return
654 except KeyError:
655 raise ValueError('Protocol message has no "%s" field.' % field_name)
656
657 if field in self._fields:
658 # Note: If the field is a sub-message, its listener will still point
659 # at us. That's fine, because the worst than can happen is that it
660 # will call _Modified() and invalidate our byte size. Big deal.
661 del self._fields[field]
662
663 if self._oneofs.get(field.containing_oneof, None) is field:
664 del self._oneofs[field.containing_oneof]
665
666 # Always call _Modified() -- even if nothing was changed, this is
667 # a mutating method, and thus calling it should cause the field to become
668 # present in the parent message.
669 self._Modified()
670
671 cls.ClearField = ClearField
672
673
674 def _AddClearExtensionMethod(cls):
675 """Helper for _AddMessageMethods()."""
676 def ClearExtension(self, extension_handle):
677 _VerifyExtensionHandle(self, extension_handle)
678
679 # Similar to ClearField(), above.
680 if extension_handle in self._fields:
681 del self._fields[extension_handle]
682 self._Modified()
683 cls.ClearExtension = ClearExtension
684
685
686 def _AddClearMethod(message_descriptor, cls):
687 """Helper for _AddMessageMethods()."""
688 def Clear(self):
689 # Clear fields.
690 self._fields = {}
691 self._unknown_fields = ()
692 self._Modified()
693 cls.Clear = Clear
694
695
696 def _AddHasExtensionMethod(cls):
697 """Helper for _AddMessageMethods()."""
698 def HasExtension(self, extension_handle):
699 _VerifyExtensionHandle(self, extension_handle)
700 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
701 raise KeyError('"%s" is repeated.' % extension_handle.full_name)
702
703 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
704 value = self._fields.get(extension_handle)
705 return value is not None and value._is_present_in_parent
706 else:
707 return extension_handle in self._fields
708 cls.HasExtension = HasExtension
709
710
711 def _AddEqualsMethod(message_descriptor, cls):
712 """Helper for _AddMessageMethods()."""
713 def __eq__(self, other):
714 if (not isinstance(other, message_mod.Message) or
715 other.DESCRIPTOR != self.DESCRIPTOR):
716 return False
717
718 if self is other:
719 return True
720
721 if not self.ListFields() == other.ListFields():
722 return False
723
724 # Sort unknown fields because their order shouldn't affect equality test.
725 unknown_fields = list(self._unknown_fields)
726 unknown_fields.sort()
727 other_unknown_fields = list(other._unknown_fields)
728 other_unknown_fields.sort()
729
730 return unknown_fields == other_unknown_fields
731
732 cls.__eq__ = __eq__
733
734
735 def _AddStrMethod(message_descriptor, cls):
736 """Helper for _AddMessageMethods()."""
737 def __str__(self):
738 return text_format.MessageToString(self)
739 cls.__str__ = __str__
740
741
742 def _AddUnicodeMethod(unused_message_descriptor, cls):
743 """Helper for _AddMessageMethods()."""
744
745 def __unicode__(self):
746 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
747 cls.__unicode__ = __unicode__
748
749
750 def _AddSetListenerMethod(cls):
751 """Helper for _AddMessageMethods()."""
752 def SetListener(self, listener):
753 if listener is None:
754 self._listener = message_listener_mod.NullMessageListener()
755 else:
756 self._listener = listener
757 cls._SetListener = SetListener
758
759
760 def _BytesForNonRepeatedElement(value, field_number, field_type):
761 """Returns the number of bytes needed to serialize a non-repeated element.
762 The returned byte count includes space for tag information and any
763 other additional space associated with serializing value.
764
765 Args:
766 value: Value we're serializing.
767 field_number: Field number of this value. (Since the field number
768 is stored as part of a varint-encoded tag, this has an impact
769 on the total bytes required to serialize the value).
770 field_type: The type of the field. One of the TYPE_* constants
771 within FieldDescriptor.
772 """
773 try:
774 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
775 return fn(field_number, value)
776 except KeyError:
777 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
778
779
780 def _AddByteSizeMethod(message_descriptor, cls):
781 """Helper for _AddMessageMethods()."""
782
783 def ByteSize(self):
784 if not self._cached_byte_size_dirty:
785 return self._cached_byte_size
786
787 size = 0
788 for field_descriptor, field_value in self.ListFields():
789 size += field_descriptor._sizer(field_value)
790
791 for tag_bytes, value_bytes in self._unknown_fields:
792 size += len(tag_bytes) + len(value_bytes)
793
794 self._cached_byte_size = size
795 self._cached_byte_size_dirty = False
796 self._listener_for_children.dirty = False
797 return size
798
799 cls.ByteSize = ByteSize
800
801
802 def _AddSerializeToStringMethod(message_descriptor, cls):
803 """Helper for _AddMessageMethods()."""
804
805 def SerializeToString(self):
806 # Check if the message has all of its required fields set.
807 errors = []
808 if not self.IsInitialized():
809 raise message_mod.EncodeError(
810 'Message %s is missing required fields: %s' % (
811 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
812 return self.SerializePartialToString()
813 cls.SerializeToString = SerializeToString
814
815
816 def _AddSerializePartialToStringMethod(message_descriptor, cls):
817 """Helper for _AddMessageMethods()."""
818
819 def SerializePartialToString(self):
820 out = BytesIO()
821 self._InternalSerialize(out.write)
822 return out.getvalue()
823 cls.SerializePartialToString = SerializePartialToString
824
825 def InternalSerialize(self, write_bytes):
826 for field_descriptor, field_value in self.ListFields():
827 field_descriptor._encoder(write_bytes, field_value)
828 for tag_bytes, value_bytes in self._unknown_fields:
829 write_bytes(tag_bytes)
830 write_bytes(value_bytes)
831 cls._InternalSerialize = InternalSerialize
832
833
834 def _AddMergeFromStringMethod(message_descriptor, cls):
835 """Helper for _AddMessageMethods()."""
836 def MergeFromString(self, serialized):
837 length = len(serialized)
838 try:
839 if self._InternalParse(serialized, 0, length) != length:
840 # The only reason _InternalParse would return early is if it
841 # encountered an end-group tag.
842 raise message_mod.DecodeError('Unexpected end-group tag.')
843 except (IndexError, TypeError):
844 # Now ord(buf[p:p+1]) == ord('') gets TypeError.
845 raise message_mod.DecodeError('Truncated message.')
846 except struct.error, e:
847 raise message_mod.DecodeError(e)
848 return length # Return this for legacy reasons.
849 cls.MergeFromString = MergeFromString
850
851 local_ReadTag = decoder.ReadTag
852 local_SkipField = decoder.SkipField
853 decoders_by_tag = cls._decoders_by_tag
854
855 def InternalParse(self, buffer, pos, end):
856 self._Modified()
857 field_dict = self._fields
858 unknown_field_list = self._unknown_fields
859 while pos != end:
860 (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
861 field_decoder = decoders_by_tag.get(tag_bytes)
862 if field_decoder is None:
863 value_start_pos = new_pos
864 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
865 if new_pos == -1:
866 return pos
867 if not unknown_field_list:
868 unknown_field_list = self._unknown_fields = []
869 unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
870 pos = new_pos
871 else:
872 pos = field_decoder(buffer, new_pos, end, self, field_dict)
873 return pos
874 cls._InternalParse = InternalParse
875
876
877 def _AddIsInitializedMethod(message_descriptor, cls):
878 """Adds the IsInitialized and FindInitializationError methods to the
879 protocol message class."""
880
881 required_fields = [field for field in message_descriptor.fields
882 if field.label == _FieldDescriptor.LABEL_REQUIRED]
883
884 def IsInitialized(self, errors=None):
885 """Checks if all required fields of a message are set.
886
887 Args:
888 errors: A list which, if provided, will be populated with the field
889 paths of all missing required fields.
890
891 Returns:
892 True iff the specified message has all required fields set.
893 """
894
895 # Performance is critical so we avoid HasField() and ListFields().
896
897 for field in required_fields:
898 if (field not in self._fields or
899 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
900 not self._fields[field]._is_present_in_parent)):
901 if errors is not None:
902 errors.extend(self.FindInitializationErrors())
903 return False
904
905 for field, value in list(self._fields.items()): # dict can change size!
906 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
907 if field.label == _FieldDescriptor.LABEL_REPEATED:
908 for element in value:
909 if not element.IsInitialized():
910 if errors is not None:
911 errors.extend(self.FindInitializationErrors())
912 return False
913 elif value._is_present_in_parent and not value.IsInitialized():
914 if errors is not None:
915 errors.extend(self.FindInitializationErrors())
916 return False
917
918 return True
919
920 cls.IsInitialized = IsInitialized
921
922 def FindInitializationErrors(self):
923 """Finds required fields which are not initialized.
924
925 Returns:
926 A list of strings. Each string is a path to an uninitialized field from
927 the top-level message, e.g. "foo.bar[5].baz".
928 """
929
930 errors = [] # simplify things
931
932 for field in required_fields:
933 if not self.HasField(field.name):
934 errors.append(field.name)
935
936 for field, value in self.ListFields():
937 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
938 if field.is_extension:
939 name = "(%s)" % field.full_name
940 else:
941 name = field.name
942
943 if field.label == _FieldDescriptor.LABEL_REPEATED:
944 for i in xrange(len(value)):
945 element = value[i]
946 prefix = "%s[%d]." % (name, i)
947 sub_errors = element.FindInitializationErrors()
948 errors += [ prefix + error for error in sub_errors ]
949 else:
950 prefix = name + "."
951 sub_errors = value.FindInitializationErrors()
952 errors += [ prefix + error for error in sub_errors ]
953
954 return errors
955
956 cls.FindInitializationErrors = FindInitializationErrors
957
958
959 def _AddMergeFromMethod(cls):
960 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
961 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
962
963 def MergeFrom(self, msg):
964 if not isinstance(msg, cls):
965 raise TypeError(
966 "Parameter to MergeFrom() must be instance of same class: "
967 "expected %s got %s." % (cls.__name__, type(msg).__name__))
968
969 assert msg is not self
970 self._Modified()
971
972 fields = self._fields
973
974 for field, value in msg._fields.iteritems():
975 if field.label == LABEL_REPEATED:
976 field_value = fields.get(field)
977 if field_value is None:
978 # Construct a new object to represent this field.
979 field_value = field._default_constructor(self)
980 fields[field] = field_value
981 field_value.MergeFrom(value)
982 elif field.cpp_type == CPPTYPE_MESSAGE:
983 if value._is_present_in_parent:
984 field_value = fields.get(field)
985 if field_value is None:
986 # Construct a new object to represent this field.
987 field_value = field._default_constructor(self)
988 fields[field] = field_value
989 field_value.MergeFrom(value)
990 else:
991 self._fields[field] = value
992
993 if msg._unknown_fields:
994 if not self._unknown_fields:
995 self._unknown_fields = []
996 self._unknown_fields.extend(msg._unknown_fields)
997
998 cls.MergeFrom = MergeFrom
999
1000
1001 def _AddWhichOneofMethod(message_descriptor, cls):
1002 def WhichOneof(self, oneof_name):
1003 """Returns the name of the currently set field inside a oneof, or None."""
1004 try:
1005 field = message_descriptor.oneofs_by_name[oneof_name]
1006 except KeyError:
1007 raise ValueError(
1008 'Protocol message has no oneof "%s" field.' % oneof_name)
1009
1010 nested_field = self._oneofs.get(field, None)
1011 if nested_field is not None and self.HasField(nested_field.name):
1012 return nested_field.name
1013 else:
1014 return None
1015
1016 cls.WhichOneof = WhichOneof
1017
1018
1019 def _AddMessageMethods(message_descriptor, cls):
1020 """Adds implementations of all Message methods to cls."""
1021 _AddListFieldsMethod(message_descriptor, cls)
1022 _AddHasFieldMethod(message_descriptor, cls)
1023 _AddClearFieldMethod(message_descriptor, cls)
1024 if message_descriptor.is_extendable:
1025 _AddClearExtensionMethod(cls)
1026 _AddHasExtensionMethod(cls)
1027 _AddClearMethod(message_descriptor, cls)
1028 _AddEqualsMethod(message_descriptor, cls)
1029 _AddStrMethod(message_descriptor, cls)
1030 _AddUnicodeMethod(message_descriptor, cls)
1031 _AddSetListenerMethod(cls)
1032 _AddByteSizeMethod(message_descriptor, cls)
1033 _AddSerializeToStringMethod(message_descriptor, cls)
1034 _AddSerializePartialToStringMethod(message_descriptor, cls)
1035 _AddMergeFromStringMethod(message_descriptor, cls)
1036 _AddIsInitializedMethod(message_descriptor, cls)
1037 _AddMergeFromMethod(cls)
1038 _AddWhichOneofMethod(message_descriptor, cls)
1039
1040 def _AddPrivateHelperMethods(message_descriptor, cls):
1041 """Adds implementation of private helper methods to cls."""
1042
1043 def Modified(self):
1044 """Sets the _cached_byte_size_dirty bit to true,
1045 and propagates this to our listener iff this was a state change.
1046 """
1047
1048 # Note: Some callers check _cached_byte_size_dirty before calling
1049 # _Modified() as an extra optimization. So, if this method is ever
1050 # changed such that it does stuff even when _cached_byte_size_dirty is
1051 # already true, the callers need to be updated.
1052 if not self._cached_byte_size_dirty:
1053 self._cached_byte_size_dirty = True
1054 self._listener_for_children.dirty = True
1055 self._is_present_in_parent = True
1056 self._listener.Modified()
1057
1058 def _UpdateOneofState(self, field):
1059 """Sets field as the active field in its containing oneof.
1060
1061 Will also delete currently active field in the oneof, if it is different
1062 from the argument. Does not mark the message as modified.
1063 """
1064 other_field = self._oneofs.setdefault(field.containing_oneof, field)
1065 if other_field is not field:
1066 del self._fields[other_field]
1067 self._oneofs[field.containing_oneof] = field
1068
1069 cls._Modified = Modified
1070 cls.SetInParent = Modified
1071 cls._UpdateOneofState = _UpdateOneofState
1072
1073
1074 class _Listener(object):
1075
1076 """MessageListener implementation that a parent message registers with its
1077 child message.
1078
1079 In order to support semantics like:
1080
1081 foo.bar.baz.qux = 23
1082 assert foo.HasField('bar')
1083
1084 ...child objects must have back references to their parents.
1085 This helper class is at the heart of this support.
1086 """
1087
1088 def __init__(self, parent_message):
1089 """Args:
1090 parent_message: The message whose _Modified() method we should call when
1091 we receive Modified() messages.
1092 """
1093 # This listener establishes a back reference from a child (contained) object
1094 # to its parent (containing) object. We make this a weak reference to avoid
1095 # creating cyclic garbage when the client finishes with the 'parent' object
1096 # in the tree.
1097 if isinstance(parent_message, weakref.ProxyType):
1098 self._parent_message_weakref = parent_message
1099 else:
1100 self._parent_message_weakref = weakref.proxy(parent_message)
1101
1102 # As an optimization, we also indicate directly on the listener whether
1103 # or not the parent message is dirty. This way we can avoid traversing
1104 # up the tree in the common case.
1105 self.dirty = False
1106
1107 def Modified(self):
1108 if self.dirty:
1109 return
1110 try:
1111 # Propagate the signal to our parents iff this is the first field set.
1112 self._parent_message_weakref._Modified()
1113 except ReferenceError:
1114 # We can get here if a client has kept a reference to a child object,
1115 # and is now setting a field on it, but the child's parent has been
1116 # garbage-collected. This is not an error.
1117 pass
1118
1119
1120 class _OneofListener(_Listener):
1121 """Special listener implementation for setting composite oneof fields."""
1122
1123 def __init__(self, parent_message, field):
1124 """Args:
1125 parent_message: The message whose _Modified() method we should call when
1126 we receive Modified() messages.
1127 field: The descriptor of the field being set in the parent message.
1128 """
1129 super(_OneofListener, self).__init__(parent_message)
1130 self._field = field
1131
1132 def Modified(self):
1133 """Also updates the state of the containing oneof in the parent message."""
1134 try:
1135 self._parent_message_weakref._UpdateOneofState(self._field)
1136 super(_OneofListener, self).Modified()
1137 except ReferenceError:
1138 pass
1139
1140
1141 # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
1142 # TODO(robinson): Unify error handling of "unknown extension" crap.
1143 # TODO(robinson): Support iteritems()-style iteration over all
1144 # extensions with the "has" bits turned on?
1145 class _ExtensionDict(object):
1146
1147 """Dict-like container for supporting an indexable "Extensions"
1148 field on proto instances.
1149
1150 Note that in all cases we expect extension handles to be
1151 FieldDescriptors.
1152 """
1153
1154 def __init__(self, extended_message):
1155 """extended_message: Message instance for which we are the Extensions dict.
1156 """
1157
1158 self._extended_message = extended_message
1159
1160 def __getitem__(self, extension_handle):
1161 """Returns the current value of the given extension handle."""
1162
1163 _VerifyExtensionHandle(self._extended_message, extension_handle)
1164
1165 result = self._extended_message._fields.get(extension_handle)
1166 if result is not None:
1167 return result
1168
1169 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1170 result = extension_handle._default_constructor(self._extended_message)
1171 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1172 result = extension_handle.message_type._concrete_class()
1173 try:
1174 result._SetListener(self._extended_message._listener_for_children)
1175 except ReferenceError:
1176 pass
1177 else:
1178 # Singular scalar -- just return the default without inserting into the
1179 # dict.
1180 return extension_handle.default_value
1181
1182 # Atomically check if another thread has preempted us and, if not, swap
1183 # in the new object we just created. If someone has preempted us, we
1184 # take that object and discard ours.
1185 # WARNING: We are relying on setdefault() being atomic. This is true
1186 # in CPython but we haven't investigated others. This warning appears
1187 # in several other locations in this file.
1188 result = self._extended_message._fields.setdefault(
1189 extension_handle, result)
1190
1191 return result
1192
1193 def __eq__(self, other):
1194 if not isinstance(other, self.__class__):
1195 return False
1196
1197 my_fields = self._extended_message.ListFields()
1198 other_fields = other._extended_message.ListFields()
1199
1200 # Get rid of non-extension fields.
1201 my_fields = [ field for field in my_fields if field.is_extension ]
1202 other_fields = [ field for field in other_fields if field.is_extension ]
1203
1204 return my_fields == other_fields
1205
1206 def __ne__(self, other):
1207 return not self == other
1208
1209 def __hash__(self):
1210 raise TypeError('unhashable object')
1211
1212 # Note that this is only meaningful for non-repeated, scalar extension
1213 # fields. Note also that we may have to call _Modified() when we do
1214 # successfully set a field this way, to set any necssary "has" bits in the
1215 # ancestors of the extended message.
1216 def __setitem__(self, extension_handle, value):
1217 """If extension_handle specifies a non-repeated, scalar extension
1218 field, sets the value of that field.
1219 """
1220
1221 _VerifyExtensionHandle(self._extended_message, extension_handle)
1222
1223 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1224 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1225 raise TypeError(
1226 'Cannot assign to extension "%s" because it is a repeated or '
1227 'composite type.' % extension_handle.full_name)
1228
1229 # It's slightly wasteful to lookup the type checker each time,
1230 # but we expect this to be a vanishingly uncommon case anyway.
1231 type_checker = type_checkers.GetTypeChecker(
1232 extension_handle)
1233 # pylint: disable=protected-access
1234 self._extended_message._fields[extension_handle] = (
1235 type_checker.CheckValue(value))
1236 self._extended_message._Modified()
1237
1238 def _FindExtensionByName(self, name):
1239 """Tries to find a known extension with the specified name.
1240
1241 Args:
1242 name: Extension full name.
1243
1244 Returns:
1245 Extension field descriptor.
1246 """
1247 return self._extended_message._extensions_by_name.get(name, None)
OLDNEW
« no previous file with comments | « third_party/google/protobuf/internal/message_listener.py ('k') | third_party/google/protobuf/internal/type_checkers.py » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698