Index: third_party/protobuf/python/google/protobuf/internal/reflection_test.py |
diff --git a/third_party/protobuf/python/google/protobuf/internal/reflection_test.py b/third_party/protobuf/python/google/protobuf/internal/reflection_test.py |
index 6dc2fffe2b927afd0efb6947bca3f8b268df3c75..0e8810157825a72ba007d601407f07f12190fc45 100755 |
--- a/third_party/protobuf/python/google/protobuf/internal/reflection_test.py |
+++ b/third_party/protobuf/python/google/protobuf/internal/reflection_test.py |
@@ -60,9 +60,13 @@ from google.protobuf.internal import more_messages_pb2 |
from google.protobuf.internal import message_set_extensions_pb2 |
from google.protobuf.internal import wire_format |
from google.protobuf.internal import test_util |
+from google.protobuf.internal import testing_refleaks |
from google.protobuf.internal import decoder |
+BaseTestCase = testing_refleaks.BaseTestCase |
+ |
+ |
class _MiniDecoder(object): |
"""Decodes a stream of values from a string. |
@@ -95,12 +99,12 @@ class _MiniDecoder(object): |
return wire_format.UnpackTag(self.ReadVarint()) |
def ReadFloat(self): |
- result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0] |
+ result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0] |
self._pos += 4 |
return result |
def ReadDouble(self): |
- result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0] |
+ result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0] |
self._pos += 8 |
return result |
@@ -108,7 +112,7 @@ class _MiniDecoder(object): |
return self._pos == len(self._bytes) |
-class ReflectionTest(unittest.TestCase): |
+class ReflectionTest(BaseTestCase): |
def assertListsEqual(self, values, others): |
self.assertEqual(len(values), len(others)) |
@@ -617,9 +621,15 @@ class ReflectionTest(unittest.TestCase): |
self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) |
self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) |
- def testIntegerTypes(self): |
+ def assertIntegerTypes(self, integer_fn): |
+ """Verifies setting of scalar integers. |
+ |
+ Args: |
+ integer_fn: A function to wrap the integers that will be assigned. |
+ """ |
def TestGetAndDeserialize(field_name, value, expected_type): |
proto = unittest_pb2.TestAllTypes() |
+ value = integer_fn(value) |
setattr(proto, field_name, value) |
self.assertIsInstance(getattr(proto, field_name), expected_type) |
proto2 = unittest_pb2.TestAllTypes() |
@@ -631,12 +641,12 @@ class ReflectionTest(unittest.TestCase): |
TestGetAndDeserialize('optional_uint32', 1 << 30, int) |
try: |
integer_64 = long |
- except NameError: # Python3 |
+ except NameError: # Python3 |
integer_64 = int |
if struct.calcsize('L') == 4: |
# Python only has signed ints, so 32-bit python can't fit an uint32 |
# in an int. |
- TestGetAndDeserialize('optional_uint32', 1 << 31, long) |
+ TestGetAndDeserialize('optional_uint32', 1 << 31, integer_64) |
else: |
# 64-bit python can fit uint32 inside an int |
TestGetAndDeserialize('optional_uint32', 1 << 31, int) |
@@ -645,9 +655,33 @@ class ReflectionTest(unittest.TestCase): |
TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64) |
TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64) |
- def testSingleScalarBoundsChecking(self): |
+ def testIntegerTypes(self): |
+ self.assertIntegerTypes(lambda x: x) |
+ |
+ def testNonStandardIntegerTypes(self): |
+ self.assertIntegerTypes(test_util.NonStandardInteger) |
+ |
+ def testIllegalValuesForIntegers(self): |
+ pb = unittest_pb2.TestAllTypes() |
+ |
+ # Strings are illegal, even when the represent an integer. |
+ with self.assertRaises(TypeError): |
+ pb.optional_uint64 = '2' |
+ |
+ # The exact error should propagate with a poorly written custom integer. |
+ with self.assertRaisesRegexp(RuntimeError, 'my_error'): |
+ pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error') |
+ |
+ def assetIntegerBoundsChecking(self, integer_fn): |
+ """Verifies bounds checking for scalar integer fields. |
+ |
+ Args: |
+ integer_fn: A function to wrap the integers that will be assigned. |
+ """ |
def TestMinAndMaxIntegers(field_name, expected_min, expected_max): |
pb = unittest_pb2.TestAllTypes() |
+ expected_min = integer_fn(expected_min) |
+ expected_max = integer_fn(expected_max) |
setattr(pb, field_name, expected_min) |
self.assertEqual(expected_min, getattr(pb, field_name)) |
setattr(pb, field_name, expected_max) |
@@ -659,11 +693,22 @@ class ReflectionTest(unittest.TestCase): |
TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) |
TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) |
TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) |
+ # A bit of white-box testing since -1 is an int and not a long in C++ and |
+ # so goes down a different path. |
+ pb = unittest_pb2.TestAllTypes() |
+ with self.assertRaises(ValueError): |
+ pb.optional_uint64 = integer_fn(-(1 << 63)) |
pb = unittest_pb2.TestAllTypes() |
- pb.optional_nested_enum = 1 |
+ pb.optional_nested_enum = integer_fn(1) |
self.assertEqual(1, pb.optional_nested_enum) |
+ def testSingleScalarBoundsChecking(self): |
+ self.assetIntegerBoundsChecking(lambda x: x) |
+ |
+ def testNonStandardSingleScalarBoundsChecking(self): |
+ self.assetIntegerBoundsChecking(test_util.NonStandardInteger) |
+ |
def testRepeatedScalarTypeSafety(self): |
proto = unittest_pb2.TestAllTypes() |
self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) |
@@ -972,6 +1017,7 @@ class ReflectionTest(unittest.TestCase): |
proto.repeated_nested_message.add(bb=23) |
self.assertEqual(1, len(proto.repeated_nested_message)) |
self.assertEqual(23, proto.repeated_nested_message[0].bb) |
+ self.assertRaises(TypeError, proto.repeated_nested_message.add, 23) |
def testRepeatedCompositeRemove(self): |
proto = unittest_pb2.TestAllTypes() |
@@ -1182,12 +1228,18 @@ class ReflectionTest(unittest.TestCase): |
self.assertTrue(not extendee_proto.HasExtension(extension)) |
def testRegisteredExtensions(self): |
- self.assertTrue('protobuf_unittest.optional_int32_extension' in |
- unittest_pb2.TestAllExtensions._extensions_by_name) |
- self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number) |
+ pool = unittest_pb2.DESCRIPTOR.pool |
+ self.assertTrue( |
+ pool.FindExtensionByNumber( |
+ unittest_pb2.TestAllExtensions.DESCRIPTOR, 1)) |
+ self.assertIs( |
+ pool.FindExtensionByName( |
+ 'protobuf_unittest.optional_int32_extension').containing_type, |
+ unittest_pb2.TestAllExtensions.DESCRIPTOR) |
# Make sure extensions haven't been registered into types that shouldn't |
# have any. |
- self.assertEqual(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) |
+ self.assertEqual(0, len( |
+ pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR))) |
# If message A directly contains message B, and |
# a.HasField('b') is currently False, then mutating any |
@@ -1551,6 +1603,20 @@ class ReflectionTest(unittest.TestCase): |
self.assertFalse(proto.HasField('optional_foreign_message')) |
self.assertEqual(0, proto.optional_foreign_message.c) |
+ def testDisconnectingInOneof(self): |
+ m = unittest_pb2.TestOneof2() # This message has two messages in a oneof. |
+ m.foo_message.qux_int = 5 |
+ sub_message = m.foo_message |
+ # Accessing another message's field does not clear the first one |
+ self.assertEqual(m.foo_lazy_message.qux_int, 0) |
+ self.assertEqual(m.foo_message.qux_int, 5) |
+ # But mutating another message in the oneof detaches the first one. |
+ m.foo_lazy_message.qux_int = 6 |
+ self.assertEqual(m.foo_message.qux_int, 0) |
+ # The reference we got above was detached and is still valid. |
+ self.assertEqual(sub_message.qux_int, 5) |
+ sub_message.qux_int = 7 |
+ |
def testOneOf(self): |
proto = unittest_pb2.TestAllTypes() |
proto.oneof_uint32 = 10 |
@@ -1809,7 +1875,7 @@ class ReflectionTest(unittest.TestCase): |
# into separate TestCase classes. |
-class TestAllTypesEqualityTest(unittest.TestCase): |
+class TestAllTypesEqualityTest(BaseTestCase): |
def setUp(self): |
self.first_proto = unittest_pb2.TestAllTypes() |
@@ -1825,7 +1891,7 @@ class TestAllTypesEqualityTest(unittest.TestCase): |
self.assertEqual(self.first_proto, self.second_proto) |
-class FullProtosEqualityTest(unittest.TestCase): |
+class FullProtosEqualityTest(BaseTestCase): |
"""Equality tests using completely-full protos as a starting point.""" |
@@ -1911,7 +1977,7 @@ class FullProtosEqualityTest(unittest.TestCase): |
self.assertEqual(self.first_proto, self.second_proto) |
-class ExtensionEqualityTest(unittest.TestCase): |
+class ExtensionEqualityTest(BaseTestCase): |
def testExtensionEquality(self): |
first_proto = unittest_pb2.TestAllExtensions() |
@@ -1944,7 +2010,7 @@ class ExtensionEqualityTest(unittest.TestCase): |
self.assertEqual(first_proto, second_proto) |
-class MutualRecursionEqualityTest(unittest.TestCase): |
+class MutualRecursionEqualityTest(BaseTestCase): |
def testEqualityWithMutualRecursion(self): |
first_proto = unittest_pb2.TestMutualRecursionA() |
@@ -1956,7 +2022,7 @@ class MutualRecursionEqualityTest(unittest.TestCase): |
self.assertEqual(first_proto, second_proto) |
-class ByteSizeTest(unittest.TestCase): |
+class ByteSizeTest(BaseTestCase): |
def setUp(self): |
self.proto = unittest_pb2.TestAllTypes() |
@@ -2252,7 +2318,7 @@ class ByteSizeTest(unittest.TestCase): |
# * Handling of empty submessages (with and without "has" |
# bits set). |
-class SerializationTest(unittest.TestCase): |
+class SerializationTest(BaseTestCase): |
def testSerializeEmtpyMessage(self): |
first_proto = unittest_pb2.TestAllTypes() |
@@ -2813,7 +2879,7 @@ class SerializationTest(unittest.TestCase): |
self.assertEqual(3, proto.repeated_int32[2]) |
-class OptionsTest(unittest.TestCase): |
+class OptionsTest(BaseTestCase): |
def testMessageOptions(self): |
proto = message_set_extensions_pb2.TestMessageSet() |
@@ -2840,7 +2906,7 @@ class OptionsTest(unittest.TestCase): |
-class ClassAPITest(unittest.TestCase): |
+class ClassAPITest(BaseTestCase): |
@unittest.skipIf( |
api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, |
@@ -2923,6 +2989,9 @@ class ClassAPITest(unittest.TestCase): |
text_format.Merge(file_descriptor_str, file_descriptor) |
return file_descriptor.SerializeToString() |
+ @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') |
+ # This test can only run once; the second time, it raises errors about |
+ # conflicting message descriptors. |
def testParsingFlatClassWithExplicitClassDeclaration(self): |
"""Test that the generated class can parse a flat message.""" |
# TODO(xiaofeng): This test fails with cpp implemetnation in the call |
@@ -2947,6 +3016,7 @@ class ClassAPITest(unittest.TestCase): |
text_format.Merge(msg_str, msg) |
self.assertEqual(msg.flat, [0, 1, 2]) |
+ @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') |
def testParsingFlatClass(self): |
"""Test that the generated class can parse a flat message.""" |
file_descriptor = descriptor_pb2.FileDescriptorProto() |
@@ -2962,6 +3032,7 @@ class ClassAPITest(unittest.TestCase): |
text_format.Merge(msg_str, msg) |
self.assertEqual(msg.flat, [0, 1, 2]) |
+ @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') |
def testParsingNestedClass(self): |
"""Test that the generated class can parse a nested message.""" |
file_descriptor = descriptor_pb2.FileDescriptorProto() |