Index: third_party/protobuf/python/google/protobuf/internal/decoder.py |
diff --git a/third_party/protobuf/python/google/protobuf/internal/decoder.py b/third_party/protobuf/python/google/protobuf/internal/decoder.py |
index cb6f5729dd765e7105a993c6f3060555d5abb6a4..31869e4575ca900440a8ce452ce2af5ed87a19aa 100755 |
--- a/third_party/protobuf/python/google/protobuf/internal/decoder.py |
+++ b/third_party/protobuf/python/google/protobuf/internal/decoder.py |
@@ -1,6 +1,6 @@ |
# Protocol Buffers - Google's data interchange format |
# Copyright 2008 Google Inc. All rights reserved. |
-# http://code.google.com/p/protobuf/ |
+# https://developers.google.com/protocol-buffers/ |
# |
# Redistribution and use in source and binary forms, with or without |
# modification, are permitted provided that the following conditions are |
@@ -81,6 +81,12 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it. |
__author__ = 'kenton@google.com (Kenton Varda)' |
import struct |
+ |
+import six |
+ |
+if six.PY3: |
+ long = int |
+ |
from google.protobuf.internal import encoder |
from google.protobuf.internal import wire_format |
from google.protobuf import message |
@@ -98,7 +104,7 @@ _NAN = _POS_INF * 0 |
_DecodeError = message.DecodeError |
-def _VarintDecoder(mask): |
+def _VarintDecoder(mask, result_type): |
"""Return an encoder for a basic varint value (does not include tag). |
Decoded values will be bitwise-anded with the given mask before being |
@@ -108,16 +114,16 @@ def _VarintDecoder(mask): |
decoder returns a (value, new_pos) pair. |
""" |
- local_ord = ord |
def DecodeVarint(buffer, pos): |
result = 0 |
shift = 0 |
while 1: |
- b = local_ord(buffer[pos]) |
+ b = six.indexbytes(buffer, pos) |
result |= ((b & 0x7f) << shift) |
pos += 1 |
if not (b & 0x80): |
result &= mask |
+ result = result_type(result) |
return (result, pos) |
shift += 7 |
if shift >= 64: |
@@ -125,15 +131,14 @@ def _VarintDecoder(mask): |
return DecodeVarint |
-def _SignedVarintDecoder(mask): |
+def _SignedVarintDecoder(mask, result_type): |
"""Like _VarintDecoder() but decodes signed values.""" |
- local_ord = ord |
def DecodeVarint(buffer, pos): |
result = 0 |
shift = 0 |
while 1: |
- b = local_ord(buffer[pos]) |
+ b = six.indexbytes(buffer, pos) |
result |= ((b & 0x7f) << shift) |
pos += 1 |
if not (b & 0x80): |
@@ -142,19 +147,23 @@ def _SignedVarintDecoder(mask): |
result |= ~mask |
else: |
result &= mask |
+ result = result_type(result) |
return (result, pos) |
shift += 7 |
if shift >= 64: |
raise _DecodeError('Too many bytes when decoding varint.') |
return DecodeVarint |
+# We force 32-bit values to int and 64-bit values to long to make |
+# alternate implementations where the distinction is more significant |
+# (e.g. the C++ implementation) simpler. |
-_DecodeVarint = _VarintDecoder((1 << 64) - 1) |
-_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1) |
+_DecodeVarint = _VarintDecoder((1 << 64) - 1, long) |
+_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long) |
# Use these versions for values which must be limited to 32 bits. |
-_DecodeVarint32 = _VarintDecoder((1 << 32) - 1) |
-_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1) |
+_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) |
+_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int) |
def ReadTag(buffer, pos): |
@@ -169,7 +178,7 @@ def ReadTag(buffer, pos): |
""" |
start = pos |
- while ord(buffer[pos]) & 0x80: |
+ while six.indexbytes(buffer, pos) & 0x80: |
pos += 1 |
pos += 1 |
return (buffer[start:pos], pos) |
@@ -294,13 +303,12 @@ def _FloatDecoder(): |
# If this value has all its exponent bits set, then it's non-finite. |
# In Python 2.4, struct.unpack will convert it to a finite 64-bit value. |
# To avoid that, we parse it specially. |
- if ((float_bytes[3] in '\x7F\xFF') |
- and (float_bytes[2] >= '\x80')): |
+ if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): |
# If at least one significand bit is set... |
- if float_bytes[0:3] != '\x00\x00\x80': |
+ if float_bytes[0:3] != b'\x00\x00\x80': |
return (_NAN, new_pos) |
# If sign bit is set... |
- if float_bytes[3] == '\xFF': |
+ if float_bytes[3:4] == b'\xFF': |
return (_NEG_INF, new_pos) |
return (_POS_INF, new_pos) |
@@ -329,9 +337,9 @@ def _DoubleDecoder(): |
# If this value has all its exponent bits set and at least one significand |
# bit set, it's not a number. In Python 2.4, struct.unpack will treat it |
# as inf or -inf. To avoid that, we treat it specially. |
- if ((double_bytes[7] in '\x7F\xFF') |
- and (double_bytes[6] >= '\xF0') |
- and (double_bytes[0:7] != '\x00\x00\x00\x00\x00\x00\xF0')): |
+ if ((double_bytes[7:8] in b'\x7F\xFF') |
+ and (double_bytes[6:7] >= b'\xF0') |
+ and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): |
return (_NAN, new_pos) |
# Note that we expect someone up-stack to catch struct.error and convert |
@@ -342,10 +350,86 @@ def _DoubleDecoder(): |
return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) |
+def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): |
+ enum_type = key.enum_type |
+ if is_packed: |
+ local_DecodeVarint = _DecodeVarint |
+ def DecodePackedField(buffer, pos, end, message, field_dict): |
+ value = field_dict.get(key) |
+ if value is None: |
+ value = field_dict.setdefault(key, new_default(message)) |
+ (endpoint, pos) = local_DecodeVarint(buffer, pos) |
+ endpoint += pos |
+ if endpoint > end: |
+ raise _DecodeError('Truncated message.') |
+ while pos < endpoint: |
+ value_start_pos = pos |
+ (element, pos) = _DecodeSignedVarint32(buffer, pos) |
+ if element in enum_type.values_by_number: |
+ value.append(element) |
+ else: |
+ if not message._unknown_fields: |
+ message._unknown_fields = [] |
+ tag_bytes = encoder.TagBytes(field_number, |
+ wire_format.WIRETYPE_VARINT) |
+ message._unknown_fields.append( |
+ (tag_bytes, buffer[value_start_pos:pos])) |
+ if pos > endpoint: |
+ if element in enum_type.values_by_number: |
+ del value[-1] # Discard corrupt value. |
+ else: |
+ del message._unknown_fields[-1] |
+ raise _DecodeError('Packed element was truncated.') |
+ return pos |
+ return DecodePackedField |
+ elif is_repeated: |
+ tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) |
+ tag_len = len(tag_bytes) |
+ def DecodeRepeatedField(buffer, pos, end, message, field_dict): |
+ value = field_dict.get(key) |
+ if value is None: |
+ value = field_dict.setdefault(key, new_default(message)) |
+ while 1: |
+ (element, new_pos) = _DecodeSignedVarint32(buffer, pos) |
+ if element in enum_type.values_by_number: |
+ value.append(element) |
+ else: |
+ if not message._unknown_fields: |
+ message._unknown_fields = [] |
+ message._unknown_fields.append( |
+ (tag_bytes, buffer[pos:new_pos])) |
+ # Predict that the next tag is another copy of the same repeated |
+ # field. |
+ pos = new_pos + tag_len |
+ if buffer[new_pos:pos] != tag_bytes or new_pos >= end: |
+ # Prediction failed. Return. |
+ if new_pos > end: |
+ raise _DecodeError('Truncated message.') |
+ return new_pos |
+ return DecodeRepeatedField |
+ else: |
+ def DecodeField(buffer, pos, end, message, field_dict): |
+ value_start_pos = pos |
+ (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) |
+ if pos > end: |
+ raise _DecodeError('Truncated message.') |
+ if enum_value in enum_type.values_by_number: |
+ field_dict[key] = enum_value |
+ else: |
+ if not message._unknown_fields: |
+ message._unknown_fields = [] |
+ tag_bytes = encoder.TagBytes(field_number, |
+ wire_format.WIRETYPE_VARINT) |
+ message._unknown_fields.append( |
+ (tag_bytes, buffer[value_start_pos:pos])) |
+ return pos |
+ return DecodeField |
+ |
+ |
# -------------------------------------------------------------------- |
-Int32Decoder = EnumDecoder = _SimpleDecoder( |
+Int32Decoder = _SimpleDecoder( |
wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) |
Int64Decoder = _SimpleDecoder( |
@@ -378,7 +462,15 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): |
"""Returns a decoder for a string field.""" |
local_DecodeVarint = _DecodeVarint |
- local_unicode = unicode |
+ local_unicode = six.text_type |
+ |
+ def _ConvertToUnicode(byte_str): |
+ try: |
+ return local_unicode(byte_str, 'utf-8') |
+ except UnicodeDecodeError as e: |
+ # add more information to the error message and re-raise it. |
+ e.reason = '%s in field: %s' % (e, key.full_name) |
+ raise |
assert not is_packed |
if is_repeated: |
@@ -394,7 +486,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): |
new_pos = pos + size |
if new_pos > end: |
raise _DecodeError('Truncated string.') |
- value.append(local_unicode(buffer[pos:new_pos], 'utf-8')) |
+ value.append(_ConvertToUnicode(buffer[pos:new_pos])) |
# Predict that the next tag is another copy of the same repeated field. |
pos = new_pos + tag_len |
if buffer[new_pos:pos] != tag_bytes or new_pos == end: |
@@ -407,7 +499,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): |
new_pos = pos + size |
if new_pos > end: |
raise _DecodeError('Truncated string.') |
- field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8') |
+ field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) |
return new_pos |
return DecodeField |
@@ -511,9 +603,6 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): |
if value is None: |
value = field_dict.setdefault(key, new_default(message)) |
while 1: |
- value = field_dict.get(key) |
- if value is None: |
- value = field_dict.setdefault(key, new_default(message)) |
# Read length. |
(size, pos) = local_DecodeVarint(buffer, pos) |
new_pos = pos + size |
@@ -626,13 +715,59 @@ def MessageSetItemDecoder(extensions_by_number): |
return DecodeItem |
# -------------------------------------------------------------------- |
+ |
+def MapDecoder(field_descriptor, new_default, is_message_map): |
+ """Returns a decoder for a map field.""" |
+ |
+ key = field_descriptor |
+ tag_bytes = encoder.TagBytes(field_descriptor.number, |
+ wire_format.WIRETYPE_LENGTH_DELIMITED) |
+ tag_len = len(tag_bytes) |
+ local_DecodeVarint = _DecodeVarint |
+ # Can't read _concrete_class yet; might not be initialized. |
+ message_type = field_descriptor.message_type |
+ |
+ def DecodeMap(buffer, pos, end, message, field_dict): |
+ submsg = message_type._concrete_class() |
+ value = field_dict.get(key) |
+ if value is None: |
+ value = field_dict.setdefault(key, new_default(message)) |
+ while 1: |
+ # Read length. |
+ (size, pos) = local_DecodeVarint(buffer, pos) |
+ new_pos = pos + size |
+ if new_pos > end: |
+ raise _DecodeError('Truncated message.') |
+ # Read sub-message. |
+ submsg.Clear() |
+ if submsg._InternalParse(buffer, pos, new_pos) != new_pos: |
+ # The only reason _InternalParse would return early is if it |
+ # encountered an end-group tag. |
+ raise _DecodeError('Unexpected end-group tag.') |
+ |
+ if is_message_map: |
+ value[submsg.key].MergeFrom(submsg.value) |
+ else: |
+ value[submsg.key] = submsg.value |
+ |
+ # Predict that the next tag is another copy of the same repeated field. |
+ pos = new_pos + tag_len |
+ if buffer[new_pos:pos] != tag_bytes or new_pos == end: |
+ # Prediction failed. Return. |
+ return new_pos |
+ |
+ return DecodeMap |
+ |
+# -------------------------------------------------------------------- |
# Optimization is not as heavy here because calls to SkipField() are rare, |
# except for handling end-group tags. |
def _SkipVarint(buffer, pos, end): |
"""Skip a varint value. Returns the new position.""" |
- |
- while ord(buffer[pos]) & 0x80: |
+ # Previously ord(buffer[pos]) raised IndexError when pos is out of range. |
+ # With this code, ord(b'') raises TypeError. Both are handled in |
+ # python_message.py to generate a 'Truncated message' error. |
+ while ord(buffer[pos:pos+1]) & 0x80: |
pos += 1 |
pos += 1 |
if pos > end: |
@@ -699,7 +834,6 @@ def _FieldSkipper(): |
] |
wiretype_mask = wire_format.TAG_TYPE_MASK |
- local_ord = ord |
def SkipField(buffer, pos, end, tag_bytes): |
"""Skips a field with the specified tag. |
@@ -712,7 +846,7 @@ def _FieldSkipper(): |
""" |
# The wire type is always in the first byte since varints are little-endian. |
- wire_type = local_ord(tag_bytes[0]) & wiretype_mask |
+ wire_type = ord(tag_bytes[0:1]) & wiretype_mask |
return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) |
return SkipField |