OLD | NEW |
| (Empty) |
1 """Helper class for TLSConnection.""" | |
2 from __future__ import generators | |
3 | |
4 from utils.compat import * | |
5 from utils.cryptomath import * | |
6 from utils.cipherfactory import createAES, createRC4, createTripleDES | |
7 from utils.codec import * | |
8 from errors import * | |
9 from messages import * | |
10 from mathtls import * | |
11 from constants import * | |
12 from utils.cryptomath import getRandomBytes | |
13 from utils import hmac | |
14 from FileObject import FileObject | |
15 | |
16 # The sha module is deprecated in Python 2.6 | |
17 try: | |
18 import sha | |
19 except ImportError: | |
20 from hashlib import sha1 as sha | |
21 | |
22 # The md5 module is deprecated in Python 2.6 | |
23 try: | |
24 import md5 | |
25 except ImportError: | |
26 from hashlib import md5 | |
27 | |
28 import socket | |
29 import errno | |
30 import traceback | |
31 | |
32 class _ConnectionState: | |
33 def __init__(self): | |
34 self.macContext = None | |
35 self.encContext = None | |
36 self.seqnum = 0 | |
37 | |
38 def getSeqNumStr(self): | |
39 w = Writer(8) | |
40 w.add(self.seqnum, 8) | |
41 seqnumStr = bytesToString(w.bytes) | |
42 self.seqnum += 1 | |
43 return seqnumStr | |
44 | |
45 | |
46 class TLSRecordLayer: | |
47 """ | |
48 This class handles data transmission for a TLS connection. | |
49 | |
50 Its only subclass is L{tlslite.TLSConnection.TLSConnection}. We've | |
51 separated the code in this class from TLSConnection to make things | |
52 more readable. | |
53 | |
54 | |
55 @type sock: socket.socket | |
56 @ivar sock: The underlying socket object. | |
57 | |
58 @type session: L{tlslite.Session.Session} | |
59 @ivar session: The session corresponding to this connection. | |
60 | |
61 Due to TLS session resumption, multiple connections can correspond | |
62 to the same underlying session. | |
63 | |
64 @type version: tuple | |
65 @ivar version: The TLS version being used for this connection. | |
66 | |
67 (3,0) means SSL 3.0, and (3,1) means TLS 1.0. | |
68 | |
69 @type closed: bool | |
70 @ivar closed: If this connection is closed. | |
71 | |
72 @type resumed: bool | |
73 @ivar resumed: If this connection is based on a resumed session. | |
74 | |
75 @type allegedSharedKeyUsername: str or None | |
76 @ivar allegedSharedKeyUsername: This is set to the shared-key | |
77 username asserted by the client, whether the handshake succeeded or | |
78 not. If the handshake fails, this can be inspected to | |
79 determine if a guessing attack is in progress against a particular | |
80 user account. | |
81 | |
82 @type allegedSrpUsername: str or None | |
83 @ivar allegedSrpUsername: This is set to the SRP username | |
84 asserted by the client, whether the handshake succeeded or not. | |
85 If the handshake fails, this can be inspected to determine | |
86 if a guessing attack is in progress against a particular user | |
87 account. | |
88 | |
89 @type closeSocket: bool | |
90 @ivar closeSocket: If the socket should be closed when the | |
91 connection is closed (writable). | |
92 | |
93 If you set this to True, TLS Lite will assume the responsibility of | |
94 closing the socket when the TLS Connection is shutdown (either | |
95 through an error or through the user calling close()). The default | |
96 is False. | |
97 | |
98 @type ignoreAbruptClose: bool | |
99 @ivar ignoreAbruptClose: If an abrupt close of the socket should | |
100 raise an error (writable). | |
101 | |
102 If you set this to True, TLS Lite will not raise a | |
103 L{tlslite.errors.TLSAbruptCloseError} exception if the underlying | |
104 socket is unexpectedly closed. Such an unexpected closure could be | |
105 caused by an attacker. However, it also occurs with some incorrect | |
106 TLS implementations. | |
107 | |
108 You should set this to True only if you're not worried about an | |
109 attacker truncating the connection, and only if necessary to avoid | |
110 spurious errors. The default is False. | |
111 | |
112 @sort: __init__, read, readAsync, write, writeAsync, close, closeAsync, | |
113 getCipherImplementation, getCipherName | |
114 """ | |
115 | |
116 def __init__(self, sock): | |
117 self.sock = sock | |
118 | |
119 #My session object (Session instance; read-only) | |
120 self.session = None | |
121 | |
122 #Am I a client or server? | |
123 self._client = None | |
124 | |
125 #Buffers for processing messages | |
126 self._handshakeBuffer = [] | |
127 self._readBuffer = "" | |
128 | |
129 #Handshake digests | |
130 self._handshake_md5 = md5.md5() | |
131 self._handshake_sha = sha.sha() | |
132 | |
133 #TLS Protocol Version | |
134 self.version = (0,0) #read-only | |
135 self._versionCheck = False #Once we choose a version, this is True | |
136 | |
137 #Current and Pending connection states | |
138 self._writeState = _ConnectionState() | |
139 self._readState = _ConnectionState() | |
140 self._pendingWriteState = _ConnectionState() | |
141 self._pendingReadState = _ConnectionState() | |
142 | |
143 #Is the connection open? | |
144 self.closed = True #read-only | |
145 self._refCount = 0 #Used to trigger closure | |
146 | |
147 #Is this a resumed (or shared-key) session? | |
148 self.resumed = False #read-only | |
149 | |
150 #What username did the client claim in his handshake? | |
151 self.allegedSharedKeyUsername = None | |
152 self.allegedSrpUsername = None | |
153 | |
154 #On a call to close(), do we close the socket? (writeable) | |
155 self.closeSocket = False | |
156 | |
157 #If the socket is abruptly closed, do we ignore it | |
158 #and pretend the connection was shut down properly? (writeable) | |
159 self.ignoreAbruptClose = False | |
160 | |
161 #Fault we will induce, for testing purposes | |
162 self.fault = None | |
163 | |
164 #********************************************************* | |
165 # Public Functions START | |
166 #********************************************************* | |
167 | |
168 def read(self, max=None, min=1): | |
169 """Read some data from the TLS connection. | |
170 | |
171 This function will block until at least 'min' bytes are | |
172 available (or the connection is closed). | |
173 | |
174 If an exception is raised, the connection will have been | |
175 automatically closed. | |
176 | |
177 @type max: int | |
178 @param max: The maximum number of bytes to return. | |
179 | |
180 @type min: int | |
181 @param min: The minimum number of bytes to return | |
182 | |
183 @rtype: str | |
184 @return: A string of no more than 'max' bytes, and no fewer | |
185 than 'min' (unless the connection has been closed, in which | |
186 case fewer than 'min' bytes may be returned). | |
187 | |
188 @raise socket.error: If a socket error occurs. | |
189 @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed | |
190 without a preceding alert. | |
191 @raise tlslite.errors.TLSAlert: If a TLS alert is signalled. | |
192 """ | |
193 for result in self.readAsync(max, min): | |
194 pass | |
195 return result | |
196 | |
197 def readAsync(self, max=None, min=1): | |
198 """Start a read operation on the TLS connection. | |
199 | |
200 This function returns a generator which behaves similarly to | |
201 read(). Successive invocations of the generator will return 0 | |
202 if it is waiting to read from the socket, 1 if it is waiting | |
203 to write to the socket, or a string if the read operation has | |
204 completed. | |
205 | |
206 @rtype: iterable | |
207 @return: A generator; see above for details. | |
208 """ | |
209 try: | |
210 while len(self._readBuffer)<min and not self.closed: | |
211 try: | |
212 for result in self._getMsg(ContentType.application_data): | |
213 if result in (0,1): | |
214 yield result | |
215 applicationData = result | |
216 self._readBuffer += bytesToString(applicationData.write()) | |
217 except TLSRemoteAlert, alert: | |
218 if alert.description != AlertDescription.close_notify: | |
219 raise | |
220 except TLSAbruptCloseError: | |
221 if not self.ignoreAbruptClose: | |
222 raise | |
223 else: | |
224 self._shutdown(True) | |
225 | |
226 if max == None: | |
227 max = len(self._readBuffer) | |
228 | |
229 returnStr = self._readBuffer[:max] | |
230 self._readBuffer = self._readBuffer[max:] | |
231 yield returnStr | |
232 except: | |
233 self._shutdown(False) | |
234 raise | |
235 | |
236 def write(self, s): | |
237 """Write some data to the TLS connection. | |
238 | |
239 This function will block until all the data has been sent. | |
240 | |
241 If an exception is raised, the connection will have been | |
242 automatically closed. | |
243 | |
244 @type s: str | |
245 @param s: The data to transmit to the other party. | |
246 | |
247 @raise socket.error: If a socket error occurs. | |
248 """ | |
249 for result in self.writeAsync(s): | |
250 pass | |
251 | |
252 def writeAsync(self, s): | |
253 """Start a write operation on the TLS connection. | |
254 | |
255 This function returns a generator which behaves similarly to | |
256 write(). Successive invocations of the generator will return | |
257 1 if it is waiting to write to the socket, or will raise | |
258 StopIteration if the write operation has completed. | |
259 | |
260 @rtype: iterable | |
261 @return: A generator; see above for details. | |
262 """ | |
263 try: | |
264 if self.closed: | |
265 raise ValueError() | |
266 | |
267 index = 0 | |
268 blockSize = 16384 | |
269 skipEmptyFrag = False | |
270 while 1: | |
271 startIndex = index * blockSize | |
272 endIndex = startIndex + blockSize | |
273 if startIndex >= len(s): | |
274 break | |
275 if endIndex > len(s): | |
276 endIndex = len(s) | |
277 block = stringToBytes(s[startIndex : endIndex]) | |
278 applicationData = ApplicationData().create(block) | |
279 for result in self._sendMsg(applicationData, skipEmptyFrag): | |
280 yield result | |
281 skipEmptyFrag = True #only send an empy fragment on 1st message | |
282 index += 1 | |
283 except: | |
284 self._shutdown(False) | |
285 raise | |
286 | |
287 def close(self): | |
288 """Close the TLS connection. | |
289 | |
290 This function will block until it has exchanged close_notify | |
291 alerts with the other party. After doing so, it will shut down the | |
292 TLS connection. Further attempts to read through this connection | |
293 will return "". Further attempts to write through this connection | |
294 will raise ValueError. | |
295 | |
296 If makefile() has been called on this connection, the connection | |
297 will be not be closed until the connection object and all file | |
298 objects have been closed. | |
299 | |
300 Even if an exception is raised, the connection will have been | |
301 closed. | |
302 | |
303 @raise socket.error: If a socket error occurs. | |
304 @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed | |
305 without a preceding alert. | |
306 @raise tlslite.errors.TLSAlert: If a TLS alert is signalled. | |
307 """ | |
308 if not self.closed: | |
309 for result in self._decrefAsync(): | |
310 pass | |
311 | |
312 def closeAsync(self): | |
313 """Start a close operation on the TLS connection. | |
314 | |
315 This function returns a generator which behaves similarly to | |
316 close(). Successive invocations of the generator will return 0 | |
317 if it is waiting to read from the socket, 1 if it is waiting | |
318 to write to the socket, or will raise StopIteration if the | |
319 close operation has completed. | |
320 | |
321 @rtype: iterable | |
322 @return: A generator; see above for details. | |
323 """ | |
324 if not self.closed: | |
325 for result in self._decrefAsync(): | |
326 yield result | |
327 | |
328 def _decrefAsync(self): | |
329 self._refCount -= 1 | |
330 if self._refCount == 0 and not self.closed: | |
331 try: | |
332 for result in self._sendMsg(Alert().create(\ | |
333 AlertDescription.close_notify, AlertLevel.warning)): | |
334 yield result | |
335 alert = None | |
336 # Forcing a shutdown as WinHTTP does not seem to be | |
337 # responsive to the close notify. | |
338 prevCloseSocket = self.closeSocket | |
339 self.closeSocket = True | |
340 self._shutdown(True) | |
341 self.closeSocket = prevCloseSocket | |
342 while not alert: | |
343 for result in self._getMsg((ContentType.alert, \ | |
344 ContentType.application_data)): | |
345 if result in (0,1): | |
346 yield result | |
347 if result.contentType == ContentType.alert: | |
348 alert = result | |
349 if alert.description == AlertDescription.close_notify: | |
350 self._shutdown(True) | |
351 else: | |
352 raise TLSRemoteAlert(alert) | |
353 except (socket.error, TLSAbruptCloseError): | |
354 #If the other side closes the socket, that's okay | |
355 self._shutdown(True) | |
356 except: | |
357 self._shutdown(False) | |
358 raise | |
359 | |
360 def getCipherName(self): | |
361 """Get the name of the cipher used with this connection. | |
362 | |
363 @rtype: str | |
364 @return: The name of the cipher used with this connection. | |
365 Either 'aes128', 'aes256', 'rc4', or '3des'. | |
366 """ | |
367 if not self._writeState.encContext: | |
368 return None | |
369 return self._writeState.encContext.name | |
370 | |
371 def getCipherImplementation(self): | |
372 """Get the name of the cipher implementation used with | |
373 this connection. | |
374 | |
375 @rtype: str | |
376 @return: The name of the cipher implementation used with | |
377 this connection. Either 'python', 'cryptlib', 'openssl', | |
378 or 'pycrypto'. | |
379 """ | |
380 if not self._writeState.encContext: | |
381 return None | |
382 return self._writeState.encContext.implementation | |
383 | |
384 | |
385 | |
386 #Emulate a socket, somewhat - | |
387 def send(self, s): | |
388 """Send data to the TLS connection (socket emulation). | |
389 | |
390 @raise socket.error: If a socket error occurs. | |
391 """ | |
392 self.write(s) | |
393 return len(s) | |
394 | |
395 def sendall(self, s): | |
396 """Send data to the TLS connection (socket emulation). | |
397 | |
398 @raise socket.error: If a socket error occurs. | |
399 """ | |
400 self.write(s) | |
401 | |
402 def recv(self, bufsize): | |
403 """Get some data from the TLS connection (socket emulation). | |
404 | |
405 @raise socket.error: If a socket error occurs. | |
406 @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed | |
407 without a preceding alert. | |
408 @raise tlslite.errors.TLSAlert: If a TLS alert is signalled. | |
409 """ | |
410 return self.read(bufsize) | |
411 | |
412 def makefile(self, mode='r', bufsize=-1): | |
413 """Create a file object for the TLS connection (socket emulation). | |
414 | |
415 @rtype: L{tlslite.FileObject.FileObject} | |
416 """ | |
417 self._refCount += 1 | |
418 return FileObject(self, mode, bufsize) | |
419 | |
420 def getsockname(self): | |
421 """Return the socket's own address (socket emulation).""" | |
422 return self.sock.getsockname() | |
423 | |
424 def getpeername(self): | |
425 """Return the remote address to which the socket is connected | |
426 (socket emulation).""" | |
427 return self.sock.getpeername() | |
428 | |
429 def settimeout(self, value): | |
430 """Set a timeout on blocking socket operations (socket emulation).""" | |
431 return self.sock.settimeout(value) | |
432 | |
433 def gettimeout(self): | |
434 """Return the timeout associated with socket operations (socket | |
435 emulation).""" | |
436 return self.sock.gettimeout() | |
437 | |
438 def setsockopt(self, level, optname, value): | |
439 """Set the value of the given socket option (socket emulation).""" | |
440 return self.sock.setsockopt(level, optname, value) | |
441 | |
442 | |
443 #********************************************************* | |
444 # Public Functions END | |
445 #********************************************************* | |
446 | |
447 def _shutdown(self, resumable): | |
448 self._writeState = _ConnectionState() | |
449 self._readState = _ConnectionState() | |
450 #Don't do this: self._readBuffer = "" | |
451 self.version = (0,0) | |
452 self._versionCheck = False | |
453 self.closed = True | |
454 if self.closeSocket: | |
455 self.sock.close() | |
456 | |
457 #Even if resumable is False, we'll never toggle this on | |
458 if not resumable and self.session: | |
459 self.session.resumable = False | |
460 | |
461 | |
462 def _sendError(self, alertDescription, errorStr=None): | |
463 alert = Alert().create(alertDescription, AlertLevel.fatal) | |
464 for result in self._sendMsg(alert): | |
465 yield result | |
466 self._shutdown(False) | |
467 raise TLSLocalAlert(alert, errorStr) | |
468 | |
469 def _sendMsgs(self, msgs): | |
470 skipEmptyFrag = False | |
471 for msg in msgs: | |
472 for result in self._sendMsg(msg, skipEmptyFrag): | |
473 yield result | |
474 skipEmptyFrag = True | |
475 | |
476 def _sendMsg(self, msg, skipEmptyFrag=False): | |
477 bytes = msg.write() | |
478 contentType = msg.contentType | |
479 | |
480 #Whenever we're connected and asked to send a message, | |
481 #we first send an empty Application Data message. This prevents | |
482 #an attacker from launching a chosen-plaintext attack based on | |
483 #knowing the next IV. | |
484 if not self.closed and not skipEmptyFrag and self.version == (3,1): | |
485 if self._writeState.encContext: | |
486 if self._writeState.encContext.isBlockCipher: | |
487 for result in self._sendMsg(ApplicationData(), | |
488 skipEmptyFrag=True): | |
489 yield result | |
490 | |
491 #Update handshake hashes | |
492 if contentType == ContentType.handshake: | |
493 bytesStr = bytesToString(bytes) | |
494 self._handshake_md5.update(bytesStr) | |
495 self._handshake_sha.update(bytesStr) | |
496 | |
497 #Calculate MAC | |
498 if self._writeState.macContext: | |
499 seqnumStr = self._writeState.getSeqNumStr() | |
500 bytesStr = bytesToString(bytes) | |
501 mac = self._writeState.macContext.copy() | |
502 mac.update(seqnumStr) | |
503 mac.update(chr(contentType)) | |
504 if self.version == (3,0): | |
505 mac.update( chr( int(len(bytes)/256) ) ) | |
506 mac.update( chr( int(len(bytes)%256) ) ) | |
507 elif self.version in ((3,1), (3,2)): | |
508 mac.update(chr(self.version[0])) | |
509 mac.update(chr(self.version[1])) | |
510 mac.update( chr( int(len(bytes)/256) ) ) | |
511 mac.update( chr( int(len(bytes)%256) ) ) | |
512 else: | |
513 raise AssertionError() | |
514 mac.update(bytesStr) | |
515 macString = mac.digest() | |
516 macBytes = stringToBytes(macString) | |
517 if self.fault == Fault.badMAC: | |
518 macBytes[0] = (macBytes[0]+1) % 256 | |
519 | |
520 #Encrypt for Block or Stream Cipher | |
521 if self._writeState.encContext: | |
522 #Add padding and encrypt (for Block Cipher): | |
523 if self._writeState.encContext.isBlockCipher: | |
524 | |
525 #Add TLS 1.1 fixed block | |
526 if self.version == (3,2): | |
527 bytes = self.fixedIVBlock + bytes | |
528 | |
529 #Add padding: bytes = bytes + (macBytes + paddingBytes) | |
530 currentLength = len(bytes) + len(macBytes) + 1 | |
531 blockLength = self._writeState.encContext.block_size | |
532 paddingLength = blockLength-(currentLength % blockLength) | |
533 | |
534 paddingBytes = createByteArraySequence([paddingLength] * \ | |
535 (paddingLength+1)) | |
536 if self.fault == Fault.badPadding: | |
537 paddingBytes[0] = (paddingBytes[0]+1) % 256 | |
538 endBytes = concatArrays(macBytes, paddingBytes) | |
539 bytes = concatArrays(bytes, endBytes) | |
540 #Encrypt | |
541 plaintext = stringToBytes(bytes) | |
542 ciphertext = self._writeState.encContext.encrypt(plaintext) | |
543 bytes = stringToBytes(ciphertext) | |
544 | |
545 #Encrypt (for Stream Cipher) | |
546 else: | |
547 bytes = concatArrays(bytes, macBytes) | |
548 plaintext = bytesToString(bytes) | |
549 ciphertext = self._writeState.encContext.encrypt(plaintext) | |
550 bytes = stringToBytes(ciphertext) | |
551 | |
552 #Add record header and send | |
553 r = RecordHeader3().create(self.version, contentType, len(bytes)) | |
554 s = bytesToString(concatArrays(r.write(), bytes)) | |
555 while 1: | |
556 try: | |
557 bytesSent = self.sock.send(s) #Might raise socket.error | |
558 except socket.error, why: | |
559 if why[0] == errno.EWOULDBLOCK: | |
560 yield 1 | |
561 continue | |
562 else: | |
563 raise | |
564 if bytesSent == len(s): | |
565 return | |
566 s = s[bytesSent:] | |
567 yield 1 | |
568 | |
569 | |
570 def _getMsg(self, expectedType, secondaryType=None, constructorType=None): | |
571 try: | |
572 if not isinstance(expectedType, tuple): | |
573 expectedType = (expectedType,) | |
574 | |
575 #Spin in a loop, until we've got a non-empty record of a type we | |
576 #expect. The loop will be repeated if: | |
577 # - we receive a renegotiation attempt; we send no_renegotiation, | |
578 # then try again | |
579 # - we receive an empty application-data fragment; we try again | |
580 while 1: | |
581 for result in self._getNextRecord(): | |
582 if result in (0,1): | |
583 yield result | |
584 recordHeader, p = result | |
585 | |
586 #If this is an empty application-data fragment, try again | |
587 if recordHeader.type == ContentType.application_data: | |
588 if p.index == len(p.bytes): | |
589 continue | |
590 | |
591 #If we received an unexpected record type... | |
592 if recordHeader.type not in expectedType: | |
593 | |
594 #If we received an alert... | |
595 if recordHeader.type == ContentType.alert: | |
596 alert = Alert().parse(p) | |
597 | |
598 #We either received a fatal error, a warning, or a | |
599 #close_notify. In any case, we're going to close the | |
600 #connection. In the latter two cases we respond with | |
601 #a close_notify, but ignore any socket errors, since | |
602 #the other side might have already closed the socket. | |
603 if alert.level == AlertLevel.warning or \ | |
604 alert.description == AlertDescription.close_notify: | |
605 | |
606 #If the sendMsg() call fails because the socket has | |
607 #already been closed, we will be forgiving and not | |
608 #report the error nor invalidate the "resumability" | |
609 #of the session. | |
610 try: | |
611 alertMsg = Alert() | |
612 alertMsg.create(AlertDescription.close_notify, | |
613 AlertLevel.warning) | |
614 for result in self._sendMsg(alertMsg): | |
615 yield result | |
616 except socket.error: | |
617 pass | |
618 | |
619 if alert.description == \ | |
620 AlertDescription.close_notify: | |
621 self._shutdown(True) | |
622 elif alert.level == AlertLevel.warning: | |
623 self._shutdown(False) | |
624 | |
625 else: #Fatal alert: | |
626 self._shutdown(False) | |
627 | |
628 #Raise the alert as an exception | |
629 raise TLSRemoteAlert(alert) | |
630 | |
631 #If we received a renegotiation attempt... | |
632 if recordHeader.type == ContentType.handshake: | |
633 subType = p.get(1) | |
634 reneg = False | |
635 if self._client: | |
636 if subType == HandshakeType.hello_request: | |
637 reneg = True | |
638 else: | |
639 if subType == HandshakeType.client_hello: | |
640 reneg = True | |
641 #Send no_renegotiation, then try again | |
642 if reneg: | |
643 alertMsg = Alert() | |
644 alertMsg.create(AlertDescription.no_renegotiation, | |
645 AlertLevel.warning) | |
646 for result in self._sendMsg(alertMsg): | |
647 yield result | |
648 continue | |
649 | |
650 #Otherwise: this is an unexpected record, but neither an | |
651 #alert nor renegotiation | |
652 for result in self._sendError(\ | |
653 AlertDescription.unexpected_message, | |
654 "received type=%d" % recordHeader.type): | |
655 yield result | |
656 | |
657 break | |
658 | |
659 #Parse based on content_type | |
660 if recordHeader.type == ContentType.change_cipher_spec: | |
661 yield ChangeCipherSpec().parse(p) | |
662 elif recordHeader.type == ContentType.alert: | |
663 yield Alert().parse(p) | |
664 elif recordHeader.type == ContentType.application_data: | |
665 yield ApplicationData().parse(p) | |
666 elif recordHeader.type == ContentType.handshake: | |
667 #Convert secondaryType to tuple, if it isn't already | |
668 if not isinstance(secondaryType, tuple): | |
669 secondaryType = (secondaryType,) | |
670 | |
671 #If it's a handshake message, check handshake header | |
672 if recordHeader.ssl2: | |
673 subType = p.get(1) | |
674 if subType != HandshakeType.client_hello: | |
675 for result in self._sendError(\ | |
676 AlertDescription.unexpected_message, | |
677 "Can only handle SSLv2 ClientHello messages"): | |
678 yield result | |
679 if HandshakeType.client_hello not in secondaryType: | |
680 for result in self._sendError(\ | |
681 AlertDescription.unexpected_message): | |
682 yield result | |
683 subType = HandshakeType.client_hello | |
684 else: | |
685 subType = p.get(1) | |
686 if subType not in secondaryType: | |
687 for result in self._sendError(\ | |
688 AlertDescription.unexpected_message, | |
689 "Expecting %s, got %s" % (str(secondaryType), su
bType)): | |
690 yield result | |
691 | |
692 #Update handshake hashes | |
693 sToHash = bytesToString(p.bytes) | |
694 self._handshake_md5.update(sToHash) | |
695 self._handshake_sha.update(sToHash) | |
696 | |
697 #Parse based on handshake type | |
698 if subType == HandshakeType.client_hello: | |
699 yield ClientHello(recordHeader.ssl2).parse(p) | |
700 elif subType == HandshakeType.server_hello: | |
701 yield ServerHello().parse(p) | |
702 elif subType == HandshakeType.certificate: | |
703 yield Certificate(constructorType).parse(p) | |
704 elif subType == HandshakeType.certificate_request: | |
705 yield CertificateRequest().parse(p) | |
706 elif subType == HandshakeType.certificate_verify: | |
707 yield CertificateVerify().parse(p) | |
708 elif subType == HandshakeType.server_key_exchange: | |
709 yield ServerKeyExchange(constructorType).parse(p) | |
710 elif subType == HandshakeType.server_hello_done: | |
711 yield ServerHelloDone().parse(p) | |
712 elif subType == HandshakeType.client_key_exchange: | |
713 yield ClientKeyExchange(constructorType, \ | |
714 self.version).parse(p) | |
715 elif subType == HandshakeType.finished: | |
716 yield Finished(self.version).parse(p) | |
717 elif subType == HandshakeType.encrypted_extensions: | |
718 yield EncryptedExtensions().parse(p) | |
719 else: | |
720 raise AssertionError() | |
721 | |
722 #If an exception was raised by a Parser or Message instance: | |
723 except SyntaxError, e: | |
724 for result in self._sendError(AlertDescription.decode_error, | |
725 formatExceptionTrace(e)): | |
726 yield result | |
727 | |
728 | |
729 #Returns next record or next handshake message | |
730 def _getNextRecord(self): | |
731 | |
732 #If there's a handshake message waiting, return it | |
733 if self._handshakeBuffer: | |
734 recordHeader, bytes = self._handshakeBuffer[0] | |
735 self._handshakeBuffer = self._handshakeBuffer[1:] | |
736 yield (recordHeader, Parser(bytes)) | |
737 return | |
738 | |
739 #Otherwise... | |
740 #Read the next record header | |
741 bytes = createByteArraySequence([]) | |
742 recordHeaderLength = 1 | |
743 ssl2 = False | |
744 while 1: | |
745 try: | |
746 s = self.sock.recv(recordHeaderLength-len(bytes)) | |
747 except socket.error, why: | |
748 if why[0] == errno.EWOULDBLOCK: | |
749 yield 0 | |
750 continue | |
751 else: | |
752 raise | |
753 | |
754 #If the connection was abruptly closed, raise an error | |
755 if len(s)==0: | |
756 raise TLSAbruptCloseError() | |
757 | |
758 bytes += stringToBytes(s) | |
759 if len(bytes)==1: | |
760 if bytes[0] in ContentType.all: | |
761 ssl2 = False | |
762 recordHeaderLength = 5 | |
763 elif bytes[0] == 128: | |
764 ssl2 = True | |
765 recordHeaderLength = 2 | |
766 else: | |
767 raise SyntaxError() | |
768 if len(bytes) == recordHeaderLength: | |
769 break | |
770 | |
771 #Parse the record header | |
772 if ssl2: | |
773 r = RecordHeader2().parse(Parser(bytes)) | |
774 else: | |
775 r = RecordHeader3().parse(Parser(bytes)) | |
776 | |
777 #Check the record header fields | |
778 if r.length > 18432: | |
779 for result in self._sendError(AlertDescription.record_overflow): | |
780 yield result | |
781 | |
782 #Read the record contents | |
783 bytes = createByteArraySequence([]) | |
784 while 1: | |
785 try: | |
786 s = self.sock.recv(r.length - len(bytes)) | |
787 except socket.error, why: | |
788 if why[0] == errno.EWOULDBLOCK: | |
789 yield 0 | |
790 continue | |
791 else: | |
792 raise | |
793 | |
794 #If the connection is closed, raise a socket error | |
795 if len(s)==0: | |
796 raise TLSAbruptCloseError() | |
797 | |
798 bytes += stringToBytes(s) | |
799 if len(bytes) == r.length: | |
800 break | |
801 | |
802 #Check the record header fields (2) | |
803 #We do this after reading the contents from the socket, so that | |
804 #if there's an error, we at least don't leave extra bytes in the | |
805 #socket.. | |
806 # | |
807 # THIS CHECK HAS NO SECURITY RELEVANCE (?), BUT COULD HURT INTEROP. | |
808 # SO WE LEAVE IT OUT FOR NOW. | |
809 # | |
810 #if self._versionCheck and r.version != self.version: | |
811 # for result in self._sendError(AlertDescription.protocol_version, | |
812 # "Version in header field: %s, should be %s" % (str(r.versio
n), | |
813 # str(self.version
))): | |
814 # yield result | |
815 | |
816 #Decrypt the record | |
817 for result in self._decryptRecord(r.type, bytes): | |
818 if result in (0,1): | |
819 yield result | |
820 else: | |
821 break | |
822 bytes = result | |
823 p = Parser(bytes) | |
824 | |
825 #If it doesn't contain handshake messages, we can just return it | |
826 if r.type != ContentType.handshake: | |
827 yield (r, p) | |
828 #If it's an SSLv2 ClientHello, we can return it as well | |
829 elif r.ssl2: | |
830 yield (r, p) | |
831 else: | |
832 #Otherwise, we loop through and add the handshake messages to the | |
833 #handshake buffer | |
834 while 1: | |
835 if p.index == len(bytes): #If we're at the end | |
836 if not self._handshakeBuffer: | |
837 for result in self._sendError(\ | |
838 AlertDescription.decode_error, \ | |
839 "Received empty handshake record"): | |
840 yield result | |
841 break | |
842 #There needs to be at least 4 bytes to get a header | |
843 if p.index+4 > len(bytes): | |
844 for result in self._sendError(\ | |
845 AlertDescription.decode_error, | |
846 "A record has a partial handshake message (1)"): | |
847 yield result | |
848 p.get(1) # skip handshake type | |
849 msgLength = p.get(3) | |
850 if p.index+msgLength > len(bytes): | |
851 for result in self._sendError(\ | |
852 AlertDescription.decode_error, | |
853 "A record has a partial handshake message (2)"): | |
854 yield result | |
855 | |
856 handshakePair = (r, bytes[p.index-4 : p.index+msgLength]) | |
857 self._handshakeBuffer.append(handshakePair) | |
858 p.index += msgLength | |
859 | |
860 #We've moved at least one handshake message into the | |
861 #handshakeBuffer, return the first one | |
862 recordHeader, bytes = self._handshakeBuffer[0] | |
863 self._handshakeBuffer = self._handshakeBuffer[1:] | |
864 yield (recordHeader, Parser(bytes)) | |
865 | |
866 | |
867 def _decryptRecord(self, recordType, bytes): | |
868 if self._readState.encContext: | |
869 | |
870 #Decrypt if it's a block cipher | |
871 if self._readState.encContext.isBlockCipher: | |
872 blockLength = self._readState.encContext.block_size | |
873 if len(bytes) % blockLength != 0: | |
874 for result in self._sendError(\ | |
875 AlertDescription.decryption_failed, | |
876 "Encrypted data not a multiple of blocksize"): | |
877 yield result | |
878 ciphertext = bytesToString(bytes) | |
879 plaintext = self._readState.encContext.decrypt(ciphertext) | |
880 if self.version == (3,2): #For TLS 1.1, remove explicit IV | |
881 plaintext = plaintext[self._readState.encContext.block_size
: ] | |
882 bytes = stringToBytes(plaintext) | |
883 | |
884 #Check padding | |
885 paddingGood = True | |
886 paddingLength = bytes[-1] | |
887 if (paddingLength+1) > len(bytes): | |
888 paddingGood=False | |
889 totalPaddingLength = 0 | |
890 else: | |
891 if self.version == (3,0): | |
892 totalPaddingLength = paddingLength+1 | |
893 elif self.version in ((3,1), (3,2)): | |
894 totalPaddingLength = paddingLength+1 | |
895 paddingBytes = bytes[-totalPaddingLength:-1] | |
896 for byte in paddingBytes: | |
897 if byte != paddingLength: | |
898 paddingGood = False | |
899 totalPaddingLength = 0 | |
900 else: | |
901 raise AssertionError() | |
902 | |
903 #Decrypt if it's a stream cipher | |
904 else: | |
905 paddingGood = True | |
906 ciphertext = bytesToString(bytes) | |
907 plaintext = self._readState.encContext.decrypt(ciphertext) | |
908 bytes = stringToBytes(plaintext) | |
909 totalPaddingLength = 0 | |
910 | |
911 #Check MAC | |
912 macGood = True | |
913 macLength = self._readState.macContext.digest_size | |
914 endLength = macLength + totalPaddingLength | |
915 if endLength > len(bytes): | |
916 macGood = False | |
917 else: | |
918 #Read MAC | |
919 startIndex = len(bytes) - endLength | |
920 endIndex = startIndex + macLength | |
921 checkBytes = bytes[startIndex : endIndex] | |
922 | |
923 #Calculate MAC | |
924 seqnumStr = self._readState.getSeqNumStr() | |
925 bytes = bytes[:-endLength] | |
926 bytesStr = bytesToString(bytes) | |
927 mac = self._readState.macContext.copy() | |
928 mac.update(seqnumStr) | |
929 mac.update(chr(recordType)) | |
930 if self.version == (3,0): | |
931 mac.update( chr( int(len(bytes)/256) ) ) | |
932 mac.update( chr( int(len(bytes)%256) ) ) | |
933 elif self.version in ((3,1), (3,2)): | |
934 mac.update(chr(self.version[0])) | |
935 mac.update(chr(self.version[1])) | |
936 mac.update( chr( int(len(bytes)/256) ) ) | |
937 mac.update( chr( int(len(bytes)%256) ) ) | |
938 else: | |
939 raise AssertionError() | |
940 mac.update(bytesStr) | |
941 macString = mac.digest() | |
942 macBytes = stringToBytes(macString) | |
943 | |
944 #Compare MACs | |
945 if macBytes != checkBytes: | |
946 macGood = False | |
947 | |
948 if not (paddingGood and macGood): | |
949 for result in self._sendError(AlertDescription.bad_record_mac, | |
950 "MAC failure (or padding failure)"): | |
951 yield result | |
952 | |
953 yield bytes | |
954 | |
955 def _handshakeStart(self, client): | |
956 self._client = client | |
957 self._handshake_md5 = md5.md5() | |
958 self._handshake_sha = sha.sha() | |
959 self._handshakeBuffer = [] | |
960 self.allegedSharedKeyUsername = None | |
961 self.allegedSrpUsername = None | |
962 self._refCount = 1 | |
963 | |
964 def _handshakeDone(self, resumed): | |
965 self.resumed = resumed | |
966 self.closed = False | |
967 | |
968 def _calcPendingStates(self, clientRandom, serverRandom, implementations): | |
969 if self.session.cipherSuite in CipherSuite.aes128Suites: | |
970 macLength = 20 | |
971 keyLength = 16 | |
972 ivLength = 16 | |
973 createCipherFunc = createAES | |
974 elif self.session.cipherSuite in CipherSuite.aes256Suites: | |
975 macLength = 20 | |
976 keyLength = 32 | |
977 ivLength = 16 | |
978 createCipherFunc = createAES | |
979 elif self.session.cipherSuite in CipherSuite.rc4Suites: | |
980 macLength = 20 | |
981 keyLength = 16 | |
982 ivLength = 0 | |
983 createCipherFunc = createRC4 | |
984 elif self.session.cipherSuite in CipherSuite.tripleDESSuites: | |
985 macLength = 20 | |
986 keyLength = 24 | |
987 ivLength = 8 | |
988 createCipherFunc = createTripleDES | |
989 else: | |
990 raise AssertionError() | |
991 | |
992 if self.version == (3,0): | |
993 createMACFunc = MAC_SSL | |
994 elif self.version in ((3,1), (3,2)): | |
995 createMACFunc = hmac.HMAC | |
996 | |
997 outputLength = (macLength*2) + (keyLength*2) + (ivLength*2) | |
998 | |
999 #Calculate Keying Material from Master Secret | |
1000 if self.version == (3,0): | |
1001 keyBlock = PRF_SSL(self.session.masterSecret, | |
1002 concatArrays(serverRandom, clientRandom), | |
1003 outputLength) | |
1004 elif self.version in ((3,1), (3,2)): | |
1005 keyBlock = PRF(self.session.masterSecret, | |
1006 "key expansion", | |
1007 concatArrays(serverRandom,clientRandom), | |
1008 outputLength) | |
1009 else: | |
1010 raise AssertionError() | |
1011 | |
1012 #Slice up Keying Material | |
1013 clientPendingState = _ConnectionState() | |
1014 serverPendingState = _ConnectionState() | |
1015 p = Parser(keyBlock) | |
1016 clientMACBlock = bytesToString(p.getFixBytes(macLength)) | |
1017 serverMACBlock = bytesToString(p.getFixBytes(macLength)) | |
1018 clientKeyBlock = bytesToString(p.getFixBytes(keyLength)) | |
1019 serverKeyBlock = bytesToString(p.getFixBytes(keyLength)) | |
1020 clientIVBlock = bytesToString(p.getFixBytes(ivLength)) | |
1021 serverIVBlock = bytesToString(p.getFixBytes(ivLength)) | |
1022 clientPendingState.macContext = createMACFunc(clientMACBlock, | |
1023 digestmod=sha) | |
1024 serverPendingState.macContext = createMACFunc(serverMACBlock, | |
1025 digestmod=sha) | |
1026 clientPendingState.encContext = createCipherFunc(clientKeyBlock, | |
1027 clientIVBlock, | |
1028 implementations) | |
1029 serverPendingState.encContext = createCipherFunc(serverKeyBlock, | |
1030 serverIVBlock, | |
1031 implementations) | |
1032 | |
1033 #Assign new connection states to pending states | |
1034 if self._client: | |
1035 self._pendingWriteState = clientPendingState | |
1036 self._pendingReadState = serverPendingState | |
1037 else: | |
1038 self._pendingWriteState = serverPendingState | |
1039 self._pendingReadState = clientPendingState | |
1040 | |
1041 if self.version == (3,2) and ivLength: | |
1042 #Choose fixedIVBlock for TLS 1.1 (this is encrypted with the CBC | |
1043 #residue to create the IV for each sent block) | |
1044 self.fixedIVBlock = getRandomBytes(ivLength) | |
1045 | |
1046 def _changeWriteState(self): | |
1047 self._writeState = self._pendingWriteState | |
1048 self._pendingWriteState = _ConnectionState() | |
1049 | |
1050 def _changeReadState(self): | |
1051 self._readState = self._pendingReadState | |
1052 self._pendingReadState = _ConnectionState() | |
1053 | |
1054 def _sendFinished(self): | |
1055 #Send ChangeCipherSpec | |
1056 for result in self._sendMsg(ChangeCipherSpec()): | |
1057 yield result | |
1058 | |
1059 #Switch to pending write state | |
1060 self._changeWriteState() | |
1061 | |
1062 #Calculate verification data | |
1063 verifyData = self._calcFinished(True) | |
1064 if self.fault == Fault.badFinished: | |
1065 verifyData[0] = (verifyData[0]+1)%256 | |
1066 | |
1067 #Send Finished message under new state | |
1068 finished = Finished(self.version).create(verifyData) | |
1069 for result in self._sendMsg(finished): | |
1070 yield result | |
1071 | |
1072 def _getChangeCipherSpec(self): | |
1073 #Get and check ChangeCipherSpec | |
1074 for result in self._getMsg(ContentType.change_cipher_spec): | |
1075 if result in (0,1): | |
1076 yield result | |
1077 changeCipherSpec = result | |
1078 | |
1079 if changeCipherSpec.type != 1: | |
1080 for result in self._sendError(AlertDescription.illegal_parameter, | |
1081 "ChangeCipherSpec type incorrect"): | |
1082 yield result | |
1083 | |
1084 #Switch to pending read state | |
1085 self._changeReadState() | |
1086 | |
1087 def _getEncryptedExtensions(self): | |
1088 for result in self._getMsg(ContentType.handshake, | |
1089 HandshakeType.encrypted_extensions): | |
1090 if result in (0,1): | |
1091 yield result | |
1092 encrypted_extensions = result | |
1093 self.channel_id = encrypted_extensions.channel_id_key | |
1094 | |
1095 def _getFinished(self): | |
1096 #Calculate verification data | |
1097 verifyData = self._calcFinished(False) | |
1098 | |
1099 #Get and check Finished message under new state | |
1100 for result in self._getMsg(ContentType.handshake, | |
1101 HandshakeType.finished): | |
1102 if result in (0,1): | |
1103 yield result | |
1104 finished = result | |
1105 if finished.verify_data != verifyData: | |
1106 for result in self._sendError(AlertDescription.decrypt_error, | |
1107 "Finished message is incorrect"): | |
1108 yield result | |
1109 | |
1110 def _calcFinished(self, send=True): | |
1111 if self.version == (3,0): | |
1112 if (self._client and send) or (not self._client and not send): | |
1113 senderStr = "\x43\x4C\x4E\x54" | |
1114 else: | |
1115 senderStr = "\x53\x52\x56\x52" | |
1116 | |
1117 verifyData = self._calcSSLHandshakeHash(self.session.masterSecret, | |
1118 senderStr) | |
1119 return verifyData | |
1120 | |
1121 elif self.version in ((3,1), (3,2)): | |
1122 if (self._client and send) or (not self._client and not send): | |
1123 label = "client finished" | |
1124 else: | |
1125 label = "server finished" | |
1126 | |
1127 handshakeHashes = stringToBytes(self._handshake_md5.digest() + \ | |
1128 self._handshake_sha.digest()) | |
1129 verifyData = PRF(self.session.masterSecret, label, handshakeHashes, | |
1130 12) | |
1131 return verifyData | |
1132 else: | |
1133 raise AssertionError() | |
1134 | |
1135 #Used for Finished messages and CertificateVerify messages in SSL v3 | |
1136 def _calcSSLHandshakeHash(self, masterSecret, label): | |
1137 masterSecretStr = bytesToString(masterSecret) | |
1138 | |
1139 imac_md5 = self._handshake_md5.copy() | |
1140 imac_sha = self._handshake_sha.copy() | |
1141 | |
1142 imac_md5.update(label + masterSecretStr + '\x36'*48) | |
1143 imac_sha.update(label + masterSecretStr + '\x36'*40) | |
1144 | |
1145 md5Str = md5.md5(masterSecretStr + ('\x5c'*48) + \ | |
1146 imac_md5.digest()).digest() | |
1147 shaStr = sha.sha(masterSecretStr + ('\x5c'*40) + \ | |
1148 imac_sha.digest()).digest() | |
1149 | |
1150 return stringToBytes(md5Str + shaStr) | |
OLD | NEW |