Index: third_party/protobuf/python/google/protobuf/internal/reflection_test.py |
diff --git a/third_party/protobuf/python/google/protobuf/internal/reflection_test.py b/third_party/protobuf/python/google/protobuf/internal/reflection_test.py |
index ed2864613582780f3c4d370601df05227a2455d4..752f2f5d90cb6a750c9cd27cb0153d7db9165252 100755 |
--- a/third_party/protobuf/python/google/protobuf/internal/reflection_test.py |
+++ b/third_party/protobuf/python/google/protobuf/internal/reflection_test.py |
@@ -1,9 +1,9 @@ |
-#! /usr/bin/python |
+#! /usr/bin/env python |
# -*- coding: utf-8 -*- |
# |
# Protocol Buffers - Google's data interchange format |
# Copyright 2008 Google Inc. All rights reserved. |
-# http://code.google.com/p/protobuf/ |
+# https://developers.google.com/protocol-buffers/ |
# |
# Redistribution and use in source and binary forms, with or without |
# modification, are permitted provided that the following conditions are |
@@ -35,13 +35,16 @@ |
pure-Python protocol compiler. |
""" |
-__author__ = 'robinson@google.com (Will Robinson)' |
- |
+import copy |
import gc |
import operator |
+import six |
import struct |
-import unittest |
+try: |
+ import unittest2 as unittest |
+except ImportError: |
+ import unittest |
from google.protobuf import unittest_import_pb2 |
from google.protobuf import unittest_mset_pb2 |
from google.protobuf import unittest_pb2 |
@@ -49,9 +52,11 @@ from google.protobuf import descriptor_pb2 |
from google.protobuf import descriptor |
from google.protobuf import message |
from google.protobuf import reflection |
+from google.protobuf import text_format |
from google.protobuf.internal import api_implementation |
from google.protobuf.internal import more_extensions_pb2 |
from google.protobuf.internal import more_messages_pb2 |
+from google.protobuf.internal import message_set_extensions_pb2 |
from google.protobuf.internal import wire_format |
from google.protobuf.internal import test_util |
from google.protobuf.internal import decoder |
@@ -128,10 +133,10 @@ class ReflectionTest(unittest.TestCase): |
repeated_bool=[True, False, False], |
repeated_string=["optional_string"]) |
- self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32)) |
- self.assertEquals([1.23, 54.321], list(proto.repeated_double)) |
- self.assertEquals([True, False, False], list(proto.repeated_bool)) |
- self.assertEquals(["optional_string"], list(proto.repeated_string)) |
+ self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32)) |
+ self.assertEqual([1.23, 54.321], list(proto.repeated_double)) |
+ self.assertEqual([True, False, False], list(proto.repeated_bool)) |
+ self.assertEqual(["optional_string"], list(proto.repeated_string)) |
def testRepeatedCompositeConstructor(self): |
# Constructor with only repeated composite types should succeed. |
@@ -150,18 +155,18 @@ class ReflectionTest(unittest.TestCase): |
unittest_pb2.TestAllTypes.RepeatedGroup(a=1), |
unittest_pb2.TestAllTypes.RepeatedGroup(a=2)]) |
- self.assertEquals( |
+ self.assertEqual( |
[unittest_pb2.TestAllTypes.NestedMessage( |
bb=unittest_pb2.TestAllTypes.FOO), |
unittest_pb2.TestAllTypes.NestedMessage( |
bb=unittest_pb2.TestAllTypes.BAR)], |
list(proto.repeated_nested_message)) |
- self.assertEquals( |
+ self.assertEqual( |
[unittest_pb2.ForeignMessage(c=-43), |
unittest_pb2.ForeignMessage(c=45324), |
unittest_pb2.ForeignMessage(c=12)], |
list(proto.repeated_foreign_message)) |
- self.assertEquals( |
+ self.assertEqual( |
[unittest_pb2.TestAllTypes.RepeatedGroup(), |
unittest_pb2.TestAllTypes.RepeatedGroup(a=1), |
unittest_pb2.TestAllTypes.RepeatedGroup(a=2)], |
@@ -186,15 +191,15 @@ class ReflectionTest(unittest.TestCase): |
self.assertEqual(24, proto.optional_int32) |
self.assertEqual('optional_string', proto.optional_string) |
- self.assertEquals([1.23, 54.321], list(proto.repeated_double)) |
- self.assertEquals([True, False, False], list(proto.repeated_bool)) |
- self.assertEquals( |
+ self.assertEqual([1.23, 54.321], list(proto.repeated_double)) |
+ self.assertEqual([True, False, False], list(proto.repeated_bool)) |
+ self.assertEqual( |
[unittest_pb2.TestAllTypes.NestedMessage( |
bb=unittest_pb2.TestAllTypes.FOO), |
unittest_pb2.TestAllTypes.NestedMessage( |
bb=unittest_pb2.TestAllTypes.BAR)], |
list(proto.repeated_nested_message)) |
- self.assertEquals( |
+ self.assertEqual( |
[unittest_pb2.ForeignMessage(c=-43), |
unittest_pb2.ForeignMessage(c=45324), |
unittest_pb2.ForeignMessage(c=12)], |
@@ -222,18 +227,18 @@ class ReflectionTest(unittest.TestCase): |
def testConstructorInvalidatesCachedByteSize(self): |
message = unittest_pb2.TestAllTypes(optional_int32 = 12) |
- self.assertEquals(2, message.ByteSize()) |
+ self.assertEqual(2, message.ByteSize()) |
message = unittest_pb2.TestAllTypes( |
optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage()) |
- self.assertEquals(3, message.ByteSize()) |
+ self.assertEqual(3, message.ByteSize()) |
message = unittest_pb2.TestAllTypes(repeated_int32 = [12]) |
- self.assertEquals(3, message.ByteSize()) |
+ self.assertEqual(3, message.ByteSize()) |
message = unittest_pb2.TestAllTypes( |
repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()]) |
- self.assertEquals(3, message.ByteSize()) |
+ self.assertEqual(3, message.ByteSize()) |
def testSimpleHasBits(self): |
# Test a scalar. |
@@ -467,7 +472,7 @@ class ReflectionTest(unittest.TestCase): |
proto.repeated_string.extend(['foo', 'bar']) |
proto.repeated_string.extend([]) |
proto.repeated_string.append('baz') |
- proto.repeated_string.extend(str(x) for x in xrange(2)) |
+ proto.repeated_string.extend(str(x) for x in range(2)) |
proto.optional_int32 = 21 |
proto.repeated_bool # Access but don't set anything; should not be listed. |
self.assertEqual( |
@@ -533,7 +538,7 @@ class ReflectionTest(unittest.TestCase): |
self.assertEqual(0.0, proto.optional_double) |
self.assertEqual(False, proto.optional_bool) |
self.assertEqual('', proto.optional_string) |
- self.assertEqual('', proto.optional_bytes) |
+ self.assertEqual(b'', proto.optional_bytes) |
self.assertEqual(41, proto.default_int32) |
self.assertEqual(42, proto.default_int64) |
@@ -549,7 +554,7 @@ class ReflectionTest(unittest.TestCase): |
self.assertEqual(52e3, proto.default_double) |
self.assertEqual(True, proto.default_bool) |
self.assertEqual('hello', proto.default_string) |
- self.assertEqual('world', proto.default_bytes) |
+ self.assertEqual(b'world', proto.default_bytes) |
self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) |
self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) |
self.assertEqual(unittest_import_pb2.IMPORT_BAR, |
@@ -566,6 +571,17 @@ class ReflectionTest(unittest.TestCase): |
proto = unittest_pb2.TestAllTypes() |
self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') |
+ def testClearRemovesChildren(self): |
+ # Make sure there aren't any implementation bugs that are only partially |
+ # clearing the message (which can happen in the more complex C++ |
+ # implementation which has parallel message lists). |
+ proto = unittest_pb2.TestRequiredForeign() |
+ for i in range(10): |
+ proto.repeated_message.add() |
+ proto2 = unittest_pb2.TestRequiredForeign() |
+ proto.CopyFrom(proto2) |
+ self.assertRaises(IndexError, lambda: proto.repeated_message[5]) |
+ |
def testDisallowedAssignments(self): |
# It's illegal to assign values directly to repeated fields |
# or to nonrepeated composite fields. Ensure that this fails. |
@@ -594,6 +610,34 @@ class ReflectionTest(unittest.TestCase): |
self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) |
self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) |
+ def testIntegerTypes(self): |
+ def TestGetAndDeserialize(field_name, value, expected_type): |
+ proto = unittest_pb2.TestAllTypes() |
+ setattr(proto, field_name, value) |
+ self.assertIsInstance(getattr(proto, field_name), expected_type) |
+ proto2 = unittest_pb2.TestAllTypes() |
+ proto2.ParseFromString(proto.SerializeToString()) |
+ self.assertIsInstance(getattr(proto2, field_name), expected_type) |
+ |
+ TestGetAndDeserialize('optional_int32', 1, int) |
+ TestGetAndDeserialize('optional_int32', 1 << 30, int) |
+ TestGetAndDeserialize('optional_uint32', 1 << 30, int) |
+ try: |
+ integer_64 = long |
+ except NameError: # Python3 |
+ integer_64 = int |
+ if struct.calcsize('L') == 4: |
+ # Python only has signed ints, so 32-bit python can't fit an uint32 |
+ # in an int. |
+ TestGetAndDeserialize('optional_uint32', 1 << 31, long) |
+ else: |
+ # 64-bit python can fit uint32 inside an int |
+ TestGetAndDeserialize('optional_uint32', 1 << 31, int) |
+ TestGetAndDeserialize('optional_int64', 1 << 30, integer_64) |
+ TestGetAndDeserialize('optional_int64', 1 << 60, integer_64) |
+ TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64) |
+ TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64) |
+ |
def testSingleScalarBoundsChecking(self): |
def TestMinAndMaxIntegers(field_name, expected_min, expected_max): |
pb = unittest_pb2.TestAllTypes() |
@@ -613,29 +657,6 @@ class ReflectionTest(unittest.TestCase): |
pb.optional_nested_enum = 1 |
self.assertEqual(1, pb.optional_nested_enum) |
- # Invalid enum values. |
- pb.optional_nested_enum = 0 |
- self.assertEqual(0, pb.optional_nested_enum) |
- |
- bytes_size_before = pb.ByteSize() |
- |
- pb.optional_nested_enum = 4 |
- self.assertEqual(4, pb.optional_nested_enum) |
- |
- pb.optional_nested_enum = 0 |
- self.assertEqual(0, pb.optional_nested_enum) |
- |
- # Make sure that setting the same enum field doesn't just add unknown |
- # fields (but overwrites them). |
- self.assertEqual(bytes_size_before, pb.ByteSize()) |
- |
- # Is the invalid value preserved after serialization? |
- serialized = pb.SerializeToString() |
- pb2 = unittest_pb2.TestAllTypes() |
- pb2.ParseFromString(serialized) |
- self.assertEqual(0, pb2.optional_nested_enum) |
- self.assertEqual(pb, pb2) |
- |
def testRepeatedScalarTypeSafety(self): |
proto = unittest_pb2.TestAllTypes() |
self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) |
@@ -741,18 +762,18 @@ class ReflectionTest(unittest.TestCase): |
def testEnum_KeysAndValues(self): |
self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'], |
- unittest_pb2.ForeignEnum.keys()) |
+ list(unittest_pb2.ForeignEnum.keys())) |
self.assertEqual([4, 5, 6], |
- unittest_pb2.ForeignEnum.values()) |
+ list(unittest_pb2.ForeignEnum.values())) |
self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), |
('FOREIGN_BAZ', 6)], |
- unittest_pb2.ForeignEnum.items()) |
+ list(unittest_pb2.ForeignEnum.items())) |
proto = unittest_pb2.TestAllTypes() |
- self.assertEqual(['FOO', 'BAR', 'BAZ'], proto.NestedEnum.keys()) |
- self.assertEqual([1, 2, 3], proto.NestedEnum.values()) |
- self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3)], |
- proto.NestedEnum.items()) |
+ self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], list(proto.NestedEnum.keys())) |
+ self.assertEqual([1, 2, 3, -1], list(proto.NestedEnum.values())) |
+ self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)], |
+ list(proto.NestedEnum.items())) |
def testRepeatedScalars(self): |
proto = unittest_pb2.TestAllTypes() |
@@ -791,7 +812,7 @@ class ReflectionTest(unittest.TestCase): |
self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:]) |
# Test slice assignment with an iterator |
- proto.repeated_int32[1:4] = (i for i in xrange(3)) |
+ proto.repeated_int32[1:4] = (i for i in range(3)) |
self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32) |
# Test slice assignment. |
@@ -882,7 +903,7 @@ class ReflectionTest(unittest.TestCase): |
self.assertTrue(proto.repeated_nested_message) |
self.assertEqual(2, len(proto.repeated_nested_message)) |
self.assertListsEqual([m0, m1], proto.repeated_nested_message) |
- self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage)) |
+ self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage) |
# Test out-of-bounds indices. |
self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, |
@@ -994,9 +1015,8 @@ class ReflectionTest(unittest.TestCase): |
containing_type=None, nested_types=[], enum_types=[], |
fields=[foo_field_descriptor], extensions=[], |
options=descriptor_pb2.MessageOptions()) |
- class MyProtoClass(message.Message): |
+ class MyProtoClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): |
DESCRIPTOR = mydescriptor |
- __metaclass__ = reflection.GeneratedProtocolMessageType |
myproto_instance = MyProtoClass() |
self.assertEqual(0, myproto_instance.foo_field) |
self.assertTrue(not myproto_instance.HasField('foo_field')) |
@@ -1036,14 +1056,13 @@ class ReflectionTest(unittest.TestCase): |
new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED |
desc = descriptor.MakeDescriptor(desc_proto) |
- self.assertTrue(desc.fields_by_name.has_key('name')) |
- self.assertTrue(desc.fields_by_name.has_key('year')) |
- self.assertTrue(desc.fields_by_name.has_key('automatic')) |
- self.assertTrue(desc.fields_by_name.has_key('price')) |
- self.assertTrue(desc.fields_by_name.has_key('owners')) |
- |
- class CarMessage(message.Message): |
- __metaclass__ = reflection.GeneratedProtocolMessageType |
+ self.assertTrue('name' in desc.fields_by_name) |
+ self.assertTrue('year' in desc.fields_by_name) |
+ self.assertTrue('automatic' in desc.fields_by_name) |
+ self.assertTrue('price' in desc.fields_by_name) |
+ self.assertTrue('owners' in desc.fields_by_name) |
+ |
+ class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): |
DESCRIPTOR = desc |
prius = CarMessage() |
@@ -1155,6 +1174,14 @@ class ReflectionTest(unittest.TestCase): |
self.assertTrue(required is not extendee_proto.Extensions[extension]) |
self.assertTrue(not extendee_proto.HasExtension(extension)) |
+ def testRegisteredExtensions(self): |
+ self.assertTrue('protobuf_unittest.optional_int32_extension' in |
+ unittest_pb2.TestAllExtensions._extensions_by_name) |
+ self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number) |
+ # Make sure extensions haven't been registered into types that shouldn't |
+ # have any. |
+ self.assertEqual(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) |
+ |
# If message A directly contains message B, and |
# a.HasField('b') is currently False, then mutating any |
# extension in B should change a.HasField('b') to True |
@@ -1230,15 +1257,18 @@ class ReflectionTest(unittest.TestCase): |
# Try something that *is* an extension handle, just not for |
# this message... |
- unknown_handle = more_extensions_pb2.optional_int_extension |
- self.assertRaises(KeyError, extendee_proto.HasExtension, |
- unknown_handle) |
- self.assertRaises(KeyError, extendee_proto.ClearExtension, |
- unknown_handle) |
- self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, |
- unknown_handle) |
- self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, |
- unknown_handle, 5) |
+ for unknown_handle in (more_extensions_pb2.optional_int_extension, |
+ more_extensions_pb2.optional_message_extension, |
+ more_extensions_pb2.repeated_int_extension, |
+ more_extensions_pb2.repeated_message_extension): |
+ self.assertRaises(KeyError, extendee_proto.HasExtension, |
+ unknown_handle) |
+ self.assertRaises(KeyError, extendee_proto.ClearExtension, |
+ unknown_handle) |
+ self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, |
+ unknown_handle) |
+ self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, |
+ unknown_handle, 5) |
# Try call HasExtension() with a valid handle, but for a |
# *repeated* field. (Just as with non-extension repeated |
@@ -1451,6 +1481,19 @@ class ReflectionTest(unittest.TestCase): |
proto2 = unittest_pb2.TestAllExtensions() |
self.assertRaises(TypeError, proto1.CopyFrom, proto2) |
+ def testDeepCopy(self): |
+ proto1 = unittest_pb2.TestAllTypes() |
+ proto1.optional_int32 = 1 |
+ proto2 = copy.deepcopy(proto1) |
+ self.assertEqual(1, proto2.optional_int32) |
+ |
+ proto1.repeated_int32.append(2) |
+ proto1.repeated_int32.append(3) |
+ container = copy.deepcopy(proto1.repeated_int32) |
+ self.assertEqual([2, 3], container) |
+ |
+ # TODO(anuraag): Implement deepcopy for repeated composite / extension dict |
+ |
def testClear(self): |
proto = unittest_pb2.TestAllTypes() |
# C++ implementation does not support lazy fields right now so leave it |
@@ -1461,18 +1504,18 @@ class ReflectionTest(unittest.TestCase): |
test_util.SetAllNonLazyFields(proto) |
# Clear the message. |
proto.Clear() |
- self.assertEquals(proto.ByteSize(), 0) |
+ self.assertEqual(proto.ByteSize(), 0) |
empty_proto = unittest_pb2.TestAllTypes() |
- self.assertEquals(proto, empty_proto) |
+ self.assertEqual(proto, empty_proto) |
# Test if extensions which were set are cleared. |
proto = unittest_pb2.TestAllExtensions() |
test_util.SetAllExtensions(proto) |
# Clear the message. |
proto.Clear() |
- self.assertEquals(proto.ByteSize(), 0) |
+ self.assertEqual(proto.ByteSize(), 0) |
empty_proto = unittest_pb2.TestAllExtensions() |
- self.assertEquals(proto, empty_proto) |
+ self.assertEqual(proto, empty_proto) |
def testDisconnectingBeforeClear(self): |
proto = unittest_pb2.TestAllTypes() |
@@ -1496,11 +1539,23 @@ class ReflectionTest(unittest.TestCase): |
self.assertEqual(6, foreign.c) |
nested.bb = 15 |
foreign.c = 16 |
- self.assertTrue(not proto.HasField('optional_nested_message')) |
+ self.assertFalse(proto.HasField('optional_nested_message')) |
self.assertEqual(0, proto.optional_nested_message.bb) |
- self.assertTrue(not proto.HasField('optional_foreign_message')) |
+ self.assertFalse(proto.HasField('optional_foreign_message')) |
self.assertEqual(0, proto.optional_foreign_message.c) |
+ def testOneOf(self): |
+ proto = unittest_pb2.TestAllTypes() |
+ proto.oneof_uint32 = 10 |
+ proto.oneof_nested_message.bb = 11 |
+ self.assertEqual(11, proto.oneof_nested_message.bb) |
+ self.assertFalse(proto.HasField('oneof_uint32')) |
+ nested = proto.oneof_nested_message |
+ proto.oneof_string = 'abc' |
+ self.assertEqual('abc', proto.oneof_string) |
+ self.assertEqual(11, nested.bb) |
+ self.assertFalse(proto.HasField('oneof_nested_message')) |
+ |
def assertInitialized(self, proto): |
self.assertTrue(proto.IsInitialized()) |
# Neither method should raise an exception. |
@@ -1571,6 +1626,40 @@ class ReflectionTest(unittest.TestCase): |
self.assertFalse(proto.IsInitialized(errors)) |
self.assertEqual(errors, ['a', 'b', 'c']) |
+ @unittest.skipIf( |
+ api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, |
+ 'Errors are only available from the most recent C++ implementation.') |
+ def testFileDescriptorErrors(self): |
+ file_name = 'test_file_descriptor_errors.proto' |
+ package_name = 'test_file_descriptor_errors.proto' |
+ file_descriptor_proto = descriptor_pb2.FileDescriptorProto() |
+ file_descriptor_proto.name = file_name |
+ file_descriptor_proto.package = package_name |
+ m1 = file_descriptor_proto.message_type.add() |
+ m1.name = 'msg1' |
+ # Compiles the proto into the C++ descriptor pool |
+ descriptor.FileDescriptor( |
+ file_name, |
+ package_name, |
+ serialized_pb=file_descriptor_proto.SerializeToString()) |
+ # Add a FileDescriptorProto that has duplicate symbols |
+ another_file_name = 'another_test_file_descriptor_errors.proto' |
+ file_descriptor_proto.name = another_file_name |
+ m2 = file_descriptor_proto.message_type.add() |
+ m2.name = 'msg2' |
+ with self.assertRaises(TypeError) as cm: |
+ descriptor.FileDescriptor( |
+ another_file_name, |
+ package_name, |
+ serialized_pb=file_descriptor_proto.SerializeToString()) |
+ self.assertTrue(hasattr(cm, 'exception'), '%s not raised' % |
+ getattr(cm.expected, '__name__', cm.expected)) |
+ self.assertIn('test_file_descriptor_errors.proto', str(cm.exception)) |
+ # Error message will say something about this definition being a |
+ # duplicate, though we don't check the message exactly to avoid a |
+ # dependency on the C++ logging code. |
+ self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) |
+ |
def testStringUTF8Encoding(self): |
proto = unittest_pb2.TestAllTypes() |
@@ -1579,32 +1668,29 @@ class ReflectionTest(unittest.TestCase): |
setattr, proto, 'optional_bytes', u'unicode object') |
# Check that the default value is of python's 'unicode' type. |
- self.assertEqual(type(proto.optional_string), unicode) |
+ self.assertEqual(type(proto.optional_string), six.text_type) |
- proto.optional_string = unicode('Testing') |
+ proto.optional_string = six.text_type('Testing') |
self.assertEqual(proto.optional_string, str('Testing')) |
# Assign a value of type 'str' which can be encoded in UTF-8. |
proto.optional_string = str('Testing') |
- self.assertEqual(proto.optional_string, unicode('Testing')) |
- |
- if api_implementation.Type() == 'python': |
- # Values of type 'str' are also accepted as long as they can be |
- # encoded in UTF-8. |
- self.assertEqual(type(proto.optional_string), str) |
+ self.assertEqual(proto.optional_string, six.text_type('Testing')) |
- # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII. |
- self.assertRaises(ValueError, |
- setattr, proto, 'optional_string', str('a\x80a')) |
- # Assign a 'str' object which contains a UTF-8 encoded string. |
+ # Try to assign a 'bytes' object which contains non-UTF-8. |
self.assertRaises(ValueError, |
- setattr, proto, 'optional_string', 'Тест') |
- # No exception thrown. |
+ setattr, proto, 'optional_string', b'a\x80a') |
+ # No exception: Assign already encoded UTF-8 bytes to a string field. |
+ utf8_bytes = u'Тест'.encode('utf-8') |
+ proto.optional_string = utf8_bytes |
+ # No exception: Assign the a non-ascii unicode object. |
+ proto.optional_string = u'Тест' |
+ # No exception thrown (normal str assignment containing ASCII). |
proto.optional_string = 'abc' |
def testStringUTF8Serialization(self): |
- proto = unittest_mset_pb2.TestMessageSet() |
- extension_message = unittest_mset_pb2.TestMessageSetExtension2 |
+ proto = message_set_extensions_pb2.TestMessageSet() |
+ extension_message = message_set_extensions_pb2.TestMessageSetExtension2 |
extension = extension_message.message_set_extension |
test_utf8 = u'Тест' |
@@ -1621,20 +1707,21 @@ class ReflectionTest(unittest.TestCase): |
self.assertEqual(proto.ByteSize(), len(serialized)) |
raw = unittest_mset_pb2.RawMessageSet() |
- raw.MergeFromString(serialized) |
+ bytes_read = raw.MergeFromString(serialized) |
+ self.assertEqual(len(serialized), bytes_read) |
- message2 = unittest_mset_pb2.TestMessageSetExtension2() |
+ message2 = message_set_extensions_pb2.TestMessageSetExtension2() |
self.assertEqual(1, len(raw.item)) |
# Check that the type_id is the same as the tag ID in the .proto file. |
- self.assertEqual(raw.item[0].type_id, 1547769) |
+ self.assertEqual(raw.item[0].type_id, 98418634) |
# Check the actual bytes on the wire. |
- self.assertTrue( |
- raw.item[0].message.endswith(test_utf8_bytes)) |
- message2.MergeFromString(raw.item[0].message) |
+ self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes)) |
+ bytes_read = message2.MergeFromString(raw.item[0].message) |
+ self.assertEqual(len(raw.item[0].message), bytes_read) |
- self.assertEqual(type(message2.str), unicode) |
+ self.assertEqual(type(message2.str), six.text_type) |
self.assertEqual(message2.str, test_utf8) |
# The pure Python API throws an exception on MergeFromString(), |
@@ -1643,17 +1730,22 @@ class ReflectionTest(unittest.TestCase): |
# MergeFromString and thus has no way to throw the exception. |
# |
# The pure Python API always returns objects of type 'unicode' (UTF-8 |
- # encoded), or 'str' (in 7 bit ASCII). |
- bytes = raw.item[0].message.replace( |
- test_utf8_bytes, len(test_utf8_bytes) * '\xff') |
+ # encoded), or 'bytes' (in 7 bit ASCII). |
+ badbytes = raw.item[0].message.replace( |
+ test_utf8_bytes, len(test_utf8_bytes) * b'\xff') |
unicode_decode_failed = False |
try: |
- message2.MergeFromString(bytes) |
- except UnicodeDecodeError as e: |
+ message2.MergeFromString(badbytes) |
+ except UnicodeDecodeError: |
unicode_decode_failed = True |
string_field = message2.str |
- self.assertTrue(unicode_decode_failed or type(string_field) == str) |
+ self.assertTrue(unicode_decode_failed or type(string_field) is bytes) |
+ |
+ def testBytesInTextFormat(self): |
+ proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') |
+ self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', |
+ six.text_type(proto)) |
def testEmptyNestedMessage(self): |
proto = unittest_pb2.TestAllTypes() |
@@ -1667,16 +1759,19 @@ class ReflectionTest(unittest.TestCase): |
self.assertTrue(proto.HasField('optional_nested_message')) |
proto = unittest_pb2.TestAllTypes() |
- proto.optional_nested_message.MergeFromString('') |
+ bytes_read = proto.optional_nested_message.MergeFromString(b'') |
+ self.assertEqual(0, bytes_read) |
self.assertTrue(proto.HasField('optional_nested_message')) |
proto = unittest_pb2.TestAllTypes() |
- proto.optional_nested_message.ParseFromString('') |
+ proto.optional_nested_message.ParseFromString(b'') |
self.assertTrue(proto.HasField('optional_nested_message')) |
serialized = proto.SerializeToString() |
proto2 = unittest_pb2.TestAllTypes() |
- proto2.MergeFromString(serialized) |
+ self.assertEqual( |
+ len(serialized), |
+ proto2.MergeFromString(serialized)) |
self.assertTrue(proto2.HasField('optional_nested_message')) |
def testSetInParent(self): |
@@ -1685,6 +1780,23 @@ class ReflectionTest(unittest.TestCase): |
proto.optionalgroup.SetInParent() |
self.assertTrue(proto.HasField('optionalgroup')) |
+ def testPackageInitializationImport(self): |
+ """Test that we can import nested messages from their __init__.py. |
+ |
+ Such setup is not trivial since at the time of processing of __init__.py one |
+ can't refer to its submodules by name in code, so expressions like |
+ google.protobuf.internal.import_test_package.inner_pb2 |
+ don't work. They do work in imports, so we have assign an alias at import |
+ and then use that alias in generated code. |
+ """ |
+ # We import here since it's the import that used to fail, and we want |
+ # the failure to have the right context. |
+ # pylint: disable=g-import-not-at-top |
+ from google.protobuf.internal import import_test_package |
+ # pylint: enable=g-import-not-at-top |
+ msg = import_test_package.myproto.Outer() |
+ # Just check the default value. |
+ self.assertEqual(57, msg.inner.value) |
# Since we had so many tests for protocol buffer equality, we broke these out |
# into separate TestCase classes. |
@@ -2140,7 +2252,9 @@ class SerializationTest(unittest.TestCase): |
second_proto = unittest_pb2.TestAllTypes() |
serialized = first_proto.SerializeToString() |
self.assertEqual(first_proto.ByteSize(), len(serialized)) |
- second_proto.MergeFromString(serialized) |
+ self.assertEqual( |
+ len(serialized), |
+ second_proto.MergeFromString(serialized)) |
self.assertEqual(first_proto, second_proto) |
def testSerializeAllFields(self): |
@@ -2149,7 +2263,9 @@ class SerializationTest(unittest.TestCase): |
test_util.SetAllFields(first_proto) |
serialized = first_proto.SerializeToString() |
self.assertEqual(first_proto.ByteSize(), len(serialized)) |
- second_proto.MergeFromString(serialized) |
+ self.assertEqual( |
+ len(serialized), |
+ second_proto.MergeFromString(serialized)) |
self.assertEqual(first_proto, second_proto) |
def testSerializeAllExtensions(self): |
@@ -2157,7 +2273,19 @@ class SerializationTest(unittest.TestCase): |
second_proto = unittest_pb2.TestAllExtensions() |
test_util.SetAllExtensions(first_proto) |
serialized = first_proto.SerializeToString() |
- second_proto.MergeFromString(serialized) |
+ self.assertEqual( |
+ len(serialized), |
+ second_proto.MergeFromString(serialized)) |
+ self.assertEqual(first_proto, second_proto) |
+ |
+ def testSerializeWithOptionalGroup(self): |
+ first_proto = unittest_pb2.TestAllTypes() |
+ second_proto = unittest_pb2.TestAllTypes() |
+ first_proto.optionalgroup.a = 242 |
+ serialized = first_proto.SerializeToString() |
+ self.assertEqual( |
+ len(serialized), |
+ second_proto.MergeFromString(serialized)) |
self.assertEqual(first_proto, second_proto) |
def testSerializeNegativeValues(self): |
@@ -2184,7 +2312,7 @@ class SerializationTest(unittest.TestCase): |
test_util.SetAllFields(first_proto) |
serialized = first_proto.SerializeToString() |
- for truncation_point in xrange(len(serialized) + 1): |
+ for truncation_point in range(len(serialized) + 1): |
try: |
second_proto = unittest_pb2.TestAllTypes() |
unknown_fields = unittest_pb2.TestEmptyMessage() |
@@ -2249,7 +2377,9 @@ class SerializationTest(unittest.TestCase): |
second_proto.optional_int32 = 100 |
second_proto.optional_nested_message.bb = 999 |
- second_proto.MergeFromString(serialized) |
+ bytes_parsed = second_proto.MergeFromString(serialized) |
+ self.assertEqual(len(serialized), bytes_parsed) |
+ |
# Ensure that we append to repeated fields. |
self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) |
# Ensure that we overwrite nonrepeatd scalars. |
@@ -2259,13 +2389,15 @@ class SerializationTest(unittest.TestCase): |
self.assertEqual(42, second_proto.optional_nested_message.bb) |
def testMessageSetWireFormat(self): |
- proto = unittest_mset_pb2.TestMessageSet() |
- extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 |
- extension_message2 = unittest_mset_pb2.TestMessageSetExtension2 |
+ proto = message_set_extensions_pb2.TestMessageSet() |
+ extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 |
+ extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2 |
extension1 = extension_message1.message_set_extension |
extension2 = extension_message2.message_set_extension |
+ extension3 = message_set_extensions_pb2.message_set_extension3 |
proto.Extensions[extension1].i = 123 |
proto.Extensions[extension2].str = 'foo' |
+ proto.Extensions[extension3].text = 'bar' |
# Serialize using the MessageSet wire format (this is specified in the |
# .proto file). |
@@ -2274,22 +2406,37 @@ class SerializationTest(unittest.TestCase): |
raw = unittest_mset_pb2.RawMessageSet() |
self.assertEqual(False, |
raw.DESCRIPTOR.GetOptions().message_set_wire_format) |
- raw.MergeFromString(serialized) |
- self.assertEqual(2, len(raw.item)) |
+ self.assertEqual( |
+ len(serialized), |
+ raw.MergeFromString(serialized)) |
+ self.assertEqual(3, len(raw.item)) |
- message1 = unittest_mset_pb2.TestMessageSetExtension1() |
- message1.MergeFromString(raw.item[0].message) |
+ message1 = message_set_extensions_pb2.TestMessageSetExtension1() |
+ self.assertEqual( |
+ len(raw.item[0].message), |
+ message1.MergeFromString(raw.item[0].message)) |
self.assertEqual(123, message1.i) |
- message2 = unittest_mset_pb2.TestMessageSetExtension2() |
- message2.MergeFromString(raw.item[1].message) |
+ message2 = message_set_extensions_pb2.TestMessageSetExtension2() |
+ self.assertEqual( |
+ len(raw.item[1].message), |
+ message2.MergeFromString(raw.item[1].message)) |
self.assertEqual('foo', message2.str) |
+ message3 = message_set_extensions_pb2.TestMessageSetExtension3() |
+ self.assertEqual( |
+ len(raw.item[2].message), |
+ message3.MergeFromString(raw.item[2].message)) |
+ self.assertEqual('bar', message3.text) |
+ |
# Deserialize using the MessageSet wire format. |
- proto2 = unittest_mset_pb2.TestMessageSet() |
- proto2.MergeFromString(serialized) |
+ proto2 = message_set_extensions_pb2.TestMessageSet() |
+ self.assertEqual( |
+ len(serialized), |
+ proto2.MergeFromString(serialized)) |
self.assertEqual(123, proto2.Extensions[extension1].i) |
self.assertEqual('foo', proto2.Extensions[extension2].str) |
+ self.assertEqual('bar', proto2.Extensions[extension3].text) |
# Check byte size. |
self.assertEqual(proto2.ByteSize(), len(serialized)) |
@@ -2302,37 +2449,39 @@ class SerializationTest(unittest.TestCase): |
# Add an item. |
item = raw.item.add() |
- item.type_id = 1545008 |
- extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 |
- message1 = unittest_mset_pb2.TestMessageSetExtension1() |
+ item.type_id = 98418603 |
+ extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 |
+ message1 = message_set_extensions_pb2.TestMessageSetExtension1() |
message1.i = 12345 |
item.message = message1.SerializeToString() |
# Add a second, unknown extension. |
item = raw.item.add() |
- item.type_id = 1545009 |
- extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 |
- message1 = unittest_mset_pb2.TestMessageSetExtension1() |
+ item.type_id = 98418604 |
+ extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 |
+ message1 = message_set_extensions_pb2.TestMessageSetExtension1() |
message1.i = 12346 |
item.message = message1.SerializeToString() |
# Add another unknown extension. |
item = raw.item.add() |
- item.type_id = 1545010 |
- message1 = unittest_mset_pb2.TestMessageSetExtension2() |
+ item.type_id = 98418605 |
+ message1 = message_set_extensions_pb2.TestMessageSetExtension2() |
message1.str = 'foo' |
item.message = message1.SerializeToString() |
serialized = raw.SerializeToString() |
# Parse message using the message set wire format. |
- proto = unittest_mset_pb2.TestMessageSet() |
- proto.MergeFromString(serialized) |
+ proto = message_set_extensions_pb2.TestMessageSet() |
+ self.assertEqual( |
+ len(serialized), |
+ proto.MergeFromString(serialized)) |
# Check that the message parsed well. |
- extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 |
+ extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 |
extension1 = extension_message1.message_set_extension |
- self.assertEquals(12345, proto.Extensions[extension1].i) |
+ self.assertEqual(12345, proto.Extensions[extension1].i) |
def testUnknownFields(self): |
proto = unittest_pb2.TestAllTypes() |
@@ -2345,7 +2494,9 @@ class SerializationTest(unittest.TestCase): |
proto2 = unittest_pb2.TestEmptyMessage() |
# Parsing this message should succeed. |
- proto2.MergeFromString(serialized) |
+ self.assertEqual( |
+ len(serialized), |
+ proto2.MergeFromString(serialized)) |
# Now test with a int64 field set. |
proto = unittest_pb2.TestAllTypes() |
@@ -2355,7 +2506,9 @@ class SerializationTest(unittest.TestCase): |
# unknown. |
proto2 = unittest_pb2.TestEmptyMessage() |
# Parsing this message should succeed. |
- proto2.MergeFromString(serialized) |
+ self.assertEqual( |
+ len(serialized), |
+ proto2.MergeFromString(serialized)) |
def _CheckRaises(self, exc_class, callable_obj, exception): |
"""This method checks if the excpetion type and message are as expected.""" |
@@ -2406,11 +2559,15 @@ class SerializationTest(unittest.TestCase): |
partial = proto.SerializePartialToString() |
proto2 = unittest_pb2.TestRequired() |
- proto2.MergeFromString(serialized) |
+ self.assertEqual( |
+ len(serialized), |
+ proto2.MergeFromString(serialized)) |
self.assertEqual(1, proto2.a) |
self.assertEqual(2, proto2.b) |
self.assertEqual(3, proto2.c) |
- proto2.ParseFromString(partial) |
+ self.assertEqual( |
+ len(partial), |
+ proto2.MergeFromString(partial)) |
self.assertEqual(1, proto2.a) |
self.assertEqual(2, proto2.b) |
self.assertEqual(3, proto2.c) |
@@ -2478,7 +2635,9 @@ class SerializationTest(unittest.TestCase): |
second_proto.packed_double.extend([1.0, 2.0]) |
second_proto.packed_sint32.append(4) |
- second_proto.MergeFromString(serialized) |
+ self.assertEqual( |
+ len(serialized), |
+ second_proto.MergeFromString(serialized)) |
self.assertEqual([3, 1, 2], second_proto.packed_int32) |
self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) |
self.assertEqual([4], second_proto.packed_sint32) |
@@ -2511,7 +2670,10 @@ class SerializationTest(unittest.TestCase): |
unpacked = unittest_pb2.TestUnpackedTypes() |
test_util.SetAllUnpackedFields(unpacked) |
packed = unittest_pb2.TestPackedTypes() |
- packed.MergeFromString(unpacked.SerializeToString()) |
+ serialized = unpacked.SerializeToString() |
+ self.assertEqual( |
+ len(serialized), |
+ packed.MergeFromString(serialized)) |
expected = unittest_pb2.TestPackedTypes() |
test_util.SetAllPackedFields(expected) |
self.assertEqual(expected, packed) |
@@ -2520,7 +2682,10 @@ class SerializationTest(unittest.TestCase): |
packed = unittest_pb2.TestPackedTypes() |
test_util.SetAllPackedFields(packed) |
unpacked = unittest_pb2.TestUnpackedTypes() |
- unpacked.MergeFromString(packed.SerializeToString()) |
+ serialized = packed.SerializeToString() |
+ self.assertEqual( |
+ len(serialized), |
+ unpacked.MergeFromString(serialized)) |
expected = unittest_pb2.TestUnpackedTypes() |
test_util.SetAllUnpackedFields(expected) |
self.assertEqual(expected, unpacked) |
@@ -2572,7 +2737,7 @@ class SerializationTest(unittest.TestCase): |
optional_int32=1, |
optional_string='foo', |
optional_bool=True, |
- optional_bytes='bar', |
+ optional_bytes=b'bar', |
optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), |
optional_foreign_message=unittest_pb2.ForeignMessage(c=1), |
optional_nested_enum=unittest_pb2.TestAllTypes.FOO, |
@@ -2590,7 +2755,7 @@ class SerializationTest(unittest.TestCase): |
self.assertEqual(1, proto.optional_int32) |
self.assertEqual('foo', proto.optional_string) |
self.assertEqual(True, proto.optional_bool) |
- self.assertEqual('bar', proto.optional_bytes) |
+ self.assertEqual(b'bar', proto.optional_bytes) |
self.assertEqual(1, proto.optional_nested_message.bb) |
self.assertEqual(1, proto.optional_foreign_message.c) |
self.assertEqual(unittest_pb2.TestAllTypes.FOO, |
@@ -2601,9 +2766,10 @@ class SerializationTest(unittest.TestCase): |
def testInitArgsUnknownFieldName(self): |
def InitalizeEmptyMessageWithExtraKeywordArg(): |
unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown') |
- self._CheckRaises(ValueError, |
- InitalizeEmptyMessageWithExtraKeywordArg, |
- 'Protocol message has no "unknown" field.') |
+ self._CheckRaises( |
+ ValueError, |
+ InitalizeEmptyMessageWithExtraKeywordArg, |
+ 'Protocol message TestEmptyMessage has no "unknown" field.') |
def testInitRequiredKwargs(self): |
proto = unittest_pb2.TestRequired(a=1, b=1, c=1) |
@@ -2643,7 +2809,7 @@ class SerializationTest(unittest.TestCase): |
class OptionsTest(unittest.TestCase): |
def testMessageOptions(self): |
- proto = unittest_mset_pb2.TestMessageSet() |
+ proto = message_set_extensions_pb2.TestMessageSet() |
self.assertEqual(True, |
proto.DESCRIPTOR.GetOptions().message_set_wire_format) |
proto = unittest_pb2.TestAllTypes() |
@@ -2662,10 +2828,149 @@ class OptionsTest(unittest.TestCase): |
proto.packed_double.append(3.0) |
for field_descriptor, _ in proto.ListFields(): |
self.assertEqual(True, field_descriptor.GetOptions().packed) |
- self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED, |
+ self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED, |
field_descriptor.label) |
+class ClassAPITest(unittest.TestCase): |
+ |
+ @unittest.skipIf( |
+ api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, |
+ 'C++ implementation requires a call to MakeDescriptor()') |
+ def testMakeClassWithNestedDescriptor(self): |
+ leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '', |
+ containing_type=None, fields=[], |
+ nested_types=[], enum_types=[], |
+ extensions=[]) |
+ child_desc = descriptor.Descriptor('child', 'package.parent.child', '', |
+ containing_type=None, fields=[], |
+ nested_types=[leaf_desc], enum_types=[], |
+ extensions=[]) |
+ sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling', |
+ '', containing_type=None, fields=[], |
+ nested_types=[], enum_types=[], |
+ extensions=[]) |
+ parent_desc = descriptor.Descriptor('parent', 'package.parent', '', |
+ containing_type=None, fields=[], |
+ nested_types=[child_desc, sibling_desc], |
+ enum_types=[], extensions=[]) |
+ message_class = reflection.MakeClass(parent_desc) |
+ self.assertIn('child', message_class.__dict__) |
+ self.assertIn('sibling', message_class.__dict__) |
+ self.assertIn('leaf', message_class.child.__dict__) |
+ |
+ def _GetSerializedFileDescriptor(self, name): |
+ """Get a serialized representation of a test FileDescriptorProto. |
+ |
+ Args: |
+ name: All calls to this must use a unique message name, to avoid |
+ collisions in the cpp descriptor pool. |
+ Returns: |
+ A string containing the serialized form of a test FileDescriptorProto. |
+ """ |
+ file_descriptor_str = ( |
+ 'message_type {' |
+ ' name: "' + name + '"' |
+ ' field {' |
+ ' name: "flat"' |
+ ' number: 1' |
+ ' label: LABEL_REPEATED' |
+ ' type: TYPE_UINT32' |
+ ' }' |
+ ' field {' |
+ ' name: "bar"' |
+ ' number: 2' |
+ ' label: LABEL_OPTIONAL' |
+ ' type: TYPE_MESSAGE' |
+ ' type_name: "Bar"' |
+ ' }' |
+ ' nested_type {' |
+ ' name: "Bar"' |
+ ' field {' |
+ ' name: "baz"' |
+ ' number: 3' |
+ ' label: LABEL_OPTIONAL' |
+ ' type: TYPE_MESSAGE' |
+ ' type_name: "Baz"' |
+ ' }' |
+ ' nested_type {' |
+ ' name: "Baz"' |
+ ' enum_type {' |
+ ' name: "deep_enum"' |
+ ' value {' |
+ ' name: "VALUE_A"' |
+ ' number: 0' |
+ ' }' |
+ ' }' |
+ ' field {' |
+ ' name: "deep"' |
+ ' number: 4' |
+ ' label: LABEL_OPTIONAL' |
+ ' type: TYPE_UINT32' |
+ ' }' |
+ ' }' |
+ ' }' |
+ '}') |
+ file_descriptor = descriptor_pb2.FileDescriptorProto() |
+ text_format.Merge(file_descriptor_str, file_descriptor) |
+ return file_descriptor.SerializeToString() |
+ |
+ def testParsingFlatClassWithExplicitClassDeclaration(self): |
+ """Test that the generated class can parse a flat message.""" |
+ # TODO(xiaofeng): This test fails with cpp implemetnation in the call |
+ # of six.with_metaclass(). The other two callsites of with_metaclass |
+ # in this file are both excluded from cpp test, so it might be expected |
+ # to fail. Need someone more familiar with the python code to take a |
+ # look at this. |
+ if api_implementation.Type() != 'python': |
+ return |
+ file_descriptor = descriptor_pb2.FileDescriptorProto() |
+ file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) |
+ msg_descriptor = descriptor.MakeDescriptor( |
+ file_descriptor.message_type[0]) |
+ |
+ class MessageClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): |
+ DESCRIPTOR = msg_descriptor |
+ msg = MessageClass() |
+ msg_str = ( |
+ 'flat: 0 ' |
+ 'flat: 1 ' |
+ 'flat: 2 ') |
+ text_format.Merge(msg_str, msg) |
+ self.assertEqual(msg.flat, [0, 1, 2]) |
+ |
+ def testParsingFlatClass(self): |
+ """Test that the generated class can parse a flat message.""" |
+ file_descriptor = descriptor_pb2.FileDescriptorProto() |
+ file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B')) |
+ msg_descriptor = descriptor.MakeDescriptor( |
+ file_descriptor.message_type[0]) |
+ msg_class = reflection.MakeClass(msg_descriptor) |
+ msg = msg_class() |
+ msg_str = ( |
+ 'flat: 0 ' |
+ 'flat: 1 ' |
+ 'flat: 2 ') |
+ text_format.Merge(msg_str, msg) |
+ self.assertEqual(msg.flat, [0, 1, 2]) |
+ |
+ def testParsingNestedClass(self): |
+ """Test that the generated class can parse a nested message.""" |
+ file_descriptor = descriptor_pb2.FileDescriptorProto() |
+ file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C')) |
+ msg_descriptor = descriptor.MakeDescriptor( |
+ file_descriptor.message_type[0]) |
+ msg_class = reflection.MakeClass(msg_descriptor) |
+ msg = msg_class() |
+ msg_str = ( |
+ 'bar {' |
+ ' baz {' |
+ ' deep: 4' |
+ ' }' |
+ '}') |
+ text_format.Merge(msg_str, msg) |
+ self.assertEqual(msg.bar.baz.deep, 4) |
+ |
if __name__ == '__main__': |
unittest.main() |