Index: third_party/protobuf/python/google/protobuf/internal/message_test.py |
=================================================================== |
--- third_party/protobuf/python/google/protobuf/internal/message_test.py (revision 216642) |
+++ third_party/protobuf/python/google/protobuf/internal/message_test.py (working copy) |
@@ -45,10 +45,15 @@ |
import copy |
import math |
+import operator |
+import pickle |
+ |
import unittest |
from google.protobuf import unittest_import_pb2 |
from google.protobuf import unittest_pb2 |
+from google.protobuf.internal import api_implementation |
from google.protobuf.internal import test_util |
+from google.protobuf import message |
# Python pre-2.6 does not have isinf() or isnan() functions, so we have |
# to provide our own. |
@@ -70,9 +75,9 @@ |
golden_message = unittest_pb2.TestAllTypes() |
golden_message.ParseFromString(golden_data) |
test_util.ExpectAllFieldsSet(self, golden_message) |
- self.assertTrue(golden_message.SerializeToString() == golden_data) |
+ self.assertEqual(golden_data, golden_message.SerializeToString()) |
golden_copy = copy.deepcopy(golden_message) |
- self.assertTrue(golden_copy.SerializeToString() == golden_data) |
+ self.assertEqual(golden_data, golden_copy.SerializeToString()) |
def testGoldenExtensions(self): |
golden_data = test_util.GoldenFile('golden_message').read() |
@@ -81,9 +86,9 @@ |
all_set = unittest_pb2.TestAllExtensions() |
test_util.SetAllExtensions(all_set) |
self.assertEquals(all_set, golden_message) |
- self.assertTrue(golden_message.SerializeToString() == golden_data) |
+ self.assertEqual(golden_data, golden_message.SerializeToString()) |
golden_copy = copy.deepcopy(golden_message) |
- self.assertTrue(golden_copy.SerializeToString() == golden_data) |
+ self.assertEqual(golden_data, golden_copy.SerializeToString()) |
def testGoldenPackedMessage(self): |
golden_data = test_util.GoldenFile('golden_packed_fields_message').read() |
@@ -92,9 +97,9 @@ |
all_set = unittest_pb2.TestPackedTypes() |
test_util.SetAllPackedFields(all_set) |
self.assertEquals(all_set, golden_message) |
- self.assertTrue(all_set.SerializeToString() == golden_data) |
+ self.assertEqual(golden_data, all_set.SerializeToString()) |
golden_copy = copy.deepcopy(golden_message) |
- self.assertTrue(golden_copy.SerializeToString() == golden_data) |
+ self.assertEqual(golden_data, golden_copy.SerializeToString()) |
def testGoldenPackedExtensions(self): |
golden_data = test_util.GoldenFile('golden_packed_fields_message').read() |
@@ -103,10 +108,29 @@ |
all_set = unittest_pb2.TestPackedExtensions() |
test_util.SetAllPackedExtensions(all_set) |
self.assertEquals(all_set, golden_message) |
- self.assertTrue(all_set.SerializeToString() == golden_data) |
+ self.assertEqual(golden_data, all_set.SerializeToString()) |
golden_copy = copy.deepcopy(golden_message) |
- self.assertTrue(golden_copy.SerializeToString() == golden_data) |
+ self.assertEqual(golden_data, golden_copy.SerializeToString()) |
+ def testPickleSupport(self): |
+ golden_data = test_util.GoldenFile('golden_message').read() |
+ golden_message = unittest_pb2.TestAllTypes() |
+ golden_message.ParseFromString(golden_data) |
+ pickled_message = pickle.dumps(golden_message) |
+ |
+ unpickled_message = pickle.loads(pickled_message) |
+ self.assertEquals(unpickled_message, golden_message) |
+ |
+ def testPickleIncompleteProto(self): |
+ golden_message = unittest_pb2.TestRequired(a=1) |
+ pickled_message = pickle.dumps(golden_message) |
+ |
+ unpickled_message = pickle.loads(pickled_message) |
+ self.assertEquals(unpickled_message, golden_message) |
+ self.assertEquals(unpickled_message.a, 1) |
+ # This is still an incomplete proto - so serializing should fail |
+ self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) |
+ |
def testPositiveInfinity(self): |
golden_data = ('\x5D\x00\x00\x80\x7F' |
'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' |
@@ -118,7 +142,7 @@ |
self.assertTrue(IsPosInf(golden_message.optional_double)) |
self.assertTrue(IsPosInf(golden_message.repeated_float[0])) |
self.assertTrue(IsPosInf(golden_message.repeated_double[0])) |
- self.assertTrue(golden_message.SerializeToString() == golden_data) |
+ self.assertEqual(golden_data, golden_message.SerializeToString()) |
def testNegativeInfinity(self): |
golden_data = ('\x5D\x00\x00\x80\xFF' |
@@ -131,7 +155,7 @@ |
self.assertTrue(IsNegInf(golden_message.optional_double)) |
self.assertTrue(IsNegInf(golden_message.repeated_float[0])) |
self.assertTrue(IsNegInf(golden_message.repeated_double[0])) |
- self.assertTrue(golden_message.SerializeToString() == golden_data) |
+ self.assertEqual(golden_data, golden_message.SerializeToString()) |
def testNotANumber(self): |
golden_data = ('\x5D\x00\x00\xC0\x7F' |
@@ -144,8 +168,19 @@ |
self.assertTrue(isnan(golden_message.optional_double)) |
self.assertTrue(isnan(golden_message.repeated_float[0])) |
self.assertTrue(isnan(golden_message.repeated_double[0])) |
- self.assertTrue(golden_message.SerializeToString() == golden_data) |
+ # The protocol buffer may serialize to any one of multiple different |
+ # representations of a NaN. Rather than verify a specific representation, |
+ # verify the serialized string can be converted into a correctly |
+ # behaving protocol buffer. |
+ serialized = golden_message.SerializeToString() |
+ message = unittest_pb2.TestAllTypes() |
+ message.ParseFromString(serialized) |
+ self.assertTrue(isnan(message.optional_float)) |
+ self.assertTrue(isnan(message.optional_double)) |
+ self.assertTrue(isnan(message.repeated_float[0])) |
+ self.assertTrue(isnan(message.repeated_double[0])) |
+ |
def testPositiveInfinityPacked(self): |
golden_data = ('\xA2\x06\x04\x00\x00\x80\x7F' |
'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') |
@@ -153,7 +188,7 @@ |
golden_message.ParseFromString(golden_data) |
self.assertTrue(IsPosInf(golden_message.packed_float[0])) |
self.assertTrue(IsPosInf(golden_message.packed_double[0])) |
- self.assertTrue(golden_message.SerializeToString() == golden_data) |
+ self.assertEqual(golden_data, golden_message.SerializeToString()) |
def testNegativeInfinityPacked(self): |
golden_data = ('\xA2\x06\x04\x00\x00\x80\xFF' |
@@ -162,7 +197,7 @@ |
golden_message.ParseFromString(golden_data) |
self.assertTrue(IsNegInf(golden_message.packed_float[0])) |
self.assertTrue(IsNegInf(golden_message.packed_double[0])) |
- self.assertTrue(golden_message.SerializeToString() == golden_data) |
+ self.assertEqual(golden_data, golden_message.SerializeToString()) |
def testNotANumberPacked(self): |
golden_data = ('\xA2\x06\x04\x00\x00\xC0\x7F' |
@@ -171,8 +206,13 @@ |
golden_message.ParseFromString(golden_data) |
self.assertTrue(isnan(golden_message.packed_float[0])) |
self.assertTrue(isnan(golden_message.packed_double[0])) |
- self.assertTrue(golden_message.SerializeToString() == golden_data) |
+ serialized = golden_message.SerializeToString() |
+ message = unittest_pb2.TestPackedTypes() |
+ message.ParseFromString(serialized) |
+ self.assertTrue(isnan(message.packed_float[0])) |
+ self.assertTrue(isnan(message.packed_double[0])) |
+ |
def testExtremeFloatValues(self): |
message = unittest_pb2.TestAllTypes() |
@@ -218,7 +258,7 @@ |
message.ParseFromString(message.SerializeToString()) |
self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit) |
- def testExtremeFloatValues(self): |
+ def testExtremeDoubleValues(self): |
message = unittest_pb2.TestAllTypes() |
# Most positive exponent, no significand bits set. |
@@ -338,6 +378,117 @@ |
self.assertEqual(message.repeated_nested_message[4].bb, 5) |
self.assertEqual(message.repeated_nested_message[5].bb, 6) |
+ def testRepeatedCompositeFieldSortArguments(self): |
+ """Check sorting a repeated composite field using list.sort() arguments.""" |
+ message = unittest_pb2.TestAllTypes() |
+ get_bb = operator.attrgetter('bb') |
+ cmp_bb = lambda a, b: cmp(a.bb, b.bb) |
+ message.repeated_nested_message.add().bb = 1 |
+ message.repeated_nested_message.add().bb = 3 |
+ message.repeated_nested_message.add().bb = 2 |
+ message.repeated_nested_message.add().bb = 6 |
+ message.repeated_nested_message.add().bb = 5 |
+ message.repeated_nested_message.add().bb = 4 |
+ message.repeated_nested_message.sort(key=get_bb) |
+ self.assertEqual([k.bb for k in message.repeated_nested_message], |
+ [1, 2, 3, 4, 5, 6]) |
+ message.repeated_nested_message.sort(key=get_bb, reverse=True) |
+ self.assertEqual([k.bb for k in message.repeated_nested_message], |
+ [6, 5, 4, 3, 2, 1]) |
+ message.repeated_nested_message.sort(sort_function=cmp_bb) |
+ self.assertEqual([k.bb for k in message.repeated_nested_message], |
+ [1, 2, 3, 4, 5, 6]) |
+ message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True) |
+ self.assertEqual([k.bb for k in message.repeated_nested_message], |
+ [6, 5, 4, 3, 2, 1]) |
+ |
+ def testRepeatedScalarFieldSortArguments(self): |
+ """Check sorting a scalar field using list.sort() arguments.""" |
+ message = unittest_pb2.TestAllTypes() |
+ |
+ abs_cmp = lambda a, b: cmp(abs(a), abs(b)) |
+ message.repeated_int32.append(-3) |
+ message.repeated_int32.append(-2) |
+ message.repeated_int32.append(-1) |
+ message.repeated_int32.sort(key=abs) |
+ self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) |
+ message.repeated_int32.sort(key=abs, reverse=True) |
+ self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) |
+ message.repeated_int32.sort(sort_function=abs_cmp) |
+ self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) |
+ message.repeated_int32.sort(cmp=abs_cmp, reverse=True) |
+ self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) |
+ |
+ len_cmp = lambda a, b: cmp(len(a), len(b)) |
+ message.repeated_string.append('aaa') |
+ message.repeated_string.append('bb') |
+ message.repeated_string.append('c') |
+ message.repeated_string.sort(key=len) |
+ self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) |
+ message.repeated_string.sort(key=len, reverse=True) |
+ self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) |
+ message.repeated_string.sort(sort_function=len_cmp) |
+ self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) |
+ message.repeated_string.sort(cmp=len_cmp, reverse=True) |
+ self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) |
+ |
+ def testParsingMerge(self): |
+ """Check the merge behavior when a required or optional field appears |
+ multiple times in the input.""" |
+ messages = [ |
+ unittest_pb2.TestAllTypes(), |
+ unittest_pb2.TestAllTypes(), |
+ unittest_pb2.TestAllTypes() ] |
+ messages[0].optional_int32 = 1 |
+ messages[1].optional_int64 = 2 |
+ messages[2].optional_int32 = 3 |
+ messages[2].optional_string = 'hello' |
+ |
+ merged_message = unittest_pb2.TestAllTypes() |
+ merged_message.optional_int32 = 3 |
+ merged_message.optional_int64 = 2 |
+ merged_message.optional_string = 'hello' |
+ |
+ generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() |
+ generator.field1.extend(messages) |
+ generator.field2.extend(messages) |
+ generator.field3.extend(messages) |
+ generator.ext1.extend(messages) |
+ generator.ext2.extend(messages) |
+ generator.group1.add().field1.MergeFrom(messages[0]) |
+ generator.group1.add().field1.MergeFrom(messages[1]) |
+ generator.group1.add().field1.MergeFrom(messages[2]) |
+ generator.group2.add().field1.MergeFrom(messages[0]) |
+ generator.group2.add().field1.MergeFrom(messages[1]) |
+ generator.group2.add().field1.MergeFrom(messages[2]) |
+ |
+ data = generator.SerializeToString() |
+ parsing_merge = unittest_pb2.TestParsingMerge() |
+ parsing_merge.ParseFromString(data) |
+ |
+ # Required and optional fields should be merged. |
+ self.assertEqual(parsing_merge.required_all_types, merged_message) |
+ self.assertEqual(parsing_merge.optional_all_types, merged_message) |
+ self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, |
+ merged_message) |
+ self.assertEqual(parsing_merge.Extensions[ |
+ unittest_pb2.TestParsingMerge.optional_ext], |
+ merged_message) |
+ |
+ # Repeated fields should not be merged. |
+ self.assertEqual(len(parsing_merge.repeated_all_types), 3) |
+ self.assertEqual(len(parsing_merge.repeatedgroup), 3) |
+ self.assertEqual(len(parsing_merge.Extensions[ |
+ unittest_pb2.TestParsingMerge.repeated_ext]), 3) |
+ |
+ |
+ def testSortEmptyRepeatedCompositeContainer(self): |
+ """Exercise a scenario that has led to segfaults in the past. |
+ """ |
+ m = unittest_pb2.TestAllTypes() |
+ m.repeated_nested_message.sort() |
+ |
+ |
if __name__ == '__main__': |
unittest.main() |