OLD | NEW |
| (Empty) |
1 # Copyright 2014 The Chromium Authors. All rights reserved. | |
2 # Use of this source code is governed by a BSD-style license that can be | |
3 # found in the LICENSE file. | |
4 | |
5 """Utility classes for serialization""" | |
6 | |
7 import struct | |
8 | |
9 | |
10 # Format of a header for a struct, array or union. | |
11 HEADER_STRUCT = struct.Struct("<II") | |
12 | |
13 # Format for a pointer. | |
14 POINTER_STRUCT = struct.Struct("<Q") | |
15 | |
16 | |
17 def Flatten(value): | |
18 """Flattens nested lists/tuples into an one-level list. If value is not a | |
19 list/tuple, it is converted to an one-item list. For example, | |
20 (1, 2, [3, 4, ('56', '7')]) is converted to [1, 2, 3, 4, '56', '7']; | |
21 1 is converted to [1]. | |
22 """ | |
23 if isinstance(value, (list, tuple)): | |
24 result = [] | |
25 for item in value: | |
26 result.extend(Flatten(item)) | |
27 return result | |
28 return [value] | |
29 | |
30 | |
31 class SerializationException(Exception): | |
32 """Error when strying to serialize a struct.""" | |
33 pass | |
34 | |
35 | |
36 class DeserializationException(Exception): | |
37 """Error when strying to deserialize a struct.""" | |
38 pass | |
39 | |
40 | |
41 class DeserializationContext(object): | |
42 | |
43 def ClaimHandle(self, handle): | |
44 raise NotImplementedError() | |
45 | |
46 def ClaimMemory(self, start, size): | |
47 raise NotImplementedError() | |
48 | |
49 def GetSubContext(self, offset): | |
50 raise NotImplementedError() | |
51 | |
52 def IsInitialContext(self): | |
53 raise NotImplementedError() | |
54 | |
55 | |
56 class RootDeserializationContext(DeserializationContext): | |
57 def __init__(self, data, handles): | |
58 if isinstance(data, buffer): | |
59 self.data = data | |
60 else: | |
61 self.data = buffer(data) | |
62 self._handles = handles | |
63 self._next_handle = 0; | |
64 self._next_memory = 0; | |
65 | |
66 def ClaimHandle(self, handle): | |
67 if handle < self._next_handle: | |
68 raise DeserializationException('Accessing handles out of order.') | |
69 self._next_handle = handle + 1 | |
70 return self._handles[handle] | |
71 | |
72 def ClaimMemory(self, start, size): | |
73 if start < self._next_memory: | |
74 raise DeserializationException('Accessing buffer out of order.') | |
75 self._next_memory = start + size | |
76 | |
77 def GetSubContext(self, offset): | |
78 return _ChildDeserializationContext(self, offset) | |
79 | |
80 def IsInitialContext(self): | |
81 return True | |
82 | |
83 | |
84 class _ChildDeserializationContext(DeserializationContext): | |
85 def __init__(self, parent, offset): | |
86 self._parent = parent | |
87 self._offset = offset | |
88 self.data = buffer(parent.data, offset) | |
89 | |
90 def ClaimHandle(self, handle): | |
91 return self._parent.ClaimHandle(handle) | |
92 | |
93 def ClaimMemory(self, start, size): | |
94 return self._parent.ClaimMemory(self._offset + start, size) | |
95 | |
96 def GetSubContext(self, offset): | |
97 return self._parent.GetSubContext(self._offset + offset) | |
98 | |
99 def IsInitialContext(self): | |
100 return False | |
101 | |
102 | |
103 class Serialization(object): | |
104 """ | |
105 Helper class to serialize/deserialize a struct. | |
106 """ | |
107 def __init__(self, groups): | |
108 self.version = _GetVersion(groups) | |
109 self._groups = groups | |
110 main_struct = _GetStruct(groups) | |
111 self.size = HEADER_STRUCT.size + main_struct.size | |
112 self._struct_per_version = { | |
113 self.version: main_struct, | |
114 } | |
115 self._groups_per_version = { | |
116 self.version: groups, | |
117 } | |
118 | |
119 def _GetMainStruct(self): | |
120 return self._GetStruct(self.version) | |
121 | |
122 def _GetGroups(self, version): | |
123 # If asking for a version greater than the last known. | |
124 version = min(version, self.version) | |
125 if version not in self._groups_per_version: | |
126 self._groups_per_version[version] = _FilterGroups(self._groups, version) | |
127 return self._groups_per_version[version] | |
128 | |
129 def _GetStruct(self, version): | |
130 # If asking for a version greater than the last known. | |
131 version = min(version, self.version) | |
132 if version not in self._struct_per_version: | |
133 self._struct_per_version[version] = _GetStruct(self._GetGroups(version)) | |
134 return self._struct_per_version[version] | |
135 | |
136 def Serialize(self, obj, handle_offset): | |
137 """ | |
138 Serialize the given obj. handle_offset is the the first value to use when | |
139 encoding handles. | |
140 """ | |
141 handles = [] | |
142 data = bytearray(self.size) | |
143 HEADER_STRUCT.pack_into(data, 0, self.size, self.version) | |
144 position = HEADER_STRUCT.size | |
145 to_pack = [] | |
146 for group in self._groups: | |
147 position = position + NeededPaddingForAlignment(position, | |
148 group.GetAlignment()) | |
149 (entry, new_handles) = group.Serialize( | |
150 obj, | |
151 len(data) - position, | |
152 data, | |
153 handle_offset + len(handles)) | |
154 to_pack.extend(Flatten(entry)) | |
155 handles.extend(new_handles) | |
156 position = position + group.GetByteSize() | |
157 self._GetMainStruct().pack_into(data, HEADER_STRUCT.size, *to_pack) | |
158 return (data, handles) | |
159 | |
160 def Deserialize(self, fields, context): | |
161 if len(context.data) < HEADER_STRUCT.size: | |
162 raise DeserializationException( | |
163 'Available data too short to contain header.') | |
164 (size, version) = HEADER_STRUCT.unpack_from(context.data) | |
165 if len(context.data) < size or size < HEADER_STRUCT.size: | |
166 raise DeserializationException('Header size is incorrect.') | |
167 if context.IsInitialContext(): | |
168 context.ClaimMemory(0, size) | |
169 version_struct = self._GetStruct(version) | |
170 entities = version_struct.unpack_from(context.data, HEADER_STRUCT.size) | |
171 filtered_groups = self._GetGroups(version) | |
172 if ((version <= self.version and | |
173 size != version_struct.size + HEADER_STRUCT.size) or | |
174 size < version_struct.size + HEADER_STRUCT.size): | |
175 raise DeserializationException('Struct size in incorrect.') | |
176 position = HEADER_STRUCT.size | |
177 enties_index = 0 | |
178 for group in filtered_groups: | |
179 position = position + NeededPaddingForAlignment(position, | |
180 group.GetAlignment()) | |
181 enties_count = len(group.GetTypeCode()) | |
182 if enties_count == 1: | |
183 value = entities[enties_index] | |
184 else: | |
185 value = tuple(entities[enties_index:enties_index+enties_count]) | |
186 fields.update(group.Deserialize(value, context.GetSubContext(position))) | |
187 position += group.GetByteSize() | |
188 enties_index += enties_count | |
189 | |
190 | |
191 def NeededPaddingForAlignment(value, alignment=8): | |
192 """Returns the padding necessary to align value with the given alignment.""" | |
193 if value % alignment: | |
194 return alignment - (value % alignment) | |
195 return 0 | |
196 | |
197 | |
198 def _GetVersion(groups): | |
199 if not len(groups): | |
200 return 0 | |
201 return max([x.GetMaxVersion() for x in groups]) | |
202 | |
203 | |
204 def _FilterGroups(groups, version): | |
205 return [group.Filter(version) for | |
206 group in groups if group.GetMinVersion() <= version] | |
207 | |
208 | |
209 def _GetStruct(groups): | |
210 index = 0 | |
211 codes = [ '<' ] | |
212 for group in groups: | |
213 code = group.GetTypeCode() | |
214 needed_padding = NeededPaddingForAlignment(index, group.GetAlignment()) | |
215 if needed_padding: | |
216 codes.append('x' * needed_padding) | |
217 index = index + needed_padding | |
218 codes.append(code) | |
219 index = index + group.GetByteSize() | |
220 alignment_needed = NeededPaddingForAlignment(index) | |
221 if alignment_needed: | |
222 codes.append('x' * alignment_needed) | |
223 return struct.Struct(''.join(codes)) | |
224 | |
225 | |
226 class UnionSerializer(object): | |
227 """ | |
228 Helper class to serialize/deserialize a union. | |
229 """ | |
230 def __init__(self, fields): | |
231 self._fields = {field.index: field for field in fields} | |
232 | |
233 def SerializeInline(self, union, handle_offset): | |
234 data = bytearray() | |
235 field = self._fields[union.tag] | |
236 | |
237 # If the union value is a simple type or a nested union, it is returned as | |
238 # entry. | |
239 # Otherwise, the serialized value is appended to data and the value of entry | |
240 # is -1. The caller will need to set entry to the location where the | |
241 # caller will append data. | |
242 (entry, handles) = field.field_type.Serialize( | |
243 union.data, -1, data, handle_offset) | |
244 | |
245 # If the value contained in the union is itself a union, we append its | |
246 # serialized value to data and set entry to -1. The caller will need to set | |
247 # entry to the location where the caller will append data. | |
248 if field.field_type.IsUnion(): | |
249 nested_union = bytearray(16) | |
250 HEADER_STRUCT.pack_into(nested_union, 0, entry[0], entry[1]) | |
251 POINTER_STRUCT.pack_into(nested_union, 8, entry[2]) | |
252 | |
253 data = nested_union + data | |
254 | |
255 # Since we do not know where the caller will append the nested union, | |
256 # we set entry to an invalid value and let the caller figure out the right | |
257 # value. | |
258 entry = -1 | |
259 | |
260 return (16, union.tag, entry, data), handles | |
261 | |
262 def Serialize(self, union, handle_offset): | |
263 (size, tag, entry, extra_data), handles = self.SerializeInline( | |
264 union, handle_offset) | |
265 data = bytearray(16) | |
266 if extra_data: | |
267 entry = 8 | |
268 data.extend(extra_data) | |
269 | |
270 field = self._fields[union.tag] | |
271 | |
272 HEADER_STRUCT.pack_into(data, 0, size, tag) | |
273 typecode = field.GetTypeCode() | |
274 | |
275 # If the value is a nested union, we store a 64 bits pointer to it. | |
276 if field.field_type.IsUnion(): | |
277 typecode = 'Q' | |
278 | |
279 struct.pack_into('<%s' % typecode, data, 8, entry) | |
280 return data, handles | |
281 | |
282 def Deserialize(self, context, union_class): | |
283 if len(context.data) < HEADER_STRUCT.size: | |
284 raise DeserializationException( | |
285 'Available data too short to contain header.') | |
286 (size, tag) = HEADER_STRUCT.unpack_from(context.data) | |
287 | |
288 if size == 0: | |
289 return None | |
290 | |
291 if size != 16: | |
292 raise DeserializationException('Invalid union size %s' % size) | |
293 | |
294 union = union_class.__new__(union_class) | |
295 if tag not in self._fields: | |
296 union.SetInternals(None, None) | |
297 return union | |
298 | |
299 field = self._fields[tag] | |
300 if field.field_type.IsUnion(): | |
301 ptr = POINTER_STRUCT.unpack_from(context.data, 8)[0] | |
302 value = field.field_type.Deserialize(ptr, context.GetSubContext(ptr+8)) | |
303 else: | |
304 raw_value = struct.unpack_from( | |
305 field.GetTypeCode(), context.data, 8)[0] | |
306 value = field.field_type.Deserialize(raw_value, context.GetSubContext(8)) | |
307 | |
308 union.SetInternals(field, value) | |
309 return union | |
OLD | NEW |