OLD | NEW |
| (Empty) |
1 # Copyright (c) 2001-2007 Twisted Matrix Laboratories. | |
2 # See LICENSE for details. | |
3 | |
4 | |
5 from twisted.trial import unittest, util as trial_util | |
6 from twisted.internet import protocol, reactor, interfaces, defer | |
7 from twisted.protocols import basic | |
8 from twisted.python import util, log | |
9 from twisted.python.runtime import platform | |
10 from twisted.test.test_tcp import WriteDataTestCase, ProperlyCloseFilesMixin | |
11 | |
12 import os, errno | |
13 | |
14 try: | |
15 from OpenSSL import SSL, crypto | |
16 from twisted.internet import ssl | |
17 from twisted.test.ssl_helpers import ClientTLSContext | |
18 except ImportError: | |
19 def _noSSL(): | |
20 # ugh, make pyflakes happy. | |
21 global SSL | |
22 global ssl | |
23 SSL = ssl = None | |
24 _noSSL() | |
25 | |
26 certPath = util.sibpath(__file__, "server.pem") | |
27 | |
28 class UnintelligentProtocol(basic.LineReceiver): | |
29 pretext = [ | |
30 "first line", | |
31 "last thing before tls starts", | |
32 "STARTTLS"] | |
33 | |
34 posttext = [ | |
35 "first thing after tls started", | |
36 "last thing ever"] | |
37 | |
38 def connectionMade(self): | |
39 for l in self.pretext: | |
40 self.sendLine(l) | |
41 | |
42 def lineReceived(self, line): | |
43 if line == "READY": | |
44 self.transport.startTLS(ClientTLSContext(), self.factory.client) | |
45 for l in self.posttext: | |
46 self.sendLine(l) | |
47 self.transport.loseConnection() | |
48 | |
49 | |
50 class LineCollector(basic.LineReceiver): | |
51 def __init__(self, doTLS, fillBuffer=0): | |
52 self.doTLS = doTLS | |
53 self.fillBuffer = fillBuffer | |
54 | |
55 def connectionMade(self): | |
56 self.factory.rawdata = '' | |
57 self.factory.lines = [] | |
58 | |
59 def lineReceived(self, line): | |
60 self.factory.lines.append(line) | |
61 if line == 'STARTTLS': | |
62 if self.fillBuffer: | |
63 for x in range(500): | |
64 self.sendLine('X'*1000) | |
65 self.sendLine('READY') | |
66 if self.doTLS: | |
67 ctx = ServerTLSContext( | |
68 privateKeyFileName=certPath, | |
69 certificateFileName=certPath, | |
70 ) | |
71 self.transport.startTLS(ctx, self.factory.server) | |
72 else: | |
73 self.setRawMode() | |
74 | |
75 def rawDataReceived(self, data): | |
76 self.factory.rawdata += data | |
77 self.factory.done = 1 | |
78 | |
79 def connectionLost(self, reason): | |
80 self.factory.done = 1 | |
81 | |
82 | |
83 class SingleLineServerProtocol(protocol.Protocol): | |
84 def connectionMade(self): | |
85 self.transport.identifier = 'SERVER' | |
86 self.transport.write("+OK <some crap>\r\n") | |
87 self.transport.getPeerCertificate() | |
88 | |
89 | |
90 class RecordingClientProtocol(protocol.Protocol): | |
91 def connectionMade(self): | |
92 self.transport.identifier = 'CLIENT' | |
93 self.buffer = [] | |
94 self.transport.getPeerCertificate() | |
95 | |
96 def dataReceived(self, data): | |
97 self.factory.buffer.append(data) | |
98 | |
99 | |
100 class ImmediatelyDisconnectingProtocol(protocol.Protocol): | |
101 def connectionMade(self): | |
102 self.transport.loseConnection() | |
103 | |
104 def connectionLost(self, reason): | |
105 self.factory.connectionDisconnected.callback(None) | |
106 | |
107 | |
108 class AlmostImmediatelyDisconnectingProtocol(protocol.Protocol): | |
109 def connectionMade(self): | |
110 # Twisted's SSL support is terribly broken. | |
111 reactor.callLater(0.1, self.transport.loseConnection) | |
112 | |
113 def connectionLost(self, reason): | |
114 self.factory.connectionDisconnected.callback(reason) | |
115 | |
116 | |
117 def generateCertificateObjects(organization, organizationalUnit): | |
118 pkey = crypto.PKey() | |
119 pkey.generate_key(crypto.TYPE_RSA, 512) | |
120 req = crypto.X509Req() | |
121 subject = req.get_subject() | |
122 subject.O = organization | |
123 subject.OU = organizationalUnit | |
124 req.set_pubkey(pkey) | |
125 req.sign(pkey, "md5") | |
126 | |
127 # Here comes the actual certificate | |
128 cert = crypto.X509() | |
129 cert.set_serial_number(1) | |
130 cert.gmtime_adj_notBefore(0) | |
131 cert.gmtime_adj_notAfter(60) # Testing certificates need not be long lived | |
132 cert.set_issuer(req.get_subject()) | |
133 cert.set_subject(req.get_subject()) | |
134 cert.set_pubkey(req.get_pubkey()) | |
135 cert.sign(pkey, "md5") | |
136 | |
137 return pkey, req, cert | |
138 | |
139 | |
140 def generateCertificateFiles(basename, organization, organizationalUnit): | |
141 pkey, req, cert = generateCertificateObjects(organization, organizationalUni
t) | |
142 | |
143 for ext, obj, dumpFunc in [ | |
144 ('key', pkey, crypto.dump_privatekey), | |
145 ('req', req, crypto.dump_certificate_request), | |
146 ('cert', cert, crypto.dump_certificate)]: | |
147 fName = os.extsep.join((basename, ext)) | |
148 fObj = file(fName, 'w') | |
149 fObj.write(dumpFunc(crypto.FILETYPE_PEM, obj)) | |
150 fObj.close() | |
151 | |
152 | |
153 class ContextGeneratingMixin: | |
154 def makeContextFactory(self, org, orgUnit, *args, **kwArgs): | |
155 base = self.mktemp() | |
156 generateCertificateFiles(base, org, orgUnit) | |
157 serverCtxFactory = ssl.DefaultOpenSSLContextFactory( | |
158 os.extsep.join((base, 'key')), | |
159 os.extsep.join((base, 'cert')), | |
160 *args, **kwArgs) | |
161 | |
162 return base, serverCtxFactory | |
163 | |
164 def setupServerAndClient(self, clientArgs, clientKwArgs, serverArgs, serverK
wArgs): | |
165 self.clientBase, self.clientCtxFactory = self.makeContextFactory( | |
166 *clientArgs, **clientKwArgs) | |
167 self.serverBase, self.serverCtxFactory = self.makeContextFactory( | |
168 *serverArgs, **serverKwArgs) | |
169 | |
170 | |
171 if SSL is not None: | |
172 class ServerTLSContext(ssl.DefaultOpenSSLContextFactory): | |
173 isClient = 0 | |
174 def __init__(self, *args, **kw): | |
175 kw['sslmethod'] = SSL.TLSv1_METHOD | |
176 ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kw) | |
177 | |
178 | |
179 class StolenTCPTestCase(ProperlyCloseFilesMixin, WriteDataTestCase): | |
180 """ | |
181 For SSL transports, test many of the same things which are tested for | |
182 TCP transports. | |
183 """ | |
184 def createServer(self, address, portNumber, factory): | |
185 contextFactory = ssl.CertificateOptions() | |
186 return reactor.listenSSL( | |
187 portNumber, factory, contextFactory, interface=address) | |
188 | |
189 | |
190 def connectClient(self, address, portNumber, clientCreator): | |
191 contextFactory = ssl.CertificateOptions() | |
192 return clientCreator.connectSSL(address, portNumber, contextFactory) | |
193 | |
194 | |
195 def getHandleExceptionType(self): | |
196 return SSL.SysCallError | |
197 | |
198 | |
199 def getHandleErrorCode(self): | |
200 # Windows 2000 SP 4 and Windows XP SP 2 give back WSAENOTSOCK for | |
201 # SSL.Connection.write for some reason. | |
202 if platform.getType() == 'win32': | |
203 return errno.WSAENOTSOCK | |
204 return ProperlyCloseFilesMixin.getHandleErrorCode(self) | |
205 | |
206 | |
207 class TLSTestCase(unittest.TestCase): | |
208 fillBuffer = 0 | |
209 | |
210 port = None | |
211 clientProto = None | |
212 serverProto = None | |
213 | |
214 def tearDown(self): | |
215 if self.clientProto is not None and self.clientProto.transport is not No
ne: | |
216 self.clientProto.transport.loseConnection() | |
217 if self.serverProto is not None and self.serverProto.transport is not No
ne: | |
218 self.serverProto.transport.loseConnection() | |
219 | |
220 if self.port is not None: | |
221 return defer.maybeDeferred(self.port.stopListening) | |
222 | |
223 def _runTest(self, clientProto, serverProto, clientIsServer=False): | |
224 self.clientProto = clientProto | |
225 cf = self.clientFactory = protocol.ClientFactory() | |
226 cf.protocol = lambda: clientProto | |
227 if clientIsServer: | |
228 cf.server = 0 | |
229 else: | |
230 cf.client = 1 | |
231 | |
232 self.serverProto = serverProto | |
233 sf = self.serverFactory = protocol.ServerFactory() | |
234 sf.protocol = lambda: serverProto | |
235 if clientIsServer: | |
236 sf.client = 0 | |
237 else: | |
238 sf.server = 1 | |
239 | |
240 if clientIsServer: | |
241 inCharge = cf | |
242 else: | |
243 inCharge = sf | |
244 inCharge.done = 0 | |
245 | |
246 port = self.port = reactor.listenTCP(0, sf, interface="127.0.0.1") | |
247 portNo = port.getHost().port | |
248 | |
249 reactor.connectTCP('127.0.0.1', portNo, cf) | |
250 | |
251 i = 0 | |
252 while i < 1000 and not inCharge.done: | |
253 reactor.iterate(0.01) | |
254 i += 1 | |
255 self.failUnless( | |
256 inCharge.done, | |
257 "Never finished reading all lines: %s" % (inCharge.lines,)) | |
258 | |
259 | |
260 def testTLS(self): | |
261 self._runTest(UnintelligentProtocol(), LineCollector(1, self.fillBuffer)
) | |
262 self.assertEquals( | |
263 self.serverFactory.lines, | |
264 UnintelligentProtocol.pretext + UnintelligentProtocol.posttext | |
265 ) | |
266 | |
267 | |
268 def testUnTLS(self): | |
269 self._runTest(UnintelligentProtocol(), LineCollector(0, self.fillBuffer)
) | |
270 self.assertEquals( | |
271 self.serverFactory.lines, | |
272 UnintelligentProtocol.pretext | |
273 ) | |
274 self.failUnless(self.serverFactory.rawdata, "No encrypted bytes received
") | |
275 | |
276 | |
277 def testBackwardsTLS(self): | |
278 self._runTest(LineCollector(1, self.fillBuffer), UnintelligentProtocol()
, True) | |
279 self.assertEquals( | |
280 self.clientFactory.lines, | |
281 UnintelligentProtocol.pretext + UnintelligentProtocol.posttext | |
282 ) | |
283 | |
284 | |
285 | |
286 _bufferedSuppression = trial_util.suppress( | |
287 message="startTLS with unwritten buffered data currently doesn't work " | |
288 "right. See issue #686. Closing connection.", | |
289 category=RuntimeWarning) | |
290 | |
291 | |
292 class SpammyTLSTestCase(TLSTestCase): | |
293 """ | |
294 Test TLS features with bytes sitting in the out buffer. | |
295 """ | |
296 fillBuffer = 1 | |
297 | |
298 def testTLS(self): | |
299 return TLSTestCase.testTLS(self) | |
300 testTLS.suppress = [_bufferedSuppression] | |
301 testTLS.todo = "startTLS doesn't empty buffer before starting TLS. :(" | |
302 | |
303 | |
304 def testBackwardsTLS(self): | |
305 return TLSTestCase.testBackwardsTLS(self) | |
306 testBackwardsTLS.suppress = [_bufferedSuppression] | |
307 testBackwardsTLS.todo = "startTLS doesn't empty buffer before starting TLS.
:(" | |
308 | |
309 | |
310 class BufferingTestCase(unittest.TestCase): | |
311 port = None | |
312 connector = None | |
313 serverProto = None | |
314 clientProto = None | |
315 | |
316 def tearDown(self): | |
317 if self.serverProto is not None and self.serverProto.transport is not No
ne: | |
318 self.serverProto.transport.loseConnection() | |
319 if self.clientProto is not None and self.clientProto.transport is not No
ne: | |
320 self.clientProto.transport.loseConnection() | |
321 if self.port is not None: | |
322 return defer.maybeDeferred(self.port.stopListening) | |
323 | |
324 def testOpenSSLBuffering(self): | |
325 serverProto = self.serverProto = SingleLineServerProtocol() | |
326 clientProto = self.clientProto = RecordingClientProtocol() | |
327 | |
328 server = protocol.ServerFactory() | |
329 client = self.client = protocol.ClientFactory() | |
330 | |
331 server.protocol = lambda: serverProto | |
332 client.protocol = lambda: clientProto | |
333 client.buffer = [] | |
334 | |
335 sCTX = ssl.DefaultOpenSSLContextFactory(certPath, certPath) | |
336 cCTX = ssl.ClientContextFactory() | |
337 | |
338 port = self.port = reactor.listenSSL(0, server, sCTX, interface='127.0.0
.1') | |
339 reactor.connectSSL('127.0.0.1', port.getHost().port, client, cCTX) | |
340 | |
341 i = 0 | |
342 while i < 5000 and not client.buffer: | |
343 i += 1 | |
344 reactor.iterate() | |
345 | |
346 self.assertEquals(client.buffer, ["+OK <some crap>\r\n"]) | |
347 | |
348 | |
349 class ConnectionLostTestCase(unittest.TestCase, ContextGeneratingMixin): | |
350 | |
351 def testImmediateDisconnect(self): | |
352 org = "twisted.test.test_ssl" | |
353 self.setupServerAndClient( | |
354 (org, org + ", client"), {}, | |
355 (org, org + ", server"), {}) | |
356 | |
357 # Set up a server, connect to it with a client, which should work since
our verifiers | |
358 # allow anything, then disconnect. | |
359 serverProtocolFactory = protocol.ServerFactory() | |
360 serverProtocolFactory.protocol = protocol.Protocol | |
361 self.serverPort = serverPort = reactor.listenSSL(0, | |
362 serverProtocolFactory, self.serverCtxFactory) | |
363 | |
364 clientProtocolFactory = protocol.ClientFactory() | |
365 clientProtocolFactory.protocol = ImmediatelyDisconnectingProtocol | |
366 clientProtocolFactory.connectionDisconnected = defer.Deferred() | |
367 clientConnector = reactor.connectSSL('127.0.0.1', | |
368 serverPort.getHost().port, clientProtocolFactory, self.clientCtxFact
ory) | |
369 | |
370 return clientProtocolFactory.connectionDisconnected.addCallback( | |
371 lambda ignoredResult: self.serverPort.stopListening()) | |
372 | |
373 def testFailedVerify(self): | |
374 org = "twisted.test.test_ssl" | |
375 self.setupServerAndClient( | |
376 (org, org + ", client"), {}, | |
377 (org, org + ", server"), {}) | |
378 | |
379 def verify(*a): | |
380 return False | |
381 self.clientCtxFactory.getContext().set_verify(SSL.VERIFY_PEER, verify) | |
382 | |
383 serverConnLost = defer.Deferred() | |
384 serverProtocol = protocol.Protocol() | |
385 serverProtocol.connectionLost = serverConnLost.callback | |
386 serverProtocolFactory = protocol.ServerFactory() | |
387 serverProtocolFactory.protocol = lambda: serverProtocol | |
388 self.serverPort = serverPort = reactor.listenSSL(0, | |
389 serverProtocolFactory, self.serverCtxFactory) | |
390 | |
391 clientConnLost = defer.Deferred() | |
392 clientProtocol = protocol.Protocol() | |
393 clientProtocol.connectionLost = clientConnLost.callback | |
394 clientProtocolFactory = protocol.ClientFactory() | |
395 clientProtocolFactory.protocol = lambda: clientProtocol | |
396 clientConnector = reactor.connectSSL('127.0.0.1', | |
397 serverPort.getHost().port, clientProtocolFactory, self.clientCtxFact
ory) | |
398 | |
399 dl = defer.DeferredList([serverConnLost, clientConnLost], consumeErrors=
True) | |
400 return dl.addCallback(self._cbLostConns) | |
401 | |
402 def _cbLostConns(self, results): | |
403 (sSuccess, sResult), (cSuccess, cResult) = results | |
404 | |
405 self.failIf(sSuccess) | |
406 self.failIf(cSuccess) | |
407 | |
408 acceptableErrors = [SSL.Error] | |
409 | |
410 # Rather than getting a verification failure on Windows, we are getting | |
411 # a connection failure. Without something like sslverify proxying | |
412 # in-between we can't fix up the platform's errors, so let's just | |
413 # specifically say it is only OK in this one case to keep the tests | |
414 # passing. Normally we'd like to be as strict as possible here, so | |
415 # we're not going to allow this to report errors incorrectly on any | |
416 # other platforms. | |
417 | |
418 if platform.isWindows(): | |
419 from twisted.internet.error import ConnectionLost | |
420 acceptableErrors.append(ConnectionLost) | |
421 | |
422 sResult.trap(*acceptableErrors) | |
423 cResult.trap(*acceptableErrors) | |
424 | |
425 return self.serverPort.stopListening() | |
426 | |
427 | |
428 if interfaces.IReactorSSL(reactor, None) is None: | |
429 for tCase in [StolenTCPTestCase, TLSTestCase, SpammyTLSTestCase, | |
430 BufferingTestCase, ConnectionLostTestCase]: | |
431 tCase.skip = "Reactor does not support SSL, cannot run SSL tests" | |
OLD | NEW |