| OLD | NEW |
| (Empty) |
| 1 # -*- test-case-name: twisted.test.test_loopback -*- | |
| 2 # Copyright (c) 2001-2004 Twisted Matrix Laboratories. | |
| 3 # See LICENSE for details. | |
| 4 | |
| 5 | |
| 6 """ | |
| 7 Testing support for protocols -- loopback between client and server. | |
| 8 """ | |
| 9 | |
| 10 # system imports | |
| 11 import tempfile | |
| 12 from zope.interface import implements | |
| 13 | |
| 14 # Twisted Imports | |
| 15 from twisted.protocols import policies | |
| 16 from twisted.internet import interfaces, protocol, main, defer | |
| 17 from twisted.python import failure | |
| 18 from twisted.internet.interfaces import IAddress | |
| 19 | |
| 20 | |
| 21 class _LoopbackQueue(object): | |
| 22 """ | |
| 23 Trivial wrapper around a list to give it an interface like a queue, which | |
| 24 the addition of also sending notifications by way of a Deferred whenever | |
| 25 the list has something added to it. | |
| 26 """ | |
| 27 | |
| 28 _notificationDeferred = None | |
| 29 disconnect = False | |
| 30 | |
| 31 def __init__(self): | |
| 32 self._queue = [] | |
| 33 | |
| 34 | |
| 35 def put(self, v): | |
| 36 self._queue.append(v) | |
| 37 if self._notificationDeferred is not None: | |
| 38 d, self._notificationDeferred = self._notificationDeferred, None | |
| 39 d.callback(None) | |
| 40 | |
| 41 | |
| 42 def __nonzero__(self): | |
| 43 return bool(self._queue) | |
| 44 | |
| 45 | |
| 46 def get(self): | |
| 47 return self._queue.pop(0) | |
| 48 | |
| 49 | |
| 50 | |
| 51 class _LoopbackAddress(object): | |
| 52 implements(IAddress) | |
| 53 | |
| 54 | |
| 55 class _LoopbackTransport(object): | |
| 56 implements(interfaces.ITransport, interfaces.IConsumer) | |
| 57 | |
| 58 disconnecting = False | |
| 59 producer = None | |
| 60 | |
| 61 # ITransport | |
| 62 def __init__(self, q): | |
| 63 self.q = q | |
| 64 | |
| 65 def write(self, bytes): | |
| 66 self.q.put(bytes) | |
| 67 | |
| 68 def writeSequence(self, iovec): | |
| 69 self.q.put(''.join(iovec)) | |
| 70 | |
| 71 def loseConnection(self): | |
| 72 self.q.disconnect = True | |
| 73 self.q.put('') | |
| 74 | |
| 75 def getPeer(self): | |
| 76 return _LoopbackAddress() | |
| 77 | |
| 78 def getHost(self): | |
| 79 return _LoopbackAddress() | |
| 80 | |
| 81 # IConsumer | |
| 82 def registerProducer(self, producer, streaming): | |
| 83 assert self.producer is None | |
| 84 self.producer = producer | |
| 85 self.streamingProducer = streaming | |
| 86 self._pollProducer() | |
| 87 | |
| 88 def unregisterProducer(self): | |
| 89 assert self.producer is not None | |
| 90 self.producer = None | |
| 91 | |
| 92 def _pollProducer(self): | |
| 93 if self.producer is not None and not self.streamingProducer: | |
| 94 self.producer.resumeProducing() | |
| 95 | |
| 96 | |
| 97 | |
| 98 def loopbackAsync(server, client): | |
| 99 """ | |
| 100 Establish a connection between C{server} and C{client} then transfer data | |
| 101 between them until the connection is closed. This is often useful for | |
| 102 testing a protocol. | |
| 103 | |
| 104 @param server: The protocol instance representing the server-side of this | |
| 105 connection. | |
| 106 | |
| 107 @param client: The protocol instance representing the client-side of this | |
| 108 connection. | |
| 109 | |
| 110 @return: A L{Deferred} which fires when the connection has been closed and | |
| 111 both sides have received notification of this. | |
| 112 """ | |
| 113 serverToClient = _LoopbackQueue() | |
| 114 clientToServer = _LoopbackQueue() | |
| 115 | |
| 116 server.makeConnection(_LoopbackTransport(serverToClient)) | |
| 117 client.makeConnection(_LoopbackTransport(clientToServer)) | |
| 118 | |
| 119 return _loopbackAsyncBody(server, serverToClient, client, clientToServer) | |
| 120 | |
| 121 | |
| 122 | |
| 123 def _loopbackAsyncBody(server, serverToClient, client, clientToServer): | |
| 124 """ | |
| 125 Transfer bytes from the output queue of each protocol to the input of the ot
her. | |
| 126 | |
| 127 @param server: The protocol instance representing the server-side of this | |
| 128 connection. | |
| 129 | |
| 130 @param serverToClient: The L{_LoopbackQueue} holding the server's output. | |
| 131 | |
| 132 @param client: The protocol instance representing the client-side of this | |
| 133 connection. | |
| 134 | |
| 135 @param clientToServer: The L{_LoopbackQueue} holding the client's output. | |
| 136 | |
| 137 @return: A L{Deferred} which fires when the connection has been closed and | |
| 138 both sides have received notification of this. | |
| 139 """ | |
| 140 def pump(source, q, target): | |
| 141 sent = False | |
| 142 while q: | |
| 143 sent = True | |
| 144 bytes = q.get() | |
| 145 if bytes: | |
| 146 target.dataReceived(bytes) | |
| 147 | |
| 148 # A write buffer has now been emptied. Give any producer on that side | |
| 149 # an opportunity to produce more data. | |
| 150 source.transport._pollProducer() | |
| 151 | |
| 152 return sent | |
| 153 | |
| 154 while 1: | |
| 155 disconnect = clientSent = serverSent = False | |
| 156 | |
| 157 # Deliver the data which has been written. | |
| 158 serverSent = pump(server, serverToClient, client) | |
| 159 clientSent = pump(client, clientToServer, server) | |
| 160 | |
| 161 if not clientSent and not serverSent: | |
| 162 # Neither side wrote any data. Wait for some new data to be added | |
| 163 # before trying to do anything further. | |
| 164 d = clientToServer._notificationDeferred = serverToClient._notificat
ionDeferred = defer.Deferred() | |
| 165 d.addCallback(_loopbackAsyncContinue, server, serverToClient, client
, clientToServer) | |
| 166 return d | |
| 167 if serverToClient.disconnect: | |
| 168 # The server wants to drop the connection. Flush any remaining | |
| 169 # data it has. | |
| 170 disconnect = True | |
| 171 pump(server, serverToClient, client) | |
| 172 elif clientToServer.disconnect: | |
| 173 # The client wants to drop the connection. Flush any remaining | |
| 174 # data it has. | |
| 175 disconnect = True | |
| 176 pump(client, clientToServer, server) | |
| 177 if disconnect: | |
| 178 # Someone wanted to disconnect, so okay, the connection is gone. | |
| 179 server.connectionLost(failure.Failure(main.CONNECTION_DONE)) | |
| 180 client.connectionLost(failure.Failure(main.CONNECTION_DONE)) | |
| 181 return defer.succeed(None) | |
| 182 | |
| 183 | |
| 184 | |
| 185 def _loopbackAsyncContinue(ignored, server, serverToClient, client, clientToServ
er): | |
| 186 # Clear the Deferred from each message queue, since it has already fired | |
| 187 # and cannot be used again. | |
| 188 clientToServer._notificationDeferred = serverToClient._notificationDeferred
= None | |
| 189 | |
| 190 # Push some more bytes around. | |
| 191 return _loopbackAsyncBody(server, serverToClient, client, clientToServer) | |
| 192 | |
| 193 | |
| 194 | |
| 195 class LoopbackRelay: | |
| 196 | |
| 197 implements(interfaces.ITransport, interfaces.IConsumer) | |
| 198 | |
| 199 buffer = '' | |
| 200 shouldLose = 0 | |
| 201 disconnecting = 0 | |
| 202 producer = None | |
| 203 | |
| 204 def __init__(self, target, logFile=None): | |
| 205 self.target = target | |
| 206 self.logFile = logFile | |
| 207 | |
| 208 def write(self, data): | |
| 209 self.buffer = self.buffer + data | |
| 210 if self.logFile: | |
| 211 self.logFile.write("loopback writing %s\n" % repr(data)) | |
| 212 | |
| 213 def writeSequence(self, iovec): | |
| 214 self.write("".join(iovec)) | |
| 215 | |
| 216 def clearBuffer(self): | |
| 217 if self.shouldLose == -1: | |
| 218 return | |
| 219 | |
| 220 if self.producer: | |
| 221 self.producer.resumeProducing() | |
| 222 if self.buffer: | |
| 223 if self.logFile: | |
| 224 self.logFile.write("loopback receiving %s\n" % repr(self.buffer)
) | |
| 225 buffer = self.buffer | |
| 226 self.buffer = '' | |
| 227 self.target.dataReceived(buffer) | |
| 228 if self.shouldLose == 1: | |
| 229 self.shouldLose = -1 | |
| 230 self.target.connectionLost(failure.Failure(main.CONNECTION_DONE)) | |
| 231 | |
| 232 def loseConnection(self): | |
| 233 if self.shouldLose != -1: | |
| 234 self.shouldLose = 1 | |
| 235 | |
| 236 def getHost(self): | |
| 237 return 'loopback' | |
| 238 | |
| 239 def getPeer(self): | |
| 240 return 'loopback' | |
| 241 | |
| 242 def registerProducer(self, producer, streaming): | |
| 243 self.producer = producer | |
| 244 | |
| 245 def unregisterProducer(self): | |
| 246 self.producer = None | |
| 247 | |
| 248 def logPrefix(self): | |
| 249 return 'Loopback(%r)' % (self.target.__class__.__name__,) | |
| 250 | |
| 251 def loopback(server, client, logFile=None): | |
| 252 """Run session between server and client. | |
| 253 DEPRECATED in Twisted 2.5. Use loopbackAsync instead. | |
| 254 """ | |
| 255 import warnings | |
| 256 warnings.warn('loopback() is deprecated (since Twisted 2.5). ' | |
| 257 'Use loopbackAsync() instead.', | |
| 258 stacklevel=2, category=DeprecationWarning) | |
| 259 from twisted.internet import reactor | |
| 260 serverToClient = LoopbackRelay(client, logFile) | |
| 261 clientToServer = LoopbackRelay(server, logFile) | |
| 262 server.makeConnection(serverToClient) | |
| 263 client.makeConnection(clientToServer) | |
| 264 while 1: | |
| 265 reactor.iterate(0.01) # this is to clear any deferreds | |
| 266 serverToClient.clearBuffer() | |
| 267 clientToServer.clearBuffer() | |
| 268 if serverToClient.shouldLose: | |
| 269 serverToClient.clearBuffer() | |
| 270 server.connectionLost(failure.Failure(main.CONNECTION_DONE)) | |
| 271 break | |
| 272 elif clientToServer.shouldLose: | |
| 273 client.connectionLost(failure.Failure(main.CONNECTION_DONE)) | |
| 274 break | |
| 275 reactor.iterate() # last gasp before I go away | |
| 276 | |
| 277 | |
| 278 class LoopbackClientFactory(protocol.ClientFactory): | |
| 279 | |
| 280 def __init__(self, protocol): | |
| 281 self.disconnected = 0 | |
| 282 self.deferred = defer.Deferred() | |
| 283 self.protocol = protocol | |
| 284 | |
| 285 def buildProtocol(self, addr): | |
| 286 return self.protocol | |
| 287 | |
| 288 def clientConnectionLost(self, connector, reason): | |
| 289 self.disconnected = 1 | |
| 290 self.deferred.callback(None) | |
| 291 | |
| 292 | |
| 293 class _FireOnClose(policies.ProtocolWrapper): | |
| 294 def __init__(self, protocol, factory): | |
| 295 policies.ProtocolWrapper.__init__(self, protocol, factory) | |
| 296 self.deferred = defer.Deferred() | |
| 297 | |
| 298 def connectionLost(self, reason): | |
| 299 policies.ProtocolWrapper.connectionLost(self, reason) | |
| 300 self.deferred.callback(None) | |
| 301 | |
| 302 | |
| 303 def loopbackTCP(server, client, port=0, noisy=True): | |
| 304 """Run session between server and client protocol instances over TCP.""" | |
| 305 from twisted.internet import reactor | |
| 306 f = policies.WrappingFactory(protocol.Factory()) | |
| 307 serverWrapper = _FireOnClose(f, server) | |
| 308 f.noisy = noisy | |
| 309 f.buildProtocol = lambda addr: serverWrapper | |
| 310 serverPort = reactor.listenTCP(port, f, interface='127.0.0.1') | |
| 311 clientF = LoopbackClientFactory(client) | |
| 312 clientF.noisy = noisy | |
| 313 reactor.connectTCP('127.0.0.1', serverPort.getHost().port, clientF) | |
| 314 d = clientF.deferred | |
| 315 d.addCallback(lambda x: serverWrapper.deferred) | |
| 316 d.addCallback(lambda x: serverPort.stopListening()) | |
| 317 return d | |
| 318 | |
| 319 | |
| 320 def loopbackUNIX(server, client, noisy=True): | |
| 321 """Run session between server and client protocol instances over UNIX socket
.""" | |
| 322 path = tempfile.mktemp() | |
| 323 from twisted.internet import reactor | |
| 324 f = policies.WrappingFactory(protocol.Factory()) | |
| 325 serverWrapper = _FireOnClose(f, server) | |
| 326 f.noisy = noisy | |
| 327 f.buildProtocol = lambda addr: serverWrapper | |
| 328 serverPort = reactor.listenUNIX(path, f) | |
| 329 clientF = LoopbackClientFactory(client) | |
| 330 clientF.noisy = noisy | |
| 331 reactor.connectUNIX(path, clientF) | |
| 332 d = clientF.deferred | |
| 333 d.addCallback(lambda x: serverWrapper.deferred) | |
| 334 d.addCallback(lambda x: serverPort.stopListening()) | |
| 335 return d | |
| OLD | NEW |