OLD | NEW |
1 #! /usr/bin/env python | 1 #! /usr/bin/env python |
2 # -*- coding: utf-8 -*- | 2 # -*- coding: utf-8 -*- |
3 # | 3 # |
4 # Protocol Buffers - Google's data interchange format | 4 # Protocol Buffers - Google's data interchange format |
5 # Copyright 2008 Google Inc. All rights reserved. | 5 # Copyright 2008 Google Inc. All rights reserved. |
6 # https://developers.google.com/protocol-buffers/ | 6 # https://developers.google.com/protocol-buffers/ |
7 # | 7 # |
8 # Redistribution and use in source and binary forms, with or without | 8 # Redistribution and use in source and binary forms, with or without |
9 # modification, are permitted provided that the following conditions are | 9 # modification, are permitted provided that the following conditions are |
10 # met: | 10 # met: |
(...skipping 29 matching lines...) Expand all Loading... |
40 except ImportError: | 40 except ImportError: |
41 import unittest | 41 import unittest |
42 from google.protobuf import unittest_mset_pb2 | 42 from google.protobuf import unittest_mset_pb2 |
43 from google.protobuf import unittest_pb2 | 43 from google.protobuf import unittest_pb2 |
44 from google.protobuf import unittest_proto3_arena_pb2 | 44 from google.protobuf import unittest_proto3_arena_pb2 |
45 from google.protobuf.internal import api_implementation | 45 from google.protobuf.internal import api_implementation |
46 from google.protobuf.internal import encoder | 46 from google.protobuf.internal import encoder |
47 from google.protobuf.internal import message_set_extensions_pb2 | 47 from google.protobuf.internal import message_set_extensions_pb2 |
48 from google.protobuf.internal import missing_enum_values_pb2 | 48 from google.protobuf.internal import missing_enum_values_pb2 |
49 from google.protobuf.internal import test_util | 49 from google.protobuf.internal import test_util |
50 from google.protobuf.internal import testing_refleaks | |
51 from google.protobuf.internal import type_checkers | 50 from google.protobuf.internal import type_checkers |
52 | 51 |
53 | 52 |
54 BaseTestCase = testing_refleaks.BaseTestCase | |
55 | |
56 | |
57 def SkipIfCppImplementation(func): | 53 def SkipIfCppImplementation(func): |
58 return unittest.skipIf( | 54 return unittest.skipIf( |
59 api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, | 55 api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, |
60 'C++ implementation does not expose unknown fields to Python')(func) | 56 'C++ implementation does not expose unknown fields to Python')(func) |
61 | 57 |
62 | 58 |
63 class UnknownFieldsTest(BaseTestCase): | 59 class UnknownFieldsTest(unittest.TestCase): |
64 | 60 |
65 def setUp(self): | 61 def setUp(self): |
66 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR | 62 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR |
67 self.all_fields = unittest_pb2.TestAllTypes() | 63 self.all_fields = unittest_pb2.TestAllTypes() |
68 test_util.SetAllFields(self.all_fields) | 64 test_util.SetAllFields(self.all_fields) |
69 self.all_fields_data = self.all_fields.SerializeToString() | 65 self.all_fields_data = self.all_fields.SerializeToString() |
70 self.empty_message = unittest_pb2.TestEmptyMessage() | 66 self.empty_message = unittest_pb2.TestEmptyMessage() |
71 self.empty_message.ParseFromString(self.all_fields_data) | 67 self.empty_message.ParseFromString(self.all_fields_data) |
72 | 68 |
73 def testSerialize(self): | 69 def testSerialize(self): |
(...skipping 63 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
137 self.assertNotEqual( | 133 self.assertNotEqual( |
138 b'', message.optional_nested_message.SerializeToString()) | 134 b'', message.optional_nested_message.SerializeToString()) |
139 self.assertNotEqual( | 135 self.assertNotEqual( |
140 b'', message.repeated_nested_message[0].SerializeToString()) | 136 b'', message.repeated_nested_message[0].SerializeToString()) |
141 message.DiscardUnknownFields() | 137 message.DiscardUnknownFields() |
142 self.assertEqual(b'', message.optional_nested_message.SerializeToString()) | 138 self.assertEqual(b'', message.optional_nested_message.SerializeToString()) |
143 self.assertEqual( | 139 self.assertEqual( |
144 b'', message.repeated_nested_message[0].SerializeToString()) | 140 b'', message.repeated_nested_message[0].SerializeToString()) |
145 | 141 |
146 | 142 |
147 class UnknownFieldsAccessorsTest(BaseTestCase): | 143 class UnknownFieldsAccessorsTest(unittest.TestCase): |
148 | 144 |
149 def setUp(self): | 145 def setUp(self): |
150 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR | 146 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR |
151 self.all_fields = unittest_pb2.TestAllTypes() | 147 self.all_fields = unittest_pb2.TestAllTypes() |
152 test_util.SetAllFields(self.all_fields) | 148 test_util.SetAllFields(self.all_fields) |
153 self.all_fields_data = self.all_fields.SerializeToString() | 149 self.all_fields_data = self.all_fields.SerializeToString() |
154 self.empty_message = unittest_pb2.TestEmptyMessage() | 150 self.empty_message = unittest_pb2.TestEmptyMessage() |
155 self.empty_message.ParseFromString(self.all_fields_data) | 151 self.empty_message.ParseFromString(self.all_fields_data) |
| 152 if api_implementation.Type() != 'cpp': |
| 153 # _unknown_fields is an implementation detail. |
| 154 self.unknown_fields = self.empty_message._unknown_fields |
156 | 155 |
157 # GetUnknownField() checks a detail of the Python implementation, which stores | 156 # All the tests that use GetField() check an implementation detail of the |
158 # unknown fields as serialized strings. It cannot be used by the C++ | 157 # Python implementation, which stores unknown fields as serialized strings. |
159 # implementation: it's enough to check that the message is correctly | 158 # These tests are skipped by the C++ implementation: it's enough to check that |
160 # serialized. | 159 # the message is correctly serialized. |
161 | 160 |
162 def GetUnknownField(self, name): | 161 def GetField(self, name): |
163 field_descriptor = self.descriptor.fields_by_name[name] | 162 field_descriptor = self.descriptor.fields_by_name[name] |
164 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] | 163 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] |
165 field_tag = encoder.TagBytes(field_descriptor.number, wire_type) | 164 field_tag = encoder.TagBytes(field_descriptor.number, wire_type) |
166 result_dict = {} | 165 result_dict = {} |
167 for tag_bytes, value in self.empty_message._unknown_fields: | 166 for tag_bytes, value in self.unknown_fields: |
168 if tag_bytes == field_tag: | 167 if tag_bytes == field_tag: |
169 decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] | 168 decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] |
170 decoder(value, 0, len(value), self.all_fields, result_dict) | 169 decoder(value, 0, len(value), self.all_fields, result_dict) |
171 return result_dict[field_descriptor] | 170 return result_dict[field_descriptor] |
172 | 171 |
173 @SkipIfCppImplementation | 172 @SkipIfCppImplementation |
174 def testEnum(self): | 173 def testEnum(self): |
175 value = self.GetUnknownField('optional_nested_enum') | 174 value = self.GetField('optional_nested_enum') |
176 self.assertEqual(self.all_fields.optional_nested_enum, value) | 175 self.assertEqual(self.all_fields.optional_nested_enum, value) |
177 | 176 |
178 @SkipIfCppImplementation | 177 @SkipIfCppImplementation |
179 def testRepeatedEnum(self): | 178 def testRepeatedEnum(self): |
180 value = self.GetUnknownField('repeated_nested_enum') | 179 value = self.GetField('repeated_nested_enum') |
181 self.assertEqual(self.all_fields.repeated_nested_enum, value) | 180 self.assertEqual(self.all_fields.repeated_nested_enum, value) |
182 | 181 |
183 @SkipIfCppImplementation | 182 @SkipIfCppImplementation |
184 def testVarint(self): | 183 def testVarint(self): |
185 value = self.GetUnknownField('optional_int32') | 184 value = self.GetField('optional_int32') |
186 self.assertEqual(self.all_fields.optional_int32, value) | 185 self.assertEqual(self.all_fields.optional_int32, value) |
187 | 186 |
188 @SkipIfCppImplementation | 187 @SkipIfCppImplementation |
189 def testFixed32(self): | 188 def testFixed32(self): |
190 value = self.GetUnknownField('optional_fixed32') | 189 value = self.GetField('optional_fixed32') |
191 self.assertEqual(self.all_fields.optional_fixed32, value) | 190 self.assertEqual(self.all_fields.optional_fixed32, value) |
192 | 191 |
193 @SkipIfCppImplementation | 192 @SkipIfCppImplementation |
194 def testFixed64(self): | 193 def testFixed64(self): |
195 value = self.GetUnknownField('optional_fixed64') | 194 value = self.GetField('optional_fixed64') |
196 self.assertEqual(self.all_fields.optional_fixed64, value) | 195 self.assertEqual(self.all_fields.optional_fixed64, value) |
197 | 196 |
198 @SkipIfCppImplementation | 197 @SkipIfCppImplementation |
199 def testLengthDelimited(self): | 198 def testLengthDelimited(self): |
200 value = self.GetUnknownField('optional_string') | 199 value = self.GetField('optional_string') |
201 self.assertEqual(self.all_fields.optional_string, value) | 200 self.assertEqual(self.all_fields.optional_string, value) |
202 | 201 |
203 @SkipIfCppImplementation | 202 @SkipIfCppImplementation |
204 def testGroup(self): | 203 def testGroup(self): |
205 value = self.GetUnknownField('optionalgroup') | 204 value = self.GetField('optionalgroup') |
206 self.assertEqual(self.all_fields.optionalgroup, value) | 205 self.assertEqual(self.all_fields.optionalgroup, value) |
207 | 206 |
208 def testCopyFrom(self): | 207 def testCopyFrom(self): |
209 message = unittest_pb2.TestEmptyMessage() | 208 message = unittest_pb2.TestEmptyMessage() |
210 message.CopyFrom(self.empty_message) | 209 message.CopyFrom(self.empty_message) |
211 self.assertEqual(message.SerializeToString(), self.all_fields_data) | 210 self.assertEqual(message.SerializeToString(), self.all_fields_data) |
212 | 211 |
213 def testMergeFrom(self): | 212 def testMergeFrom(self): |
214 message = unittest_pb2.TestAllTypes() | 213 message = unittest_pb2.TestAllTypes() |
215 message.optional_int32 = 1 | 214 message.optional_int32 = 1 |
(...skipping 19 matching lines...) Expand all Loading... |
235 self.empty_message.Clear() | 234 self.empty_message.Clear() |
236 # All cleared, even unknown fields. | 235 # All cleared, even unknown fields. |
237 self.assertEqual(self.empty_message.SerializeToString(), b'') | 236 self.assertEqual(self.empty_message.SerializeToString(), b'') |
238 | 237 |
239 def testUnknownExtensions(self): | 238 def testUnknownExtensions(self): |
240 message = unittest_pb2.TestEmptyMessageWithExtensions() | 239 message = unittest_pb2.TestEmptyMessageWithExtensions() |
241 message.ParseFromString(self.all_fields_data) | 240 message.ParseFromString(self.all_fields_data) |
242 self.assertEqual(message.SerializeToString(), self.all_fields_data) | 241 self.assertEqual(message.SerializeToString(), self.all_fields_data) |
243 | 242 |
244 | 243 |
245 class UnknownEnumValuesTest(BaseTestCase): | 244 class UnknownEnumValuesTest(unittest.TestCase): |
246 | 245 |
247 def setUp(self): | 246 def setUp(self): |
248 self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR | 247 self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR |
249 | 248 |
250 self.message = missing_enum_values_pb2.TestEnumValues() | 249 self.message = missing_enum_values_pb2.TestEnumValues() |
251 # TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum. | |
252 self.message.optional_nested_enum = ( | 250 self.message.optional_nested_enum = ( |
253 missing_enum_values_pb2.TestEnumValues.ZERO) | 251 missing_enum_values_pb2.TestEnumValues.ZERO) |
254 self.message.repeated_nested_enum.extend([ | 252 self.message.repeated_nested_enum.extend([ |
255 missing_enum_values_pb2.TestEnumValues.ZERO, | 253 missing_enum_values_pb2.TestEnumValues.ZERO, |
256 missing_enum_values_pb2.TestEnumValues.ONE, | 254 missing_enum_values_pb2.TestEnumValues.ONE, |
257 ]) | 255 ]) |
258 self.message.packed_nested_enum.extend([ | 256 self.message.packed_nested_enum.extend([ |
259 missing_enum_values_pb2.TestEnumValues.ZERO, | 257 missing_enum_values_pb2.TestEnumValues.ZERO, |
260 missing_enum_values_pb2.TestEnumValues.ONE, | 258 missing_enum_values_pb2.TestEnumValues.ONE, |
261 ]) | 259 ]) |
262 self.message_data = self.message.SerializeToString() | 260 self.message_data = self.message.SerializeToString() |
263 self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() | 261 self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() |
264 self.missing_message.ParseFromString(self.message_data) | 262 self.missing_message.ParseFromString(self.message_data) |
| 263 if api_implementation.Type() != 'cpp': |
| 264 # _unknown_fields is an implementation detail. |
| 265 self.unknown_fields = self.missing_message._unknown_fields |
265 | 266 |
266 # GetUnknownField() checks a detail of the Python implementation, which stores | 267 # All the tests that use GetField() check an implementation detail of the |
267 # unknown fields as serialized strings. It cannot be used by the C++ | 268 # Python implementation, which stores unknown fields as serialized strings. |
268 # implementation: it's enough to check that the message is correctly | 269 # These tests are skipped by the C++ implementation: it's enough to check that |
269 # serialized. | 270 # the message is correctly serialized. |
270 | 271 |
271 def GetUnknownField(self, name): | 272 def GetField(self, name): |
272 field_descriptor = self.descriptor.fields_by_name[name] | 273 field_descriptor = self.descriptor.fields_by_name[name] |
273 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] | 274 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] |
274 field_tag = encoder.TagBytes(field_descriptor.number, wire_type) | 275 field_tag = encoder.TagBytes(field_descriptor.number, wire_type) |
275 result_dict = {} | 276 result_dict = {} |
276 for tag_bytes, value in self.missing_message._unknown_fields: | 277 for tag_bytes, value in self.unknown_fields: |
277 if tag_bytes == field_tag: | 278 if tag_bytes == field_tag: |
278 decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ | 279 decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ |
279 tag_bytes][0] | 280 tag_bytes][0] |
280 decoder(value, 0, len(value), self.message, result_dict) | 281 decoder(value, 0, len(value), self.message, result_dict) |
281 return result_dict[field_descriptor] | 282 return result_dict[field_descriptor] |
282 | 283 |
283 def testUnknownParseMismatchEnumValue(self): | 284 def testUnknownParseMismatchEnumValue(self): |
284 just_string = missing_enum_values_pb2.JustString() | 285 just_string = missing_enum_values_pb2.JustString() |
285 just_string.dummy = 'blah' | 286 just_string.dummy = 'blah' |
286 | 287 |
287 missing = missing_enum_values_pb2.TestEnumValues() | 288 missing = missing_enum_values_pb2.TestEnumValues() |
288 # The parse is invalid, storing the string proto into the set of | 289 # The parse is invalid, storing the string proto into the set of |
289 # unknown fields. | 290 # unknown fields. |
290 missing.ParseFromString(just_string.SerializeToString()) | 291 missing.ParseFromString(just_string.SerializeToString()) |
291 | 292 |
292 # Fetching the enum field shouldn't crash, instead returning the | 293 # Fetching the enum field shouldn't crash, instead returning the |
293 # default value. | 294 # default value. |
294 self.assertEqual(missing.optional_nested_enum, 0) | 295 self.assertEqual(missing.optional_nested_enum, 0) |
295 | 296 |
| 297 @SkipIfCppImplementation |
296 def testUnknownEnumValue(self): | 298 def testUnknownEnumValue(self): |
297 if api_implementation.Type() == 'cpp': | |
298 # The CPP implementation of protos (wrongly) allows unknown enum values | |
299 # for proto2. | |
300 self.assertTrue(self.missing_message.HasField('optional_nested_enum')) | |
301 self.assertEqual(self.message.optional_nested_enum, | |
302 self.missing_message.optional_nested_enum) | |
303 else: | |
304 # On the other hand, the Python implementation considers unknown values | |
305 # as unknown fields. This is the correct behavior. | |
306 self.assertFalse(self.missing_message.HasField('optional_nested_enum')) | |
307 value = self.GetUnknownField('optional_nested_enum') | |
308 self.assertEqual(self.message.optional_nested_enum, value) | |
309 self.missing_message.ClearField('optional_nested_enum') | |
310 self.assertFalse(self.missing_message.HasField('optional_nested_enum')) | 299 self.assertFalse(self.missing_message.HasField('optional_nested_enum')) |
| 300 value = self.GetField('optional_nested_enum') |
| 301 self.assertEqual(self.message.optional_nested_enum, value) |
311 | 302 |
| 303 @SkipIfCppImplementation |
312 def testUnknownRepeatedEnumValue(self): | 304 def testUnknownRepeatedEnumValue(self): |
313 if api_implementation.Type() == 'cpp': | 305 value = self.GetField('repeated_nested_enum') |
314 # For repeated enums, both implementations agree. | 306 self.assertEqual(self.message.repeated_nested_enum, value) |
315 self.assertEqual([], self.missing_message.repeated_nested_enum) | |
316 else: | |
317 self.assertEqual([], self.missing_message.repeated_nested_enum) | |
318 value = self.GetUnknownField('repeated_nested_enum') | |
319 self.assertEqual(self.message.repeated_nested_enum, value) | |
320 | 307 |
| 308 @SkipIfCppImplementation |
321 def testUnknownPackedEnumValue(self): | 309 def testUnknownPackedEnumValue(self): |
322 if api_implementation.Type() == 'cpp': | 310 value = self.GetField('packed_nested_enum') |
323 # For repeated enums, both implementations agree. | 311 self.assertEqual(self.message.packed_nested_enum, value) |
324 self.assertEqual([], self.missing_message.packed_nested_enum) | |
325 else: | |
326 self.assertEqual([], self.missing_message.packed_nested_enum) | |
327 value = self.GetUnknownField('packed_nested_enum') | |
328 self.assertEqual(self.message.packed_nested_enum, value) | |
329 | 312 |
330 def testRoundTrip(self): | 313 def testRoundTrip(self): |
331 new_message = missing_enum_values_pb2.TestEnumValues() | 314 new_message = missing_enum_values_pb2.TestEnumValues() |
332 new_message.ParseFromString(self.missing_message.SerializeToString()) | 315 new_message.ParseFromString(self.missing_message.SerializeToString()) |
333 self.assertEqual(self.message, new_message) | 316 self.assertEqual(self.message, new_message) |
334 | 317 |
335 | 318 |
336 if __name__ == '__main__': | 319 if __name__ == '__main__': |
337 unittest.main() | 320 unittest.main() |
OLD | NEW |