| Index: third_party/tlslite/tlslite/messages.py
|
| diff --git a/third_party/tlslite/tlslite/messages.py b/third_party/tlslite/tlslite/messages.py
|
| index fd6265f294e8603184fbf830ff5ea23b7c136f84..532d86bb13c9977834ec1f48e7dd33306339aa3e 100644
|
| --- a/third_party/tlslite/tlslite/messages.py
|
| +++ b/third_party/tlslite/tlslite/messages.py
|
| @@ -1,26 +1,23 @@
|
| +# Authors:
|
| +# Trevor Perrin
|
| +# Google - handling CertificateRequest.certificate_types
|
| +# Google (adapted by Sam Rushing and Marcelo Fernandez) - NPN support
|
| +# Dimitris Moraitis - Anon ciphersuites
|
| +#
|
| +# See the LICENSE file for legal information regarding use of this file.
|
| +
|
| """Classes representing TLS messages."""
|
|
|
| -from utils.compat import *
|
| -from utils.cryptomath import *
|
| -from errors import *
|
| -from utils.codec import *
|
| -from constants import *
|
| -from x509 import X509
|
| -from x509certchain import X509CertChain
|
| -
|
| -# The sha module is deprecated in Python 2.6
|
| -try:
|
| - import sha
|
| -except ImportError:
|
| - from hashlib import sha1 as sha
|
| -
|
| -# The md5 module is deprecated in Python 2.6
|
| -try:
|
| - import md5
|
| -except ImportError:
|
| - from hashlib import md5
|
| -
|
| -class RecordHeader3:
|
| +from .utils.compat import *
|
| +from .utils.cryptomath import *
|
| +from .errors import *
|
| +from .utils.codec import *
|
| +from .constants import *
|
| +from .x509 import X509
|
| +from .x509certchain import X509CertChain
|
| +from .utils.tackwrapper import *
|
| +
|
| +class RecordHeader3(object):
|
| def __init__(self):
|
| self.type = 0
|
| self.version = (0,0)
|
| @@ -34,7 +31,7 @@ class RecordHeader3:
|
| return self
|
|
|
| def write(self):
|
| - w = Writer(5)
|
| + w = Writer()
|
| w.add(self.type, 1)
|
| w.add(self.version[0], 1)
|
| w.add(self.version[1], 1)
|
| @@ -48,7 +45,7 @@ class RecordHeader3:
|
| self.ssl2 = False
|
| return self
|
|
|
| -class RecordHeader2:
|
| +class RecordHeader2(object):
|
| def __init__(self):
|
| self.type = 0
|
| self.version = (0,0)
|
| @@ -65,22 +62,7 @@ class RecordHeader2:
|
| return self
|
|
|
|
|
| -class Msg:
|
| - def preWrite(self, trial):
|
| - if trial:
|
| - w = Writer()
|
| - else:
|
| - length = self.write(True)
|
| - w = Writer(length)
|
| - return w
|
| -
|
| - def postWrite(self, w, trial):
|
| - if trial:
|
| - return w.index
|
| - else:
|
| - return w.bytes
|
| -
|
| -class Alert(Msg):
|
| +class Alert(object):
|
| def __init__(self):
|
| self.contentType = ContentType.alert
|
| self.level = 0
|
| @@ -99,50 +81,56 @@ class Alert(Msg):
|
| return self
|
|
|
| def write(self):
|
| - w = Writer(2)
|
| + w = Writer()
|
| w.add(self.level, 1)
|
| w.add(self.description, 1)
|
| return w.bytes
|
|
|
|
|
| -class HandshakeMsg(Msg):
|
| - def preWrite(self, handshakeType, trial):
|
| - if trial:
|
| - w = Writer()
|
| - w.add(handshakeType, 1)
|
| - w.add(0, 3)
|
| - else:
|
| - length = self.write(True)
|
| - w = Writer(length)
|
| - w.add(handshakeType, 1)
|
| - w.add(length-4, 3)
|
| - return w
|
| -
|
| +class HandshakeMsg(object):
|
| + def __init__(self, handshakeType):
|
| + self.contentType = ContentType.handshake
|
| + self.handshakeType = handshakeType
|
| +
|
| + def postWrite(self, w):
|
| + headerWriter = Writer()
|
| + headerWriter.add(self.handshakeType, 1)
|
| + headerWriter.add(len(w.bytes), 3)
|
| + return headerWriter.bytes + w.bytes
|
|
|
| class ClientHello(HandshakeMsg):
|
| def __init__(self, ssl2=False):
|
| - self.contentType = ContentType.handshake
|
| + HandshakeMsg.__init__(self, HandshakeType.client_hello)
|
| self.ssl2 = ssl2
|
| self.client_version = (0,0)
|
| - self.random = createByteArrayZeros(32)
|
| - self.session_id = createByteArraySequence([])
|
| + self.random = bytearray(32)
|
| + self.session_id = bytearray(0)
|
| self.cipher_suites = [] # a list of 16-bit values
|
| self.certificate_types = [CertificateType.x509]
|
| self.compression_methods = [] # a list of 8-bit values
|
| self.srp_username = None # a string
|
| + self.tack = False
|
| + self.supports_npn = False
|
| + self.server_name = bytearray(0)
|
| self.channel_id = False
|
| self.support_signed_cert_timestamps = False
|
| self.status_request = False
|
|
|
| def create(self, version, random, session_id, cipher_suites,
|
| - certificate_types=None, srp_username=None):
|
| + certificate_types=None, srpUsername=None,
|
| + tack=False, supports_npn=False, serverName=None):
|
| self.client_version = version
|
| self.random = random
|
| self.session_id = session_id
|
| self.cipher_suites = cipher_suites
|
| self.certificate_types = certificate_types
|
| self.compression_methods = [0]
|
| - self.srp_username = srp_username
|
| + if srpUsername:
|
| + self.srp_username = bytearray(srpUsername, "utf-8")
|
| + self.tack = tack
|
| + self.supports_npn = supports_npn
|
| + if serverName:
|
| + self.server_name = bytearray(serverName, "utf-8")
|
| return self
|
|
|
| def parse(self, p):
|
| @@ -151,12 +139,12 @@ class ClientHello(HandshakeMsg):
|
| cipherSpecsLength = p.get(2)
|
| sessionIDLength = p.get(2)
|
| randomLength = p.get(2)
|
| - self.cipher_suites = p.getFixList(3, int(cipherSpecsLength/3))
|
| + self.cipher_suites = p.getFixList(3, cipherSpecsLength//3)
|
| self.session_id = p.getFixBytes(sessionIDLength)
|
| self.random = p.getFixBytes(randomLength)
|
| if len(self.random) < 32:
|
| zeroBytes = 32-len(self.random)
|
| - self.random = createByteArrayZeros(zeroBytes) + self.random
|
| + self.random = bytearray(zeroBytes) + self.random
|
| self.compression_methods = [0]#Fake this value
|
|
|
| #We're not doing a stopLengthCheck() for SSLv2, oh well..
|
| @@ -173,10 +161,27 @@ class ClientHello(HandshakeMsg):
|
| while soFar != totalExtLength:
|
| extType = p.get(2)
|
| extLength = p.get(2)
|
| - if extType == 6:
|
| - self.srp_username = bytesToString(p.getVarBytes(1))
|
| - elif extType == 7:
|
| + index1 = p.index
|
| + if extType == ExtensionType.srp:
|
| + self.srp_username = p.getVarBytes(1)
|
| + elif extType == ExtensionType.cert_type:
|
| self.certificate_types = p.getVarList(1, 1)
|
| + elif extType == ExtensionType.tack:
|
| + self.tack = True
|
| + elif extType == ExtensionType.supports_npn:
|
| + self.supports_npn = True
|
| + elif extType == ExtensionType.server_name:
|
| + serverNameListBytes = p.getFixBytes(extLength)
|
| + p2 = Parser(serverNameListBytes)
|
| + p2.startLengthCheck(2)
|
| + while 1:
|
| + if p2.atLengthCheck():
|
| + break # no host_name, oh well
|
| + name_type = p2.get(1)
|
| + hostNameBytes = p2.getVarBytes(2)
|
| + if name_type == NameType.host_name:
|
| + self.server_name = hostNameBytes
|
| + break
|
| elif extType == ExtensionType.channel_id:
|
| self.channel_id = True
|
| elif extType == ExtensionType.signed_cert_timestamps:
|
| @@ -197,13 +202,16 @@ class ClientHello(HandshakeMsg):
|
| p.getFixBytes(extLength)
|
| self.status_request = True
|
| else:
|
| - p.getFixBytes(extLength)
|
| + _ = p.getFixBytes(extLength)
|
| + index2 = p.index
|
| + if index2 - index1 != extLength:
|
| + raise SyntaxError("Bad length for extension_data")
|
| soFar += 4 + extLength
|
| p.stopLengthCheck()
|
| return self
|
|
|
| - def write(self, trial=False):
|
| - w = HandshakeMsg.preWrite(self, HandshakeType.client_hello, trial)
|
| + def write(self):
|
| + w = Writer()
|
| w.add(self.client_version[0], 1)
|
| w.add(self.client_version[1], 1)
|
| w.addFixSeq(self.random, 1)
|
| @@ -211,49 +219,66 @@ class ClientHello(HandshakeMsg):
|
| w.addVarSeq(self.cipher_suites, 2, 2)
|
| w.addVarSeq(self.compression_methods, 1, 1)
|
|
|
| - extLength = 0
|
| - if self.certificate_types and self.certificate_types != \
|
| - [CertificateType.x509]:
|
| - extLength += 5 + len(self.certificate_types)
|
| - if self.srp_username:
|
| - extLength += 5 + len(self.srp_username)
|
| - if extLength > 0:
|
| - w.add(extLength, 2)
|
| -
|
| + w2 = Writer() # For Extensions
|
| if self.certificate_types and self.certificate_types != \
|
| [CertificateType.x509]:
|
| - w.add(7, 2)
|
| - w.add(len(self.certificate_types)+1, 2)
|
| - w.addVarSeq(self.certificate_types, 1, 1)
|
| + w2.add(ExtensionType.cert_type, 2)
|
| + w2.add(len(self.certificate_types)+1, 2)
|
| + w2.addVarSeq(self.certificate_types, 1, 1)
|
| if self.srp_username:
|
| - w.add(6, 2)
|
| - w.add(len(self.srp_username)+1, 2)
|
| - w.addVarSeq(stringToBytes(self.srp_username), 1, 1)
|
| -
|
| - return HandshakeMsg.postWrite(self, w, trial)
|
| -
|
| + w2.add(ExtensionType.srp, 2)
|
| + w2.add(len(self.srp_username)+1, 2)
|
| + w2.addVarSeq(self.srp_username, 1, 1)
|
| + if self.supports_npn:
|
| + w2.add(ExtensionType.supports_npn, 2)
|
| + w2.add(0, 2)
|
| + if self.server_name:
|
| + w2.add(ExtensionType.server_name, 2)
|
| + w2.add(len(self.server_name)+5, 2)
|
| + w2.add(len(self.server_name)+3, 2)
|
| + w2.add(NameType.host_name, 1)
|
| + w2.addVarSeq(self.server_name, 1, 2)
|
| + if self.tack:
|
| + w2.add(ExtensionType.tack, 2)
|
| + w2.add(0, 2)
|
| + if len(w2.bytes):
|
| + w.add(len(w2.bytes), 2)
|
| + w.bytes += w2.bytes
|
| + return self.postWrite(w)
|
| +
|
| +class BadNextProtos(Exception):
|
| + def __init__(self, l):
|
| + self.length = l
|
| +
|
| + def __str__(self):
|
| + return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length
|
|
|
| class ServerHello(HandshakeMsg):
|
| def __init__(self):
|
| - self.contentType = ContentType.handshake
|
| + HandshakeMsg.__init__(self, HandshakeType.server_hello)
|
| self.server_version = (0,0)
|
| - self.random = createByteArrayZeros(32)
|
| - self.session_id = createByteArraySequence([])
|
| + self.random = bytearray(32)
|
| + self.session_id = bytearray(0)
|
| self.cipher_suite = 0
|
| self.certificate_type = CertificateType.x509
|
| self.compression_method = 0
|
| + self.tackExt = None
|
| + self.next_protos_advertised = None
|
| + self.next_protos = None
|
| self.channel_id = False
|
| self.signed_cert_timestamps = None
|
| self.status_request = False
|
|
|
| def create(self, version, random, session_id, cipher_suite,
|
| - certificate_type):
|
| + certificate_type, tackExt, next_protos_advertised):
|
| self.server_version = version
|
| self.random = random
|
| self.session_id = session_id
|
| self.cipher_suite = cipher_suite
|
| self.certificate_type = certificate_type
|
| self.compression_method = 0
|
| + self.tackExt = tackExt
|
| + self.next_protos_advertised = next_protos_advertised
|
| return self
|
|
|
| def parse(self, p):
|
| @@ -269,16 +294,43 @@ class ServerHello(HandshakeMsg):
|
| while soFar != totalExtLength:
|
| extType = p.get(2)
|
| extLength = p.get(2)
|
| - if extType == 7:
|
| + if extType == ExtensionType.cert_type:
|
| + if extLength != 1:
|
| + raise SyntaxError()
|
| self.certificate_type = p.get(1)
|
| + elif extType == ExtensionType.tack and tackpyLoaded:
|
| + self.tackExt = TackExtension(p.getFixBytes(extLength))
|
| + elif extType == ExtensionType.supports_npn:
|
| + self.next_protos = self.__parse_next_protos(p.getFixBytes(extLength))
|
| else:
|
| p.getFixBytes(extLength)
|
| soFar += 4 + extLength
|
| p.stopLengthCheck()
|
| return self
|
|
|
| - def write(self, trial=False):
|
| - w = HandshakeMsg.preWrite(self, HandshakeType.server_hello, trial)
|
| + def __parse_next_protos(self, b):
|
| + protos = []
|
| + while True:
|
| + if len(b) == 0:
|
| + break
|
| + l = b[0]
|
| + b = b[1:]
|
| + if len(b) < l:
|
| + raise BadNextProtos(len(b))
|
| + protos.append(b[:l])
|
| + b = b[l:]
|
| + return protos
|
| +
|
| + def __next_protos_encoded(self):
|
| + b = bytearray()
|
| + for e in self.next_protos_advertised:
|
| + if len(e) > 255 or len(e) == 0:
|
| + raise BadNextProtos(len(e))
|
| + b += bytearray( [len(e)] ) + bytearray(e)
|
| + return b
|
| +
|
| + def write(self):
|
| + w = Writer()
|
| w.add(self.server_version[0], 1)
|
| w.add(self.server_version[1], 1)
|
| w.addFixSeq(self.random, 1)
|
| @@ -286,47 +338,41 @@ class ServerHello(HandshakeMsg):
|
| w.add(self.cipher_suite, 2)
|
| w.add(self.compression_method, 1)
|
|
|
| - extLength = 0
|
| - if self.certificate_type and self.certificate_type != \
|
| - CertificateType.x509:
|
| - extLength += 5
|
| -
|
| - if self.channel_id:
|
| - extLength += 4
|
| -
|
| - if self.signed_cert_timestamps:
|
| - extLength += 4 + len(self.signed_cert_timestamps)
|
| -
|
| - if self.status_request:
|
| - extLength += 4
|
| -
|
| - if extLength != 0:
|
| - w.add(extLength, 2)
|
| -
|
| + w2 = Writer() # For Extensions
|
| if self.certificate_type and self.certificate_type != \
|
| CertificateType.x509:
|
| - w.add(7, 2)
|
| - w.add(1, 2)
|
| - w.add(self.certificate_type, 1)
|
| -
|
| + w2.add(ExtensionType.cert_type, 2)
|
| + w2.add(1, 2)
|
| + w2.add(self.certificate_type, 1)
|
| + if self.tackExt:
|
| + b = self.tackExt.serialize()
|
| + w2.add(ExtensionType.tack, 2)
|
| + w2.add(len(b), 2)
|
| + w2.bytes += b
|
| + if self.next_protos_advertised is not None:
|
| + encoded_next_protos_advertised = self.__next_protos_encoded()
|
| + w2.add(ExtensionType.supports_npn, 2)
|
| + w2.add(len(encoded_next_protos_advertised), 2)
|
| + w2.addFixSeq(encoded_next_protos_advertised, 1)
|
| if self.channel_id:
|
| - w.add(ExtensionType.channel_id, 2)
|
| - w.add(0, 2)
|
| -
|
| + w2.add(ExtensionType.channel_id, 2)
|
| + w2.add(0, 2)
|
| if self.signed_cert_timestamps:
|
| - w.add(ExtensionType.signed_cert_timestamps, 2)
|
| - w.addVarSeq(stringToBytes(self.signed_cert_timestamps), 1, 2)
|
| -
|
| + w2.add(ExtensionType.signed_cert_timestamps, 2)
|
| + w2.addVarSeq(bytearray(self.signed_cert_timestamps), 1, 2)
|
| if self.status_request:
|
| - w.add(ExtensionType.status_request, 2)
|
| - w.add(0, 2)
|
| + w2.add(ExtensionType.status_request, 2)
|
| + w2.add(0, 2)
|
| + if len(w2.bytes):
|
| + w.add(len(w2.bytes), 2)
|
| + w.bytes += w2.bytes
|
| + return self.postWrite(w)
|
|
|
| - return HandshakeMsg.postWrite(self, w, trial)
|
|
|
| class Certificate(HandshakeMsg):
|
| def __init__(self, certificateType):
|
| + HandshakeMsg.__init__(self, HandshakeType.certificate)
|
| self.certificateType = certificateType
|
| - self.contentType = ContentType.handshake
|
| self.certChain = None
|
|
|
| def create(self, certChain):
|
| @@ -347,23 +393,14 @@ class Certificate(HandshakeMsg):
|
| index += len(certBytes)+3
|
| if certificate_list:
|
| self.certChain = X509CertChain(certificate_list)
|
| - elif self.certificateType == CertificateType.cryptoID:
|
| - s = bytesToString(p.getVarBytes(2))
|
| - if s:
|
| - try:
|
| - import cryptoIDlib.CertChain
|
| - except ImportError:
|
| - raise SyntaxError(\
|
| - "cryptoID cert chain received, cryptoIDlib not present")
|
| - self.certChain = cryptoIDlib.CertChain.CertChain().parse(s)
|
| else:
|
| raise AssertionError()
|
|
|
| p.stopLengthCheck()
|
| return self
|
|
|
| - def write(self, trial=False):
|
| - w = HandshakeMsg.preWrite(self, HandshakeType.certificate, trial)
|
| + def write(self):
|
| + w = Writer()
|
| if self.certificateType == CertificateType.x509:
|
| chainLength = 0
|
| if self.certChain:
|
| @@ -379,19 +416,13 @@ class Certificate(HandshakeMsg):
|
| for cert in certificate_list:
|
| bytes = cert.writeBytes()
|
| w.addVarSeq(bytes, 1, 3)
|
| - elif self.certificateType == CertificateType.cryptoID:
|
| - if self.certChain:
|
| - bytes = stringToBytes(self.certChain.write())
|
| - else:
|
| - bytes = createByteArraySequence([])
|
| - w.addVarSeq(bytes, 1, 2)
|
| else:
|
| raise AssertionError()
|
| - return HandshakeMsg.postWrite(self, w, trial)
|
| + return self.postWrite(w)
|
|
|
| class CertificateStatus(HandshakeMsg):
|
| def __init__(self):
|
| - self.contentType = ContentType.handshake
|
| + HandshakeMsg.__init__(self, HandshakeType.certificate_status)
|
|
|
| def create(self, ocsp_response):
|
| self.ocsp_response = ocsp_response
|
| @@ -411,18 +442,18 @@ class CertificateStatus(HandshakeMsg):
|
| # Can't be empty
|
| raise SyntaxError()
|
| self.ocsp_response = ocsp_response
|
| + p.stopLengthCheck()
|
| return self
|
|
|
| - def write(self, trial=False):
|
| - w = HandshakeMsg.preWrite(self, HandshakeType.certificate_status,
|
| - trial)
|
| + def write(self):
|
| + w = Writer()
|
| w.add(CertificateStatusType.ocsp, 1)
|
| - w.addVarSeq(stringToBytes(self.ocsp_response), 1, 3)
|
| - return HandshakeMsg.postWrite(self, w, trial)
|
| + w.addVarSeq(bytearray(self.ocsp_response), 1, 3)
|
| + return self.postWrite(w)
|
|
|
| class CertificateRequest(HandshakeMsg):
|
| def __init__(self):
|
| - self.contentType = ContentType.handshake
|
| + HandshakeMsg.__init__(self, HandshakeType.certificate_request)
|
| #Apple's Secure Transport library rejects empty certificate_types, so
|
| #default to rsa_sign.
|
| self.certificate_types = [ClientCertificateType.rsa_sign]
|
| @@ -446,9 +477,8 @@ class CertificateRequest(HandshakeMsg):
|
| p.stopLengthCheck()
|
| return self
|
|
|
| - def write(self, trial=False):
|
| - w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request,
|
| - trial)
|
| + def write(self):
|
| + w = Writer()
|
| w.addVarSeq(self.certificate_types, 1, 1)
|
| caLength = 0
|
| #determine length
|
| @@ -458,17 +488,21 @@ class CertificateRequest(HandshakeMsg):
|
| #add bytes
|
| for ca_dn in self.certificate_authorities:
|
| w.addVarSeq(ca_dn, 1, 2)
|
| - return HandshakeMsg.postWrite(self, w, trial)
|
| + return self.postWrite(w)
|
|
|
| class ServerKeyExchange(HandshakeMsg):
|
| def __init__(self, cipherSuite):
|
| + HandshakeMsg.__init__(self, HandshakeType.server_key_exchange)
|
| self.cipherSuite = cipherSuite
|
| - self.contentType = ContentType.handshake
|
| - self.srp_N = 0L
|
| - self.srp_g = 0L
|
| - self.srp_s = createByteArraySequence([])
|
| - self.srp_B = 0L
|
| - self.signature = createByteArraySequence([])
|
| + self.srp_N = 0
|
| + self.srp_g = 0
|
| + self.srp_s = bytearray(0)
|
| + self.srp_B = 0
|
| + # Anon DH params:
|
| + self.dh_p = 0
|
| + self.dh_g = 0
|
| + self.dh_Ys = 0
|
| + self.signature = bytearray(0)
|
|
|
| def createSRP(self, srp_N, srp_g, srp_s, srp_B):
|
| self.srp_N = srp_N
|
| @@ -476,42 +510,58 @@ class ServerKeyExchange(HandshakeMsg):
|
| self.srp_s = srp_s
|
| self.srp_B = srp_B
|
| return self
|
| +
|
| + def createDH(self, dh_p, dh_g, dh_Ys):
|
| + self.dh_p = dh_p
|
| + self.dh_g = dh_g
|
| + self.dh_Ys = dh_Ys
|
| + return self
|
|
|
| def parse(self, p):
|
| p.startLengthCheck(3)
|
| - self.srp_N = bytesToNumber(p.getVarBytes(2))
|
| - self.srp_g = bytesToNumber(p.getVarBytes(2))
|
| - self.srp_s = p.getVarBytes(1)
|
| - self.srp_B = bytesToNumber(p.getVarBytes(2))
|
| - if self.cipherSuite in CipherSuite.srpRsaSuites:
|
| - self.signature = p.getVarBytes(2)
|
| + if self.cipherSuite in CipherSuite.srpAllSuites:
|
| + self.srp_N = bytesToNumber(p.getVarBytes(2))
|
| + self.srp_g = bytesToNumber(p.getVarBytes(2))
|
| + self.srp_s = p.getVarBytes(1)
|
| + self.srp_B = bytesToNumber(p.getVarBytes(2))
|
| + if self.cipherSuite in CipherSuite.srpCertSuites:
|
| + self.signature = p.getVarBytes(2)
|
| + elif self.cipherSuite in CipherSuite.anonSuites:
|
| + self.dh_p = bytesToNumber(p.getVarBytes(2))
|
| + self.dh_g = bytesToNumber(p.getVarBytes(2))
|
| + self.dh_Ys = bytesToNumber(p.getVarBytes(2))
|
| p.stopLengthCheck()
|
| return self
|
|
|
| - def write(self, trial=False):
|
| - w = HandshakeMsg.preWrite(self, HandshakeType.server_key_exchange,
|
| - trial)
|
| - w.addVarSeq(numberToBytes(self.srp_N), 1, 2)
|
| - w.addVarSeq(numberToBytes(self.srp_g), 1, 2)
|
| - w.addVarSeq(self.srp_s, 1, 1)
|
| - w.addVarSeq(numberToBytes(self.srp_B), 1, 2)
|
| - if self.cipherSuite in CipherSuite.srpRsaSuites:
|
| - w.addVarSeq(self.signature, 1, 2)
|
| - return HandshakeMsg.postWrite(self, w, trial)
|
| + def write(self):
|
| + w = Writer()
|
| + if self.cipherSuite in CipherSuite.srpAllSuites:
|
| + w.addVarSeq(numberToByteArray(self.srp_N), 1, 2)
|
| + w.addVarSeq(numberToByteArray(self.srp_g), 1, 2)
|
| + w.addVarSeq(self.srp_s, 1, 1)
|
| + w.addVarSeq(numberToByteArray(self.srp_B), 1, 2)
|
| + if self.cipherSuite in CipherSuite.srpCertSuites:
|
| + w.addVarSeq(self.signature, 1, 2)
|
| + elif self.cipherSuite in CipherSuite.anonSuites:
|
| + w.addVarSeq(numberToByteArray(self.dh_p), 1, 2)
|
| + w.addVarSeq(numberToByteArray(self.dh_g), 1, 2)
|
| + w.addVarSeq(numberToByteArray(self.dh_Ys), 1, 2)
|
| + if self.cipherSuite in []: # TODO support for signed_params
|
| + w.addVarSeq(self.signature, 1, 2)
|
| + return self.postWrite(w)
|
|
|
| def hash(self, clientRandom, serverRandom):
|
| oldCipherSuite = self.cipherSuite
|
| self.cipherSuite = None
|
| try:
|
| bytes = clientRandom + serverRandom + self.write()[4:]
|
| - s = bytesToString(bytes)
|
| - return stringToBytes(md5.md5(s).digest() + sha.sha(s).digest())
|
| + return MD5(bytes) + SHA1(bytes)
|
| finally:
|
| self.cipherSuite = oldCipherSuite
|
|
|
| class ServerHelloDone(HandshakeMsg):
|
| def __init__(self):
|
| - self.contentType = ContentType.handshake
|
| + HandshakeMsg.__init__(self, HandshakeType.server_hello_done)
|
|
|
| def create(self):
|
| return self
|
| @@ -521,17 +571,17 @@ class ServerHelloDone(HandshakeMsg):
|
| p.stopLengthCheck()
|
| return self
|
|
|
| - def write(self, trial=False):
|
| - w = HandshakeMsg.preWrite(self, HandshakeType.server_hello_done, trial)
|
| - return HandshakeMsg.postWrite(self, w, trial)
|
| + def write(self):
|
| + w = Writer()
|
| + return self.postWrite(w)
|
|
|
| class ClientKeyExchange(HandshakeMsg):
|
| def __init__(self, cipherSuite, version=None):
|
| + HandshakeMsg.__init__(self, HandshakeType.client_key_exchange)
|
| self.cipherSuite = cipherSuite
|
| self.version = version
|
| - self.contentType = ContentType.handshake
|
| self.srp_A = 0
|
| - self.encryptedPreMasterSecret = createByteArraySequence([])
|
| + self.encryptedPreMasterSecret = bytearray(0)
|
|
|
| def createSRP(self, srp_A):
|
| self.srp_A = srp_A
|
| @@ -540,13 +590,16 @@ class ClientKeyExchange(HandshakeMsg):
|
| def createRSA(self, encryptedPreMasterSecret):
|
| self.encryptedPreMasterSecret = encryptedPreMasterSecret
|
| return self
|
| -
|
| +
|
| + def createDH(self, dh_Yc):
|
| + self.dh_Yc = dh_Yc
|
| + return self
|
| +
|
| def parse(self, p):
|
| p.startLengthCheck(3)
|
| - if self.cipherSuite in CipherSuite.srpSuites + \
|
| - CipherSuite.srpRsaSuites:
|
| + if self.cipherSuite in CipherSuite.srpAllSuites:
|
| self.srp_A = bytesToNumber(p.getVarBytes(2))
|
| - elif self.cipherSuite in CipherSuite.rsaSuites:
|
| + elif self.cipherSuite in CipherSuite.certSuites:
|
| if self.version in ((3,1), (3,2)):
|
| self.encryptedPreMasterSecret = p.getVarBytes(2)
|
| elif self.version == (3,0):
|
| @@ -554,32 +607,34 @@ class ClientKeyExchange(HandshakeMsg):
|
| p.getFixBytes(len(p.bytes)-p.index)
|
| else:
|
| raise AssertionError()
|
| + elif self.cipherSuite in CipherSuite.anonSuites:
|
| + self.dh_Yc = bytesToNumber(p.getVarBytes(2))
|
| else:
|
| raise AssertionError()
|
| p.stopLengthCheck()
|
| return self
|
|
|
| - def write(self, trial=False):
|
| - w = HandshakeMsg.preWrite(self, HandshakeType.client_key_exchange,
|
| - trial)
|
| - if self.cipherSuite in CipherSuite.srpSuites + \
|
| - CipherSuite.srpRsaSuites:
|
| - w.addVarSeq(numberToBytes(self.srp_A), 1, 2)
|
| - elif self.cipherSuite in CipherSuite.rsaSuites:
|
| + def write(self):
|
| + w = Writer()
|
| + if self.cipherSuite in CipherSuite.srpAllSuites:
|
| + w.addVarSeq(numberToByteArray(self.srp_A), 1, 2)
|
| + elif self.cipherSuite in CipherSuite.certSuites:
|
| if self.version in ((3,1), (3,2)):
|
| w.addVarSeq(self.encryptedPreMasterSecret, 1, 2)
|
| elif self.version == (3,0):
|
| w.addFixSeq(self.encryptedPreMasterSecret, 1)
|
| else:
|
| raise AssertionError()
|
| + elif self.cipherSuite in CipherSuite.anonSuites:
|
| + w.addVarSeq(numberToByteArray(self.dh_Yc), 1, 2)
|
| else:
|
| raise AssertionError()
|
| - return HandshakeMsg.postWrite(self, w, trial)
|
| + return self.postWrite(w)
|
|
|
| class CertificateVerify(HandshakeMsg):
|
| def __init__(self):
|
| - self.contentType = ContentType.handshake
|
| - self.signature = createByteArraySequence([])
|
| + HandshakeMsg.__init__(self, HandshakeType.certificate_verify)
|
| + self.signature = bytearray(0)
|
|
|
| def create(self, signature):
|
| self.signature = signature
|
| @@ -591,13 +646,12 @@ class CertificateVerify(HandshakeMsg):
|
| p.stopLengthCheck()
|
| return self
|
|
|
| - def write(self, trial=False):
|
| - w = HandshakeMsg.preWrite(self, HandshakeType.certificate_verify,
|
| - trial)
|
| + def write(self):
|
| + w = Writer()
|
| w.addVarSeq(self.signature, 1, 2)
|
| - return HandshakeMsg.postWrite(self, w, trial)
|
| + return self.postWrite(w)
|
|
|
| -class ChangeCipherSpec(Msg):
|
| +class ChangeCipherSpec(object):
|
| def __init__(self):
|
| self.contentType = ContentType.change_cipher_spec
|
| self.type = 1
|
| @@ -612,17 +666,40 @@ class ChangeCipherSpec(Msg):
|
| p.stopLengthCheck()
|
| return self
|
|
|
| - def write(self, trial=False):
|
| - w = Msg.preWrite(self, trial)
|
| + def write(self):
|
| + w = Writer()
|
| w.add(self.type,1)
|
| - return Msg.postWrite(self, w, trial)
|
| + return w.bytes
|
| +
|
| +
|
| +class NextProtocol(HandshakeMsg):
|
| + def __init__(self):
|
| + HandshakeMsg.__init__(self, HandshakeType.next_protocol)
|
| + self.next_proto = None
|
| +
|
| + def create(self, next_proto):
|
| + self.next_proto = next_proto
|
| + return self
|
| +
|
| + def parse(self, p):
|
| + p.startLengthCheck(3)
|
| + self.next_proto = p.getVarBytes(1)
|
| + _ = p.getVarBytes(1)
|
| + p.stopLengthCheck()
|
| + return self
|
|
|
| + def write(self, trial=False):
|
| + w = Writer()
|
| + w.addVarSeq(self.next_proto, 1, 1)
|
| + paddingLen = 32 - ((len(self.next_proto) + 2) % 32)
|
| + w.addVarSeq(bytearray(paddingLen), 1, 1)
|
| + return self.postWrite(w)
|
|
|
| class Finished(HandshakeMsg):
|
| def __init__(self, version):
|
| - self.contentType = ContentType.handshake
|
| + HandshakeMsg.__init__(self, HandshakeType.finished)
|
| self.version = version
|
| - self.verify_data = createByteArraySequence([])
|
| + self.verify_data = bytearray(0)
|
|
|
| def create(self, verify_data):
|
| self.verify_data = verify_data
|
| @@ -639,10 +716,10 @@ class Finished(HandshakeMsg):
|
| p.stopLengthCheck()
|
| return self
|
|
|
| - def write(self, trial=False):
|
| - w = HandshakeMsg.preWrite(self, HandshakeType.finished, trial)
|
| + def write(self):
|
| + w = Writer()
|
| w.addFixSeq(self.verify_data, 1)
|
| - return HandshakeMsg.postWrite(self, w, trial)
|
| + return self.postWrite(w)
|
|
|
| class EncryptedExtensions(HandshakeMsg):
|
| def __init__(self):
|
| @@ -666,14 +743,19 @@ class EncryptedExtensions(HandshakeMsg):
|
| p.stopLengthCheck()
|
| return self
|
|
|
| -class ApplicationData(Msg):
|
| +class ApplicationData(object):
|
| def __init__(self):
|
| self.contentType = ContentType.application_data
|
| - self.bytes = createByteArraySequence([])
|
| + self.bytes = bytearray(0)
|
|
|
| def create(self, bytes):
|
| self.bytes = bytes
|
| return self
|
| +
|
| + def splitFirstByte(self):
|
| + newMsg = ApplicationData().create(self.bytes[:1])
|
| + self.bytes = self.bytes[1:]
|
| + return newMsg
|
|
|
| def parse(self, p):
|
| self.bytes = p.bytes
|
|
|