OLD | NEW |
(Empty) | |
| 1 # Protocol Buffers - Google's data interchange format |
| 2 # Copyright 2008 Google Inc. All rights reserved. |
| 3 # http://code.google.com/p/protobuf/ |
| 4 # |
| 5 # Redistribution and use in source and binary forms, with or without |
| 6 # modification, are permitted provided that the following conditions are |
| 7 # met: |
| 8 # |
| 9 # * Redistributions of source code must retain the above copyright |
| 10 # notice, this list of conditions and the following disclaimer. |
| 11 # * Redistributions in binary form must reproduce the above |
| 12 # copyright notice, this list of conditions and the following disclaimer |
| 13 # in the documentation and/or other materials provided with the |
| 14 # distribution. |
| 15 # * Neither the name of Google Inc. nor the names of its |
| 16 # contributors may be used to endorse or promote products derived from |
| 17 # this software without specific prior written permission. |
| 18 # |
| 19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
| 20 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
| 21 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
| 22 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
| 23 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
| 24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
| 25 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
| 26 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
| 27 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| 28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 29 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 30 |
| 31 # Keep it Python2.5 compatible for GAE. |
| 32 # |
| 33 # Copyright 2007 Google Inc. All Rights Reserved. |
| 34 # |
| 35 # This code is meant to work on Python 2.4 and above only. |
| 36 # |
| 37 # TODO(robinson): Helpers for verbose, common checks like seeing if a |
| 38 # descriptor's cpp_type is CPPTYPE_MESSAGE. |
| 39 |
| 40 """Contains a metaclass and helper functions used to create |
| 41 protocol message classes from Descriptor objects at runtime. |
| 42 |
| 43 Recall that a metaclass is the "type" of a class. |
| 44 (A class is to a metaclass what an instance is to a class.) |
| 45 |
| 46 In this case, we use the GeneratedProtocolMessageType metaclass |
| 47 to inject all the useful functionality into the classes |
| 48 output by the protocol compiler at compile-time. |
| 49 |
| 50 The upshot of all this is that the real implementation |
| 51 details for ALL pure-Python protocol buffers are *here in |
| 52 this file*. |
| 53 """ |
| 54 |
| 55 __author__ = 'robinson@google.com (Will Robinson)' |
| 56 |
| 57 import sys |
| 58 if sys.version_info[0] < 3: |
| 59 try: |
| 60 from cStringIO import StringIO as BytesIO |
| 61 except ImportError: |
| 62 from StringIO import StringIO as BytesIO |
| 63 import copy_reg as copyreg |
| 64 else: |
| 65 from io import BytesIO |
| 66 import copyreg |
| 67 import struct |
| 68 import weakref |
| 69 |
| 70 # We use "as" to avoid name collisions with variables. |
| 71 from google.protobuf.internal import containers |
| 72 from google.protobuf.internal import decoder |
| 73 from google.protobuf.internal import encoder |
| 74 from google.protobuf.internal import enum_type_wrapper |
| 75 from google.protobuf.internal import message_listener as message_listener_mod |
| 76 from google.protobuf.internal import type_checkers |
| 77 from google.protobuf.internal import wire_format |
| 78 from google.protobuf import descriptor as descriptor_mod |
| 79 from google.protobuf import message as message_mod |
| 80 from google.protobuf import text_format |
| 81 |
| 82 _FieldDescriptor = descriptor_mod.FieldDescriptor |
| 83 |
| 84 |
| 85 def NewMessage(bases, descriptor, dictionary): |
| 86 _AddClassAttributesForNestedExtensions(descriptor, dictionary) |
| 87 _AddSlots(descriptor, dictionary) |
| 88 return bases |
| 89 |
| 90 |
| 91 def InitMessage(descriptor, cls): |
| 92 cls._decoders_by_tag = {} |
| 93 cls._extensions_by_name = {} |
| 94 cls._extensions_by_number = {} |
| 95 if (descriptor.has_options and |
| 96 descriptor.GetOptions().message_set_wire_format): |
| 97 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( |
| 98 decoder.MessageSetItemDecoder(cls._extensions_by_number)) |
| 99 |
| 100 # Attach stuff to each FieldDescriptor for quick lookup later on. |
| 101 for field in descriptor.fields: |
| 102 _AttachFieldHelpers(cls, field) |
| 103 |
| 104 _AddEnumValues(descriptor, cls) |
| 105 _AddInitMethod(descriptor, cls) |
| 106 _AddPropertiesForFields(descriptor, cls) |
| 107 _AddPropertiesForExtensions(descriptor, cls) |
| 108 _AddStaticMethods(cls) |
| 109 _AddMessageMethods(descriptor, cls) |
| 110 _AddPrivateHelperMethods(descriptor, cls) |
| 111 copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) |
| 112 |
| 113 |
| 114 # Stateless helpers for GeneratedProtocolMessageType below. |
| 115 # Outside clients should not access these directly. |
| 116 # |
| 117 # I opted not to make any of these methods on the metaclass, to make it more |
| 118 # clear that I'm not really using any state there and to keep clients from |
| 119 # thinking that they have direct access to these construction helpers. |
| 120 |
| 121 |
| 122 def _PropertyName(proto_field_name): |
| 123 """Returns the name of the public property attribute which |
| 124 clients can use to get and (in some cases) set the value |
| 125 of a protocol message field. |
| 126 |
| 127 Args: |
| 128 proto_field_name: The protocol message field name, exactly |
| 129 as it appears (or would appear) in a .proto file. |
| 130 """ |
| 131 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. |
| 132 # nnorwitz makes my day by writing: |
| 133 # """ |
| 134 # FYI. See the keyword module in the stdlib. This could be as simple as: |
| 135 # |
| 136 # if keyword.iskeyword(proto_field_name): |
| 137 # return proto_field_name + "_" |
| 138 # return proto_field_name |
| 139 # """ |
| 140 # Kenton says: The above is a BAD IDEA. People rely on being able to use |
| 141 # getattr() and setattr() to reflectively manipulate field values. If we |
| 142 # rename the properties, then every such user has to also make sure to apply |
| 143 # the same transformation. Note that currently if you name a field "yield", |
| 144 # you can still access it just fine using getattr/setattr -- it's not even |
| 145 # that cumbersome to do so. |
| 146 # TODO(kenton): Remove this method entirely if/when everyone agrees with my |
| 147 # position. |
| 148 return proto_field_name |
| 149 |
| 150 |
| 151 def _VerifyExtensionHandle(message, extension_handle): |
| 152 """Verify that the given extension handle is valid.""" |
| 153 |
| 154 if not isinstance(extension_handle, _FieldDescriptor): |
| 155 raise KeyError('HasExtension() expects an extension handle, got: %s' % |
| 156 extension_handle) |
| 157 |
| 158 if not extension_handle.is_extension: |
| 159 raise KeyError('"%s" is not an extension.' % extension_handle.full_name) |
| 160 |
| 161 if not extension_handle.containing_type: |
| 162 raise KeyError('"%s" is missing a containing_type.' |
| 163 % extension_handle.full_name) |
| 164 |
| 165 if extension_handle.containing_type is not message.DESCRIPTOR: |
| 166 raise KeyError('Extension "%s" extends message type "%s", but this ' |
| 167 'message is of type "%s".' % |
| 168 (extension_handle.full_name, |
| 169 extension_handle.containing_type.full_name, |
| 170 message.DESCRIPTOR.full_name)) |
| 171 |
| 172 |
| 173 def _AddSlots(message_descriptor, dictionary): |
| 174 """Adds a __slots__ entry to dictionary, containing the names of all valid |
| 175 attributes for this message type. |
| 176 |
| 177 Args: |
| 178 message_descriptor: A Descriptor instance describing this message type. |
| 179 dictionary: Class dictionary to which we'll add a '__slots__' entry. |
| 180 """ |
| 181 dictionary['__slots__'] = ['_cached_byte_size', |
| 182 '_cached_byte_size_dirty', |
| 183 '_fields', |
| 184 '_unknown_fields', |
| 185 '_is_present_in_parent', |
| 186 '_listener', |
| 187 '_listener_for_children', |
| 188 '__weakref__', |
| 189 '_oneofs'] |
| 190 |
| 191 |
| 192 def _IsMessageSetExtension(field): |
| 193 return (field.is_extension and |
| 194 field.containing_type.has_options and |
| 195 field.containing_type.GetOptions().message_set_wire_format and |
| 196 field.type == _FieldDescriptor.TYPE_MESSAGE and |
| 197 field.message_type == field.extension_scope and |
| 198 field.label == _FieldDescriptor.LABEL_OPTIONAL) |
| 199 |
| 200 |
| 201 def _AttachFieldHelpers(cls, field_descriptor): |
| 202 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) |
| 203 is_packed = (field_descriptor.has_options and |
| 204 field_descriptor.GetOptions().packed) |
| 205 |
| 206 if _IsMessageSetExtension(field_descriptor): |
| 207 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) |
| 208 sizer = encoder.MessageSetItemSizer(field_descriptor.number) |
| 209 else: |
| 210 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( |
| 211 field_descriptor.number, is_repeated, is_packed) |
| 212 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( |
| 213 field_descriptor.number, is_repeated, is_packed) |
| 214 |
| 215 field_descriptor._encoder = field_encoder |
| 216 field_descriptor._sizer = sizer |
| 217 field_descriptor._default_constructor = _DefaultValueConstructorForField( |
| 218 field_descriptor) |
| 219 |
| 220 def AddDecoder(wiretype, is_packed): |
| 221 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) |
| 222 cls._decoders_by_tag[tag_bytes] = ( |
| 223 type_checkers.TYPE_TO_DECODER[field_descriptor.type]( |
| 224 field_descriptor.number, is_repeated, is_packed, |
| 225 field_descriptor, field_descriptor._default_constructor)) |
| 226 |
| 227 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], |
| 228 False) |
| 229 |
| 230 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): |
| 231 # To support wire compatibility of adding packed = true, add a decoder for |
| 232 # packed values regardless of the field's options. |
| 233 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) |
| 234 |
| 235 |
| 236 def _AddClassAttributesForNestedExtensions(descriptor, dictionary): |
| 237 extension_dict = descriptor.extensions_by_name |
| 238 for extension_name, extension_field in extension_dict.iteritems(): |
| 239 assert extension_name not in dictionary |
| 240 dictionary[extension_name] = extension_field |
| 241 |
| 242 |
| 243 def _AddEnumValues(descriptor, cls): |
| 244 """Sets class-level attributes for all enum fields defined in this message. |
| 245 |
| 246 Also exporting a class-level object that can name enum values. |
| 247 |
| 248 Args: |
| 249 descriptor: Descriptor object for this message type. |
| 250 cls: Class we're constructing for this message type. |
| 251 """ |
| 252 for enum_type in descriptor.enum_types: |
| 253 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) |
| 254 for enum_value in enum_type.values: |
| 255 setattr(cls, enum_value.name, enum_value.number) |
| 256 |
| 257 |
| 258 def _DefaultValueConstructorForField(field): |
| 259 """Returns a function which returns a default value for a field. |
| 260 |
| 261 Args: |
| 262 field: FieldDescriptor object for this field. |
| 263 |
| 264 The returned function has one argument: |
| 265 message: Message instance containing this field, or a weakref proxy |
| 266 of same. |
| 267 |
| 268 That function in turn returns a default value for this field. The default |
| 269 value may refer back to |message| via a weak reference. |
| 270 """ |
| 271 |
| 272 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 273 if field.has_default_value and field.default_value != []: |
| 274 raise ValueError('Repeated field default value not empty list: %s' % ( |
| 275 field.default_value)) |
| 276 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 277 # We can't look at _concrete_class yet since it might not have |
| 278 # been set. (Depends on order in which we initialize the classes). |
| 279 message_type = field.message_type |
| 280 def MakeRepeatedMessageDefault(message): |
| 281 return containers.RepeatedCompositeFieldContainer( |
| 282 message._listener_for_children, field.message_type) |
| 283 return MakeRepeatedMessageDefault |
| 284 else: |
| 285 type_checker = type_checkers.GetTypeChecker(field) |
| 286 def MakeRepeatedScalarDefault(message): |
| 287 return containers.RepeatedScalarFieldContainer( |
| 288 message._listener_for_children, type_checker) |
| 289 return MakeRepeatedScalarDefault |
| 290 |
| 291 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 292 # _concrete_class may not yet be initialized. |
| 293 message_type = field.message_type |
| 294 def MakeSubMessageDefault(message): |
| 295 result = message_type._concrete_class() |
| 296 result._SetListener(message._listener_for_children) |
| 297 return result |
| 298 return MakeSubMessageDefault |
| 299 |
| 300 def MakeScalarDefault(message): |
| 301 # TODO(protobuf-team): This may be broken since there may not be |
| 302 # default_value. Combine with has_default_value somehow. |
| 303 return field.default_value |
| 304 return MakeScalarDefault |
| 305 |
| 306 |
| 307 def _AddInitMethod(message_descriptor, cls): |
| 308 """Adds an __init__ method to cls.""" |
| 309 fields = message_descriptor.fields |
| 310 def init(self, **kwargs): |
| 311 self._cached_byte_size = 0 |
| 312 self._cached_byte_size_dirty = len(kwargs) > 0 |
| 313 self._fields = {} |
| 314 # Contains a mapping from oneof field descriptors to the descriptor |
| 315 # of the currently set field in that oneof field. |
| 316 self._oneofs = {} |
| 317 |
| 318 # _unknown_fields is () when empty for efficiency, and will be turned into |
| 319 # a list if fields are added. |
| 320 self._unknown_fields = () |
| 321 self._is_present_in_parent = False |
| 322 self._listener = message_listener_mod.NullMessageListener() |
| 323 self._listener_for_children = _Listener(self) |
| 324 for field_name, field_value in kwargs.iteritems(): |
| 325 field = _GetFieldByName(message_descriptor, field_name) |
| 326 if field is None: |
| 327 raise TypeError("%s() got an unexpected keyword argument '%s'" % |
| 328 (message_descriptor.name, field_name)) |
| 329 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 330 copy = field._default_constructor(self) |
| 331 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite |
| 332 for val in field_value: |
| 333 copy.add().MergeFrom(val) |
| 334 else: # Scalar |
| 335 copy.extend(field_value) |
| 336 self._fields[field] = copy |
| 337 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 338 copy = field._default_constructor(self) |
| 339 copy.MergeFrom(field_value) |
| 340 self._fields[field] = copy |
| 341 else: |
| 342 setattr(self, field_name, field_value) |
| 343 |
| 344 init.__module__ = None |
| 345 init.__doc__ = None |
| 346 cls.__init__ = init |
| 347 |
| 348 |
| 349 def _GetFieldByName(message_descriptor, field_name): |
| 350 """Returns a field descriptor by field name. |
| 351 |
| 352 Args: |
| 353 message_descriptor: A Descriptor describing all fields in message. |
| 354 field_name: The name of the field to retrieve. |
| 355 Returns: |
| 356 The field descriptor associated with the field name. |
| 357 """ |
| 358 try: |
| 359 return message_descriptor.fields_by_name[field_name] |
| 360 except KeyError: |
| 361 raise ValueError('Protocol message has no "%s" field.' % field_name) |
| 362 |
| 363 |
| 364 def _AddPropertiesForFields(descriptor, cls): |
| 365 """Adds properties for all fields in this protocol message type.""" |
| 366 for field in descriptor.fields: |
| 367 _AddPropertiesForField(field, cls) |
| 368 |
| 369 if descriptor.is_extendable: |
| 370 # _ExtensionDict is just an adaptor with no state so we allocate a new one |
| 371 # every time it is accessed. |
| 372 cls.Extensions = property(lambda self: _ExtensionDict(self)) |
| 373 |
| 374 |
| 375 def _AddPropertiesForField(field, cls): |
| 376 """Adds a public property for a protocol message field. |
| 377 Clients can use this property to get and (in the case |
| 378 of non-repeated scalar fields) directly set the value |
| 379 of a protocol message field. |
| 380 |
| 381 Args: |
| 382 field: A FieldDescriptor for this field. |
| 383 cls: The class we're constructing. |
| 384 """ |
| 385 # Catch it if we add other types that we should |
| 386 # handle specially here. |
| 387 assert _FieldDescriptor.MAX_CPPTYPE == 10 |
| 388 |
| 389 constant_name = field.name.upper() + "_FIELD_NUMBER" |
| 390 setattr(cls, constant_name, field.number) |
| 391 |
| 392 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 393 _AddPropertiesForRepeatedField(field, cls) |
| 394 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 395 _AddPropertiesForNonRepeatedCompositeField(field, cls) |
| 396 else: |
| 397 _AddPropertiesForNonRepeatedScalarField(field, cls) |
| 398 |
| 399 |
| 400 def _AddPropertiesForRepeatedField(field, cls): |
| 401 """Adds a public property for a "repeated" protocol message field. Clients |
| 402 can use this property to get the value of the field, which will be either a |
| 403 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see |
| 404 below). |
| 405 |
| 406 Note that when clients add values to these containers, we perform |
| 407 type-checking in the case of repeated scalar fields, and we also set any |
| 408 necessary "has" bits as a side-effect. |
| 409 |
| 410 Args: |
| 411 field: A FieldDescriptor for this field. |
| 412 cls: The class we're constructing. |
| 413 """ |
| 414 proto_field_name = field.name |
| 415 property_name = _PropertyName(proto_field_name) |
| 416 |
| 417 def getter(self): |
| 418 field_value = self._fields.get(field) |
| 419 if field_value is None: |
| 420 # Construct a new object to represent this field. |
| 421 field_value = field._default_constructor(self) |
| 422 |
| 423 # Atomically check if another thread has preempted us and, if not, swap |
| 424 # in the new object we just created. If someone has preempted us, we |
| 425 # take that object and discard ours. |
| 426 # WARNING: We are relying on setdefault() being atomic. This is true |
| 427 # in CPython but we haven't investigated others. This warning appears |
| 428 # in several other locations in this file. |
| 429 field_value = self._fields.setdefault(field, field_value) |
| 430 return field_value |
| 431 getter.__module__ = None |
| 432 getter.__doc__ = 'Getter for %s.' % proto_field_name |
| 433 |
| 434 # We define a setter just so we can throw an exception with a more |
| 435 # helpful error message. |
| 436 def setter(self, new_value): |
| 437 raise AttributeError('Assignment not allowed to repeated field ' |
| 438 '"%s" in protocol message object.' % proto_field_name) |
| 439 |
| 440 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name |
| 441 setattr(cls, property_name, property(getter, setter, doc=doc)) |
| 442 |
| 443 |
| 444 def _AddPropertiesForNonRepeatedScalarField(field, cls): |
| 445 """Adds a public property for a nonrepeated, scalar protocol message field. |
| 446 Clients can use this property to get and directly set the value of the field. |
| 447 Note that when the client sets the value of a field by using this property, |
| 448 all necessary "has" bits are set as a side-effect, and we also perform |
| 449 type-checking. |
| 450 |
| 451 Args: |
| 452 field: A FieldDescriptor for this field. |
| 453 cls: The class we're constructing. |
| 454 """ |
| 455 proto_field_name = field.name |
| 456 property_name = _PropertyName(proto_field_name) |
| 457 type_checker = type_checkers.GetTypeChecker(field) |
| 458 default_value = field.default_value |
| 459 valid_values = set() |
| 460 |
| 461 def getter(self): |
| 462 # TODO(protobuf-team): This may be broken since there may not be |
| 463 # default_value. Combine with has_default_value somehow. |
| 464 return self._fields.get(field, default_value) |
| 465 getter.__module__ = None |
| 466 getter.__doc__ = 'Getter for %s.' % proto_field_name |
| 467 def field_setter(self, new_value): |
| 468 # pylint: disable=protected-access |
| 469 self._fields[field] = type_checker.CheckValue(new_value) |
| 470 # Check _cached_byte_size_dirty inline to improve performance, since scalar |
| 471 # setters are called frequently. |
| 472 if not self._cached_byte_size_dirty: |
| 473 self._Modified() |
| 474 |
| 475 if field.containing_oneof is not None: |
| 476 def setter(self, new_value): |
| 477 field_setter(self, new_value) |
| 478 self._UpdateOneofState(field) |
| 479 else: |
| 480 setter = field_setter |
| 481 |
| 482 setter.__module__ = None |
| 483 setter.__doc__ = 'Setter for %s.' % proto_field_name |
| 484 |
| 485 # Add a property to encapsulate the getter/setter. |
| 486 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name |
| 487 setattr(cls, property_name, property(getter, setter, doc=doc)) |
| 488 |
| 489 |
| 490 def _AddPropertiesForNonRepeatedCompositeField(field, cls): |
| 491 """Adds a public property for a nonrepeated, composite protocol message field. |
| 492 A composite field is a "group" or "message" field. |
| 493 |
| 494 Clients can use this property to get the value of the field, but cannot |
| 495 assign to the property directly. |
| 496 |
| 497 Args: |
| 498 field: A FieldDescriptor for this field. |
| 499 cls: The class we're constructing. |
| 500 """ |
| 501 # TODO(robinson): Remove duplication with similar method |
| 502 # for non-repeated scalars. |
| 503 proto_field_name = field.name |
| 504 property_name = _PropertyName(proto_field_name) |
| 505 |
| 506 # TODO(komarek): Can anyone explain to me why we cache the message_type this |
| 507 # way, instead of referring to field.message_type inside of getter(self)? |
| 508 # What if someone sets message_type later on (which makes for simpler |
| 509 # dyanmic proto descriptor and class creation code). |
| 510 message_type = field.message_type |
| 511 |
| 512 def getter(self): |
| 513 field_value = self._fields.get(field) |
| 514 if field_value is None: |
| 515 # Construct a new object to represent this field. |
| 516 field_value = message_type._concrete_class() # use field.message_type? |
| 517 field_value._SetListener( |
| 518 _OneofListener(self, field) |
| 519 if field.containing_oneof is not None |
| 520 else self._listener_for_children) |
| 521 |
| 522 # Atomically check if another thread has preempted us and, if not, swap |
| 523 # in the new object we just created. If someone has preempted us, we |
| 524 # take that object and discard ours. |
| 525 # WARNING: We are relying on setdefault() being atomic. This is true |
| 526 # in CPython but we haven't investigated others. This warning appears |
| 527 # in several other locations in this file. |
| 528 field_value = self._fields.setdefault(field, field_value) |
| 529 return field_value |
| 530 getter.__module__ = None |
| 531 getter.__doc__ = 'Getter for %s.' % proto_field_name |
| 532 |
| 533 # We define a setter just so we can throw an exception with a more |
| 534 # helpful error message. |
| 535 def setter(self, new_value): |
| 536 raise AttributeError('Assignment not allowed to composite field ' |
| 537 '"%s" in protocol message object.' % proto_field_name) |
| 538 |
| 539 # Add a property to encapsulate the getter. |
| 540 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name |
| 541 setattr(cls, property_name, property(getter, setter, doc=doc)) |
| 542 |
| 543 |
| 544 def _AddPropertiesForExtensions(descriptor, cls): |
| 545 """Adds properties for all fields in this protocol message type.""" |
| 546 extension_dict = descriptor.extensions_by_name |
| 547 for extension_name, extension_field in extension_dict.iteritems(): |
| 548 constant_name = extension_name.upper() + "_FIELD_NUMBER" |
| 549 setattr(cls, constant_name, extension_field.number) |
| 550 |
| 551 |
| 552 def _AddStaticMethods(cls): |
| 553 # TODO(robinson): This probably needs to be thread-safe(?) |
| 554 def RegisterExtension(extension_handle): |
| 555 extension_handle.containing_type = cls.DESCRIPTOR |
| 556 _AttachFieldHelpers(cls, extension_handle) |
| 557 |
| 558 # Try to insert our extension, failing if an extension with the same number |
| 559 # already exists. |
| 560 actual_handle = cls._extensions_by_number.setdefault( |
| 561 extension_handle.number, extension_handle) |
| 562 if actual_handle is not extension_handle: |
| 563 raise AssertionError( |
| 564 'Extensions "%s" and "%s" both try to extend message type "%s" with ' |
| 565 'field number %d.' % |
| 566 (extension_handle.full_name, actual_handle.full_name, |
| 567 cls.DESCRIPTOR.full_name, extension_handle.number)) |
| 568 |
| 569 cls._extensions_by_name[extension_handle.full_name] = extension_handle |
| 570 |
| 571 handle = extension_handle # avoid line wrapping |
| 572 if _IsMessageSetExtension(handle): |
| 573 # MessageSet extension. Also register under type name. |
| 574 cls._extensions_by_name[ |
| 575 extension_handle.message_type.full_name] = extension_handle |
| 576 |
| 577 cls.RegisterExtension = staticmethod(RegisterExtension) |
| 578 |
| 579 def FromString(s): |
| 580 message = cls() |
| 581 message.MergeFromString(s) |
| 582 return message |
| 583 cls.FromString = staticmethod(FromString) |
| 584 |
| 585 |
| 586 def _IsPresent(item): |
| 587 """Given a (FieldDescriptor, value) tuple from _fields, return true if the |
| 588 value should be included in the list returned by ListFields().""" |
| 589 |
| 590 if item[0].label == _FieldDescriptor.LABEL_REPEATED: |
| 591 return bool(item[1]) |
| 592 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 593 return item[1]._is_present_in_parent |
| 594 else: |
| 595 return True |
| 596 |
| 597 |
| 598 def _AddListFieldsMethod(message_descriptor, cls): |
| 599 """Helper for _AddMessageMethods().""" |
| 600 |
| 601 def ListFields(self): |
| 602 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] |
| 603 all_fields.sort(key = lambda item: item[0].number) |
| 604 return all_fields |
| 605 |
| 606 cls.ListFields = ListFields |
| 607 |
| 608 |
| 609 def _AddHasFieldMethod(message_descriptor, cls): |
| 610 """Helper for _AddMessageMethods().""" |
| 611 |
| 612 singular_fields = {} |
| 613 for field in message_descriptor.fields: |
| 614 if field.label != _FieldDescriptor.LABEL_REPEATED: |
| 615 singular_fields[field.name] = field |
| 616 # Fields inside oneofs are never repeated (enforced by the compiler). |
| 617 for field in message_descriptor.oneofs: |
| 618 singular_fields[field.name] = field |
| 619 |
| 620 def HasField(self, field_name): |
| 621 try: |
| 622 field = singular_fields[field_name] |
| 623 except KeyError: |
| 624 raise ValueError( |
| 625 'Protocol message has no singular "%s" field.' % field_name) |
| 626 |
| 627 if isinstance(field, descriptor_mod.OneofDescriptor): |
| 628 try: |
| 629 return HasField(self, self._oneofs[field].name) |
| 630 except KeyError: |
| 631 return False |
| 632 else: |
| 633 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 634 value = self._fields.get(field) |
| 635 return value is not None and value._is_present_in_parent |
| 636 else: |
| 637 return field in self._fields |
| 638 |
| 639 cls.HasField = HasField |
| 640 |
| 641 |
| 642 def _AddClearFieldMethod(message_descriptor, cls): |
| 643 """Helper for _AddMessageMethods().""" |
| 644 def ClearField(self, field_name): |
| 645 try: |
| 646 field = message_descriptor.fields_by_name[field_name] |
| 647 except KeyError: |
| 648 try: |
| 649 field = message_descriptor.oneofs_by_name[field_name] |
| 650 if field in self._oneofs: |
| 651 field = self._oneofs[field] |
| 652 else: |
| 653 return |
| 654 except KeyError: |
| 655 raise ValueError('Protocol message has no "%s" field.' % field_name) |
| 656 |
| 657 if field in self._fields: |
| 658 # Note: If the field is a sub-message, its listener will still point |
| 659 # at us. That's fine, because the worst than can happen is that it |
| 660 # will call _Modified() and invalidate our byte size. Big deal. |
| 661 del self._fields[field] |
| 662 |
| 663 if self._oneofs.get(field.containing_oneof, None) is field: |
| 664 del self._oneofs[field.containing_oneof] |
| 665 |
| 666 # Always call _Modified() -- even if nothing was changed, this is |
| 667 # a mutating method, and thus calling it should cause the field to become |
| 668 # present in the parent message. |
| 669 self._Modified() |
| 670 |
| 671 cls.ClearField = ClearField |
| 672 |
| 673 |
| 674 def _AddClearExtensionMethod(cls): |
| 675 """Helper for _AddMessageMethods().""" |
| 676 def ClearExtension(self, extension_handle): |
| 677 _VerifyExtensionHandle(self, extension_handle) |
| 678 |
| 679 # Similar to ClearField(), above. |
| 680 if extension_handle in self._fields: |
| 681 del self._fields[extension_handle] |
| 682 self._Modified() |
| 683 cls.ClearExtension = ClearExtension |
| 684 |
| 685 |
| 686 def _AddClearMethod(message_descriptor, cls): |
| 687 """Helper for _AddMessageMethods().""" |
| 688 def Clear(self): |
| 689 # Clear fields. |
| 690 self._fields = {} |
| 691 self._unknown_fields = () |
| 692 self._Modified() |
| 693 cls.Clear = Clear |
| 694 |
| 695 |
| 696 def _AddHasExtensionMethod(cls): |
| 697 """Helper for _AddMessageMethods().""" |
| 698 def HasExtension(self, extension_handle): |
| 699 _VerifyExtensionHandle(self, extension_handle) |
| 700 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: |
| 701 raise KeyError('"%s" is repeated.' % extension_handle.full_name) |
| 702 |
| 703 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 704 value = self._fields.get(extension_handle) |
| 705 return value is not None and value._is_present_in_parent |
| 706 else: |
| 707 return extension_handle in self._fields |
| 708 cls.HasExtension = HasExtension |
| 709 |
| 710 |
| 711 def _AddEqualsMethod(message_descriptor, cls): |
| 712 """Helper for _AddMessageMethods().""" |
| 713 def __eq__(self, other): |
| 714 if (not isinstance(other, message_mod.Message) or |
| 715 other.DESCRIPTOR != self.DESCRIPTOR): |
| 716 return False |
| 717 |
| 718 if self is other: |
| 719 return True |
| 720 |
| 721 if not self.ListFields() == other.ListFields(): |
| 722 return False |
| 723 |
| 724 # Sort unknown fields because their order shouldn't affect equality test. |
| 725 unknown_fields = list(self._unknown_fields) |
| 726 unknown_fields.sort() |
| 727 other_unknown_fields = list(other._unknown_fields) |
| 728 other_unknown_fields.sort() |
| 729 |
| 730 return unknown_fields == other_unknown_fields |
| 731 |
| 732 cls.__eq__ = __eq__ |
| 733 |
| 734 |
| 735 def _AddStrMethod(message_descriptor, cls): |
| 736 """Helper for _AddMessageMethods().""" |
| 737 def __str__(self): |
| 738 return text_format.MessageToString(self) |
| 739 cls.__str__ = __str__ |
| 740 |
| 741 |
| 742 def _AddUnicodeMethod(unused_message_descriptor, cls): |
| 743 """Helper for _AddMessageMethods().""" |
| 744 |
| 745 def __unicode__(self): |
| 746 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') |
| 747 cls.__unicode__ = __unicode__ |
| 748 |
| 749 |
| 750 def _AddSetListenerMethod(cls): |
| 751 """Helper for _AddMessageMethods().""" |
| 752 def SetListener(self, listener): |
| 753 if listener is None: |
| 754 self._listener = message_listener_mod.NullMessageListener() |
| 755 else: |
| 756 self._listener = listener |
| 757 cls._SetListener = SetListener |
| 758 |
| 759 |
| 760 def _BytesForNonRepeatedElement(value, field_number, field_type): |
| 761 """Returns the number of bytes needed to serialize a non-repeated element. |
| 762 The returned byte count includes space for tag information and any |
| 763 other additional space associated with serializing value. |
| 764 |
| 765 Args: |
| 766 value: Value we're serializing. |
| 767 field_number: Field number of this value. (Since the field number |
| 768 is stored as part of a varint-encoded tag, this has an impact |
| 769 on the total bytes required to serialize the value). |
| 770 field_type: The type of the field. One of the TYPE_* constants |
| 771 within FieldDescriptor. |
| 772 """ |
| 773 try: |
| 774 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] |
| 775 return fn(field_number, value) |
| 776 except KeyError: |
| 777 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) |
| 778 |
| 779 |
| 780 def _AddByteSizeMethod(message_descriptor, cls): |
| 781 """Helper for _AddMessageMethods().""" |
| 782 |
| 783 def ByteSize(self): |
| 784 if not self._cached_byte_size_dirty: |
| 785 return self._cached_byte_size |
| 786 |
| 787 size = 0 |
| 788 for field_descriptor, field_value in self.ListFields(): |
| 789 size += field_descriptor._sizer(field_value) |
| 790 |
| 791 for tag_bytes, value_bytes in self._unknown_fields: |
| 792 size += len(tag_bytes) + len(value_bytes) |
| 793 |
| 794 self._cached_byte_size = size |
| 795 self._cached_byte_size_dirty = False |
| 796 self._listener_for_children.dirty = False |
| 797 return size |
| 798 |
| 799 cls.ByteSize = ByteSize |
| 800 |
| 801 |
| 802 def _AddSerializeToStringMethod(message_descriptor, cls): |
| 803 """Helper for _AddMessageMethods().""" |
| 804 |
| 805 def SerializeToString(self): |
| 806 # Check if the message has all of its required fields set. |
| 807 errors = [] |
| 808 if not self.IsInitialized(): |
| 809 raise message_mod.EncodeError( |
| 810 'Message %s is missing required fields: %s' % ( |
| 811 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) |
| 812 return self.SerializePartialToString() |
| 813 cls.SerializeToString = SerializeToString |
| 814 |
| 815 |
| 816 def _AddSerializePartialToStringMethod(message_descriptor, cls): |
| 817 """Helper for _AddMessageMethods().""" |
| 818 |
| 819 def SerializePartialToString(self): |
| 820 out = BytesIO() |
| 821 self._InternalSerialize(out.write) |
| 822 return out.getvalue() |
| 823 cls.SerializePartialToString = SerializePartialToString |
| 824 |
| 825 def InternalSerialize(self, write_bytes): |
| 826 for field_descriptor, field_value in self.ListFields(): |
| 827 field_descriptor._encoder(write_bytes, field_value) |
| 828 for tag_bytes, value_bytes in self._unknown_fields: |
| 829 write_bytes(tag_bytes) |
| 830 write_bytes(value_bytes) |
| 831 cls._InternalSerialize = InternalSerialize |
| 832 |
| 833 |
| 834 def _AddMergeFromStringMethod(message_descriptor, cls): |
| 835 """Helper for _AddMessageMethods().""" |
| 836 def MergeFromString(self, serialized): |
| 837 length = len(serialized) |
| 838 try: |
| 839 if self._InternalParse(serialized, 0, length) != length: |
| 840 # The only reason _InternalParse would return early is if it |
| 841 # encountered an end-group tag. |
| 842 raise message_mod.DecodeError('Unexpected end-group tag.') |
| 843 except (IndexError, TypeError): |
| 844 # Now ord(buf[p:p+1]) == ord('') gets TypeError. |
| 845 raise message_mod.DecodeError('Truncated message.') |
| 846 except struct.error, e: |
| 847 raise message_mod.DecodeError(e) |
| 848 return length # Return this for legacy reasons. |
| 849 cls.MergeFromString = MergeFromString |
| 850 |
| 851 local_ReadTag = decoder.ReadTag |
| 852 local_SkipField = decoder.SkipField |
| 853 decoders_by_tag = cls._decoders_by_tag |
| 854 |
| 855 def InternalParse(self, buffer, pos, end): |
| 856 self._Modified() |
| 857 field_dict = self._fields |
| 858 unknown_field_list = self._unknown_fields |
| 859 while pos != end: |
| 860 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) |
| 861 field_decoder = decoders_by_tag.get(tag_bytes) |
| 862 if field_decoder is None: |
| 863 value_start_pos = new_pos |
| 864 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) |
| 865 if new_pos == -1: |
| 866 return pos |
| 867 if not unknown_field_list: |
| 868 unknown_field_list = self._unknown_fields = [] |
| 869 unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) |
| 870 pos = new_pos |
| 871 else: |
| 872 pos = field_decoder(buffer, new_pos, end, self, field_dict) |
| 873 return pos |
| 874 cls._InternalParse = InternalParse |
| 875 |
| 876 |
| 877 def _AddIsInitializedMethod(message_descriptor, cls): |
| 878 """Adds the IsInitialized and FindInitializationError methods to the |
| 879 protocol message class.""" |
| 880 |
| 881 required_fields = [field for field in message_descriptor.fields |
| 882 if field.label == _FieldDescriptor.LABEL_REQUIRED] |
| 883 |
| 884 def IsInitialized(self, errors=None): |
| 885 """Checks if all required fields of a message are set. |
| 886 |
| 887 Args: |
| 888 errors: A list which, if provided, will be populated with the field |
| 889 paths of all missing required fields. |
| 890 |
| 891 Returns: |
| 892 True iff the specified message has all required fields set. |
| 893 """ |
| 894 |
| 895 # Performance is critical so we avoid HasField() and ListFields(). |
| 896 |
| 897 for field in required_fields: |
| 898 if (field not in self._fields or |
| 899 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and |
| 900 not self._fields[field]._is_present_in_parent)): |
| 901 if errors is not None: |
| 902 errors.extend(self.FindInitializationErrors()) |
| 903 return False |
| 904 |
| 905 for field, value in list(self._fields.items()): # dict can change size! |
| 906 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 907 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 908 for element in value: |
| 909 if not element.IsInitialized(): |
| 910 if errors is not None: |
| 911 errors.extend(self.FindInitializationErrors()) |
| 912 return False |
| 913 elif value._is_present_in_parent and not value.IsInitialized(): |
| 914 if errors is not None: |
| 915 errors.extend(self.FindInitializationErrors()) |
| 916 return False |
| 917 |
| 918 return True |
| 919 |
| 920 cls.IsInitialized = IsInitialized |
| 921 |
| 922 def FindInitializationErrors(self): |
| 923 """Finds required fields which are not initialized. |
| 924 |
| 925 Returns: |
| 926 A list of strings. Each string is a path to an uninitialized field from |
| 927 the top-level message, e.g. "foo.bar[5].baz". |
| 928 """ |
| 929 |
| 930 errors = [] # simplify things |
| 931 |
| 932 for field in required_fields: |
| 933 if not self.HasField(field.name): |
| 934 errors.append(field.name) |
| 935 |
| 936 for field, value in self.ListFields(): |
| 937 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 938 if field.is_extension: |
| 939 name = "(%s)" % field.full_name |
| 940 else: |
| 941 name = field.name |
| 942 |
| 943 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 944 for i in xrange(len(value)): |
| 945 element = value[i] |
| 946 prefix = "%s[%d]." % (name, i) |
| 947 sub_errors = element.FindInitializationErrors() |
| 948 errors += [ prefix + error for error in sub_errors ] |
| 949 else: |
| 950 prefix = name + "." |
| 951 sub_errors = value.FindInitializationErrors() |
| 952 errors += [ prefix + error for error in sub_errors ] |
| 953 |
| 954 return errors |
| 955 |
| 956 cls.FindInitializationErrors = FindInitializationErrors |
| 957 |
| 958 |
| 959 def _AddMergeFromMethod(cls): |
| 960 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED |
| 961 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE |
| 962 |
| 963 def MergeFrom(self, msg): |
| 964 if not isinstance(msg, cls): |
| 965 raise TypeError( |
| 966 "Parameter to MergeFrom() must be instance of same class: " |
| 967 "expected %s got %s." % (cls.__name__, type(msg).__name__)) |
| 968 |
| 969 assert msg is not self |
| 970 self._Modified() |
| 971 |
| 972 fields = self._fields |
| 973 |
| 974 for field, value in msg._fields.iteritems(): |
| 975 if field.label == LABEL_REPEATED: |
| 976 field_value = fields.get(field) |
| 977 if field_value is None: |
| 978 # Construct a new object to represent this field. |
| 979 field_value = field._default_constructor(self) |
| 980 fields[field] = field_value |
| 981 field_value.MergeFrom(value) |
| 982 elif field.cpp_type == CPPTYPE_MESSAGE: |
| 983 if value._is_present_in_parent: |
| 984 field_value = fields.get(field) |
| 985 if field_value is None: |
| 986 # Construct a new object to represent this field. |
| 987 field_value = field._default_constructor(self) |
| 988 fields[field] = field_value |
| 989 field_value.MergeFrom(value) |
| 990 else: |
| 991 self._fields[field] = value |
| 992 |
| 993 if msg._unknown_fields: |
| 994 if not self._unknown_fields: |
| 995 self._unknown_fields = [] |
| 996 self._unknown_fields.extend(msg._unknown_fields) |
| 997 |
| 998 cls.MergeFrom = MergeFrom |
| 999 |
| 1000 |
| 1001 def _AddWhichOneofMethod(message_descriptor, cls): |
| 1002 def WhichOneof(self, oneof_name): |
| 1003 """Returns the name of the currently set field inside a oneof, or None.""" |
| 1004 try: |
| 1005 field = message_descriptor.oneofs_by_name[oneof_name] |
| 1006 except KeyError: |
| 1007 raise ValueError( |
| 1008 'Protocol message has no oneof "%s" field.' % oneof_name) |
| 1009 |
| 1010 nested_field = self._oneofs.get(field, None) |
| 1011 if nested_field is not None and self.HasField(nested_field.name): |
| 1012 return nested_field.name |
| 1013 else: |
| 1014 return None |
| 1015 |
| 1016 cls.WhichOneof = WhichOneof |
| 1017 |
| 1018 |
| 1019 def _AddMessageMethods(message_descriptor, cls): |
| 1020 """Adds implementations of all Message methods to cls.""" |
| 1021 _AddListFieldsMethod(message_descriptor, cls) |
| 1022 _AddHasFieldMethod(message_descriptor, cls) |
| 1023 _AddClearFieldMethod(message_descriptor, cls) |
| 1024 if message_descriptor.is_extendable: |
| 1025 _AddClearExtensionMethod(cls) |
| 1026 _AddHasExtensionMethod(cls) |
| 1027 _AddClearMethod(message_descriptor, cls) |
| 1028 _AddEqualsMethod(message_descriptor, cls) |
| 1029 _AddStrMethod(message_descriptor, cls) |
| 1030 _AddUnicodeMethod(message_descriptor, cls) |
| 1031 _AddSetListenerMethod(cls) |
| 1032 _AddByteSizeMethod(message_descriptor, cls) |
| 1033 _AddSerializeToStringMethod(message_descriptor, cls) |
| 1034 _AddSerializePartialToStringMethod(message_descriptor, cls) |
| 1035 _AddMergeFromStringMethod(message_descriptor, cls) |
| 1036 _AddIsInitializedMethod(message_descriptor, cls) |
| 1037 _AddMergeFromMethod(cls) |
| 1038 _AddWhichOneofMethod(message_descriptor, cls) |
| 1039 |
| 1040 def _AddPrivateHelperMethods(message_descriptor, cls): |
| 1041 """Adds implementation of private helper methods to cls.""" |
| 1042 |
| 1043 def Modified(self): |
| 1044 """Sets the _cached_byte_size_dirty bit to true, |
| 1045 and propagates this to our listener iff this was a state change. |
| 1046 """ |
| 1047 |
| 1048 # Note: Some callers check _cached_byte_size_dirty before calling |
| 1049 # _Modified() as an extra optimization. So, if this method is ever |
| 1050 # changed such that it does stuff even when _cached_byte_size_dirty is |
| 1051 # already true, the callers need to be updated. |
| 1052 if not self._cached_byte_size_dirty: |
| 1053 self._cached_byte_size_dirty = True |
| 1054 self._listener_for_children.dirty = True |
| 1055 self._is_present_in_parent = True |
| 1056 self._listener.Modified() |
| 1057 |
| 1058 def _UpdateOneofState(self, field): |
| 1059 """Sets field as the active field in its containing oneof. |
| 1060 |
| 1061 Will also delete currently active field in the oneof, if it is different |
| 1062 from the argument. Does not mark the message as modified. |
| 1063 """ |
| 1064 other_field = self._oneofs.setdefault(field.containing_oneof, field) |
| 1065 if other_field is not field: |
| 1066 del self._fields[other_field] |
| 1067 self._oneofs[field.containing_oneof] = field |
| 1068 |
| 1069 cls._Modified = Modified |
| 1070 cls.SetInParent = Modified |
| 1071 cls._UpdateOneofState = _UpdateOneofState |
| 1072 |
| 1073 |
| 1074 class _Listener(object): |
| 1075 |
| 1076 """MessageListener implementation that a parent message registers with its |
| 1077 child message. |
| 1078 |
| 1079 In order to support semantics like: |
| 1080 |
| 1081 foo.bar.baz.qux = 23 |
| 1082 assert foo.HasField('bar') |
| 1083 |
| 1084 ...child objects must have back references to their parents. |
| 1085 This helper class is at the heart of this support. |
| 1086 """ |
| 1087 |
| 1088 def __init__(self, parent_message): |
| 1089 """Args: |
| 1090 parent_message: The message whose _Modified() method we should call when |
| 1091 we receive Modified() messages. |
| 1092 """ |
| 1093 # This listener establishes a back reference from a child (contained) object |
| 1094 # to its parent (containing) object. We make this a weak reference to avoid |
| 1095 # creating cyclic garbage when the client finishes with the 'parent' object |
| 1096 # in the tree. |
| 1097 if isinstance(parent_message, weakref.ProxyType): |
| 1098 self._parent_message_weakref = parent_message |
| 1099 else: |
| 1100 self._parent_message_weakref = weakref.proxy(parent_message) |
| 1101 |
| 1102 # As an optimization, we also indicate directly on the listener whether |
| 1103 # or not the parent message is dirty. This way we can avoid traversing |
| 1104 # up the tree in the common case. |
| 1105 self.dirty = False |
| 1106 |
| 1107 def Modified(self): |
| 1108 if self.dirty: |
| 1109 return |
| 1110 try: |
| 1111 # Propagate the signal to our parents iff this is the first field set. |
| 1112 self._parent_message_weakref._Modified() |
| 1113 except ReferenceError: |
| 1114 # We can get here if a client has kept a reference to a child object, |
| 1115 # and is now setting a field on it, but the child's parent has been |
| 1116 # garbage-collected. This is not an error. |
| 1117 pass |
| 1118 |
| 1119 |
| 1120 class _OneofListener(_Listener): |
| 1121 """Special listener implementation for setting composite oneof fields.""" |
| 1122 |
| 1123 def __init__(self, parent_message, field): |
| 1124 """Args: |
| 1125 parent_message: The message whose _Modified() method we should call when |
| 1126 we receive Modified() messages. |
| 1127 field: The descriptor of the field being set in the parent message. |
| 1128 """ |
| 1129 super(_OneofListener, self).__init__(parent_message) |
| 1130 self._field = field |
| 1131 |
| 1132 def Modified(self): |
| 1133 """Also updates the state of the containing oneof in the parent message.""" |
| 1134 try: |
| 1135 self._parent_message_weakref._UpdateOneofState(self._field) |
| 1136 super(_OneofListener, self).Modified() |
| 1137 except ReferenceError: |
| 1138 pass |
| 1139 |
| 1140 |
| 1141 # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... |
| 1142 # TODO(robinson): Unify error handling of "unknown extension" crap. |
| 1143 # TODO(robinson): Support iteritems()-style iteration over all |
| 1144 # extensions with the "has" bits turned on? |
| 1145 class _ExtensionDict(object): |
| 1146 |
| 1147 """Dict-like container for supporting an indexable "Extensions" |
| 1148 field on proto instances. |
| 1149 |
| 1150 Note that in all cases we expect extension handles to be |
| 1151 FieldDescriptors. |
| 1152 """ |
| 1153 |
| 1154 def __init__(self, extended_message): |
| 1155 """extended_message: Message instance for which we are the Extensions dict. |
| 1156 """ |
| 1157 |
| 1158 self._extended_message = extended_message |
| 1159 |
| 1160 def __getitem__(self, extension_handle): |
| 1161 """Returns the current value of the given extension handle.""" |
| 1162 |
| 1163 _VerifyExtensionHandle(self._extended_message, extension_handle) |
| 1164 |
| 1165 result = self._extended_message._fields.get(extension_handle) |
| 1166 if result is not None: |
| 1167 return result |
| 1168 |
| 1169 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: |
| 1170 result = extension_handle._default_constructor(self._extended_message) |
| 1171 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 1172 result = extension_handle.message_type._concrete_class() |
| 1173 try: |
| 1174 result._SetListener(self._extended_message._listener_for_children) |
| 1175 except ReferenceError: |
| 1176 pass |
| 1177 else: |
| 1178 # Singular scalar -- just return the default without inserting into the |
| 1179 # dict. |
| 1180 return extension_handle.default_value |
| 1181 |
| 1182 # Atomically check if another thread has preempted us and, if not, swap |
| 1183 # in the new object we just created. If someone has preempted us, we |
| 1184 # take that object and discard ours. |
| 1185 # WARNING: We are relying on setdefault() being atomic. This is true |
| 1186 # in CPython but we haven't investigated others. This warning appears |
| 1187 # in several other locations in this file. |
| 1188 result = self._extended_message._fields.setdefault( |
| 1189 extension_handle, result) |
| 1190 |
| 1191 return result |
| 1192 |
| 1193 def __eq__(self, other): |
| 1194 if not isinstance(other, self.__class__): |
| 1195 return False |
| 1196 |
| 1197 my_fields = self._extended_message.ListFields() |
| 1198 other_fields = other._extended_message.ListFields() |
| 1199 |
| 1200 # Get rid of non-extension fields. |
| 1201 my_fields = [ field for field in my_fields if field.is_extension ] |
| 1202 other_fields = [ field for field in other_fields if field.is_extension ] |
| 1203 |
| 1204 return my_fields == other_fields |
| 1205 |
| 1206 def __ne__(self, other): |
| 1207 return not self == other |
| 1208 |
| 1209 def __hash__(self): |
| 1210 raise TypeError('unhashable object') |
| 1211 |
| 1212 # Note that this is only meaningful for non-repeated, scalar extension |
| 1213 # fields. Note also that we may have to call _Modified() when we do |
| 1214 # successfully set a field this way, to set any necssary "has" bits in the |
| 1215 # ancestors of the extended message. |
| 1216 def __setitem__(self, extension_handle, value): |
| 1217 """If extension_handle specifies a non-repeated, scalar extension |
| 1218 field, sets the value of that field. |
| 1219 """ |
| 1220 |
| 1221 _VerifyExtensionHandle(self._extended_message, extension_handle) |
| 1222 |
| 1223 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or |
| 1224 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): |
| 1225 raise TypeError( |
| 1226 'Cannot assign to extension "%s" because it is a repeated or ' |
| 1227 'composite type.' % extension_handle.full_name) |
| 1228 |
| 1229 # It's slightly wasteful to lookup the type checker each time, |
| 1230 # but we expect this to be a vanishingly uncommon case anyway. |
| 1231 type_checker = type_checkers.GetTypeChecker( |
| 1232 extension_handle) |
| 1233 # pylint: disable=protected-access |
| 1234 self._extended_message._fields[extension_handle] = ( |
| 1235 type_checker.CheckValue(value)) |
| 1236 self._extended_message._Modified() |
| 1237 |
| 1238 def _FindExtensionByName(self, name): |
| 1239 """Tries to find a known extension with the specified name. |
| 1240 |
| 1241 Args: |
| 1242 name: Extension full name. |
| 1243 |
| 1244 Returns: |
| 1245 Extension field descriptor. |
| 1246 """ |
| 1247 return self._extended_message._extensions_by_name.get(name, None) |
OLD | NEW |