| OLD | NEW |
| (Empty) |
| 1 # Copyright (c) 2001-2004 Twisted Matrix Laboratories. | |
| 2 # See LICENSE for details. | |
| 3 | |
| 4 | |
| 5 """ | |
| 6 Test case for twisted.protocols.loopback | |
| 7 """ | |
| 8 | |
| 9 from zope.interface import implements | |
| 10 | |
| 11 from twisted.trial import unittest | |
| 12 from twisted.trial.util import suppress as SUPPRESS | |
| 13 from twisted.protocols import basic, loopback | |
| 14 from twisted.internet import defer | |
| 15 from twisted.internet.protocol import Protocol | |
| 16 from twisted.internet.defer import Deferred | |
| 17 from twisted.internet.interfaces import IAddress, IPushProducer, IPullProducer | |
| 18 from twisted.internet import reactor | |
| 19 | |
| 20 | |
| 21 class SimpleProtocol(basic.LineReceiver): | |
| 22 def __init__(self): | |
| 23 self.conn = defer.Deferred() | |
| 24 self.lines = [] | |
| 25 self.connLost = [] | |
| 26 | |
| 27 def connectionMade(self): | |
| 28 self.conn.callback(None) | |
| 29 | |
| 30 def lineReceived(self, line): | |
| 31 self.lines.append(line) | |
| 32 | |
| 33 def connectionLost(self, reason): | |
| 34 self.connLost.append(reason) | |
| 35 | |
| 36 | |
| 37 class DoomProtocol(SimpleProtocol): | |
| 38 i = 0 | |
| 39 def lineReceived(self, line): | |
| 40 self.i += 1 | |
| 41 if self.i < 4: | |
| 42 # by this point we should have connection closed, | |
| 43 # but just in case we didn't we won't ever send 'Hello 4' | |
| 44 self.sendLine("Hello %d" % self.i) | |
| 45 SimpleProtocol.lineReceived(self, line) | |
| 46 if self.lines[-1] == "Hello 3": | |
| 47 self.transport.loseConnection() | |
| 48 | |
| 49 | |
| 50 class LoopbackTestCaseMixin: | |
| 51 def testRegularFunction(self): | |
| 52 s = SimpleProtocol() | |
| 53 c = SimpleProtocol() | |
| 54 | |
| 55 def sendALine(result): | |
| 56 s.sendLine("THIS IS LINE ONE!") | |
| 57 s.transport.loseConnection() | |
| 58 s.conn.addCallback(sendALine) | |
| 59 | |
| 60 def check(ignored): | |
| 61 self.assertEquals(c.lines, ["THIS IS LINE ONE!"]) | |
| 62 self.assertEquals(len(s.connLost), 1) | |
| 63 self.assertEquals(len(c.connLost), 1) | |
| 64 d = defer.maybeDeferred(self.loopbackFunc, s, c) | |
| 65 d.addCallback(check) | |
| 66 return d | |
| 67 | |
| 68 def testSneakyHiddenDoom(self): | |
| 69 s = DoomProtocol() | |
| 70 c = DoomProtocol() | |
| 71 | |
| 72 def sendALine(result): | |
| 73 s.sendLine("DOOM LINE") | |
| 74 s.conn.addCallback(sendALine) | |
| 75 | |
| 76 def check(ignored): | |
| 77 self.assertEquals(s.lines, ['Hello 1', 'Hello 2', 'Hello 3']) | |
| 78 self.assertEquals(c.lines, ['DOOM LINE', 'Hello 1', 'Hello 2', 'Hell
o 3']) | |
| 79 self.assertEquals(len(s.connLost), 1) | |
| 80 self.assertEquals(len(c.connLost), 1) | |
| 81 d = defer.maybeDeferred(self.loopbackFunc, s, c) | |
| 82 d.addCallback(check) | |
| 83 return d | |
| 84 | |
| 85 | |
| 86 | |
| 87 class LoopbackTestCase(LoopbackTestCaseMixin, unittest.TestCase): | |
| 88 loopbackFunc = staticmethod(loopback.loopback) | |
| 89 | |
| 90 def testRegularFunction(self): | |
| 91 """ | |
| 92 Suppress loopback deprecation warning. | |
| 93 """ | |
| 94 return LoopbackTestCaseMixin.testRegularFunction(self) | |
| 95 testRegularFunction.suppress = [ | |
| 96 SUPPRESS(message="loopback\(\) is deprecated", | |
| 97 category=DeprecationWarning)] | |
| 98 | |
| 99 | |
| 100 | |
| 101 class LoopbackAsyncTestCase(LoopbackTestCase): | |
| 102 loopbackFunc = staticmethod(loopback.loopbackAsync) | |
| 103 | |
| 104 | |
| 105 def test_makeConnection(self): | |
| 106 """ | |
| 107 Test that the client and server protocol both have makeConnection | |
| 108 invoked on them by loopbackAsync. | |
| 109 """ | |
| 110 class TestProtocol(Protocol): | |
| 111 transport = None | |
| 112 def makeConnection(self, transport): | |
| 113 self.transport = transport | |
| 114 | |
| 115 server = TestProtocol() | |
| 116 client = TestProtocol() | |
| 117 loopback.loopbackAsync(server, client) | |
| 118 self.failIfEqual(client.transport, None) | |
| 119 self.failIfEqual(server.transport, None) | |
| 120 | |
| 121 | |
| 122 def _hostpeertest(self, get, testServer): | |
| 123 """ | |
| 124 Test one of the permutations of client/server host/peer. | |
| 125 """ | |
| 126 class TestProtocol(Protocol): | |
| 127 def makeConnection(self, transport): | |
| 128 Protocol.makeConnection(self, transport) | |
| 129 self.onConnection.callback(transport) | |
| 130 | |
| 131 if testServer: | |
| 132 server = TestProtocol() | |
| 133 d = server.onConnection = Deferred() | |
| 134 client = Protocol() | |
| 135 else: | |
| 136 server = Protocol() | |
| 137 client = TestProtocol() | |
| 138 d = client.onConnection = Deferred() | |
| 139 | |
| 140 loopback.loopbackAsync(server, client) | |
| 141 | |
| 142 def connected(transport): | |
| 143 host = getattr(transport, get)() | |
| 144 self.failUnless(IAddress.providedBy(host)) | |
| 145 | |
| 146 return d.addCallback(connected) | |
| 147 | |
| 148 | |
| 149 def test_serverHost(self): | |
| 150 """ | |
| 151 Test that the server gets a transport with a properly functioning | |
| 152 implementation of L{ITransport.getHost}. | |
| 153 """ | |
| 154 return self._hostpeertest("getHost", True) | |
| 155 | |
| 156 | |
| 157 def test_serverPeer(self): | |
| 158 """ | |
| 159 Like C{test_serverHost} but for L{ITransport.getPeer} | |
| 160 """ | |
| 161 return self._hostpeertest("getPeer", True) | |
| 162 | |
| 163 | |
| 164 def test_clientHost(self, get="getHost"): | |
| 165 """ | |
| 166 Test that the client gets a transport with a properly functioning | |
| 167 implementation of L{ITransport.getHost}. | |
| 168 """ | |
| 169 return self._hostpeertest("getHost", False) | |
| 170 | |
| 171 | |
| 172 def test_clientPeer(self): | |
| 173 """ | |
| 174 Like C{test_clientHost} but for L{ITransport.getPeer}. | |
| 175 """ | |
| 176 return self._hostpeertest("getPeer", False) | |
| 177 | |
| 178 | |
| 179 def _greetingtest(self, write, testServer): | |
| 180 """ | |
| 181 Test one of the permutations of write/writeSequence client/server. | |
| 182 """ | |
| 183 class GreeteeProtocol(Protocol): | |
| 184 bytes = "" | |
| 185 def dataReceived(self, bytes): | |
| 186 self.bytes += bytes | |
| 187 if self.bytes == "bytes": | |
| 188 self.received.callback(None) | |
| 189 | |
| 190 class GreeterProtocol(Protocol): | |
| 191 def connectionMade(self): | |
| 192 getattr(self.transport, write)("bytes") | |
| 193 | |
| 194 if testServer: | |
| 195 server = GreeterProtocol() | |
| 196 client = GreeteeProtocol() | |
| 197 d = client.received = Deferred() | |
| 198 else: | |
| 199 server = GreeteeProtocol() | |
| 200 d = server.received = Deferred() | |
| 201 client = GreeterProtocol() | |
| 202 | |
| 203 loopback.loopbackAsync(server, client) | |
| 204 return d | |
| 205 | |
| 206 | |
| 207 def test_clientGreeting(self): | |
| 208 """ | |
| 209 Test that on a connection where the client speaks first, the server | |
| 210 receives the bytes sent by the client. | |
| 211 """ | |
| 212 return self._greetingtest("write", False) | |
| 213 | |
| 214 | |
| 215 def test_clientGreetingSequence(self): | |
| 216 """ | |
| 217 Like C{test_clientGreeting}, but use C{writeSequence} instead of | |
| 218 C{write} to issue the greeting. | |
| 219 """ | |
| 220 return self._greetingtest("writeSequence", False) | |
| 221 | |
| 222 | |
| 223 def test_serverGreeting(self, write="write"): | |
| 224 """ | |
| 225 Test that on a connection where the server speaks first, the client | |
| 226 receives the bytes sent by the server. | |
| 227 """ | |
| 228 return self._greetingtest("write", True) | |
| 229 | |
| 230 | |
| 231 def test_serverGreetingSequence(self): | |
| 232 """ | |
| 233 Like C{test_serverGreeting}, but use C{writeSequence} instead of | |
| 234 C{write} to issue the greeting. | |
| 235 """ | |
| 236 return self._greetingtest("writeSequence", True) | |
| 237 | |
| 238 | |
| 239 def _producertest(self, producerClass): | |
| 240 toProduce = map(str, range(0, 10)) | |
| 241 | |
| 242 class ProducingProtocol(Protocol): | |
| 243 def connectionMade(self): | |
| 244 self.producer = producerClass(list(toProduce)) | |
| 245 self.producer.start(self.transport) | |
| 246 | |
| 247 class ReceivingProtocol(Protocol): | |
| 248 bytes = "" | |
| 249 def dataReceived(self, bytes): | |
| 250 self.bytes += bytes | |
| 251 if self.bytes == ''.join(toProduce): | |
| 252 self.received.callback((client, server)) | |
| 253 | |
| 254 server = ProducingProtocol() | |
| 255 client = ReceivingProtocol() | |
| 256 client.received = Deferred() | |
| 257 | |
| 258 loopback.loopbackAsync(server, client) | |
| 259 return client.received | |
| 260 | |
| 261 | |
| 262 def test_pushProducer(self): | |
| 263 """ | |
| 264 Test a push producer registered against a loopback transport. | |
| 265 """ | |
| 266 class PushProducer(object): | |
| 267 implements(IPushProducer) | |
| 268 resumed = False | |
| 269 | |
| 270 def __init__(self, toProduce): | |
| 271 self.toProduce = toProduce | |
| 272 | |
| 273 def resumeProducing(self): | |
| 274 self.resumed = True | |
| 275 | |
| 276 def start(self, consumer): | |
| 277 self.consumer = consumer | |
| 278 consumer.registerProducer(self, True) | |
| 279 self._produceAndSchedule() | |
| 280 | |
| 281 def _produceAndSchedule(self): | |
| 282 if self.toProduce: | |
| 283 self.consumer.write(self.toProduce.pop(0)) | |
| 284 reactor.callLater(0, self._produceAndSchedule) | |
| 285 else: | |
| 286 self.consumer.unregisterProducer() | |
| 287 d = self._producertest(PushProducer) | |
| 288 | |
| 289 def finished((client, server)): | |
| 290 self.failIf( | |
| 291 server.producer.resumed, | |
| 292 "Streaming producer should not have been resumed.") | |
| 293 d.addCallback(finished) | |
| 294 return d | |
| 295 | |
| 296 | |
| 297 def test_pullProducer(self): | |
| 298 """ | |
| 299 Test a pull producer registered against a loopback transport. | |
| 300 """ | |
| 301 class PullProducer(object): | |
| 302 implements(IPullProducer) | |
| 303 | |
| 304 def __init__(self, toProduce): | |
| 305 self.toProduce = toProduce | |
| 306 | |
| 307 def start(self, consumer): | |
| 308 self.consumer = consumer | |
| 309 self.consumer.registerProducer(self, False) | |
| 310 | |
| 311 def resumeProducing(self): | |
| 312 self.consumer.write(self.toProduce.pop(0)) | |
| 313 if not self.toProduce: | |
| 314 self.consumer.unregisterProducer() | |
| 315 return self._producertest(PullProducer) | |
| 316 | |
| 317 | |
| 318 class LoopbackTCPTestCase(LoopbackTestCase): | |
| 319 loopbackFunc = staticmethod(loopback.loopbackTCP) | |
| 320 | |
| 321 | |
| 322 class LoopbackUNIXTestCase(LoopbackTestCase): | |
| 323 loopbackFunc = staticmethod(loopback.loopbackUNIX) | |
| 324 | |
| 325 def setUp(self): | |
| 326 from twisted.internet import reactor, interfaces | |
| 327 if interfaces.IReactorUNIX(reactor, None) is None: | |
| 328 raise unittest.SkipTest("Current reactor does not support UNIX socke
ts") | |
| OLD | NEW |