Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(30)

Unified Diff: third_party/tlslite/tlslite/messages.py

Issue 210323002: Update tlslite to 0.4.6. (Closed) Base URL: svn://svn.chromium.org/chrome/trunk/src
Patch Set: Executable bit and --similarity=80 Created 6 years, 8 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
« no previous file with comments | « third_party/tlslite/tlslite/mathtls.py ('k') | third_party/tlslite/tlslite/session.py » ('j') | no next file with comments »
Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
Index: third_party/tlslite/tlslite/messages.py
diff --git a/third_party/tlslite/tlslite/messages.py b/third_party/tlslite/tlslite/messages.py
index fd6265f294e8603184fbf830ff5ea23b7c136f84..532d86bb13c9977834ec1f48e7dd33306339aa3e 100644
--- a/third_party/tlslite/tlslite/messages.py
+++ b/third_party/tlslite/tlslite/messages.py
@@ -1,26 +1,23 @@
+# Authors:
+# Trevor Perrin
+# Google - handling CertificateRequest.certificate_types
+# Google (adapted by Sam Rushing and Marcelo Fernandez) - NPN support
+# Dimitris Moraitis - Anon ciphersuites
+#
+# See the LICENSE file for legal information regarding use of this file.
+
"""Classes representing TLS messages."""
-from utils.compat import *
-from utils.cryptomath import *
-from errors import *
-from utils.codec import *
-from constants import *
-from x509 import X509
-from x509certchain import X509CertChain
-
-# The sha module is deprecated in Python 2.6
-try:
- import sha
-except ImportError:
- from hashlib import sha1 as sha
-
-# The md5 module is deprecated in Python 2.6
-try:
- import md5
-except ImportError:
- from hashlib import md5
-
-class RecordHeader3:
+from .utils.compat import *
+from .utils.cryptomath import *
+from .errors import *
+from .utils.codec import *
+from .constants import *
+from .x509 import X509
+from .x509certchain import X509CertChain
+from .utils.tackwrapper import *
+
+class RecordHeader3(object):
def __init__(self):
self.type = 0
self.version = (0,0)
@@ -34,7 +31,7 @@ class RecordHeader3:
return self
def write(self):
- w = Writer(5)
+ w = Writer()
w.add(self.type, 1)
w.add(self.version[0], 1)
w.add(self.version[1], 1)
@@ -48,7 +45,7 @@ class RecordHeader3:
self.ssl2 = False
return self
-class RecordHeader2:
+class RecordHeader2(object):
def __init__(self):
self.type = 0
self.version = (0,0)
@@ -65,22 +62,7 @@ class RecordHeader2:
return self
-class Msg:
- def preWrite(self, trial):
- if trial:
- w = Writer()
- else:
- length = self.write(True)
- w = Writer(length)
- return w
-
- def postWrite(self, w, trial):
- if trial:
- return w.index
- else:
- return w.bytes
-
-class Alert(Msg):
+class Alert(object):
def __init__(self):
self.contentType = ContentType.alert
self.level = 0
@@ -99,50 +81,56 @@ class Alert(Msg):
return self
def write(self):
- w = Writer(2)
+ w = Writer()
w.add(self.level, 1)
w.add(self.description, 1)
return w.bytes
-class HandshakeMsg(Msg):
- def preWrite(self, handshakeType, trial):
- if trial:
- w = Writer()
- w.add(handshakeType, 1)
- w.add(0, 3)
- else:
- length = self.write(True)
- w = Writer(length)
- w.add(handshakeType, 1)
- w.add(length-4, 3)
- return w
-
+class HandshakeMsg(object):
+ def __init__(self, handshakeType):
+ self.contentType = ContentType.handshake
+ self.handshakeType = handshakeType
+
+ def postWrite(self, w):
+ headerWriter = Writer()
+ headerWriter.add(self.handshakeType, 1)
+ headerWriter.add(len(w.bytes), 3)
+ return headerWriter.bytes + w.bytes
class ClientHello(HandshakeMsg):
def __init__(self, ssl2=False):
- self.contentType = ContentType.handshake
+ HandshakeMsg.__init__(self, HandshakeType.client_hello)
self.ssl2 = ssl2
self.client_version = (0,0)
- self.random = createByteArrayZeros(32)
- self.session_id = createByteArraySequence([])
+ self.random = bytearray(32)
+ self.session_id = bytearray(0)
self.cipher_suites = [] # a list of 16-bit values
self.certificate_types = [CertificateType.x509]
self.compression_methods = [] # a list of 8-bit values
self.srp_username = None # a string
+ self.tack = False
+ self.supports_npn = False
+ self.server_name = bytearray(0)
self.channel_id = False
self.support_signed_cert_timestamps = False
self.status_request = False
def create(self, version, random, session_id, cipher_suites,
- certificate_types=None, srp_username=None):
+ certificate_types=None, srpUsername=None,
+ tack=False, supports_npn=False, serverName=None):
self.client_version = version
self.random = random
self.session_id = session_id
self.cipher_suites = cipher_suites
self.certificate_types = certificate_types
self.compression_methods = [0]
- self.srp_username = srp_username
+ if srpUsername:
+ self.srp_username = bytearray(srpUsername, "utf-8")
+ self.tack = tack
+ self.supports_npn = supports_npn
+ if serverName:
+ self.server_name = bytearray(serverName, "utf-8")
return self
def parse(self, p):
@@ -151,12 +139,12 @@ class ClientHello(HandshakeMsg):
cipherSpecsLength = p.get(2)
sessionIDLength = p.get(2)
randomLength = p.get(2)
- self.cipher_suites = p.getFixList(3, int(cipherSpecsLength/3))
+ self.cipher_suites = p.getFixList(3, cipherSpecsLength//3)
self.session_id = p.getFixBytes(sessionIDLength)
self.random = p.getFixBytes(randomLength)
if len(self.random) < 32:
zeroBytes = 32-len(self.random)
- self.random = createByteArrayZeros(zeroBytes) + self.random
+ self.random = bytearray(zeroBytes) + self.random
self.compression_methods = [0]#Fake this value
#We're not doing a stopLengthCheck() for SSLv2, oh well..
@@ -173,10 +161,27 @@ class ClientHello(HandshakeMsg):
while soFar != totalExtLength:
extType = p.get(2)
extLength = p.get(2)
- if extType == 6:
- self.srp_username = bytesToString(p.getVarBytes(1))
- elif extType == 7:
+ index1 = p.index
+ if extType == ExtensionType.srp:
+ self.srp_username = p.getVarBytes(1)
+ elif extType == ExtensionType.cert_type:
self.certificate_types = p.getVarList(1, 1)
+ elif extType == ExtensionType.tack:
+ self.tack = True
+ elif extType == ExtensionType.supports_npn:
+ self.supports_npn = True
+ elif extType == ExtensionType.server_name:
+ serverNameListBytes = p.getFixBytes(extLength)
+ p2 = Parser(serverNameListBytes)
+ p2.startLengthCheck(2)
+ while 1:
+ if p2.atLengthCheck():
+ break # no host_name, oh well
+ name_type = p2.get(1)
+ hostNameBytes = p2.getVarBytes(2)
+ if name_type == NameType.host_name:
+ self.server_name = hostNameBytes
+ break
elif extType == ExtensionType.channel_id:
self.channel_id = True
elif extType == ExtensionType.signed_cert_timestamps:
@@ -197,13 +202,16 @@ class ClientHello(HandshakeMsg):
p.getFixBytes(extLength)
self.status_request = True
else:
- p.getFixBytes(extLength)
+ _ = p.getFixBytes(extLength)
+ index2 = p.index
+ if index2 - index1 != extLength:
+ raise SyntaxError("Bad length for extension_data")
soFar += 4 + extLength
p.stopLengthCheck()
return self
- def write(self, trial=False):
- w = HandshakeMsg.preWrite(self, HandshakeType.client_hello, trial)
+ def write(self):
+ w = Writer()
w.add(self.client_version[0], 1)
w.add(self.client_version[1], 1)
w.addFixSeq(self.random, 1)
@@ -211,49 +219,66 @@ class ClientHello(HandshakeMsg):
w.addVarSeq(self.cipher_suites, 2, 2)
w.addVarSeq(self.compression_methods, 1, 1)
- extLength = 0
- if self.certificate_types and self.certificate_types != \
- [CertificateType.x509]:
- extLength += 5 + len(self.certificate_types)
- if self.srp_username:
- extLength += 5 + len(self.srp_username)
- if extLength > 0:
- w.add(extLength, 2)
-
+ w2 = Writer() # For Extensions
if self.certificate_types and self.certificate_types != \
[CertificateType.x509]:
- w.add(7, 2)
- w.add(len(self.certificate_types)+1, 2)
- w.addVarSeq(self.certificate_types, 1, 1)
+ w2.add(ExtensionType.cert_type, 2)
+ w2.add(len(self.certificate_types)+1, 2)
+ w2.addVarSeq(self.certificate_types, 1, 1)
if self.srp_username:
- w.add(6, 2)
- w.add(len(self.srp_username)+1, 2)
- w.addVarSeq(stringToBytes(self.srp_username), 1, 1)
-
- return HandshakeMsg.postWrite(self, w, trial)
-
+ w2.add(ExtensionType.srp, 2)
+ w2.add(len(self.srp_username)+1, 2)
+ w2.addVarSeq(self.srp_username, 1, 1)
+ if self.supports_npn:
+ w2.add(ExtensionType.supports_npn, 2)
+ w2.add(0, 2)
+ if self.server_name:
+ w2.add(ExtensionType.server_name, 2)
+ w2.add(len(self.server_name)+5, 2)
+ w2.add(len(self.server_name)+3, 2)
+ w2.add(NameType.host_name, 1)
+ w2.addVarSeq(self.server_name, 1, 2)
+ if self.tack:
+ w2.add(ExtensionType.tack, 2)
+ w2.add(0, 2)
+ if len(w2.bytes):
+ w.add(len(w2.bytes), 2)
+ w.bytes += w2.bytes
+ return self.postWrite(w)
+
+class BadNextProtos(Exception):
+ def __init__(self, l):
+ self.length = l
+
+ def __str__(self):
+ return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length
class ServerHello(HandshakeMsg):
def __init__(self):
- self.contentType = ContentType.handshake
+ HandshakeMsg.__init__(self, HandshakeType.server_hello)
self.server_version = (0,0)
- self.random = createByteArrayZeros(32)
- self.session_id = createByteArraySequence([])
+ self.random = bytearray(32)
+ self.session_id = bytearray(0)
self.cipher_suite = 0
self.certificate_type = CertificateType.x509
self.compression_method = 0
+ self.tackExt = None
+ self.next_protos_advertised = None
+ self.next_protos = None
self.channel_id = False
self.signed_cert_timestamps = None
self.status_request = False
def create(self, version, random, session_id, cipher_suite,
- certificate_type):
+ certificate_type, tackExt, next_protos_advertised):
self.server_version = version
self.random = random
self.session_id = session_id
self.cipher_suite = cipher_suite
self.certificate_type = certificate_type
self.compression_method = 0
+ self.tackExt = tackExt
+ self.next_protos_advertised = next_protos_advertised
return self
def parse(self, p):
@@ -269,16 +294,43 @@ class ServerHello(HandshakeMsg):
while soFar != totalExtLength:
extType = p.get(2)
extLength = p.get(2)
- if extType == 7:
+ if extType == ExtensionType.cert_type:
+ if extLength != 1:
+ raise SyntaxError()
self.certificate_type = p.get(1)
+ elif extType == ExtensionType.tack and tackpyLoaded:
+ self.tackExt = TackExtension(p.getFixBytes(extLength))
+ elif extType == ExtensionType.supports_npn:
+ self.next_protos = self.__parse_next_protos(p.getFixBytes(extLength))
else:
p.getFixBytes(extLength)
soFar += 4 + extLength
p.stopLengthCheck()
return self
- def write(self, trial=False):
- w = HandshakeMsg.preWrite(self, HandshakeType.server_hello, trial)
+ def __parse_next_protos(self, b):
+ protos = []
+ while True:
+ if len(b) == 0:
+ break
+ l = b[0]
+ b = b[1:]
+ if len(b) < l:
+ raise BadNextProtos(len(b))
+ protos.append(b[:l])
+ b = b[l:]
+ return protos
+
+ def __next_protos_encoded(self):
+ b = bytearray()
+ for e in self.next_protos_advertised:
+ if len(e) > 255 or len(e) == 0:
+ raise BadNextProtos(len(e))
+ b += bytearray( [len(e)] ) + bytearray(e)
+ return b
+
+ def write(self):
+ w = Writer()
w.add(self.server_version[0], 1)
w.add(self.server_version[1], 1)
w.addFixSeq(self.random, 1)
@@ -286,47 +338,41 @@ class ServerHello(HandshakeMsg):
w.add(self.cipher_suite, 2)
w.add(self.compression_method, 1)
- extLength = 0
- if self.certificate_type and self.certificate_type != \
- CertificateType.x509:
- extLength += 5
-
- if self.channel_id:
- extLength += 4
-
- if self.signed_cert_timestamps:
- extLength += 4 + len(self.signed_cert_timestamps)
-
- if self.status_request:
- extLength += 4
-
- if extLength != 0:
- w.add(extLength, 2)
-
+ w2 = Writer() # For Extensions
if self.certificate_type and self.certificate_type != \
CertificateType.x509:
- w.add(7, 2)
- w.add(1, 2)
- w.add(self.certificate_type, 1)
-
+ w2.add(ExtensionType.cert_type, 2)
+ w2.add(1, 2)
+ w2.add(self.certificate_type, 1)
+ if self.tackExt:
+ b = self.tackExt.serialize()
+ w2.add(ExtensionType.tack, 2)
+ w2.add(len(b), 2)
+ w2.bytes += b
+ if self.next_protos_advertised is not None:
+ encoded_next_protos_advertised = self.__next_protos_encoded()
+ w2.add(ExtensionType.supports_npn, 2)
+ w2.add(len(encoded_next_protos_advertised), 2)
+ w2.addFixSeq(encoded_next_protos_advertised, 1)
if self.channel_id:
- w.add(ExtensionType.channel_id, 2)
- w.add(0, 2)
-
+ w2.add(ExtensionType.channel_id, 2)
+ w2.add(0, 2)
if self.signed_cert_timestamps:
- w.add(ExtensionType.signed_cert_timestamps, 2)
- w.addVarSeq(stringToBytes(self.signed_cert_timestamps), 1, 2)
-
+ w2.add(ExtensionType.signed_cert_timestamps, 2)
+ w2.addVarSeq(bytearray(self.signed_cert_timestamps), 1, 2)
if self.status_request:
- w.add(ExtensionType.status_request, 2)
- w.add(0, 2)
+ w2.add(ExtensionType.status_request, 2)
+ w2.add(0, 2)
+ if len(w2.bytes):
+ w.add(len(w2.bytes), 2)
+ w.bytes += w2.bytes
+ return self.postWrite(w)
- return HandshakeMsg.postWrite(self, w, trial)
class Certificate(HandshakeMsg):
def __init__(self, certificateType):
+ HandshakeMsg.__init__(self, HandshakeType.certificate)
self.certificateType = certificateType
- self.contentType = ContentType.handshake
self.certChain = None
def create(self, certChain):
@@ -347,23 +393,14 @@ class Certificate(HandshakeMsg):
index += len(certBytes)+3
if certificate_list:
self.certChain = X509CertChain(certificate_list)
- elif self.certificateType == CertificateType.cryptoID:
- s = bytesToString(p.getVarBytes(2))
- if s:
- try:
- import cryptoIDlib.CertChain
- except ImportError:
- raise SyntaxError(\
- "cryptoID cert chain received, cryptoIDlib not present")
- self.certChain = cryptoIDlib.CertChain.CertChain().parse(s)
else:
raise AssertionError()
p.stopLengthCheck()
return self
- def write(self, trial=False):
- w = HandshakeMsg.preWrite(self, HandshakeType.certificate, trial)
+ def write(self):
+ w = Writer()
if self.certificateType == CertificateType.x509:
chainLength = 0
if self.certChain:
@@ -379,19 +416,13 @@ class Certificate(HandshakeMsg):
for cert in certificate_list:
bytes = cert.writeBytes()
w.addVarSeq(bytes, 1, 3)
- elif self.certificateType == CertificateType.cryptoID:
- if self.certChain:
- bytes = stringToBytes(self.certChain.write())
- else:
- bytes = createByteArraySequence([])
- w.addVarSeq(bytes, 1, 2)
else:
raise AssertionError()
- return HandshakeMsg.postWrite(self, w, trial)
+ return self.postWrite(w)
class CertificateStatus(HandshakeMsg):
def __init__(self):
- self.contentType = ContentType.handshake
+ HandshakeMsg.__init__(self, HandshakeType.certificate_status)
def create(self, ocsp_response):
self.ocsp_response = ocsp_response
@@ -411,18 +442,18 @@ class CertificateStatus(HandshakeMsg):
# Can't be empty
raise SyntaxError()
self.ocsp_response = ocsp_response
+ p.stopLengthCheck()
return self
- def write(self, trial=False):
- w = HandshakeMsg.preWrite(self, HandshakeType.certificate_status,
- trial)
+ def write(self):
+ w = Writer()
w.add(CertificateStatusType.ocsp, 1)
- w.addVarSeq(stringToBytes(self.ocsp_response), 1, 3)
- return HandshakeMsg.postWrite(self, w, trial)
+ w.addVarSeq(bytearray(self.ocsp_response), 1, 3)
+ return self.postWrite(w)
class CertificateRequest(HandshakeMsg):
def __init__(self):
- self.contentType = ContentType.handshake
+ HandshakeMsg.__init__(self, HandshakeType.certificate_request)
#Apple's Secure Transport library rejects empty certificate_types, so
#default to rsa_sign.
self.certificate_types = [ClientCertificateType.rsa_sign]
@@ -446,9 +477,8 @@ class CertificateRequest(HandshakeMsg):
p.stopLengthCheck()
return self
- def write(self, trial=False):
- w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request,
- trial)
+ def write(self):
+ w = Writer()
w.addVarSeq(self.certificate_types, 1, 1)
caLength = 0
#determine length
@@ -458,17 +488,21 @@ class CertificateRequest(HandshakeMsg):
#add bytes
for ca_dn in self.certificate_authorities:
w.addVarSeq(ca_dn, 1, 2)
- return HandshakeMsg.postWrite(self, w, trial)
+ return self.postWrite(w)
class ServerKeyExchange(HandshakeMsg):
def __init__(self, cipherSuite):
+ HandshakeMsg.__init__(self, HandshakeType.server_key_exchange)
self.cipherSuite = cipherSuite
- self.contentType = ContentType.handshake
- self.srp_N = 0L
- self.srp_g = 0L
- self.srp_s = createByteArraySequence([])
- self.srp_B = 0L
- self.signature = createByteArraySequence([])
+ self.srp_N = 0
+ self.srp_g = 0
+ self.srp_s = bytearray(0)
+ self.srp_B = 0
+ # Anon DH params:
+ self.dh_p = 0
+ self.dh_g = 0
+ self.dh_Ys = 0
+ self.signature = bytearray(0)
def createSRP(self, srp_N, srp_g, srp_s, srp_B):
self.srp_N = srp_N
@@ -476,42 +510,58 @@ class ServerKeyExchange(HandshakeMsg):
self.srp_s = srp_s
self.srp_B = srp_B
return self
+
+ def createDH(self, dh_p, dh_g, dh_Ys):
+ self.dh_p = dh_p
+ self.dh_g = dh_g
+ self.dh_Ys = dh_Ys
+ return self
def parse(self, p):
p.startLengthCheck(3)
- self.srp_N = bytesToNumber(p.getVarBytes(2))
- self.srp_g = bytesToNumber(p.getVarBytes(2))
- self.srp_s = p.getVarBytes(1)
- self.srp_B = bytesToNumber(p.getVarBytes(2))
- if self.cipherSuite in CipherSuite.srpRsaSuites:
- self.signature = p.getVarBytes(2)
+ if self.cipherSuite in CipherSuite.srpAllSuites:
+ self.srp_N = bytesToNumber(p.getVarBytes(2))
+ self.srp_g = bytesToNumber(p.getVarBytes(2))
+ self.srp_s = p.getVarBytes(1)
+ self.srp_B = bytesToNumber(p.getVarBytes(2))
+ if self.cipherSuite in CipherSuite.srpCertSuites:
+ self.signature = p.getVarBytes(2)
+ elif self.cipherSuite in CipherSuite.anonSuites:
+ self.dh_p = bytesToNumber(p.getVarBytes(2))
+ self.dh_g = bytesToNumber(p.getVarBytes(2))
+ self.dh_Ys = bytesToNumber(p.getVarBytes(2))
p.stopLengthCheck()
return self
- def write(self, trial=False):
- w = HandshakeMsg.preWrite(self, HandshakeType.server_key_exchange,
- trial)
- w.addVarSeq(numberToBytes(self.srp_N), 1, 2)
- w.addVarSeq(numberToBytes(self.srp_g), 1, 2)
- w.addVarSeq(self.srp_s, 1, 1)
- w.addVarSeq(numberToBytes(self.srp_B), 1, 2)
- if self.cipherSuite in CipherSuite.srpRsaSuites:
- w.addVarSeq(self.signature, 1, 2)
- return HandshakeMsg.postWrite(self, w, trial)
+ def write(self):
+ w = Writer()
+ if self.cipherSuite in CipherSuite.srpAllSuites:
+ w.addVarSeq(numberToByteArray(self.srp_N), 1, 2)
+ w.addVarSeq(numberToByteArray(self.srp_g), 1, 2)
+ w.addVarSeq(self.srp_s, 1, 1)
+ w.addVarSeq(numberToByteArray(self.srp_B), 1, 2)
+ if self.cipherSuite in CipherSuite.srpCertSuites:
+ w.addVarSeq(self.signature, 1, 2)
+ elif self.cipherSuite in CipherSuite.anonSuites:
+ w.addVarSeq(numberToByteArray(self.dh_p), 1, 2)
+ w.addVarSeq(numberToByteArray(self.dh_g), 1, 2)
+ w.addVarSeq(numberToByteArray(self.dh_Ys), 1, 2)
+ if self.cipherSuite in []: # TODO support for signed_params
+ w.addVarSeq(self.signature, 1, 2)
+ return self.postWrite(w)
def hash(self, clientRandom, serverRandom):
oldCipherSuite = self.cipherSuite
self.cipherSuite = None
try:
bytes = clientRandom + serverRandom + self.write()[4:]
- s = bytesToString(bytes)
- return stringToBytes(md5.md5(s).digest() + sha.sha(s).digest())
+ return MD5(bytes) + SHA1(bytes)
finally:
self.cipherSuite = oldCipherSuite
class ServerHelloDone(HandshakeMsg):
def __init__(self):
- self.contentType = ContentType.handshake
+ HandshakeMsg.__init__(self, HandshakeType.server_hello_done)
def create(self):
return self
@@ -521,17 +571,17 @@ class ServerHelloDone(HandshakeMsg):
p.stopLengthCheck()
return self
- def write(self, trial=False):
- w = HandshakeMsg.preWrite(self, HandshakeType.server_hello_done, trial)
- return HandshakeMsg.postWrite(self, w, trial)
+ def write(self):
+ w = Writer()
+ return self.postWrite(w)
class ClientKeyExchange(HandshakeMsg):
def __init__(self, cipherSuite, version=None):
+ HandshakeMsg.__init__(self, HandshakeType.client_key_exchange)
self.cipherSuite = cipherSuite
self.version = version
- self.contentType = ContentType.handshake
self.srp_A = 0
- self.encryptedPreMasterSecret = createByteArraySequence([])
+ self.encryptedPreMasterSecret = bytearray(0)
def createSRP(self, srp_A):
self.srp_A = srp_A
@@ -540,13 +590,16 @@ class ClientKeyExchange(HandshakeMsg):
def createRSA(self, encryptedPreMasterSecret):
self.encryptedPreMasterSecret = encryptedPreMasterSecret
return self
-
+
+ def createDH(self, dh_Yc):
+ self.dh_Yc = dh_Yc
+ return self
+
def parse(self, p):
p.startLengthCheck(3)
- if self.cipherSuite in CipherSuite.srpSuites + \
- CipherSuite.srpRsaSuites:
+ if self.cipherSuite in CipherSuite.srpAllSuites:
self.srp_A = bytesToNumber(p.getVarBytes(2))
- elif self.cipherSuite in CipherSuite.rsaSuites:
+ elif self.cipherSuite in CipherSuite.certSuites:
if self.version in ((3,1), (3,2)):
self.encryptedPreMasterSecret = p.getVarBytes(2)
elif self.version == (3,0):
@@ -554,32 +607,34 @@ class ClientKeyExchange(HandshakeMsg):
p.getFixBytes(len(p.bytes)-p.index)
else:
raise AssertionError()
+ elif self.cipherSuite in CipherSuite.anonSuites:
+ self.dh_Yc = bytesToNumber(p.getVarBytes(2))
else:
raise AssertionError()
p.stopLengthCheck()
return self
- def write(self, trial=False):
- w = HandshakeMsg.preWrite(self, HandshakeType.client_key_exchange,
- trial)
- if self.cipherSuite in CipherSuite.srpSuites + \
- CipherSuite.srpRsaSuites:
- w.addVarSeq(numberToBytes(self.srp_A), 1, 2)
- elif self.cipherSuite in CipherSuite.rsaSuites:
+ def write(self):
+ w = Writer()
+ if self.cipherSuite in CipherSuite.srpAllSuites:
+ w.addVarSeq(numberToByteArray(self.srp_A), 1, 2)
+ elif self.cipherSuite in CipherSuite.certSuites:
if self.version in ((3,1), (3,2)):
w.addVarSeq(self.encryptedPreMasterSecret, 1, 2)
elif self.version == (3,0):
w.addFixSeq(self.encryptedPreMasterSecret, 1)
else:
raise AssertionError()
+ elif self.cipherSuite in CipherSuite.anonSuites:
+ w.addVarSeq(numberToByteArray(self.dh_Yc), 1, 2)
else:
raise AssertionError()
- return HandshakeMsg.postWrite(self, w, trial)
+ return self.postWrite(w)
class CertificateVerify(HandshakeMsg):
def __init__(self):
- self.contentType = ContentType.handshake
- self.signature = createByteArraySequence([])
+ HandshakeMsg.__init__(self, HandshakeType.certificate_verify)
+ self.signature = bytearray(0)
def create(self, signature):
self.signature = signature
@@ -591,13 +646,12 @@ class CertificateVerify(HandshakeMsg):
p.stopLengthCheck()
return self
- def write(self, trial=False):
- w = HandshakeMsg.preWrite(self, HandshakeType.certificate_verify,
- trial)
+ def write(self):
+ w = Writer()
w.addVarSeq(self.signature, 1, 2)
- return HandshakeMsg.postWrite(self, w, trial)
+ return self.postWrite(w)
-class ChangeCipherSpec(Msg):
+class ChangeCipherSpec(object):
def __init__(self):
self.contentType = ContentType.change_cipher_spec
self.type = 1
@@ -612,17 +666,40 @@ class ChangeCipherSpec(Msg):
p.stopLengthCheck()
return self
- def write(self, trial=False):
- w = Msg.preWrite(self, trial)
+ def write(self):
+ w = Writer()
w.add(self.type,1)
- return Msg.postWrite(self, w, trial)
+ return w.bytes
+
+
+class NextProtocol(HandshakeMsg):
+ def __init__(self):
+ HandshakeMsg.__init__(self, HandshakeType.next_protocol)
+ self.next_proto = None
+
+ def create(self, next_proto):
+ self.next_proto = next_proto
+ return self
+
+ def parse(self, p):
+ p.startLengthCheck(3)
+ self.next_proto = p.getVarBytes(1)
+ _ = p.getVarBytes(1)
+ p.stopLengthCheck()
+ return self
+ def write(self, trial=False):
+ w = Writer()
+ w.addVarSeq(self.next_proto, 1, 1)
+ paddingLen = 32 - ((len(self.next_proto) + 2) % 32)
+ w.addVarSeq(bytearray(paddingLen), 1, 1)
+ return self.postWrite(w)
class Finished(HandshakeMsg):
def __init__(self, version):
- self.contentType = ContentType.handshake
+ HandshakeMsg.__init__(self, HandshakeType.finished)
self.version = version
- self.verify_data = createByteArraySequence([])
+ self.verify_data = bytearray(0)
def create(self, verify_data):
self.verify_data = verify_data
@@ -639,10 +716,10 @@ class Finished(HandshakeMsg):
p.stopLengthCheck()
return self
- def write(self, trial=False):
- w = HandshakeMsg.preWrite(self, HandshakeType.finished, trial)
+ def write(self):
+ w = Writer()
w.addFixSeq(self.verify_data, 1)
- return HandshakeMsg.postWrite(self, w, trial)
+ return self.postWrite(w)
class EncryptedExtensions(HandshakeMsg):
def __init__(self):
@@ -666,14 +743,19 @@ class EncryptedExtensions(HandshakeMsg):
p.stopLengthCheck()
return self
-class ApplicationData(Msg):
+class ApplicationData(object):
def __init__(self):
self.contentType = ContentType.application_data
- self.bytes = createByteArraySequence([])
+ self.bytes = bytearray(0)
def create(self, bytes):
self.bytes = bytes
return self
+
+ def splitFirstByte(self):
+ newMsg = ApplicationData().create(self.bytes[:1])
+ self.bytes = self.bytes[1:]
+ return newMsg
def parse(self, p):
self.bytes = p.bytes
« no previous file with comments | « third_party/tlslite/tlslite/mathtls.py ('k') | third_party/tlslite/tlslite/session.py » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698