Index: third_party/tlslite/tlslite/tlsrecordlayer.py |
diff --git a/third_party/tlslite/tlslite/tlsrecordlayer.py b/third_party/tlslite/tlslite/tlsrecordlayer.py |
index bcfd64061ac4f98650ae0156a73e778d6ef2fb48..8b92221ae2437a18c4c7aa6cb4cf7c13b20bb524 100644 |
--- a/third_party/tlslite/tlslite/tlsrecordlayer.py |
+++ b/third_party/tlslite/tlslite/tlsrecordlayer.py |
@@ -1,49 +1,41 @@ |
+# Authors: |
+# Trevor Perrin |
+# Google (adapted by Sam Rushing) - NPN support |
+# Martin von Loewis - python 3 port |
+# |
+# See the LICENSE file for legal information regarding use of this file. |
+ |
"""Helper class for TLSConnection.""" |
from __future__ import generators |
-from utils.compat import * |
-from utils.cryptomath import * |
-from utils.cipherfactory import createAES, createRC4, createTripleDES |
-from utils.codec import * |
-from errors import * |
-from messages import * |
-from mathtls import * |
-from constants import * |
-from utils.cryptomath import getRandomBytes |
-from utils import hmac |
-from fileobject import FileObject |
- |
-# The sha module is deprecated in Python 2.6 |
-try: |
- import sha |
-except ImportError: |
- from hashlib import sha1 as sha |
- |
-# The md5 module is deprecated in Python 2.6 |
-try: |
- import md5 |
-except ImportError: |
- from hashlib import md5 |
+from .utils.compat import * |
+from .utils.cryptomath import * |
+from .utils.cipherfactory import createAES, createRC4, createTripleDES |
+from .utils.codec import * |
+from .errors import * |
+from .messages import * |
+from .mathtls import * |
+from .constants import * |
+from .utils.cryptomath import getRandomBytes |
import socket |
import errno |
import traceback |
-class _ConnectionState: |
+class _ConnectionState(object): |
def __init__(self): |
self.macContext = None |
self.encContext = None |
self.seqnum = 0 |
- def getSeqNumStr(self): |
- w = Writer(8) |
+ def getSeqNumBytes(self): |
+ w = Writer() |
w.add(self.seqnum, 8) |
- seqnumStr = bytesToString(w.bytes) |
self.seqnum += 1 |
- return seqnumStr |
+ return w.bytes |
-class TLSRecordLayer: |
+class TLSRecordLayer(object): |
""" |
This class handles data transmission for a TLS connection. |
@@ -72,13 +64,6 @@ class TLSRecordLayer: |
@type resumed: bool |
@ivar resumed: If this connection is based on a resumed session. |
- @type allegedSharedKeyUsername: str or None |
- @ivar allegedSharedKeyUsername: This is set to the shared-key |
- username asserted by the client, whether the handshake succeeded or |
- not. If the handshake fails, this can be inspected to |
- determine if a guessing attack is in progress against a particular |
- user account. |
- |
@type allegedSrpUsername: str or None |
@ivar allegedSrpUsername: This is set to the SRP username |
asserted by the client, whether the handshake succeeded or not. |
@@ -88,7 +73,7 @@ class TLSRecordLayer: |
@type closeSocket: bool |
@ivar closeSocket: If the socket should be closed when the |
- connection is closed (writable). |
+ connection is closed, defaults to True (writable). |
If you set this to True, TLS Lite will assume the responsibility of |
closing the socket when the TLS Connection is shutdown (either |
@@ -124,11 +109,12 @@ class TLSRecordLayer: |
#Buffers for processing messages |
self._handshakeBuffer = [] |
- self._readBuffer = "" |
+ self.clearReadBuffer() |
+ self.clearWriteBuffer() |
#Handshake digests |
- self._handshake_md5 = md5.md5() |
- self._handshake_sha = sha.sha() |
+ self._handshake_md5 = hashlib.md5() |
+ self._handshake_sha = hashlib.sha1() |
#TLS Protocol Version |
self.version = (0,0) #read-only |
@@ -144,15 +130,14 @@ class TLSRecordLayer: |
self.closed = True #read-only |
self._refCount = 0 #Used to trigger closure |
- #Is this a resumed (or shared-key) session? |
+ #Is this a resumed session? |
self.resumed = False #read-only |
#What username did the client claim in his handshake? |
- self.allegedSharedKeyUsername = None |
self.allegedSrpUsername = None |
#On a call to close(), do we close the socket? (writeable) |
- self.closeSocket = False |
+ self.closeSocket = True |
#If the socket is abruptly closed, do we ignore it |
#and pretend the connection was shut down properly? (writeable) |
@@ -161,6 +146,13 @@ class TLSRecordLayer: |
#Fault we will induce, for testing purposes |
self.fault = None |
+ def clearReadBuffer(self): |
+ self._readBuffer = b'' |
+ |
+ def clearWriteBuffer(self): |
+ self._send_writer = None |
+ |
+ |
#********************************************************* |
# Public Functions START |
#********************************************************* |
@@ -213,8 +205,8 @@ class TLSRecordLayer: |
if result in (0,1): |
yield result |
applicationData = result |
- self._readBuffer += bytesToString(applicationData.write()) |
- except TLSRemoteAlert, alert: |
+ self._readBuffer += applicationData.write() |
+ except TLSRemoteAlert as alert: |
if alert.description != AlertDescription.close_notify: |
raise |
except TLSAbruptCloseError: |
@@ -226,13 +218,23 @@ class TLSRecordLayer: |
if max == None: |
max = len(self._readBuffer) |
- returnStr = self._readBuffer[:max] |
+ returnBytes = self._readBuffer[:max] |
self._readBuffer = self._readBuffer[max:] |
- yield returnStr |
+ yield bytes(returnBytes) |
+ except GeneratorExit: |
+ raise |
except: |
self._shutdown(False) |
raise |
+ def unread(self, b): |
+ """Add bytes to the front of the socket read buffer for future |
+ reading. Be careful using this in the context of select(...): if you |
+ unread the last data from a socket, that won't wake up selected waiters, |
+ and those waiters may hang forever. |
+ """ |
+ self._readBuffer = b + self._readBuffer |
+ |
def write(self, s): |
"""Write some data to the TLS connection. |
@@ -262,11 +264,11 @@ class TLSRecordLayer: |
""" |
try: |
if self.closed: |
- raise ValueError() |
+ raise TLSClosedConnectionError("attempt to write to closed connection") |
index = 0 |
blockSize = 16384 |
- skipEmptyFrag = False |
+ randomizeFirstBlock = True |
while 1: |
startIndex = index * blockSize |
endIndex = startIndex + blockSize |
@@ -274,13 +276,16 @@ class TLSRecordLayer: |
break |
if endIndex > len(s): |
endIndex = len(s) |
- block = stringToBytes(s[startIndex : endIndex]) |
+ block = bytearray(s[startIndex : endIndex]) |
applicationData = ApplicationData().create(block) |
- for result in self._sendMsg(applicationData, skipEmptyFrag): |
+ for result in self._sendMsg(applicationData, \ |
+ randomizeFirstBlock): |
yield result |
- skipEmptyFrag = True #only send an empy fragment on 1st message |
+ randomizeFirstBlock = False #only on 1st message |
index += 1 |
- except: |
+ except GeneratorExit: |
+ raise |
+ except Exception: |
self._shutdown(False) |
raise |
@@ -309,6 +314,9 @@ class TLSRecordLayer: |
for result in self._decrefAsync(): |
pass |
+ # Python 3 callback |
+ _decref_socketios = close |
+ |
def closeAsync(self): |
"""Start a close operation on the TLS connection. |
@@ -333,30 +341,49 @@ class TLSRecordLayer: |
AlertDescription.close_notify, AlertLevel.warning)): |
yield result |
alert = None |
- # Forcing a shutdown as WinHTTP does not seem to be |
- # responsive to the close notify. |
- prevCloseSocket = self.closeSocket |
- self.closeSocket = True |
- self._shutdown(True) |
- self.closeSocket = prevCloseSocket |
- while not alert: |
- for result in self._getMsg((ContentType.alert, \ |
- ContentType.application_data)): |
- if result in (0,1): |
- yield result |
- if result.contentType == ContentType.alert: |
- alert = result |
- if alert.description == AlertDescription.close_notify: |
+ # By default close the socket, since it's been observed |
+ # that some other libraries will not respond to the |
+ # close_notify alert, thus leaving us hanging if we're |
+ # expecting it |
+ if self.closeSocket: |
self._shutdown(True) |
else: |
- raise TLSRemoteAlert(alert) |
+ while not alert: |
+ for result in self._getMsg((ContentType.alert, \ |
+ ContentType.application_data)): |
+ if result in (0,1): |
+ yield result |
+ if result.contentType == ContentType.alert: |
+ alert = result |
+ if alert.description == AlertDescription.close_notify: |
+ self._shutdown(True) |
+ else: |
+ raise TLSRemoteAlert(alert) |
except (socket.error, TLSAbruptCloseError): |
#If the other side closes the socket, that's okay |
self._shutdown(True) |
+ except GeneratorExit: |
+ raise |
except: |
self._shutdown(False) |
raise |
+ def getVersionName(self): |
+ """Get the name of this TLS version. |
+ |
+ @rtype: str |
+ @return: The name of the TLS version used with this connection. |
+ Either None, 'SSL 3.0', 'TLS 1.0', or 'TLS 1.1'. |
+ """ |
+ if self.version == (3,0): |
+ return "SSL 3.0" |
+ elif self.version == (3,1): |
+ return "TLS 1.0" |
+ elif self.version == (3,2): |
+ return "TLS 1.1" |
+ else: |
+ return None |
+ |
def getCipherName(self): |
"""Get the name of the cipher used with this connection. |
@@ -374,8 +401,7 @@ class TLSRecordLayer: |
@rtype: str |
@return: The name of the cipher implementation used with |
- this connection. Either 'python', 'cryptlib', 'openssl', |
- or 'pycrypto'. |
+ this connection. Either 'python', 'openssl', or 'pycrypto'. |
""" |
if not self._writeState.encContext: |
return None |
@@ -409,13 +435,35 @@ class TLSRecordLayer: |
""" |
return self.read(bufsize) |
+ def recv_into(self, b): |
+ # XXX doc string |
+ data = self.read(len(b)) |
+ if not data: |
+ return None |
+ b[:len(data)] = data |
+ return len(data) |
+ |
def makefile(self, mode='r', bufsize=-1): |
"""Create a file object for the TLS connection (socket emulation). |
- @rtype: L{tlslite.FileObject.FileObject} |
+ @rtype: L{socket._fileobject} |
""" |
self._refCount += 1 |
- return FileObject(self, mode, bufsize) |
+ # So, it is pretty fragile to be using Python internal objects |
+ # like this, but it is probably the best/easiest way to provide |
+ # matching behavior for socket emulation purposes. The 'close' |
+ # argument is nice, its apparently a recent addition to this |
+ # class, so that when fileobject.close() gets called, it will |
+ # close() us, causing the refcount to be decremented (decrefAsync). |
+ # |
+ # If this is the last close() on the outstanding fileobjects / |
+ # TLSConnection, then the "actual" close alerts will be sent, |
+ # socket closed, etc. |
+ if sys.version_info < (3,): |
+ return socket._fileobject(self, mode, bufsize, close=True) |
+ else: |
+ # XXX need to wrap this further if buffering is requested |
+ return socket.SocketIO(self, mode) |
def getsockname(self): |
"""Return the socket's own address (socket emulation).""" |
@@ -439,6 +487,14 @@ class TLSRecordLayer: |
"""Set the value of the given socket option (socket emulation).""" |
return self.sock.setsockopt(level, optname, value) |
+ def shutdown(self, how): |
+ """Shutdown the underlying socket.""" |
+ return self.sock.shutdown(how) |
+ |
+ def fileno(self): |
+ """Not implement in TLS Lite.""" |
+ raise NotImplementedError() |
+ |
#********************************************************* |
# Public Functions END |
@@ -447,7 +503,6 @@ class TLSRecordLayer: |
def _shutdown(self, resumable): |
self._writeState = _ConnectionState() |
self._readState = _ConnectionState() |
- #Don't do this: self._readBuffer = "" |
self.version = (0,0) |
self._versionCheck = False |
self.closed = True |
@@ -467,53 +522,58 @@ class TLSRecordLayer: |
raise TLSLocalAlert(alert, errorStr) |
def _sendMsgs(self, msgs): |
- skipEmptyFrag = False |
+ randomizeFirstBlock = True |
for msg in msgs: |
- for result in self._sendMsg(msg, skipEmptyFrag): |
+ for result in self._sendMsg(msg, randomizeFirstBlock): |
yield result |
- skipEmptyFrag = True |
+ randomizeFirstBlock = True |
- def _sendMsg(self, msg, skipEmptyFrag=False): |
- bytes = msg.write() |
- contentType = msg.contentType |
- |
- #Whenever we're connected and asked to send a message, |
- #we first send an empty Application Data message. This prevents |
+ def _sendMsg(self, msg, randomizeFirstBlock = True): |
+ #Whenever we're connected and asked to send an app data message, |
+ #we first send the first byte of the message. This prevents |
#an attacker from launching a chosen-plaintext attack based on |
- #knowing the next IV. |
- if not self.closed and not skipEmptyFrag and self.version == (3,1): |
- if self._writeState.encContext: |
- if self._writeState.encContext.isBlockCipher: |
- for result in self._sendMsg(ApplicationData(), |
- skipEmptyFrag=True): |
- yield result |
+ #knowing the next IV (a la BEAST). |
+ if not self.closed and randomizeFirstBlock and self.version <= (3,1) \ |
+ and self._writeState.encContext \ |
+ and self._writeState.encContext.isBlockCipher \ |
+ and isinstance(msg, ApplicationData): |
+ msgFirstByte = msg.splitFirstByte() |
+ for result in self._sendMsg(msgFirstByte, |
+ randomizeFirstBlock = False): |
+ yield result |
+ |
+ b = msg.write() |
+ |
+ # If a 1-byte message was passed in, and we "split" the |
+ # first(only) byte off above, we may have a 0-length msg: |
+ if len(b) == 0: |
+ return |
+ |
+ contentType = msg.contentType |
#Update handshake hashes |
if contentType == ContentType.handshake: |
- bytesStr = bytesToString(bytes) |
- self._handshake_md5.update(bytesStr) |
- self._handshake_sha.update(bytesStr) |
+ self._handshake_md5.update(compat26Str(b)) |
+ self._handshake_sha.update(compat26Str(b)) |
#Calculate MAC |
if self._writeState.macContext: |
- seqnumStr = self._writeState.getSeqNumStr() |
- bytesStr = bytesToString(bytes) |
+ seqnumBytes = self._writeState.getSeqNumBytes() |
mac = self._writeState.macContext.copy() |
- mac.update(seqnumStr) |
- mac.update(chr(contentType)) |
+ mac.update(compatHMAC(seqnumBytes)) |
+ mac.update(compatHMAC(bytearray([contentType]))) |
if self.version == (3,0): |
- mac.update( chr( int(len(bytes)/256) ) ) |
- mac.update( chr( int(len(bytes)%256) ) ) |
+ mac.update( compatHMAC( bytearray([len(b)//256] ))) |
+ mac.update( compatHMAC( bytearray([len(b)%256] ))) |
elif self.version in ((3,1), (3,2)): |
- mac.update(chr(self.version[0])) |
- mac.update(chr(self.version[1])) |
- mac.update( chr( int(len(bytes)/256) ) ) |
- mac.update( chr( int(len(bytes)%256) ) ) |
+ mac.update(compatHMAC( bytearray([self.version[0]] ))) |
+ mac.update(compatHMAC( bytearray([self.version[1]] ))) |
+ mac.update( compatHMAC( bytearray([len(b)//256] ))) |
+ mac.update( compatHMAC( bytearray([len(b)%256] ))) |
else: |
raise AssertionError() |
- mac.update(bytesStr) |
- macString = mac.digest() |
- macBytes = stringToBytes(macString) |
+ mac.update(compatHMAC(b)) |
+ macBytes = bytearray(mac.digest()) |
if self.fault == Fault.badMAC: |
macBytes[0] = (macBytes[0]+1) % 256 |
@@ -524,43 +584,68 @@ class TLSRecordLayer: |
#Add TLS 1.1 fixed block |
if self.version == (3,2): |
- bytes = self.fixedIVBlock + bytes |
+ b = self.fixedIVBlock + b |
- #Add padding: bytes = bytes + (macBytes + paddingBytes) |
- currentLength = len(bytes) + len(macBytes) + 1 |
+ #Add padding: b = b + (macBytes + paddingBytes) |
+ currentLength = len(b) + len(macBytes) |
blockLength = self._writeState.encContext.block_size |
- paddingLength = blockLength-(currentLength % blockLength) |
+ paddingLength = blockLength - 1 - (currentLength % blockLength) |
- paddingBytes = createByteArraySequence([paddingLength] * \ |
- (paddingLength+1)) |
+ paddingBytes = bytearray([paddingLength] * (paddingLength+1)) |
if self.fault == Fault.badPadding: |
paddingBytes[0] = (paddingBytes[0]+1) % 256 |
- endBytes = concatArrays(macBytes, paddingBytes) |
- bytes = concatArrays(bytes, endBytes) |
+ endBytes = macBytes + paddingBytes |
+ b += endBytes |
#Encrypt |
- plaintext = stringToBytes(bytes) |
- ciphertext = self._writeState.encContext.encrypt(plaintext) |
- bytes = stringToBytes(ciphertext) |
+ b = self._writeState.encContext.encrypt(b) |
#Encrypt (for Stream Cipher) |
else: |
- bytes = concatArrays(bytes, macBytes) |
- plaintext = bytesToString(bytes) |
- ciphertext = self._writeState.encContext.encrypt(plaintext) |
- bytes = stringToBytes(ciphertext) |
+ b += macBytes |
+ b = self._writeState.encContext.encrypt(b) |
#Add record header and send |
- r = RecordHeader3().create(self.version, contentType, len(bytes)) |
- s = bytesToString(concatArrays(r.write(), bytes)) |
+ r = RecordHeader3().create(self.version, contentType, len(b)) |
+ s = r.write() + b |
while 1: |
try: |
bytesSent = self.sock.send(s) #Might raise socket.error |
- except socket.error, why: |
- if why[0] == errno.EWOULDBLOCK: |
+ except socket.error as why: |
+ if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): |
yield 1 |
continue |
else: |
- raise |
+ # The socket was unexpectedly closed. The tricky part |
+ # is that there may be an alert sent by the other party |
+ # sitting in the read buffer. So, if we get here after |
+ # handshaking, we will just raise the error and let the |
+ # caller read more data if it would like, thus stumbling |
+ # upon the error. |
+ # |
+ # However, if we get here DURING handshaking, we take |
+ # it upon ourselves to see if the next message is an |
+ # Alert. |
+ if contentType == ContentType.handshake: |
+ |
+ # See if there's an alert record |
+ # Could raise socket.error or TLSAbruptCloseError |
+ for result in self._getNextRecord(): |
+ if result in (0,1): |
+ yield result |
+ |
+ # Closes the socket |
+ self._shutdown(False) |
+ |
+ # If we got an alert, raise it |
+ recordHeader, p = result |
+ if recordHeader.type == ContentType.alert: |
+ alert = Alert().parse(p) |
+ raise TLSRemoteAlert(alert) |
+ else: |
+ # If we got some other message who know what |
+ # the remote side is doing, just go ahead and |
+ # raise the socket.error |
+ raise |
if bytesSent == len(s): |
return |
s = s[bytesSent:] |
@@ -690,9 +775,8 @@ class TLSRecordLayer: |
yield result |
#Update handshake hashes |
- sToHash = bytesToString(p.bytes) |
- self._handshake_md5.update(sToHash) |
- self._handshake_sha.update(sToHash) |
+ self._handshake_md5.update(compat26Str(p.bytes)) |
+ self._handshake_sha.update(compat26Str(p.bytes)) |
#Parse based on handshake type |
if subType == HandshakeType.client_hello: |
@@ -714,13 +798,15 @@ class TLSRecordLayer: |
self.version).parse(p) |
elif subType == HandshakeType.finished: |
yield Finished(self.version).parse(p) |
+ elif subType == HandshakeType.next_protocol: |
+ yield NextProtocol().parse(p) |
elif subType == HandshakeType.encrypted_extensions: |
yield EncryptedExtensions().parse(p) |
else: |
raise AssertionError() |
#If an exception was raised by a Parser or Message instance: |
- except SyntaxError, e: |
+ except SyntaxError as e: |
for result in self._sendError(AlertDescription.decode_error, |
formatExceptionTrace(e)): |
yield result |
@@ -731,21 +817,21 @@ class TLSRecordLayer: |
#If there's a handshake message waiting, return it |
if self._handshakeBuffer: |
- recordHeader, bytes = self._handshakeBuffer[0] |
+ recordHeader, b = self._handshakeBuffer[0] |
self._handshakeBuffer = self._handshakeBuffer[1:] |
- yield (recordHeader, Parser(bytes)) |
+ yield (recordHeader, Parser(b)) |
return |
#Otherwise... |
#Read the next record header |
- bytes = createByteArraySequence([]) |
+ b = bytearray(0) |
recordHeaderLength = 1 |
ssl2 = False |
while 1: |
try: |
- s = self.sock.recv(recordHeaderLength-len(bytes)) |
- except socket.error, why: |
- if why[0] == errno.EWOULDBLOCK: |
+ s = self.sock.recv(recordHeaderLength-len(b)) |
+ except socket.error as why: |
+ if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): |
yield 0 |
continue |
else: |
@@ -755,24 +841,24 @@ class TLSRecordLayer: |
if len(s)==0: |
raise TLSAbruptCloseError() |
- bytes += stringToBytes(s) |
- if len(bytes)==1: |
- if bytes[0] in ContentType.all: |
+ b += bytearray(s) |
+ if len(b)==1: |
+ if b[0] in ContentType.all: |
ssl2 = False |
recordHeaderLength = 5 |
- elif bytes[0] == 128: |
+ elif b[0] == 128: |
ssl2 = True |
recordHeaderLength = 2 |
else: |
raise SyntaxError() |
- if len(bytes) == recordHeaderLength: |
+ if len(b) == recordHeaderLength: |
break |
#Parse the record header |
if ssl2: |
- r = RecordHeader2().parse(Parser(bytes)) |
+ r = RecordHeader2().parse(Parser(b)) |
else: |
- r = RecordHeader3().parse(Parser(bytes)) |
+ r = RecordHeader3().parse(Parser(b)) |
#Check the record header fields |
if r.length > 18432: |
@@ -780,12 +866,12 @@ class TLSRecordLayer: |
yield result |
#Read the record contents |
- bytes = createByteArraySequence([]) |
+ b = bytearray(0) |
while 1: |
try: |
- s = self.sock.recv(r.length - len(bytes)) |
- except socket.error, why: |
- if why[0] == errno.EWOULDBLOCK: |
+ s = self.sock.recv(r.length - len(b)) |
+ except socket.error as why: |
+ if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): |
yield 0 |
continue |
else: |
@@ -795,8 +881,8 @@ class TLSRecordLayer: |
if len(s)==0: |
raise TLSAbruptCloseError() |
- bytes += stringToBytes(s) |
- if len(bytes) == r.length: |
+ b += bytearray(s) |
+ if len(b) == r.length: |
break |
#Check the record header fields (2) |
@@ -814,13 +900,11 @@ class TLSRecordLayer: |
# yield result |
#Decrypt the record |
- for result in self._decryptRecord(r.type, bytes): |
- if result in (0,1): |
- yield result |
- else: |
- break |
- bytes = result |
- p = Parser(bytes) |
+ for result in self._decryptRecord(r.type, b): |
+ if result in (0,1): yield result |
+ else: break |
+ b = result |
+ p = Parser(b) |
#If it doesn't contain handshake messages, we can just return it |
if r.type != ContentType.handshake: |
@@ -832,7 +916,7 @@ class TLSRecordLayer: |
#Otherwise, we loop through and add the handshake messages to the |
#handshake buffer |
while 1: |
- if p.index == len(bytes): #If we're at the end |
+ if p.index == len(b): #If we're at the end |
if not self._handshakeBuffer: |
for result in self._sendError(\ |
AlertDescription.decode_error, \ |
@@ -840,51 +924,49 @@ class TLSRecordLayer: |
yield result |
break |
#There needs to be at least 4 bytes to get a header |
- if p.index+4 > len(bytes): |
+ if p.index+4 > len(b): |
for result in self._sendError(\ |
AlertDescription.decode_error, |
"A record has a partial handshake message (1)"): |
yield result |
p.get(1) # skip handshake type |
msgLength = p.get(3) |
- if p.index+msgLength > len(bytes): |
+ if p.index+msgLength > len(b): |
for result in self._sendError(\ |
AlertDescription.decode_error, |
"A record has a partial handshake message (2)"): |
yield result |
- handshakePair = (r, bytes[p.index-4 : p.index+msgLength]) |
+ handshakePair = (r, b[p.index-4 : p.index+msgLength]) |
self._handshakeBuffer.append(handshakePair) |
p.index += msgLength |
#We've moved at least one handshake message into the |
#handshakeBuffer, return the first one |
- recordHeader, bytes = self._handshakeBuffer[0] |
+ recordHeader, b = self._handshakeBuffer[0] |
self._handshakeBuffer = self._handshakeBuffer[1:] |
- yield (recordHeader, Parser(bytes)) |
+ yield (recordHeader, Parser(b)) |
- def _decryptRecord(self, recordType, bytes): |
+ def _decryptRecord(self, recordType, b): |
if self._readState.encContext: |
#Decrypt if it's a block cipher |
if self._readState.encContext.isBlockCipher: |
blockLength = self._readState.encContext.block_size |
- if len(bytes) % blockLength != 0: |
+ if len(b) % blockLength != 0: |
for result in self._sendError(\ |
AlertDescription.decryption_failed, |
"Encrypted data not a multiple of blocksize"): |
yield result |
- ciphertext = bytesToString(bytes) |
- plaintext = self._readState.encContext.decrypt(ciphertext) |
+ b = self._readState.encContext.decrypt(b) |
if self.version == (3,2): #For TLS 1.1, remove explicit IV |
- plaintext = plaintext[self._readState.encContext.block_size : ] |
- bytes = stringToBytes(plaintext) |
+ b = b[self._readState.encContext.block_size : ] |
#Check padding |
paddingGood = True |
- paddingLength = bytes[-1] |
- if (paddingLength+1) > len(bytes): |
+ paddingLength = b[-1] |
+ if (paddingLength+1) > len(b): |
paddingGood=False |
totalPaddingLength = 0 |
else: |
@@ -892,7 +974,7 @@ class TLSRecordLayer: |
totalPaddingLength = paddingLength+1 |
elif self.version in ((3,1), (3,2)): |
totalPaddingLength = paddingLength+1 |
- paddingBytes = bytes[-totalPaddingLength:-1] |
+ paddingBytes = b[-totalPaddingLength:-1] |
for byte in paddingBytes: |
if byte != paddingLength: |
paddingGood = False |
@@ -903,43 +985,39 @@ class TLSRecordLayer: |
#Decrypt if it's a stream cipher |
else: |
paddingGood = True |
- ciphertext = bytesToString(bytes) |
- plaintext = self._readState.encContext.decrypt(ciphertext) |
- bytes = stringToBytes(plaintext) |
+ b = self._readState.encContext.decrypt(b) |
totalPaddingLength = 0 |
#Check MAC |
macGood = True |
macLength = self._readState.macContext.digest_size |
endLength = macLength + totalPaddingLength |
- if endLength > len(bytes): |
+ if endLength > len(b): |
macGood = False |
else: |
#Read MAC |
- startIndex = len(bytes) - endLength |
+ startIndex = len(b) - endLength |
endIndex = startIndex + macLength |
- checkBytes = bytes[startIndex : endIndex] |
+ checkBytes = b[startIndex : endIndex] |
#Calculate MAC |
- seqnumStr = self._readState.getSeqNumStr() |
- bytes = bytes[:-endLength] |
- bytesStr = bytesToString(bytes) |
+ seqnumBytes = self._readState.getSeqNumBytes() |
+ b = b[:-endLength] |
mac = self._readState.macContext.copy() |
- mac.update(seqnumStr) |
- mac.update(chr(recordType)) |
+ mac.update(compatHMAC(seqnumBytes)) |
+ mac.update(compatHMAC(bytearray([recordType]))) |
if self.version == (3,0): |
- mac.update( chr( int(len(bytes)/256) ) ) |
- mac.update( chr( int(len(bytes)%256) ) ) |
+ mac.update( compatHMAC(bytearray( [len(b)//256] ) )) |
+ mac.update( compatHMAC(bytearray( [len(b)%256] ) )) |
elif self.version in ((3,1), (3,2)): |
- mac.update(chr(self.version[0])) |
- mac.update(chr(self.version[1])) |
- mac.update( chr( int(len(bytes)/256) ) ) |
- mac.update( chr( int(len(bytes)%256) ) ) |
+ mac.update(compatHMAC(bytearray( [self.version[0]] ) )) |
+ mac.update(compatHMAC(bytearray( [self.version[1]] ) )) |
+ mac.update(compatHMAC(bytearray( [len(b)//256] ) )) |
+ mac.update(compatHMAC(bytearray( [len(b)%256] ) )) |
else: |
raise AssertionError() |
- mac.update(bytesStr) |
- macString = mac.digest() |
- macBytes = stringToBytes(macString) |
+ mac.update(compatHMAC(b)) |
+ macBytes = bytearray(mac.digest()) |
#Compare MACs |
if macBytes != checkBytes: |
@@ -950,14 +1028,15 @@ class TLSRecordLayer: |
"MAC failure (or padding failure)"): |
yield result |
- yield bytes |
+ yield b |
def _handshakeStart(self, client): |
+ if not self.closed: |
+ raise ValueError("Renegotiation disallowed for security reasons") |
self._client = client |
- self._handshake_md5 = md5.md5() |
- self._handshake_sha = sha.sha() |
+ self._handshake_md5 = hashlib.md5() |
+ self._handshake_sha = hashlib.sha1() |
self._handshakeBuffer = [] |
- self.allegedSharedKeyUsername = None |
self.allegedSrpUsername = None |
self._refCount = 1 |
@@ -965,46 +1044,50 @@ class TLSRecordLayer: |
self.resumed = resumed |
self.closed = False |
- def _calcPendingStates(self, clientRandom, serverRandom, implementations): |
- if self.session.cipherSuite in CipherSuite.aes128Suites: |
- macLength = 20 |
+ def _calcPendingStates(self, cipherSuite, masterSecret, |
+ clientRandom, serverRandom, implementations): |
+ if cipherSuite in CipherSuite.aes128Suites: |
keyLength = 16 |
ivLength = 16 |
createCipherFunc = createAES |
- elif self.session.cipherSuite in CipherSuite.aes256Suites: |
- macLength = 20 |
+ elif cipherSuite in CipherSuite.aes256Suites: |
keyLength = 32 |
ivLength = 16 |
createCipherFunc = createAES |
- elif self.session.cipherSuite in CipherSuite.rc4Suites: |
- macLength = 20 |
+ elif cipherSuite in CipherSuite.rc4Suites: |
keyLength = 16 |
ivLength = 0 |
createCipherFunc = createRC4 |
- elif self.session.cipherSuite in CipherSuite.tripleDESSuites: |
- macLength = 20 |
+ elif cipherSuite in CipherSuite.tripleDESSuites: |
keyLength = 24 |
ivLength = 8 |
createCipherFunc = createTripleDES |
else: |
raise AssertionError() |
+ |
+ if cipherSuite in CipherSuite.shaSuites: |
+ macLength = 20 |
+ digestmod = hashlib.sha1 |
+ elif cipherSuite in CipherSuite.md5Suites: |
+ macLength = 16 |
+ digestmod = hashlib.md5 |
if self.version == (3,0): |
- createMACFunc = MAC_SSL |
+ createMACFunc = createMAC_SSL |
elif self.version in ((3,1), (3,2)): |
- createMACFunc = hmac.HMAC |
+ createMACFunc = createHMAC |
outputLength = (macLength*2) + (keyLength*2) + (ivLength*2) |
#Calculate Keying Material from Master Secret |
if self.version == (3,0): |
- keyBlock = PRF_SSL(self.session.masterSecret, |
- concatArrays(serverRandom, clientRandom), |
+ keyBlock = PRF_SSL(masterSecret, |
+ serverRandom + clientRandom, |
outputLength) |
elif self.version in ((3,1), (3,2)): |
- keyBlock = PRF(self.session.masterSecret, |
- "key expansion", |
- concatArrays(serverRandom,clientRandom), |
+ keyBlock = PRF(masterSecret, |
+ b"key expansion", |
+ serverRandom + clientRandom, |
outputLength) |
else: |
raise AssertionError() |
@@ -1013,16 +1096,16 @@ class TLSRecordLayer: |
clientPendingState = _ConnectionState() |
serverPendingState = _ConnectionState() |
p = Parser(keyBlock) |
- clientMACBlock = bytesToString(p.getFixBytes(macLength)) |
- serverMACBlock = bytesToString(p.getFixBytes(macLength)) |
- clientKeyBlock = bytesToString(p.getFixBytes(keyLength)) |
- serverKeyBlock = bytesToString(p.getFixBytes(keyLength)) |
- clientIVBlock = bytesToString(p.getFixBytes(ivLength)) |
- serverIVBlock = bytesToString(p.getFixBytes(ivLength)) |
- clientPendingState.macContext = createMACFunc(clientMACBlock, |
- digestmod=sha) |
- serverPendingState.macContext = createMACFunc(serverMACBlock, |
- digestmod=sha) |
+ clientMACBlock = p.getFixBytes(macLength) |
+ serverMACBlock = p.getFixBytes(macLength) |
+ clientKeyBlock = p.getFixBytes(keyLength) |
+ serverKeyBlock = p.getFixBytes(keyLength) |
+ clientIVBlock = p.getFixBytes(ivLength) |
+ serverIVBlock = p.getFixBytes(ivLength) |
+ clientPendingState.macContext = createMACFunc( |
+ compatHMAC(clientMACBlock), digestmod=digestmod) |
+ serverPendingState.macContext = createMACFunc( |
+ compatHMAC(serverMACBlock), digestmod=digestmod) |
clientPendingState.encContext = createCipherFunc(clientKeyBlock, |
clientIVBlock, |
implementations) |
@@ -1051,100 +1134,18 @@ class TLSRecordLayer: |
self._readState = self._pendingReadState |
self._pendingReadState = _ConnectionState() |
- def _sendFinished(self): |
- #Send ChangeCipherSpec |
- for result in self._sendMsg(ChangeCipherSpec()): |
- yield result |
- |
- #Switch to pending write state |
- self._changeWriteState() |
- |
- #Calculate verification data |
- verifyData = self._calcFinished(True) |
- if self.fault == Fault.badFinished: |
- verifyData[0] = (verifyData[0]+1)%256 |
- |
- #Send Finished message under new state |
- finished = Finished(self.version).create(verifyData) |
- for result in self._sendMsg(finished): |
- yield result |
- |
- def _getChangeCipherSpec(self): |
- #Get and check ChangeCipherSpec |
- for result in self._getMsg(ContentType.change_cipher_spec): |
- if result in (0,1): |
- yield result |
- changeCipherSpec = result |
- |
- if changeCipherSpec.type != 1: |
- for result in self._sendError(AlertDescription.illegal_parameter, |
- "ChangeCipherSpec type incorrect"): |
- yield result |
- |
- #Switch to pending read state |
- self._changeReadState() |
- |
- def _getEncryptedExtensions(self): |
- for result in self._getMsg(ContentType.handshake, |
- HandshakeType.encrypted_extensions): |
- if result in (0,1): |
- yield result |
- encrypted_extensions = result |
- self.channel_id = encrypted_extensions.channel_id_key |
- |
- def _getFinished(self): |
- #Calculate verification data |
- verifyData = self._calcFinished(False) |
- |
- #Get and check Finished message under new state |
- for result in self._getMsg(ContentType.handshake, |
- HandshakeType.finished): |
- if result in (0,1): |
- yield result |
- finished = result |
- if finished.verify_data != verifyData: |
- for result in self._sendError(AlertDescription.decrypt_error, |
- "Finished message is incorrect"): |
- yield result |
- |
- def _calcFinished(self, send=True): |
- if self.version == (3,0): |
- if (self._client and send) or (not self._client and not send): |
- senderStr = "\x43\x4C\x4E\x54" |
- else: |
- senderStr = "\x53\x52\x56\x52" |
- |
- verifyData = self._calcSSLHandshakeHash(self.session.masterSecret, |
- senderStr) |
- return verifyData |
- |
- elif self.version in ((3,1), (3,2)): |
- if (self._client and send) or (not self._client and not send): |
- label = "client finished" |
- else: |
- label = "server finished" |
- |
- handshakeHashes = stringToBytes(self._handshake_md5.digest() + \ |
- self._handshake_sha.digest()) |
- verifyData = PRF(self.session.masterSecret, label, handshakeHashes, |
- 12) |
- return verifyData |
- else: |
- raise AssertionError() |
- |
#Used for Finished messages and CertificateVerify messages in SSL v3 |
def _calcSSLHandshakeHash(self, masterSecret, label): |
- masterSecretStr = bytesToString(masterSecret) |
- |
imac_md5 = self._handshake_md5.copy() |
imac_sha = self._handshake_sha.copy() |
- imac_md5.update(label + masterSecretStr + '\x36'*48) |
- imac_sha.update(label + masterSecretStr + '\x36'*40) |
+ imac_md5.update(compatHMAC(label + masterSecret + bytearray([0x36]*48))) |
+ imac_sha.update(compatHMAC(label + masterSecret + bytearray([0x36]*40))) |
+ |
+ md5Bytes = MD5(masterSecret + bytearray([0x5c]*48) + \ |
+ bytearray(imac_md5.digest())) |
+ shaBytes = SHA1(masterSecret + bytearray([0x5c]*40) + \ |
+ bytearray(imac_sha.digest())) |
- md5Str = md5.md5(masterSecretStr + ('\x5c'*48) + \ |
- imac_md5.digest()).digest() |
- shaStr = sha.sha(masterSecretStr + ('\x5c'*40) + \ |
- imac_sha.digest()).digest() |
+ return md5Bytes + shaBytes |
- return stringToBytes(md5Str + shaStr) |