Index: third_party/tlslite/tlslite/messages.py |
diff --git a/third_party/tlslite/tlslite/messages.py b/third_party/tlslite/tlslite/messages.py |
index fd6265f294e8603184fbf830ff5ea23b7c136f84..532d86bb13c9977834ec1f48e7dd33306339aa3e 100644 |
--- a/third_party/tlslite/tlslite/messages.py |
+++ b/third_party/tlslite/tlslite/messages.py |
@@ -1,26 +1,23 @@ |
+# Authors: |
+# Trevor Perrin |
+# Google - handling CertificateRequest.certificate_types |
+# Google (adapted by Sam Rushing and Marcelo Fernandez) - NPN support |
+# Dimitris Moraitis - Anon ciphersuites |
+# |
+# See the LICENSE file for legal information regarding use of this file. |
+ |
"""Classes representing TLS messages.""" |
-from utils.compat import * |
-from utils.cryptomath import * |
-from errors import * |
-from utils.codec import * |
-from constants import * |
-from x509 import X509 |
-from x509certchain import X509CertChain |
- |
-# The sha module is deprecated in Python 2.6 |
-try: |
- import sha |
-except ImportError: |
- from hashlib import sha1 as sha |
- |
-# The md5 module is deprecated in Python 2.6 |
-try: |
- import md5 |
-except ImportError: |
- from hashlib import md5 |
- |
-class RecordHeader3: |
+from .utils.compat import * |
+from .utils.cryptomath import * |
+from .errors import * |
+from .utils.codec import * |
+from .constants import * |
+from .x509 import X509 |
+from .x509certchain import X509CertChain |
+from .utils.tackwrapper import * |
+ |
+class RecordHeader3(object): |
def __init__(self): |
self.type = 0 |
self.version = (0,0) |
@@ -34,7 +31,7 @@ class RecordHeader3: |
return self |
def write(self): |
- w = Writer(5) |
+ w = Writer() |
w.add(self.type, 1) |
w.add(self.version[0], 1) |
w.add(self.version[1], 1) |
@@ -48,7 +45,7 @@ class RecordHeader3: |
self.ssl2 = False |
return self |
-class RecordHeader2: |
+class RecordHeader2(object): |
def __init__(self): |
self.type = 0 |
self.version = (0,0) |
@@ -65,22 +62,7 @@ class RecordHeader2: |
return self |
-class Msg: |
- def preWrite(self, trial): |
- if trial: |
- w = Writer() |
- else: |
- length = self.write(True) |
- w = Writer(length) |
- return w |
- |
- def postWrite(self, w, trial): |
- if trial: |
- return w.index |
- else: |
- return w.bytes |
- |
-class Alert(Msg): |
+class Alert(object): |
def __init__(self): |
self.contentType = ContentType.alert |
self.level = 0 |
@@ -99,50 +81,56 @@ class Alert(Msg): |
return self |
def write(self): |
- w = Writer(2) |
+ w = Writer() |
w.add(self.level, 1) |
w.add(self.description, 1) |
return w.bytes |
-class HandshakeMsg(Msg): |
- def preWrite(self, handshakeType, trial): |
- if trial: |
- w = Writer() |
- w.add(handshakeType, 1) |
- w.add(0, 3) |
- else: |
- length = self.write(True) |
- w = Writer(length) |
- w.add(handshakeType, 1) |
- w.add(length-4, 3) |
- return w |
- |
+class HandshakeMsg(object): |
+ def __init__(self, handshakeType): |
+ self.contentType = ContentType.handshake |
+ self.handshakeType = handshakeType |
+ |
+ def postWrite(self, w): |
+ headerWriter = Writer() |
+ headerWriter.add(self.handshakeType, 1) |
+ headerWriter.add(len(w.bytes), 3) |
+ return headerWriter.bytes + w.bytes |
class ClientHello(HandshakeMsg): |
def __init__(self, ssl2=False): |
- self.contentType = ContentType.handshake |
+ HandshakeMsg.__init__(self, HandshakeType.client_hello) |
self.ssl2 = ssl2 |
self.client_version = (0,0) |
- self.random = createByteArrayZeros(32) |
- self.session_id = createByteArraySequence([]) |
+ self.random = bytearray(32) |
+ self.session_id = bytearray(0) |
self.cipher_suites = [] # a list of 16-bit values |
self.certificate_types = [CertificateType.x509] |
self.compression_methods = [] # a list of 8-bit values |
self.srp_username = None # a string |
+ self.tack = False |
+ self.supports_npn = False |
+ self.server_name = bytearray(0) |
self.channel_id = False |
self.support_signed_cert_timestamps = False |
self.status_request = False |
def create(self, version, random, session_id, cipher_suites, |
- certificate_types=None, srp_username=None): |
+ certificate_types=None, srpUsername=None, |
+ tack=False, supports_npn=False, serverName=None): |
self.client_version = version |
self.random = random |
self.session_id = session_id |
self.cipher_suites = cipher_suites |
self.certificate_types = certificate_types |
self.compression_methods = [0] |
- self.srp_username = srp_username |
+ if srpUsername: |
+ self.srp_username = bytearray(srpUsername, "utf-8") |
+ self.tack = tack |
+ self.supports_npn = supports_npn |
+ if serverName: |
+ self.server_name = bytearray(serverName, "utf-8") |
return self |
def parse(self, p): |
@@ -151,12 +139,12 @@ class ClientHello(HandshakeMsg): |
cipherSpecsLength = p.get(2) |
sessionIDLength = p.get(2) |
randomLength = p.get(2) |
- self.cipher_suites = p.getFixList(3, int(cipherSpecsLength/3)) |
+ self.cipher_suites = p.getFixList(3, cipherSpecsLength//3) |
self.session_id = p.getFixBytes(sessionIDLength) |
self.random = p.getFixBytes(randomLength) |
if len(self.random) < 32: |
zeroBytes = 32-len(self.random) |
- self.random = createByteArrayZeros(zeroBytes) + self.random |
+ self.random = bytearray(zeroBytes) + self.random |
self.compression_methods = [0]#Fake this value |
#We're not doing a stopLengthCheck() for SSLv2, oh well.. |
@@ -173,10 +161,27 @@ class ClientHello(HandshakeMsg): |
while soFar != totalExtLength: |
extType = p.get(2) |
extLength = p.get(2) |
- if extType == 6: |
- self.srp_username = bytesToString(p.getVarBytes(1)) |
- elif extType == 7: |
+ index1 = p.index |
+ if extType == ExtensionType.srp: |
+ self.srp_username = p.getVarBytes(1) |
+ elif extType == ExtensionType.cert_type: |
self.certificate_types = p.getVarList(1, 1) |
+ elif extType == ExtensionType.tack: |
+ self.tack = True |
+ elif extType == ExtensionType.supports_npn: |
+ self.supports_npn = True |
+ elif extType == ExtensionType.server_name: |
+ serverNameListBytes = p.getFixBytes(extLength) |
+ p2 = Parser(serverNameListBytes) |
+ p2.startLengthCheck(2) |
+ while 1: |
+ if p2.atLengthCheck(): |
+ break # no host_name, oh well |
+ name_type = p2.get(1) |
+ hostNameBytes = p2.getVarBytes(2) |
+ if name_type == NameType.host_name: |
+ self.server_name = hostNameBytes |
+ break |
elif extType == ExtensionType.channel_id: |
self.channel_id = True |
elif extType == ExtensionType.signed_cert_timestamps: |
@@ -197,13 +202,16 @@ class ClientHello(HandshakeMsg): |
p.getFixBytes(extLength) |
self.status_request = True |
else: |
- p.getFixBytes(extLength) |
+ _ = p.getFixBytes(extLength) |
+ index2 = p.index |
+ if index2 - index1 != extLength: |
+ raise SyntaxError("Bad length for extension_data") |
soFar += 4 + extLength |
p.stopLengthCheck() |
return self |
- def write(self, trial=False): |
- w = HandshakeMsg.preWrite(self, HandshakeType.client_hello, trial) |
+ def write(self): |
+ w = Writer() |
w.add(self.client_version[0], 1) |
w.add(self.client_version[1], 1) |
w.addFixSeq(self.random, 1) |
@@ -211,49 +219,66 @@ class ClientHello(HandshakeMsg): |
w.addVarSeq(self.cipher_suites, 2, 2) |
w.addVarSeq(self.compression_methods, 1, 1) |
- extLength = 0 |
- if self.certificate_types and self.certificate_types != \ |
- [CertificateType.x509]: |
- extLength += 5 + len(self.certificate_types) |
- if self.srp_username: |
- extLength += 5 + len(self.srp_username) |
- if extLength > 0: |
- w.add(extLength, 2) |
- |
+ w2 = Writer() # For Extensions |
if self.certificate_types and self.certificate_types != \ |
[CertificateType.x509]: |
- w.add(7, 2) |
- w.add(len(self.certificate_types)+1, 2) |
- w.addVarSeq(self.certificate_types, 1, 1) |
+ w2.add(ExtensionType.cert_type, 2) |
+ w2.add(len(self.certificate_types)+1, 2) |
+ w2.addVarSeq(self.certificate_types, 1, 1) |
if self.srp_username: |
- w.add(6, 2) |
- w.add(len(self.srp_username)+1, 2) |
- w.addVarSeq(stringToBytes(self.srp_username), 1, 1) |
- |
- return HandshakeMsg.postWrite(self, w, trial) |
- |
+ w2.add(ExtensionType.srp, 2) |
+ w2.add(len(self.srp_username)+1, 2) |
+ w2.addVarSeq(self.srp_username, 1, 1) |
+ if self.supports_npn: |
+ w2.add(ExtensionType.supports_npn, 2) |
+ w2.add(0, 2) |
+ if self.server_name: |
+ w2.add(ExtensionType.server_name, 2) |
+ w2.add(len(self.server_name)+5, 2) |
+ w2.add(len(self.server_name)+3, 2) |
+ w2.add(NameType.host_name, 1) |
+ w2.addVarSeq(self.server_name, 1, 2) |
+ if self.tack: |
+ w2.add(ExtensionType.tack, 2) |
+ w2.add(0, 2) |
+ if len(w2.bytes): |
+ w.add(len(w2.bytes), 2) |
+ w.bytes += w2.bytes |
+ return self.postWrite(w) |
+ |
+class BadNextProtos(Exception): |
+ def __init__(self, l): |
+ self.length = l |
+ |
+ def __str__(self): |
+ return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length |
class ServerHello(HandshakeMsg): |
def __init__(self): |
- self.contentType = ContentType.handshake |
+ HandshakeMsg.__init__(self, HandshakeType.server_hello) |
self.server_version = (0,0) |
- self.random = createByteArrayZeros(32) |
- self.session_id = createByteArraySequence([]) |
+ self.random = bytearray(32) |
+ self.session_id = bytearray(0) |
self.cipher_suite = 0 |
self.certificate_type = CertificateType.x509 |
self.compression_method = 0 |
+ self.tackExt = None |
+ self.next_protos_advertised = None |
+ self.next_protos = None |
self.channel_id = False |
self.signed_cert_timestamps = None |
self.status_request = False |
def create(self, version, random, session_id, cipher_suite, |
- certificate_type): |
+ certificate_type, tackExt, next_protos_advertised): |
self.server_version = version |
self.random = random |
self.session_id = session_id |
self.cipher_suite = cipher_suite |
self.certificate_type = certificate_type |
self.compression_method = 0 |
+ self.tackExt = tackExt |
+ self.next_protos_advertised = next_protos_advertised |
return self |
def parse(self, p): |
@@ -269,16 +294,43 @@ class ServerHello(HandshakeMsg): |
while soFar != totalExtLength: |
extType = p.get(2) |
extLength = p.get(2) |
- if extType == 7: |
+ if extType == ExtensionType.cert_type: |
+ if extLength != 1: |
+ raise SyntaxError() |
self.certificate_type = p.get(1) |
+ elif extType == ExtensionType.tack and tackpyLoaded: |
+ self.tackExt = TackExtension(p.getFixBytes(extLength)) |
+ elif extType == ExtensionType.supports_npn: |
+ self.next_protos = self.__parse_next_protos(p.getFixBytes(extLength)) |
else: |
p.getFixBytes(extLength) |
soFar += 4 + extLength |
p.stopLengthCheck() |
return self |
- def write(self, trial=False): |
- w = HandshakeMsg.preWrite(self, HandshakeType.server_hello, trial) |
+ def __parse_next_protos(self, b): |
+ protos = [] |
+ while True: |
+ if len(b) == 0: |
+ break |
+ l = b[0] |
+ b = b[1:] |
+ if len(b) < l: |
+ raise BadNextProtos(len(b)) |
+ protos.append(b[:l]) |
+ b = b[l:] |
+ return protos |
+ |
+ def __next_protos_encoded(self): |
+ b = bytearray() |
+ for e in self.next_protos_advertised: |
+ if len(e) > 255 or len(e) == 0: |
+ raise BadNextProtos(len(e)) |
+ b += bytearray( [len(e)] ) + bytearray(e) |
+ return b |
+ |
+ def write(self): |
+ w = Writer() |
w.add(self.server_version[0], 1) |
w.add(self.server_version[1], 1) |
w.addFixSeq(self.random, 1) |
@@ -286,47 +338,41 @@ class ServerHello(HandshakeMsg): |
w.add(self.cipher_suite, 2) |
w.add(self.compression_method, 1) |
- extLength = 0 |
- if self.certificate_type and self.certificate_type != \ |
- CertificateType.x509: |
- extLength += 5 |
- |
- if self.channel_id: |
- extLength += 4 |
- |
- if self.signed_cert_timestamps: |
- extLength += 4 + len(self.signed_cert_timestamps) |
- |
- if self.status_request: |
- extLength += 4 |
- |
- if extLength != 0: |
- w.add(extLength, 2) |
- |
+ w2 = Writer() # For Extensions |
if self.certificate_type and self.certificate_type != \ |
CertificateType.x509: |
- w.add(7, 2) |
- w.add(1, 2) |
- w.add(self.certificate_type, 1) |
- |
+ w2.add(ExtensionType.cert_type, 2) |
+ w2.add(1, 2) |
+ w2.add(self.certificate_type, 1) |
+ if self.tackExt: |
+ b = self.tackExt.serialize() |
+ w2.add(ExtensionType.tack, 2) |
+ w2.add(len(b), 2) |
+ w2.bytes += b |
+ if self.next_protos_advertised is not None: |
+ encoded_next_protos_advertised = self.__next_protos_encoded() |
+ w2.add(ExtensionType.supports_npn, 2) |
+ w2.add(len(encoded_next_protos_advertised), 2) |
+ w2.addFixSeq(encoded_next_protos_advertised, 1) |
if self.channel_id: |
- w.add(ExtensionType.channel_id, 2) |
- w.add(0, 2) |
- |
+ w2.add(ExtensionType.channel_id, 2) |
+ w2.add(0, 2) |
if self.signed_cert_timestamps: |
- w.add(ExtensionType.signed_cert_timestamps, 2) |
- w.addVarSeq(stringToBytes(self.signed_cert_timestamps), 1, 2) |
- |
+ w2.add(ExtensionType.signed_cert_timestamps, 2) |
+ w2.addVarSeq(bytearray(self.signed_cert_timestamps), 1, 2) |
if self.status_request: |
- w.add(ExtensionType.status_request, 2) |
- w.add(0, 2) |
+ w2.add(ExtensionType.status_request, 2) |
+ w2.add(0, 2) |
+ if len(w2.bytes): |
+ w.add(len(w2.bytes), 2) |
+ w.bytes += w2.bytes |
+ return self.postWrite(w) |
- return HandshakeMsg.postWrite(self, w, trial) |
class Certificate(HandshakeMsg): |
def __init__(self, certificateType): |
+ HandshakeMsg.__init__(self, HandshakeType.certificate) |
self.certificateType = certificateType |
- self.contentType = ContentType.handshake |
self.certChain = None |
def create(self, certChain): |
@@ -347,23 +393,14 @@ class Certificate(HandshakeMsg): |
index += len(certBytes)+3 |
if certificate_list: |
self.certChain = X509CertChain(certificate_list) |
- elif self.certificateType == CertificateType.cryptoID: |
- s = bytesToString(p.getVarBytes(2)) |
- if s: |
- try: |
- import cryptoIDlib.CertChain |
- except ImportError: |
- raise SyntaxError(\ |
- "cryptoID cert chain received, cryptoIDlib not present") |
- self.certChain = cryptoIDlib.CertChain.CertChain().parse(s) |
else: |
raise AssertionError() |
p.stopLengthCheck() |
return self |
- def write(self, trial=False): |
- w = HandshakeMsg.preWrite(self, HandshakeType.certificate, trial) |
+ def write(self): |
+ w = Writer() |
if self.certificateType == CertificateType.x509: |
chainLength = 0 |
if self.certChain: |
@@ -379,19 +416,13 @@ class Certificate(HandshakeMsg): |
for cert in certificate_list: |
bytes = cert.writeBytes() |
w.addVarSeq(bytes, 1, 3) |
- elif self.certificateType == CertificateType.cryptoID: |
- if self.certChain: |
- bytes = stringToBytes(self.certChain.write()) |
- else: |
- bytes = createByteArraySequence([]) |
- w.addVarSeq(bytes, 1, 2) |
else: |
raise AssertionError() |
- return HandshakeMsg.postWrite(self, w, trial) |
+ return self.postWrite(w) |
class CertificateStatus(HandshakeMsg): |
def __init__(self): |
- self.contentType = ContentType.handshake |
+ HandshakeMsg.__init__(self, HandshakeType.certificate_status) |
def create(self, ocsp_response): |
self.ocsp_response = ocsp_response |
@@ -411,18 +442,18 @@ class CertificateStatus(HandshakeMsg): |
# Can't be empty |
raise SyntaxError() |
self.ocsp_response = ocsp_response |
+ p.stopLengthCheck() |
return self |
- def write(self, trial=False): |
- w = HandshakeMsg.preWrite(self, HandshakeType.certificate_status, |
- trial) |
+ def write(self): |
+ w = Writer() |
w.add(CertificateStatusType.ocsp, 1) |
- w.addVarSeq(stringToBytes(self.ocsp_response), 1, 3) |
- return HandshakeMsg.postWrite(self, w, trial) |
+ w.addVarSeq(bytearray(self.ocsp_response), 1, 3) |
+ return self.postWrite(w) |
class CertificateRequest(HandshakeMsg): |
def __init__(self): |
- self.contentType = ContentType.handshake |
+ HandshakeMsg.__init__(self, HandshakeType.certificate_request) |
#Apple's Secure Transport library rejects empty certificate_types, so |
#default to rsa_sign. |
self.certificate_types = [ClientCertificateType.rsa_sign] |
@@ -446,9 +477,8 @@ class CertificateRequest(HandshakeMsg): |
p.stopLengthCheck() |
return self |
- def write(self, trial=False): |
- w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request, |
- trial) |
+ def write(self): |
+ w = Writer() |
w.addVarSeq(self.certificate_types, 1, 1) |
caLength = 0 |
#determine length |
@@ -458,17 +488,21 @@ class CertificateRequest(HandshakeMsg): |
#add bytes |
for ca_dn in self.certificate_authorities: |
w.addVarSeq(ca_dn, 1, 2) |
- return HandshakeMsg.postWrite(self, w, trial) |
+ return self.postWrite(w) |
class ServerKeyExchange(HandshakeMsg): |
def __init__(self, cipherSuite): |
+ HandshakeMsg.__init__(self, HandshakeType.server_key_exchange) |
self.cipherSuite = cipherSuite |
- self.contentType = ContentType.handshake |
- self.srp_N = 0L |
- self.srp_g = 0L |
- self.srp_s = createByteArraySequence([]) |
- self.srp_B = 0L |
- self.signature = createByteArraySequence([]) |
+ self.srp_N = 0 |
+ self.srp_g = 0 |
+ self.srp_s = bytearray(0) |
+ self.srp_B = 0 |
+ # Anon DH params: |
+ self.dh_p = 0 |
+ self.dh_g = 0 |
+ self.dh_Ys = 0 |
+ self.signature = bytearray(0) |
def createSRP(self, srp_N, srp_g, srp_s, srp_B): |
self.srp_N = srp_N |
@@ -476,42 +510,58 @@ class ServerKeyExchange(HandshakeMsg): |
self.srp_s = srp_s |
self.srp_B = srp_B |
return self |
+ |
+ def createDH(self, dh_p, dh_g, dh_Ys): |
+ self.dh_p = dh_p |
+ self.dh_g = dh_g |
+ self.dh_Ys = dh_Ys |
+ return self |
def parse(self, p): |
p.startLengthCheck(3) |
- self.srp_N = bytesToNumber(p.getVarBytes(2)) |
- self.srp_g = bytesToNumber(p.getVarBytes(2)) |
- self.srp_s = p.getVarBytes(1) |
- self.srp_B = bytesToNumber(p.getVarBytes(2)) |
- if self.cipherSuite in CipherSuite.srpRsaSuites: |
- self.signature = p.getVarBytes(2) |
+ if self.cipherSuite in CipherSuite.srpAllSuites: |
+ self.srp_N = bytesToNumber(p.getVarBytes(2)) |
+ self.srp_g = bytesToNumber(p.getVarBytes(2)) |
+ self.srp_s = p.getVarBytes(1) |
+ self.srp_B = bytesToNumber(p.getVarBytes(2)) |
+ if self.cipherSuite in CipherSuite.srpCertSuites: |
+ self.signature = p.getVarBytes(2) |
+ elif self.cipherSuite in CipherSuite.anonSuites: |
+ self.dh_p = bytesToNumber(p.getVarBytes(2)) |
+ self.dh_g = bytesToNumber(p.getVarBytes(2)) |
+ self.dh_Ys = bytesToNumber(p.getVarBytes(2)) |
p.stopLengthCheck() |
return self |
- def write(self, trial=False): |
- w = HandshakeMsg.preWrite(self, HandshakeType.server_key_exchange, |
- trial) |
- w.addVarSeq(numberToBytes(self.srp_N), 1, 2) |
- w.addVarSeq(numberToBytes(self.srp_g), 1, 2) |
- w.addVarSeq(self.srp_s, 1, 1) |
- w.addVarSeq(numberToBytes(self.srp_B), 1, 2) |
- if self.cipherSuite in CipherSuite.srpRsaSuites: |
- w.addVarSeq(self.signature, 1, 2) |
- return HandshakeMsg.postWrite(self, w, trial) |
+ def write(self): |
+ w = Writer() |
+ if self.cipherSuite in CipherSuite.srpAllSuites: |
+ w.addVarSeq(numberToByteArray(self.srp_N), 1, 2) |
+ w.addVarSeq(numberToByteArray(self.srp_g), 1, 2) |
+ w.addVarSeq(self.srp_s, 1, 1) |
+ w.addVarSeq(numberToByteArray(self.srp_B), 1, 2) |
+ if self.cipherSuite in CipherSuite.srpCertSuites: |
+ w.addVarSeq(self.signature, 1, 2) |
+ elif self.cipherSuite in CipherSuite.anonSuites: |
+ w.addVarSeq(numberToByteArray(self.dh_p), 1, 2) |
+ w.addVarSeq(numberToByteArray(self.dh_g), 1, 2) |
+ w.addVarSeq(numberToByteArray(self.dh_Ys), 1, 2) |
+ if self.cipherSuite in []: # TODO support for signed_params |
+ w.addVarSeq(self.signature, 1, 2) |
+ return self.postWrite(w) |
def hash(self, clientRandom, serverRandom): |
oldCipherSuite = self.cipherSuite |
self.cipherSuite = None |
try: |
bytes = clientRandom + serverRandom + self.write()[4:] |
- s = bytesToString(bytes) |
- return stringToBytes(md5.md5(s).digest() + sha.sha(s).digest()) |
+ return MD5(bytes) + SHA1(bytes) |
finally: |
self.cipherSuite = oldCipherSuite |
class ServerHelloDone(HandshakeMsg): |
def __init__(self): |
- self.contentType = ContentType.handshake |
+ HandshakeMsg.__init__(self, HandshakeType.server_hello_done) |
def create(self): |
return self |
@@ -521,17 +571,17 @@ class ServerHelloDone(HandshakeMsg): |
p.stopLengthCheck() |
return self |
- def write(self, trial=False): |
- w = HandshakeMsg.preWrite(self, HandshakeType.server_hello_done, trial) |
- return HandshakeMsg.postWrite(self, w, trial) |
+ def write(self): |
+ w = Writer() |
+ return self.postWrite(w) |
class ClientKeyExchange(HandshakeMsg): |
def __init__(self, cipherSuite, version=None): |
+ HandshakeMsg.__init__(self, HandshakeType.client_key_exchange) |
self.cipherSuite = cipherSuite |
self.version = version |
- self.contentType = ContentType.handshake |
self.srp_A = 0 |
- self.encryptedPreMasterSecret = createByteArraySequence([]) |
+ self.encryptedPreMasterSecret = bytearray(0) |
def createSRP(self, srp_A): |
self.srp_A = srp_A |
@@ -540,13 +590,16 @@ class ClientKeyExchange(HandshakeMsg): |
def createRSA(self, encryptedPreMasterSecret): |
self.encryptedPreMasterSecret = encryptedPreMasterSecret |
return self |
- |
+ |
+ def createDH(self, dh_Yc): |
+ self.dh_Yc = dh_Yc |
+ return self |
+ |
def parse(self, p): |
p.startLengthCheck(3) |
- if self.cipherSuite in CipherSuite.srpSuites + \ |
- CipherSuite.srpRsaSuites: |
+ if self.cipherSuite in CipherSuite.srpAllSuites: |
self.srp_A = bytesToNumber(p.getVarBytes(2)) |
- elif self.cipherSuite in CipherSuite.rsaSuites: |
+ elif self.cipherSuite in CipherSuite.certSuites: |
if self.version in ((3,1), (3,2)): |
self.encryptedPreMasterSecret = p.getVarBytes(2) |
elif self.version == (3,0): |
@@ -554,32 +607,34 @@ class ClientKeyExchange(HandshakeMsg): |
p.getFixBytes(len(p.bytes)-p.index) |
else: |
raise AssertionError() |
+ elif self.cipherSuite in CipherSuite.anonSuites: |
+ self.dh_Yc = bytesToNumber(p.getVarBytes(2)) |
else: |
raise AssertionError() |
p.stopLengthCheck() |
return self |
- def write(self, trial=False): |
- w = HandshakeMsg.preWrite(self, HandshakeType.client_key_exchange, |
- trial) |
- if self.cipherSuite in CipherSuite.srpSuites + \ |
- CipherSuite.srpRsaSuites: |
- w.addVarSeq(numberToBytes(self.srp_A), 1, 2) |
- elif self.cipherSuite in CipherSuite.rsaSuites: |
+ def write(self): |
+ w = Writer() |
+ if self.cipherSuite in CipherSuite.srpAllSuites: |
+ w.addVarSeq(numberToByteArray(self.srp_A), 1, 2) |
+ elif self.cipherSuite in CipherSuite.certSuites: |
if self.version in ((3,1), (3,2)): |
w.addVarSeq(self.encryptedPreMasterSecret, 1, 2) |
elif self.version == (3,0): |
w.addFixSeq(self.encryptedPreMasterSecret, 1) |
else: |
raise AssertionError() |
+ elif self.cipherSuite in CipherSuite.anonSuites: |
+ w.addVarSeq(numberToByteArray(self.dh_Yc), 1, 2) |
else: |
raise AssertionError() |
- return HandshakeMsg.postWrite(self, w, trial) |
+ return self.postWrite(w) |
class CertificateVerify(HandshakeMsg): |
def __init__(self): |
- self.contentType = ContentType.handshake |
- self.signature = createByteArraySequence([]) |
+ HandshakeMsg.__init__(self, HandshakeType.certificate_verify) |
+ self.signature = bytearray(0) |
def create(self, signature): |
self.signature = signature |
@@ -591,13 +646,12 @@ class CertificateVerify(HandshakeMsg): |
p.stopLengthCheck() |
return self |
- def write(self, trial=False): |
- w = HandshakeMsg.preWrite(self, HandshakeType.certificate_verify, |
- trial) |
+ def write(self): |
+ w = Writer() |
w.addVarSeq(self.signature, 1, 2) |
- return HandshakeMsg.postWrite(self, w, trial) |
+ return self.postWrite(w) |
-class ChangeCipherSpec(Msg): |
+class ChangeCipherSpec(object): |
def __init__(self): |
self.contentType = ContentType.change_cipher_spec |
self.type = 1 |
@@ -612,17 +666,40 @@ class ChangeCipherSpec(Msg): |
p.stopLengthCheck() |
return self |
- def write(self, trial=False): |
- w = Msg.preWrite(self, trial) |
+ def write(self): |
+ w = Writer() |
w.add(self.type,1) |
- return Msg.postWrite(self, w, trial) |
+ return w.bytes |
+ |
+ |
+class NextProtocol(HandshakeMsg): |
+ def __init__(self): |
+ HandshakeMsg.__init__(self, HandshakeType.next_protocol) |
+ self.next_proto = None |
+ |
+ def create(self, next_proto): |
+ self.next_proto = next_proto |
+ return self |
+ |
+ def parse(self, p): |
+ p.startLengthCheck(3) |
+ self.next_proto = p.getVarBytes(1) |
+ _ = p.getVarBytes(1) |
+ p.stopLengthCheck() |
+ return self |
+ def write(self, trial=False): |
+ w = Writer() |
+ w.addVarSeq(self.next_proto, 1, 1) |
+ paddingLen = 32 - ((len(self.next_proto) + 2) % 32) |
+ w.addVarSeq(bytearray(paddingLen), 1, 1) |
+ return self.postWrite(w) |
class Finished(HandshakeMsg): |
def __init__(self, version): |
- self.contentType = ContentType.handshake |
+ HandshakeMsg.__init__(self, HandshakeType.finished) |
self.version = version |
- self.verify_data = createByteArraySequence([]) |
+ self.verify_data = bytearray(0) |
def create(self, verify_data): |
self.verify_data = verify_data |
@@ -639,10 +716,10 @@ class Finished(HandshakeMsg): |
p.stopLengthCheck() |
return self |
- def write(self, trial=False): |
- w = HandshakeMsg.preWrite(self, HandshakeType.finished, trial) |
+ def write(self): |
+ w = Writer() |
w.addFixSeq(self.verify_data, 1) |
- return HandshakeMsg.postWrite(self, w, trial) |
+ return self.postWrite(w) |
class EncryptedExtensions(HandshakeMsg): |
def __init__(self): |
@@ -666,14 +743,19 @@ class EncryptedExtensions(HandshakeMsg): |
p.stopLengthCheck() |
return self |
-class ApplicationData(Msg): |
+class ApplicationData(object): |
def __init__(self): |
self.contentType = ContentType.application_data |
- self.bytes = createByteArraySequence([]) |
+ self.bytes = bytearray(0) |
def create(self, bytes): |
self.bytes = bytes |
return self |
+ |
+ def splitFirstByte(self): |
+ newMsg = ApplicationData().create(self.bytes[:1]) |
+ self.bytes = self.bytes[1:] |
+ return newMsg |
def parse(self, p): |
self.bytes = p.bytes |