OLD | NEW |
| (Empty) |
1 # Protocol Buffers - Google's data interchange format | |
2 # Copyright 2008 Google Inc. All rights reserved. | |
3 # http://code.google.com/p/protobuf/ | |
4 # | |
5 # Redistribution and use in source and binary forms, with or without | |
6 # modification, are permitted provided that the following conditions are | |
7 # met: | |
8 # | |
9 # * Redistributions of source code must retain the above copyright | |
10 # notice, this list of conditions and the following disclaimer. | |
11 # * Redistributions in binary form must reproduce the above | |
12 # copyright notice, this list of conditions and the following disclaimer | |
13 # in the documentation and/or other materials provided with the | |
14 # distribution. | |
15 # * Neither the name of Google Inc. nor the names of its | |
16 # contributors may be used to endorse or promote products derived from | |
17 # this software without specific prior written permission. | |
18 # | |
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | |
20 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | |
21 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | |
22 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | |
23 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | |
24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | |
25 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |
26 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | |
27 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
29 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
30 | |
31 # Keep it Python2.5 compatible for GAE. | |
32 # | |
33 # Copyright 2007 Google Inc. All Rights Reserved. | |
34 # | |
35 # This code is meant to work on Python 2.4 and above only. | |
36 # | |
37 # TODO(robinson): Helpers for verbose, common checks like seeing if a | |
38 # descriptor's cpp_type is CPPTYPE_MESSAGE. | |
39 | |
40 """Contains a metaclass and helper functions used to create | |
41 protocol message classes from Descriptor objects at runtime. | |
42 | |
43 Recall that a metaclass is the "type" of a class. | |
44 (A class is to a metaclass what an instance is to a class.) | |
45 | |
46 In this case, we use the GeneratedProtocolMessageType metaclass | |
47 to inject all the useful functionality into the classes | |
48 output by the protocol compiler at compile-time. | |
49 | |
50 The upshot of all this is that the real implementation | |
51 details for ALL pure-Python protocol buffers are *here in | |
52 this file*. | |
53 """ | |
54 | |
55 __author__ = 'robinson@google.com (Will Robinson)' | |
56 | |
57 import sys | |
58 if sys.version_info[0] < 3: | |
59 try: | |
60 from cStringIO import StringIO as BytesIO | |
61 except ImportError: | |
62 from StringIO import StringIO as BytesIO | |
63 import copy_reg as copyreg | |
64 else: | |
65 from io import BytesIO | |
66 import copyreg | |
67 import struct | |
68 import weakref | |
69 | |
70 # We use "as" to avoid name collisions with variables. | |
71 from google.protobuf.internal import containers | |
72 from google.protobuf.internal import decoder | |
73 from google.protobuf.internal import encoder | |
74 from google.protobuf.internal import enum_type_wrapper | |
75 from google.protobuf.internal import message_listener as message_listener_mod | |
76 from google.protobuf.internal import type_checkers | |
77 from google.protobuf.internal import wire_format | |
78 from google.protobuf import descriptor as descriptor_mod | |
79 from google.protobuf import message as message_mod | |
80 from google.protobuf import text_format | |
81 | |
82 _FieldDescriptor = descriptor_mod.FieldDescriptor | |
83 | |
84 | |
85 def NewMessage(bases, descriptor, dictionary): | |
86 _AddClassAttributesForNestedExtensions(descriptor, dictionary) | |
87 _AddSlots(descriptor, dictionary) | |
88 return bases | |
89 | |
90 | |
91 def InitMessage(descriptor, cls): | |
92 cls._decoders_by_tag = {} | |
93 cls._extensions_by_name = {} | |
94 cls._extensions_by_number = {} | |
95 if (descriptor.has_options and | |
96 descriptor.GetOptions().message_set_wire_format): | |
97 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( | |
98 decoder.MessageSetItemDecoder(cls._extensions_by_number)) | |
99 | |
100 # Attach stuff to each FieldDescriptor for quick lookup later on. | |
101 for field in descriptor.fields: | |
102 _AttachFieldHelpers(cls, field) | |
103 | |
104 _AddEnumValues(descriptor, cls) | |
105 _AddInitMethod(descriptor, cls) | |
106 _AddPropertiesForFields(descriptor, cls) | |
107 _AddPropertiesForExtensions(descriptor, cls) | |
108 _AddStaticMethods(cls) | |
109 _AddMessageMethods(descriptor, cls) | |
110 _AddPrivateHelperMethods(descriptor, cls) | |
111 copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) | |
112 | |
113 | |
114 # Stateless helpers for GeneratedProtocolMessageType below. | |
115 # Outside clients should not access these directly. | |
116 # | |
117 # I opted not to make any of these methods on the metaclass, to make it more | |
118 # clear that I'm not really using any state there and to keep clients from | |
119 # thinking that they have direct access to these construction helpers. | |
120 | |
121 | |
122 def _PropertyName(proto_field_name): | |
123 """Returns the name of the public property attribute which | |
124 clients can use to get and (in some cases) set the value | |
125 of a protocol message field. | |
126 | |
127 Args: | |
128 proto_field_name: The protocol message field name, exactly | |
129 as it appears (or would appear) in a .proto file. | |
130 """ | |
131 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. | |
132 # nnorwitz makes my day by writing: | |
133 # """ | |
134 # FYI. See the keyword module in the stdlib. This could be as simple as: | |
135 # | |
136 # if keyword.iskeyword(proto_field_name): | |
137 # return proto_field_name + "_" | |
138 # return proto_field_name | |
139 # """ | |
140 # Kenton says: The above is a BAD IDEA. People rely on being able to use | |
141 # getattr() and setattr() to reflectively manipulate field values. If we | |
142 # rename the properties, then every such user has to also make sure to apply | |
143 # the same transformation. Note that currently if you name a field "yield", | |
144 # you can still access it just fine using getattr/setattr -- it's not even | |
145 # that cumbersome to do so. | |
146 # TODO(kenton): Remove this method entirely if/when everyone agrees with my | |
147 # position. | |
148 return proto_field_name | |
149 | |
150 | |
151 def _VerifyExtensionHandle(message, extension_handle): | |
152 """Verify that the given extension handle is valid.""" | |
153 | |
154 if not isinstance(extension_handle, _FieldDescriptor): | |
155 raise KeyError('HasExtension() expects an extension handle, got: %s' % | |
156 extension_handle) | |
157 | |
158 if not extension_handle.is_extension: | |
159 raise KeyError('"%s" is not an extension.' % extension_handle.full_name) | |
160 | |
161 if not extension_handle.containing_type: | |
162 raise KeyError('"%s" is missing a containing_type.' | |
163 % extension_handle.full_name) | |
164 | |
165 if extension_handle.containing_type is not message.DESCRIPTOR: | |
166 raise KeyError('Extension "%s" extends message type "%s", but this ' | |
167 'message is of type "%s".' % | |
168 (extension_handle.full_name, | |
169 extension_handle.containing_type.full_name, | |
170 message.DESCRIPTOR.full_name)) | |
171 | |
172 | |
173 def _AddSlots(message_descriptor, dictionary): | |
174 """Adds a __slots__ entry to dictionary, containing the names of all valid | |
175 attributes for this message type. | |
176 | |
177 Args: | |
178 message_descriptor: A Descriptor instance describing this message type. | |
179 dictionary: Class dictionary to which we'll add a '__slots__' entry. | |
180 """ | |
181 dictionary['__slots__'] = ['_cached_byte_size', | |
182 '_cached_byte_size_dirty', | |
183 '_fields', | |
184 '_unknown_fields', | |
185 '_is_present_in_parent', | |
186 '_listener', | |
187 '_listener_for_children', | |
188 '__weakref__', | |
189 '_oneofs'] | |
190 | |
191 | |
192 def _IsMessageSetExtension(field): | |
193 return (field.is_extension and | |
194 field.containing_type.has_options and | |
195 field.containing_type.GetOptions().message_set_wire_format and | |
196 field.type == _FieldDescriptor.TYPE_MESSAGE and | |
197 field.message_type == field.extension_scope and | |
198 field.label == _FieldDescriptor.LABEL_OPTIONAL) | |
199 | |
200 | |
201 def _AttachFieldHelpers(cls, field_descriptor): | |
202 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) | |
203 is_packed = (field_descriptor.has_options and | |
204 field_descriptor.GetOptions().packed) | |
205 | |
206 if _IsMessageSetExtension(field_descriptor): | |
207 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) | |
208 sizer = encoder.MessageSetItemSizer(field_descriptor.number) | |
209 else: | |
210 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( | |
211 field_descriptor.number, is_repeated, is_packed) | |
212 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( | |
213 field_descriptor.number, is_repeated, is_packed) | |
214 | |
215 field_descriptor._encoder = field_encoder | |
216 field_descriptor._sizer = sizer | |
217 field_descriptor._default_constructor = _DefaultValueConstructorForField( | |
218 field_descriptor) | |
219 | |
220 def AddDecoder(wiretype, is_packed): | |
221 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) | |
222 cls._decoders_by_tag[tag_bytes] = ( | |
223 type_checkers.TYPE_TO_DECODER[field_descriptor.type]( | |
224 field_descriptor.number, is_repeated, is_packed, | |
225 field_descriptor, field_descriptor._default_constructor)) | |
226 | |
227 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], | |
228 False) | |
229 | |
230 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): | |
231 # To support wire compatibility of adding packed = true, add a decoder for | |
232 # packed values regardless of the field's options. | |
233 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) | |
234 | |
235 | |
236 def _AddClassAttributesForNestedExtensions(descriptor, dictionary): | |
237 extension_dict = descriptor.extensions_by_name | |
238 for extension_name, extension_field in extension_dict.iteritems(): | |
239 assert extension_name not in dictionary | |
240 dictionary[extension_name] = extension_field | |
241 | |
242 | |
243 def _AddEnumValues(descriptor, cls): | |
244 """Sets class-level attributes for all enum fields defined in this message. | |
245 | |
246 Also exporting a class-level object that can name enum values. | |
247 | |
248 Args: | |
249 descriptor: Descriptor object for this message type. | |
250 cls: Class we're constructing for this message type. | |
251 """ | |
252 for enum_type in descriptor.enum_types: | |
253 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) | |
254 for enum_value in enum_type.values: | |
255 setattr(cls, enum_value.name, enum_value.number) | |
256 | |
257 | |
258 def _DefaultValueConstructorForField(field): | |
259 """Returns a function which returns a default value for a field. | |
260 | |
261 Args: | |
262 field: FieldDescriptor object for this field. | |
263 | |
264 The returned function has one argument: | |
265 message: Message instance containing this field, or a weakref proxy | |
266 of same. | |
267 | |
268 That function in turn returns a default value for this field. The default | |
269 value may refer back to |message| via a weak reference. | |
270 """ | |
271 | |
272 if field.label == _FieldDescriptor.LABEL_REPEATED: | |
273 if field.has_default_value and field.default_value != []: | |
274 raise ValueError('Repeated field default value not empty list: %s' % ( | |
275 field.default_value)) | |
276 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
277 # We can't look at _concrete_class yet since it might not have | |
278 # been set. (Depends on order in which we initialize the classes). | |
279 message_type = field.message_type | |
280 def MakeRepeatedMessageDefault(message): | |
281 return containers.RepeatedCompositeFieldContainer( | |
282 message._listener_for_children, field.message_type) | |
283 return MakeRepeatedMessageDefault | |
284 else: | |
285 type_checker = type_checkers.GetTypeChecker(field) | |
286 def MakeRepeatedScalarDefault(message): | |
287 return containers.RepeatedScalarFieldContainer( | |
288 message._listener_for_children, type_checker) | |
289 return MakeRepeatedScalarDefault | |
290 | |
291 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
292 # _concrete_class may not yet be initialized. | |
293 message_type = field.message_type | |
294 def MakeSubMessageDefault(message): | |
295 result = message_type._concrete_class() | |
296 result._SetListener(message._listener_for_children) | |
297 return result | |
298 return MakeSubMessageDefault | |
299 | |
300 def MakeScalarDefault(message): | |
301 # TODO(protobuf-team): This may be broken since there may not be | |
302 # default_value. Combine with has_default_value somehow. | |
303 return field.default_value | |
304 return MakeScalarDefault | |
305 | |
306 | |
307 def _AddInitMethod(message_descriptor, cls): | |
308 """Adds an __init__ method to cls.""" | |
309 fields = message_descriptor.fields | |
310 def init(self, **kwargs): | |
311 self._cached_byte_size = 0 | |
312 self._cached_byte_size_dirty = len(kwargs) > 0 | |
313 self._fields = {} | |
314 # Contains a mapping from oneof field descriptors to the descriptor | |
315 # of the currently set field in that oneof field. | |
316 self._oneofs = {} | |
317 | |
318 # _unknown_fields is () when empty for efficiency, and will be turned into | |
319 # a list if fields are added. | |
320 self._unknown_fields = () | |
321 self._is_present_in_parent = False | |
322 self._listener = message_listener_mod.NullMessageListener() | |
323 self._listener_for_children = _Listener(self) | |
324 for field_name, field_value in kwargs.iteritems(): | |
325 field = _GetFieldByName(message_descriptor, field_name) | |
326 if field is None: | |
327 raise TypeError("%s() got an unexpected keyword argument '%s'" % | |
328 (message_descriptor.name, field_name)) | |
329 if field.label == _FieldDescriptor.LABEL_REPEATED: | |
330 copy = field._default_constructor(self) | |
331 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite | |
332 for val in field_value: | |
333 copy.add().MergeFrom(val) | |
334 else: # Scalar | |
335 copy.extend(field_value) | |
336 self._fields[field] = copy | |
337 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
338 copy = field._default_constructor(self) | |
339 copy.MergeFrom(field_value) | |
340 self._fields[field] = copy | |
341 else: | |
342 setattr(self, field_name, field_value) | |
343 | |
344 init.__module__ = None | |
345 init.__doc__ = None | |
346 cls.__init__ = init | |
347 | |
348 | |
349 def _GetFieldByName(message_descriptor, field_name): | |
350 """Returns a field descriptor by field name. | |
351 | |
352 Args: | |
353 message_descriptor: A Descriptor describing all fields in message. | |
354 field_name: The name of the field to retrieve. | |
355 Returns: | |
356 The field descriptor associated with the field name. | |
357 """ | |
358 try: | |
359 return message_descriptor.fields_by_name[field_name] | |
360 except KeyError: | |
361 raise ValueError('Protocol message has no "%s" field.' % field_name) | |
362 | |
363 | |
364 def _AddPropertiesForFields(descriptor, cls): | |
365 """Adds properties for all fields in this protocol message type.""" | |
366 for field in descriptor.fields: | |
367 _AddPropertiesForField(field, cls) | |
368 | |
369 if descriptor.is_extendable: | |
370 # _ExtensionDict is just an adaptor with no state so we allocate a new one | |
371 # every time it is accessed. | |
372 cls.Extensions = property(lambda self: _ExtensionDict(self)) | |
373 | |
374 | |
375 def _AddPropertiesForField(field, cls): | |
376 """Adds a public property for a protocol message field. | |
377 Clients can use this property to get and (in the case | |
378 of non-repeated scalar fields) directly set the value | |
379 of a protocol message field. | |
380 | |
381 Args: | |
382 field: A FieldDescriptor for this field. | |
383 cls: The class we're constructing. | |
384 """ | |
385 # Catch it if we add other types that we should | |
386 # handle specially here. | |
387 assert _FieldDescriptor.MAX_CPPTYPE == 10 | |
388 | |
389 constant_name = field.name.upper() + "_FIELD_NUMBER" | |
390 setattr(cls, constant_name, field.number) | |
391 | |
392 if field.label == _FieldDescriptor.LABEL_REPEATED: | |
393 _AddPropertiesForRepeatedField(field, cls) | |
394 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
395 _AddPropertiesForNonRepeatedCompositeField(field, cls) | |
396 else: | |
397 _AddPropertiesForNonRepeatedScalarField(field, cls) | |
398 | |
399 | |
400 def _AddPropertiesForRepeatedField(field, cls): | |
401 """Adds a public property for a "repeated" protocol message field. Clients | |
402 can use this property to get the value of the field, which will be either a | |
403 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see | |
404 below). | |
405 | |
406 Note that when clients add values to these containers, we perform | |
407 type-checking in the case of repeated scalar fields, and we also set any | |
408 necessary "has" bits as a side-effect. | |
409 | |
410 Args: | |
411 field: A FieldDescriptor for this field. | |
412 cls: The class we're constructing. | |
413 """ | |
414 proto_field_name = field.name | |
415 property_name = _PropertyName(proto_field_name) | |
416 | |
417 def getter(self): | |
418 field_value = self._fields.get(field) | |
419 if field_value is None: | |
420 # Construct a new object to represent this field. | |
421 field_value = field._default_constructor(self) | |
422 | |
423 # Atomically check if another thread has preempted us and, if not, swap | |
424 # in the new object we just created. If someone has preempted us, we | |
425 # take that object and discard ours. | |
426 # WARNING: We are relying on setdefault() being atomic. This is true | |
427 # in CPython but we haven't investigated others. This warning appears | |
428 # in several other locations in this file. | |
429 field_value = self._fields.setdefault(field, field_value) | |
430 return field_value | |
431 getter.__module__ = None | |
432 getter.__doc__ = 'Getter for %s.' % proto_field_name | |
433 | |
434 # We define a setter just so we can throw an exception with a more | |
435 # helpful error message. | |
436 def setter(self, new_value): | |
437 raise AttributeError('Assignment not allowed to repeated field ' | |
438 '"%s" in protocol message object.' % proto_field_name) | |
439 | |
440 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name | |
441 setattr(cls, property_name, property(getter, setter, doc=doc)) | |
442 | |
443 | |
444 def _AddPropertiesForNonRepeatedScalarField(field, cls): | |
445 """Adds a public property for a nonrepeated, scalar protocol message field. | |
446 Clients can use this property to get and directly set the value of the field. | |
447 Note that when the client sets the value of a field by using this property, | |
448 all necessary "has" bits are set as a side-effect, and we also perform | |
449 type-checking. | |
450 | |
451 Args: | |
452 field: A FieldDescriptor for this field. | |
453 cls: The class we're constructing. | |
454 """ | |
455 proto_field_name = field.name | |
456 property_name = _PropertyName(proto_field_name) | |
457 type_checker = type_checkers.GetTypeChecker(field) | |
458 default_value = field.default_value | |
459 valid_values = set() | |
460 | |
461 def getter(self): | |
462 # TODO(protobuf-team): This may be broken since there may not be | |
463 # default_value. Combine with has_default_value somehow. | |
464 return self._fields.get(field, default_value) | |
465 getter.__module__ = None | |
466 getter.__doc__ = 'Getter for %s.' % proto_field_name | |
467 def field_setter(self, new_value): | |
468 # pylint: disable=protected-access | |
469 self._fields[field] = type_checker.CheckValue(new_value) | |
470 # Check _cached_byte_size_dirty inline to improve performance, since scalar | |
471 # setters are called frequently. | |
472 if not self._cached_byte_size_dirty: | |
473 self._Modified() | |
474 | |
475 if field.containing_oneof is not None: | |
476 def setter(self, new_value): | |
477 field_setter(self, new_value) | |
478 self._UpdateOneofState(field) | |
479 else: | |
480 setter = field_setter | |
481 | |
482 setter.__module__ = None | |
483 setter.__doc__ = 'Setter for %s.' % proto_field_name | |
484 | |
485 # Add a property to encapsulate the getter/setter. | |
486 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name | |
487 setattr(cls, property_name, property(getter, setter, doc=doc)) | |
488 | |
489 | |
490 def _AddPropertiesForNonRepeatedCompositeField(field, cls): | |
491 """Adds a public property for a nonrepeated, composite protocol message field. | |
492 A composite field is a "group" or "message" field. | |
493 | |
494 Clients can use this property to get the value of the field, but cannot | |
495 assign to the property directly. | |
496 | |
497 Args: | |
498 field: A FieldDescriptor for this field. | |
499 cls: The class we're constructing. | |
500 """ | |
501 # TODO(robinson): Remove duplication with similar method | |
502 # for non-repeated scalars. | |
503 proto_field_name = field.name | |
504 property_name = _PropertyName(proto_field_name) | |
505 | |
506 # TODO(komarek): Can anyone explain to me why we cache the message_type this | |
507 # way, instead of referring to field.message_type inside of getter(self)? | |
508 # What if someone sets message_type later on (which makes for simpler | |
509 # dyanmic proto descriptor and class creation code). | |
510 message_type = field.message_type | |
511 | |
512 def getter(self): | |
513 field_value = self._fields.get(field) | |
514 if field_value is None: | |
515 # Construct a new object to represent this field. | |
516 field_value = message_type._concrete_class() # use field.message_type? | |
517 field_value._SetListener( | |
518 _OneofListener(self, field) | |
519 if field.containing_oneof is not None | |
520 else self._listener_for_children) | |
521 | |
522 # Atomically check if another thread has preempted us and, if not, swap | |
523 # in the new object we just created. If someone has preempted us, we | |
524 # take that object and discard ours. | |
525 # WARNING: We are relying on setdefault() being atomic. This is true | |
526 # in CPython but we haven't investigated others. This warning appears | |
527 # in several other locations in this file. | |
528 field_value = self._fields.setdefault(field, field_value) | |
529 return field_value | |
530 getter.__module__ = None | |
531 getter.__doc__ = 'Getter for %s.' % proto_field_name | |
532 | |
533 # We define a setter just so we can throw an exception with a more | |
534 # helpful error message. | |
535 def setter(self, new_value): | |
536 raise AttributeError('Assignment not allowed to composite field ' | |
537 '"%s" in protocol message object.' % proto_field_name) | |
538 | |
539 # Add a property to encapsulate the getter. | |
540 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name | |
541 setattr(cls, property_name, property(getter, setter, doc=doc)) | |
542 | |
543 | |
544 def _AddPropertiesForExtensions(descriptor, cls): | |
545 """Adds properties for all fields in this protocol message type.""" | |
546 extension_dict = descriptor.extensions_by_name | |
547 for extension_name, extension_field in extension_dict.iteritems(): | |
548 constant_name = extension_name.upper() + "_FIELD_NUMBER" | |
549 setattr(cls, constant_name, extension_field.number) | |
550 | |
551 | |
552 def _AddStaticMethods(cls): | |
553 # TODO(robinson): This probably needs to be thread-safe(?) | |
554 def RegisterExtension(extension_handle): | |
555 extension_handle.containing_type = cls.DESCRIPTOR | |
556 _AttachFieldHelpers(cls, extension_handle) | |
557 | |
558 # Try to insert our extension, failing if an extension with the same number | |
559 # already exists. | |
560 actual_handle = cls._extensions_by_number.setdefault( | |
561 extension_handle.number, extension_handle) | |
562 if actual_handle is not extension_handle: | |
563 raise AssertionError( | |
564 'Extensions "%s" and "%s" both try to extend message type "%s" with ' | |
565 'field number %d.' % | |
566 (extension_handle.full_name, actual_handle.full_name, | |
567 cls.DESCRIPTOR.full_name, extension_handle.number)) | |
568 | |
569 cls._extensions_by_name[extension_handle.full_name] = extension_handle | |
570 | |
571 handle = extension_handle # avoid line wrapping | |
572 if _IsMessageSetExtension(handle): | |
573 # MessageSet extension. Also register under type name. | |
574 cls._extensions_by_name[ | |
575 extension_handle.message_type.full_name] = extension_handle | |
576 | |
577 cls.RegisterExtension = staticmethod(RegisterExtension) | |
578 | |
579 def FromString(s): | |
580 message = cls() | |
581 message.MergeFromString(s) | |
582 return message | |
583 cls.FromString = staticmethod(FromString) | |
584 | |
585 | |
586 def _IsPresent(item): | |
587 """Given a (FieldDescriptor, value) tuple from _fields, return true if the | |
588 value should be included in the list returned by ListFields().""" | |
589 | |
590 if item[0].label == _FieldDescriptor.LABEL_REPEATED: | |
591 return bool(item[1]) | |
592 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
593 return item[1]._is_present_in_parent | |
594 else: | |
595 return True | |
596 | |
597 | |
598 def _AddListFieldsMethod(message_descriptor, cls): | |
599 """Helper for _AddMessageMethods().""" | |
600 | |
601 def ListFields(self): | |
602 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] | |
603 all_fields.sort(key = lambda item: item[0].number) | |
604 return all_fields | |
605 | |
606 cls.ListFields = ListFields | |
607 | |
608 | |
609 def _AddHasFieldMethod(message_descriptor, cls): | |
610 """Helper for _AddMessageMethods().""" | |
611 | |
612 singular_fields = {} | |
613 for field in message_descriptor.fields: | |
614 if field.label != _FieldDescriptor.LABEL_REPEATED: | |
615 singular_fields[field.name] = field | |
616 # Fields inside oneofs are never repeated (enforced by the compiler). | |
617 for field in message_descriptor.oneofs: | |
618 singular_fields[field.name] = field | |
619 | |
620 def HasField(self, field_name): | |
621 try: | |
622 field = singular_fields[field_name] | |
623 except KeyError: | |
624 raise ValueError( | |
625 'Protocol message has no singular "%s" field.' % field_name) | |
626 | |
627 if isinstance(field, descriptor_mod.OneofDescriptor): | |
628 try: | |
629 return HasField(self, self._oneofs[field].name) | |
630 except KeyError: | |
631 return False | |
632 else: | |
633 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
634 value = self._fields.get(field) | |
635 return value is not None and value._is_present_in_parent | |
636 else: | |
637 return field in self._fields | |
638 | |
639 cls.HasField = HasField | |
640 | |
641 | |
642 def _AddClearFieldMethod(message_descriptor, cls): | |
643 """Helper for _AddMessageMethods().""" | |
644 def ClearField(self, field_name): | |
645 try: | |
646 field = message_descriptor.fields_by_name[field_name] | |
647 except KeyError: | |
648 try: | |
649 field = message_descriptor.oneofs_by_name[field_name] | |
650 if field in self._oneofs: | |
651 field = self._oneofs[field] | |
652 else: | |
653 return | |
654 except KeyError: | |
655 raise ValueError('Protocol message has no "%s" field.' % field_name) | |
656 | |
657 if field in self._fields: | |
658 # Note: If the field is a sub-message, its listener will still point | |
659 # at us. That's fine, because the worst than can happen is that it | |
660 # will call _Modified() and invalidate our byte size. Big deal. | |
661 del self._fields[field] | |
662 | |
663 if self._oneofs.get(field.containing_oneof, None) is field: | |
664 del self._oneofs[field.containing_oneof] | |
665 | |
666 # Always call _Modified() -- even if nothing was changed, this is | |
667 # a mutating method, and thus calling it should cause the field to become | |
668 # present in the parent message. | |
669 self._Modified() | |
670 | |
671 cls.ClearField = ClearField | |
672 | |
673 | |
674 def _AddClearExtensionMethod(cls): | |
675 """Helper for _AddMessageMethods().""" | |
676 def ClearExtension(self, extension_handle): | |
677 _VerifyExtensionHandle(self, extension_handle) | |
678 | |
679 # Similar to ClearField(), above. | |
680 if extension_handle in self._fields: | |
681 del self._fields[extension_handle] | |
682 self._Modified() | |
683 cls.ClearExtension = ClearExtension | |
684 | |
685 | |
686 def _AddClearMethod(message_descriptor, cls): | |
687 """Helper for _AddMessageMethods().""" | |
688 def Clear(self): | |
689 # Clear fields. | |
690 self._fields = {} | |
691 self._unknown_fields = () | |
692 self._Modified() | |
693 cls.Clear = Clear | |
694 | |
695 | |
696 def _AddHasExtensionMethod(cls): | |
697 """Helper for _AddMessageMethods().""" | |
698 def HasExtension(self, extension_handle): | |
699 _VerifyExtensionHandle(self, extension_handle) | |
700 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: | |
701 raise KeyError('"%s" is repeated.' % extension_handle.full_name) | |
702 | |
703 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
704 value = self._fields.get(extension_handle) | |
705 return value is not None and value._is_present_in_parent | |
706 else: | |
707 return extension_handle in self._fields | |
708 cls.HasExtension = HasExtension | |
709 | |
710 | |
711 def _AddEqualsMethod(message_descriptor, cls): | |
712 """Helper for _AddMessageMethods().""" | |
713 def __eq__(self, other): | |
714 if (not isinstance(other, message_mod.Message) or | |
715 other.DESCRIPTOR != self.DESCRIPTOR): | |
716 return False | |
717 | |
718 if self is other: | |
719 return True | |
720 | |
721 if not self.ListFields() == other.ListFields(): | |
722 return False | |
723 | |
724 # Sort unknown fields because their order shouldn't affect equality test. | |
725 unknown_fields = list(self._unknown_fields) | |
726 unknown_fields.sort() | |
727 other_unknown_fields = list(other._unknown_fields) | |
728 other_unknown_fields.sort() | |
729 | |
730 return unknown_fields == other_unknown_fields | |
731 | |
732 cls.__eq__ = __eq__ | |
733 | |
734 | |
735 def _AddStrMethod(message_descriptor, cls): | |
736 """Helper for _AddMessageMethods().""" | |
737 def __str__(self): | |
738 return text_format.MessageToString(self) | |
739 cls.__str__ = __str__ | |
740 | |
741 | |
742 def _AddUnicodeMethod(unused_message_descriptor, cls): | |
743 """Helper for _AddMessageMethods().""" | |
744 | |
745 def __unicode__(self): | |
746 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') | |
747 cls.__unicode__ = __unicode__ | |
748 | |
749 | |
750 def _AddSetListenerMethod(cls): | |
751 """Helper for _AddMessageMethods().""" | |
752 def SetListener(self, listener): | |
753 if listener is None: | |
754 self._listener = message_listener_mod.NullMessageListener() | |
755 else: | |
756 self._listener = listener | |
757 cls._SetListener = SetListener | |
758 | |
759 | |
760 def _BytesForNonRepeatedElement(value, field_number, field_type): | |
761 """Returns the number of bytes needed to serialize a non-repeated element. | |
762 The returned byte count includes space for tag information and any | |
763 other additional space associated with serializing value. | |
764 | |
765 Args: | |
766 value: Value we're serializing. | |
767 field_number: Field number of this value. (Since the field number | |
768 is stored as part of a varint-encoded tag, this has an impact | |
769 on the total bytes required to serialize the value). | |
770 field_type: The type of the field. One of the TYPE_* constants | |
771 within FieldDescriptor. | |
772 """ | |
773 try: | |
774 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] | |
775 return fn(field_number, value) | |
776 except KeyError: | |
777 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) | |
778 | |
779 | |
780 def _AddByteSizeMethod(message_descriptor, cls): | |
781 """Helper for _AddMessageMethods().""" | |
782 | |
783 def ByteSize(self): | |
784 if not self._cached_byte_size_dirty: | |
785 return self._cached_byte_size | |
786 | |
787 size = 0 | |
788 for field_descriptor, field_value in self.ListFields(): | |
789 size += field_descriptor._sizer(field_value) | |
790 | |
791 for tag_bytes, value_bytes in self._unknown_fields: | |
792 size += len(tag_bytes) + len(value_bytes) | |
793 | |
794 self._cached_byte_size = size | |
795 self._cached_byte_size_dirty = False | |
796 self._listener_for_children.dirty = False | |
797 return size | |
798 | |
799 cls.ByteSize = ByteSize | |
800 | |
801 | |
802 def _AddSerializeToStringMethod(message_descriptor, cls): | |
803 """Helper for _AddMessageMethods().""" | |
804 | |
805 def SerializeToString(self): | |
806 # Check if the message has all of its required fields set. | |
807 errors = [] | |
808 if not self.IsInitialized(): | |
809 raise message_mod.EncodeError( | |
810 'Message %s is missing required fields: %s' % ( | |
811 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) | |
812 return self.SerializePartialToString() | |
813 cls.SerializeToString = SerializeToString | |
814 | |
815 | |
816 def _AddSerializePartialToStringMethod(message_descriptor, cls): | |
817 """Helper for _AddMessageMethods().""" | |
818 | |
819 def SerializePartialToString(self): | |
820 out = BytesIO() | |
821 self._InternalSerialize(out.write) | |
822 return out.getvalue() | |
823 cls.SerializePartialToString = SerializePartialToString | |
824 | |
825 def InternalSerialize(self, write_bytes): | |
826 for field_descriptor, field_value in self.ListFields(): | |
827 field_descriptor._encoder(write_bytes, field_value) | |
828 for tag_bytes, value_bytes in self._unknown_fields: | |
829 write_bytes(tag_bytes) | |
830 write_bytes(value_bytes) | |
831 cls._InternalSerialize = InternalSerialize | |
832 | |
833 | |
834 def _AddMergeFromStringMethod(message_descriptor, cls): | |
835 """Helper for _AddMessageMethods().""" | |
836 def MergeFromString(self, serialized): | |
837 length = len(serialized) | |
838 try: | |
839 if self._InternalParse(serialized, 0, length) != length: | |
840 # The only reason _InternalParse would return early is if it | |
841 # encountered an end-group tag. | |
842 raise message_mod.DecodeError('Unexpected end-group tag.') | |
843 except (IndexError, TypeError): | |
844 # Now ord(buf[p:p+1]) == ord('') gets TypeError. | |
845 raise message_mod.DecodeError('Truncated message.') | |
846 except struct.error, e: | |
847 raise message_mod.DecodeError(e) | |
848 return length # Return this for legacy reasons. | |
849 cls.MergeFromString = MergeFromString | |
850 | |
851 local_ReadTag = decoder.ReadTag | |
852 local_SkipField = decoder.SkipField | |
853 decoders_by_tag = cls._decoders_by_tag | |
854 | |
855 def InternalParse(self, buffer, pos, end): | |
856 self._Modified() | |
857 field_dict = self._fields | |
858 unknown_field_list = self._unknown_fields | |
859 while pos != end: | |
860 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) | |
861 field_decoder = decoders_by_tag.get(tag_bytes) | |
862 if field_decoder is None: | |
863 value_start_pos = new_pos | |
864 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) | |
865 if new_pos == -1: | |
866 return pos | |
867 if not unknown_field_list: | |
868 unknown_field_list = self._unknown_fields = [] | |
869 unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) | |
870 pos = new_pos | |
871 else: | |
872 pos = field_decoder(buffer, new_pos, end, self, field_dict) | |
873 return pos | |
874 cls._InternalParse = InternalParse | |
875 | |
876 | |
877 def _AddIsInitializedMethod(message_descriptor, cls): | |
878 """Adds the IsInitialized and FindInitializationError methods to the | |
879 protocol message class.""" | |
880 | |
881 required_fields = [field for field in message_descriptor.fields | |
882 if field.label == _FieldDescriptor.LABEL_REQUIRED] | |
883 | |
884 def IsInitialized(self, errors=None): | |
885 """Checks if all required fields of a message are set. | |
886 | |
887 Args: | |
888 errors: A list which, if provided, will be populated with the field | |
889 paths of all missing required fields. | |
890 | |
891 Returns: | |
892 True iff the specified message has all required fields set. | |
893 """ | |
894 | |
895 # Performance is critical so we avoid HasField() and ListFields(). | |
896 | |
897 for field in required_fields: | |
898 if (field not in self._fields or | |
899 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and | |
900 not self._fields[field]._is_present_in_parent)): | |
901 if errors is not None: | |
902 errors.extend(self.FindInitializationErrors()) | |
903 return False | |
904 | |
905 for field, value in list(self._fields.items()): # dict can change size! | |
906 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
907 if field.label == _FieldDescriptor.LABEL_REPEATED: | |
908 for element in value: | |
909 if not element.IsInitialized(): | |
910 if errors is not None: | |
911 errors.extend(self.FindInitializationErrors()) | |
912 return False | |
913 elif value._is_present_in_parent and not value.IsInitialized(): | |
914 if errors is not None: | |
915 errors.extend(self.FindInitializationErrors()) | |
916 return False | |
917 | |
918 return True | |
919 | |
920 cls.IsInitialized = IsInitialized | |
921 | |
922 def FindInitializationErrors(self): | |
923 """Finds required fields which are not initialized. | |
924 | |
925 Returns: | |
926 A list of strings. Each string is a path to an uninitialized field from | |
927 the top-level message, e.g. "foo.bar[5].baz". | |
928 """ | |
929 | |
930 errors = [] # simplify things | |
931 | |
932 for field in required_fields: | |
933 if not self.HasField(field.name): | |
934 errors.append(field.name) | |
935 | |
936 for field, value in self.ListFields(): | |
937 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
938 if field.is_extension: | |
939 name = "(%s)" % field.full_name | |
940 else: | |
941 name = field.name | |
942 | |
943 if field.label == _FieldDescriptor.LABEL_REPEATED: | |
944 for i in xrange(len(value)): | |
945 element = value[i] | |
946 prefix = "%s[%d]." % (name, i) | |
947 sub_errors = element.FindInitializationErrors() | |
948 errors += [ prefix + error for error in sub_errors ] | |
949 else: | |
950 prefix = name + "." | |
951 sub_errors = value.FindInitializationErrors() | |
952 errors += [ prefix + error for error in sub_errors ] | |
953 | |
954 return errors | |
955 | |
956 cls.FindInitializationErrors = FindInitializationErrors | |
957 | |
958 | |
959 def _AddMergeFromMethod(cls): | |
960 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED | |
961 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE | |
962 | |
963 def MergeFrom(self, msg): | |
964 if not isinstance(msg, cls): | |
965 raise TypeError( | |
966 "Parameter to MergeFrom() must be instance of same class: " | |
967 "expected %s got %s." % (cls.__name__, type(msg).__name__)) | |
968 | |
969 assert msg is not self | |
970 self._Modified() | |
971 | |
972 fields = self._fields | |
973 | |
974 for field, value in msg._fields.iteritems(): | |
975 if field.label == LABEL_REPEATED: | |
976 field_value = fields.get(field) | |
977 if field_value is None: | |
978 # Construct a new object to represent this field. | |
979 field_value = field._default_constructor(self) | |
980 fields[field] = field_value | |
981 field_value.MergeFrom(value) | |
982 elif field.cpp_type == CPPTYPE_MESSAGE: | |
983 if value._is_present_in_parent: | |
984 field_value = fields.get(field) | |
985 if field_value is None: | |
986 # Construct a new object to represent this field. | |
987 field_value = field._default_constructor(self) | |
988 fields[field] = field_value | |
989 field_value.MergeFrom(value) | |
990 else: | |
991 self._fields[field] = value | |
992 | |
993 if msg._unknown_fields: | |
994 if not self._unknown_fields: | |
995 self._unknown_fields = [] | |
996 self._unknown_fields.extend(msg._unknown_fields) | |
997 | |
998 cls.MergeFrom = MergeFrom | |
999 | |
1000 | |
1001 def _AddWhichOneofMethod(message_descriptor, cls): | |
1002 def WhichOneof(self, oneof_name): | |
1003 """Returns the name of the currently set field inside a oneof, or None.""" | |
1004 try: | |
1005 field = message_descriptor.oneofs_by_name[oneof_name] | |
1006 except KeyError: | |
1007 raise ValueError( | |
1008 'Protocol message has no oneof "%s" field.' % oneof_name) | |
1009 | |
1010 nested_field = self._oneofs.get(field, None) | |
1011 if nested_field is not None and self.HasField(nested_field.name): | |
1012 return nested_field.name | |
1013 else: | |
1014 return None | |
1015 | |
1016 cls.WhichOneof = WhichOneof | |
1017 | |
1018 | |
1019 def _AddMessageMethods(message_descriptor, cls): | |
1020 """Adds implementations of all Message methods to cls.""" | |
1021 _AddListFieldsMethod(message_descriptor, cls) | |
1022 _AddHasFieldMethod(message_descriptor, cls) | |
1023 _AddClearFieldMethod(message_descriptor, cls) | |
1024 if message_descriptor.is_extendable: | |
1025 _AddClearExtensionMethod(cls) | |
1026 _AddHasExtensionMethod(cls) | |
1027 _AddClearMethod(message_descriptor, cls) | |
1028 _AddEqualsMethod(message_descriptor, cls) | |
1029 _AddStrMethod(message_descriptor, cls) | |
1030 _AddUnicodeMethod(message_descriptor, cls) | |
1031 _AddSetListenerMethod(cls) | |
1032 _AddByteSizeMethod(message_descriptor, cls) | |
1033 _AddSerializeToStringMethod(message_descriptor, cls) | |
1034 _AddSerializePartialToStringMethod(message_descriptor, cls) | |
1035 _AddMergeFromStringMethod(message_descriptor, cls) | |
1036 _AddIsInitializedMethod(message_descriptor, cls) | |
1037 _AddMergeFromMethod(cls) | |
1038 _AddWhichOneofMethod(message_descriptor, cls) | |
1039 | |
1040 def _AddPrivateHelperMethods(message_descriptor, cls): | |
1041 """Adds implementation of private helper methods to cls.""" | |
1042 | |
1043 def Modified(self): | |
1044 """Sets the _cached_byte_size_dirty bit to true, | |
1045 and propagates this to our listener iff this was a state change. | |
1046 """ | |
1047 | |
1048 # Note: Some callers check _cached_byte_size_dirty before calling | |
1049 # _Modified() as an extra optimization. So, if this method is ever | |
1050 # changed such that it does stuff even when _cached_byte_size_dirty is | |
1051 # already true, the callers need to be updated. | |
1052 if not self._cached_byte_size_dirty: | |
1053 self._cached_byte_size_dirty = True | |
1054 self._listener_for_children.dirty = True | |
1055 self._is_present_in_parent = True | |
1056 self._listener.Modified() | |
1057 | |
1058 def _UpdateOneofState(self, field): | |
1059 """Sets field as the active field in its containing oneof. | |
1060 | |
1061 Will also delete currently active field in the oneof, if it is different | |
1062 from the argument. Does not mark the message as modified. | |
1063 """ | |
1064 other_field = self._oneofs.setdefault(field.containing_oneof, field) | |
1065 if other_field is not field: | |
1066 del self._fields[other_field] | |
1067 self._oneofs[field.containing_oneof] = field | |
1068 | |
1069 cls._Modified = Modified | |
1070 cls.SetInParent = Modified | |
1071 cls._UpdateOneofState = _UpdateOneofState | |
1072 | |
1073 | |
1074 class _Listener(object): | |
1075 | |
1076 """MessageListener implementation that a parent message registers with its | |
1077 child message. | |
1078 | |
1079 In order to support semantics like: | |
1080 | |
1081 foo.bar.baz.qux = 23 | |
1082 assert foo.HasField('bar') | |
1083 | |
1084 ...child objects must have back references to their parents. | |
1085 This helper class is at the heart of this support. | |
1086 """ | |
1087 | |
1088 def __init__(self, parent_message): | |
1089 """Args: | |
1090 parent_message: The message whose _Modified() method we should call when | |
1091 we receive Modified() messages. | |
1092 """ | |
1093 # This listener establishes a back reference from a child (contained) object | |
1094 # to its parent (containing) object. We make this a weak reference to avoid | |
1095 # creating cyclic garbage when the client finishes with the 'parent' object | |
1096 # in the tree. | |
1097 if isinstance(parent_message, weakref.ProxyType): | |
1098 self._parent_message_weakref = parent_message | |
1099 else: | |
1100 self._parent_message_weakref = weakref.proxy(parent_message) | |
1101 | |
1102 # As an optimization, we also indicate directly on the listener whether | |
1103 # or not the parent message is dirty. This way we can avoid traversing | |
1104 # up the tree in the common case. | |
1105 self.dirty = False | |
1106 | |
1107 def Modified(self): | |
1108 if self.dirty: | |
1109 return | |
1110 try: | |
1111 # Propagate the signal to our parents iff this is the first field set. | |
1112 self._parent_message_weakref._Modified() | |
1113 except ReferenceError: | |
1114 # We can get here if a client has kept a reference to a child object, | |
1115 # and is now setting a field on it, but the child's parent has been | |
1116 # garbage-collected. This is not an error. | |
1117 pass | |
1118 | |
1119 | |
1120 class _OneofListener(_Listener): | |
1121 """Special listener implementation for setting composite oneof fields.""" | |
1122 | |
1123 def __init__(self, parent_message, field): | |
1124 """Args: | |
1125 parent_message: The message whose _Modified() method we should call when | |
1126 we receive Modified() messages. | |
1127 field: The descriptor of the field being set in the parent message. | |
1128 """ | |
1129 super(_OneofListener, self).__init__(parent_message) | |
1130 self._field = field | |
1131 | |
1132 def Modified(self): | |
1133 """Also updates the state of the containing oneof in the parent message.""" | |
1134 try: | |
1135 self._parent_message_weakref._UpdateOneofState(self._field) | |
1136 super(_OneofListener, self).Modified() | |
1137 except ReferenceError: | |
1138 pass | |
1139 | |
1140 | |
1141 # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... | |
1142 # TODO(robinson): Unify error handling of "unknown extension" crap. | |
1143 # TODO(robinson): Support iteritems()-style iteration over all | |
1144 # extensions with the "has" bits turned on? | |
1145 class _ExtensionDict(object): | |
1146 | |
1147 """Dict-like container for supporting an indexable "Extensions" | |
1148 field on proto instances. | |
1149 | |
1150 Note that in all cases we expect extension handles to be | |
1151 FieldDescriptors. | |
1152 """ | |
1153 | |
1154 def __init__(self, extended_message): | |
1155 """extended_message: Message instance for which we are the Extensions dict. | |
1156 """ | |
1157 | |
1158 self._extended_message = extended_message | |
1159 | |
1160 def __getitem__(self, extension_handle): | |
1161 """Returns the current value of the given extension handle.""" | |
1162 | |
1163 _VerifyExtensionHandle(self._extended_message, extension_handle) | |
1164 | |
1165 result = self._extended_message._fields.get(extension_handle) | |
1166 if result is not None: | |
1167 return result | |
1168 | |
1169 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: | |
1170 result = extension_handle._default_constructor(self._extended_message) | |
1171 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
1172 result = extension_handle.message_type._concrete_class() | |
1173 try: | |
1174 result._SetListener(self._extended_message._listener_for_children) | |
1175 except ReferenceError: | |
1176 pass | |
1177 else: | |
1178 # Singular scalar -- just return the default without inserting into the | |
1179 # dict. | |
1180 return extension_handle.default_value | |
1181 | |
1182 # Atomically check if another thread has preempted us and, if not, swap | |
1183 # in the new object we just created. If someone has preempted us, we | |
1184 # take that object and discard ours. | |
1185 # WARNING: We are relying on setdefault() being atomic. This is true | |
1186 # in CPython but we haven't investigated others. This warning appears | |
1187 # in several other locations in this file. | |
1188 result = self._extended_message._fields.setdefault( | |
1189 extension_handle, result) | |
1190 | |
1191 return result | |
1192 | |
1193 def __eq__(self, other): | |
1194 if not isinstance(other, self.__class__): | |
1195 return False | |
1196 | |
1197 my_fields = self._extended_message.ListFields() | |
1198 other_fields = other._extended_message.ListFields() | |
1199 | |
1200 # Get rid of non-extension fields. | |
1201 my_fields = [ field for field in my_fields if field.is_extension ] | |
1202 other_fields = [ field for field in other_fields if field.is_extension ] | |
1203 | |
1204 return my_fields == other_fields | |
1205 | |
1206 def __ne__(self, other): | |
1207 return not self == other | |
1208 | |
1209 def __hash__(self): | |
1210 raise TypeError('unhashable object') | |
1211 | |
1212 # Note that this is only meaningful for non-repeated, scalar extension | |
1213 # fields. Note also that we may have to call _Modified() when we do | |
1214 # successfully set a field this way, to set any necssary "has" bits in the | |
1215 # ancestors of the extended message. | |
1216 def __setitem__(self, extension_handle, value): | |
1217 """If extension_handle specifies a non-repeated, scalar extension | |
1218 field, sets the value of that field. | |
1219 """ | |
1220 | |
1221 _VerifyExtensionHandle(self._extended_message, extension_handle) | |
1222 | |
1223 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or | |
1224 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): | |
1225 raise TypeError( | |
1226 'Cannot assign to extension "%s" because it is a repeated or ' | |
1227 'composite type.' % extension_handle.full_name) | |
1228 | |
1229 # It's slightly wasteful to lookup the type checker each time, | |
1230 # but we expect this to be a vanishingly uncommon case anyway. | |
1231 type_checker = type_checkers.GetTypeChecker( | |
1232 extension_handle) | |
1233 # pylint: disable=protected-access | |
1234 self._extended_message._fields[extension_handle] = ( | |
1235 type_checker.CheckValue(value)) | |
1236 self._extended_message._Modified() | |
1237 | |
1238 def _FindExtensionByName(self, name): | |
1239 """Tries to find a known extension with the specified name. | |
1240 | |
1241 Args: | |
1242 name: Extension full name. | |
1243 | |
1244 Returns: | |
1245 Extension field descriptor. | |
1246 """ | |
1247 return self._extended_message._extensions_by_name.get(name, None) | |
OLD | NEW |