OLD | NEW |
| (Empty) |
1 # Copyright (c) 2001-2008 Twisted Matrix Laboratories. | |
2 # See LICENSE for details. | |
3 | |
4 """ | |
5 Tests for implementations of L{IReactorTCP}. | |
6 """ | |
7 | |
8 import socket, random, errno | |
9 | |
10 from zope.interface import implements | |
11 | |
12 from twisted.trial import unittest | |
13 | |
14 from twisted.python.log import msg | |
15 from twisted.internet import protocol, reactor, defer, interfaces | |
16 from twisted.internet import error | |
17 from twisted.internet.address import IPv4Address | |
18 from twisted.internet.interfaces import IHalfCloseableProtocol, IPullProducer | |
19 from twisted.protocols import policies | |
20 | |
21 def loopUntil(predicate, interval=0): | |
22 """ | |
23 Poor excuse for an event notification helper. This polls a condition and | |
24 calls back a Deferred when it is seen to be true. | |
25 | |
26 Do not use this function. | |
27 """ | |
28 from twisted.internet import task | |
29 d = defer.Deferred() | |
30 def check(): | |
31 res = predicate() | |
32 if res: | |
33 d.callback(res) | |
34 call = task.LoopingCall(check) | |
35 def stop(result): | |
36 call.stop() | |
37 return result | |
38 d.addCallback(stop) | |
39 d2 = call.start(interval) | |
40 d2.addErrback(d.errback) | |
41 return d | |
42 | |
43 | |
44 class ClosingProtocol(protocol.Protocol): | |
45 | |
46 def connectionMade(self): | |
47 self.transport.loseConnection() | |
48 | |
49 def connectionLost(self, reason): | |
50 reason.trap(error.ConnectionDone) | |
51 | |
52 class ClosingFactory(protocol.ServerFactory): | |
53 """Factory that closes port immediatley.""" | |
54 | |
55 def buildProtocol(self, conn): | |
56 self.port.stopListening() | |
57 return ClosingProtocol() | |
58 | |
59 | |
60 class MyProtocol(protocol.Protocol): | |
61 made = closed = failed = 0 | |
62 | |
63 closedDeferred = None | |
64 | |
65 data = "" | |
66 | |
67 factory = None | |
68 | |
69 def connectionMade(self): | |
70 self.made = 1 | |
71 if (self.factory is not None and | |
72 self.factory.protocolConnectionMade is not None): | |
73 d = self.factory.protocolConnectionMade | |
74 self.factory.protocolConnectionMade = None | |
75 d.callback(self) | |
76 | |
77 def dataReceived(self, data): | |
78 self.data += data | |
79 | |
80 def connectionLost(self, reason): | |
81 self.closed = 1 | |
82 if self.closedDeferred is not None: | |
83 d, self.closedDeferred = self.closedDeferred, None | |
84 d.callback(None) | |
85 | |
86 | |
87 class MyProtocolFactoryMixin(object): | |
88 """ | |
89 Mixin for factories which create L{MyProtocol} instances. | |
90 | |
91 @type protocolFactory: no-argument callable | |
92 @ivar protocolFactory: Factory for protocols - takes the place of the | |
93 typical C{protocol} attribute of factories (but that name is used by | |
94 this class for something else). | |
95 | |
96 @type protocolConnectionMade: L{NoneType} or L{defer.Deferred} | |
97 @ivar protocolConnectionMade: When an instance of L{MyProtocol} is | |
98 connected, if this is not C{None}, the L{Deferred} will be called | |
99 back with the protocol instance and the attribute set to C{None}. | |
100 | |
101 @type protocolConnectionLost: L{NoneType} or L{defer.Deferred} | |
102 @ivar protocolConnectionLost: When an instance of L{MyProtocol} is | |
103 created, this will be set as its C{closedDeferred} attribute and | |
104 then this attribute will be set to C{None} so the L{defer.Deferred} | |
105 is not used by more than one protocol. | |
106 | |
107 @ivar protocol: The most recently created L{MyProtocol} instance which | |
108 was returned from C{buildProtocol}. | |
109 | |
110 @type called: C{int} | |
111 @ivar called: A counter which is incremented each time C{buildProtocol} | |
112 is called. | |
113 | |
114 @ivar peerAddresses: A C{list} of the addresses passed to C{buildProtocol}. | |
115 """ | |
116 protocolFactory = MyProtocol | |
117 | |
118 protocolConnectionMade = None | |
119 protocolConnectionLost = None | |
120 protocol = None | |
121 called = 0 | |
122 | |
123 def __init__(self): | |
124 self.peerAddresses = [] | |
125 | |
126 | |
127 def buildProtocol(self, addr): | |
128 """ | |
129 Create a L{MyProtocol} and set it up to be able to perform | |
130 callbacks. | |
131 """ | |
132 self.peerAddresses.append(addr) | |
133 self.called += 1 | |
134 p = self.protocolFactory() | |
135 p.factory = self | |
136 p.closedDeferred = self.protocolConnectionLost | |
137 self.protocolConnectionLost = None | |
138 self.protocol = p | |
139 return p | |
140 | |
141 | |
142 | |
143 class MyServerFactory(MyProtocolFactoryMixin, protocol.ServerFactory): | |
144 """ | |
145 Server factory which creates L{MyProtocol} instances. | |
146 """ | |
147 | |
148 | |
149 | |
150 class MyClientFactory(MyProtocolFactoryMixin, protocol.ClientFactory): | |
151 """ | |
152 Client factory which creates L{MyProtocol} instances. | |
153 """ | |
154 failed = 0 | |
155 stopped = 0 | |
156 | |
157 def __init__(self): | |
158 MyProtocolFactoryMixin.__init__(self) | |
159 self.deferred = defer.Deferred() | |
160 self.failDeferred = defer.Deferred() | |
161 | |
162 def clientConnectionFailed(self, connector, reason): | |
163 self.failed = 1 | |
164 self.reason = reason | |
165 self.failDeferred.callback(None) | |
166 | |
167 def clientConnectionLost(self, connector, reason): | |
168 self.lostReason = reason | |
169 self.deferred.callback(None) | |
170 | |
171 def stopFactory(self): | |
172 self.stopped = 1 | |
173 | |
174 | |
175 | |
176 class ListeningTestCase(unittest.TestCase): | |
177 | |
178 def test_listen(self): | |
179 """ | |
180 L{IReactorTCP.listenTCP} returns an object which provides | |
181 L{IListeningPort}. | |
182 """ | |
183 f = MyServerFactory() | |
184 p1 = reactor.listenTCP(0, f, interface="127.0.0.1") | |
185 self.addCleanup(p1.stopListening) | |
186 self.failUnless(interfaces.IListeningPort.providedBy(p1)) | |
187 | |
188 | |
189 def testStopListening(self): | |
190 """ | |
191 The L{IListeningPort} returned by L{IReactorTCP.listenTCP} can be | |
192 stopped with its C{stopListening} method. After the L{Deferred} it | |
193 (optionally) returns has been called back, the port number can be bound | |
194 to a new server. | |
195 """ | |
196 f = MyServerFactory() | |
197 port = reactor.listenTCP(0, f, interface="127.0.0.1") | |
198 n = port.getHost().port | |
199 | |
200 def cbStopListening(ignored): | |
201 # Make sure we can rebind the port right away | |
202 port = reactor.listenTCP(n, f, interface="127.0.0.1") | |
203 return port.stopListening() | |
204 | |
205 d = defer.maybeDeferred(port.stopListening) | |
206 d.addCallback(cbStopListening) | |
207 return d | |
208 | |
209 | |
210 def testNumberedInterface(self): | |
211 f = MyServerFactory() | |
212 # listen only on the loopback interface | |
213 p1 = reactor.listenTCP(0, f, interface='127.0.0.1') | |
214 return p1.stopListening() | |
215 | |
216 def testPortRepr(self): | |
217 f = MyServerFactory() | |
218 p = reactor.listenTCP(0, f) | |
219 portNo = str(p.getHost().port) | |
220 self.failIf(repr(p).find(portNo) == -1) | |
221 def stoppedListening(ign): | |
222 self.failIf(repr(p).find(portNo) != -1) | |
223 d = defer.maybeDeferred(p.stopListening) | |
224 return d.addCallback(stoppedListening) | |
225 | |
226 | |
227 def test_serverRepr(self): | |
228 """ | |
229 Check that the repr string of the server transport get the good port | |
230 number if the server listens on 0. | |
231 """ | |
232 server = MyServerFactory() | |
233 serverConnMade = server.protocolConnectionMade = defer.Deferred() | |
234 port = reactor.listenTCP(0, server) | |
235 self.addCleanup(port.stopListening) | |
236 | |
237 client = MyClientFactory() | |
238 clientConnMade = client.protocolConnectionMade = defer.Deferred() | |
239 connector = reactor.connectTCP("127.0.0.1", | |
240 port.getHost().port, client) | |
241 self.addCleanup(connector.disconnect) | |
242 def check((serverProto, clientProto)): | |
243 portNumber = port.getHost().port | |
244 self.assertEquals(repr(serverProto.transport), | |
245 "<MyProtocol #0 on %s>" % (portNumber,)) | |
246 serverProto.transport.loseConnection() | |
247 clientProto.transport.loseConnection() | |
248 return defer.gatherResults([serverConnMade, clientConnMade] | |
249 ).addCallback(check) | |
250 | |
251 | |
252 | |
253 def callWithSpew(f): | |
254 from twisted.python.util import spewerWithLinenums as spewer | |
255 import sys | |
256 sys.settrace(spewer) | |
257 try: | |
258 f() | |
259 finally: | |
260 sys.settrace(None) | |
261 | |
262 class LoopbackTestCase(unittest.TestCase): | |
263 """ | |
264 Test loopback connections. | |
265 """ | |
266 def test_closePortInProtocolFactory(self): | |
267 """ | |
268 A port created with L{IReactorTCP.listenTCP} can be connected to with | |
269 L{IReactorTCP.connectTCP}. | |
270 """ | |
271 f = ClosingFactory() | |
272 port = reactor.listenTCP(0, f, interface="127.0.0.1") | |
273 self.addCleanup(port.stopListening) | |
274 portNumber = port.getHost().port | |
275 f.port = port | |
276 clientF = MyClientFactory() | |
277 reactor.connectTCP("127.0.0.1", portNumber, clientF) | |
278 def check(x): | |
279 self.assertTrue(clientF.protocol.made) | |
280 self.assertTrue(port.disconnected) | |
281 clientF.lostReason.trap(error.ConnectionDone) | |
282 return clientF.deferred.addCallback(check) | |
283 | |
284 def _trapCnxDone(self, obj): | |
285 getattr(obj, 'trap', lambda x: None)(error.ConnectionDone) | |
286 | |
287 | |
288 def _connectedClientAndServerTest(self, callback): | |
289 """ | |
290 Invoke the given callback with a client protocol and a server protocol | |
291 which have been connected to each other. | |
292 """ | |
293 serverFactory = MyServerFactory() | |
294 serverConnMade = defer.Deferred() | |
295 serverFactory.protocolConnectionMade = serverConnMade | |
296 port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1") | |
297 self.addCleanup(port.stopListening) | |
298 | |
299 portNumber = port.getHost().port | |
300 clientF = MyClientFactory() | |
301 clientConnMade = defer.Deferred() | |
302 clientF.protocolConnectionMade = clientConnMade | |
303 reactor.connectTCP("127.0.0.1", portNumber, clientF) | |
304 | |
305 connsMade = defer.gatherResults([serverConnMade, clientConnMade]) | |
306 def connected((serverProtocol, clientProtocol)): | |
307 callback(serverProtocol, clientProtocol) | |
308 serverProtocol.transport.loseConnection() | |
309 clientProtocol.transport.loseConnection() | |
310 connsMade.addCallback(connected) | |
311 return connsMade | |
312 | |
313 | |
314 def test_tcpNoDelay(self): | |
315 """ | |
316 The transport of a protocol connected with L{IReactorTCP.connectTCP} or | |
317 L{IReactor.TCP.listenTCP} can have its I{TCP_NODELAY} state inspected | |
318 and manipulated with L{ITCPTransport.getTcpNoDelay} and | |
319 L{ITCPTransport.setTcpNoDelay}. | |
320 """ | |
321 def check(serverProtocol, clientProtocol): | |
322 for p in [serverProtocol, clientProtocol]: | |
323 transport = p.transport | |
324 self.assertEquals(transport.getTcpNoDelay(), 0) | |
325 transport.setTcpNoDelay(1) | |
326 self.assertEquals(transport.getTcpNoDelay(), 1) | |
327 transport.setTcpNoDelay(0) | |
328 self.assertEquals(transport.getTcpNoDelay(), 0) | |
329 return self._connectedClientAndServerTest(check) | |
330 | |
331 | |
332 def test_tcpKeepAlive(self): | |
333 """ | |
334 The transport of a protocol connected with L{IReactorTCP.connectTCP} or | |
335 L{IReactor.TCP.listenTCP} can have its I{SO_KEEPALIVE} state inspected | |
336 and manipulated with L{ITCPTransport.getTcpKeepAlive} and | |
337 L{ITCPTransport.setTcpKeepAlive}. | |
338 """ | |
339 def check(serverProtocol, clientProtocol): | |
340 for p in [serverProtocol, clientProtocol]: | |
341 transport = p.transport | |
342 self.assertEquals(transport.getTcpKeepAlive(), 0) | |
343 transport.setTcpKeepAlive(1) | |
344 self.assertEquals(transport.getTcpKeepAlive(), 1) | |
345 transport.setTcpKeepAlive(0) | |
346 self.assertEquals(transport.getTcpKeepAlive(), 0) | |
347 return self._connectedClientAndServerTest(check) | |
348 | |
349 | |
350 def testFailing(self): | |
351 clientF = MyClientFactory() | |
352 # XXX we assume no one is listening on TCP port 69 | |
353 reactor.connectTCP("127.0.0.1", 69, clientF, timeout=5) | |
354 def check(ignored): | |
355 clientF.reason.trap(error.ConnectionRefusedError) | |
356 return clientF.failDeferred.addCallback(check) | |
357 | |
358 | |
359 def test_connectionRefusedErrorNumber(self): | |
360 """ | |
361 Assert that the error number of the ConnectionRefusedError is | |
362 ECONNREFUSED, and not some other socket related error. | |
363 """ | |
364 | |
365 # Bind a number of ports in the operating system. We will attempt | |
366 # to connect to these in turn immediately after closing them, in the | |
367 # hopes that no one else has bound them in the mean time. Any | |
368 # connection which succeeds is ignored and causes us to move on to | |
369 # the next port. As soon as a connection attempt fails, we move on | |
370 # to making an assertion about how it failed. If they all succeed, | |
371 # the test will fail. | |
372 | |
373 # It would be nice to have a simpler, reliable way to cause a | |
374 # connection failure from the platform. | |
375 # | |
376 # On Linux (2.6.15), connecting to port 0 always fails. FreeBSD | |
377 # (5.4) rejects the connection attempt with EADDRNOTAVAIL. | |
378 # | |
379 # On FreeBSD (5.4), listening on a port and then repeatedly | |
380 # connecting to it without ever accepting any connections eventually | |
381 # leads to an ECONNREFUSED. On Linux (2.6.15), a seemingly | |
382 # unbounded number of connections succeed. | |
383 | |
384 serverSockets = [] | |
385 for i in xrange(10): | |
386 serverSocket = socket.socket() | |
387 serverSocket.bind(('127.0.0.1', 0)) | |
388 serverSocket.listen(1) | |
389 serverSockets.append(serverSocket) | |
390 random.shuffle(serverSockets) | |
391 | |
392 clientCreator = protocol.ClientCreator(reactor, protocol.Protocol) | |
393 | |
394 def tryConnectFailure(): | |
395 def connected(proto): | |
396 """ | |
397 Darn. Kill it and try again, if there are any tries left. | |
398 """ | |
399 proto.transport.loseConnection() | |
400 if serverSockets: | |
401 return tryConnectFailure() | |
402 self.fail("Could not fail to connect - could not test errno for
that case.") | |
403 | |
404 serverSocket = serverSockets.pop() | |
405 serverHost, serverPort = serverSocket.getsockname() | |
406 serverSocket.close() | |
407 | |
408 connectDeferred = clientCreator.connectTCP(serverHost, serverPort) | |
409 connectDeferred.addCallback(connected) | |
410 return connectDeferred | |
411 | |
412 refusedDeferred = tryConnectFailure() | |
413 self.assertFailure(refusedDeferred, error.ConnectionRefusedError) | |
414 def connRefused(exc): | |
415 self.assertEqual(exc.osError, errno.ECONNREFUSED) | |
416 refusedDeferred.addCallback(connRefused) | |
417 def cleanup(passthrough): | |
418 while serverSockets: | |
419 serverSockets.pop().close() | |
420 return passthrough | |
421 refusedDeferred.addBoth(cleanup) | |
422 return refusedDeferred | |
423 | |
424 | |
425 def test_connectByServiceFail(self): | |
426 """ | |
427 Connecting to a named service which does not exist raises | |
428 L{error.ServiceNameUnknownError}. | |
429 """ | |
430 self.assertRaises( | |
431 error.ServiceNameUnknownError, | |
432 reactor.connectTCP, | |
433 "127.0.0.1", "thisbetternotexist", MyClientFactory()) | |
434 | |
435 | |
436 def test_connectByService(self): | |
437 """ | |
438 L{IReactorTCP.connectTCP} accepts the name of a service instead of a | |
439 port number and connects to the port number associated with that | |
440 service, as defined by L{socket.getservbyname}. | |
441 """ | |
442 serverFactory = MyServerFactory() | |
443 serverConnMade = defer.Deferred() | |
444 serverFactory.protocolConnectionMade = serverConnMade | |
445 port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1") | |
446 self.addCleanup(port.stopListening) | |
447 portNumber = port.getHost().port | |
448 clientFactory = MyClientFactory() | |
449 clientConnMade = defer.Deferred() | |
450 clientFactory.protocolConnectionMade = clientConnMade | |
451 | |
452 def fakeGetServicePortByName(serviceName, protocolName): | |
453 if serviceName == 'http' and protocolName == 'tcp': | |
454 return portNumber | |
455 return 10 | |
456 self.patch(socket, 'getservbyname', fakeGetServicePortByName) | |
457 | |
458 c = reactor.connectTCP('127.0.0.1', 'http', clientFactory) | |
459 | |
460 connMade = defer.gatherResults([serverConnMade, clientConnMade]) | |
461 def connected((serverProtocol, clientProtocol)): | |
462 self.assertTrue( | |
463 serverFactory.called, | |
464 "Server factory was not called upon to build a protocol.") | |
465 serverProtocol.transport.loseConnection() | |
466 clientProtocol.transport.loseConnection() | |
467 connMade.addCallback(connected) | |
468 return connMade | |
469 | |
470 | |
471 class StartStopFactory(protocol.Factory): | |
472 | |
473 started = 0 | |
474 stopped = 0 | |
475 | |
476 def startFactory(self): | |
477 if self.started or self.stopped: | |
478 raise RuntimeError | |
479 self.started = 1 | |
480 | |
481 def stopFactory(self): | |
482 if not self.started or self.stopped: | |
483 raise RuntimeError | |
484 self.stopped = 1 | |
485 | |
486 | |
487 class ClientStartStopFactory(MyClientFactory): | |
488 | |
489 started = 0 | |
490 stopped = 0 | |
491 | |
492 def startFactory(self): | |
493 if self.started or self.stopped: | |
494 raise RuntimeError | |
495 self.started = 1 | |
496 | |
497 def stopFactory(self): | |
498 if not self.started or self.stopped: | |
499 raise RuntimeError | |
500 self.stopped = 1 | |
501 | |
502 | |
503 class FactoryTestCase(unittest.TestCase): | |
504 """Tests for factories.""" | |
505 | |
506 def test_serverStartStop(self): | |
507 """ | |
508 The factory passed to L{IReactorTCP.listenTCP} should be started only | |
509 when it transitions from being used on no ports to being used on one | |
510 port and should be stopped only when it transitions from being used on | |
511 one port to being used on no ports. | |
512 """ | |
513 # Note - this test doesn't need to use listenTCP. It is exercising | |
514 # logic implemented in Factory.doStart and Factory.doStop, so it could | |
515 # just call that directly. Some other test can make sure that | |
516 # listenTCP and stopListening correctly call doStart and | |
517 # doStop. -exarkun | |
518 | |
519 f = StartStopFactory() | |
520 | |
521 # listen on port | |
522 p1 = reactor.listenTCP(0, f, interface='127.0.0.1') | |
523 self.addCleanup(p1.stopListening) | |
524 | |
525 self.assertEqual((f.started, f.stopped), (1, 0)) | |
526 | |
527 # listen on two more ports | |
528 p2 = reactor.listenTCP(0, f, interface='127.0.0.1') | |
529 p3 = reactor.listenTCP(0, f, interface='127.0.0.1') | |
530 | |
531 self.assertEqual((f.started, f.stopped), (1, 0)) | |
532 | |
533 # close two ports | |
534 d1 = defer.maybeDeferred(p1.stopListening) | |
535 d2 = defer.maybeDeferred(p2.stopListening) | |
536 closedDeferred = defer.gatherResults([d1, d2]) | |
537 def cbClosed(ignored): | |
538 self.assertEqual((f.started, f.stopped), (1, 0)) | |
539 # Close the last port | |
540 return p3.stopListening() | |
541 closedDeferred.addCallback(cbClosed) | |
542 | |
543 def cbClosedAll(ignored): | |
544 self.assertEquals((f.started, f.stopped), (1, 1)) | |
545 closedDeferred.addCallback(cbClosedAll) | |
546 return closedDeferred | |
547 | |
548 | |
549 def test_clientStartStop(self): | |
550 """ | |
551 The factory passed to L{IReactorTCP.connectTCP} should be started when | |
552 the connection attempt starts and stopped when it is over. | |
553 """ | |
554 f = ClosingFactory() | |
555 p = reactor.listenTCP(0, f, interface="127.0.0.1") | |
556 self.addCleanup(p.stopListening) | |
557 portNumber = p.getHost().port | |
558 f.port = p | |
559 | |
560 factory = ClientStartStopFactory() | |
561 reactor.connectTCP("127.0.0.1", portNumber, factory) | |
562 self.assertTrue(factory.started) | |
563 return loopUntil(lambda: factory.stopped) | |
564 | |
565 | |
566 | |
567 class ConnectorTestCase(unittest.TestCase): | |
568 | |
569 def test_connectorIdentity(self): | |
570 """ | |
571 L{IReactorTCP.connectTCP} returns an object which provides | |
572 L{IConnector}. The destination of the connector is the address which | |
573 was passed to C{connectTCP}. The same connector object is passed to | |
574 the factory's C{startedConnecting} method as to the factory's | |
575 C{clientConnectionLost} method. | |
576 """ | |
577 serverFactory = ClosingFactory() | |
578 tcpPort = reactor.listenTCP(0, serverFactory, interface="127.0.0.1") | |
579 self.addCleanup(tcpPort.stopListening) | |
580 portNumber = tcpPort.getHost().port | |
581 serverFactory.port = tcpPort | |
582 | |
583 seenConnectors = [] | |
584 seenFailures = [] | |
585 | |
586 clientFactory = ClientStartStopFactory() | |
587 clientFactory.clientConnectionLost = ( | |
588 lambda connector, reason: (seenConnectors.append(connector), | |
589 seenFailures.append(reason))) | |
590 clientFactory.startedConnecting = seenConnectors.append | |
591 | |
592 connector = reactor.connectTCP("127.0.0.1", portNumber, clientFactory) | |
593 self.assertTrue(interfaces.IConnector.providedBy(connector)) | |
594 dest = connector.getDestination() | |
595 self.assertEquals(dest.type, "TCP") | |
596 self.assertEquals(dest.host, "127.0.0.1") | |
597 self.assertEquals(dest.port, portNumber) | |
598 | |
599 d = loopUntil(lambda: clientFactory.stopped) | |
600 def clientFactoryStopped(ignored): | |
601 seenFailures[0].trap(error.ConnectionDone) | |
602 self.assertEqual(seenConnectors, [connector, connector]) | |
603 d.addCallback(clientFactoryStopped) | |
604 return d | |
605 | |
606 | |
607 def test_userFail(self): | |
608 """ | |
609 Calling L{IConnector.stopConnecting} in C{Factory.startedConnecting} | |
610 results in C{Factory.clientConnectionFailed} being called with | |
611 L{error.UserError} as the reason. | |
612 """ | |
613 serverFactory = MyServerFactory() | |
614 tcpPort = reactor.listenTCP(0, serverFactory, interface="127.0.0.1") | |
615 self.addCleanup(tcpPort.stopListening) | |
616 portNumber = tcpPort.getHost().port | |
617 | |
618 def startedConnecting(connector): | |
619 connector.stopConnecting() | |
620 | |
621 clientFactory = ClientStartStopFactory() | |
622 clientFactory.startedConnecting = startedConnecting | |
623 reactor.connectTCP("127.0.0.1", portNumber, clientFactory) | |
624 | |
625 d = loopUntil(lambda: clientFactory.stopped) | |
626 def check(ignored): | |
627 self.assertEquals(clientFactory.failed, 1) | |
628 clientFactory.reason.trap(error.UserError) | |
629 return d.addCallback(check) | |
630 | |
631 | |
632 def test_reconnect(self): | |
633 """ | |
634 Calling L{IConnector.connect} in C{Factory.clientConnectionLost} causes | |
635 a new connection attempt to be made. | |
636 """ | |
637 serverFactory = ClosingFactory() | |
638 tcpPort = reactor.listenTCP(0, serverFactory, interface="127.0.0.1") | |
639 self.addCleanup(tcpPort.stopListening) | |
640 portNumber = tcpPort.getHost().port | |
641 serverFactory.port = tcpPort | |
642 | |
643 clientFactory = MyClientFactory() | |
644 | |
645 def clientConnectionLost(connector, reason): | |
646 connector.connect() | |
647 clientFactory.clientConnectionLost = clientConnectionLost | |
648 reactor.connectTCP("127.0.0.1", portNumber, clientFactory) | |
649 | |
650 d = loopUntil(lambda: clientFactory.failed) | |
651 def reconnectFailed(ignored): | |
652 p = clientFactory.protocol | |
653 self.assertEqual((p.made, p.closed), (1, 1)) | |
654 clientFactory.reason.trap(error.ConnectionRefusedError) | |
655 self.assertEqual(clientFactory.stopped, 1) | |
656 return d.addCallback(reconnectFailed) | |
657 | |
658 | |
659 | |
660 class CannotBindTestCase(unittest.TestCase): | |
661 """ | |
662 Tests for correct behavior when a reactor cannot bind to the required TCP | |
663 port. | |
664 """ | |
665 | |
666 def test_cannotBind(self): | |
667 """ | |
668 L{IReactorTCP.listenTCP} raises L{error.CannotListenError} if the | |
669 address to listen on is already in use. | |
670 """ | |
671 f = MyServerFactory() | |
672 | |
673 p1 = reactor.listenTCP(0, f, interface='127.0.0.1') | |
674 self.addCleanup(p1.stopListening) | |
675 n = p1.getHost().port | |
676 dest = p1.getHost() | |
677 self.assertEquals(dest.type, "TCP") | |
678 self.assertEquals(dest.host, "127.0.0.1") | |
679 self.assertEquals(dest.port, n) | |
680 | |
681 # make sure new listen raises error | |
682 self.assertRaises(error.CannotListenError, | |
683 reactor.listenTCP, n, f, interface='127.0.0.1') | |
684 | |
685 | |
686 | |
687 def _fireWhenDoneFunc(self, d, f): | |
688 """Returns closure that when called calls f and then callbacks d. | |
689 """ | |
690 from twisted.python import util as tputil | |
691 def newf(*args, **kw): | |
692 rtn = f(*args, **kw) | |
693 d.callback('') | |
694 return rtn | |
695 return tputil.mergeFunctionMetadata(f, newf) | |
696 | |
697 | |
698 def test_clientBind(self): | |
699 """ | |
700 L{IReactorTCP.connectTCP} calls C{Factory.clientConnectionFailed} with | |
701 L{error.ConnectBindError} if the bind address specified is already in | |
702 use. | |
703 """ | |
704 theDeferred = defer.Deferred() | |
705 sf = MyServerFactory() | |
706 sf.startFactory = self._fireWhenDoneFunc(theDeferred, sf.startFactory) | |
707 p = reactor.listenTCP(0, sf, interface="127.0.0.1") | |
708 self.addCleanup(p.stopListening) | |
709 | |
710 def _connect1(results): | |
711 d = defer.Deferred() | |
712 cf1 = MyClientFactory() | |
713 cf1.buildProtocol = self._fireWhenDoneFunc(d, cf1.buildProtocol) | |
714 reactor.connectTCP("127.0.0.1", p.getHost().port, cf1, | |
715 bindAddress=("127.0.0.1", 0)) | |
716 d.addCallback(_conmade, cf1) | |
717 return d | |
718 | |
719 def _conmade(results, cf1): | |
720 d = defer.Deferred() | |
721 cf1.protocol.connectionMade = self._fireWhenDoneFunc( | |
722 d, cf1.protocol.connectionMade) | |
723 d.addCallback(_check1connect2, cf1) | |
724 return d | |
725 | |
726 def _check1connect2(results, cf1): | |
727 self.assertEquals(cf1.protocol.made, 1) | |
728 | |
729 d1 = defer.Deferred() | |
730 d2 = defer.Deferred() | |
731 port = cf1.protocol.transport.getHost().port | |
732 cf2 = MyClientFactory() | |
733 cf2.clientConnectionFailed = self._fireWhenDoneFunc( | |
734 d1, cf2.clientConnectionFailed) | |
735 cf2.stopFactory = self._fireWhenDoneFunc(d2, cf2.stopFactory) | |
736 reactor.connectTCP("127.0.0.1", p.getHost().port, cf2, | |
737 bindAddress=("127.0.0.1", port)) | |
738 d1.addCallback(_check2failed, cf1, cf2) | |
739 d2.addCallback(_check2stopped, cf1, cf2) | |
740 dl = defer.DeferredList([d1, d2]) | |
741 dl.addCallback(_stop, cf1, cf2) | |
742 return dl | |
743 | |
744 def _check2failed(results, cf1, cf2): | |
745 self.assertEquals(cf2.failed, 1) | |
746 cf2.reason.trap(error.ConnectBindError) | |
747 self.assertTrue(cf2.reason.check(error.ConnectBindError)) | |
748 return results | |
749 | |
750 def _check2stopped(results, cf1, cf2): | |
751 self.assertEquals(cf2.stopped, 1) | |
752 return results | |
753 | |
754 def _stop(results, cf1, cf2): | |
755 d = defer.Deferred() | |
756 d.addCallback(_check1cleanup, cf1) | |
757 cf1.stopFactory = self._fireWhenDoneFunc(d, cf1.stopFactory) | |
758 cf1.protocol.transport.loseConnection() | |
759 return d | |
760 | |
761 def _check1cleanup(results, cf1): | |
762 self.assertEquals(cf1.stopped, 1) | |
763 | |
764 theDeferred.addCallback(_connect1) | |
765 return theDeferred | |
766 | |
767 | |
768 | |
769 class MyOtherClientFactory(protocol.ClientFactory): | |
770 def buildProtocol(self, address): | |
771 self.address = address | |
772 self.protocol = MyProtocol() | |
773 return self.protocol | |
774 | |
775 | |
776 | |
777 class LocalRemoteAddressTestCase(unittest.TestCase): | |
778 """ | |
779 Tests for correct getHost/getPeer values and that the correct address is | |
780 passed to buildProtocol. | |
781 """ | |
782 def test_hostAddress(self): | |
783 """ | |
784 L{IListeningPort.getHost} returns the same address as a client | |
785 connection's L{ITCPTransport.getPeer}. | |
786 """ | |
787 f1 = MyServerFactory() | |
788 p1 = reactor.listenTCP(0, f1, interface='127.0.0.1') | |
789 self.addCleanup(p1.stopListening) | |
790 n = p1.getHost().port | |
791 | |
792 f2 = MyOtherClientFactory() | |
793 p2 = reactor.connectTCP('127.0.0.1', n, f2) | |
794 | |
795 d = loopUntil(lambda :p2.state == "connected") | |
796 def check(ignored): | |
797 self.assertEquals(p1.getHost(), f2.address) | |
798 self.assertEquals(p1.getHost(), f2.protocol.transport.getPeer()) | |
799 return p1.stopListening() | |
800 def cleanup(ignored): | |
801 p2.transport.loseConnection() | |
802 return d.addCallback(check).addCallback(cleanup) | |
803 | |
804 | |
805 class WriterProtocol(protocol.Protocol): | |
806 def connectionMade(self): | |
807 # use everything ITransport claims to provide. If something here | |
808 # fails, the exception will be written to the log, but it will not | |
809 # directly flunk the test. The test will fail when maximum number of | |
810 # iterations have passed and the writer's factory.done has not yet | |
811 # been set. | |
812 self.transport.write("Hello Cleveland!\n") | |
813 seq = ["Goodbye", " cruel", " world", "\n"] | |
814 self.transport.writeSequence(seq) | |
815 peer = self.transport.getPeer() | |
816 if peer.type != "TCP": | |
817 print "getPeer returned non-TCP socket:", peer | |
818 self.factory.problem = 1 | |
819 us = self.transport.getHost() | |
820 if us.type != "TCP": | |
821 print "getHost returned non-TCP socket:", us | |
822 self.factory.problem = 1 | |
823 self.factory.done = 1 | |
824 | |
825 self.transport.loseConnection() | |
826 | |
827 class ReaderProtocol(protocol.Protocol): | |
828 def dataReceived(self, data): | |
829 self.factory.data += data | |
830 def connectionLost(self, reason): | |
831 self.factory.done = 1 | |
832 | |
833 class WriterClientFactory(protocol.ClientFactory): | |
834 def __init__(self): | |
835 self.done = 0 | |
836 self.data = "" | |
837 def buildProtocol(self, addr): | |
838 p = ReaderProtocol() | |
839 p.factory = self | |
840 self.protocol = p | |
841 return p | |
842 | |
843 class WriteDataTestCase(unittest.TestCase): | |
844 """ | |
845 Test that connected TCP sockets can actually write data. Try to exercise | |
846 the entire ITransport interface. | |
847 """ | |
848 | |
849 def test_writer(self): | |
850 """ | |
851 L{ITCPTransport.write} and L{ITCPTransport.writeSequence} send bytes to | |
852 the other end of the connection. | |
853 """ | |
854 f = protocol.Factory() | |
855 f.protocol = WriterProtocol | |
856 f.done = 0 | |
857 f.problem = 0 | |
858 wrappedF = WiredFactory(f) | |
859 p = reactor.listenTCP(0, wrappedF, interface="127.0.0.1") | |
860 self.addCleanup(p.stopListening) | |
861 n = p.getHost().port | |
862 clientF = WriterClientFactory() | |
863 wrappedClientF = WiredFactory(clientF) | |
864 reactor.connectTCP("127.0.0.1", n, wrappedClientF) | |
865 | |
866 def check(ignored): | |
867 self.failUnless(f.done, "writer didn't finish, it probably died") | |
868 self.failUnless(f.problem == 0, "writer indicated an error") | |
869 self.failUnless(clientF.done, | |
870 "client didn't see connection dropped") | |
871 expected = "".join(["Hello Cleveland!\n", | |
872 "Goodbye", " cruel", " world", "\n"]) | |
873 self.failUnless(clientF.data == expected, | |
874 "client didn't receive all the data it expected") | |
875 d = defer.gatherResults([wrappedF.onDisconnect, | |
876 wrappedClientF.onDisconnect]) | |
877 return d.addCallback(check) | |
878 | |
879 | |
880 def test_writeAfterShutdownWithoutReading(self): | |
881 """ | |
882 A TCP transport which is written to after the connection has been shut | |
883 down should notify its protocol that the connection has been lost, even | |
884 if the TCP transport is not actively being monitored for read events | |
885 (ie, pauseProducing was called on it). | |
886 """ | |
887 # This is an unpleasant thing. Generally tests shouldn't skip or | |
888 # run based on the name of the reactor being used (most tests | |
889 # shouldn't care _at all_ what reactor is being used, in fact). The | |
890 # Gtk reactor cannot pass this test, though, because it fails to | |
891 # implement IReactorTCP entirely correctly. Gtk is quite old at | |
892 # this point, so it's more likely that gtkreactor will be deprecated | |
893 # and removed rather than fixed to handle this case correctly. | |
894 # Since this is a pre-existing (and very long-standing) issue with | |
895 # the Gtk reactor, there's no reason for it to prevent this test | |
896 # being added to exercise the other reactors, for which the behavior | |
897 # was also untested but at least works correctly (now). See #2833 | |
898 # for information on the status of gtkreactor. | |
899 if reactor.__class__.__name__ == 'IOCPReactor': | |
900 raise unittest.SkipTest( | |
901 "iocpreactor does not, in fact, stop reading immediately after " | |
902 "pauseProducing is called. This results in a bonus disconnection
" | |
903 "notification. Under some circumstances, it might be possible to
" | |
904 "not receive this notifications (specifically, pauseProducing, " | |
905 "deliver some data, proceed with this test).") | |
906 if reactor.__class__.__name__ == 'GtkReactor': | |
907 raise unittest.SkipTest( | |
908 "gtkreactor does not implement unclean disconnection " | |
909 "notification correctly. This might more properly be " | |
910 "a todo, but due to technical limitations it cannot be.") | |
911 | |
912 # Called back after the protocol for the client side of the connection | |
913 # has paused its transport, preventing it from reading, therefore | |
914 # preventing it from noticing the disconnection before the rest of the | |
915 # actions which are necessary to trigger the case this test is for have | |
916 # been taken. | |
917 clientPaused = defer.Deferred() | |
918 | |
919 # Called back when the protocol for the server side of the connection | |
920 # has received connection lost notification. | |
921 serverLost = defer.Deferred() | |
922 | |
923 class Disconnecter(protocol.Protocol): | |
924 """ | |
925 Protocol for the server side of the connection which disconnects | |
926 itself in a callback on clientPaused and publishes notification | |
927 when its connection is actually lost. | |
928 """ | |
929 def connectionMade(self): | |
930 """ | |
931 Set up a callback on clientPaused to lose the connection. | |
932 """ | |
933 msg('Disconnector.connectionMade') | |
934 def disconnect(ignored): | |
935 msg('Disconnector.connectionMade disconnect') | |
936 self.transport.loseConnection() | |
937 msg('loseConnection called') | |
938 clientPaused.addCallback(disconnect) | |
939 | |
940 def connectionLost(self, reason): | |
941 """ | |
942 Notify observers that the server side of the connection has | |
943 ended. | |
944 """ | |
945 msg('Disconnecter.connectionLost') | |
946 serverLost.callback(None) | |
947 msg('serverLost called back') | |
948 | |
949 # Create the server port to which a connection will be made. | |
950 server = protocol.ServerFactory() | |
951 server.protocol = Disconnecter | |
952 port = reactor.listenTCP(0, server, interface='127.0.0.1') | |
953 self.addCleanup(port.stopListening) | |
954 addr = port.getHost() | |
955 | |
956 class Infinite(object): | |
957 """ | |
958 A producer which will write to its consumer as long as | |
959 resumeProducing is called. | |
960 | |
961 @ivar consumer: The L{IConsumer} which will be written to. | |
962 """ | |
963 implements(IPullProducer) | |
964 | |
965 def __init__(self, consumer): | |
966 self.consumer = consumer | |
967 | |
968 def resumeProducing(self): | |
969 msg('Infinite.resumeProducing') | |
970 self.consumer.write('x') | |
971 msg('Infinite.resumeProducing wrote to consumer') | |
972 | |
973 def stopProducing(self): | |
974 msg('Infinite.stopProducing') | |
975 | |
976 | |
977 class UnreadingWriter(protocol.Protocol): | |
978 """ | |
979 Trivial protocol which pauses its transport immediately and then | |
980 writes some bytes to it. | |
981 """ | |
982 def connectionMade(self): | |
983 msg('UnreadingWriter.connectionMade') | |
984 self.transport.pauseProducing() | |
985 clientPaused.callback(None) | |
986 msg('clientPaused called back') | |
987 def write(ignored): | |
988 msg('UnreadingWriter.connectionMade write') | |
989 # This needs to be enough bytes to spill over into the | |
990 # userspace Twisted send buffer - if it all fits into | |
991 # the kernel, Twisted won't even poll for OUT events, | |
992 # which means it won't poll for any events at all, so | |
993 # the disconnection is never noticed. This is due to | |
994 # #1662. When #1662 is fixed, this test will likely | |
995 # need to be adjusted, otherwise connection lost | |
996 # notification will happen too soon and the test will | |
997 # probably begin to fail with ConnectionDone instead of | |
998 # ConnectionLost (in any case, it will no longer be | |
999 # entirely correct). | |
1000 producer = Infinite(self.transport) | |
1001 msg('UnreadingWriter.connectionMade write created producer') | |
1002 self.transport.registerProducer(producer, False) | |
1003 msg('UnreadingWriter.connectionMade write registered produce
r') | |
1004 serverLost.addCallback(write) | |
1005 | |
1006 # Create the client and initiate the connection | |
1007 client = MyClientFactory() | |
1008 client.protocolFactory = UnreadingWriter | |
1009 clientConnectionLost = client.deferred | |
1010 def cbClientLost(ignored): | |
1011 msg('cbClientLost') | |
1012 return client.lostReason | |
1013 clientConnectionLost.addCallback(cbClientLost) | |
1014 msg('Connecting to %s:%s' % (addr.host, addr.port)) | |
1015 connector = reactor.connectTCP(addr.host, addr.port, client) | |
1016 | |
1017 # By the end of the test, the client should have received notification | |
1018 # of unclean disconnection. | |
1019 msg('Returning Deferred') | |
1020 return self.assertFailure(clientConnectionLost, error.ConnectionLost) | |
1021 | |
1022 | |
1023 | |
1024 class ConnectionLosingProtocol(protocol.Protocol): | |
1025 def connectionMade(self): | |
1026 self.transport.write("1") | |
1027 self.transport.loseConnection() | |
1028 self.master._connectionMade() | |
1029 self.master.ports.append(self.transport) | |
1030 | |
1031 | |
1032 | |
1033 class NoopProtocol(protocol.Protocol): | |
1034 def connectionMade(self): | |
1035 self.d = defer.Deferred() | |
1036 self.master.serverConns.append(self.d) | |
1037 | |
1038 def connectionLost(self, reason): | |
1039 self.d.callback(True) | |
1040 | |
1041 | |
1042 | |
1043 class ConnectionLostNotifyingProtocol(protocol.Protocol): | |
1044 """ | |
1045 Protocol which fires a Deferred which was previously passed to | |
1046 its initializer when the connection is lost. | |
1047 """ | |
1048 def __init__(self, onConnectionLost): | |
1049 self.onConnectionLost = onConnectionLost | |
1050 | |
1051 | |
1052 def connectionLost(self, reason): | |
1053 self.onConnectionLost.callback(self) | |
1054 | |
1055 | |
1056 | |
1057 class HandleSavingProtocol(ConnectionLostNotifyingProtocol): | |
1058 """ | |
1059 Protocol which grabs the platform-specific socket handle and | |
1060 saves it as an attribute on itself when the connection is | |
1061 established. | |
1062 """ | |
1063 def makeConnection(self, transport): | |
1064 """ | |
1065 Save the platform-specific socket handle for future | |
1066 introspection. | |
1067 """ | |
1068 self.handle = transport.getHandle() | |
1069 return protocol.Protocol.makeConnection(self, transport) | |
1070 | |
1071 | |
1072 | |
1073 class ProperlyCloseFilesMixin: | |
1074 """ | |
1075 Tests for platform resources properly being cleaned up. | |
1076 """ | |
1077 def createServer(self, address, portNumber, factory): | |
1078 """ | |
1079 Bind a server port to which connections will be made. The server | |
1080 should use the given protocol factory. | |
1081 | |
1082 @return: The L{IListeningPort} for the server created. | |
1083 """ | |
1084 raise NotImplementedError() | |
1085 | |
1086 | |
1087 def connectClient(self, address, portNumber, clientCreator): | |
1088 """ | |
1089 Establish a connection to the given address using the given | |
1090 L{ClientCreator} instance. | |
1091 | |
1092 @return: A Deferred which will fire with the connected protocol instance
. | |
1093 """ | |
1094 raise NotImplementedError() | |
1095 | |
1096 | |
1097 def getHandleExceptionType(self): | |
1098 """ | |
1099 Return the exception class which will be raised when an operation is | |
1100 attempted on a closed platform handle. | |
1101 """ | |
1102 raise NotImplementedError() | |
1103 | |
1104 | |
1105 def getHandleErrorCode(self): | |
1106 """ | |
1107 Return the errno expected to result from writing to a closed | |
1108 platform socket handle. | |
1109 """ | |
1110 # These platforms have been seen to give EBADF: | |
1111 # | |
1112 # Linux 2.4.26, Linux 2.6.15, OS X 10.4, FreeBSD 5.4 | |
1113 # Windows 2000 SP 4, Windows XP SP 2 | |
1114 return errno.EBADF | |
1115 | |
1116 | |
1117 def test_properlyCloseFiles(self): | |
1118 """ | |
1119 Test that lost connections properly have their underlying socket | |
1120 resources cleaned up. | |
1121 """ | |
1122 onServerConnectionLost = defer.Deferred() | |
1123 serverFactory = protocol.ServerFactory() | |
1124 serverFactory.protocol = lambda: ConnectionLostNotifyingProtocol( | |
1125 onServerConnectionLost) | |
1126 serverPort = self.createServer('127.0.0.1', 0, serverFactory) | |
1127 | |
1128 onClientConnectionLost = defer.Deferred() | |
1129 serverAddr = serverPort.getHost() | |
1130 clientCreator = protocol.ClientCreator( | |
1131 reactor, lambda: HandleSavingProtocol(onClientConnectionLost)) | |
1132 clientDeferred = self.connectClient( | |
1133 serverAddr.host, serverAddr.port, clientCreator) | |
1134 | |
1135 def clientConnected(client): | |
1136 """ | |
1137 Disconnect the client. Return a Deferred which fires when both | |
1138 the client and the server have received disconnect notification. | |
1139 """ | |
1140 client.transport.loseConnection() | |
1141 return defer.gatherResults([ | |
1142 onClientConnectionLost, onServerConnectionLost]) | |
1143 clientDeferred.addCallback(clientConnected) | |
1144 | |
1145 def clientDisconnected((client, server)): | |
1146 """ | |
1147 Verify that the underlying platform socket handle has been | |
1148 cleaned up. | |
1149 """ | |
1150 expectedErrorCode = self.getHandleErrorCode() | |
1151 err = self.assertRaises( | |
1152 self.getHandleExceptionType(), client.handle.send, 'bytes') | |
1153 self.assertEqual(err.args[0], expectedErrorCode) | |
1154 clientDeferred.addCallback(clientDisconnected) | |
1155 | |
1156 def cleanup(passthrough): | |
1157 """ | |
1158 Shut down the server port. Return a Deferred which fires when | |
1159 this has completed. | |
1160 """ | |
1161 result = defer.maybeDeferred(serverPort.stopListening) | |
1162 result.addCallback(lambda ign: passthrough) | |
1163 return result | |
1164 clientDeferred.addBoth(cleanup) | |
1165 | |
1166 return clientDeferred | |
1167 | |
1168 | |
1169 | |
1170 class ProperlyCloseFilesTestCase(unittest.TestCase, ProperlyCloseFilesMixin): | |
1171 def createServer(self, address, portNumber, factory): | |
1172 return reactor.listenTCP(portNumber, factory, interface=address) | |
1173 | |
1174 | |
1175 def connectClient(self, address, portNumber, clientCreator): | |
1176 return clientCreator.connectTCP(address, portNumber) | |
1177 | |
1178 | |
1179 def getHandleExceptionType(self): | |
1180 return socket.error | |
1181 | |
1182 | |
1183 | |
1184 class WiredForDeferreds(policies.ProtocolWrapper): | |
1185 def __init__(self, factory, wrappedProtocol): | |
1186 policies.ProtocolWrapper.__init__(self, factory, wrappedProtocol) | |
1187 | |
1188 def connectionMade(self): | |
1189 policies.ProtocolWrapper.connectionMade(self) | |
1190 self.factory.onConnect.callback(None) | |
1191 | |
1192 def connectionLost(self, reason): | |
1193 policies.ProtocolWrapper.connectionLost(self, reason) | |
1194 self.factory.onDisconnect.callback(None) | |
1195 | |
1196 | |
1197 | |
1198 class WiredFactory(policies.WrappingFactory): | |
1199 protocol = WiredForDeferreds | |
1200 | |
1201 def __init__(self, wrappedFactory): | |
1202 policies.WrappingFactory.__init__(self, wrappedFactory) | |
1203 self.onConnect = defer.Deferred() | |
1204 self.onDisconnect = defer.Deferred() | |
1205 | |
1206 | |
1207 | |
1208 class AddressTestCase(unittest.TestCase): | |
1209 """ | |
1210 Tests for address-related interactions with client and server protocols. | |
1211 """ | |
1212 def setUp(self): | |
1213 """ | |
1214 Create a port and connected client/server pair which can be used | |
1215 to test factory behavior related to addresses. | |
1216 | |
1217 @return: A L{defer.Deferred} which will be called back when both the | |
1218 client and server protocols have received their connection made | |
1219 callback. | |
1220 """ | |
1221 class RememberingWrapper(protocol.ClientFactory): | |
1222 """ | |
1223 Simple wrapper factory which records the addresses which are | |
1224 passed to its L{buildProtocol} method and delegates actual | |
1225 protocol creation to another factory. | |
1226 | |
1227 @ivar addresses: A list of the objects passed to buildProtocol. | |
1228 @ivar factory: The wrapped factory to which protocol creation is | |
1229 delegated. | |
1230 """ | |
1231 def __init__(self, factory): | |
1232 self.addresses = [] | |
1233 self.factory = factory | |
1234 | |
1235 # Only bother to pass on buildProtocol calls to the wrapped | |
1236 # factory - doStart, doStop, etc aren't necessary for this test | |
1237 # to pass. | |
1238 def buildProtocol(self, addr): | |
1239 """ | |
1240 Append the given address to C{self.addresses} and forward | |
1241 the call to C{self.factory}. | |
1242 """ | |
1243 self.addresses.append(addr) | |
1244 return self.factory.buildProtocol(addr) | |
1245 | |
1246 # Make a server which we can receive connection and disconnection | |
1247 # notification for, and which will record the address passed to its | |
1248 # buildProtocol. | |
1249 self.server = MyServerFactory() | |
1250 self.serverConnMade = self.server.protocolConnectionMade = defer.Deferre
d() | |
1251 self.serverConnLost = self.server.protocolConnectionLost = defer.Deferre
d() | |
1252 # RememberingWrapper is a ClientFactory, but ClientFactory is-a | |
1253 # ServerFactory, so this is okay. | |
1254 self.serverWrapper = RememberingWrapper(self.server) | |
1255 | |
1256 # Do something similar for a client. | |
1257 self.client = MyClientFactory() | |
1258 self.clientConnMade = self.client.protocolConnectionMade = defer.Deferre
d() | |
1259 self.clientConnLost = self.client.protocolConnectionLost = defer.Deferre
d() | |
1260 self.clientWrapper = RememberingWrapper(self.client) | |
1261 | |
1262 self.port = reactor.listenTCP(0, self.serverWrapper, interface='127.0.0.
1') | |
1263 self.connector = reactor.connectTCP( | |
1264 self.port.getHost().host, self.port.getHost().port, self.clientWrapp
er) | |
1265 | |
1266 return defer.gatherResults([self.serverConnMade, self.clientConnMade]) | |
1267 | |
1268 | |
1269 def tearDown(self): | |
1270 """ | |
1271 Disconnect the client/server pair and shutdown the port created in | |
1272 L{setUp}. | |
1273 """ | |
1274 self.connector.disconnect() | |
1275 return defer.gatherResults([ | |
1276 self.serverConnLost, self.clientConnLost, | |
1277 defer.maybeDeferred(self.port.stopListening)]) | |
1278 | |
1279 | |
1280 def test_buildProtocolClient(self): | |
1281 """ | |
1282 L{ClientFactory.buildProtocol} should be invoked with the address of | |
1283 the server to which a connection has been established, which should | |
1284 be the same as the address reported by the C{getHost} method of the | |
1285 transport of the server protocol and as the C{getPeer} method of the | |
1286 transport of the client protocol. | |
1287 """ | |
1288 serverHost = self.server.protocol.transport.getHost() | |
1289 clientPeer = self.client.protocol.transport.getPeer() | |
1290 | |
1291 self.assertEqual( | |
1292 self.clientWrapper.addresses, | |
1293 [IPv4Address('TCP', serverHost.host, serverHost.port)]) | |
1294 self.assertEqual( | |
1295 self.clientWrapper.addresses, | |
1296 [IPv4Address('TCP', clientPeer.host, clientPeer.port)]) | |
1297 | |
1298 | |
1299 def test_buildProtocolServer(self): | |
1300 """ | |
1301 L{ServerFactory.buildProtocol} should be invoked with the address of | |
1302 the client which has connected to the port the factory is listening on, | |
1303 which should be the same as the address reported by the C{getPeer} | |
1304 method of the transport of the server protocol and as the C{getHost} | |
1305 method of the transport of the client protocol. | |
1306 """ | |
1307 clientHost = self.client.protocol.transport.getHost() | |
1308 serverPeer = self.server.protocol.transport.getPeer() | |
1309 | |
1310 self.assertEqual( | |
1311 self.serverWrapper.addresses, | |
1312 [IPv4Address('TCP', serverPeer.host, serverPeer.port)]) | |
1313 self.assertEqual( | |
1314 self.serverWrapper.addresses, | |
1315 [IPv4Address('TCP', clientHost.host, clientHost.port)]) | |
1316 | |
1317 | |
1318 | |
1319 class LargeBufferWriterProtocol(protocol.Protocol): | |
1320 | |
1321 # Win32 sockets cannot handle single huge chunks of bytes. Write one | |
1322 # massive string to make sure Twisted deals with this fact. | |
1323 | |
1324 def connectionMade(self): | |
1325 # write 60MB | |
1326 self.transport.write('X'*self.factory.len) | |
1327 self.factory.done = 1 | |
1328 self.transport.loseConnection() | |
1329 | |
1330 class LargeBufferReaderProtocol(protocol.Protocol): | |
1331 def dataReceived(self, data): | |
1332 self.factory.len += len(data) | |
1333 def connectionLost(self, reason): | |
1334 self.factory.done = 1 | |
1335 | |
1336 class LargeBufferReaderClientFactory(protocol.ClientFactory): | |
1337 def __init__(self): | |
1338 self.done = 0 | |
1339 self.len = 0 | |
1340 def buildProtocol(self, addr): | |
1341 p = LargeBufferReaderProtocol() | |
1342 p.factory = self | |
1343 self.protocol = p | |
1344 return p | |
1345 | |
1346 | |
1347 class FireOnClose(policies.ProtocolWrapper): | |
1348 """A wrapper around a protocol that makes it fire a deferred when | |
1349 connectionLost is called. | |
1350 """ | |
1351 def connectionLost(self, reason): | |
1352 policies.ProtocolWrapper.connectionLost(self, reason) | |
1353 self.factory.deferred.callback(None) | |
1354 | |
1355 | |
1356 class FireOnCloseFactory(policies.WrappingFactory): | |
1357 protocol = FireOnClose | |
1358 | |
1359 def __init__(self, wrappedFactory): | |
1360 policies.WrappingFactory.__init__(self, wrappedFactory) | |
1361 self.deferred = defer.Deferred() | |
1362 | |
1363 | |
1364 class LargeBufferTestCase(unittest.TestCase): | |
1365 """Test that buffering large amounts of data works. | |
1366 """ | |
1367 | |
1368 datalen = 60*1024*1024 | |
1369 def testWriter(self): | |
1370 f = protocol.Factory() | |
1371 f.protocol = LargeBufferWriterProtocol | |
1372 f.done = 0 | |
1373 f.problem = 0 | |
1374 f.len = self.datalen | |
1375 wrappedF = FireOnCloseFactory(f) | |
1376 p = reactor.listenTCP(0, wrappedF, interface="127.0.0.1") | |
1377 self.addCleanup(p.stopListening) | |
1378 n = p.getHost().port | |
1379 clientF = LargeBufferReaderClientFactory() | |
1380 wrappedClientF = FireOnCloseFactory(clientF) | |
1381 reactor.connectTCP("127.0.0.1", n, wrappedClientF) | |
1382 | |
1383 d = defer.gatherResults([wrappedF.deferred, wrappedClientF.deferred]) | |
1384 def check(ignored): | |
1385 self.failUnless(f.done, "writer didn't finish, it probably died") | |
1386 self.failUnless(clientF.len == self.datalen, | |
1387 "client didn't receive all the data it expected " | |
1388 "(%d != %d)" % (clientF.len, self.datalen)) | |
1389 self.failUnless(clientF.done, | |
1390 "client didn't see connection dropped") | |
1391 return d.addCallback(check) | |
1392 | |
1393 | |
1394 class MyHCProtocol(MyProtocol): | |
1395 | |
1396 implements(IHalfCloseableProtocol) | |
1397 | |
1398 readHalfClosed = False | |
1399 writeHalfClosed = False | |
1400 | |
1401 def readConnectionLost(self): | |
1402 self.readHalfClosed = True | |
1403 # Invoke notification logic from the base class to simplify testing. | |
1404 if self.writeHalfClosed: | |
1405 self.connectionLost(None) | |
1406 | |
1407 def writeConnectionLost(self): | |
1408 self.writeHalfClosed = True | |
1409 # Invoke notification logic from the base class to simplify testing. | |
1410 if self.readHalfClosed: | |
1411 self.connectionLost(None) | |
1412 | |
1413 | |
1414 class MyHCFactory(protocol.ServerFactory): | |
1415 | |
1416 called = 0 | |
1417 protocolConnectionMade = None | |
1418 | |
1419 def buildProtocol(self, addr): | |
1420 self.called += 1 | |
1421 p = MyHCProtocol() | |
1422 p.factory = self | |
1423 self.protocol = p | |
1424 return p | |
1425 | |
1426 | |
1427 class HalfCloseTestCase(unittest.TestCase): | |
1428 """Test half-closing connections.""" | |
1429 | |
1430 def setUp(self): | |
1431 self.f = f = MyHCFactory() | |
1432 self.p = p = reactor.listenTCP(0, f, interface="127.0.0.1") | |
1433 self.addCleanup(p.stopListening) | |
1434 d = loopUntil(lambda :p.connected) | |
1435 | |
1436 self.cf = protocol.ClientCreator(reactor, MyHCProtocol) | |
1437 | |
1438 d.addCallback(lambda _: self.cf.connectTCP(p.getHost().host, | |
1439 p.getHost().port)) | |
1440 d.addCallback(self._setUp) | |
1441 return d | |
1442 | |
1443 def _setUp(self, client): | |
1444 self.client = client | |
1445 self.clientProtoConnectionLost = self.client.closedDeferred = defer.Defe
rred() | |
1446 self.assertEquals(self.client.transport.connected, 1) | |
1447 # Wait for the server to notice there is a connection, too. | |
1448 return loopUntil(lambda: getattr(self.f, 'protocol', None) is not None) | |
1449 | |
1450 def tearDown(self): | |
1451 self.assertEquals(self.client.closed, 0) | |
1452 self.client.transport.loseConnection() | |
1453 d = defer.maybeDeferred(self.p.stopListening) | |
1454 d.addCallback(lambda ign: self.clientProtoConnectionLost) | |
1455 d.addCallback(self._tearDown) | |
1456 return d | |
1457 | |
1458 def _tearDown(self, ignored): | |
1459 self.assertEquals(self.client.closed, 1) | |
1460 # because we did half-close, the server also needs to | |
1461 # closed explicitly. | |
1462 self.assertEquals(self.f.protocol.closed, 0) | |
1463 d = defer.Deferred() | |
1464 def _connectionLost(reason): | |
1465 self.f.protocol.closed = 1 | |
1466 d.callback(None) | |
1467 self.f.protocol.connectionLost = _connectionLost | |
1468 self.f.protocol.transport.loseConnection() | |
1469 d.addCallback(lambda x:self.assertEquals(self.f.protocol.closed, 1)) | |
1470 return d | |
1471 | |
1472 def testCloseWriteCloser(self): | |
1473 client = self.client | |
1474 f = self.f | |
1475 t = client.transport | |
1476 | |
1477 t.write("hello") | |
1478 d = loopUntil(lambda :len(t._tempDataBuffer) == 0) | |
1479 def loseWrite(ignored): | |
1480 t.loseWriteConnection() | |
1481 return loopUntil(lambda :t._writeDisconnected) | |
1482 def check(ignored): | |
1483 self.assertEquals(client.closed, False) | |
1484 self.assertEquals(client.writeHalfClosed, True) | |
1485 self.assertEquals(client.readHalfClosed, False) | |
1486 return loopUntil(lambda :f.protocol.readHalfClosed) | |
1487 def write(ignored): | |
1488 w = client.transport.write | |
1489 w(" world") | |
1490 w("lalala fooled you") | |
1491 self.assertEquals(0, len(client.transport._tempDataBuffer)) | |
1492 self.assertEquals(f.protocol.data, "hello") | |
1493 self.assertEquals(f.protocol.closed, False) | |
1494 self.assertEquals(f.protocol.readHalfClosed, True) | |
1495 return d.addCallback(loseWrite).addCallback(check).addCallback(write) | |
1496 | |
1497 def testWriteCloseNotification(self): | |
1498 f = self.f | |
1499 f.protocol.transport.loseWriteConnection() | |
1500 | |
1501 d = defer.gatherResults([ | |
1502 loopUntil(lambda :f.protocol.writeHalfClosed), | |
1503 loopUntil(lambda :self.client.readHalfClosed)]) | |
1504 d.addCallback(lambda _: self.assertEquals( | |
1505 f.protocol.readHalfClosed, False)) | |
1506 return d | |
1507 | |
1508 | |
1509 class HalfClose2TestCase(unittest.TestCase): | |
1510 | |
1511 def setUp(self): | |
1512 self.f = f = MyServerFactory() | |
1513 self.f.protocolConnectionMade = defer.Deferred() | |
1514 self.p = p = reactor.listenTCP(0, f, interface="127.0.0.1") | |
1515 | |
1516 # XXX we don't test server side yet since we don't do it yet | |
1517 d = protocol.ClientCreator(reactor, MyProtocol).connectTCP( | |
1518 p.getHost().host, p.getHost().port) | |
1519 d.addCallback(self._gotClient) | |
1520 return d | |
1521 | |
1522 def _gotClient(self, client): | |
1523 self.client = client | |
1524 # Now wait for the server to catch up - it doesn't matter if this | |
1525 # Deferred has already fired and gone away, in that case we'll | |
1526 # return None and not wait at all, which is precisely correct. | |
1527 return self.f.protocolConnectionMade | |
1528 | |
1529 def tearDown(self): | |
1530 self.client.transport.loseConnection() | |
1531 return self.p.stopListening() | |
1532 | |
1533 def testNoNotification(self): | |
1534 """ | |
1535 TCP protocols support half-close connections, but not all of them | |
1536 support being notified of write closes. In this case, test that | |
1537 half-closing the connection causes the peer's connection to be | |
1538 closed. | |
1539 """ | |
1540 self.client.transport.write("hello") | |
1541 self.client.transport.loseWriteConnection() | |
1542 self.f.protocol.closedDeferred = d = defer.Deferred() | |
1543 self.client.closedDeferred = d2 = defer.Deferred() | |
1544 d.addCallback(lambda x: | |
1545 self.assertEqual(self.f.protocol.data, 'hello')) | |
1546 d.addCallback(lambda x: self.assertEqual(self.f.protocol.closed, True)) | |
1547 return defer.gatherResults([d, d2]) | |
1548 | |
1549 def testShutdownException(self): | |
1550 """ | |
1551 If the other side has already closed its connection, | |
1552 loseWriteConnection should pass silently. | |
1553 """ | |
1554 self.f.protocol.transport.loseConnection() | |
1555 self.client.transport.write("X") | |
1556 self.client.transport.loseWriteConnection() | |
1557 self.f.protocol.closedDeferred = d = defer.Deferred() | |
1558 self.client.closedDeferred = d2 = defer.Deferred() | |
1559 d.addCallback(lambda x: | |
1560 self.failUnlessEqual(self.f.protocol.closed, True)) | |
1561 return defer.gatherResults([d, d2]) | |
1562 | |
1563 | |
1564 class HalfCloseBuggyApplicationTests(unittest.TestCase): | |
1565 """ | |
1566 Test half-closing connections where notification code has bugs. | |
1567 """ | |
1568 | |
1569 def setUp(self): | |
1570 """ | |
1571 Set up a server and connect a client to it. Return a Deferred which | |
1572 only fires once this is done. | |
1573 """ | |
1574 self.serverFactory = MyHCFactory() | |
1575 self.serverFactory.protocolConnectionMade = defer.Deferred() | |
1576 self.port = reactor.listenTCP( | |
1577 0, self.serverFactory, interface="127.0.0.1") | |
1578 self.addCleanup(self.port.stopListening) | |
1579 addr = self.port.getHost() | |
1580 creator = protocol.ClientCreator(reactor, MyHCProtocol) | |
1581 clientDeferred = creator.connectTCP(addr.host, addr.port) | |
1582 def setClient(clientProtocol): | |
1583 self.clientProtocol = clientProtocol | |
1584 clientDeferred.addCallback(setClient) | |
1585 return defer.gatherResults([ | |
1586 self.serverFactory.protocolConnectionMade, | |
1587 clientDeferred]) | |
1588 | |
1589 | |
1590 def aBug(self, *args): | |
1591 """ | |
1592 Fake implementation of a callback which illegally raises an | |
1593 exception. | |
1594 """ | |
1595 raise RuntimeError("ONO I AM BUGGY CODE") | |
1596 | |
1597 | |
1598 def _notificationRaisesTest(self): | |
1599 """ | |
1600 Helper for testing that an exception is logged by the time the | |
1601 client protocol loses its connection. | |
1602 """ | |
1603 closed = self.clientProtocol.closedDeferred = defer.Deferred() | |
1604 self.clientProtocol.transport.loseWriteConnection() | |
1605 def check(ignored): | |
1606 errors = self.flushLoggedErrors(RuntimeError) | |
1607 self.assertEqual(len(errors), 1) | |
1608 closed.addCallback(check) | |
1609 return closed | |
1610 | |
1611 | |
1612 def test_readNotificationRaises(self): | |
1613 """ | |
1614 If C{readConnectionLost} raises an exception when the transport | |
1615 calls it to notify the protocol of that event, the exception should | |
1616 be logged and the protocol should be disconnected completely. | |
1617 """ | |
1618 self.serverFactory.protocol.readConnectionLost = self.aBug | |
1619 return self._notificationRaisesTest() | |
1620 | |
1621 | |
1622 def test_writeNotificationRaises(self): | |
1623 """ | |
1624 If C{writeConnectionLost} raises an exception when the transport | |
1625 calls it to notify the protocol of that event, the exception should | |
1626 be logged and the protocol should be disconnected completely. | |
1627 """ | |
1628 self.clientProtocol.writeConnectionLost = self.aBug | |
1629 return self._notificationRaisesTest() | |
1630 | |
1631 | |
1632 | |
1633 class LogTestCase(unittest.TestCase): | |
1634 """ | |
1635 Test logging facility of TCP base classes. | |
1636 """ | |
1637 | |
1638 def test_logstrClientSetup(self): | |
1639 """ | |
1640 Check that the log customization of the client transport happens | |
1641 once the client is connected. | |
1642 """ | |
1643 server = MyServerFactory() | |
1644 | |
1645 client = MyClientFactory() | |
1646 client.protocolConnectionMade = defer.Deferred() | |
1647 | |
1648 port = reactor.listenTCP(0, server, interface='127.0.0.1') | |
1649 self.addCleanup(port.stopListening) | |
1650 | |
1651 connector = reactor.connectTCP( | |
1652 port.getHost().host, port.getHost().port, client) | |
1653 self.addCleanup(connector.disconnect) | |
1654 | |
1655 # It should still have the default value | |
1656 self.assertEquals(connector.transport.logstr, | |
1657 "Uninitialized") | |
1658 | |
1659 def cb(ign): | |
1660 self.assertEquals(connector.transport.logstr, | |
1661 "MyProtocol,client") | |
1662 client.protocolConnectionMade.addCallback(cb) | |
1663 return client.protocolConnectionMade | |
1664 | |
1665 | |
1666 | |
1667 class PauseProducingTestCase(unittest.TestCase): | |
1668 """ | |
1669 Test some behaviors of pausing the production of a transport. | |
1670 """ | |
1671 | |
1672 def test_pauseProducingInConnectionMade(self): | |
1673 """ | |
1674 In C{connectionMade} of a client protocol, C{pauseProducing} used to be | |
1675 ignored: this test is here to ensure it's not ignored. | |
1676 """ | |
1677 server = MyServerFactory() | |
1678 | |
1679 client = MyClientFactory() | |
1680 client.protocolConnectionMade = defer.Deferred() | |
1681 | |
1682 port = reactor.listenTCP(0, server, interface='127.0.0.1') | |
1683 self.addCleanup(port.stopListening) | |
1684 | |
1685 connector = reactor.connectTCP( | |
1686 port.getHost().host, port.getHost().port, client) | |
1687 self.addCleanup(connector.disconnect) | |
1688 | |
1689 def checkInConnectionMade(proto): | |
1690 tr = proto.transport | |
1691 # The transport should already be monitored | |
1692 self.assertIn(tr, reactor.getReaders() + | |
1693 reactor.getWriters()) | |
1694 proto.transport.pauseProducing() | |
1695 self.assertNotIn(tr, reactor.getReaders() + | |
1696 reactor.getWriters()) | |
1697 d = defer.Deferred() | |
1698 d.addCallback(checkAfterConnectionMade) | |
1699 reactor.callLater(0, d.callback, proto) | |
1700 return d | |
1701 def checkAfterConnectionMade(proto): | |
1702 tr = proto.transport | |
1703 # The transport should still not be monitored | |
1704 self.assertNotIn(tr, reactor.getReaders() + | |
1705 reactor.getWriters()) | |
1706 client.protocolConnectionMade.addCallback(checkInConnectionMade) | |
1707 return client.protocolConnectionMade | |
1708 | |
1709 if not interfaces.IReactorFDSet.providedBy(reactor): | |
1710 test_pauseProducingInConnectionMade.skip = "Reactor not providing IReact
orFDSet" | |
1711 | |
1712 | |
1713 | |
1714 class CallBackOrderTestCase(unittest.TestCase): | |
1715 """ | |
1716 Test the order of reactor callbacks | |
1717 """ | |
1718 | |
1719 def test_loseOrder(self): | |
1720 """ | |
1721 Check that Protocol.connectionLost is called before factory's | |
1722 clientConnectionLost | |
1723 """ | |
1724 server = MyServerFactory() | |
1725 server.protocolConnectionMade = (defer.Deferred() | |
1726 .addCallback(lambda proto: self.addCleanup( | |
1727 proto.transport.loseConnection))) | |
1728 | |
1729 client = MyClientFactory() | |
1730 client.protocolConnectionLost = defer.Deferred() | |
1731 client.protocolConnectionMade = defer.Deferred() | |
1732 | |
1733 def _cbCM(res): | |
1734 """ | |
1735 protocol.connectionMade callback | |
1736 """ | |
1737 reactor.callLater(0, client.protocol.transport.loseConnection) | |
1738 | |
1739 client.protocolConnectionMade.addCallback(_cbCM) | |
1740 | |
1741 port = reactor.listenTCP(0, server, interface='127.0.0.1') | |
1742 self.addCleanup(port.stopListening) | |
1743 | |
1744 connector = reactor.connectTCP( | |
1745 port.getHost().host, port.getHost().port, client) | |
1746 self.addCleanup(connector.disconnect) | |
1747 | |
1748 def _cbCCL(res): | |
1749 """ | |
1750 factory.clientConnectionLost callback | |
1751 """ | |
1752 return 'CCL' | |
1753 | |
1754 def _cbCL(res): | |
1755 """ | |
1756 protocol.connectionLost callback | |
1757 """ | |
1758 return 'CL' | |
1759 | |
1760 def _cbGather(res): | |
1761 self.assertEquals(res, ['CL', 'CCL']) | |
1762 | |
1763 d = defer.gatherResults([ | |
1764 client.protocolConnectionLost.addCallback(_cbCL), | |
1765 client.deferred.addCallback(_cbCCL)]) | |
1766 return d.addCallback(_cbGather) | |
1767 | |
1768 | |
1769 | |
1770 try: | |
1771 import resource | |
1772 except ImportError: | |
1773 pass | |
1774 else: | |
1775 numRounds = resource.getrlimit(resource.RLIMIT_NOFILE)[0] + 10 | |
1776 ProperlyCloseFilesTestCase.numberRounds = numRounds | |
OLD | NEW |