Index: third_party/protobuf/python/google/protobuf/internal/message_test.py |
diff --git a/third_party/protobuf/python/google/protobuf/internal/message_test.py b/third_party/protobuf/python/google/protobuf/internal/message_test.py |
index d03f2d25db50e523c4fd20b7603592e6b84ec2f0..4ee31d8ed97cf4255d613672b942a094fafe7239 100755 |
--- a/third_party/protobuf/python/google/protobuf/internal/message_test.py |
+++ b/third_party/protobuf/python/google/protobuf/internal/message_test.py |
@@ -53,22 +53,27 @@ import six |
import sys |
try: |
- import unittest2 as unittest |
+ import unittest2 as unittest #PY26 |
except ImportError: |
import unittest |
-from google.protobuf.internal import _parameterized |
+ |
from google.protobuf import map_unittest_pb2 |
from google.protobuf import unittest_pb2 |
from google.protobuf import unittest_proto3_arena_pb2 |
-from google.protobuf.internal import any_test_pb2 |
+from google.protobuf import descriptor_pb2 |
+from google.protobuf import descriptor_pool |
+from google.protobuf import message_factory |
+from google.protobuf import text_format |
from google.protobuf.internal import api_implementation |
from google.protobuf.internal import packed_field_test_pb2 |
from google.protobuf.internal import test_util |
from google.protobuf import message |
+from google.protobuf.internal import _parameterized |
if six.PY3: |
long = int |
+ |
# Python pre-2.6 does not have isinf() or isnan() functions, so we have |
# to provide our own. |
def isnan(val): |
@@ -1157,6 +1162,7 @@ class Proto2Test(unittest.TestCase): |
unittest_pb2.TestAllTypes(repeated_nested_enum='FOO') |
+ |
# Class to test proto3-only features/behavior (updated field presence & enums) |
class Proto3Test(unittest.TestCase): |
@@ -1259,7 +1265,10 @@ class Proto3Test(unittest.TestCase): |
self.assertFalse(-2**33 in msg.map_int64_int64) |
self.assertFalse(123 in msg.map_uint32_uint32) |
self.assertFalse(2**33 in msg.map_uint64_uint64) |
+ self.assertFalse(123 in msg.map_int32_double) |
+ self.assertFalse(False in msg.map_bool_bool) |
self.assertFalse('abc' in msg.map_string_string) |
+ self.assertFalse(111 in msg.map_int32_bytes) |
self.assertFalse(888 in msg.map_int32_enum) |
# Accessing an unset key returns the default. |
@@ -1267,7 +1276,12 @@ class Proto3Test(unittest.TestCase): |
self.assertEqual(0, msg.map_int64_int64[-2**33]) |
self.assertEqual(0, msg.map_uint32_uint32[123]) |
self.assertEqual(0, msg.map_uint64_uint64[2**33]) |
+ self.assertEqual(0.0, msg.map_int32_double[123]) |
+ self.assertTrue(isinstance(msg.map_int32_double[123], float)) |
+ self.assertEqual(False, msg.map_bool_bool[False]) |
+ self.assertTrue(isinstance(msg.map_bool_bool[False], bool)) |
self.assertEqual('', msg.map_string_string['abc']) |
+ self.assertEqual(b'', msg.map_int32_bytes[111]) |
self.assertEqual(0, msg.map_int32_enum[888]) |
# It also sets the value in the map |
@@ -1275,7 +1289,10 @@ class Proto3Test(unittest.TestCase): |
self.assertTrue(-2**33 in msg.map_int64_int64) |
self.assertTrue(123 in msg.map_uint32_uint32) |
self.assertTrue(2**33 in msg.map_uint64_uint64) |
+ self.assertTrue(123 in msg.map_int32_double) |
+ self.assertTrue(False in msg.map_bool_bool) |
self.assertTrue('abc' in msg.map_string_string) |
+ self.assertTrue(111 in msg.map_int32_bytes) |
self.assertTrue(888 in msg.map_int32_enum) |
self.assertIsInstance(msg.map_string_string['abc'], six.text_type) |
@@ -1448,6 +1465,22 @@ class Proto3Test(unittest.TestCase): |
del msg2.map_int32_foreign_message[222] |
self.assertFalse(222 in msg2.map_int32_foreign_message) |
+ def testMergeFromBadType(self): |
+ msg = map_unittest_pb2.TestMap() |
+ with self.assertRaisesRegexp( |
+ TypeError, |
+ r'Parameter to MergeFrom\(\) must be instance of same class: expected ' |
+ r'.*TestMap got int\.'): |
+ msg.MergeFrom(1) |
+ |
+ def testCopyFromBadType(self): |
+ msg = map_unittest_pb2.TestMap() |
+ with self.assertRaisesRegexp( |
+ TypeError, |
+ r'Parameter to [A-Za-z]*From\(\) must be instance of same class: ' |
+ r'expected .*TestMap got int\.'): |
+ msg.CopyFrom(1) |
+ |
def testIntegerMapWithLongs(self): |
msg = map_unittest_pb2.TestMap() |
msg.map_int32_int32[long(-123)] = long(-456) |
@@ -1565,6 +1598,21 @@ class Proto3Test(unittest.TestCase): |
matching_dict = {2: 4, 3: 6, 4: 8} |
self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict) |
+ def testMapItems(self): |
+ # Map items used to have strange behaviors when use c extension. Because |
+ # [] may reorder the map and invalidate any exsting iterators. |
+ # TODO(jieluo): Check if [] reordering the map is a bug or intended |
+ # behavior. |
+ msg = map_unittest_pb2.TestMap() |
+ msg.map_string_string['local_init_op'] = '' |
+ msg.map_string_string['trainable_variables'] = '' |
+ msg.map_string_string['variables'] = '' |
+ msg.map_string_string['init_op'] = '' |
+ msg.map_string_string['summaries'] = '' |
+ items1 = msg.map_string_string.items() |
+ items2 = msg.map_string_string.items() |
+ self.assertEqual(items1, items2) |
+ |
def testMapIterationClearMessage(self): |
# Iterator needs to work even if message and map are deleted. |
msg = map_unittest_pb2.TestMap() |
@@ -1666,37 +1714,6 @@ class Proto3Test(unittest.TestCase): |
msg.map_string_foreign_message['foo'].c = 5 |
self.assertEqual(0, len(msg.FindInitializationErrors())) |
- def testAnyMessage(self): |
- # Creates and sets message. |
- msg = any_test_pb2.TestAny() |
- msg_descriptor = msg.DESCRIPTOR |
- all_types = unittest_pb2.TestAllTypes() |
- all_descriptor = all_types.DESCRIPTOR |
- all_types.repeated_string.append(u'\u00fc\ua71f') |
- # Packs to Any. |
- msg.value.Pack(all_types) |
- self.assertEqual(msg.value.type_url, |
- 'type.googleapis.com/%s' % all_descriptor.full_name) |
- self.assertEqual(msg.value.value, |
- all_types.SerializeToString()) |
- # Tests Is() method. |
- self.assertTrue(msg.value.Is(all_descriptor)) |
- self.assertFalse(msg.value.Is(msg_descriptor)) |
- # Unpacks Any. |
- unpacked_message = unittest_pb2.TestAllTypes() |
- self.assertTrue(msg.value.Unpack(unpacked_message)) |
- self.assertEqual(all_types, unpacked_message) |
- # Unpacks to different type. |
- self.assertFalse(msg.value.Unpack(msg)) |
- # Only Any messages have Pack method. |
- try: |
- msg.Pack(all_types) |
- except AttributeError: |
- pass |
- else: |
- raise AttributeError('%s should not have Pack method.' % |
- msg_descriptor.full_name) |
- |
class ValidTypeNamesTest(unittest.TestCase): |
@@ -1776,5 +1793,60 @@ class PackedFieldTest(unittest.TestCase): |
b'\x70\x01') |
self.assertEqual(golden_data, message.SerializeToString()) |
+ |
+@unittest.skipIf(api_implementation.Type() != 'cpp', |
+ 'explicit tests of the C++ implementation') |
+class OversizeProtosTest(unittest.TestCase): |
+ |
+ def setUp(self): |
+ self.file_desc = """ |
+ name: "f/f.msg2" |
+ package: "f" |
+ message_type { |
+ name: "msg1" |
+ field { |
+ name: "payload" |
+ number: 1 |
+ label: LABEL_OPTIONAL |
+ type: TYPE_STRING |
+ } |
+ } |
+ message_type { |
+ name: "msg2" |
+ field { |
+ name: "field" |
+ number: 1 |
+ label: LABEL_OPTIONAL |
+ type: TYPE_MESSAGE |
+ type_name: "msg1" |
+ } |
+ } |
+ """ |
+ pool = descriptor_pool.DescriptorPool() |
+ desc = descriptor_pb2.FileDescriptorProto() |
+ text_format.Parse(self.file_desc, desc) |
+ pool.Add(desc) |
+ self.proto_cls = message_factory.MessageFactory(pool).GetPrototype( |
+ pool.FindMessageTypeByName('f.msg2')) |
+ self.p = self.proto_cls() |
+ self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1) |
+ self.p_serialized = self.p.SerializeToString() |
+ |
+ def testAssertOversizeProto(self): |
+ from google.protobuf.pyext._message import SetAllowOversizeProtos |
+ SetAllowOversizeProtos(False) |
+ q = self.proto_cls() |
+ try: |
+ q.ParseFromString(self.p_serialized) |
+ except message.DecodeError as e: |
+ self.assertEqual(str(e), 'Error parsing message') |
+ |
+ def testSucceedOversizeProto(self): |
+ from google.protobuf.pyext._message import SetAllowOversizeProtos |
+ SetAllowOversizeProtos(True) |
+ q = self.proto_cls() |
+ q.ParseFromString(self.p_serialized) |
+ self.assertEqual(self.p.field.payload, q.field.payload) |
+ |
if __name__ == '__main__': |
unittest.main() |