| OLD | NEW |
| (Empty) |
| 1 # Protocol Buffers - Google's data interchange format | |
| 2 # Copyright 2008 Google Inc. All rights reserved. | |
| 3 # http://code.google.com/p/protobuf/ | |
| 4 # | |
| 5 # Redistribution and use in source and binary forms, with or without | |
| 6 # modification, are permitted provided that the following conditions are | |
| 7 # met: | |
| 8 # | |
| 9 # * Redistributions of source code must retain the above copyright | |
| 10 # notice, this list of conditions and the following disclaimer. | |
| 11 # * Redistributions in binary form must reproduce the above | |
| 12 # copyright notice, this list of conditions and the following disclaimer | |
| 13 # in the documentation and/or other materials provided with the | |
| 14 # distribution. | |
| 15 # * Neither the name of Google Inc. nor the names of its | |
| 16 # contributors may be used to endorse or promote products derived from | |
| 17 # this software without specific prior written permission. | |
| 18 # | |
| 19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | |
| 20 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | |
| 21 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | |
| 22 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | |
| 23 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | |
| 24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | |
| 25 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |
| 26 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | |
| 27 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
| 28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
| 29 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| 30 | |
| 31 # Keep it Python2.5 compatible for GAE. | |
| 32 # | |
| 33 # Copyright 2007 Google Inc. All Rights Reserved. | |
| 34 # | |
| 35 # This code is meant to work on Python 2.4 and above only. | |
| 36 # | |
| 37 # TODO(robinson): Helpers for verbose, common checks like seeing if a | |
| 38 # descriptor's cpp_type is CPPTYPE_MESSAGE. | |
| 39 | |
| 40 """Contains a metaclass and helper functions used to create | |
| 41 protocol message classes from Descriptor objects at runtime. | |
| 42 | |
| 43 Recall that a metaclass is the "type" of a class. | |
| 44 (A class is to a metaclass what an instance is to a class.) | |
| 45 | |
| 46 In this case, we use the GeneratedProtocolMessageType metaclass | |
| 47 to inject all the useful functionality into the classes | |
| 48 output by the protocol compiler at compile-time. | |
| 49 | |
| 50 The upshot of all this is that the real implementation | |
| 51 details for ALL pure-Python protocol buffers are *here in | |
| 52 this file*. | |
| 53 """ | |
| 54 | |
| 55 __author__ = 'robinson@google.com (Will Robinson)' | |
| 56 | |
| 57 import sys | |
| 58 if sys.version_info[0] < 3: | |
| 59 try: | |
| 60 from cStringIO import StringIO as BytesIO | |
| 61 except ImportError: | |
| 62 from StringIO import StringIO as BytesIO | |
| 63 import copy_reg as copyreg | |
| 64 else: | |
| 65 from io import BytesIO | |
| 66 import copyreg | |
| 67 import struct | |
| 68 import weakref | |
| 69 | |
| 70 # We use "as" to avoid name collisions with variables. | |
| 71 from google.protobuf.internal import containers | |
| 72 from google.protobuf.internal import decoder | |
| 73 from google.protobuf.internal import encoder | |
| 74 from google.protobuf.internal import enum_type_wrapper | |
| 75 from google.protobuf.internal import message_listener as message_listener_mod | |
| 76 from google.protobuf.internal import type_checkers | |
| 77 from google.protobuf.internal import wire_format | |
| 78 from google.protobuf import descriptor as descriptor_mod | |
| 79 from google.protobuf import message as message_mod | |
| 80 from google.protobuf import text_format | |
| 81 | |
| 82 _FieldDescriptor = descriptor_mod.FieldDescriptor | |
| 83 | |
| 84 | |
| 85 def NewMessage(bases, descriptor, dictionary): | |
| 86 _AddClassAttributesForNestedExtensions(descriptor, dictionary) | |
| 87 _AddSlots(descriptor, dictionary) | |
| 88 return bases | |
| 89 | |
| 90 | |
| 91 def InitMessage(descriptor, cls): | |
| 92 cls._decoders_by_tag = {} | |
| 93 cls._extensions_by_name = {} | |
| 94 cls._extensions_by_number = {} | |
| 95 if (descriptor.has_options and | |
| 96 descriptor.GetOptions().message_set_wire_format): | |
| 97 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( | |
| 98 decoder.MessageSetItemDecoder(cls._extensions_by_number)) | |
| 99 | |
| 100 # Attach stuff to each FieldDescriptor for quick lookup later on. | |
| 101 for field in descriptor.fields: | |
| 102 _AttachFieldHelpers(cls, field) | |
| 103 | |
| 104 _AddEnumValues(descriptor, cls) | |
| 105 _AddInitMethod(descriptor, cls) | |
| 106 _AddPropertiesForFields(descriptor, cls) | |
| 107 _AddPropertiesForExtensions(descriptor, cls) | |
| 108 _AddStaticMethods(cls) | |
| 109 _AddMessageMethods(descriptor, cls) | |
| 110 _AddPrivateHelperMethods(descriptor, cls) | |
| 111 copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) | |
| 112 | |
| 113 | |
| 114 # Stateless helpers for GeneratedProtocolMessageType below. | |
| 115 # Outside clients should not access these directly. | |
| 116 # | |
| 117 # I opted not to make any of these methods on the metaclass, to make it more | |
| 118 # clear that I'm not really using any state there and to keep clients from | |
| 119 # thinking that they have direct access to these construction helpers. | |
| 120 | |
| 121 | |
| 122 def _PropertyName(proto_field_name): | |
| 123 """Returns the name of the public property attribute which | |
| 124 clients can use to get and (in some cases) set the value | |
| 125 of a protocol message field. | |
| 126 | |
| 127 Args: | |
| 128 proto_field_name: The protocol message field name, exactly | |
| 129 as it appears (or would appear) in a .proto file. | |
| 130 """ | |
| 131 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. | |
| 132 # nnorwitz makes my day by writing: | |
| 133 # """ | |
| 134 # FYI. See the keyword module in the stdlib. This could be as simple as: | |
| 135 # | |
| 136 # if keyword.iskeyword(proto_field_name): | |
| 137 # return proto_field_name + "_" | |
| 138 # return proto_field_name | |
| 139 # """ | |
| 140 # Kenton says: The above is a BAD IDEA. People rely on being able to use | |
| 141 # getattr() and setattr() to reflectively manipulate field values. If we | |
| 142 # rename the properties, then every such user has to also make sure to apply | |
| 143 # the same transformation. Note that currently if you name a field "yield", | |
| 144 # you can still access it just fine using getattr/setattr -- it's not even | |
| 145 # that cumbersome to do so. | |
| 146 # TODO(kenton): Remove this method entirely if/when everyone agrees with my | |
| 147 # position. | |
| 148 return proto_field_name | |
| 149 | |
| 150 | |
| 151 def _VerifyExtensionHandle(message, extension_handle): | |
| 152 """Verify that the given extension handle is valid.""" | |
| 153 | |
| 154 if not isinstance(extension_handle, _FieldDescriptor): | |
| 155 raise KeyError('HasExtension() expects an extension handle, got: %s' % | |
| 156 extension_handle) | |
| 157 | |
| 158 if not extension_handle.is_extension: | |
| 159 raise KeyError('"%s" is not an extension.' % extension_handle.full_name) | |
| 160 | |
| 161 if not extension_handle.containing_type: | |
| 162 raise KeyError('"%s" is missing a containing_type.' | |
| 163 % extension_handle.full_name) | |
| 164 | |
| 165 if extension_handle.containing_type is not message.DESCRIPTOR: | |
| 166 raise KeyError('Extension "%s" extends message type "%s", but this ' | |
| 167 'message is of type "%s".' % | |
| 168 (extension_handle.full_name, | |
| 169 extension_handle.containing_type.full_name, | |
| 170 message.DESCRIPTOR.full_name)) | |
| 171 | |
| 172 | |
| 173 def _AddSlots(message_descriptor, dictionary): | |
| 174 """Adds a __slots__ entry to dictionary, containing the names of all valid | |
| 175 attributes for this message type. | |
| 176 | |
| 177 Args: | |
| 178 message_descriptor: A Descriptor instance describing this message type. | |
| 179 dictionary: Class dictionary to which we'll add a '__slots__' entry. | |
| 180 """ | |
| 181 dictionary['__slots__'] = ['_cached_byte_size', | |
| 182 '_cached_byte_size_dirty', | |
| 183 '_fields', | |
| 184 '_unknown_fields', | |
| 185 '_is_present_in_parent', | |
| 186 '_listener', | |
| 187 '_listener_for_children', | |
| 188 '__weakref__', | |
| 189 '_oneofs'] | |
| 190 | |
| 191 | |
| 192 def _IsMessageSetExtension(field): | |
| 193 return (field.is_extension and | |
| 194 field.containing_type.has_options and | |
| 195 field.containing_type.GetOptions().message_set_wire_format and | |
| 196 field.type == _FieldDescriptor.TYPE_MESSAGE and | |
| 197 field.message_type == field.extension_scope and | |
| 198 field.label == _FieldDescriptor.LABEL_OPTIONAL) | |
| 199 | |
| 200 | |
| 201 def _AttachFieldHelpers(cls, field_descriptor): | |
| 202 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) | |
| 203 is_packed = (field_descriptor.has_options and | |
| 204 field_descriptor.GetOptions().packed) | |
| 205 | |
| 206 if _IsMessageSetExtension(field_descriptor): | |
| 207 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) | |
| 208 sizer = encoder.MessageSetItemSizer(field_descriptor.number) | |
| 209 else: | |
| 210 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( | |
| 211 field_descriptor.number, is_repeated, is_packed) | |
| 212 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( | |
| 213 field_descriptor.number, is_repeated, is_packed) | |
| 214 | |
| 215 field_descriptor._encoder = field_encoder | |
| 216 field_descriptor._sizer = sizer | |
| 217 field_descriptor._default_constructor = _DefaultValueConstructorForField( | |
| 218 field_descriptor) | |
| 219 | |
| 220 def AddDecoder(wiretype, is_packed): | |
| 221 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) | |
| 222 cls._decoders_by_tag[tag_bytes] = ( | |
| 223 type_checkers.TYPE_TO_DECODER[field_descriptor.type]( | |
| 224 field_descriptor.number, is_repeated, is_packed, | |
| 225 field_descriptor, field_descriptor._default_constructor)) | |
| 226 | |
| 227 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], | |
| 228 False) | |
| 229 | |
| 230 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): | |
| 231 # To support wire compatibility of adding packed = true, add a decoder for | |
| 232 # packed values regardless of the field's options. | |
| 233 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) | |
| 234 | |
| 235 | |
| 236 def _AddClassAttributesForNestedExtensions(descriptor, dictionary): | |
| 237 extension_dict = descriptor.extensions_by_name | |
| 238 for extension_name, extension_field in extension_dict.iteritems(): | |
| 239 assert extension_name not in dictionary | |
| 240 dictionary[extension_name] = extension_field | |
| 241 | |
| 242 | |
| 243 def _AddEnumValues(descriptor, cls): | |
| 244 """Sets class-level attributes for all enum fields defined in this message. | |
| 245 | |
| 246 Also exporting a class-level object that can name enum values. | |
| 247 | |
| 248 Args: | |
| 249 descriptor: Descriptor object for this message type. | |
| 250 cls: Class we're constructing for this message type. | |
| 251 """ | |
| 252 for enum_type in descriptor.enum_types: | |
| 253 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) | |
| 254 for enum_value in enum_type.values: | |
| 255 setattr(cls, enum_value.name, enum_value.number) | |
| 256 | |
| 257 | |
| 258 def _DefaultValueConstructorForField(field): | |
| 259 """Returns a function which returns a default value for a field. | |
| 260 | |
| 261 Args: | |
| 262 field: FieldDescriptor object for this field. | |
| 263 | |
| 264 The returned function has one argument: | |
| 265 message: Message instance containing this field, or a weakref proxy | |
| 266 of same. | |
| 267 | |
| 268 That function in turn returns a default value for this field. The default | |
| 269 value may refer back to |message| via a weak reference. | |
| 270 """ | |
| 271 | |
| 272 if field.label == _FieldDescriptor.LABEL_REPEATED: | |
| 273 if field.has_default_value and field.default_value != []: | |
| 274 raise ValueError('Repeated field default value not empty list: %s' % ( | |
| 275 field.default_value)) | |
| 276 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
| 277 # We can't look at _concrete_class yet since it might not have | |
| 278 # been set. (Depends on order in which we initialize the classes). | |
| 279 message_type = field.message_type | |
| 280 def MakeRepeatedMessageDefault(message): | |
| 281 return containers.RepeatedCompositeFieldContainer( | |
| 282 message._listener_for_children, field.message_type) | |
| 283 return MakeRepeatedMessageDefault | |
| 284 else: | |
| 285 type_checker = type_checkers.GetTypeChecker(field) | |
| 286 def MakeRepeatedScalarDefault(message): | |
| 287 return containers.RepeatedScalarFieldContainer( | |
| 288 message._listener_for_children, type_checker) | |
| 289 return MakeRepeatedScalarDefault | |
| 290 | |
| 291 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
| 292 # _concrete_class may not yet be initialized. | |
| 293 message_type = field.message_type | |
| 294 def MakeSubMessageDefault(message): | |
| 295 result = message_type._concrete_class() | |
| 296 result._SetListener(message._listener_for_children) | |
| 297 return result | |
| 298 return MakeSubMessageDefault | |
| 299 | |
| 300 def MakeScalarDefault(message): | |
| 301 # TODO(protobuf-team): This may be broken since there may not be | |
| 302 # default_value. Combine with has_default_value somehow. | |
| 303 return field.default_value | |
| 304 return MakeScalarDefault | |
| 305 | |
| 306 | |
| 307 def _AddInitMethod(message_descriptor, cls): | |
| 308 """Adds an __init__ method to cls.""" | |
| 309 fields = message_descriptor.fields | |
| 310 def init(self, **kwargs): | |
| 311 self._cached_byte_size = 0 | |
| 312 self._cached_byte_size_dirty = len(kwargs) > 0 | |
| 313 self._fields = {} | |
| 314 # Contains a mapping from oneof field descriptors to the descriptor | |
| 315 # of the currently set field in that oneof field. | |
| 316 self._oneofs = {} | |
| 317 | |
| 318 # _unknown_fields is () when empty for efficiency, and will be turned into | |
| 319 # a list if fields are added. | |
| 320 self._unknown_fields = () | |
| 321 self._is_present_in_parent = False | |
| 322 self._listener = message_listener_mod.NullMessageListener() | |
| 323 self._listener_for_children = _Listener(self) | |
| 324 for field_name, field_value in kwargs.iteritems(): | |
| 325 field = _GetFieldByName(message_descriptor, field_name) | |
| 326 if field is None: | |
| 327 raise TypeError("%s() got an unexpected keyword argument '%s'" % | |
| 328 (message_descriptor.name, field_name)) | |
| 329 if field.label == _FieldDescriptor.LABEL_REPEATED: | |
| 330 copy = field._default_constructor(self) | |
| 331 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite | |
| 332 for val in field_value: | |
| 333 copy.add().MergeFrom(val) | |
| 334 else: # Scalar | |
| 335 copy.extend(field_value) | |
| 336 self._fields[field] = copy | |
| 337 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
| 338 copy = field._default_constructor(self) | |
| 339 copy.MergeFrom(field_value) | |
| 340 self._fields[field] = copy | |
| 341 else: | |
| 342 setattr(self, field_name, field_value) | |
| 343 | |
| 344 init.__module__ = None | |
| 345 init.__doc__ = None | |
| 346 cls.__init__ = init | |
| 347 | |
| 348 | |
| 349 def _GetFieldByName(message_descriptor, field_name): | |
| 350 """Returns a field descriptor by field name. | |
| 351 | |
| 352 Args: | |
| 353 message_descriptor: A Descriptor describing all fields in message. | |
| 354 field_name: The name of the field to retrieve. | |
| 355 Returns: | |
| 356 The field descriptor associated with the field name. | |
| 357 """ | |
| 358 try: | |
| 359 return message_descriptor.fields_by_name[field_name] | |
| 360 except KeyError: | |
| 361 raise ValueError('Protocol message has no "%s" field.' % field_name) | |
| 362 | |
| 363 | |
| 364 def _AddPropertiesForFields(descriptor, cls): | |
| 365 """Adds properties for all fields in this protocol message type.""" | |
| 366 for field in descriptor.fields: | |
| 367 _AddPropertiesForField(field, cls) | |
| 368 | |
| 369 if descriptor.is_extendable: | |
| 370 # _ExtensionDict is just an adaptor with no state so we allocate a new one | |
| 371 # every time it is accessed. | |
| 372 cls.Extensions = property(lambda self: _ExtensionDict(self)) | |
| 373 | |
| 374 | |
| 375 def _AddPropertiesForField(field, cls): | |
| 376 """Adds a public property for a protocol message field. | |
| 377 Clients can use this property to get and (in the case | |
| 378 of non-repeated scalar fields) directly set the value | |
| 379 of a protocol message field. | |
| 380 | |
| 381 Args: | |
| 382 field: A FieldDescriptor for this field. | |
| 383 cls: The class we're constructing. | |
| 384 """ | |
| 385 # Catch it if we add other types that we should | |
| 386 # handle specially here. | |
| 387 assert _FieldDescriptor.MAX_CPPTYPE == 10 | |
| 388 | |
| 389 constant_name = field.name.upper() + "_FIELD_NUMBER" | |
| 390 setattr(cls, constant_name, field.number) | |
| 391 | |
| 392 if field.label == _FieldDescriptor.LABEL_REPEATED: | |
| 393 _AddPropertiesForRepeatedField(field, cls) | |
| 394 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
| 395 _AddPropertiesForNonRepeatedCompositeField(field, cls) | |
| 396 else: | |
| 397 _AddPropertiesForNonRepeatedScalarField(field, cls) | |
| 398 | |
| 399 | |
| 400 def _AddPropertiesForRepeatedField(field, cls): | |
| 401 """Adds a public property for a "repeated" protocol message field. Clients | |
| 402 can use this property to get the value of the field, which will be either a | |
| 403 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see | |
| 404 below). | |
| 405 | |
| 406 Note that when clients add values to these containers, we perform | |
| 407 type-checking in the case of repeated scalar fields, and we also set any | |
| 408 necessary "has" bits as a side-effect. | |
| 409 | |
| 410 Args: | |
| 411 field: A FieldDescriptor for this field. | |
| 412 cls: The class we're constructing. | |
| 413 """ | |
| 414 proto_field_name = field.name | |
| 415 property_name = _PropertyName(proto_field_name) | |
| 416 | |
| 417 def getter(self): | |
| 418 field_value = self._fields.get(field) | |
| 419 if field_value is None: | |
| 420 # Construct a new object to represent this field. | |
| 421 field_value = field._default_constructor(self) | |
| 422 | |
| 423 # Atomically check if another thread has preempted us and, if not, swap | |
| 424 # in the new object we just created. If someone has preempted us, we | |
| 425 # take that object and discard ours. | |
| 426 # WARNING: We are relying on setdefault() being atomic. This is true | |
| 427 # in CPython but we haven't investigated others. This warning appears | |
| 428 # in several other locations in this file. | |
| 429 field_value = self._fields.setdefault(field, field_value) | |
| 430 return field_value | |
| 431 getter.__module__ = None | |
| 432 getter.__doc__ = 'Getter for %s.' % proto_field_name | |
| 433 | |
| 434 # We define a setter just so we can throw an exception with a more | |
| 435 # helpful error message. | |
| 436 def setter(self, new_value): | |
| 437 raise AttributeError('Assignment not allowed to repeated field ' | |
| 438 '"%s" in protocol message object.' % proto_field_name) | |
| 439 | |
| 440 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name | |
| 441 setattr(cls, property_name, property(getter, setter, doc=doc)) | |
| 442 | |
| 443 | |
| 444 def _AddPropertiesForNonRepeatedScalarField(field, cls): | |
| 445 """Adds a public property for a nonrepeated, scalar protocol message field. | |
| 446 Clients can use this property to get and directly set the value of the field. | |
| 447 Note that when the client sets the value of a field by using this property, | |
| 448 all necessary "has" bits are set as a side-effect, and we also perform | |
| 449 type-checking. | |
| 450 | |
| 451 Args: | |
| 452 field: A FieldDescriptor for this field. | |
| 453 cls: The class we're constructing. | |
| 454 """ | |
| 455 proto_field_name = field.name | |
| 456 property_name = _PropertyName(proto_field_name) | |
| 457 type_checker = type_checkers.GetTypeChecker(field) | |
| 458 default_value = field.default_value | |
| 459 valid_values = set() | |
| 460 | |
| 461 def getter(self): | |
| 462 # TODO(protobuf-team): This may be broken since there may not be | |
| 463 # default_value. Combine with has_default_value somehow. | |
| 464 return self._fields.get(field, default_value) | |
| 465 getter.__module__ = None | |
| 466 getter.__doc__ = 'Getter for %s.' % proto_field_name | |
| 467 def field_setter(self, new_value): | |
| 468 # pylint: disable=protected-access | |
| 469 self._fields[field] = type_checker.CheckValue(new_value) | |
| 470 # Check _cached_byte_size_dirty inline to improve performance, since scalar | |
| 471 # setters are called frequently. | |
| 472 if not self._cached_byte_size_dirty: | |
| 473 self._Modified() | |
| 474 | |
| 475 if field.containing_oneof is not None: | |
| 476 def setter(self, new_value): | |
| 477 field_setter(self, new_value) | |
| 478 self._UpdateOneofState(field) | |
| 479 else: | |
| 480 setter = field_setter | |
| 481 | |
| 482 setter.__module__ = None | |
| 483 setter.__doc__ = 'Setter for %s.' % proto_field_name | |
| 484 | |
| 485 # Add a property to encapsulate the getter/setter. | |
| 486 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name | |
| 487 setattr(cls, property_name, property(getter, setter, doc=doc)) | |
| 488 | |
| 489 | |
| 490 def _AddPropertiesForNonRepeatedCompositeField(field, cls): | |
| 491 """Adds a public property for a nonrepeated, composite protocol message field. | |
| 492 A composite field is a "group" or "message" field. | |
| 493 | |
| 494 Clients can use this property to get the value of the field, but cannot | |
| 495 assign to the property directly. | |
| 496 | |
| 497 Args: | |
| 498 field: A FieldDescriptor for this field. | |
| 499 cls: The class we're constructing. | |
| 500 """ | |
| 501 # TODO(robinson): Remove duplication with similar method | |
| 502 # for non-repeated scalars. | |
| 503 proto_field_name = field.name | |
| 504 property_name = _PropertyName(proto_field_name) | |
| 505 | |
| 506 # TODO(komarek): Can anyone explain to me why we cache the message_type this | |
| 507 # way, instead of referring to field.message_type inside of getter(self)? | |
| 508 # What if someone sets message_type later on (which makes for simpler | |
| 509 # dyanmic proto descriptor and class creation code). | |
| 510 message_type = field.message_type | |
| 511 | |
| 512 def getter(self): | |
| 513 field_value = self._fields.get(field) | |
| 514 if field_value is None: | |
| 515 # Construct a new object to represent this field. | |
| 516 field_value = message_type._concrete_class() # use field.message_type? | |
| 517 field_value._SetListener( | |
| 518 _OneofListener(self, field) | |
| 519 if field.containing_oneof is not None | |
| 520 else self._listener_for_children) | |
| 521 | |
| 522 # Atomically check if another thread has preempted us and, if not, swap | |
| 523 # in the new object we just created. If someone has preempted us, we | |
| 524 # take that object and discard ours. | |
| 525 # WARNING: We are relying on setdefault() being atomic. This is true | |
| 526 # in CPython but we haven't investigated others. This warning appears | |
| 527 # in several other locations in this file. | |
| 528 field_value = self._fields.setdefault(field, field_value) | |
| 529 return field_value | |
| 530 getter.__module__ = None | |
| 531 getter.__doc__ = 'Getter for %s.' % proto_field_name | |
| 532 | |
| 533 # We define a setter just so we can throw an exception with a more | |
| 534 # helpful error message. | |
| 535 def setter(self, new_value): | |
| 536 raise AttributeError('Assignment not allowed to composite field ' | |
| 537 '"%s" in protocol message object.' % proto_field_name) | |
| 538 | |
| 539 # Add a property to encapsulate the getter. | |
| 540 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name | |
| 541 setattr(cls, property_name, property(getter, setter, doc=doc)) | |
| 542 | |
| 543 | |
| 544 def _AddPropertiesForExtensions(descriptor, cls): | |
| 545 """Adds properties for all fields in this protocol message type.""" | |
| 546 extension_dict = descriptor.extensions_by_name | |
| 547 for extension_name, extension_field in extension_dict.iteritems(): | |
| 548 constant_name = extension_name.upper() + "_FIELD_NUMBER" | |
| 549 setattr(cls, constant_name, extension_field.number) | |
| 550 | |
| 551 | |
| 552 def _AddStaticMethods(cls): | |
| 553 # TODO(robinson): This probably needs to be thread-safe(?) | |
| 554 def RegisterExtension(extension_handle): | |
| 555 extension_handle.containing_type = cls.DESCRIPTOR | |
| 556 _AttachFieldHelpers(cls, extension_handle) | |
| 557 | |
| 558 # Try to insert our extension, failing if an extension with the same number | |
| 559 # already exists. | |
| 560 actual_handle = cls._extensions_by_number.setdefault( | |
| 561 extension_handle.number, extension_handle) | |
| 562 if actual_handle is not extension_handle: | |
| 563 raise AssertionError( | |
| 564 'Extensions "%s" and "%s" both try to extend message type "%s" with ' | |
| 565 'field number %d.' % | |
| 566 (extension_handle.full_name, actual_handle.full_name, | |
| 567 cls.DESCRIPTOR.full_name, extension_handle.number)) | |
| 568 | |
| 569 cls._extensions_by_name[extension_handle.full_name] = extension_handle | |
| 570 | |
| 571 handle = extension_handle # avoid line wrapping | |
| 572 if _IsMessageSetExtension(handle): | |
| 573 # MessageSet extension. Also register under type name. | |
| 574 cls._extensions_by_name[ | |
| 575 extension_handle.message_type.full_name] = extension_handle | |
| 576 | |
| 577 cls.RegisterExtension = staticmethod(RegisterExtension) | |
| 578 | |
| 579 def FromString(s): | |
| 580 message = cls() | |
| 581 message.MergeFromString(s) | |
| 582 return message | |
| 583 cls.FromString = staticmethod(FromString) | |
| 584 | |
| 585 | |
| 586 def _IsPresent(item): | |
| 587 """Given a (FieldDescriptor, value) tuple from _fields, return true if the | |
| 588 value should be included in the list returned by ListFields().""" | |
| 589 | |
| 590 if item[0].label == _FieldDescriptor.LABEL_REPEATED: | |
| 591 return bool(item[1]) | |
| 592 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
| 593 return item[1]._is_present_in_parent | |
| 594 else: | |
| 595 return True | |
| 596 | |
| 597 | |
| 598 def _AddListFieldsMethod(message_descriptor, cls): | |
| 599 """Helper for _AddMessageMethods().""" | |
| 600 | |
| 601 def ListFields(self): | |
| 602 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] | |
| 603 all_fields.sort(key = lambda item: item[0].number) | |
| 604 return all_fields | |
| 605 | |
| 606 cls.ListFields = ListFields | |
| 607 | |
| 608 | |
| 609 def _AddHasFieldMethod(message_descriptor, cls): | |
| 610 """Helper for _AddMessageMethods().""" | |
| 611 | |
| 612 singular_fields = {} | |
| 613 for field in message_descriptor.fields: | |
| 614 if field.label != _FieldDescriptor.LABEL_REPEATED: | |
| 615 singular_fields[field.name] = field | |
| 616 # Fields inside oneofs are never repeated (enforced by the compiler). | |
| 617 for field in message_descriptor.oneofs: | |
| 618 singular_fields[field.name] = field | |
| 619 | |
| 620 def HasField(self, field_name): | |
| 621 try: | |
| 622 field = singular_fields[field_name] | |
| 623 except KeyError: | |
| 624 raise ValueError( | |
| 625 'Protocol message has no singular "%s" field.' % field_name) | |
| 626 | |
| 627 if isinstance(field, descriptor_mod.OneofDescriptor): | |
| 628 try: | |
| 629 return HasField(self, self._oneofs[field].name) | |
| 630 except KeyError: | |
| 631 return False | |
| 632 else: | |
| 633 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
| 634 value = self._fields.get(field) | |
| 635 return value is not None and value._is_present_in_parent | |
| 636 else: | |
| 637 return field in self._fields | |
| 638 | |
| 639 cls.HasField = HasField | |
| 640 | |
| 641 | |
| 642 def _AddClearFieldMethod(message_descriptor, cls): | |
| 643 """Helper for _AddMessageMethods().""" | |
| 644 def ClearField(self, field_name): | |
| 645 try: | |
| 646 field = message_descriptor.fields_by_name[field_name] | |
| 647 except KeyError: | |
| 648 try: | |
| 649 field = message_descriptor.oneofs_by_name[field_name] | |
| 650 if field in self._oneofs: | |
| 651 field = self._oneofs[field] | |
| 652 else: | |
| 653 return | |
| 654 except KeyError: | |
| 655 raise ValueError('Protocol message has no "%s" field.' % field_name) | |
| 656 | |
| 657 if field in self._fields: | |
| 658 # Note: If the field is a sub-message, its listener will still point | |
| 659 # at us. That's fine, because the worst than can happen is that it | |
| 660 # will call _Modified() and invalidate our byte size. Big deal. | |
| 661 del self._fields[field] | |
| 662 | |
| 663 if self._oneofs.get(field.containing_oneof, None) is field: | |
| 664 del self._oneofs[field.containing_oneof] | |
| 665 | |
| 666 # Always call _Modified() -- even if nothing was changed, this is | |
| 667 # a mutating method, and thus calling it should cause the field to become | |
| 668 # present in the parent message. | |
| 669 self._Modified() | |
| 670 | |
| 671 cls.ClearField = ClearField | |
| 672 | |
| 673 | |
| 674 def _AddClearExtensionMethod(cls): | |
| 675 """Helper for _AddMessageMethods().""" | |
| 676 def ClearExtension(self, extension_handle): | |
| 677 _VerifyExtensionHandle(self, extension_handle) | |
| 678 | |
| 679 # Similar to ClearField(), above. | |
| 680 if extension_handle in self._fields: | |
| 681 del self._fields[extension_handle] | |
| 682 self._Modified() | |
| 683 cls.ClearExtension = ClearExtension | |
| 684 | |
| 685 | |
| 686 def _AddClearMethod(message_descriptor, cls): | |
| 687 """Helper for _AddMessageMethods().""" | |
| 688 def Clear(self): | |
| 689 # Clear fields. | |
| 690 self._fields = {} | |
| 691 self._unknown_fields = () | |
| 692 self._Modified() | |
| 693 cls.Clear = Clear | |
| 694 | |
| 695 | |
| 696 def _AddHasExtensionMethod(cls): | |
| 697 """Helper for _AddMessageMethods().""" | |
| 698 def HasExtension(self, extension_handle): | |
| 699 _VerifyExtensionHandle(self, extension_handle) | |
| 700 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: | |
| 701 raise KeyError('"%s" is repeated.' % extension_handle.full_name) | |
| 702 | |
| 703 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
| 704 value = self._fields.get(extension_handle) | |
| 705 return value is not None and value._is_present_in_parent | |
| 706 else: | |
| 707 return extension_handle in self._fields | |
| 708 cls.HasExtension = HasExtension | |
| 709 | |
| 710 | |
| 711 def _AddEqualsMethod(message_descriptor, cls): | |
| 712 """Helper for _AddMessageMethods().""" | |
| 713 def __eq__(self, other): | |
| 714 if (not isinstance(other, message_mod.Message) or | |
| 715 other.DESCRIPTOR != self.DESCRIPTOR): | |
| 716 return False | |
| 717 | |
| 718 if self is other: | |
| 719 return True | |
| 720 | |
| 721 if not self.ListFields() == other.ListFields(): | |
| 722 return False | |
| 723 | |
| 724 # Sort unknown fields because their order shouldn't affect equality test. | |
| 725 unknown_fields = list(self._unknown_fields) | |
| 726 unknown_fields.sort() | |
| 727 other_unknown_fields = list(other._unknown_fields) | |
| 728 other_unknown_fields.sort() | |
| 729 | |
| 730 return unknown_fields == other_unknown_fields | |
| 731 | |
| 732 cls.__eq__ = __eq__ | |
| 733 | |
| 734 | |
| 735 def _AddStrMethod(message_descriptor, cls): | |
| 736 """Helper for _AddMessageMethods().""" | |
| 737 def __str__(self): | |
| 738 return text_format.MessageToString(self) | |
| 739 cls.__str__ = __str__ | |
| 740 | |
| 741 | |
| 742 def _AddUnicodeMethod(unused_message_descriptor, cls): | |
| 743 """Helper for _AddMessageMethods().""" | |
| 744 | |
| 745 def __unicode__(self): | |
| 746 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') | |
| 747 cls.__unicode__ = __unicode__ | |
| 748 | |
| 749 | |
| 750 def _AddSetListenerMethod(cls): | |
| 751 """Helper for _AddMessageMethods().""" | |
| 752 def SetListener(self, listener): | |
| 753 if listener is None: | |
| 754 self._listener = message_listener_mod.NullMessageListener() | |
| 755 else: | |
| 756 self._listener = listener | |
| 757 cls._SetListener = SetListener | |
| 758 | |
| 759 | |
| 760 def _BytesForNonRepeatedElement(value, field_number, field_type): | |
| 761 """Returns the number of bytes needed to serialize a non-repeated element. | |
| 762 The returned byte count includes space for tag information and any | |
| 763 other additional space associated with serializing value. | |
| 764 | |
| 765 Args: | |
| 766 value: Value we're serializing. | |
| 767 field_number: Field number of this value. (Since the field number | |
| 768 is stored as part of a varint-encoded tag, this has an impact | |
| 769 on the total bytes required to serialize the value). | |
| 770 field_type: The type of the field. One of the TYPE_* constants | |
| 771 within FieldDescriptor. | |
| 772 """ | |
| 773 try: | |
| 774 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] | |
| 775 return fn(field_number, value) | |
| 776 except KeyError: | |
| 777 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) | |
| 778 | |
| 779 | |
| 780 def _AddByteSizeMethod(message_descriptor, cls): | |
| 781 """Helper for _AddMessageMethods().""" | |
| 782 | |
| 783 def ByteSize(self): | |
| 784 if not self._cached_byte_size_dirty: | |
| 785 return self._cached_byte_size | |
| 786 | |
| 787 size = 0 | |
| 788 for field_descriptor, field_value in self.ListFields(): | |
| 789 size += field_descriptor._sizer(field_value) | |
| 790 | |
| 791 for tag_bytes, value_bytes in self._unknown_fields: | |
| 792 size += len(tag_bytes) + len(value_bytes) | |
| 793 | |
| 794 self._cached_byte_size = size | |
| 795 self._cached_byte_size_dirty = False | |
| 796 self._listener_for_children.dirty = False | |
| 797 return size | |
| 798 | |
| 799 cls.ByteSize = ByteSize | |
| 800 | |
| 801 | |
| 802 def _AddSerializeToStringMethod(message_descriptor, cls): | |
| 803 """Helper for _AddMessageMethods().""" | |
| 804 | |
| 805 def SerializeToString(self): | |
| 806 # Check if the message has all of its required fields set. | |
| 807 errors = [] | |
| 808 if not self.IsInitialized(): | |
| 809 raise message_mod.EncodeError( | |
| 810 'Message %s is missing required fields: %s' % ( | |
| 811 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) | |
| 812 return self.SerializePartialToString() | |
| 813 cls.SerializeToString = SerializeToString | |
| 814 | |
| 815 | |
| 816 def _AddSerializePartialToStringMethod(message_descriptor, cls): | |
| 817 """Helper for _AddMessageMethods().""" | |
| 818 | |
| 819 def SerializePartialToString(self): | |
| 820 out = BytesIO() | |
| 821 self._InternalSerialize(out.write) | |
| 822 return out.getvalue() | |
| 823 cls.SerializePartialToString = SerializePartialToString | |
| 824 | |
| 825 def InternalSerialize(self, write_bytes): | |
| 826 for field_descriptor, field_value in self.ListFields(): | |
| 827 field_descriptor._encoder(write_bytes, field_value) | |
| 828 for tag_bytes, value_bytes in self._unknown_fields: | |
| 829 write_bytes(tag_bytes) | |
| 830 write_bytes(value_bytes) | |
| 831 cls._InternalSerialize = InternalSerialize | |
| 832 | |
| 833 | |
| 834 def _AddMergeFromStringMethod(message_descriptor, cls): | |
| 835 """Helper for _AddMessageMethods().""" | |
| 836 def MergeFromString(self, serialized): | |
| 837 length = len(serialized) | |
| 838 try: | |
| 839 if self._InternalParse(serialized, 0, length) != length: | |
| 840 # The only reason _InternalParse would return early is if it | |
| 841 # encountered an end-group tag. | |
| 842 raise message_mod.DecodeError('Unexpected end-group tag.') | |
| 843 except (IndexError, TypeError): | |
| 844 # Now ord(buf[p:p+1]) == ord('') gets TypeError. | |
| 845 raise message_mod.DecodeError('Truncated message.') | |
| 846 except struct.error, e: | |
| 847 raise message_mod.DecodeError(e) | |
| 848 return length # Return this for legacy reasons. | |
| 849 cls.MergeFromString = MergeFromString | |
| 850 | |
| 851 local_ReadTag = decoder.ReadTag | |
| 852 local_SkipField = decoder.SkipField | |
| 853 decoders_by_tag = cls._decoders_by_tag | |
| 854 | |
| 855 def InternalParse(self, buffer, pos, end): | |
| 856 self._Modified() | |
| 857 field_dict = self._fields | |
| 858 unknown_field_list = self._unknown_fields | |
| 859 while pos != end: | |
| 860 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) | |
| 861 field_decoder = decoders_by_tag.get(tag_bytes) | |
| 862 if field_decoder is None: | |
| 863 value_start_pos = new_pos | |
| 864 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) | |
| 865 if new_pos == -1: | |
| 866 return pos | |
| 867 if not unknown_field_list: | |
| 868 unknown_field_list = self._unknown_fields = [] | |
| 869 unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) | |
| 870 pos = new_pos | |
| 871 else: | |
| 872 pos = field_decoder(buffer, new_pos, end, self, field_dict) | |
| 873 return pos | |
| 874 cls._InternalParse = InternalParse | |
| 875 | |
| 876 | |
| 877 def _AddIsInitializedMethod(message_descriptor, cls): | |
| 878 """Adds the IsInitialized and FindInitializationError methods to the | |
| 879 protocol message class.""" | |
| 880 | |
| 881 required_fields = [field for field in message_descriptor.fields | |
| 882 if field.label == _FieldDescriptor.LABEL_REQUIRED] | |
| 883 | |
| 884 def IsInitialized(self, errors=None): | |
| 885 """Checks if all required fields of a message are set. | |
| 886 | |
| 887 Args: | |
| 888 errors: A list which, if provided, will be populated with the field | |
| 889 paths of all missing required fields. | |
| 890 | |
| 891 Returns: | |
| 892 True iff the specified message has all required fields set. | |
| 893 """ | |
| 894 | |
| 895 # Performance is critical so we avoid HasField() and ListFields(). | |
| 896 | |
| 897 for field in required_fields: | |
| 898 if (field not in self._fields or | |
| 899 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and | |
| 900 not self._fields[field]._is_present_in_parent)): | |
| 901 if errors is not None: | |
| 902 errors.extend(self.FindInitializationErrors()) | |
| 903 return False | |
| 904 | |
| 905 for field, value in list(self._fields.items()): # dict can change size! | |
| 906 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
| 907 if field.label == _FieldDescriptor.LABEL_REPEATED: | |
| 908 for element in value: | |
| 909 if not element.IsInitialized(): | |
| 910 if errors is not None: | |
| 911 errors.extend(self.FindInitializationErrors()) | |
| 912 return False | |
| 913 elif value._is_present_in_parent and not value.IsInitialized(): | |
| 914 if errors is not None: | |
| 915 errors.extend(self.FindInitializationErrors()) | |
| 916 return False | |
| 917 | |
| 918 return True | |
| 919 | |
| 920 cls.IsInitialized = IsInitialized | |
| 921 | |
| 922 def FindInitializationErrors(self): | |
| 923 """Finds required fields which are not initialized. | |
| 924 | |
| 925 Returns: | |
| 926 A list of strings. Each string is a path to an uninitialized field from | |
| 927 the top-level message, e.g. "foo.bar[5].baz". | |
| 928 """ | |
| 929 | |
| 930 errors = [] # simplify things | |
| 931 | |
| 932 for field in required_fields: | |
| 933 if not self.HasField(field.name): | |
| 934 errors.append(field.name) | |
| 935 | |
| 936 for field, value in self.ListFields(): | |
| 937 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
| 938 if field.is_extension: | |
| 939 name = "(%s)" % field.full_name | |
| 940 else: | |
| 941 name = field.name | |
| 942 | |
| 943 if field.label == _FieldDescriptor.LABEL_REPEATED: | |
| 944 for i in xrange(len(value)): | |
| 945 element = value[i] | |
| 946 prefix = "%s[%d]." % (name, i) | |
| 947 sub_errors = element.FindInitializationErrors() | |
| 948 errors += [ prefix + error for error in sub_errors ] | |
| 949 else: | |
| 950 prefix = name + "." | |
| 951 sub_errors = value.FindInitializationErrors() | |
| 952 errors += [ prefix + error for error in sub_errors ] | |
| 953 | |
| 954 return errors | |
| 955 | |
| 956 cls.FindInitializationErrors = FindInitializationErrors | |
| 957 | |
| 958 | |
| 959 def _AddMergeFromMethod(cls): | |
| 960 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED | |
| 961 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE | |
| 962 | |
| 963 def MergeFrom(self, msg): | |
| 964 if not isinstance(msg, cls): | |
| 965 raise TypeError( | |
| 966 "Parameter to MergeFrom() must be instance of same class: " | |
| 967 "expected %s got %s." % (cls.__name__, type(msg).__name__)) | |
| 968 | |
| 969 assert msg is not self | |
| 970 self._Modified() | |
| 971 | |
| 972 fields = self._fields | |
| 973 | |
| 974 for field, value in msg._fields.iteritems(): | |
| 975 if field.label == LABEL_REPEATED: | |
| 976 field_value = fields.get(field) | |
| 977 if field_value is None: | |
| 978 # Construct a new object to represent this field. | |
| 979 field_value = field._default_constructor(self) | |
| 980 fields[field] = field_value | |
| 981 field_value.MergeFrom(value) | |
| 982 elif field.cpp_type == CPPTYPE_MESSAGE: | |
| 983 if value._is_present_in_parent: | |
| 984 field_value = fields.get(field) | |
| 985 if field_value is None: | |
| 986 # Construct a new object to represent this field. | |
| 987 field_value = field._default_constructor(self) | |
| 988 fields[field] = field_value | |
| 989 field_value.MergeFrom(value) | |
| 990 else: | |
| 991 self._fields[field] = value | |
| 992 | |
| 993 if msg._unknown_fields: | |
| 994 if not self._unknown_fields: | |
| 995 self._unknown_fields = [] | |
| 996 self._unknown_fields.extend(msg._unknown_fields) | |
| 997 | |
| 998 cls.MergeFrom = MergeFrom | |
| 999 | |
| 1000 | |
| 1001 def _AddWhichOneofMethod(message_descriptor, cls): | |
| 1002 def WhichOneof(self, oneof_name): | |
| 1003 """Returns the name of the currently set field inside a oneof, or None.""" | |
| 1004 try: | |
| 1005 field = message_descriptor.oneofs_by_name[oneof_name] | |
| 1006 except KeyError: | |
| 1007 raise ValueError( | |
| 1008 'Protocol message has no oneof "%s" field.' % oneof_name) | |
| 1009 | |
| 1010 nested_field = self._oneofs.get(field, None) | |
| 1011 if nested_field is not None and self.HasField(nested_field.name): | |
| 1012 return nested_field.name | |
| 1013 else: | |
| 1014 return None | |
| 1015 | |
| 1016 cls.WhichOneof = WhichOneof | |
| 1017 | |
| 1018 | |
| 1019 def _AddMessageMethods(message_descriptor, cls): | |
| 1020 """Adds implementations of all Message methods to cls.""" | |
| 1021 _AddListFieldsMethod(message_descriptor, cls) | |
| 1022 _AddHasFieldMethod(message_descriptor, cls) | |
| 1023 _AddClearFieldMethod(message_descriptor, cls) | |
| 1024 if message_descriptor.is_extendable: | |
| 1025 _AddClearExtensionMethod(cls) | |
| 1026 _AddHasExtensionMethod(cls) | |
| 1027 _AddClearMethod(message_descriptor, cls) | |
| 1028 _AddEqualsMethod(message_descriptor, cls) | |
| 1029 _AddStrMethod(message_descriptor, cls) | |
| 1030 _AddUnicodeMethod(message_descriptor, cls) | |
| 1031 _AddSetListenerMethod(cls) | |
| 1032 _AddByteSizeMethod(message_descriptor, cls) | |
| 1033 _AddSerializeToStringMethod(message_descriptor, cls) | |
| 1034 _AddSerializePartialToStringMethod(message_descriptor, cls) | |
| 1035 _AddMergeFromStringMethod(message_descriptor, cls) | |
| 1036 _AddIsInitializedMethod(message_descriptor, cls) | |
| 1037 _AddMergeFromMethod(cls) | |
| 1038 _AddWhichOneofMethod(message_descriptor, cls) | |
| 1039 | |
| 1040 def _AddPrivateHelperMethods(message_descriptor, cls): | |
| 1041 """Adds implementation of private helper methods to cls.""" | |
| 1042 | |
| 1043 def Modified(self): | |
| 1044 """Sets the _cached_byte_size_dirty bit to true, | |
| 1045 and propagates this to our listener iff this was a state change. | |
| 1046 """ | |
| 1047 | |
| 1048 # Note: Some callers check _cached_byte_size_dirty before calling | |
| 1049 # _Modified() as an extra optimization. So, if this method is ever | |
| 1050 # changed such that it does stuff even when _cached_byte_size_dirty is | |
| 1051 # already true, the callers need to be updated. | |
| 1052 if not self._cached_byte_size_dirty: | |
| 1053 self._cached_byte_size_dirty = True | |
| 1054 self._listener_for_children.dirty = True | |
| 1055 self._is_present_in_parent = True | |
| 1056 self._listener.Modified() | |
| 1057 | |
| 1058 def _UpdateOneofState(self, field): | |
| 1059 """Sets field as the active field in its containing oneof. | |
| 1060 | |
| 1061 Will also delete currently active field in the oneof, if it is different | |
| 1062 from the argument. Does not mark the message as modified. | |
| 1063 """ | |
| 1064 other_field = self._oneofs.setdefault(field.containing_oneof, field) | |
| 1065 if other_field is not field: | |
| 1066 del self._fields[other_field] | |
| 1067 self._oneofs[field.containing_oneof] = field | |
| 1068 | |
| 1069 cls._Modified = Modified | |
| 1070 cls.SetInParent = Modified | |
| 1071 cls._UpdateOneofState = _UpdateOneofState | |
| 1072 | |
| 1073 | |
| 1074 class _Listener(object): | |
| 1075 | |
| 1076 """MessageListener implementation that a parent message registers with its | |
| 1077 child message. | |
| 1078 | |
| 1079 In order to support semantics like: | |
| 1080 | |
| 1081 foo.bar.baz.qux = 23 | |
| 1082 assert foo.HasField('bar') | |
| 1083 | |
| 1084 ...child objects must have back references to their parents. | |
| 1085 This helper class is at the heart of this support. | |
| 1086 """ | |
| 1087 | |
| 1088 def __init__(self, parent_message): | |
| 1089 """Args: | |
| 1090 parent_message: The message whose _Modified() method we should call when | |
| 1091 we receive Modified() messages. | |
| 1092 """ | |
| 1093 # This listener establishes a back reference from a child (contained) object | |
| 1094 # to its parent (containing) object. We make this a weak reference to avoid | |
| 1095 # creating cyclic garbage when the client finishes with the 'parent' object | |
| 1096 # in the tree. | |
| 1097 if isinstance(parent_message, weakref.ProxyType): | |
| 1098 self._parent_message_weakref = parent_message | |
| 1099 else: | |
| 1100 self._parent_message_weakref = weakref.proxy(parent_message) | |
| 1101 | |
| 1102 # As an optimization, we also indicate directly on the listener whether | |
| 1103 # or not the parent message is dirty. This way we can avoid traversing | |
| 1104 # up the tree in the common case. | |
| 1105 self.dirty = False | |
| 1106 | |
| 1107 def Modified(self): | |
| 1108 if self.dirty: | |
| 1109 return | |
| 1110 try: | |
| 1111 # Propagate the signal to our parents iff this is the first field set. | |
| 1112 self._parent_message_weakref._Modified() | |
| 1113 except ReferenceError: | |
| 1114 # We can get here if a client has kept a reference to a child object, | |
| 1115 # and is now setting a field on it, but the child's parent has been | |
| 1116 # garbage-collected. This is not an error. | |
| 1117 pass | |
| 1118 | |
| 1119 | |
| 1120 class _OneofListener(_Listener): | |
| 1121 """Special listener implementation for setting composite oneof fields.""" | |
| 1122 | |
| 1123 def __init__(self, parent_message, field): | |
| 1124 """Args: | |
| 1125 parent_message: The message whose _Modified() method we should call when | |
| 1126 we receive Modified() messages. | |
| 1127 field: The descriptor of the field being set in the parent message. | |
| 1128 """ | |
| 1129 super(_OneofListener, self).__init__(parent_message) | |
| 1130 self._field = field | |
| 1131 | |
| 1132 def Modified(self): | |
| 1133 """Also updates the state of the containing oneof in the parent message.""" | |
| 1134 try: | |
| 1135 self._parent_message_weakref._UpdateOneofState(self._field) | |
| 1136 super(_OneofListener, self).Modified() | |
| 1137 except ReferenceError: | |
| 1138 pass | |
| 1139 | |
| 1140 | |
| 1141 # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... | |
| 1142 # TODO(robinson): Unify error handling of "unknown extension" crap. | |
| 1143 # TODO(robinson): Support iteritems()-style iteration over all | |
| 1144 # extensions with the "has" bits turned on? | |
| 1145 class _ExtensionDict(object): | |
| 1146 | |
| 1147 """Dict-like container for supporting an indexable "Extensions" | |
| 1148 field on proto instances. | |
| 1149 | |
| 1150 Note that in all cases we expect extension handles to be | |
| 1151 FieldDescriptors. | |
| 1152 """ | |
| 1153 | |
| 1154 def __init__(self, extended_message): | |
| 1155 """extended_message: Message instance for which we are the Extensions dict. | |
| 1156 """ | |
| 1157 | |
| 1158 self._extended_message = extended_message | |
| 1159 | |
| 1160 def __getitem__(self, extension_handle): | |
| 1161 """Returns the current value of the given extension handle.""" | |
| 1162 | |
| 1163 _VerifyExtensionHandle(self._extended_message, extension_handle) | |
| 1164 | |
| 1165 result = self._extended_message._fields.get(extension_handle) | |
| 1166 if result is not None: | |
| 1167 return result | |
| 1168 | |
| 1169 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: | |
| 1170 result = extension_handle._default_constructor(self._extended_message) | |
| 1171 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | |
| 1172 result = extension_handle.message_type._concrete_class() | |
| 1173 try: | |
| 1174 result._SetListener(self._extended_message._listener_for_children) | |
| 1175 except ReferenceError: | |
| 1176 pass | |
| 1177 else: | |
| 1178 # Singular scalar -- just return the default without inserting into the | |
| 1179 # dict. | |
| 1180 return extension_handle.default_value | |
| 1181 | |
| 1182 # Atomically check if another thread has preempted us and, if not, swap | |
| 1183 # in the new object we just created. If someone has preempted us, we | |
| 1184 # take that object and discard ours. | |
| 1185 # WARNING: We are relying on setdefault() being atomic. This is true | |
| 1186 # in CPython but we haven't investigated others. This warning appears | |
| 1187 # in several other locations in this file. | |
| 1188 result = self._extended_message._fields.setdefault( | |
| 1189 extension_handle, result) | |
| 1190 | |
| 1191 return result | |
| 1192 | |
| 1193 def __eq__(self, other): | |
| 1194 if not isinstance(other, self.__class__): | |
| 1195 return False | |
| 1196 | |
| 1197 my_fields = self._extended_message.ListFields() | |
| 1198 other_fields = other._extended_message.ListFields() | |
| 1199 | |
| 1200 # Get rid of non-extension fields. | |
| 1201 my_fields = [ field for field in my_fields if field.is_extension ] | |
| 1202 other_fields = [ field for field in other_fields if field.is_extension ] | |
| 1203 | |
| 1204 return my_fields == other_fields | |
| 1205 | |
| 1206 def __ne__(self, other): | |
| 1207 return not self == other | |
| 1208 | |
| 1209 def __hash__(self): | |
| 1210 raise TypeError('unhashable object') | |
| 1211 | |
| 1212 # Note that this is only meaningful for non-repeated, scalar extension | |
| 1213 # fields. Note also that we may have to call _Modified() when we do | |
| 1214 # successfully set a field this way, to set any necssary "has" bits in the | |
| 1215 # ancestors of the extended message. | |
| 1216 def __setitem__(self, extension_handle, value): | |
| 1217 """If extension_handle specifies a non-repeated, scalar extension | |
| 1218 field, sets the value of that field. | |
| 1219 """ | |
| 1220 | |
| 1221 _VerifyExtensionHandle(self._extended_message, extension_handle) | |
| 1222 | |
| 1223 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or | |
| 1224 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): | |
| 1225 raise TypeError( | |
| 1226 'Cannot assign to extension "%s" because it is a repeated or ' | |
| 1227 'composite type.' % extension_handle.full_name) | |
| 1228 | |
| 1229 # It's slightly wasteful to lookup the type checker each time, | |
| 1230 # but we expect this to be a vanishingly uncommon case anyway. | |
| 1231 type_checker = type_checkers.GetTypeChecker( | |
| 1232 extension_handle) | |
| 1233 # pylint: disable=protected-access | |
| 1234 self._extended_message._fields[extension_handle] = ( | |
| 1235 type_checker.CheckValue(value)) | |
| 1236 self._extended_message._Modified() | |
| 1237 | |
| 1238 def _FindExtensionByName(self, name): | |
| 1239 """Tries to find a known extension with the specified name. | |
| 1240 | |
| 1241 Args: | |
| 1242 name: Extension full name. | |
| 1243 | |
| 1244 Returns: | |
| 1245 Extension field descriptor. | |
| 1246 """ | |
| 1247 return self._extended_message._extensions_by_name.get(name, None) | |
| OLD | NEW |