OLD | NEW |
---|---|
1 # Authors: | 1 # Authors: |
2 # Trevor Perrin | 2 # Trevor Perrin |
3 # Google - handling CertificateRequest.certificate_types | 3 # Google - handling CertificateRequest.certificate_types |
4 # Google (adapted by Sam Rushing and Marcelo Fernandez) - NPN support | 4 # Google (adapted by Sam Rushing and Marcelo Fernandez) - NPN support |
5 # Dimitris Moraitis - Anon ciphersuites | 5 # Dimitris Moraitis - Anon ciphersuites |
6 # Yngve Pettersen (ported by Paul Sokolovsky) - TLS 1.2 | 6 # Yngve Pettersen (ported by Paul Sokolovsky) - TLS 1.2 |
7 # | 7 # |
8 # See the LICENSE file for legal information regarding use of this file. | 8 # See the LICENSE file for legal information regarding use of this file. |
9 | 9 |
10 """Classes representing TLS messages.""" | 10 """Classes representing TLS messages.""" |
(...skipping 81 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
92 def __init__(self, handshakeType): | 92 def __init__(self, handshakeType): |
93 self.contentType = ContentType.handshake | 93 self.contentType = ContentType.handshake |
94 self.handshakeType = handshakeType | 94 self.handshakeType = handshakeType |
95 | 95 |
96 def postWrite(self, w): | 96 def postWrite(self, w): |
97 headerWriter = Writer() | 97 headerWriter = Writer() |
98 headerWriter.add(self.handshakeType, 1) | 98 headerWriter.add(self.handshakeType, 1) |
99 headerWriter.add(len(w.bytes), 3) | 99 headerWriter.add(len(w.bytes), 3) |
100 return headerWriter.bytes + w.bytes | 100 return headerWriter.bytes + w.bytes |
101 | 101 |
102 def parse_next_protos(self, b): | |
103 protos = [] | |
104 while True: | |
105 if len(b) == 0: | |
106 break | |
107 l = b[0] | |
108 b = b[1:] | |
109 if len(b) < l: | |
110 raise BadNextProtos(len(b)) | |
111 protos.append(b[:l]) | |
112 b = b[l:] | |
113 return protos | |
114 | |
115 def next_protos_encoded(self, protocol_list): | |
116 b = bytearray() | |
117 for e in protocol_list: | |
118 if len(e) > 255 or len(e) == 0: | |
119 raise BadNextProtos(len(e)) | |
120 b += bytearray( [len(e)] ) + bytearray(e) | |
121 return b | |
122 | |
102 class ClientHello(HandshakeMsg): | 123 class ClientHello(HandshakeMsg): |
103 def __init__(self, ssl2=False): | 124 def __init__(self, ssl2=False): |
104 HandshakeMsg.__init__(self, HandshakeType.client_hello) | 125 HandshakeMsg.__init__(self, HandshakeType.client_hello) |
105 self.ssl2 = ssl2 | 126 self.ssl2 = ssl2 |
106 self.client_version = (0,0) | 127 self.client_version = (0,0) |
107 self.random = bytearray(32) | 128 self.random = bytearray(32) |
108 self.session_id = bytearray(0) | 129 self.session_id = bytearray(0) |
109 self.cipher_suites = [] # a list of 16-bit values | 130 self.cipher_suites = [] # a list of 16-bit values |
110 self.certificate_types = [CertificateType.x509] | 131 self.certificate_types = [CertificateType.x509] |
111 self.compression_methods = [] # a list of 8-bit values | 132 self.compression_methods = [] # a list of 8-bit values |
112 self.srp_username = None # a string | 133 self.srp_username = None # a string |
113 self.tack = False | 134 self.tack = False |
135 self.alpn_protos_advertised = None | |
114 self.supports_npn = False | 136 self.supports_npn = False |
115 self.server_name = bytearray(0) | 137 self.server_name = bytearray(0) |
116 self.channel_id = False | 138 self.channel_id = False |
117 self.extended_master_secret = False | 139 self.extended_master_secret = False |
118 self.tb_client_params = [] | 140 self.tb_client_params = [] |
119 self.support_signed_cert_timestamps = False | 141 self.support_signed_cert_timestamps = False |
120 self.status_request = False | 142 self.status_request = False |
121 | 143 |
122 def create(self, version, random, session_id, cipher_suites, | 144 def create(self, version, random, session_id, cipher_suites, |
123 certificate_types=None, srpUsername=None, | 145 certificate_types=None, srpUsername=None, tack=False, |
124 tack=False, supports_npn=False, serverName=None): | 146 alpn_protos_advertised=None, supports_npn=False, serverName=None) : |
125 self.client_version = version | 147 self.client_version = version |
126 self.random = random | 148 self.random = random |
127 self.session_id = session_id | 149 self.session_id = session_id |
128 self.cipher_suites = cipher_suites | 150 self.cipher_suites = cipher_suites |
129 self.certificate_types = certificate_types | 151 self.certificate_types = certificate_types |
130 self.compression_methods = [0] | 152 self.compression_methods = [0] |
131 if srpUsername: | 153 if srpUsername: |
132 self.srp_username = bytearray(srpUsername, "utf-8") | 154 self.srp_username = bytearray(srpUsername, "utf-8") |
133 self.tack = tack | 155 self.tack = tack |
156 self.alpn_protos_advertised = alpn_protos_advertised | |
134 self.supports_npn = supports_npn | 157 self.supports_npn = supports_npn |
135 if serverName: | 158 if serverName: |
136 self.server_name = bytearray(serverName, "utf-8") | 159 self.server_name = bytearray(serverName, "utf-8") |
137 return self | 160 return self |
138 | 161 |
139 def parse(self, p): | 162 def parse(self, p): |
140 if self.ssl2: | 163 if self.ssl2: |
141 self.client_version = (p.get(1), p.get(1)) | 164 self.client_version = (p.get(1), p.get(1)) |
142 cipherSpecsLength = p.get(2) | 165 cipherSpecsLength = p.get(2) |
143 sessionIDLength = p.get(2) | 166 sessionIDLength = p.get(2) |
(...skipping 20 matching lines...) Expand all Loading... | |
164 while soFar != totalExtLength: | 187 while soFar != totalExtLength: |
165 extType = p.get(2) | 188 extType = p.get(2) |
166 extLength = p.get(2) | 189 extLength = p.get(2) |
167 index1 = p.index | 190 index1 = p.index |
168 if extType == ExtensionType.srp: | 191 if extType == ExtensionType.srp: |
169 self.srp_username = p.getVarBytes(1) | 192 self.srp_username = p.getVarBytes(1) |
170 elif extType == ExtensionType.cert_type: | 193 elif extType == ExtensionType.cert_type: |
171 self.certificate_types = p.getVarList(1, 1) | 194 self.certificate_types = p.getVarList(1, 1) |
172 elif extType == ExtensionType.tack: | 195 elif extType == ExtensionType.tack: |
173 self.tack = True | 196 self.tack = True |
197 elif extType == ExtensionType.alpn: | |
198 structLength = p.get(2) | |
199 if (structLength + 2 != extLength): | |
200 raise SyntaxError() | |
201 self.alpn_protos_advertised = self.parse_next_protos(p.g etFixBytes(structLength)) | |
174 elif extType == ExtensionType.supports_npn: | 202 elif extType == ExtensionType.supports_npn: |
175 self.supports_npn = True | 203 self.supports_npn = True |
176 elif extType == ExtensionType.server_name: | 204 elif extType == ExtensionType.server_name: |
177 serverNameListBytes = p.getFixBytes(extLength) | 205 serverNameListBytes = p.getFixBytes(extLength) |
178 p2 = Parser(serverNameListBytes) | 206 p2 = Parser(serverNameListBytes) |
179 p2.startLengthCheck(2) | 207 p2.startLengthCheck(2) |
180 while 1: | 208 while 1: |
181 if p2.atLengthCheck(): | 209 if p2.atLengthCheck(): |
182 break # no host_name, oh well | 210 break # no host_name, oh well |
183 name_type = p2.get(1) | 211 name_type = p2.get(1) |
(...skipping 52 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
236 w2 = Writer() # For Extensions | 264 w2 = Writer() # For Extensions |
237 if self.certificate_types and self.certificate_types != \ | 265 if self.certificate_types and self.certificate_types != \ |
238 [CertificateType.x509]: | 266 [CertificateType.x509]: |
239 w2.add(ExtensionType.cert_type, 2) | 267 w2.add(ExtensionType.cert_type, 2) |
240 w2.add(len(self.certificate_types)+1, 2) | 268 w2.add(len(self.certificate_types)+1, 2) |
241 w2.addVarSeq(self.certificate_types, 1, 1) | 269 w2.addVarSeq(self.certificate_types, 1, 1) |
242 if self.srp_username: | 270 if self.srp_username: |
243 w2.add(ExtensionType.srp, 2) | 271 w2.add(ExtensionType.srp, 2) |
244 w2.add(len(self.srp_username)+1, 2) | 272 w2.add(len(self.srp_username)+1, 2) |
245 w2.addVarSeq(self.srp_username, 1, 1) | 273 w2.addVarSeq(self.srp_username, 1, 1) |
274 if self.alpn_protos_advertised is not None: | |
275 encoded_alpn_protos_advertised = self.next_protos_encoded(self.alpn_ protos_advertised) | |
276 w2.add(ExtensionType.alpn, 2) | |
277 w2.add(len(encoded_alpn_protos_advertised) + 2, 2) | |
278 w2.add(len(encoded_alpn_protos_advertised), 2) | |
279 w2.addFixSeq(encoded_alpn_protos_advertised, 1) | |
246 if self.supports_npn: | 280 if self.supports_npn: |
247 w2.add(ExtensionType.supports_npn, 2) | 281 w2.add(ExtensionType.supports_npn, 2) |
248 w2.add(0, 2) | 282 w2.add(0, 2) |
249 if self.server_name: | 283 if self.server_name: |
250 w2.add(ExtensionType.server_name, 2) | 284 w2.add(ExtensionType.server_name, 2) |
251 w2.add(len(self.server_name)+5, 2) | 285 w2.add(len(self.server_name)+5, 2) |
252 w2.add(len(self.server_name)+3, 2) | 286 w2.add(len(self.server_name)+3, 2) |
253 w2.add(NameType.host_name, 1) | 287 w2.add(NameType.host_name, 1) |
254 w2.addVarSeq(self.server_name, 1, 2) | 288 w2.addVarSeq(self.server_name, 1, 2) |
255 if self.tack: | 289 if self.tack: |
256 w2.add(ExtensionType.tack, 2) | 290 w2.add(ExtensionType.tack, 2) |
257 w2.add(0, 2) | 291 w2.add(0, 2) |
258 if len(w2.bytes): | 292 if len(w2.bytes): |
259 w.add(len(w2.bytes), 2) | 293 w.add(len(w2.bytes), 2) |
260 w.bytes += w2.bytes | 294 w.bytes += w2.bytes |
261 return self.postWrite(w) | 295 return self.postWrite(w) |
262 | 296 |
263 class BadNextProtos(Exception): | 297 class BadNextProtos(Exception): |
264 def __init__(self, l): | 298 def __init__(self, l): |
265 self.length = l | 299 self.length = l |
266 | 300 |
267 def __str__(self): | 301 def __str__(self): |
268 return 'Cannot encode a list of next protocols because it contains an el ement with invalid length %d. Element lengths must be 0 < x < 256' % self.length | 302 return 'Cannot encode a list of next protocols because it contains an el ement with invalid length %d. Element lengths must be 0 < x < 256' % self.length |
269 | 303 |
304 class InvalidAlpnResponse(Exception): | |
305 def __init__(self, l): | |
306 self.length = l | |
307 | |
308 def __str__(self): | |
309 return 'ALPN server response protocol list has invalid length %d. It mu st be of length one.' % self.length | |
310 | |
270 class ServerHello(HandshakeMsg): | 311 class ServerHello(HandshakeMsg): |
271 def __init__(self): | 312 def __init__(self): |
272 HandshakeMsg.__init__(self, HandshakeType.server_hello) | 313 HandshakeMsg.__init__(self, HandshakeType.server_hello) |
273 self.server_version = (0,0) | 314 self.server_version = (0,0) |
274 self.random = bytearray(32) | 315 self.random = bytearray(32) |
275 self.session_id = bytearray(0) | 316 self.session_id = bytearray(0) |
276 self.cipher_suite = 0 | 317 self.cipher_suite = 0 |
277 self.certificate_type = CertificateType.x509 | 318 self.certificate_type = CertificateType.x509 |
278 self.compression_method = 0 | 319 self.compression_method = 0 |
279 self.tackExt = None | 320 self.tackExt = None |
321 self.alpn_proto_selected = None | |
280 self.next_protos_advertised = None | 322 self.next_protos_advertised = None |
281 self.next_protos = None | 323 self.next_protos = None |
282 self.channel_id = False | 324 self.channel_id = False |
283 self.extended_master_secret = False | 325 self.extended_master_secret = False |
284 self.tb_params = None | 326 self.tb_params = None |
285 self.signed_cert_timestamps = None | 327 self.signed_cert_timestamps = None |
286 self.status_request = False | 328 self.status_request = False |
287 | 329 |
288 def create(self, version, random, session_id, cipher_suite, | 330 def create(self, version, random, session_id, cipher_suite, |
289 certificate_type, tackExt, next_protos_advertised): | 331 certificate_type, tackExt, alpn_proto_selected, |
332 next_protos_advertised): | |
290 self.server_version = version | 333 self.server_version = version |
291 self.random = random | 334 self.random = random |
292 self.session_id = session_id | 335 self.session_id = session_id |
293 self.cipher_suite = cipher_suite | 336 self.cipher_suite = cipher_suite |
294 self.certificate_type = certificate_type | 337 self.certificate_type = certificate_type |
295 self.compression_method = 0 | 338 self.compression_method = 0 |
296 self.tackExt = tackExt | 339 self.tackExt = tackExt |
340 self.alpn_proto_selected = alpn_proto_selected | |
297 self.next_protos_advertised = next_protos_advertised | 341 self.next_protos_advertised = next_protos_advertised |
298 return self | 342 return self |
299 | 343 |
300 def parse(self, p): | 344 def parse(self, p): |
301 p.startLengthCheck(3) | 345 p.startLengthCheck(3) |
302 self.server_version = (p.get(1), p.get(1)) | 346 self.server_version = (p.get(1), p.get(1)) |
303 self.random = p.getFixBytes(32) | 347 self.random = p.getFixBytes(32) |
304 self.session_id = p.getVarBytes(1) | 348 self.session_id = p.getVarBytes(1) |
305 self.cipher_suite = p.get(2) | 349 self.cipher_suite = p.get(2) |
306 self.compression_method = p.get(1) | 350 self.compression_method = p.get(1) |
307 if not p.atLengthCheck(): | 351 if not p.atLengthCheck(): |
308 totalExtLength = p.get(2) | 352 totalExtLength = p.get(2) |
309 soFar = 0 | 353 soFar = 0 |
310 while soFar != totalExtLength: | 354 while soFar != totalExtLength: |
311 extType = p.get(2) | 355 extType = p.get(2) |
312 extLength = p.get(2) | 356 extLength = p.get(2) |
313 if extType == ExtensionType.cert_type: | 357 if extType == ExtensionType.cert_type: |
314 if extLength != 1: | 358 if extLength != 1: |
315 raise SyntaxError() | 359 raise SyntaxError() |
316 self.certificate_type = p.get(1) | 360 self.certificate_type = p.get(1) |
317 elif extType == ExtensionType.tack and tackpyLoaded: | 361 elif extType == ExtensionType.tack and tackpyLoaded: |
318 self.tackExt = TackExtension(p.getFixBytes(extLength)) | 362 self.tackExt = TackExtension(p.getFixBytes(extLength)) |
363 elif extType == ExtensionType.alpn: | |
364 structLength = p.get(2) | |
365 if (structLength + 2 != extLength): | |
davidben
2016/08/03 23:34:22
Nit: No parens in Python
Bence
2016/08/04 18:41:44
Done.
| |
366 raise SyntaxError() | |
367 alpn_protos = self.parse_next_protos(p.getFixBytes(structLen gth)) | |
368 if (alpn_protos.len() != 1): | |
davidben
2016/08/03 23:34:22
Ditto.
Bence
2016/08/04 18:41:44
Done.
| |
369 raise InvalidAlpnResponse(alpn_protos.len()); | |
370 self.alpn_proto_selected = alpn_protos[0] | |
319 elif extType == ExtensionType.supports_npn: | 371 elif extType == ExtensionType.supports_npn: |
320 self.next_protos = self.__parse_next_protos(p.getFixBytes(ex tLength)) | 372 self.next_protos = self.parse_next_protos(p.getFixBytes(extL ength)) |
321 else: | 373 else: |
322 p.getFixBytes(extLength) | 374 p.getFixBytes(extLength) |
323 soFar += 4 + extLength | 375 soFar += 4 + extLength |
324 p.stopLengthCheck() | 376 p.stopLengthCheck() |
325 return self | 377 return self |
326 | 378 |
327 def __parse_next_protos(self, b): | |
328 protos = [] | |
329 while True: | |
330 if len(b) == 0: | |
331 break | |
332 l = b[0] | |
333 b = b[1:] | |
334 if len(b) < l: | |
335 raise BadNextProtos(len(b)) | |
336 protos.append(b[:l]) | |
337 b = b[l:] | |
338 return protos | |
339 | |
340 def __next_protos_encoded(self): | |
341 b = bytearray() | |
342 for e in self.next_protos_advertised: | |
343 if len(e) > 255 or len(e) == 0: | |
344 raise BadNextProtos(len(e)) | |
345 b += bytearray( [len(e)] ) + bytearray(e) | |
346 return b | |
347 | |
348 def write(self): | 379 def write(self): |
349 w = Writer() | 380 w = Writer() |
350 w.add(self.server_version[0], 1) | 381 w.add(self.server_version[0], 1) |
351 w.add(self.server_version[1], 1) | 382 w.add(self.server_version[1], 1) |
352 w.addFixSeq(self.random, 1) | 383 w.addFixSeq(self.random, 1) |
353 w.addVarSeq(self.session_id, 1, 1) | 384 w.addVarSeq(self.session_id, 1, 1) |
354 w.add(self.cipher_suite, 2) | 385 w.add(self.cipher_suite, 2) |
355 w.add(self.compression_method, 1) | 386 w.add(self.compression_method, 1) |
356 | 387 |
357 w2 = Writer() # For Extensions | 388 w2 = Writer() # For Extensions |
358 if self.certificate_type and self.certificate_type != \ | 389 if self.certificate_type and self.certificate_type != \ |
359 CertificateType.x509: | 390 CertificateType.x509: |
360 w2.add(ExtensionType.cert_type, 2) | 391 w2.add(ExtensionType.cert_type, 2) |
361 w2.add(1, 2) | 392 w2.add(1, 2) |
362 w2.add(self.certificate_type, 1) | 393 w2.add(self.certificate_type, 1) |
363 if self.tackExt: | 394 if self.tackExt: |
364 b = self.tackExt.serialize() | 395 b = self.tackExt.serialize() |
365 w2.add(ExtensionType.tack, 2) | 396 w2.add(ExtensionType.tack, 2) |
366 w2.add(len(b), 2) | 397 w2.add(len(b), 2) |
367 w2.bytes += b | 398 w2.bytes += b |
368 if self.next_protos_advertised is not None: | 399 if self.alpn_proto_selected is not None: |
369 encoded_next_protos_advertised = self.__next_protos_encoded() | 400 alpn_protos_single_element_list = [self.alpn_proto_selected] |
401 encoded_alpn_protos_advertised = self.next_protos_encoded(alpn_proto s_single_element_list) | |
402 w2.add(ExtensionType.alpn, 2) | |
403 w2.add(len(encoded_alpn_protos_advertised) + 2, 2) | |
404 w2.add(len(encoded_alpn_protos_advertised), 2) | |
405 w2.addFixSeq(encoded_alpn_protos_advertised, 1) | |
406 # Do not use NPN if ALPN is used. | |
407 elif self.next_protos_advertised is not None: | |
408 encoded_next_protos_advertised = self.next_protos_encoded(self.next_ protos_advertised) | |
370 w2.add(ExtensionType.supports_npn, 2) | 409 w2.add(ExtensionType.supports_npn, 2) |
371 w2.add(len(encoded_next_protos_advertised), 2) | 410 w2.add(len(encoded_next_protos_advertised), 2) |
372 w2.addFixSeq(encoded_next_protos_advertised, 1) | 411 w2.addFixSeq(encoded_next_protos_advertised, 1) |
373 if self.channel_id: | 412 if self.channel_id: |
374 w2.add(ExtensionType.channel_id, 2) | 413 w2.add(ExtensionType.channel_id, 2) |
375 w2.add(0, 2) | 414 w2.add(0, 2) |
376 if self.extended_master_secret: | 415 if self.extended_master_secret: |
377 w2.add(ExtensionType.extended_master_secret, 2) | 416 w2.add(ExtensionType.extended_master_secret, 2) |
378 w2.add(0, 2) | 417 w2.add(0, 2) |
379 if self.tb_params: | 418 if self.tb_params: |
(...skipping 445 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
825 newMsg = ApplicationData().create(self.bytes[:1]) | 864 newMsg = ApplicationData().create(self.bytes[:1]) |
826 self.bytes = self.bytes[1:] | 865 self.bytes = self.bytes[1:] |
827 return newMsg | 866 return newMsg |
828 | 867 |
829 def parse(self, p): | 868 def parse(self, p): |
830 self.bytes = p.bytes | 869 self.bytes = p.bytes |
831 return self | 870 return self |
832 | 871 |
833 def write(self): | 872 def write(self): |
834 return self.bytes | 873 return self.bytes |
OLD | NEW |