OLD | NEW |
(Empty) | |
| 1 # Protocol Buffers - Google's data interchange format |
| 2 # Copyright 2008 Google Inc. All rights reserved. |
| 3 # http://code.google.com/p/protobuf/ |
| 4 # |
| 5 # Redistribution and use in source and binary forms, with or without |
| 6 # modification, are permitted provided that the following conditions are |
| 7 # met: |
| 8 # |
| 9 # * Redistributions of source code must retain the above copyright |
| 10 # notice, this list of conditions and the following disclaimer. |
| 11 # * Redistributions in binary form must reproduce the above |
| 12 # copyright notice, this list of conditions and the following disclaimer |
| 13 # in the documentation and/or other materials provided with the |
| 14 # distribution. |
| 15 # * Neither the name of Google Inc. nor the names of its |
| 16 # contributors may be used to endorse or promote products derived from |
| 17 # this software without specific prior written permission. |
| 18 # |
| 19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
| 20 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
| 21 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
| 22 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
| 23 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
| 24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
| 25 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
| 26 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
| 27 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| 28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 29 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 30 |
| 31 # This code is meant to work on Python 2.4 and above only. |
| 32 # |
| 33 # TODO(robinson): Helpers for verbose, common checks like seeing if a |
| 34 # descriptor's cpp_type is CPPTYPE_MESSAGE. |
| 35 |
| 36 """Contains a metaclass and helper functions used to create |
| 37 protocol message classes from Descriptor objects at runtime. |
| 38 |
| 39 Recall that a metaclass is the "type" of a class. |
| 40 (A class is to a metaclass what an instance is to a class.) |
| 41 |
| 42 In this case, we use the GeneratedProtocolMessageType metaclass |
| 43 to inject all the useful functionality into the classes |
| 44 output by the protocol compiler at compile-time. |
| 45 |
| 46 The upshot of all this is that the real implementation |
| 47 details for ALL pure-Python protocol buffers are *here in |
| 48 this file*. |
| 49 """ |
| 50 |
| 51 __author__ = 'robinson@google.com (Will Robinson)' |
| 52 |
| 53 try: |
| 54 from cStringIO import StringIO |
| 55 except ImportError: |
| 56 from StringIO import StringIO |
| 57 import copy_reg |
| 58 import struct |
| 59 import weakref |
| 60 |
| 61 # We use "as" to avoid name collisions with variables. |
| 62 from google.protobuf.internal import containers |
| 63 from google.protobuf.internal import decoder |
| 64 from google.protobuf.internal import encoder |
| 65 from google.protobuf.internal import enum_type_wrapper |
| 66 from google.protobuf.internal import message_listener as message_listener_mod |
| 67 from google.protobuf.internal import type_checkers |
| 68 from google.protobuf.internal import wire_format |
| 69 from google.protobuf import descriptor as descriptor_mod |
| 70 from google.protobuf import message as message_mod |
| 71 from google.protobuf import text_format |
| 72 |
| 73 _FieldDescriptor = descriptor_mod.FieldDescriptor |
| 74 |
| 75 |
| 76 def NewMessage(bases, descriptor, dictionary): |
| 77 _AddClassAttributesForNestedExtensions(descriptor, dictionary) |
| 78 _AddSlots(descriptor, dictionary) |
| 79 return bases |
| 80 |
| 81 |
| 82 def InitMessage(descriptor, cls): |
| 83 cls._decoders_by_tag = {} |
| 84 cls._extensions_by_name = {} |
| 85 cls._extensions_by_number = {} |
| 86 if (descriptor.has_options and |
| 87 descriptor.GetOptions().message_set_wire_format): |
| 88 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( |
| 89 decoder.MessageSetItemDecoder(cls._extensions_by_number)) |
| 90 |
| 91 # Attach stuff to each FieldDescriptor for quick lookup later on. |
| 92 for field in descriptor.fields: |
| 93 _AttachFieldHelpers(cls, field) |
| 94 |
| 95 _AddEnumValues(descriptor, cls) |
| 96 _AddInitMethod(descriptor, cls) |
| 97 _AddPropertiesForFields(descriptor, cls) |
| 98 _AddPropertiesForExtensions(descriptor, cls) |
| 99 _AddStaticMethods(cls) |
| 100 _AddMessageMethods(descriptor, cls) |
| 101 _AddPrivateHelperMethods(cls) |
| 102 copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) |
| 103 |
| 104 |
| 105 # Stateless helpers for GeneratedProtocolMessageType below. |
| 106 # Outside clients should not access these directly. |
| 107 # |
| 108 # I opted not to make any of these methods on the metaclass, to make it more |
| 109 # clear that I'm not really using any state there and to keep clients from |
| 110 # thinking that they have direct access to these construction helpers. |
| 111 |
| 112 |
| 113 def _PropertyName(proto_field_name): |
| 114 """Returns the name of the public property attribute which |
| 115 clients can use to get and (in some cases) set the value |
| 116 of a protocol message field. |
| 117 |
| 118 Args: |
| 119 proto_field_name: The protocol message field name, exactly |
| 120 as it appears (or would appear) in a .proto file. |
| 121 """ |
| 122 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. |
| 123 # nnorwitz makes my day by writing: |
| 124 # """ |
| 125 # FYI. See the keyword module in the stdlib. This could be as simple as: |
| 126 # |
| 127 # if keyword.iskeyword(proto_field_name): |
| 128 # return proto_field_name + "_" |
| 129 # return proto_field_name |
| 130 # """ |
| 131 # Kenton says: The above is a BAD IDEA. People rely on being able to use |
| 132 # getattr() and setattr() to reflectively manipulate field values. If we |
| 133 # rename the properties, then every such user has to also make sure to apply |
| 134 # the same transformation. Note that currently if you name a field "yield", |
| 135 # you can still access it just fine using getattr/setattr -- it's not even |
| 136 # that cumbersome to do so. |
| 137 # TODO(kenton): Remove this method entirely if/when everyone agrees with my |
| 138 # position. |
| 139 return proto_field_name |
| 140 |
| 141 |
| 142 def _VerifyExtensionHandle(message, extension_handle): |
| 143 """Verify that the given extension handle is valid.""" |
| 144 |
| 145 if not isinstance(extension_handle, _FieldDescriptor): |
| 146 raise KeyError('HasExtension() expects an extension handle, got: %s' % |
| 147 extension_handle) |
| 148 |
| 149 if not extension_handle.is_extension: |
| 150 raise KeyError('"%s" is not an extension.' % extension_handle.full_name) |
| 151 |
| 152 if not extension_handle.containing_type: |
| 153 raise KeyError('"%s" is missing a containing_type.' |
| 154 % extension_handle.full_name) |
| 155 |
| 156 if extension_handle.containing_type is not message.DESCRIPTOR: |
| 157 raise KeyError('Extension "%s" extends message type "%s", but this ' |
| 158 'message is of type "%s".' % |
| 159 (extension_handle.full_name, |
| 160 extension_handle.containing_type.full_name, |
| 161 message.DESCRIPTOR.full_name)) |
| 162 |
| 163 |
| 164 def _AddSlots(message_descriptor, dictionary): |
| 165 """Adds a __slots__ entry to dictionary, containing the names of all valid |
| 166 attributes for this message type. |
| 167 |
| 168 Args: |
| 169 message_descriptor: A Descriptor instance describing this message type. |
| 170 dictionary: Class dictionary to which we'll add a '__slots__' entry. |
| 171 """ |
| 172 dictionary['__slots__'] = ['_cached_byte_size', |
| 173 '_cached_byte_size_dirty', |
| 174 '_fields', |
| 175 '_unknown_fields', |
| 176 '_is_present_in_parent', |
| 177 '_listener', |
| 178 '_listener_for_children', |
| 179 '__weakref__'] |
| 180 |
| 181 |
| 182 def _IsMessageSetExtension(field): |
| 183 return (field.is_extension and |
| 184 field.containing_type.has_options and |
| 185 field.containing_type.GetOptions().message_set_wire_format and |
| 186 field.type == _FieldDescriptor.TYPE_MESSAGE and |
| 187 field.message_type == field.extension_scope and |
| 188 field.label == _FieldDescriptor.LABEL_OPTIONAL) |
| 189 |
| 190 |
| 191 def _AttachFieldHelpers(cls, field_descriptor): |
| 192 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) |
| 193 is_packed = (field_descriptor.has_options and |
| 194 field_descriptor.GetOptions().packed) |
| 195 |
| 196 if _IsMessageSetExtension(field_descriptor): |
| 197 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) |
| 198 sizer = encoder.MessageSetItemSizer(field_descriptor.number) |
| 199 else: |
| 200 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( |
| 201 field_descriptor.number, is_repeated, is_packed) |
| 202 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( |
| 203 field_descriptor.number, is_repeated, is_packed) |
| 204 |
| 205 field_descriptor._encoder = field_encoder |
| 206 field_descriptor._sizer = sizer |
| 207 field_descriptor._default_constructor = _DefaultValueConstructorForField( |
| 208 field_descriptor) |
| 209 |
| 210 def AddDecoder(wiretype, is_packed): |
| 211 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) |
| 212 cls._decoders_by_tag[tag_bytes] = ( |
| 213 type_checkers.TYPE_TO_DECODER[field_descriptor.type]( |
| 214 field_descriptor.number, is_repeated, is_packed, |
| 215 field_descriptor, field_descriptor._default_constructor)) |
| 216 |
| 217 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], |
| 218 False) |
| 219 |
| 220 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): |
| 221 # To support wire compatibility of adding packed = true, add a decoder for |
| 222 # packed values regardless of the field's options. |
| 223 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) |
| 224 |
| 225 |
| 226 def _AddClassAttributesForNestedExtensions(descriptor, dictionary): |
| 227 extension_dict = descriptor.extensions_by_name |
| 228 for extension_name, extension_field in extension_dict.iteritems(): |
| 229 assert extension_name not in dictionary |
| 230 dictionary[extension_name] = extension_field |
| 231 |
| 232 |
| 233 def _AddEnumValues(descriptor, cls): |
| 234 """Sets class-level attributes for all enum fields defined in this message. |
| 235 |
| 236 Also exporting a class-level object that can name enum values. |
| 237 |
| 238 Args: |
| 239 descriptor: Descriptor object for this message type. |
| 240 cls: Class we're constructing for this message type. |
| 241 """ |
| 242 for enum_type in descriptor.enum_types: |
| 243 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) |
| 244 for enum_value in enum_type.values: |
| 245 setattr(cls, enum_value.name, enum_value.number) |
| 246 |
| 247 |
| 248 def _DefaultValueConstructorForField(field): |
| 249 """Returns a function which returns a default value for a field. |
| 250 |
| 251 Args: |
| 252 field: FieldDescriptor object for this field. |
| 253 |
| 254 The returned function has one argument: |
| 255 message: Message instance containing this field, or a weakref proxy |
| 256 of same. |
| 257 |
| 258 That function in turn returns a default value for this field. The default |
| 259 value may refer back to |message| via a weak reference. |
| 260 """ |
| 261 |
| 262 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 263 if field.has_default_value and field.default_value != []: |
| 264 raise ValueError('Repeated field default value not empty list: %s' % ( |
| 265 field.default_value)) |
| 266 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 267 # We can't look at _concrete_class yet since it might not have |
| 268 # been set. (Depends on order in which we initialize the classes). |
| 269 message_type = field.message_type |
| 270 def MakeRepeatedMessageDefault(message): |
| 271 return containers.RepeatedCompositeFieldContainer( |
| 272 message._listener_for_children, field.message_type) |
| 273 return MakeRepeatedMessageDefault |
| 274 else: |
| 275 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) |
| 276 def MakeRepeatedScalarDefault(message): |
| 277 return containers.RepeatedScalarFieldContainer( |
| 278 message._listener_for_children, type_checker) |
| 279 return MakeRepeatedScalarDefault |
| 280 |
| 281 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 282 # _concrete_class may not yet be initialized. |
| 283 message_type = field.message_type |
| 284 def MakeSubMessageDefault(message): |
| 285 result = message_type._concrete_class() |
| 286 result._SetListener(message._listener_for_children) |
| 287 return result |
| 288 return MakeSubMessageDefault |
| 289 |
| 290 def MakeScalarDefault(message): |
| 291 # TODO(protobuf-team): This may be broken since there may not be |
| 292 # default_value. Combine with has_default_value somehow. |
| 293 return field.default_value |
| 294 return MakeScalarDefault |
| 295 |
| 296 |
| 297 def _AddInitMethod(message_descriptor, cls): |
| 298 """Adds an __init__ method to cls.""" |
| 299 fields = message_descriptor.fields |
| 300 def init(self, **kwargs): |
| 301 self._cached_byte_size = 0 |
| 302 self._cached_byte_size_dirty = len(kwargs) > 0 |
| 303 self._fields = {} |
| 304 # _unknown_fields is () when empty for efficiency, and will be turned into |
| 305 # a list if fields are added. |
| 306 self._unknown_fields = () |
| 307 self._is_present_in_parent = False |
| 308 self._listener = message_listener_mod.NullMessageListener() |
| 309 self._listener_for_children = _Listener(self) |
| 310 for field_name, field_value in kwargs.iteritems(): |
| 311 field = _GetFieldByName(message_descriptor, field_name) |
| 312 if field is None: |
| 313 raise TypeError("%s() got an unexpected keyword argument '%s'" % |
| 314 (message_descriptor.name, field_name)) |
| 315 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 316 copy = field._default_constructor(self) |
| 317 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite |
| 318 for val in field_value: |
| 319 copy.add().MergeFrom(val) |
| 320 else: # Scalar |
| 321 copy.extend(field_value) |
| 322 self._fields[field] = copy |
| 323 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 324 copy = field._default_constructor(self) |
| 325 copy.MergeFrom(field_value) |
| 326 self._fields[field] = copy |
| 327 else: |
| 328 setattr(self, field_name, field_value) |
| 329 |
| 330 init.__module__ = None |
| 331 init.__doc__ = None |
| 332 cls.__init__ = init |
| 333 |
| 334 |
| 335 def _GetFieldByName(message_descriptor, field_name): |
| 336 """Returns a field descriptor by field name. |
| 337 |
| 338 Args: |
| 339 message_descriptor: A Descriptor describing all fields in message. |
| 340 field_name: The name of the field to retrieve. |
| 341 Returns: |
| 342 The field descriptor associated with the field name. |
| 343 """ |
| 344 try: |
| 345 return message_descriptor.fields_by_name[field_name] |
| 346 except KeyError: |
| 347 raise ValueError('Protocol message has no "%s" field.' % field_name) |
| 348 |
| 349 |
| 350 def _AddPropertiesForFields(descriptor, cls): |
| 351 """Adds properties for all fields in this protocol message type.""" |
| 352 for field in descriptor.fields: |
| 353 _AddPropertiesForField(field, cls) |
| 354 |
| 355 if descriptor.is_extendable: |
| 356 # _ExtensionDict is just an adaptor with no state so we allocate a new one |
| 357 # every time it is accessed. |
| 358 cls.Extensions = property(lambda self: _ExtensionDict(self)) |
| 359 |
| 360 |
| 361 def _AddPropertiesForField(field, cls): |
| 362 """Adds a public property for a protocol message field. |
| 363 Clients can use this property to get and (in the case |
| 364 of non-repeated scalar fields) directly set the value |
| 365 of a protocol message field. |
| 366 |
| 367 Args: |
| 368 field: A FieldDescriptor for this field. |
| 369 cls: The class we're constructing. |
| 370 """ |
| 371 # Catch it if we add other types that we should |
| 372 # handle specially here. |
| 373 assert _FieldDescriptor.MAX_CPPTYPE == 10 |
| 374 |
| 375 constant_name = field.name.upper() + "_FIELD_NUMBER" |
| 376 setattr(cls, constant_name, field.number) |
| 377 |
| 378 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 379 _AddPropertiesForRepeatedField(field, cls) |
| 380 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 381 _AddPropertiesForNonRepeatedCompositeField(field, cls) |
| 382 else: |
| 383 _AddPropertiesForNonRepeatedScalarField(field, cls) |
| 384 |
| 385 |
| 386 def _AddPropertiesForRepeatedField(field, cls): |
| 387 """Adds a public property for a "repeated" protocol message field. Clients |
| 388 can use this property to get the value of the field, which will be either a |
| 389 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see |
| 390 below). |
| 391 |
| 392 Note that when clients add values to these containers, we perform |
| 393 type-checking in the case of repeated scalar fields, and we also set any |
| 394 necessary "has" bits as a side-effect. |
| 395 |
| 396 Args: |
| 397 field: A FieldDescriptor for this field. |
| 398 cls: The class we're constructing. |
| 399 """ |
| 400 proto_field_name = field.name |
| 401 property_name = _PropertyName(proto_field_name) |
| 402 |
| 403 def getter(self): |
| 404 field_value = self._fields.get(field) |
| 405 if field_value is None: |
| 406 # Construct a new object to represent this field. |
| 407 field_value = field._default_constructor(self) |
| 408 |
| 409 # Atomically check if another thread has preempted us and, if not, swap |
| 410 # in the new object we just created. If someone has preempted us, we |
| 411 # take that object and discard ours. |
| 412 # WARNING: We are relying on setdefault() being atomic. This is true |
| 413 # in CPython but we haven't investigated others. This warning appears |
| 414 # in several other locations in this file. |
| 415 field_value = self._fields.setdefault(field, field_value) |
| 416 return field_value |
| 417 getter.__module__ = None |
| 418 getter.__doc__ = 'Getter for %s.' % proto_field_name |
| 419 |
| 420 # We define a setter just so we can throw an exception with a more |
| 421 # helpful error message. |
| 422 def setter(self, new_value): |
| 423 raise AttributeError('Assignment not allowed to repeated field ' |
| 424 '"%s" in protocol message object.' % proto_field_name) |
| 425 |
| 426 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name |
| 427 setattr(cls, property_name, property(getter, setter, doc=doc)) |
| 428 |
| 429 |
| 430 def _AddPropertiesForNonRepeatedScalarField(field, cls): |
| 431 """Adds a public property for a nonrepeated, scalar protocol message field. |
| 432 Clients can use this property to get and directly set the value of the field. |
| 433 Note that when the client sets the value of a field by using this property, |
| 434 all necessary "has" bits are set as a side-effect, and we also perform |
| 435 type-checking. |
| 436 |
| 437 Args: |
| 438 field: A FieldDescriptor for this field. |
| 439 cls: The class we're constructing. |
| 440 """ |
| 441 proto_field_name = field.name |
| 442 property_name = _PropertyName(proto_field_name) |
| 443 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) |
| 444 default_value = field.default_value |
| 445 valid_values = set() |
| 446 |
| 447 def getter(self): |
| 448 # TODO(protobuf-team): This may be broken since there may not be |
| 449 # default_value. Combine with has_default_value somehow. |
| 450 return self._fields.get(field, default_value) |
| 451 getter.__module__ = None |
| 452 getter.__doc__ = 'Getter for %s.' % proto_field_name |
| 453 def setter(self, new_value): |
| 454 type_checker.CheckValue(new_value) |
| 455 self._fields[field] = new_value |
| 456 # Check _cached_byte_size_dirty inline to improve performance, since scalar |
| 457 # setters are called frequently. |
| 458 if not self._cached_byte_size_dirty: |
| 459 self._Modified() |
| 460 |
| 461 setter.__module__ = None |
| 462 setter.__doc__ = 'Setter for %s.' % proto_field_name |
| 463 |
| 464 # Add a property to encapsulate the getter/setter. |
| 465 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name |
| 466 setattr(cls, property_name, property(getter, setter, doc=doc)) |
| 467 |
| 468 |
| 469 def _AddPropertiesForNonRepeatedCompositeField(field, cls): |
| 470 """Adds a public property for a nonrepeated, composite protocol message field. |
| 471 A composite field is a "group" or "message" field. |
| 472 |
| 473 Clients can use this property to get the value of the field, but cannot |
| 474 assign to the property directly. |
| 475 |
| 476 Args: |
| 477 field: A FieldDescriptor for this field. |
| 478 cls: The class we're constructing. |
| 479 """ |
| 480 # TODO(robinson): Remove duplication with similar method |
| 481 # for non-repeated scalars. |
| 482 proto_field_name = field.name |
| 483 property_name = _PropertyName(proto_field_name) |
| 484 |
| 485 # TODO(komarek): Can anyone explain to me why we cache the message_type this |
| 486 # way, instead of referring to field.message_type inside of getter(self)? |
| 487 # What if someone sets message_type later on (which makes for simpler |
| 488 # dyanmic proto descriptor and class creation code). |
| 489 message_type = field.message_type |
| 490 |
| 491 def getter(self): |
| 492 field_value = self._fields.get(field) |
| 493 if field_value is None: |
| 494 # Construct a new object to represent this field. |
| 495 field_value = message_type._concrete_class() # use field.message_type? |
| 496 field_value._SetListener(self._listener_for_children) |
| 497 |
| 498 # Atomically check if another thread has preempted us and, if not, swap |
| 499 # in the new object we just created. If someone has preempted us, we |
| 500 # take that object and discard ours. |
| 501 # WARNING: We are relying on setdefault() being atomic. This is true |
| 502 # in CPython but we haven't investigated others. This warning appears |
| 503 # in several other locations in this file. |
| 504 field_value = self._fields.setdefault(field, field_value) |
| 505 return field_value |
| 506 getter.__module__ = None |
| 507 getter.__doc__ = 'Getter for %s.' % proto_field_name |
| 508 |
| 509 # We define a setter just so we can throw an exception with a more |
| 510 # helpful error message. |
| 511 def setter(self, new_value): |
| 512 raise AttributeError('Assignment not allowed to composite field ' |
| 513 '"%s" in protocol message object.' % proto_field_name) |
| 514 |
| 515 # Add a property to encapsulate the getter. |
| 516 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name |
| 517 setattr(cls, property_name, property(getter, setter, doc=doc)) |
| 518 |
| 519 |
| 520 def _AddPropertiesForExtensions(descriptor, cls): |
| 521 """Adds properties for all fields in this protocol message type.""" |
| 522 extension_dict = descriptor.extensions_by_name |
| 523 for extension_name, extension_field in extension_dict.iteritems(): |
| 524 constant_name = extension_name.upper() + "_FIELD_NUMBER" |
| 525 setattr(cls, constant_name, extension_field.number) |
| 526 |
| 527 |
| 528 def _AddStaticMethods(cls): |
| 529 # TODO(robinson): This probably needs to be thread-safe(?) |
| 530 def RegisterExtension(extension_handle): |
| 531 extension_handle.containing_type = cls.DESCRIPTOR |
| 532 _AttachFieldHelpers(cls, extension_handle) |
| 533 |
| 534 # Try to insert our extension, failing if an extension with the same number |
| 535 # already exists. |
| 536 actual_handle = cls._extensions_by_number.setdefault( |
| 537 extension_handle.number, extension_handle) |
| 538 if actual_handle is not extension_handle: |
| 539 raise AssertionError( |
| 540 'Extensions "%s" and "%s" both try to extend message type "%s" with ' |
| 541 'field number %d.' % |
| 542 (extension_handle.full_name, actual_handle.full_name, |
| 543 cls.DESCRIPTOR.full_name, extension_handle.number)) |
| 544 |
| 545 cls._extensions_by_name[extension_handle.full_name] = extension_handle |
| 546 |
| 547 handle = extension_handle # avoid line wrapping |
| 548 if _IsMessageSetExtension(handle): |
| 549 # MessageSet extension. Also register under type name. |
| 550 cls._extensions_by_name[ |
| 551 extension_handle.message_type.full_name] = extension_handle |
| 552 |
| 553 cls.RegisterExtension = staticmethod(RegisterExtension) |
| 554 |
| 555 def FromString(s): |
| 556 message = cls() |
| 557 message.MergeFromString(s) |
| 558 return message |
| 559 cls.FromString = staticmethod(FromString) |
| 560 |
| 561 |
| 562 def _IsPresent(item): |
| 563 """Given a (FieldDescriptor, value) tuple from _fields, return true if the |
| 564 value should be included in the list returned by ListFields().""" |
| 565 |
| 566 if item[0].label == _FieldDescriptor.LABEL_REPEATED: |
| 567 return bool(item[1]) |
| 568 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 569 return item[1]._is_present_in_parent |
| 570 else: |
| 571 return True |
| 572 |
| 573 |
| 574 def _AddListFieldsMethod(message_descriptor, cls): |
| 575 """Helper for _AddMessageMethods().""" |
| 576 |
| 577 def ListFields(self): |
| 578 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] |
| 579 all_fields.sort(key = lambda item: item[0].number) |
| 580 return all_fields |
| 581 |
| 582 cls.ListFields = ListFields |
| 583 |
| 584 |
| 585 def _AddHasFieldMethod(message_descriptor, cls): |
| 586 """Helper for _AddMessageMethods().""" |
| 587 |
| 588 singular_fields = {} |
| 589 for field in message_descriptor.fields: |
| 590 if field.label != _FieldDescriptor.LABEL_REPEATED: |
| 591 singular_fields[field.name] = field |
| 592 |
| 593 def HasField(self, field_name): |
| 594 try: |
| 595 field = singular_fields[field_name] |
| 596 except KeyError: |
| 597 raise ValueError( |
| 598 'Protocol message has no singular "%s" field.' % field_name) |
| 599 |
| 600 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 601 value = self._fields.get(field) |
| 602 return value is not None and value._is_present_in_parent |
| 603 else: |
| 604 return field in self._fields |
| 605 cls.HasField = HasField |
| 606 |
| 607 |
| 608 def _AddClearFieldMethod(message_descriptor, cls): |
| 609 """Helper for _AddMessageMethods().""" |
| 610 def ClearField(self, field_name): |
| 611 try: |
| 612 field = message_descriptor.fields_by_name[field_name] |
| 613 except KeyError: |
| 614 raise ValueError('Protocol message has no "%s" field.' % field_name) |
| 615 |
| 616 if field in self._fields: |
| 617 # Note: If the field is a sub-message, its listener will still point |
| 618 # at us. That's fine, because the worst than can happen is that it |
| 619 # will call _Modified() and invalidate our byte size. Big deal. |
| 620 del self._fields[field] |
| 621 |
| 622 # Always call _Modified() -- even if nothing was changed, this is |
| 623 # a mutating method, and thus calling it should cause the field to become |
| 624 # present in the parent message. |
| 625 self._Modified() |
| 626 |
| 627 cls.ClearField = ClearField |
| 628 |
| 629 |
| 630 def _AddClearExtensionMethod(cls): |
| 631 """Helper for _AddMessageMethods().""" |
| 632 def ClearExtension(self, extension_handle): |
| 633 _VerifyExtensionHandle(self, extension_handle) |
| 634 |
| 635 # Similar to ClearField(), above. |
| 636 if extension_handle in self._fields: |
| 637 del self._fields[extension_handle] |
| 638 self._Modified() |
| 639 cls.ClearExtension = ClearExtension |
| 640 |
| 641 |
| 642 def _AddClearMethod(message_descriptor, cls): |
| 643 """Helper for _AddMessageMethods().""" |
| 644 def Clear(self): |
| 645 # Clear fields. |
| 646 self._fields = {} |
| 647 self._unknown_fields = () |
| 648 self._Modified() |
| 649 cls.Clear = Clear |
| 650 |
| 651 |
| 652 def _AddHasExtensionMethod(cls): |
| 653 """Helper for _AddMessageMethods().""" |
| 654 def HasExtension(self, extension_handle): |
| 655 _VerifyExtensionHandle(self, extension_handle) |
| 656 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: |
| 657 raise KeyError('"%s" is repeated.' % extension_handle.full_name) |
| 658 |
| 659 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 660 value = self._fields.get(extension_handle) |
| 661 return value is not None and value._is_present_in_parent |
| 662 else: |
| 663 return extension_handle in self._fields |
| 664 cls.HasExtension = HasExtension |
| 665 |
| 666 |
| 667 def _AddEqualsMethod(message_descriptor, cls): |
| 668 """Helper for _AddMessageMethods().""" |
| 669 def __eq__(self, other): |
| 670 if (not isinstance(other, message_mod.Message) or |
| 671 other.DESCRIPTOR != self.DESCRIPTOR): |
| 672 return False |
| 673 |
| 674 if self is other: |
| 675 return True |
| 676 |
| 677 if not self.ListFields() == other.ListFields(): |
| 678 return False |
| 679 |
| 680 # Sort unknown fields because their order shouldn't affect equality test. |
| 681 unknown_fields = list(self._unknown_fields) |
| 682 unknown_fields.sort() |
| 683 other_unknown_fields = list(other._unknown_fields) |
| 684 other_unknown_fields.sort() |
| 685 |
| 686 return unknown_fields == other_unknown_fields |
| 687 |
| 688 cls.__eq__ = __eq__ |
| 689 |
| 690 |
| 691 def _AddStrMethod(message_descriptor, cls): |
| 692 """Helper for _AddMessageMethods().""" |
| 693 def __str__(self): |
| 694 return text_format.MessageToString(self) |
| 695 cls.__str__ = __str__ |
| 696 |
| 697 |
| 698 def _AddUnicodeMethod(unused_message_descriptor, cls): |
| 699 """Helper for _AddMessageMethods().""" |
| 700 |
| 701 def __unicode__(self): |
| 702 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') |
| 703 cls.__unicode__ = __unicode__ |
| 704 |
| 705 |
| 706 def _AddSetListenerMethod(cls): |
| 707 """Helper for _AddMessageMethods().""" |
| 708 def SetListener(self, listener): |
| 709 if listener is None: |
| 710 self._listener = message_listener_mod.NullMessageListener() |
| 711 else: |
| 712 self._listener = listener |
| 713 cls._SetListener = SetListener |
| 714 |
| 715 |
| 716 def _BytesForNonRepeatedElement(value, field_number, field_type): |
| 717 """Returns the number of bytes needed to serialize a non-repeated element. |
| 718 The returned byte count includes space for tag information and any |
| 719 other additional space associated with serializing value. |
| 720 |
| 721 Args: |
| 722 value: Value we're serializing. |
| 723 field_number: Field number of this value. (Since the field number |
| 724 is stored as part of a varint-encoded tag, this has an impact |
| 725 on the total bytes required to serialize the value). |
| 726 field_type: The type of the field. One of the TYPE_* constants |
| 727 within FieldDescriptor. |
| 728 """ |
| 729 try: |
| 730 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] |
| 731 return fn(field_number, value) |
| 732 except KeyError: |
| 733 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) |
| 734 |
| 735 |
| 736 def _AddByteSizeMethod(message_descriptor, cls): |
| 737 """Helper for _AddMessageMethods().""" |
| 738 |
| 739 def ByteSize(self): |
| 740 if not self._cached_byte_size_dirty: |
| 741 return self._cached_byte_size |
| 742 |
| 743 size = 0 |
| 744 for field_descriptor, field_value in self.ListFields(): |
| 745 size += field_descriptor._sizer(field_value) |
| 746 |
| 747 for tag_bytes, value_bytes in self._unknown_fields: |
| 748 size += len(tag_bytes) + len(value_bytes) |
| 749 |
| 750 self._cached_byte_size = size |
| 751 self._cached_byte_size_dirty = False |
| 752 self._listener_for_children.dirty = False |
| 753 return size |
| 754 |
| 755 cls.ByteSize = ByteSize |
| 756 |
| 757 |
| 758 def _AddSerializeToStringMethod(message_descriptor, cls): |
| 759 """Helper for _AddMessageMethods().""" |
| 760 |
| 761 def SerializeToString(self): |
| 762 # Check if the message has all of its required fields set. |
| 763 errors = [] |
| 764 if not self.IsInitialized(): |
| 765 raise message_mod.EncodeError( |
| 766 'Message %s is missing required fields: %s' % ( |
| 767 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) |
| 768 return self.SerializePartialToString() |
| 769 cls.SerializeToString = SerializeToString |
| 770 |
| 771 |
| 772 def _AddSerializePartialToStringMethod(message_descriptor, cls): |
| 773 """Helper for _AddMessageMethods().""" |
| 774 |
| 775 def SerializePartialToString(self): |
| 776 out = StringIO() |
| 777 self._InternalSerialize(out.write) |
| 778 return out.getvalue() |
| 779 cls.SerializePartialToString = SerializePartialToString |
| 780 |
| 781 def InternalSerialize(self, write_bytes): |
| 782 for field_descriptor, field_value in self.ListFields(): |
| 783 field_descriptor._encoder(write_bytes, field_value) |
| 784 for tag_bytes, value_bytes in self._unknown_fields: |
| 785 write_bytes(tag_bytes) |
| 786 write_bytes(value_bytes) |
| 787 cls._InternalSerialize = InternalSerialize |
| 788 |
| 789 |
| 790 def _AddMergeFromStringMethod(message_descriptor, cls): |
| 791 """Helper for _AddMessageMethods().""" |
| 792 def MergeFromString(self, serialized): |
| 793 length = len(serialized) |
| 794 try: |
| 795 if self._InternalParse(serialized, 0, length) != length: |
| 796 # The only reason _InternalParse would return early is if it |
| 797 # encountered an end-group tag. |
| 798 raise message_mod.DecodeError('Unexpected end-group tag.') |
| 799 except IndexError: |
| 800 raise message_mod.DecodeError('Truncated message.') |
| 801 except struct.error, e: |
| 802 raise message_mod.DecodeError(e) |
| 803 return length # Return this for legacy reasons. |
| 804 cls.MergeFromString = MergeFromString |
| 805 |
| 806 local_ReadTag = decoder.ReadTag |
| 807 local_SkipField = decoder.SkipField |
| 808 decoders_by_tag = cls._decoders_by_tag |
| 809 |
| 810 def InternalParse(self, buffer, pos, end): |
| 811 self._Modified() |
| 812 field_dict = self._fields |
| 813 unknown_field_list = self._unknown_fields |
| 814 while pos != end: |
| 815 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) |
| 816 field_decoder = decoders_by_tag.get(tag_bytes) |
| 817 if field_decoder is None: |
| 818 value_start_pos = new_pos |
| 819 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) |
| 820 if new_pos == -1: |
| 821 return pos |
| 822 if not unknown_field_list: |
| 823 unknown_field_list = self._unknown_fields = [] |
| 824 unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) |
| 825 pos = new_pos |
| 826 else: |
| 827 pos = field_decoder(buffer, new_pos, end, self, field_dict) |
| 828 return pos |
| 829 cls._InternalParse = InternalParse |
| 830 |
| 831 |
| 832 def _AddIsInitializedMethod(message_descriptor, cls): |
| 833 """Adds the IsInitialized and FindInitializationError methods to the |
| 834 protocol message class.""" |
| 835 |
| 836 required_fields = [field for field in message_descriptor.fields |
| 837 if field.label == _FieldDescriptor.LABEL_REQUIRED] |
| 838 |
| 839 def IsInitialized(self, errors=None): |
| 840 """Checks if all required fields of a message are set. |
| 841 |
| 842 Args: |
| 843 errors: A list which, if provided, will be populated with the field |
| 844 paths of all missing required fields. |
| 845 |
| 846 Returns: |
| 847 True iff the specified message has all required fields set. |
| 848 """ |
| 849 |
| 850 # Performance is critical so we avoid HasField() and ListFields(). |
| 851 |
| 852 for field in required_fields: |
| 853 if (field not in self._fields or |
| 854 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and |
| 855 not self._fields[field]._is_present_in_parent)): |
| 856 if errors is not None: |
| 857 errors.extend(self.FindInitializationErrors()) |
| 858 return False |
| 859 |
| 860 for field, value in self._fields.iteritems(): |
| 861 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 862 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 863 for element in value: |
| 864 if not element.IsInitialized(): |
| 865 if errors is not None: |
| 866 errors.extend(self.FindInitializationErrors()) |
| 867 return False |
| 868 elif value._is_present_in_parent and not value.IsInitialized(): |
| 869 if errors is not None: |
| 870 errors.extend(self.FindInitializationErrors()) |
| 871 return False |
| 872 |
| 873 return True |
| 874 |
| 875 cls.IsInitialized = IsInitialized |
| 876 |
| 877 def FindInitializationErrors(self): |
| 878 """Finds required fields which are not initialized. |
| 879 |
| 880 Returns: |
| 881 A list of strings. Each string is a path to an uninitialized field from |
| 882 the top-level message, e.g. "foo.bar[5].baz". |
| 883 """ |
| 884 |
| 885 errors = [] # simplify things |
| 886 |
| 887 for field in required_fields: |
| 888 if not self.HasField(field.name): |
| 889 errors.append(field.name) |
| 890 |
| 891 for field, value in self.ListFields(): |
| 892 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 893 if field.is_extension: |
| 894 name = "(%s)" % field.full_name |
| 895 else: |
| 896 name = field.name |
| 897 |
| 898 if field.label == _FieldDescriptor.LABEL_REPEATED: |
| 899 for i in xrange(len(value)): |
| 900 element = value[i] |
| 901 prefix = "%s[%d]." % (name, i) |
| 902 sub_errors = element.FindInitializationErrors() |
| 903 errors += [ prefix + error for error in sub_errors ] |
| 904 else: |
| 905 prefix = name + "." |
| 906 sub_errors = value.FindInitializationErrors() |
| 907 errors += [ prefix + error for error in sub_errors ] |
| 908 |
| 909 return errors |
| 910 |
| 911 cls.FindInitializationErrors = FindInitializationErrors |
| 912 |
| 913 |
| 914 def _AddMergeFromMethod(cls): |
| 915 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED |
| 916 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE |
| 917 |
| 918 def MergeFrom(self, msg): |
| 919 if not isinstance(msg, cls): |
| 920 raise TypeError( |
| 921 "Parameter to MergeFrom() must be instance of same class: " |
| 922 "expected %s got %s." % (cls.__name__, type(msg).__name__)) |
| 923 |
| 924 assert msg is not self |
| 925 self._Modified() |
| 926 |
| 927 fields = self._fields |
| 928 |
| 929 for field, value in msg._fields.iteritems(): |
| 930 if field.label == LABEL_REPEATED: |
| 931 field_value = fields.get(field) |
| 932 if field_value is None: |
| 933 # Construct a new object to represent this field. |
| 934 field_value = field._default_constructor(self) |
| 935 fields[field] = field_value |
| 936 field_value.MergeFrom(value) |
| 937 elif field.cpp_type == CPPTYPE_MESSAGE: |
| 938 if value._is_present_in_parent: |
| 939 field_value = fields.get(field) |
| 940 if field_value is None: |
| 941 # Construct a new object to represent this field. |
| 942 field_value = field._default_constructor(self) |
| 943 fields[field] = field_value |
| 944 field_value.MergeFrom(value) |
| 945 else: |
| 946 self._fields[field] = value |
| 947 |
| 948 if msg._unknown_fields: |
| 949 if not self._unknown_fields: |
| 950 self._unknown_fields = [] |
| 951 self._unknown_fields.extend(msg._unknown_fields) |
| 952 |
| 953 cls.MergeFrom = MergeFrom |
| 954 |
| 955 |
| 956 def _AddMessageMethods(message_descriptor, cls): |
| 957 """Adds implementations of all Message methods to cls.""" |
| 958 _AddListFieldsMethod(message_descriptor, cls) |
| 959 _AddHasFieldMethod(message_descriptor, cls) |
| 960 _AddClearFieldMethod(message_descriptor, cls) |
| 961 if message_descriptor.is_extendable: |
| 962 _AddClearExtensionMethod(cls) |
| 963 _AddHasExtensionMethod(cls) |
| 964 _AddClearMethod(message_descriptor, cls) |
| 965 _AddEqualsMethod(message_descriptor, cls) |
| 966 _AddStrMethod(message_descriptor, cls) |
| 967 _AddUnicodeMethod(message_descriptor, cls) |
| 968 _AddSetListenerMethod(cls) |
| 969 _AddByteSizeMethod(message_descriptor, cls) |
| 970 _AddSerializeToStringMethod(message_descriptor, cls) |
| 971 _AddSerializePartialToStringMethod(message_descriptor, cls) |
| 972 _AddMergeFromStringMethod(message_descriptor, cls) |
| 973 _AddIsInitializedMethod(message_descriptor, cls) |
| 974 _AddMergeFromMethod(cls) |
| 975 |
| 976 |
| 977 def _AddPrivateHelperMethods(cls): |
| 978 """Adds implementation of private helper methods to cls.""" |
| 979 |
| 980 def Modified(self): |
| 981 """Sets the _cached_byte_size_dirty bit to true, |
| 982 and propagates this to our listener iff this was a state change. |
| 983 """ |
| 984 |
| 985 # Note: Some callers check _cached_byte_size_dirty before calling |
| 986 # _Modified() as an extra optimization. So, if this method is ever |
| 987 # changed such that it does stuff even when _cached_byte_size_dirty is |
| 988 # already true, the callers need to be updated. |
| 989 if not self._cached_byte_size_dirty: |
| 990 self._cached_byte_size_dirty = True |
| 991 self._listener_for_children.dirty = True |
| 992 self._is_present_in_parent = True |
| 993 self._listener.Modified() |
| 994 |
| 995 cls._Modified = Modified |
| 996 cls.SetInParent = Modified |
| 997 |
| 998 |
| 999 class _Listener(object): |
| 1000 |
| 1001 """MessageListener implementation that a parent message registers with its |
| 1002 child message. |
| 1003 |
| 1004 In order to support semantics like: |
| 1005 |
| 1006 foo.bar.baz.qux = 23 |
| 1007 assert foo.HasField('bar') |
| 1008 |
| 1009 ...child objects must have back references to their parents. |
| 1010 This helper class is at the heart of this support. |
| 1011 """ |
| 1012 |
| 1013 def __init__(self, parent_message): |
| 1014 """Args: |
| 1015 parent_message: The message whose _Modified() method we should call when |
| 1016 we receive Modified() messages. |
| 1017 """ |
| 1018 # This listener establishes a back reference from a child (contained) object |
| 1019 # to its parent (containing) object. We make this a weak reference to avoid |
| 1020 # creating cyclic garbage when the client finishes with the 'parent' object |
| 1021 # in the tree. |
| 1022 if isinstance(parent_message, weakref.ProxyType): |
| 1023 self._parent_message_weakref = parent_message |
| 1024 else: |
| 1025 self._parent_message_weakref = weakref.proxy(parent_message) |
| 1026 |
| 1027 # As an optimization, we also indicate directly on the listener whether |
| 1028 # or not the parent message is dirty. This way we can avoid traversing |
| 1029 # up the tree in the common case. |
| 1030 self.dirty = False |
| 1031 |
| 1032 def Modified(self): |
| 1033 if self.dirty: |
| 1034 return |
| 1035 try: |
| 1036 # Propagate the signal to our parents iff this is the first field set. |
| 1037 self._parent_message_weakref._Modified() |
| 1038 except ReferenceError: |
| 1039 # We can get here if a client has kept a reference to a child object, |
| 1040 # and is now setting a field on it, but the child's parent has been |
| 1041 # garbage-collected. This is not an error. |
| 1042 pass |
| 1043 |
| 1044 |
| 1045 # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... |
| 1046 # TODO(robinson): Unify error handling of "unknown extension" crap. |
| 1047 # TODO(robinson): Support iteritems()-style iteration over all |
| 1048 # extensions with the "has" bits turned on? |
| 1049 class _ExtensionDict(object): |
| 1050 |
| 1051 """Dict-like container for supporting an indexable "Extensions" |
| 1052 field on proto instances. |
| 1053 |
| 1054 Note that in all cases we expect extension handles to be |
| 1055 FieldDescriptors. |
| 1056 """ |
| 1057 |
| 1058 def __init__(self, extended_message): |
| 1059 """extended_message: Message instance for which we are the Extensions dict. |
| 1060 """ |
| 1061 |
| 1062 self._extended_message = extended_message |
| 1063 |
| 1064 def __getitem__(self, extension_handle): |
| 1065 """Returns the current value of the given extension handle.""" |
| 1066 |
| 1067 _VerifyExtensionHandle(self._extended_message, extension_handle) |
| 1068 |
| 1069 result = self._extended_message._fields.get(extension_handle) |
| 1070 if result is not None: |
| 1071 return result |
| 1072 |
| 1073 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: |
| 1074 result = extension_handle._default_constructor(self._extended_message) |
| 1075 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| 1076 result = extension_handle.message_type._concrete_class() |
| 1077 try: |
| 1078 result._SetListener(self._extended_message._listener_for_children) |
| 1079 except ReferenceError: |
| 1080 pass |
| 1081 else: |
| 1082 # Singular scalar -- just return the default without inserting into the |
| 1083 # dict. |
| 1084 return extension_handle.default_value |
| 1085 |
| 1086 # Atomically check if another thread has preempted us and, if not, swap |
| 1087 # in the new object we just created. If someone has preempted us, we |
| 1088 # take that object and discard ours. |
| 1089 # WARNING: We are relying on setdefault() being atomic. This is true |
| 1090 # in CPython but we haven't investigated others. This warning appears |
| 1091 # in several other locations in this file. |
| 1092 result = self._extended_message._fields.setdefault( |
| 1093 extension_handle, result) |
| 1094 |
| 1095 return result |
| 1096 |
| 1097 def __eq__(self, other): |
| 1098 if not isinstance(other, self.__class__): |
| 1099 return False |
| 1100 |
| 1101 my_fields = self._extended_message.ListFields() |
| 1102 other_fields = other._extended_message.ListFields() |
| 1103 |
| 1104 # Get rid of non-extension fields. |
| 1105 my_fields = [ field for field in my_fields if field.is_extension ] |
| 1106 other_fields = [ field for field in other_fields if field.is_extension ] |
| 1107 |
| 1108 return my_fields == other_fields |
| 1109 |
| 1110 def __ne__(self, other): |
| 1111 return not self == other |
| 1112 |
| 1113 def __hash__(self): |
| 1114 raise TypeError('unhashable object') |
| 1115 |
| 1116 # Note that this is only meaningful for non-repeated, scalar extension |
| 1117 # fields. Note also that we may have to call _Modified() when we do |
| 1118 # successfully set a field this way, to set any necssary "has" bits in the |
| 1119 # ancestors of the extended message. |
| 1120 def __setitem__(self, extension_handle, value): |
| 1121 """If extension_handle specifies a non-repeated, scalar extension |
| 1122 field, sets the value of that field. |
| 1123 """ |
| 1124 |
| 1125 _VerifyExtensionHandle(self._extended_message, extension_handle) |
| 1126 |
| 1127 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or |
| 1128 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): |
| 1129 raise TypeError( |
| 1130 'Cannot assign to extension "%s" because it is a repeated or ' |
| 1131 'composite type.' % extension_handle.full_name) |
| 1132 |
| 1133 # It's slightly wasteful to lookup the type checker each time, |
| 1134 # but we expect this to be a vanishingly uncommon case anyway. |
| 1135 type_checker = type_checkers.GetTypeChecker( |
| 1136 extension_handle.cpp_type, extension_handle.type) |
| 1137 type_checker.CheckValue(value) |
| 1138 self._extended_message._fields[extension_handle] = value |
| 1139 self._extended_message._Modified() |
| 1140 |
| 1141 def _FindExtensionByName(self, name): |
| 1142 """Tries to find a known extension with the specified name. |
| 1143 |
| 1144 Args: |
| 1145 name: Extension full name. |
| 1146 |
| 1147 Returns: |
| 1148 Extension field descriptor. |
| 1149 """ |
| 1150 return self._extended_message._extensions_by_name.get(name, None) |
OLD | NEW |