| Index: third_party/protobuf/python/google/protobuf/internal/reflection_test.py
|
| diff --git a/third_party/protobuf/python/google/protobuf/internal/reflection_test.py b/third_party/protobuf/python/google/protobuf/internal/reflection_test.py
|
| index 6dc2fffe2b927afd0efb6947bca3f8b268df3c75..0e8810157825a72ba007d601407f07f12190fc45 100755
|
| --- a/third_party/protobuf/python/google/protobuf/internal/reflection_test.py
|
| +++ b/third_party/protobuf/python/google/protobuf/internal/reflection_test.py
|
| @@ -60,9 +60,13 @@ from google.protobuf.internal import more_messages_pb2
|
| from google.protobuf.internal import message_set_extensions_pb2
|
| from google.protobuf.internal import wire_format
|
| from google.protobuf.internal import test_util
|
| +from google.protobuf.internal import testing_refleaks
|
| from google.protobuf.internal import decoder
|
|
|
|
|
| +BaseTestCase = testing_refleaks.BaseTestCase
|
| +
|
| +
|
| class _MiniDecoder(object):
|
| """Decodes a stream of values from a string.
|
|
|
| @@ -95,12 +99,12 @@ class _MiniDecoder(object):
|
| return wire_format.UnpackTag(self.ReadVarint())
|
|
|
| def ReadFloat(self):
|
| - result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0]
|
| + result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0]
|
| self._pos += 4
|
| return result
|
|
|
| def ReadDouble(self):
|
| - result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0]
|
| + result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0]
|
| self._pos += 8
|
| return result
|
|
|
| @@ -108,7 +112,7 @@ class _MiniDecoder(object):
|
| return self._pos == len(self._bytes)
|
|
|
|
|
| -class ReflectionTest(unittest.TestCase):
|
| +class ReflectionTest(BaseTestCase):
|
|
|
| def assertListsEqual(self, values, others):
|
| self.assertEqual(len(values), len(others))
|
| @@ -617,9 +621,15 @@ class ReflectionTest(unittest.TestCase):
|
| self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
|
| self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
|
|
|
| - def testIntegerTypes(self):
|
| + def assertIntegerTypes(self, integer_fn):
|
| + """Verifies setting of scalar integers.
|
| +
|
| + Args:
|
| + integer_fn: A function to wrap the integers that will be assigned.
|
| + """
|
| def TestGetAndDeserialize(field_name, value, expected_type):
|
| proto = unittest_pb2.TestAllTypes()
|
| + value = integer_fn(value)
|
| setattr(proto, field_name, value)
|
| self.assertIsInstance(getattr(proto, field_name), expected_type)
|
| proto2 = unittest_pb2.TestAllTypes()
|
| @@ -631,12 +641,12 @@ class ReflectionTest(unittest.TestCase):
|
| TestGetAndDeserialize('optional_uint32', 1 << 30, int)
|
| try:
|
| integer_64 = long
|
| - except NameError: # Python3
|
| + except NameError: # Python3
|
| integer_64 = int
|
| if struct.calcsize('L') == 4:
|
| # Python only has signed ints, so 32-bit python can't fit an uint32
|
| # in an int.
|
| - TestGetAndDeserialize('optional_uint32', 1 << 31, long)
|
| + TestGetAndDeserialize('optional_uint32', 1 << 31, integer_64)
|
| else:
|
| # 64-bit python can fit uint32 inside an int
|
| TestGetAndDeserialize('optional_uint32', 1 << 31, int)
|
| @@ -645,9 +655,33 @@ class ReflectionTest(unittest.TestCase):
|
| TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64)
|
| TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64)
|
|
|
| - def testSingleScalarBoundsChecking(self):
|
| + def testIntegerTypes(self):
|
| + self.assertIntegerTypes(lambda x: x)
|
| +
|
| + def testNonStandardIntegerTypes(self):
|
| + self.assertIntegerTypes(test_util.NonStandardInteger)
|
| +
|
| + def testIllegalValuesForIntegers(self):
|
| + pb = unittest_pb2.TestAllTypes()
|
| +
|
| + # Strings are illegal, even when the represent an integer.
|
| + with self.assertRaises(TypeError):
|
| + pb.optional_uint64 = '2'
|
| +
|
| + # The exact error should propagate with a poorly written custom integer.
|
| + with self.assertRaisesRegexp(RuntimeError, 'my_error'):
|
| + pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error')
|
| +
|
| + def assetIntegerBoundsChecking(self, integer_fn):
|
| + """Verifies bounds checking for scalar integer fields.
|
| +
|
| + Args:
|
| + integer_fn: A function to wrap the integers that will be assigned.
|
| + """
|
| def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
|
| pb = unittest_pb2.TestAllTypes()
|
| + expected_min = integer_fn(expected_min)
|
| + expected_max = integer_fn(expected_max)
|
| setattr(pb, field_name, expected_min)
|
| self.assertEqual(expected_min, getattr(pb, field_name))
|
| setattr(pb, field_name, expected_max)
|
| @@ -659,11 +693,22 @@ class ReflectionTest(unittest.TestCase):
|
| TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
|
| TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
|
| TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
|
| + # A bit of white-box testing since -1 is an int and not a long in C++ and
|
| + # so goes down a different path.
|
| + pb = unittest_pb2.TestAllTypes()
|
| + with self.assertRaises(ValueError):
|
| + pb.optional_uint64 = integer_fn(-(1 << 63))
|
|
|
| pb = unittest_pb2.TestAllTypes()
|
| - pb.optional_nested_enum = 1
|
| + pb.optional_nested_enum = integer_fn(1)
|
| self.assertEqual(1, pb.optional_nested_enum)
|
|
|
| + def testSingleScalarBoundsChecking(self):
|
| + self.assetIntegerBoundsChecking(lambda x: x)
|
| +
|
| + def testNonStandardSingleScalarBoundsChecking(self):
|
| + self.assetIntegerBoundsChecking(test_util.NonStandardInteger)
|
| +
|
| def testRepeatedScalarTypeSafety(self):
|
| proto = unittest_pb2.TestAllTypes()
|
| self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
|
| @@ -972,6 +1017,7 @@ class ReflectionTest(unittest.TestCase):
|
| proto.repeated_nested_message.add(bb=23)
|
| self.assertEqual(1, len(proto.repeated_nested_message))
|
| self.assertEqual(23, proto.repeated_nested_message[0].bb)
|
| + self.assertRaises(TypeError, proto.repeated_nested_message.add, 23)
|
|
|
| def testRepeatedCompositeRemove(self):
|
| proto = unittest_pb2.TestAllTypes()
|
| @@ -1182,12 +1228,18 @@ class ReflectionTest(unittest.TestCase):
|
| self.assertTrue(not extendee_proto.HasExtension(extension))
|
|
|
| def testRegisteredExtensions(self):
|
| - self.assertTrue('protobuf_unittest.optional_int32_extension' in
|
| - unittest_pb2.TestAllExtensions._extensions_by_name)
|
| - self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number)
|
| + pool = unittest_pb2.DESCRIPTOR.pool
|
| + self.assertTrue(
|
| + pool.FindExtensionByNumber(
|
| + unittest_pb2.TestAllExtensions.DESCRIPTOR, 1))
|
| + self.assertIs(
|
| + pool.FindExtensionByName(
|
| + 'protobuf_unittest.optional_int32_extension').containing_type,
|
| + unittest_pb2.TestAllExtensions.DESCRIPTOR)
|
| # Make sure extensions haven't been registered into types that shouldn't
|
| # have any.
|
| - self.assertEqual(0, len(unittest_pb2.TestAllTypes._extensions_by_name))
|
| + self.assertEqual(0, len(
|
| + pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR)))
|
|
|
| # If message A directly contains message B, and
|
| # a.HasField('b') is currently False, then mutating any
|
| @@ -1551,6 +1603,20 @@ class ReflectionTest(unittest.TestCase):
|
| self.assertFalse(proto.HasField('optional_foreign_message'))
|
| self.assertEqual(0, proto.optional_foreign_message.c)
|
|
|
| + def testDisconnectingInOneof(self):
|
| + m = unittest_pb2.TestOneof2() # This message has two messages in a oneof.
|
| + m.foo_message.qux_int = 5
|
| + sub_message = m.foo_message
|
| + # Accessing another message's field does not clear the first one
|
| + self.assertEqual(m.foo_lazy_message.qux_int, 0)
|
| + self.assertEqual(m.foo_message.qux_int, 5)
|
| + # But mutating another message in the oneof detaches the first one.
|
| + m.foo_lazy_message.qux_int = 6
|
| + self.assertEqual(m.foo_message.qux_int, 0)
|
| + # The reference we got above was detached and is still valid.
|
| + self.assertEqual(sub_message.qux_int, 5)
|
| + sub_message.qux_int = 7
|
| +
|
| def testOneOf(self):
|
| proto = unittest_pb2.TestAllTypes()
|
| proto.oneof_uint32 = 10
|
| @@ -1809,7 +1875,7 @@ class ReflectionTest(unittest.TestCase):
|
| # into separate TestCase classes.
|
|
|
|
|
| -class TestAllTypesEqualityTest(unittest.TestCase):
|
| +class TestAllTypesEqualityTest(BaseTestCase):
|
|
|
| def setUp(self):
|
| self.first_proto = unittest_pb2.TestAllTypes()
|
| @@ -1825,7 +1891,7 @@ class TestAllTypesEqualityTest(unittest.TestCase):
|
| self.assertEqual(self.first_proto, self.second_proto)
|
|
|
|
|
| -class FullProtosEqualityTest(unittest.TestCase):
|
| +class FullProtosEqualityTest(BaseTestCase):
|
|
|
| """Equality tests using completely-full protos as a starting point."""
|
|
|
| @@ -1911,7 +1977,7 @@ class FullProtosEqualityTest(unittest.TestCase):
|
| self.assertEqual(self.first_proto, self.second_proto)
|
|
|
|
|
| -class ExtensionEqualityTest(unittest.TestCase):
|
| +class ExtensionEqualityTest(BaseTestCase):
|
|
|
| def testExtensionEquality(self):
|
| first_proto = unittest_pb2.TestAllExtensions()
|
| @@ -1944,7 +2010,7 @@ class ExtensionEqualityTest(unittest.TestCase):
|
| self.assertEqual(first_proto, second_proto)
|
|
|
|
|
| -class MutualRecursionEqualityTest(unittest.TestCase):
|
| +class MutualRecursionEqualityTest(BaseTestCase):
|
|
|
| def testEqualityWithMutualRecursion(self):
|
| first_proto = unittest_pb2.TestMutualRecursionA()
|
| @@ -1956,7 +2022,7 @@ class MutualRecursionEqualityTest(unittest.TestCase):
|
| self.assertEqual(first_proto, second_proto)
|
|
|
|
|
| -class ByteSizeTest(unittest.TestCase):
|
| +class ByteSizeTest(BaseTestCase):
|
|
|
| def setUp(self):
|
| self.proto = unittest_pb2.TestAllTypes()
|
| @@ -2252,7 +2318,7 @@ class ByteSizeTest(unittest.TestCase):
|
| # * Handling of empty submessages (with and without "has"
|
| # bits set).
|
|
|
| -class SerializationTest(unittest.TestCase):
|
| +class SerializationTest(BaseTestCase):
|
|
|
| def testSerializeEmtpyMessage(self):
|
| first_proto = unittest_pb2.TestAllTypes()
|
| @@ -2813,7 +2879,7 @@ class SerializationTest(unittest.TestCase):
|
| self.assertEqual(3, proto.repeated_int32[2])
|
|
|
|
|
| -class OptionsTest(unittest.TestCase):
|
| +class OptionsTest(BaseTestCase):
|
|
|
| def testMessageOptions(self):
|
| proto = message_set_extensions_pb2.TestMessageSet()
|
| @@ -2840,7 +2906,7 @@ class OptionsTest(unittest.TestCase):
|
|
|
|
|
|
|
| -class ClassAPITest(unittest.TestCase):
|
| +class ClassAPITest(BaseTestCase):
|
|
|
| @unittest.skipIf(
|
| api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
|
| @@ -2923,6 +2989,9 @@ class ClassAPITest(unittest.TestCase):
|
| text_format.Merge(file_descriptor_str, file_descriptor)
|
| return file_descriptor.SerializeToString()
|
|
|
| + @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
|
| + # This test can only run once; the second time, it raises errors about
|
| + # conflicting message descriptors.
|
| def testParsingFlatClassWithExplicitClassDeclaration(self):
|
| """Test that the generated class can parse a flat message."""
|
| # TODO(xiaofeng): This test fails with cpp implemetnation in the call
|
| @@ -2947,6 +3016,7 @@ class ClassAPITest(unittest.TestCase):
|
| text_format.Merge(msg_str, msg)
|
| self.assertEqual(msg.flat, [0, 1, 2])
|
|
|
| + @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
|
| def testParsingFlatClass(self):
|
| """Test that the generated class can parse a flat message."""
|
| file_descriptor = descriptor_pb2.FileDescriptorProto()
|
| @@ -2962,6 +3032,7 @@ class ClassAPITest(unittest.TestCase):
|
| text_format.Merge(msg_str, msg)
|
| self.assertEqual(msg.flat, [0, 1, 2])
|
|
|
| + @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
|
| def testParsingNestedClass(self):
|
| """Test that the generated class can parse a nested message."""
|
| file_descriptor = descriptor_pb2.FileDescriptorProto()
|
|
|