| OLD | NEW |
| (Empty) |
| 1 # Copyright 2014 The Chromium Authors. All rights reserved. | |
| 2 # Use of this source code is governed by a BSD-style license that can be | |
| 3 # found in the LICENSE file. | |
| 4 | |
| 5 """Utility classes for serialization""" | |
| 6 | |
| 7 import struct | |
| 8 | |
| 9 | |
| 10 # Format of a header for a struct, array or union. | |
| 11 HEADER_STRUCT = struct.Struct("<II") | |
| 12 | |
| 13 # Format for a pointer. | |
| 14 POINTER_STRUCT = struct.Struct("<Q") | |
| 15 | |
| 16 | |
| 17 def Flatten(value): | |
| 18 """Flattens nested lists/tuples into an one-level list. If value is not a | |
| 19 list/tuple, it is converted to an one-item list. For example, | |
| 20 (1, 2, [3, 4, ('56', '7')]) is converted to [1, 2, 3, 4, '56', '7']; | |
| 21 1 is converted to [1]. | |
| 22 """ | |
| 23 if isinstance(value, (list, tuple)): | |
| 24 result = [] | |
| 25 for item in value: | |
| 26 result.extend(Flatten(item)) | |
| 27 return result | |
| 28 return [value] | |
| 29 | |
| 30 | |
| 31 class SerializationException(Exception): | |
| 32 """Error when strying to serialize a struct.""" | |
| 33 pass | |
| 34 | |
| 35 | |
| 36 class DeserializationException(Exception): | |
| 37 """Error when strying to deserialize a struct.""" | |
| 38 pass | |
| 39 | |
| 40 | |
| 41 class DeserializationContext(object): | |
| 42 | |
| 43 def ClaimHandle(self, handle): | |
| 44 raise NotImplementedError() | |
| 45 | |
| 46 def ClaimMemory(self, start, size): | |
| 47 raise NotImplementedError() | |
| 48 | |
| 49 def GetSubContext(self, offset): | |
| 50 raise NotImplementedError() | |
| 51 | |
| 52 def IsInitialContext(self): | |
| 53 raise NotImplementedError() | |
| 54 | |
| 55 | |
| 56 class RootDeserializationContext(DeserializationContext): | |
| 57 def __init__(self, data, handles): | |
| 58 if isinstance(data, buffer): | |
| 59 self.data = data | |
| 60 else: | |
| 61 self.data = buffer(data) | |
| 62 self._handles = handles | |
| 63 self._next_handle = 0; | |
| 64 self._next_memory = 0; | |
| 65 | |
| 66 def ClaimHandle(self, handle): | |
| 67 if handle < self._next_handle: | |
| 68 raise DeserializationException('Accessing handles out of order.') | |
| 69 self._next_handle = handle + 1 | |
| 70 return self._handles[handle] | |
| 71 | |
| 72 def ClaimMemory(self, start, size): | |
| 73 if start < self._next_memory: | |
| 74 raise DeserializationException('Accessing buffer out of order.') | |
| 75 self._next_memory = start + size | |
| 76 | |
| 77 def GetSubContext(self, offset): | |
| 78 return _ChildDeserializationContext(self, offset) | |
| 79 | |
| 80 def IsInitialContext(self): | |
| 81 return True | |
| 82 | |
| 83 | |
| 84 class _ChildDeserializationContext(DeserializationContext): | |
| 85 def __init__(self, parent, offset): | |
| 86 self._parent = parent | |
| 87 self._offset = offset | |
| 88 self.data = buffer(parent.data, offset) | |
| 89 | |
| 90 def ClaimHandle(self, handle): | |
| 91 return self._parent.ClaimHandle(handle) | |
| 92 | |
| 93 def ClaimMemory(self, start, size): | |
| 94 return self._parent.ClaimMemory(self._offset + start, size) | |
| 95 | |
| 96 def GetSubContext(self, offset): | |
| 97 return self._parent.GetSubContext(self._offset + offset) | |
| 98 | |
| 99 def IsInitialContext(self): | |
| 100 return False | |
| 101 | |
| 102 | |
| 103 class Serialization(object): | |
| 104 """ | |
| 105 Helper class to serialize/deserialize a struct. | |
| 106 """ | |
| 107 def __init__(self, groups): | |
| 108 self.version = _GetVersion(groups) | |
| 109 self._groups = groups | |
| 110 main_struct = _GetStruct(groups) | |
| 111 self.size = HEADER_STRUCT.size + main_struct.size | |
| 112 self._struct_per_version = { | |
| 113 self.version: main_struct, | |
| 114 } | |
| 115 self._groups_per_version = { | |
| 116 self.version: groups, | |
| 117 } | |
| 118 | |
| 119 def _GetMainStruct(self): | |
| 120 return self._GetStruct(self.version) | |
| 121 | |
| 122 def _GetGroups(self, version): | |
| 123 # If asking for a version greater than the last known. | |
| 124 version = min(version, self.version) | |
| 125 if version not in self._groups_per_version: | |
| 126 self._groups_per_version[version] = _FilterGroups(self._groups, version) | |
| 127 return self._groups_per_version[version] | |
| 128 | |
| 129 def _GetStruct(self, version): | |
| 130 # If asking for a version greater than the last known. | |
| 131 version = min(version, self.version) | |
| 132 if version not in self._struct_per_version: | |
| 133 self._struct_per_version[version] = _GetStruct(self._GetGroups(version)) | |
| 134 return self._struct_per_version[version] | |
| 135 | |
| 136 def Serialize(self, obj, handle_offset): | |
| 137 """ | |
| 138 Serialize the given obj. handle_offset is the the first value to use when | |
| 139 encoding handles. | |
| 140 """ | |
| 141 handles = [] | |
| 142 data = bytearray(self.size) | |
| 143 HEADER_STRUCT.pack_into(data, 0, self.size, self.version) | |
| 144 position = HEADER_STRUCT.size | |
| 145 to_pack = [] | |
| 146 for group in self._groups: | |
| 147 position = position + NeededPaddingForAlignment(position, | |
| 148 group.GetAlignment()) | |
| 149 (entry, new_handles) = group.Serialize( | |
| 150 obj, | |
| 151 len(data) - position, | |
| 152 data, | |
| 153 handle_offset + len(handles)) | |
| 154 to_pack.extend(Flatten(entry)) | |
| 155 handles.extend(new_handles) | |
| 156 position = position + group.GetByteSize() | |
| 157 self._GetMainStruct().pack_into(data, HEADER_STRUCT.size, *to_pack) | |
| 158 return (data, handles) | |
| 159 | |
| 160 def Deserialize(self, fields, context): | |
| 161 if len(context.data) < HEADER_STRUCT.size: | |
| 162 raise DeserializationException( | |
| 163 'Available data too short to contain header.') | |
| 164 (size, version) = HEADER_STRUCT.unpack_from(context.data) | |
| 165 if len(context.data) < size or size < HEADER_STRUCT.size: | |
| 166 raise DeserializationException('Header size is incorrect.') | |
| 167 if context.IsInitialContext(): | |
| 168 context.ClaimMemory(0, size) | |
| 169 version_struct = self._GetStruct(version) | |
| 170 entities = version_struct.unpack_from(context.data, HEADER_STRUCT.size) | |
| 171 filtered_groups = self._GetGroups(version) | |
| 172 if ((version <= self.version and | |
| 173 size != version_struct.size + HEADER_STRUCT.size) or | |
| 174 size < version_struct.size + HEADER_STRUCT.size): | |
| 175 raise DeserializationException('Struct size in incorrect.') | |
| 176 position = HEADER_STRUCT.size | |
| 177 enties_index = 0 | |
| 178 for group in filtered_groups: | |
| 179 position = position + NeededPaddingForAlignment(position, | |
| 180 group.GetAlignment()) | |
| 181 enties_count = len(group.GetTypeCode()) | |
| 182 if enties_count == 1: | |
| 183 value = entities[enties_index] | |
| 184 else: | |
| 185 value = tuple(entities[enties_index:enties_index+enties_count]) | |
| 186 fields.update(group.Deserialize(value, context.GetSubContext(position))) | |
| 187 position += group.GetByteSize() | |
| 188 enties_index += enties_count | |
| 189 | |
| 190 | |
| 191 def NeededPaddingForAlignment(value, alignment=8): | |
| 192 """Returns the padding necessary to align value with the given alignment.""" | |
| 193 if value % alignment: | |
| 194 return alignment - (value % alignment) | |
| 195 return 0 | |
| 196 | |
| 197 | |
| 198 def _GetVersion(groups): | |
| 199 if not len(groups): | |
| 200 return 0 | |
| 201 return max([x.GetMaxVersion() for x in groups]) | |
| 202 | |
| 203 | |
| 204 def _FilterGroups(groups, version): | |
| 205 return [group.Filter(version) for | |
| 206 group in groups if group.GetMinVersion() <= version] | |
| 207 | |
| 208 | |
| 209 def _GetStruct(groups): | |
| 210 index = 0 | |
| 211 codes = [ '<' ] | |
| 212 for group in groups: | |
| 213 code = group.GetTypeCode() | |
| 214 needed_padding = NeededPaddingForAlignment(index, group.GetAlignment()) | |
| 215 if needed_padding: | |
| 216 codes.append('x' * needed_padding) | |
| 217 index = index + needed_padding | |
| 218 codes.append(code) | |
| 219 index = index + group.GetByteSize() | |
| 220 alignment_needed = NeededPaddingForAlignment(index) | |
| 221 if alignment_needed: | |
| 222 codes.append('x' * alignment_needed) | |
| 223 return struct.Struct(''.join(codes)) | |
| 224 | |
| 225 | |
| 226 class UnionSerializer(object): | |
| 227 """ | |
| 228 Helper class to serialize/deserialize a union. | |
| 229 """ | |
| 230 def __init__(self, fields): | |
| 231 self._fields = {field.index: field for field in fields} | |
| 232 | |
| 233 def SerializeInline(self, union, handle_offset): | |
| 234 data = bytearray() | |
| 235 field = self._fields[union.tag] | |
| 236 | |
| 237 # If the union value is a simple type or a nested union, it is returned as | |
| 238 # entry. | |
| 239 # Otherwise, the serialized value is appended to data and the value of entry | |
| 240 # is -1. The caller will need to set entry to the location where the | |
| 241 # caller will append data. | |
| 242 (entry, handles) = field.field_type.Serialize( | |
| 243 union.data, -1, data, handle_offset) | |
| 244 | |
| 245 # If the value contained in the union is itself a union, we append its | |
| 246 # serialized value to data and set entry to -1. The caller will need to set | |
| 247 # entry to the location where the caller will append data. | |
| 248 if field.field_type.IsUnion(): | |
| 249 nested_union = bytearray(16) | |
| 250 HEADER_STRUCT.pack_into(nested_union, 0, entry[0], entry[1]) | |
| 251 POINTER_STRUCT.pack_into(nested_union, 8, entry[2]) | |
| 252 | |
| 253 data = nested_union + data | |
| 254 | |
| 255 # Since we do not know where the caller will append the nested union, | |
| 256 # we set entry to an invalid value and let the caller figure out the right | |
| 257 # value. | |
| 258 entry = -1 | |
| 259 | |
| 260 return (16, union.tag, entry, data), handles | |
| 261 | |
| 262 def Serialize(self, union, handle_offset): | |
| 263 (size, tag, entry, extra_data), handles = self.SerializeInline( | |
| 264 union, handle_offset) | |
| 265 data = bytearray(16) | |
| 266 if extra_data: | |
| 267 entry = 8 | |
| 268 data.extend(extra_data) | |
| 269 | |
| 270 field = self._fields[union.tag] | |
| 271 | |
| 272 HEADER_STRUCT.pack_into(data, 0, size, tag) | |
| 273 typecode = field.GetTypeCode() | |
| 274 | |
| 275 # If the value is a nested union, we store a 64 bits pointer to it. | |
| 276 if field.field_type.IsUnion(): | |
| 277 typecode = 'Q' | |
| 278 | |
| 279 struct.pack_into('<%s' % typecode, data, 8, entry) | |
| 280 return data, handles | |
| 281 | |
| 282 def Deserialize(self, context, union_class): | |
| 283 if len(context.data) < HEADER_STRUCT.size: | |
| 284 raise DeserializationException( | |
| 285 'Available data too short to contain header.') | |
| 286 (size, tag) = HEADER_STRUCT.unpack_from(context.data) | |
| 287 | |
| 288 if size == 0: | |
| 289 return None | |
| 290 | |
| 291 if size != 16: | |
| 292 raise DeserializationException('Invalid union size %s' % size) | |
| 293 | |
| 294 union = union_class.__new__(union_class) | |
| 295 if tag not in self._fields: | |
| 296 union.SetInternals(None, None) | |
| 297 return union | |
| 298 | |
| 299 field = self._fields[tag] | |
| 300 if field.field_type.IsUnion(): | |
| 301 ptr = POINTER_STRUCT.unpack_from(context.data, 8)[0] | |
| 302 value = field.field_type.Deserialize(ptr, context.GetSubContext(ptr+8)) | |
| 303 else: | |
| 304 raw_value = struct.unpack_from( | |
| 305 field.GetTypeCode(), context.data, 8)[0] | |
| 306 value = field.field_type.Deserialize(raw_value, context.GetSubContext(8)) | |
| 307 | |
| 308 union.SetInternals(field, value) | |
| 309 return union | |
| OLD | NEW |