OLD | NEW |
| (Empty) |
1 # Protocol Buffers - Google's data interchange format | |
2 # Copyright 2008 Google Inc. All rights reserved. | |
3 # http://code.google.com/p/protobuf/ | |
4 # | |
5 # Redistribution and use in source and binary forms, with or without | |
6 # modification, are permitted provided that the following conditions are | |
7 # met: | |
8 # | |
9 # * Redistributions of source code must retain the above copyright | |
10 # notice, this list of conditions and the following disclaimer. | |
11 # * Redistributions in binary form must reproduce the above | |
12 # copyright notice, this list of conditions and the following disclaimer | |
13 # in the documentation and/or other materials provided with the | |
14 # distribution. | |
15 # * Neither the name of Google Inc. nor the names of its | |
16 # contributors may be used to endorse or promote products derived from | |
17 # this software without specific prior written permission. | |
18 # | |
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | |
20 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | |
21 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | |
22 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | |
23 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | |
24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | |
25 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |
26 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | |
27 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
29 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
30 | |
31 #PY25 compatible for GAE. | |
32 # | |
33 # Copyright 2009 Google Inc. All Rights Reserved. | |
34 | |
35 """Code for decoding protocol buffer primitives. | |
36 | |
37 This code is very similar to encoder.py -- read the docs for that module first. | |
38 | |
39 A "decoder" is a function with the signature: | |
40 Decode(buffer, pos, end, message, field_dict) | |
41 The arguments are: | |
42 buffer: The string containing the encoded message. | |
43 pos: The current position in the string. | |
44 end: The position in the string where the current message ends. May be | |
45 less than len(buffer) if we're reading a sub-message. | |
46 message: The message object into which we're parsing. | |
47 field_dict: message._fields (avoids a hashtable lookup). | |
48 The decoder reads the field and stores it into field_dict, returning the new | |
49 buffer position. A decoder for a repeated field may proactively decode all of | |
50 the elements of that field, if they appear consecutively. | |
51 | |
52 Note that decoders may throw any of the following: | |
53 IndexError: Indicates a truncated message. | |
54 struct.error: Unpacking of a fixed-width field failed. | |
55 message.DecodeError: Other errors. | |
56 | |
57 Decoders are expected to raise an exception if they are called with pos > end. | |
58 This allows callers to be lax about bounds checking: it's fineto read past | |
59 "end" as long as you are sure that someone else will notice and throw an | |
60 exception later on. | |
61 | |
62 Something up the call stack is expected to catch IndexError and struct.error | |
63 and convert them to message.DecodeError. | |
64 | |
65 Decoders are constructed using decoder constructors with the signature: | |
66 MakeDecoder(field_number, is_repeated, is_packed, key, new_default) | |
67 The arguments are: | |
68 field_number: The field number of the field we want to decode. | |
69 is_repeated: Is the field a repeated field? (bool) | |
70 is_packed: Is the field a packed field? (bool) | |
71 key: The key to use when looking up the field within field_dict. | |
72 (This is actually the FieldDescriptor but nothing in this | |
73 file should depend on that.) | |
74 new_default: A function which takes a message object as a parameter and | |
75 returns a new instance of the default value for this field. | |
76 (This is called for repeated fields and sub-messages, when an | |
77 instance does not already exist.) | |
78 | |
79 As with encoders, we define a decoder constructor for every type of field. | |
80 Then, for every field of every message class we construct an actual decoder. | |
81 That decoder goes into a dict indexed by tag, so when we decode a message | |
82 we repeatedly read a tag, look up the corresponding decoder, and invoke it. | |
83 """ | |
84 | |
85 __author__ = 'kenton@google.com (Kenton Varda)' | |
86 | |
87 import struct | |
88 import sys ##PY25 | |
89 _PY2 = sys.version_info[0] < 3 ##PY25 | |
90 from google.protobuf.internal import encoder | |
91 from google.protobuf.internal import wire_format | |
92 from google.protobuf import message | |
93 | |
94 | |
95 # This will overflow and thus become IEEE-754 "infinity". We would use | |
96 # "float('inf')" but it doesn't work on Windows pre-Python-2.6. | |
97 _POS_INF = 1e10000 | |
98 _NEG_INF = -_POS_INF | |
99 _NAN = _POS_INF * 0 | |
100 | |
101 | |
102 # This is not for optimization, but rather to avoid conflicts with local | |
103 # variables named "message". | |
104 _DecodeError = message.DecodeError | |
105 | |
106 | |
107 def _VarintDecoder(mask, result_type): | |
108 """Return an encoder for a basic varint value (does not include tag). | |
109 | |
110 Decoded values will be bitwise-anded with the given mask before being | |
111 returned, e.g. to limit them to 32 bits. The returned decoder does not | |
112 take the usual "end" parameter -- the caller is expected to do bounds checking | |
113 after the fact (often the caller can defer such checking until later). The | |
114 decoder returns a (value, new_pos) pair. | |
115 """ | |
116 | |
117 local_ord = ord | |
118 py2 = _PY2 ##PY25 | |
119 ##!PY25 py2 = str is bytes | |
120 def DecodeVarint(buffer, pos): | |
121 result = 0 | |
122 shift = 0 | |
123 while 1: | |
124 b = local_ord(buffer[pos]) if py2 else buffer[pos] | |
125 result |= ((b & 0x7f) << shift) | |
126 pos += 1 | |
127 if not (b & 0x80): | |
128 result &= mask | |
129 result = result_type(result) | |
130 return (result, pos) | |
131 shift += 7 | |
132 if shift >= 64: | |
133 raise _DecodeError('Too many bytes when decoding varint.') | |
134 return DecodeVarint | |
135 | |
136 | |
137 def _SignedVarintDecoder(mask, result_type): | |
138 """Like _VarintDecoder() but decodes signed values.""" | |
139 | |
140 local_ord = ord | |
141 py2 = _PY2 ##PY25 | |
142 ##!PY25 py2 = str is bytes | |
143 def DecodeVarint(buffer, pos): | |
144 result = 0 | |
145 shift = 0 | |
146 while 1: | |
147 b = local_ord(buffer[pos]) if py2 else buffer[pos] | |
148 result |= ((b & 0x7f) << shift) | |
149 pos += 1 | |
150 if not (b & 0x80): | |
151 if result > 0x7fffffffffffffff: | |
152 result -= (1 << 64) | |
153 result |= ~mask | |
154 else: | |
155 result &= mask | |
156 result = result_type(result) | |
157 return (result, pos) | |
158 shift += 7 | |
159 if shift >= 64: | |
160 raise _DecodeError('Too many bytes when decoding varint.') | |
161 return DecodeVarint | |
162 | |
163 # We force 32-bit values to int and 64-bit values to long to make | |
164 # alternate implementations where the distinction is more significant | |
165 # (e.g. the C++ implementation) simpler. | |
166 | |
167 _DecodeVarint = _VarintDecoder((1 << 64) - 1, long) | |
168 _DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long) | |
169 | |
170 # Use these versions for values which must be limited to 32 bits. | |
171 _DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) | |
172 _DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int) | |
173 | |
174 | |
175 def ReadTag(buffer, pos): | |
176 """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple. | |
177 | |
178 We return the raw bytes of the tag rather than decoding them. The raw | |
179 bytes can then be used to look up the proper decoder. This effectively allows | |
180 us to trade some work that would be done in pure-python (decoding a varint) | |
181 for work that is done in C (searching for a byte string in a hash table). | |
182 In a low-level language it would be much cheaper to decode the varint and | |
183 use that, but not in Python. | |
184 """ | |
185 | |
186 py2 = _PY2 ##PY25 | |
187 ##!PY25 py2 = str is bytes | |
188 start = pos | |
189 while (ord(buffer[pos]) if py2 else buffer[pos]) & 0x80: | |
190 pos += 1 | |
191 pos += 1 | |
192 return (buffer[start:pos], pos) | |
193 | |
194 | |
195 # -------------------------------------------------------------------- | |
196 | |
197 | |
198 def _SimpleDecoder(wire_type, decode_value): | |
199 """Return a constructor for a decoder for fields of a particular type. | |
200 | |
201 Args: | |
202 wire_type: The field's wire type. | |
203 decode_value: A function which decodes an individual value, e.g. | |
204 _DecodeVarint() | |
205 """ | |
206 | |
207 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default): | |
208 if is_packed: | |
209 local_DecodeVarint = _DecodeVarint | |
210 def DecodePackedField(buffer, pos, end, message, field_dict): | |
211 value = field_dict.get(key) | |
212 if value is None: | |
213 value = field_dict.setdefault(key, new_default(message)) | |
214 (endpoint, pos) = local_DecodeVarint(buffer, pos) | |
215 endpoint += pos | |
216 if endpoint > end: | |
217 raise _DecodeError('Truncated message.') | |
218 while pos < endpoint: | |
219 (element, pos) = decode_value(buffer, pos) | |
220 value.append(element) | |
221 if pos > endpoint: | |
222 del value[-1] # Discard corrupt value. | |
223 raise _DecodeError('Packed element was truncated.') | |
224 return pos | |
225 return DecodePackedField | |
226 elif is_repeated: | |
227 tag_bytes = encoder.TagBytes(field_number, wire_type) | |
228 tag_len = len(tag_bytes) | |
229 def DecodeRepeatedField(buffer, pos, end, message, field_dict): | |
230 value = field_dict.get(key) | |
231 if value is None: | |
232 value = field_dict.setdefault(key, new_default(message)) | |
233 while 1: | |
234 (element, new_pos) = decode_value(buffer, pos) | |
235 value.append(element) | |
236 # Predict that the next tag is another copy of the same repeated | |
237 # field. | |
238 pos = new_pos + tag_len | |
239 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: | |
240 # Prediction failed. Return. | |
241 if new_pos > end: | |
242 raise _DecodeError('Truncated message.') | |
243 return new_pos | |
244 return DecodeRepeatedField | |
245 else: | |
246 def DecodeField(buffer, pos, end, message, field_dict): | |
247 (field_dict[key], pos) = decode_value(buffer, pos) | |
248 if pos > end: | |
249 del field_dict[key] # Discard corrupt value. | |
250 raise _DecodeError('Truncated message.') | |
251 return pos | |
252 return DecodeField | |
253 | |
254 return SpecificDecoder | |
255 | |
256 | |
257 def _ModifiedDecoder(wire_type, decode_value, modify_value): | |
258 """Like SimpleDecoder but additionally invokes modify_value on every value | |
259 before storing it. Usually modify_value is ZigZagDecode. | |
260 """ | |
261 | |
262 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but | |
263 # not enough to make a significant difference. | |
264 | |
265 def InnerDecode(buffer, pos): | |
266 (result, new_pos) = decode_value(buffer, pos) | |
267 return (modify_value(result), new_pos) | |
268 return _SimpleDecoder(wire_type, InnerDecode) | |
269 | |
270 | |
271 def _StructPackDecoder(wire_type, format): | |
272 """Return a constructor for a decoder for a fixed-width field. | |
273 | |
274 Args: | |
275 wire_type: The field's wire type. | |
276 format: The format string to pass to struct.unpack(). | |
277 """ | |
278 | |
279 value_size = struct.calcsize(format) | |
280 local_unpack = struct.unpack | |
281 | |
282 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but | |
283 # not enough to make a significant difference. | |
284 | |
285 # Note that we expect someone up-stack to catch struct.error and convert | |
286 # it to _DecodeError -- this way we don't have to set up exception- | |
287 # handling blocks every time we parse one value. | |
288 | |
289 def InnerDecode(buffer, pos): | |
290 new_pos = pos + value_size | |
291 result = local_unpack(format, buffer[pos:new_pos])[0] | |
292 return (result, new_pos) | |
293 return _SimpleDecoder(wire_type, InnerDecode) | |
294 | |
295 | |
296 def _FloatDecoder(): | |
297 """Returns a decoder for a float field. | |
298 | |
299 This code works around a bug in struct.unpack for non-finite 32-bit | |
300 floating-point values. | |
301 """ | |
302 | |
303 local_unpack = struct.unpack | |
304 b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25 | |
305 | |
306 def InnerDecode(buffer, pos): | |
307 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign | |
308 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. | |
309 new_pos = pos + 4 | |
310 float_bytes = buffer[pos:new_pos] | |
311 | |
312 # If this value has all its exponent bits set, then it's non-finite. | |
313 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. | |
314 # To avoid that, we parse it specially. | |
315 if ((float_bytes[3:4] in b('\x7F\xFF')) ##PY25 | |
316 ##!PY25 if ((float_bytes[3:4] in b'\x7F\xFF') | |
317 and (float_bytes[2:3] >= b('\x80'))): ##PY25 | |
318 ##!PY25 and (float_bytes[2:3] >= b'\x80')): | |
319 # If at least one significand bit is set... | |
320 if float_bytes[0:3] != b('\x00\x00\x80'): ##PY25 | |
321 ##!PY25 if float_bytes[0:3] != b'\x00\x00\x80': | |
322 return (_NAN, new_pos) | |
323 # If sign bit is set... | |
324 if float_bytes[3:4] == b('\xFF'): ##PY25 | |
325 ##!PY25 if float_bytes[3:4] == b'\xFF': | |
326 return (_NEG_INF, new_pos) | |
327 return (_POS_INF, new_pos) | |
328 | |
329 # Note that we expect someone up-stack to catch struct.error and convert | |
330 # it to _DecodeError -- this way we don't have to set up exception- | |
331 # handling blocks every time we parse one value. | |
332 result = local_unpack('<f', float_bytes)[0] | |
333 return (result, new_pos) | |
334 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) | |
335 | |
336 | |
337 def _DoubleDecoder(): | |
338 """Returns a decoder for a double field. | |
339 | |
340 This code works around a bug in struct.unpack for not-a-number. | |
341 """ | |
342 | |
343 local_unpack = struct.unpack | |
344 b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25 | |
345 | |
346 def InnerDecode(buffer, pos): | |
347 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign | |
348 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. | |
349 new_pos = pos + 8 | |
350 double_bytes = buffer[pos:new_pos] | |
351 | |
352 # If this value has all its exponent bits set and at least one significand | |
353 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it | |
354 # as inf or -inf. To avoid that, we treat it specially. | |
355 ##!PY25 if ((double_bytes[7:8] in b'\x7F\xFF') | |
356 ##!PY25 and (double_bytes[6:7] >= b'\xF0') | |
357 ##!PY25 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): | |
358 if ((double_bytes[7:8] in b('\x7F\xFF')) ##PY25 | |
359 and (double_bytes[6:7] >= b('\xF0')) ##PY25 | |
360 and (double_bytes[0:7] != b('\x00\x00\x00\x00\x00\x00\xF0'))): ##PY25 | |
361 return (_NAN, new_pos) | |
362 | |
363 # Note that we expect someone up-stack to catch struct.error and convert | |
364 # it to _DecodeError -- this way we don't have to set up exception- | |
365 # handling blocks every time we parse one value. | |
366 result = local_unpack('<d', double_bytes)[0] | |
367 return (result, new_pos) | |
368 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) | |
369 | |
370 | |
371 def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): | |
372 enum_type = key.enum_type | |
373 if is_packed: | |
374 local_DecodeVarint = _DecodeVarint | |
375 def DecodePackedField(buffer, pos, end, message, field_dict): | |
376 value = field_dict.get(key) | |
377 if value is None: | |
378 value = field_dict.setdefault(key, new_default(message)) | |
379 (endpoint, pos) = local_DecodeVarint(buffer, pos) | |
380 endpoint += pos | |
381 if endpoint > end: | |
382 raise _DecodeError('Truncated message.') | |
383 while pos < endpoint: | |
384 value_start_pos = pos | |
385 (element, pos) = _DecodeSignedVarint32(buffer, pos) | |
386 if element in enum_type.values_by_number: | |
387 value.append(element) | |
388 else: | |
389 if not message._unknown_fields: | |
390 message._unknown_fields = [] | |
391 tag_bytes = encoder.TagBytes(field_number, | |
392 wire_format.WIRETYPE_VARINT) | |
393 message._unknown_fields.append( | |
394 (tag_bytes, buffer[value_start_pos:pos])) | |
395 if pos > endpoint: | |
396 if element in enum_type.values_by_number: | |
397 del value[-1] # Discard corrupt value. | |
398 else: | |
399 del message._unknown_fields[-1] | |
400 raise _DecodeError('Packed element was truncated.') | |
401 return pos | |
402 return DecodePackedField | |
403 elif is_repeated: | |
404 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) | |
405 tag_len = len(tag_bytes) | |
406 def DecodeRepeatedField(buffer, pos, end, message, field_dict): | |
407 value = field_dict.get(key) | |
408 if value is None: | |
409 value = field_dict.setdefault(key, new_default(message)) | |
410 while 1: | |
411 (element, new_pos) = _DecodeSignedVarint32(buffer, pos) | |
412 if element in enum_type.values_by_number: | |
413 value.append(element) | |
414 else: | |
415 if not message._unknown_fields: | |
416 message._unknown_fields = [] | |
417 message._unknown_fields.append( | |
418 (tag_bytes, buffer[pos:new_pos])) | |
419 # Predict that the next tag is another copy of the same repeated | |
420 # field. | |
421 pos = new_pos + tag_len | |
422 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: | |
423 # Prediction failed. Return. | |
424 if new_pos > end: | |
425 raise _DecodeError('Truncated message.') | |
426 return new_pos | |
427 return DecodeRepeatedField | |
428 else: | |
429 def DecodeField(buffer, pos, end, message, field_dict): | |
430 value_start_pos = pos | |
431 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) | |
432 if pos > end: | |
433 raise _DecodeError('Truncated message.') | |
434 if enum_value in enum_type.values_by_number: | |
435 field_dict[key] = enum_value | |
436 else: | |
437 if not message._unknown_fields: | |
438 message._unknown_fields = [] | |
439 tag_bytes = encoder.TagBytes(field_number, | |
440 wire_format.WIRETYPE_VARINT) | |
441 message._unknown_fields.append( | |
442 (tag_bytes, buffer[value_start_pos:pos])) | |
443 return pos | |
444 return DecodeField | |
445 | |
446 | |
447 # -------------------------------------------------------------------- | |
448 | |
449 | |
450 Int32Decoder = _SimpleDecoder( | |
451 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) | |
452 | |
453 Int64Decoder = _SimpleDecoder( | |
454 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) | |
455 | |
456 UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) | |
457 UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) | |
458 | |
459 SInt32Decoder = _ModifiedDecoder( | |
460 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) | |
461 SInt64Decoder = _ModifiedDecoder( | |
462 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) | |
463 | |
464 # Note that Python conveniently guarantees that when using the '<' prefix on | |
465 # formats, they will also have the same size across all platforms (as opposed | |
466 # to without the prefix, where their sizes depend on the C compiler's basic | |
467 # type sizes). | |
468 Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') | |
469 Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') | |
470 SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') | |
471 SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') | |
472 FloatDecoder = _FloatDecoder() | |
473 DoubleDecoder = _DoubleDecoder() | |
474 | |
475 BoolDecoder = _ModifiedDecoder( | |
476 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) | |
477 | |
478 | |
479 def StringDecoder(field_number, is_repeated, is_packed, key, new_default): | |
480 """Returns a decoder for a string field.""" | |
481 | |
482 local_DecodeVarint = _DecodeVarint | |
483 local_unicode = unicode | |
484 | |
485 def _ConvertToUnicode(byte_str): | |
486 try: | |
487 return local_unicode(byte_str, 'utf-8') | |
488 except UnicodeDecodeError, e: | |
489 # add more information to the error message and re-raise it. | |
490 e.reason = '%s in field: %s' % (e, key.full_name) | |
491 raise | |
492 | |
493 assert not is_packed | |
494 if is_repeated: | |
495 tag_bytes = encoder.TagBytes(field_number, | |
496 wire_format.WIRETYPE_LENGTH_DELIMITED) | |
497 tag_len = len(tag_bytes) | |
498 def DecodeRepeatedField(buffer, pos, end, message, field_dict): | |
499 value = field_dict.get(key) | |
500 if value is None: | |
501 value = field_dict.setdefault(key, new_default(message)) | |
502 while 1: | |
503 (size, pos) = local_DecodeVarint(buffer, pos) | |
504 new_pos = pos + size | |
505 if new_pos > end: | |
506 raise _DecodeError('Truncated string.') | |
507 value.append(_ConvertToUnicode(buffer[pos:new_pos])) | |
508 # Predict that the next tag is another copy of the same repeated field. | |
509 pos = new_pos + tag_len | |
510 if buffer[new_pos:pos] != tag_bytes or new_pos == end: | |
511 # Prediction failed. Return. | |
512 return new_pos | |
513 return DecodeRepeatedField | |
514 else: | |
515 def DecodeField(buffer, pos, end, message, field_dict): | |
516 (size, pos) = local_DecodeVarint(buffer, pos) | |
517 new_pos = pos + size | |
518 if new_pos > end: | |
519 raise _DecodeError('Truncated string.') | |
520 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) | |
521 return new_pos | |
522 return DecodeField | |
523 | |
524 | |
525 def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): | |
526 """Returns a decoder for a bytes field.""" | |
527 | |
528 local_DecodeVarint = _DecodeVarint | |
529 | |
530 assert not is_packed | |
531 if is_repeated: | |
532 tag_bytes = encoder.TagBytes(field_number, | |
533 wire_format.WIRETYPE_LENGTH_DELIMITED) | |
534 tag_len = len(tag_bytes) | |
535 def DecodeRepeatedField(buffer, pos, end, message, field_dict): | |
536 value = field_dict.get(key) | |
537 if value is None: | |
538 value = field_dict.setdefault(key, new_default(message)) | |
539 while 1: | |
540 (size, pos) = local_DecodeVarint(buffer, pos) | |
541 new_pos = pos + size | |
542 if new_pos > end: | |
543 raise _DecodeError('Truncated string.') | |
544 value.append(buffer[pos:new_pos]) | |
545 # Predict that the next tag is another copy of the same repeated field. | |
546 pos = new_pos + tag_len | |
547 if buffer[new_pos:pos] != tag_bytes or new_pos == end: | |
548 # Prediction failed. Return. | |
549 return new_pos | |
550 return DecodeRepeatedField | |
551 else: | |
552 def DecodeField(buffer, pos, end, message, field_dict): | |
553 (size, pos) = local_DecodeVarint(buffer, pos) | |
554 new_pos = pos + size | |
555 if new_pos > end: | |
556 raise _DecodeError('Truncated string.') | |
557 field_dict[key] = buffer[pos:new_pos] | |
558 return new_pos | |
559 return DecodeField | |
560 | |
561 | |
562 def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): | |
563 """Returns a decoder for a group field.""" | |
564 | |
565 end_tag_bytes = encoder.TagBytes(field_number, | |
566 wire_format.WIRETYPE_END_GROUP) | |
567 end_tag_len = len(end_tag_bytes) | |
568 | |
569 assert not is_packed | |
570 if is_repeated: | |
571 tag_bytes = encoder.TagBytes(field_number, | |
572 wire_format.WIRETYPE_START_GROUP) | |
573 tag_len = len(tag_bytes) | |
574 def DecodeRepeatedField(buffer, pos, end, message, field_dict): | |
575 value = field_dict.get(key) | |
576 if value is None: | |
577 value = field_dict.setdefault(key, new_default(message)) | |
578 while 1: | |
579 value = field_dict.get(key) | |
580 if value is None: | |
581 value = field_dict.setdefault(key, new_default(message)) | |
582 # Read sub-message. | |
583 pos = value.add()._InternalParse(buffer, pos, end) | |
584 # Read end tag. | |
585 new_pos = pos+end_tag_len | |
586 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: | |
587 raise _DecodeError('Missing group end tag.') | |
588 # Predict that the next tag is another copy of the same repeated field. | |
589 pos = new_pos + tag_len | |
590 if buffer[new_pos:pos] != tag_bytes or new_pos == end: | |
591 # Prediction failed. Return. | |
592 return new_pos | |
593 return DecodeRepeatedField | |
594 else: | |
595 def DecodeField(buffer, pos, end, message, field_dict): | |
596 value = field_dict.get(key) | |
597 if value is None: | |
598 value = field_dict.setdefault(key, new_default(message)) | |
599 # Read sub-message. | |
600 pos = value._InternalParse(buffer, pos, end) | |
601 # Read end tag. | |
602 new_pos = pos+end_tag_len | |
603 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: | |
604 raise _DecodeError('Missing group end tag.') | |
605 return new_pos | |
606 return DecodeField | |
607 | |
608 | |
609 def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): | |
610 """Returns a decoder for a message field.""" | |
611 | |
612 local_DecodeVarint = _DecodeVarint | |
613 | |
614 assert not is_packed | |
615 if is_repeated: | |
616 tag_bytes = encoder.TagBytes(field_number, | |
617 wire_format.WIRETYPE_LENGTH_DELIMITED) | |
618 tag_len = len(tag_bytes) | |
619 def DecodeRepeatedField(buffer, pos, end, message, field_dict): | |
620 value = field_dict.get(key) | |
621 if value is None: | |
622 value = field_dict.setdefault(key, new_default(message)) | |
623 while 1: | |
624 value = field_dict.get(key) | |
625 if value is None: | |
626 value = field_dict.setdefault(key, new_default(message)) | |
627 # Read length. | |
628 (size, pos) = local_DecodeVarint(buffer, pos) | |
629 new_pos = pos + size | |
630 if new_pos > end: | |
631 raise _DecodeError('Truncated message.') | |
632 # Read sub-message. | |
633 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: | |
634 # The only reason _InternalParse would return early is if it | |
635 # encountered an end-group tag. | |
636 raise _DecodeError('Unexpected end-group tag.') | |
637 # Predict that the next tag is another copy of the same repeated field. | |
638 pos = new_pos + tag_len | |
639 if buffer[new_pos:pos] != tag_bytes or new_pos == end: | |
640 # Prediction failed. Return. | |
641 return new_pos | |
642 return DecodeRepeatedField | |
643 else: | |
644 def DecodeField(buffer, pos, end, message, field_dict): | |
645 value = field_dict.get(key) | |
646 if value is None: | |
647 value = field_dict.setdefault(key, new_default(message)) | |
648 # Read length. | |
649 (size, pos) = local_DecodeVarint(buffer, pos) | |
650 new_pos = pos + size | |
651 if new_pos > end: | |
652 raise _DecodeError('Truncated message.') | |
653 # Read sub-message. | |
654 if value._InternalParse(buffer, pos, new_pos) != new_pos: | |
655 # The only reason _InternalParse would return early is if it encountered | |
656 # an end-group tag. | |
657 raise _DecodeError('Unexpected end-group tag.') | |
658 return new_pos | |
659 return DecodeField | |
660 | |
661 | |
662 # -------------------------------------------------------------------- | |
663 | |
664 MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) | |
665 | |
666 def MessageSetItemDecoder(extensions_by_number): | |
667 """Returns a decoder for a MessageSet item. | |
668 | |
669 The parameter is the _extensions_by_number map for the message class. | |
670 | |
671 The message set message looks like this: | |
672 message MessageSet { | |
673 repeated group Item = 1 { | |
674 required int32 type_id = 2; | |
675 required string message = 3; | |
676 } | |
677 } | |
678 """ | |
679 | |
680 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) | |
681 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) | |
682 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) | |
683 | |
684 local_ReadTag = ReadTag | |
685 local_DecodeVarint = _DecodeVarint | |
686 local_SkipField = SkipField | |
687 | |
688 def DecodeItem(buffer, pos, end, message, field_dict): | |
689 message_set_item_start = pos | |
690 type_id = -1 | |
691 message_start = -1 | |
692 message_end = -1 | |
693 | |
694 # Technically, type_id and message can appear in any order, so we need | |
695 # a little loop here. | |
696 while 1: | |
697 (tag_bytes, pos) = local_ReadTag(buffer, pos) | |
698 if tag_bytes == type_id_tag_bytes: | |
699 (type_id, pos) = local_DecodeVarint(buffer, pos) | |
700 elif tag_bytes == message_tag_bytes: | |
701 (size, message_start) = local_DecodeVarint(buffer, pos) | |
702 pos = message_end = message_start + size | |
703 elif tag_bytes == item_end_tag_bytes: | |
704 break | |
705 else: | |
706 pos = SkipField(buffer, pos, end, tag_bytes) | |
707 if pos == -1: | |
708 raise _DecodeError('Missing group end tag.') | |
709 | |
710 if pos > end: | |
711 raise _DecodeError('Truncated message.') | |
712 | |
713 if type_id == -1: | |
714 raise _DecodeError('MessageSet item missing type_id.') | |
715 if message_start == -1: | |
716 raise _DecodeError('MessageSet item missing message.') | |
717 | |
718 extension = extensions_by_number.get(type_id) | |
719 if extension is not None: | |
720 value = field_dict.get(extension) | |
721 if value is None: | |
722 value = field_dict.setdefault( | |
723 extension, extension.message_type._concrete_class()) | |
724 if value._InternalParse(buffer, message_start,message_end) != message_end: | |
725 # The only reason _InternalParse would return early is if it encountered | |
726 # an end-group tag. | |
727 raise _DecodeError('Unexpected end-group tag.') | |
728 else: | |
729 if not message._unknown_fields: | |
730 message._unknown_fields = [] | |
731 message._unknown_fields.append((MESSAGE_SET_ITEM_TAG, | |
732 buffer[message_set_item_start:pos])) | |
733 | |
734 return pos | |
735 | |
736 return DecodeItem | |
737 | |
738 # -------------------------------------------------------------------- | |
739 # Optimization is not as heavy here because calls to SkipField() are rare, | |
740 # except for handling end-group tags. | |
741 | |
742 def _SkipVarint(buffer, pos, end): | |
743 """Skip a varint value. Returns the new position.""" | |
744 # Previously ord(buffer[pos]) raised IndexError when pos is out of range. | |
745 # With this code, ord(b'') raises TypeError. Both are handled in | |
746 # python_message.py to generate a 'Truncated message' error. | |
747 while ord(buffer[pos:pos+1]) & 0x80: | |
748 pos += 1 | |
749 pos += 1 | |
750 if pos > end: | |
751 raise _DecodeError('Truncated message.') | |
752 return pos | |
753 | |
754 def _SkipFixed64(buffer, pos, end): | |
755 """Skip a fixed64 value. Returns the new position.""" | |
756 | |
757 pos += 8 | |
758 if pos > end: | |
759 raise _DecodeError('Truncated message.') | |
760 return pos | |
761 | |
762 def _SkipLengthDelimited(buffer, pos, end): | |
763 """Skip a length-delimited value. Returns the new position.""" | |
764 | |
765 (size, pos) = _DecodeVarint(buffer, pos) | |
766 pos += size | |
767 if pos > end: | |
768 raise _DecodeError('Truncated message.') | |
769 return pos | |
770 | |
771 def _SkipGroup(buffer, pos, end): | |
772 """Skip sub-group. Returns the new position.""" | |
773 | |
774 while 1: | |
775 (tag_bytes, pos) = ReadTag(buffer, pos) | |
776 new_pos = SkipField(buffer, pos, end, tag_bytes) | |
777 if new_pos == -1: | |
778 return pos | |
779 pos = new_pos | |
780 | |
781 def _EndGroup(buffer, pos, end): | |
782 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" | |
783 | |
784 return -1 | |
785 | |
786 def _SkipFixed32(buffer, pos, end): | |
787 """Skip a fixed32 value. Returns the new position.""" | |
788 | |
789 pos += 4 | |
790 if pos > end: | |
791 raise _DecodeError('Truncated message.') | |
792 return pos | |
793 | |
794 def _RaiseInvalidWireType(buffer, pos, end): | |
795 """Skip function for unknown wire types. Raises an exception.""" | |
796 | |
797 raise _DecodeError('Tag had invalid wire type.') | |
798 | |
799 def _FieldSkipper(): | |
800 """Constructs the SkipField function.""" | |
801 | |
802 WIRETYPE_TO_SKIPPER = [ | |
803 _SkipVarint, | |
804 _SkipFixed64, | |
805 _SkipLengthDelimited, | |
806 _SkipGroup, | |
807 _EndGroup, | |
808 _SkipFixed32, | |
809 _RaiseInvalidWireType, | |
810 _RaiseInvalidWireType, | |
811 ] | |
812 | |
813 wiretype_mask = wire_format.TAG_TYPE_MASK | |
814 | |
815 def SkipField(buffer, pos, end, tag_bytes): | |
816 """Skips a field with the specified tag. | |
817 | |
818 |pos| should point to the byte immediately after the tag. | |
819 | |
820 Returns: | |
821 The new position (after the tag value), or -1 if the tag is an end-group | |
822 tag (in which case the calling loop should break). | |
823 """ | |
824 | |
825 # The wire type is always in the first byte since varints are little-endian. | |
826 wire_type = ord(tag_bytes[0:1]) & wiretype_mask | |
827 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) | |
828 | |
829 return SkipField | |
830 | |
831 SkipField = _FieldSkipper() | |
OLD | NEW |