OLD | NEW |
1 #! /usr/bin/env python | 1 #! /usr/bin/env python |
2 # -*- coding: utf-8 -*- | 2 # -*- coding: utf-8 -*- |
3 # | 3 # |
4 # Protocol Buffers - Google's data interchange format | 4 # Protocol Buffers - Google's data interchange format |
5 # Copyright 2008 Google Inc. All rights reserved. | 5 # Copyright 2008 Google Inc. All rights reserved. |
6 # https://developers.google.com/protocol-buffers/ | 6 # https://developers.google.com/protocol-buffers/ |
7 # | 7 # |
8 # Redistribution and use in source and binary forms, with or without | 8 # Redistribution and use in source and binary forms, with or without |
9 # modification, are permitted provided that the following conditions are | 9 # modification, are permitted provided that the following conditions are |
10 # met: | 10 # met: |
(...skipping 29 matching lines...) Expand all Loading... |
40 except ImportError: | 40 except ImportError: |
41 import unittest | 41 import unittest |
42 from google.protobuf import unittest_mset_pb2 | 42 from google.protobuf import unittest_mset_pb2 |
43 from google.protobuf import unittest_pb2 | 43 from google.protobuf import unittest_pb2 |
44 from google.protobuf import unittest_proto3_arena_pb2 | 44 from google.protobuf import unittest_proto3_arena_pb2 |
45 from google.protobuf.internal import api_implementation | 45 from google.protobuf.internal import api_implementation |
46 from google.protobuf.internal import encoder | 46 from google.protobuf.internal import encoder |
47 from google.protobuf.internal import message_set_extensions_pb2 | 47 from google.protobuf.internal import message_set_extensions_pb2 |
48 from google.protobuf.internal import missing_enum_values_pb2 | 48 from google.protobuf.internal import missing_enum_values_pb2 |
49 from google.protobuf.internal import test_util | 49 from google.protobuf.internal import test_util |
| 50 from google.protobuf.internal import testing_refleaks |
50 from google.protobuf.internal import type_checkers | 51 from google.protobuf.internal import type_checkers |
51 | 52 |
52 | 53 |
| 54 BaseTestCase = testing_refleaks.BaseTestCase |
| 55 |
| 56 |
53 def SkipIfCppImplementation(func): | 57 def SkipIfCppImplementation(func): |
54 return unittest.skipIf( | 58 return unittest.skipIf( |
55 api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, | 59 api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, |
56 'C++ implementation does not expose unknown fields to Python')(func) | 60 'C++ implementation does not expose unknown fields to Python')(func) |
57 | 61 |
58 | 62 |
59 class UnknownFieldsTest(unittest.TestCase): | 63 class UnknownFieldsTest(BaseTestCase): |
60 | 64 |
61 def setUp(self): | 65 def setUp(self): |
62 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR | 66 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR |
63 self.all_fields = unittest_pb2.TestAllTypes() | 67 self.all_fields = unittest_pb2.TestAllTypes() |
64 test_util.SetAllFields(self.all_fields) | 68 test_util.SetAllFields(self.all_fields) |
65 self.all_fields_data = self.all_fields.SerializeToString() | 69 self.all_fields_data = self.all_fields.SerializeToString() |
66 self.empty_message = unittest_pb2.TestEmptyMessage() | 70 self.empty_message = unittest_pb2.TestEmptyMessage() |
67 self.empty_message.ParseFromString(self.all_fields_data) | 71 self.empty_message.ParseFromString(self.all_fields_data) |
68 | 72 |
69 def testSerialize(self): | 73 def testSerialize(self): |
(...skipping 63 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
133 self.assertNotEqual( | 137 self.assertNotEqual( |
134 b'', message.optional_nested_message.SerializeToString()) | 138 b'', message.optional_nested_message.SerializeToString()) |
135 self.assertNotEqual( | 139 self.assertNotEqual( |
136 b'', message.repeated_nested_message[0].SerializeToString()) | 140 b'', message.repeated_nested_message[0].SerializeToString()) |
137 message.DiscardUnknownFields() | 141 message.DiscardUnknownFields() |
138 self.assertEqual(b'', message.optional_nested_message.SerializeToString()) | 142 self.assertEqual(b'', message.optional_nested_message.SerializeToString()) |
139 self.assertEqual( | 143 self.assertEqual( |
140 b'', message.repeated_nested_message[0].SerializeToString()) | 144 b'', message.repeated_nested_message[0].SerializeToString()) |
141 | 145 |
142 | 146 |
143 class UnknownFieldsAccessorsTest(unittest.TestCase): | 147 class UnknownFieldsAccessorsTest(BaseTestCase): |
144 | 148 |
145 def setUp(self): | 149 def setUp(self): |
146 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR | 150 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR |
147 self.all_fields = unittest_pb2.TestAllTypes() | 151 self.all_fields = unittest_pb2.TestAllTypes() |
148 test_util.SetAllFields(self.all_fields) | 152 test_util.SetAllFields(self.all_fields) |
149 self.all_fields_data = self.all_fields.SerializeToString() | 153 self.all_fields_data = self.all_fields.SerializeToString() |
150 self.empty_message = unittest_pb2.TestEmptyMessage() | 154 self.empty_message = unittest_pb2.TestEmptyMessage() |
151 self.empty_message.ParseFromString(self.all_fields_data) | 155 self.empty_message.ParseFromString(self.all_fields_data) |
152 if api_implementation.Type() != 'cpp': | |
153 # _unknown_fields is an implementation detail. | |
154 self.unknown_fields = self.empty_message._unknown_fields | |
155 | 156 |
156 # All the tests that use GetField() check an implementation detail of the | 157 # GetUnknownField() checks a detail of the Python implementation, which stores |
157 # Python implementation, which stores unknown fields as serialized strings. | 158 # unknown fields as serialized strings. It cannot be used by the C++ |
158 # These tests are skipped by the C++ implementation: it's enough to check that | 159 # implementation: it's enough to check that the message is correctly |
159 # the message is correctly serialized. | 160 # serialized. |
160 | 161 |
161 def GetField(self, name): | 162 def GetUnknownField(self, name): |
162 field_descriptor = self.descriptor.fields_by_name[name] | 163 field_descriptor = self.descriptor.fields_by_name[name] |
163 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] | 164 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] |
164 field_tag = encoder.TagBytes(field_descriptor.number, wire_type) | 165 field_tag = encoder.TagBytes(field_descriptor.number, wire_type) |
165 result_dict = {} | 166 result_dict = {} |
166 for tag_bytes, value in self.unknown_fields: | 167 for tag_bytes, value in self.empty_message._unknown_fields: |
167 if tag_bytes == field_tag: | 168 if tag_bytes == field_tag: |
168 decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] | 169 decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] |
169 decoder(value, 0, len(value), self.all_fields, result_dict) | 170 decoder(value, 0, len(value), self.all_fields, result_dict) |
170 return result_dict[field_descriptor] | 171 return result_dict[field_descriptor] |
171 | 172 |
172 @SkipIfCppImplementation | 173 @SkipIfCppImplementation |
173 def testEnum(self): | 174 def testEnum(self): |
174 value = self.GetField('optional_nested_enum') | 175 value = self.GetUnknownField('optional_nested_enum') |
175 self.assertEqual(self.all_fields.optional_nested_enum, value) | 176 self.assertEqual(self.all_fields.optional_nested_enum, value) |
176 | 177 |
177 @SkipIfCppImplementation | 178 @SkipIfCppImplementation |
178 def testRepeatedEnum(self): | 179 def testRepeatedEnum(self): |
179 value = self.GetField('repeated_nested_enum') | 180 value = self.GetUnknownField('repeated_nested_enum') |
180 self.assertEqual(self.all_fields.repeated_nested_enum, value) | 181 self.assertEqual(self.all_fields.repeated_nested_enum, value) |
181 | 182 |
182 @SkipIfCppImplementation | 183 @SkipIfCppImplementation |
183 def testVarint(self): | 184 def testVarint(self): |
184 value = self.GetField('optional_int32') | 185 value = self.GetUnknownField('optional_int32') |
185 self.assertEqual(self.all_fields.optional_int32, value) | 186 self.assertEqual(self.all_fields.optional_int32, value) |
186 | 187 |
187 @SkipIfCppImplementation | 188 @SkipIfCppImplementation |
188 def testFixed32(self): | 189 def testFixed32(self): |
189 value = self.GetField('optional_fixed32') | 190 value = self.GetUnknownField('optional_fixed32') |
190 self.assertEqual(self.all_fields.optional_fixed32, value) | 191 self.assertEqual(self.all_fields.optional_fixed32, value) |
191 | 192 |
192 @SkipIfCppImplementation | 193 @SkipIfCppImplementation |
193 def testFixed64(self): | 194 def testFixed64(self): |
194 value = self.GetField('optional_fixed64') | 195 value = self.GetUnknownField('optional_fixed64') |
195 self.assertEqual(self.all_fields.optional_fixed64, value) | 196 self.assertEqual(self.all_fields.optional_fixed64, value) |
196 | 197 |
197 @SkipIfCppImplementation | 198 @SkipIfCppImplementation |
198 def testLengthDelimited(self): | 199 def testLengthDelimited(self): |
199 value = self.GetField('optional_string') | 200 value = self.GetUnknownField('optional_string') |
200 self.assertEqual(self.all_fields.optional_string, value) | 201 self.assertEqual(self.all_fields.optional_string, value) |
201 | 202 |
202 @SkipIfCppImplementation | 203 @SkipIfCppImplementation |
203 def testGroup(self): | 204 def testGroup(self): |
204 value = self.GetField('optionalgroup') | 205 value = self.GetUnknownField('optionalgroup') |
205 self.assertEqual(self.all_fields.optionalgroup, value) | 206 self.assertEqual(self.all_fields.optionalgroup, value) |
206 | 207 |
207 def testCopyFrom(self): | 208 def testCopyFrom(self): |
208 message = unittest_pb2.TestEmptyMessage() | 209 message = unittest_pb2.TestEmptyMessage() |
209 message.CopyFrom(self.empty_message) | 210 message.CopyFrom(self.empty_message) |
210 self.assertEqual(message.SerializeToString(), self.all_fields_data) | 211 self.assertEqual(message.SerializeToString(), self.all_fields_data) |
211 | 212 |
212 def testMergeFrom(self): | 213 def testMergeFrom(self): |
213 message = unittest_pb2.TestAllTypes() | 214 message = unittest_pb2.TestAllTypes() |
214 message.optional_int32 = 1 | 215 message.optional_int32 = 1 |
(...skipping 19 matching lines...) Expand all Loading... |
234 self.empty_message.Clear() | 235 self.empty_message.Clear() |
235 # All cleared, even unknown fields. | 236 # All cleared, even unknown fields. |
236 self.assertEqual(self.empty_message.SerializeToString(), b'') | 237 self.assertEqual(self.empty_message.SerializeToString(), b'') |
237 | 238 |
238 def testUnknownExtensions(self): | 239 def testUnknownExtensions(self): |
239 message = unittest_pb2.TestEmptyMessageWithExtensions() | 240 message = unittest_pb2.TestEmptyMessageWithExtensions() |
240 message.ParseFromString(self.all_fields_data) | 241 message.ParseFromString(self.all_fields_data) |
241 self.assertEqual(message.SerializeToString(), self.all_fields_data) | 242 self.assertEqual(message.SerializeToString(), self.all_fields_data) |
242 | 243 |
243 | 244 |
244 class UnknownEnumValuesTest(unittest.TestCase): | 245 class UnknownEnumValuesTest(BaseTestCase): |
245 | 246 |
246 def setUp(self): | 247 def setUp(self): |
247 self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR | 248 self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR |
248 | 249 |
249 self.message = missing_enum_values_pb2.TestEnumValues() | 250 self.message = missing_enum_values_pb2.TestEnumValues() |
| 251 # TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum. |
250 self.message.optional_nested_enum = ( | 252 self.message.optional_nested_enum = ( |
251 missing_enum_values_pb2.TestEnumValues.ZERO) | 253 missing_enum_values_pb2.TestEnumValues.ZERO) |
252 self.message.repeated_nested_enum.extend([ | 254 self.message.repeated_nested_enum.extend([ |
253 missing_enum_values_pb2.TestEnumValues.ZERO, | 255 missing_enum_values_pb2.TestEnumValues.ZERO, |
254 missing_enum_values_pb2.TestEnumValues.ONE, | 256 missing_enum_values_pb2.TestEnumValues.ONE, |
255 ]) | 257 ]) |
256 self.message.packed_nested_enum.extend([ | 258 self.message.packed_nested_enum.extend([ |
257 missing_enum_values_pb2.TestEnumValues.ZERO, | 259 missing_enum_values_pb2.TestEnumValues.ZERO, |
258 missing_enum_values_pb2.TestEnumValues.ONE, | 260 missing_enum_values_pb2.TestEnumValues.ONE, |
259 ]) | 261 ]) |
260 self.message_data = self.message.SerializeToString() | 262 self.message_data = self.message.SerializeToString() |
261 self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() | 263 self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() |
262 self.missing_message.ParseFromString(self.message_data) | 264 self.missing_message.ParseFromString(self.message_data) |
263 if api_implementation.Type() != 'cpp': | |
264 # _unknown_fields is an implementation detail. | |
265 self.unknown_fields = self.missing_message._unknown_fields | |
266 | 265 |
267 # All the tests that use GetField() check an implementation detail of the | 266 # GetUnknownField() checks a detail of the Python implementation, which stores |
268 # Python implementation, which stores unknown fields as serialized strings. | 267 # unknown fields as serialized strings. It cannot be used by the C++ |
269 # These tests are skipped by the C++ implementation: it's enough to check that | 268 # implementation: it's enough to check that the message is correctly |
270 # the message is correctly serialized. | 269 # serialized. |
271 | 270 |
272 def GetField(self, name): | 271 def GetUnknownField(self, name): |
273 field_descriptor = self.descriptor.fields_by_name[name] | 272 field_descriptor = self.descriptor.fields_by_name[name] |
274 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] | 273 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] |
275 field_tag = encoder.TagBytes(field_descriptor.number, wire_type) | 274 field_tag = encoder.TagBytes(field_descriptor.number, wire_type) |
276 result_dict = {} | 275 result_dict = {} |
277 for tag_bytes, value in self.unknown_fields: | 276 for tag_bytes, value in self.missing_message._unknown_fields: |
278 if tag_bytes == field_tag: | 277 if tag_bytes == field_tag: |
279 decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ | 278 decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ |
280 tag_bytes][0] | 279 tag_bytes][0] |
281 decoder(value, 0, len(value), self.message, result_dict) | 280 decoder(value, 0, len(value), self.message, result_dict) |
282 return result_dict[field_descriptor] | 281 return result_dict[field_descriptor] |
283 | 282 |
284 def testUnknownParseMismatchEnumValue(self): | 283 def testUnknownParseMismatchEnumValue(self): |
285 just_string = missing_enum_values_pb2.JustString() | 284 just_string = missing_enum_values_pb2.JustString() |
286 just_string.dummy = 'blah' | 285 just_string.dummy = 'blah' |
287 | 286 |
288 missing = missing_enum_values_pb2.TestEnumValues() | 287 missing = missing_enum_values_pb2.TestEnumValues() |
289 # The parse is invalid, storing the string proto into the set of | 288 # The parse is invalid, storing the string proto into the set of |
290 # unknown fields. | 289 # unknown fields. |
291 missing.ParseFromString(just_string.SerializeToString()) | 290 missing.ParseFromString(just_string.SerializeToString()) |
292 | 291 |
293 # Fetching the enum field shouldn't crash, instead returning the | 292 # Fetching the enum field shouldn't crash, instead returning the |
294 # default value. | 293 # default value. |
295 self.assertEqual(missing.optional_nested_enum, 0) | 294 self.assertEqual(missing.optional_nested_enum, 0) |
296 | 295 |
297 @SkipIfCppImplementation | |
298 def testUnknownEnumValue(self): | 296 def testUnknownEnumValue(self): |
| 297 if api_implementation.Type() == 'cpp': |
| 298 # The CPP implementation of protos (wrongly) allows unknown enum values |
| 299 # for proto2. |
| 300 self.assertTrue(self.missing_message.HasField('optional_nested_enum')) |
| 301 self.assertEqual(self.message.optional_nested_enum, |
| 302 self.missing_message.optional_nested_enum) |
| 303 else: |
| 304 # On the other hand, the Python implementation considers unknown values |
| 305 # as unknown fields. This is the correct behavior. |
| 306 self.assertFalse(self.missing_message.HasField('optional_nested_enum')) |
| 307 value = self.GetUnknownField('optional_nested_enum') |
| 308 self.assertEqual(self.message.optional_nested_enum, value) |
| 309 self.missing_message.ClearField('optional_nested_enum') |
299 self.assertFalse(self.missing_message.HasField('optional_nested_enum')) | 310 self.assertFalse(self.missing_message.HasField('optional_nested_enum')) |
300 value = self.GetField('optional_nested_enum') | |
301 self.assertEqual(self.message.optional_nested_enum, value) | |
302 | 311 |
303 @SkipIfCppImplementation | |
304 def testUnknownRepeatedEnumValue(self): | 312 def testUnknownRepeatedEnumValue(self): |
305 value = self.GetField('repeated_nested_enum') | 313 if api_implementation.Type() == 'cpp': |
306 self.assertEqual(self.message.repeated_nested_enum, value) | 314 # For repeated enums, both implementations agree. |
| 315 self.assertEqual([], self.missing_message.repeated_nested_enum) |
| 316 else: |
| 317 self.assertEqual([], self.missing_message.repeated_nested_enum) |
| 318 value = self.GetUnknownField('repeated_nested_enum') |
| 319 self.assertEqual(self.message.repeated_nested_enum, value) |
307 | 320 |
308 @SkipIfCppImplementation | |
309 def testUnknownPackedEnumValue(self): | 321 def testUnknownPackedEnumValue(self): |
310 value = self.GetField('packed_nested_enum') | 322 if api_implementation.Type() == 'cpp': |
311 self.assertEqual(self.message.packed_nested_enum, value) | 323 # For repeated enums, both implementations agree. |
| 324 self.assertEqual([], self.missing_message.packed_nested_enum) |
| 325 else: |
| 326 self.assertEqual([], self.missing_message.packed_nested_enum) |
| 327 value = self.GetUnknownField('packed_nested_enum') |
| 328 self.assertEqual(self.message.packed_nested_enum, value) |
312 | 329 |
313 def testRoundTrip(self): | 330 def testRoundTrip(self): |
314 new_message = missing_enum_values_pb2.TestEnumValues() | 331 new_message = missing_enum_values_pb2.TestEnumValues() |
315 new_message.ParseFromString(self.missing_message.SerializeToString()) | 332 new_message.ParseFromString(self.missing_message.SerializeToString()) |
316 self.assertEqual(self.message, new_message) | 333 self.assertEqual(self.message, new_message) |
317 | 334 |
318 | 335 |
319 if __name__ == '__main__': | 336 if __name__ == '__main__': |
320 unittest.main() | 337 unittest.main() |
OLD | NEW |