| OLD | NEW |
| (Empty) |
| 1 # Copyright (c) 2001-2007 Twisted Matrix Laboratories. | |
| 2 # See LICENSE for details. | |
| 3 | |
| 4 | |
| 5 from twisted.trial import unittest, util as trial_util | |
| 6 from twisted.internet import protocol, reactor, interfaces, defer | |
| 7 from twisted.protocols import basic | |
| 8 from twisted.python import util, log | |
| 9 from twisted.python.runtime import platform | |
| 10 from twisted.test.test_tcp import WriteDataTestCase, ProperlyCloseFilesMixin | |
| 11 | |
| 12 import os, errno | |
| 13 | |
| 14 try: | |
| 15 from OpenSSL import SSL, crypto | |
| 16 from twisted.internet import ssl | |
| 17 from twisted.test.ssl_helpers import ClientTLSContext | |
| 18 except ImportError: | |
| 19 def _noSSL(): | |
| 20 # ugh, make pyflakes happy. | |
| 21 global SSL | |
| 22 global ssl | |
| 23 SSL = ssl = None | |
| 24 _noSSL() | |
| 25 | |
| 26 certPath = util.sibpath(__file__, "server.pem") | |
| 27 | |
| 28 class UnintelligentProtocol(basic.LineReceiver): | |
| 29 pretext = [ | |
| 30 "first line", | |
| 31 "last thing before tls starts", | |
| 32 "STARTTLS"] | |
| 33 | |
| 34 posttext = [ | |
| 35 "first thing after tls started", | |
| 36 "last thing ever"] | |
| 37 | |
| 38 def connectionMade(self): | |
| 39 for l in self.pretext: | |
| 40 self.sendLine(l) | |
| 41 | |
| 42 def lineReceived(self, line): | |
| 43 if line == "READY": | |
| 44 self.transport.startTLS(ClientTLSContext(), self.factory.client) | |
| 45 for l in self.posttext: | |
| 46 self.sendLine(l) | |
| 47 self.transport.loseConnection() | |
| 48 | |
| 49 | |
| 50 class LineCollector(basic.LineReceiver): | |
| 51 def __init__(self, doTLS, fillBuffer=0): | |
| 52 self.doTLS = doTLS | |
| 53 self.fillBuffer = fillBuffer | |
| 54 | |
| 55 def connectionMade(self): | |
| 56 self.factory.rawdata = '' | |
| 57 self.factory.lines = [] | |
| 58 | |
| 59 def lineReceived(self, line): | |
| 60 self.factory.lines.append(line) | |
| 61 if line == 'STARTTLS': | |
| 62 if self.fillBuffer: | |
| 63 for x in range(500): | |
| 64 self.sendLine('X'*1000) | |
| 65 self.sendLine('READY') | |
| 66 if self.doTLS: | |
| 67 ctx = ServerTLSContext( | |
| 68 privateKeyFileName=certPath, | |
| 69 certificateFileName=certPath, | |
| 70 ) | |
| 71 self.transport.startTLS(ctx, self.factory.server) | |
| 72 else: | |
| 73 self.setRawMode() | |
| 74 | |
| 75 def rawDataReceived(self, data): | |
| 76 self.factory.rawdata += data | |
| 77 self.factory.done = 1 | |
| 78 | |
| 79 def connectionLost(self, reason): | |
| 80 self.factory.done = 1 | |
| 81 | |
| 82 | |
| 83 class SingleLineServerProtocol(protocol.Protocol): | |
| 84 def connectionMade(self): | |
| 85 self.transport.identifier = 'SERVER' | |
| 86 self.transport.write("+OK <some crap>\r\n") | |
| 87 self.transport.getPeerCertificate() | |
| 88 | |
| 89 | |
| 90 class RecordingClientProtocol(protocol.Protocol): | |
| 91 def connectionMade(self): | |
| 92 self.transport.identifier = 'CLIENT' | |
| 93 self.buffer = [] | |
| 94 self.transport.getPeerCertificate() | |
| 95 | |
| 96 def dataReceived(self, data): | |
| 97 self.factory.buffer.append(data) | |
| 98 | |
| 99 | |
| 100 class ImmediatelyDisconnectingProtocol(protocol.Protocol): | |
| 101 def connectionMade(self): | |
| 102 self.transport.loseConnection() | |
| 103 | |
| 104 def connectionLost(self, reason): | |
| 105 self.factory.connectionDisconnected.callback(None) | |
| 106 | |
| 107 | |
| 108 class AlmostImmediatelyDisconnectingProtocol(protocol.Protocol): | |
| 109 def connectionMade(self): | |
| 110 # Twisted's SSL support is terribly broken. | |
| 111 reactor.callLater(0.1, self.transport.loseConnection) | |
| 112 | |
| 113 def connectionLost(self, reason): | |
| 114 self.factory.connectionDisconnected.callback(reason) | |
| 115 | |
| 116 | |
| 117 def generateCertificateObjects(organization, organizationalUnit): | |
| 118 pkey = crypto.PKey() | |
| 119 pkey.generate_key(crypto.TYPE_RSA, 512) | |
| 120 req = crypto.X509Req() | |
| 121 subject = req.get_subject() | |
| 122 subject.O = organization | |
| 123 subject.OU = organizationalUnit | |
| 124 req.set_pubkey(pkey) | |
| 125 req.sign(pkey, "md5") | |
| 126 | |
| 127 # Here comes the actual certificate | |
| 128 cert = crypto.X509() | |
| 129 cert.set_serial_number(1) | |
| 130 cert.gmtime_adj_notBefore(0) | |
| 131 cert.gmtime_adj_notAfter(60) # Testing certificates need not be long lived | |
| 132 cert.set_issuer(req.get_subject()) | |
| 133 cert.set_subject(req.get_subject()) | |
| 134 cert.set_pubkey(req.get_pubkey()) | |
| 135 cert.sign(pkey, "md5") | |
| 136 | |
| 137 return pkey, req, cert | |
| 138 | |
| 139 | |
| 140 def generateCertificateFiles(basename, organization, organizationalUnit): | |
| 141 pkey, req, cert = generateCertificateObjects(organization, organizationalUni
t) | |
| 142 | |
| 143 for ext, obj, dumpFunc in [ | |
| 144 ('key', pkey, crypto.dump_privatekey), | |
| 145 ('req', req, crypto.dump_certificate_request), | |
| 146 ('cert', cert, crypto.dump_certificate)]: | |
| 147 fName = os.extsep.join((basename, ext)) | |
| 148 fObj = file(fName, 'w') | |
| 149 fObj.write(dumpFunc(crypto.FILETYPE_PEM, obj)) | |
| 150 fObj.close() | |
| 151 | |
| 152 | |
| 153 class ContextGeneratingMixin: | |
| 154 def makeContextFactory(self, org, orgUnit, *args, **kwArgs): | |
| 155 base = self.mktemp() | |
| 156 generateCertificateFiles(base, org, orgUnit) | |
| 157 serverCtxFactory = ssl.DefaultOpenSSLContextFactory( | |
| 158 os.extsep.join((base, 'key')), | |
| 159 os.extsep.join((base, 'cert')), | |
| 160 *args, **kwArgs) | |
| 161 | |
| 162 return base, serverCtxFactory | |
| 163 | |
| 164 def setupServerAndClient(self, clientArgs, clientKwArgs, serverArgs, serverK
wArgs): | |
| 165 self.clientBase, self.clientCtxFactory = self.makeContextFactory( | |
| 166 *clientArgs, **clientKwArgs) | |
| 167 self.serverBase, self.serverCtxFactory = self.makeContextFactory( | |
| 168 *serverArgs, **serverKwArgs) | |
| 169 | |
| 170 | |
| 171 if SSL is not None: | |
| 172 class ServerTLSContext(ssl.DefaultOpenSSLContextFactory): | |
| 173 isClient = 0 | |
| 174 def __init__(self, *args, **kw): | |
| 175 kw['sslmethod'] = SSL.TLSv1_METHOD | |
| 176 ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kw) | |
| 177 | |
| 178 | |
| 179 class StolenTCPTestCase(ProperlyCloseFilesMixin, WriteDataTestCase): | |
| 180 """ | |
| 181 For SSL transports, test many of the same things which are tested for | |
| 182 TCP transports. | |
| 183 """ | |
| 184 def createServer(self, address, portNumber, factory): | |
| 185 contextFactory = ssl.CertificateOptions() | |
| 186 return reactor.listenSSL( | |
| 187 portNumber, factory, contextFactory, interface=address) | |
| 188 | |
| 189 | |
| 190 def connectClient(self, address, portNumber, clientCreator): | |
| 191 contextFactory = ssl.CertificateOptions() | |
| 192 return clientCreator.connectSSL(address, portNumber, contextFactory) | |
| 193 | |
| 194 | |
| 195 def getHandleExceptionType(self): | |
| 196 return SSL.SysCallError | |
| 197 | |
| 198 | |
| 199 def getHandleErrorCode(self): | |
| 200 # Windows 2000 SP 4 and Windows XP SP 2 give back WSAENOTSOCK for | |
| 201 # SSL.Connection.write for some reason. | |
| 202 if platform.getType() == 'win32': | |
| 203 return errno.WSAENOTSOCK | |
| 204 return ProperlyCloseFilesMixin.getHandleErrorCode(self) | |
| 205 | |
| 206 | |
| 207 class TLSTestCase(unittest.TestCase): | |
| 208 fillBuffer = 0 | |
| 209 | |
| 210 port = None | |
| 211 clientProto = None | |
| 212 serverProto = None | |
| 213 | |
| 214 def tearDown(self): | |
| 215 if self.clientProto is not None and self.clientProto.transport is not No
ne: | |
| 216 self.clientProto.transport.loseConnection() | |
| 217 if self.serverProto is not None and self.serverProto.transport is not No
ne: | |
| 218 self.serverProto.transport.loseConnection() | |
| 219 | |
| 220 if self.port is not None: | |
| 221 return defer.maybeDeferred(self.port.stopListening) | |
| 222 | |
| 223 def _runTest(self, clientProto, serverProto, clientIsServer=False): | |
| 224 self.clientProto = clientProto | |
| 225 cf = self.clientFactory = protocol.ClientFactory() | |
| 226 cf.protocol = lambda: clientProto | |
| 227 if clientIsServer: | |
| 228 cf.server = 0 | |
| 229 else: | |
| 230 cf.client = 1 | |
| 231 | |
| 232 self.serverProto = serverProto | |
| 233 sf = self.serverFactory = protocol.ServerFactory() | |
| 234 sf.protocol = lambda: serverProto | |
| 235 if clientIsServer: | |
| 236 sf.client = 0 | |
| 237 else: | |
| 238 sf.server = 1 | |
| 239 | |
| 240 if clientIsServer: | |
| 241 inCharge = cf | |
| 242 else: | |
| 243 inCharge = sf | |
| 244 inCharge.done = 0 | |
| 245 | |
| 246 port = self.port = reactor.listenTCP(0, sf, interface="127.0.0.1") | |
| 247 portNo = port.getHost().port | |
| 248 | |
| 249 reactor.connectTCP('127.0.0.1', portNo, cf) | |
| 250 | |
| 251 i = 0 | |
| 252 while i < 1000 and not inCharge.done: | |
| 253 reactor.iterate(0.01) | |
| 254 i += 1 | |
| 255 self.failUnless( | |
| 256 inCharge.done, | |
| 257 "Never finished reading all lines: %s" % (inCharge.lines,)) | |
| 258 | |
| 259 | |
| 260 def testTLS(self): | |
| 261 self._runTest(UnintelligentProtocol(), LineCollector(1, self.fillBuffer)
) | |
| 262 self.assertEquals( | |
| 263 self.serverFactory.lines, | |
| 264 UnintelligentProtocol.pretext + UnintelligentProtocol.posttext | |
| 265 ) | |
| 266 | |
| 267 | |
| 268 def testUnTLS(self): | |
| 269 self._runTest(UnintelligentProtocol(), LineCollector(0, self.fillBuffer)
) | |
| 270 self.assertEquals( | |
| 271 self.serverFactory.lines, | |
| 272 UnintelligentProtocol.pretext | |
| 273 ) | |
| 274 self.failUnless(self.serverFactory.rawdata, "No encrypted bytes received
") | |
| 275 | |
| 276 | |
| 277 def testBackwardsTLS(self): | |
| 278 self._runTest(LineCollector(1, self.fillBuffer), UnintelligentProtocol()
, True) | |
| 279 self.assertEquals( | |
| 280 self.clientFactory.lines, | |
| 281 UnintelligentProtocol.pretext + UnintelligentProtocol.posttext | |
| 282 ) | |
| 283 | |
| 284 | |
| 285 | |
| 286 _bufferedSuppression = trial_util.suppress( | |
| 287 message="startTLS with unwritten buffered data currently doesn't work " | |
| 288 "right. See issue #686. Closing connection.", | |
| 289 category=RuntimeWarning) | |
| 290 | |
| 291 | |
| 292 class SpammyTLSTestCase(TLSTestCase): | |
| 293 """ | |
| 294 Test TLS features with bytes sitting in the out buffer. | |
| 295 """ | |
| 296 fillBuffer = 1 | |
| 297 | |
| 298 def testTLS(self): | |
| 299 return TLSTestCase.testTLS(self) | |
| 300 testTLS.suppress = [_bufferedSuppression] | |
| 301 testTLS.todo = "startTLS doesn't empty buffer before starting TLS. :(" | |
| 302 | |
| 303 | |
| 304 def testBackwardsTLS(self): | |
| 305 return TLSTestCase.testBackwardsTLS(self) | |
| 306 testBackwardsTLS.suppress = [_bufferedSuppression] | |
| 307 testBackwardsTLS.todo = "startTLS doesn't empty buffer before starting TLS.
:(" | |
| 308 | |
| 309 | |
| 310 class BufferingTestCase(unittest.TestCase): | |
| 311 port = None | |
| 312 connector = None | |
| 313 serverProto = None | |
| 314 clientProto = None | |
| 315 | |
| 316 def tearDown(self): | |
| 317 if self.serverProto is not None and self.serverProto.transport is not No
ne: | |
| 318 self.serverProto.transport.loseConnection() | |
| 319 if self.clientProto is not None and self.clientProto.transport is not No
ne: | |
| 320 self.clientProto.transport.loseConnection() | |
| 321 if self.port is not None: | |
| 322 return defer.maybeDeferred(self.port.stopListening) | |
| 323 | |
| 324 def testOpenSSLBuffering(self): | |
| 325 serverProto = self.serverProto = SingleLineServerProtocol() | |
| 326 clientProto = self.clientProto = RecordingClientProtocol() | |
| 327 | |
| 328 server = protocol.ServerFactory() | |
| 329 client = self.client = protocol.ClientFactory() | |
| 330 | |
| 331 server.protocol = lambda: serverProto | |
| 332 client.protocol = lambda: clientProto | |
| 333 client.buffer = [] | |
| 334 | |
| 335 sCTX = ssl.DefaultOpenSSLContextFactory(certPath, certPath) | |
| 336 cCTX = ssl.ClientContextFactory() | |
| 337 | |
| 338 port = self.port = reactor.listenSSL(0, server, sCTX, interface='127.0.0
.1') | |
| 339 reactor.connectSSL('127.0.0.1', port.getHost().port, client, cCTX) | |
| 340 | |
| 341 i = 0 | |
| 342 while i < 5000 and not client.buffer: | |
| 343 i += 1 | |
| 344 reactor.iterate() | |
| 345 | |
| 346 self.assertEquals(client.buffer, ["+OK <some crap>\r\n"]) | |
| 347 | |
| 348 | |
| 349 class ConnectionLostTestCase(unittest.TestCase, ContextGeneratingMixin): | |
| 350 | |
| 351 def testImmediateDisconnect(self): | |
| 352 org = "twisted.test.test_ssl" | |
| 353 self.setupServerAndClient( | |
| 354 (org, org + ", client"), {}, | |
| 355 (org, org + ", server"), {}) | |
| 356 | |
| 357 # Set up a server, connect to it with a client, which should work since
our verifiers | |
| 358 # allow anything, then disconnect. | |
| 359 serverProtocolFactory = protocol.ServerFactory() | |
| 360 serverProtocolFactory.protocol = protocol.Protocol | |
| 361 self.serverPort = serverPort = reactor.listenSSL(0, | |
| 362 serverProtocolFactory, self.serverCtxFactory) | |
| 363 | |
| 364 clientProtocolFactory = protocol.ClientFactory() | |
| 365 clientProtocolFactory.protocol = ImmediatelyDisconnectingProtocol | |
| 366 clientProtocolFactory.connectionDisconnected = defer.Deferred() | |
| 367 clientConnector = reactor.connectSSL('127.0.0.1', | |
| 368 serverPort.getHost().port, clientProtocolFactory, self.clientCtxFact
ory) | |
| 369 | |
| 370 return clientProtocolFactory.connectionDisconnected.addCallback( | |
| 371 lambda ignoredResult: self.serverPort.stopListening()) | |
| 372 | |
| 373 def testFailedVerify(self): | |
| 374 org = "twisted.test.test_ssl" | |
| 375 self.setupServerAndClient( | |
| 376 (org, org + ", client"), {}, | |
| 377 (org, org + ", server"), {}) | |
| 378 | |
| 379 def verify(*a): | |
| 380 return False | |
| 381 self.clientCtxFactory.getContext().set_verify(SSL.VERIFY_PEER, verify) | |
| 382 | |
| 383 serverConnLost = defer.Deferred() | |
| 384 serverProtocol = protocol.Protocol() | |
| 385 serverProtocol.connectionLost = serverConnLost.callback | |
| 386 serverProtocolFactory = protocol.ServerFactory() | |
| 387 serverProtocolFactory.protocol = lambda: serverProtocol | |
| 388 self.serverPort = serverPort = reactor.listenSSL(0, | |
| 389 serverProtocolFactory, self.serverCtxFactory) | |
| 390 | |
| 391 clientConnLost = defer.Deferred() | |
| 392 clientProtocol = protocol.Protocol() | |
| 393 clientProtocol.connectionLost = clientConnLost.callback | |
| 394 clientProtocolFactory = protocol.ClientFactory() | |
| 395 clientProtocolFactory.protocol = lambda: clientProtocol | |
| 396 clientConnector = reactor.connectSSL('127.0.0.1', | |
| 397 serverPort.getHost().port, clientProtocolFactory, self.clientCtxFact
ory) | |
| 398 | |
| 399 dl = defer.DeferredList([serverConnLost, clientConnLost], consumeErrors=
True) | |
| 400 return dl.addCallback(self._cbLostConns) | |
| 401 | |
| 402 def _cbLostConns(self, results): | |
| 403 (sSuccess, sResult), (cSuccess, cResult) = results | |
| 404 | |
| 405 self.failIf(sSuccess) | |
| 406 self.failIf(cSuccess) | |
| 407 | |
| 408 acceptableErrors = [SSL.Error] | |
| 409 | |
| 410 # Rather than getting a verification failure on Windows, we are getting | |
| 411 # a connection failure. Without something like sslverify proxying | |
| 412 # in-between we can't fix up the platform's errors, so let's just | |
| 413 # specifically say it is only OK in this one case to keep the tests | |
| 414 # passing. Normally we'd like to be as strict as possible here, so | |
| 415 # we're not going to allow this to report errors incorrectly on any | |
| 416 # other platforms. | |
| 417 | |
| 418 if platform.isWindows(): | |
| 419 from twisted.internet.error import ConnectionLost | |
| 420 acceptableErrors.append(ConnectionLost) | |
| 421 | |
| 422 sResult.trap(*acceptableErrors) | |
| 423 cResult.trap(*acceptableErrors) | |
| 424 | |
| 425 return self.serverPort.stopListening() | |
| 426 | |
| 427 | |
| 428 if interfaces.IReactorSSL(reactor, None) is None: | |
| 429 for tCase in [StolenTCPTestCase, TLSTestCase, SpammyTLSTestCase, | |
| 430 BufferingTestCase, ConnectionLostTestCase]: | |
| 431 tCase.skip = "Reactor does not support SSL, cannot run SSL tests" | |
| OLD | NEW |