| Index: third_party/protobuf/python/google/protobuf/internal/decoder.py
|
| diff --git a/third_party/protobuf/python/google/protobuf/internal/decoder.py b/third_party/protobuf/python/google/protobuf/internal/decoder.py
|
| index cb6f5729dd765e7105a993c6f3060555d5abb6a4..31869e4575ca900440a8ce452ce2af5ed87a19aa 100755
|
| --- a/third_party/protobuf/python/google/protobuf/internal/decoder.py
|
| +++ b/third_party/protobuf/python/google/protobuf/internal/decoder.py
|
| @@ -1,6 +1,6 @@
|
| # Protocol Buffers - Google's data interchange format
|
| # Copyright 2008 Google Inc. All rights reserved.
|
| -# http://code.google.com/p/protobuf/
|
| +# https://developers.google.com/protocol-buffers/
|
| #
|
| # Redistribution and use in source and binary forms, with or without
|
| # modification, are permitted provided that the following conditions are
|
| @@ -81,6 +81,12 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it.
|
| __author__ = 'kenton@google.com (Kenton Varda)'
|
|
|
| import struct
|
| +
|
| +import six
|
| +
|
| +if six.PY3:
|
| + long = int
|
| +
|
| from google.protobuf.internal import encoder
|
| from google.protobuf.internal import wire_format
|
| from google.protobuf import message
|
| @@ -98,7 +104,7 @@ _NAN = _POS_INF * 0
|
| _DecodeError = message.DecodeError
|
|
|
|
|
| -def _VarintDecoder(mask):
|
| +def _VarintDecoder(mask, result_type):
|
| """Return an encoder for a basic varint value (does not include tag).
|
|
|
| Decoded values will be bitwise-anded with the given mask before being
|
| @@ -108,16 +114,16 @@ def _VarintDecoder(mask):
|
| decoder returns a (value, new_pos) pair.
|
| """
|
|
|
| - local_ord = ord
|
| def DecodeVarint(buffer, pos):
|
| result = 0
|
| shift = 0
|
| while 1:
|
| - b = local_ord(buffer[pos])
|
| + b = six.indexbytes(buffer, pos)
|
| result |= ((b & 0x7f) << shift)
|
| pos += 1
|
| if not (b & 0x80):
|
| result &= mask
|
| + result = result_type(result)
|
| return (result, pos)
|
| shift += 7
|
| if shift >= 64:
|
| @@ -125,15 +131,14 @@ def _VarintDecoder(mask):
|
| return DecodeVarint
|
|
|
|
|
| -def _SignedVarintDecoder(mask):
|
| +def _SignedVarintDecoder(mask, result_type):
|
| """Like _VarintDecoder() but decodes signed values."""
|
|
|
| - local_ord = ord
|
| def DecodeVarint(buffer, pos):
|
| result = 0
|
| shift = 0
|
| while 1:
|
| - b = local_ord(buffer[pos])
|
| + b = six.indexbytes(buffer, pos)
|
| result |= ((b & 0x7f) << shift)
|
| pos += 1
|
| if not (b & 0x80):
|
| @@ -142,19 +147,23 @@ def _SignedVarintDecoder(mask):
|
| result |= ~mask
|
| else:
|
| result &= mask
|
| + result = result_type(result)
|
| return (result, pos)
|
| shift += 7
|
| if shift >= 64:
|
| raise _DecodeError('Too many bytes when decoding varint.')
|
| return DecodeVarint
|
|
|
| +# We force 32-bit values to int and 64-bit values to long to make
|
| +# alternate implementations where the distinction is more significant
|
| +# (e.g. the C++ implementation) simpler.
|
|
|
| -_DecodeVarint = _VarintDecoder((1 << 64) - 1)
|
| -_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1)
|
| +_DecodeVarint = _VarintDecoder((1 << 64) - 1, long)
|
| +_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long)
|
|
|
| # Use these versions for values which must be limited to 32 bits.
|
| -_DecodeVarint32 = _VarintDecoder((1 << 32) - 1)
|
| -_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1)
|
| +_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
|
| +_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int)
|
|
|
|
|
| def ReadTag(buffer, pos):
|
| @@ -169,7 +178,7 @@ def ReadTag(buffer, pos):
|
| """
|
|
|
| start = pos
|
| - while ord(buffer[pos]) & 0x80:
|
| + while six.indexbytes(buffer, pos) & 0x80:
|
| pos += 1
|
| pos += 1
|
| return (buffer[start:pos], pos)
|
| @@ -294,13 +303,12 @@ def _FloatDecoder():
|
| # If this value has all its exponent bits set, then it's non-finite.
|
| # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
|
| # To avoid that, we parse it specially.
|
| - if ((float_bytes[3] in '\x7F\xFF')
|
| - and (float_bytes[2] >= '\x80')):
|
| + if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
|
| # If at least one significand bit is set...
|
| - if float_bytes[0:3] != '\x00\x00\x80':
|
| + if float_bytes[0:3] != b'\x00\x00\x80':
|
| return (_NAN, new_pos)
|
| # If sign bit is set...
|
| - if float_bytes[3] == '\xFF':
|
| + if float_bytes[3:4] == b'\xFF':
|
| return (_NEG_INF, new_pos)
|
| return (_POS_INF, new_pos)
|
|
|
| @@ -329,9 +337,9 @@ def _DoubleDecoder():
|
| # If this value has all its exponent bits set and at least one significand
|
| # bit set, it's not a number. In Python 2.4, struct.unpack will treat it
|
| # as inf or -inf. To avoid that, we treat it specially.
|
| - if ((double_bytes[7] in '\x7F\xFF')
|
| - and (double_bytes[6] >= '\xF0')
|
| - and (double_bytes[0:7] != '\x00\x00\x00\x00\x00\x00\xF0')):
|
| + if ((double_bytes[7:8] in b'\x7F\xFF')
|
| + and (double_bytes[6:7] >= b'\xF0')
|
| + and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
|
| return (_NAN, new_pos)
|
|
|
| # Note that we expect someone up-stack to catch struct.error and convert
|
| @@ -342,10 +350,86 @@ def _DoubleDecoder():
|
| return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
|
|
|
|
|
| +def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
|
| + enum_type = key.enum_type
|
| + if is_packed:
|
| + local_DecodeVarint = _DecodeVarint
|
| + def DecodePackedField(buffer, pos, end, message, field_dict):
|
| + value = field_dict.get(key)
|
| + if value is None:
|
| + value = field_dict.setdefault(key, new_default(message))
|
| + (endpoint, pos) = local_DecodeVarint(buffer, pos)
|
| + endpoint += pos
|
| + if endpoint > end:
|
| + raise _DecodeError('Truncated message.')
|
| + while pos < endpoint:
|
| + value_start_pos = pos
|
| + (element, pos) = _DecodeSignedVarint32(buffer, pos)
|
| + if element in enum_type.values_by_number:
|
| + value.append(element)
|
| + else:
|
| + if not message._unknown_fields:
|
| + message._unknown_fields = []
|
| + tag_bytes = encoder.TagBytes(field_number,
|
| + wire_format.WIRETYPE_VARINT)
|
| + message._unknown_fields.append(
|
| + (tag_bytes, buffer[value_start_pos:pos]))
|
| + if pos > endpoint:
|
| + if element in enum_type.values_by_number:
|
| + del value[-1] # Discard corrupt value.
|
| + else:
|
| + del message._unknown_fields[-1]
|
| + raise _DecodeError('Packed element was truncated.')
|
| + return pos
|
| + return DecodePackedField
|
| + elif is_repeated:
|
| + tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
|
| + tag_len = len(tag_bytes)
|
| + def DecodeRepeatedField(buffer, pos, end, message, field_dict):
|
| + value = field_dict.get(key)
|
| + if value is None:
|
| + value = field_dict.setdefault(key, new_default(message))
|
| + while 1:
|
| + (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
|
| + if element in enum_type.values_by_number:
|
| + value.append(element)
|
| + else:
|
| + if not message._unknown_fields:
|
| + message._unknown_fields = []
|
| + message._unknown_fields.append(
|
| + (tag_bytes, buffer[pos:new_pos]))
|
| + # Predict that the next tag is another copy of the same repeated
|
| + # field.
|
| + pos = new_pos + tag_len
|
| + if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
|
| + # Prediction failed. Return.
|
| + if new_pos > end:
|
| + raise _DecodeError('Truncated message.')
|
| + return new_pos
|
| + return DecodeRepeatedField
|
| + else:
|
| + def DecodeField(buffer, pos, end, message, field_dict):
|
| + value_start_pos = pos
|
| + (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
|
| + if pos > end:
|
| + raise _DecodeError('Truncated message.')
|
| + if enum_value in enum_type.values_by_number:
|
| + field_dict[key] = enum_value
|
| + else:
|
| + if not message._unknown_fields:
|
| + message._unknown_fields = []
|
| + tag_bytes = encoder.TagBytes(field_number,
|
| + wire_format.WIRETYPE_VARINT)
|
| + message._unknown_fields.append(
|
| + (tag_bytes, buffer[value_start_pos:pos]))
|
| + return pos
|
| + return DecodeField
|
| +
|
| +
|
| # --------------------------------------------------------------------
|
|
|
|
|
| -Int32Decoder = EnumDecoder = _SimpleDecoder(
|
| +Int32Decoder = _SimpleDecoder(
|
| wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
|
|
|
| Int64Decoder = _SimpleDecoder(
|
| @@ -378,7 +462,15 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
|
| """Returns a decoder for a string field."""
|
|
|
| local_DecodeVarint = _DecodeVarint
|
| - local_unicode = unicode
|
| + local_unicode = six.text_type
|
| +
|
| + def _ConvertToUnicode(byte_str):
|
| + try:
|
| + return local_unicode(byte_str, 'utf-8')
|
| + except UnicodeDecodeError as e:
|
| + # add more information to the error message and re-raise it.
|
| + e.reason = '%s in field: %s' % (e, key.full_name)
|
| + raise
|
|
|
| assert not is_packed
|
| if is_repeated:
|
| @@ -394,7 +486,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
|
| new_pos = pos + size
|
| if new_pos > end:
|
| raise _DecodeError('Truncated string.')
|
| - value.append(local_unicode(buffer[pos:new_pos], 'utf-8'))
|
| + value.append(_ConvertToUnicode(buffer[pos:new_pos]))
|
| # Predict that the next tag is another copy of the same repeated field.
|
| pos = new_pos + tag_len
|
| if buffer[new_pos:pos] != tag_bytes or new_pos == end:
|
| @@ -407,7 +499,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
|
| new_pos = pos + size
|
| if new_pos > end:
|
| raise _DecodeError('Truncated string.')
|
| - field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8')
|
| + field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
|
| return new_pos
|
| return DecodeField
|
|
|
| @@ -511,9 +603,6 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
|
| if value is None:
|
| value = field_dict.setdefault(key, new_default(message))
|
| while 1:
|
| - value = field_dict.get(key)
|
| - if value is None:
|
| - value = field_dict.setdefault(key, new_default(message))
|
| # Read length.
|
| (size, pos) = local_DecodeVarint(buffer, pos)
|
| new_pos = pos + size
|
| @@ -626,13 +715,59 @@ def MessageSetItemDecoder(extensions_by_number):
|
| return DecodeItem
|
|
|
| # --------------------------------------------------------------------
|
| +
|
| +def MapDecoder(field_descriptor, new_default, is_message_map):
|
| + """Returns a decoder for a map field."""
|
| +
|
| + key = field_descriptor
|
| + tag_bytes = encoder.TagBytes(field_descriptor.number,
|
| + wire_format.WIRETYPE_LENGTH_DELIMITED)
|
| + tag_len = len(tag_bytes)
|
| + local_DecodeVarint = _DecodeVarint
|
| + # Can't read _concrete_class yet; might not be initialized.
|
| + message_type = field_descriptor.message_type
|
| +
|
| + def DecodeMap(buffer, pos, end, message, field_dict):
|
| + submsg = message_type._concrete_class()
|
| + value = field_dict.get(key)
|
| + if value is None:
|
| + value = field_dict.setdefault(key, new_default(message))
|
| + while 1:
|
| + # Read length.
|
| + (size, pos) = local_DecodeVarint(buffer, pos)
|
| + new_pos = pos + size
|
| + if new_pos > end:
|
| + raise _DecodeError('Truncated message.')
|
| + # Read sub-message.
|
| + submsg.Clear()
|
| + if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
|
| + # The only reason _InternalParse would return early is if it
|
| + # encountered an end-group tag.
|
| + raise _DecodeError('Unexpected end-group tag.')
|
| +
|
| + if is_message_map:
|
| + value[submsg.key].MergeFrom(submsg.value)
|
| + else:
|
| + value[submsg.key] = submsg.value
|
| +
|
| + # Predict that the next tag is another copy of the same repeated field.
|
| + pos = new_pos + tag_len
|
| + if buffer[new_pos:pos] != tag_bytes or new_pos == end:
|
| + # Prediction failed. Return.
|
| + return new_pos
|
| +
|
| + return DecodeMap
|
| +
|
| +# --------------------------------------------------------------------
|
| # Optimization is not as heavy here because calls to SkipField() are rare,
|
| # except for handling end-group tags.
|
|
|
| def _SkipVarint(buffer, pos, end):
|
| """Skip a varint value. Returns the new position."""
|
| -
|
| - while ord(buffer[pos]) & 0x80:
|
| + # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
|
| + # With this code, ord(b'') raises TypeError. Both are handled in
|
| + # python_message.py to generate a 'Truncated message' error.
|
| + while ord(buffer[pos:pos+1]) & 0x80:
|
| pos += 1
|
| pos += 1
|
| if pos > end:
|
| @@ -699,7 +834,6 @@ def _FieldSkipper():
|
| ]
|
|
|
| wiretype_mask = wire_format.TAG_TYPE_MASK
|
| - local_ord = ord
|
|
|
| def SkipField(buffer, pos, end, tag_bytes):
|
| """Skips a field with the specified tag.
|
| @@ -712,7 +846,7 @@ def _FieldSkipper():
|
| """
|
|
|
| # The wire type is always in the first byte since varints are little-endian.
|
| - wire_type = local_ord(tag_bytes[0]) & wiretype_mask
|
| + wire_type = ord(tag_bytes[0:1]) & wiretype_mask
|
| return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
|
|
|
| return SkipField
|
|
|