OLD | NEW |
| (Empty) |
1 # -*- test-case-name: twisted.test.test_loopback -*- | |
2 # Copyright (c) 2001-2004 Twisted Matrix Laboratories. | |
3 # See LICENSE for details. | |
4 | |
5 | |
6 """ | |
7 Testing support for protocols -- loopback between client and server. | |
8 """ | |
9 | |
10 # system imports | |
11 import tempfile | |
12 from zope.interface import implements | |
13 | |
14 # Twisted Imports | |
15 from twisted.protocols import policies | |
16 from twisted.internet import interfaces, protocol, main, defer | |
17 from twisted.python import failure | |
18 from twisted.internet.interfaces import IAddress | |
19 | |
20 | |
21 class _LoopbackQueue(object): | |
22 """ | |
23 Trivial wrapper around a list to give it an interface like a queue, which | |
24 the addition of also sending notifications by way of a Deferred whenever | |
25 the list has something added to it. | |
26 """ | |
27 | |
28 _notificationDeferred = None | |
29 disconnect = False | |
30 | |
31 def __init__(self): | |
32 self._queue = [] | |
33 | |
34 | |
35 def put(self, v): | |
36 self._queue.append(v) | |
37 if self._notificationDeferred is not None: | |
38 d, self._notificationDeferred = self._notificationDeferred, None | |
39 d.callback(None) | |
40 | |
41 | |
42 def __nonzero__(self): | |
43 return bool(self._queue) | |
44 | |
45 | |
46 def get(self): | |
47 return self._queue.pop(0) | |
48 | |
49 | |
50 | |
51 class _LoopbackAddress(object): | |
52 implements(IAddress) | |
53 | |
54 | |
55 class _LoopbackTransport(object): | |
56 implements(interfaces.ITransport, interfaces.IConsumer) | |
57 | |
58 disconnecting = False | |
59 producer = None | |
60 | |
61 # ITransport | |
62 def __init__(self, q): | |
63 self.q = q | |
64 | |
65 def write(self, bytes): | |
66 self.q.put(bytes) | |
67 | |
68 def writeSequence(self, iovec): | |
69 self.q.put(''.join(iovec)) | |
70 | |
71 def loseConnection(self): | |
72 self.q.disconnect = True | |
73 self.q.put('') | |
74 | |
75 def getPeer(self): | |
76 return _LoopbackAddress() | |
77 | |
78 def getHost(self): | |
79 return _LoopbackAddress() | |
80 | |
81 # IConsumer | |
82 def registerProducer(self, producer, streaming): | |
83 assert self.producer is None | |
84 self.producer = producer | |
85 self.streamingProducer = streaming | |
86 self._pollProducer() | |
87 | |
88 def unregisterProducer(self): | |
89 assert self.producer is not None | |
90 self.producer = None | |
91 | |
92 def _pollProducer(self): | |
93 if self.producer is not None and not self.streamingProducer: | |
94 self.producer.resumeProducing() | |
95 | |
96 | |
97 | |
98 def loopbackAsync(server, client): | |
99 """ | |
100 Establish a connection between C{server} and C{client} then transfer data | |
101 between them until the connection is closed. This is often useful for | |
102 testing a protocol. | |
103 | |
104 @param server: The protocol instance representing the server-side of this | |
105 connection. | |
106 | |
107 @param client: The protocol instance representing the client-side of this | |
108 connection. | |
109 | |
110 @return: A L{Deferred} which fires when the connection has been closed and | |
111 both sides have received notification of this. | |
112 """ | |
113 serverToClient = _LoopbackQueue() | |
114 clientToServer = _LoopbackQueue() | |
115 | |
116 server.makeConnection(_LoopbackTransport(serverToClient)) | |
117 client.makeConnection(_LoopbackTransport(clientToServer)) | |
118 | |
119 return _loopbackAsyncBody(server, serverToClient, client, clientToServer) | |
120 | |
121 | |
122 | |
123 def _loopbackAsyncBody(server, serverToClient, client, clientToServer): | |
124 """ | |
125 Transfer bytes from the output queue of each protocol to the input of the ot
her. | |
126 | |
127 @param server: The protocol instance representing the server-side of this | |
128 connection. | |
129 | |
130 @param serverToClient: The L{_LoopbackQueue} holding the server's output. | |
131 | |
132 @param client: The protocol instance representing the client-side of this | |
133 connection. | |
134 | |
135 @param clientToServer: The L{_LoopbackQueue} holding the client's output. | |
136 | |
137 @return: A L{Deferred} which fires when the connection has been closed and | |
138 both sides have received notification of this. | |
139 """ | |
140 def pump(source, q, target): | |
141 sent = False | |
142 while q: | |
143 sent = True | |
144 bytes = q.get() | |
145 if bytes: | |
146 target.dataReceived(bytes) | |
147 | |
148 # A write buffer has now been emptied. Give any producer on that side | |
149 # an opportunity to produce more data. | |
150 source.transport._pollProducer() | |
151 | |
152 return sent | |
153 | |
154 while 1: | |
155 disconnect = clientSent = serverSent = False | |
156 | |
157 # Deliver the data which has been written. | |
158 serverSent = pump(server, serverToClient, client) | |
159 clientSent = pump(client, clientToServer, server) | |
160 | |
161 if not clientSent and not serverSent: | |
162 # Neither side wrote any data. Wait for some new data to be added | |
163 # before trying to do anything further. | |
164 d = clientToServer._notificationDeferred = serverToClient._notificat
ionDeferred = defer.Deferred() | |
165 d.addCallback(_loopbackAsyncContinue, server, serverToClient, client
, clientToServer) | |
166 return d | |
167 if serverToClient.disconnect: | |
168 # The server wants to drop the connection. Flush any remaining | |
169 # data it has. | |
170 disconnect = True | |
171 pump(server, serverToClient, client) | |
172 elif clientToServer.disconnect: | |
173 # The client wants to drop the connection. Flush any remaining | |
174 # data it has. | |
175 disconnect = True | |
176 pump(client, clientToServer, server) | |
177 if disconnect: | |
178 # Someone wanted to disconnect, so okay, the connection is gone. | |
179 server.connectionLost(failure.Failure(main.CONNECTION_DONE)) | |
180 client.connectionLost(failure.Failure(main.CONNECTION_DONE)) | |
181 return defer.succeed(None) | |
182 | |
183 | |
184 | |
185 def _loopbackAsyncContinue(ignored, server, serverToClient, client, clientToServ
er): | |
186 # Clear the Deferred from each message queue, since it has already fired | |
187 # and cannot be used again. | |
188 clientToServer._notificationDeferred = serverToClient._notificationDeferred
= None | |
189 | |
190 # Push some more bytes around. | |
191 return _loopbackAsyncBody(server, serverToClient, client, clientToServer) | |
192 | |
193 | |
194 | |
195 class LoopbackRelay: | |
196 | |
197 implements(interfaces.ITransport, interfaces.IConsumer) | |
198 | |
199 buffer = '' | |
200 shouldLose = 0 | |
201 disconnecting = 0 | |
202 producer = None | |
203 | |
204 def __init__(self, target, logFile=None): | |
205 self.target = target | |
206 self.logFile = logFile | |
207 | |
208 def write(self, data): | |
209 self.buffer = self.buffer + data | |
210 if self.logFile: | |
211 self.logFile.write("loopback writing %s\n" % repr(data)) | |
212 | |
213 def writeSequence(self, iovec): | |
214 self.write("".join(iovec)) | |
215 | |
216 def clearBuffer(self): | |
217 if self.shouldLose == -1: | |
218 return | |
219 | |
220 if self.producer: | |
221 self.producer.resumeProducing() | |
222 if self.buffer: | |
223 if self.logFile: | |
224 self.logFile.write("loopback receiving %s\n" % repr(self.buffer)
) | |
225 buffer = self.buffer | |
226 self.buffer = '' | |
227 self.target.dataReceived(buffer) | |
228 if self.shouldLose == 1: | |
229 self.shouldLose = -1 | |
230 self.target.connectionLost(failure.Failure(main.CONNECTION_DONE)) | |
231 | |
232 def loseConnection(self): | |
233 if self.shouldLose != -1: | |
234 self.shouldLose = 1 | |
235 | |
236 def getHost(self): | |
237 return 'loopback' | |
238 | |
239 def getPeer(self): | |
240 return 'loopback' | |
241 | |
242 def registerProducer(self, producer, streaming): | |
243 self.producer = producer | |
244 | |
245 def unregisterProducer(self): | |
246 self.producer = None | |
247 | |
248 def logPrefix(self): | |
249 return 'Loopback(%r)' % (self.target.__class__.__name__,) | |
250 | |
251 def loopback(server, client, logFile=None): | |
252 """Run session between server and client. | |
253 DEPRECATED in Twisted 2.5. Use loopbackAsync instead. | |
254 """ | |
255 import warnings | |
256 warnings.warn('loopback() is deprecated (since Twisted 2.5). ' | |
257 'Use loopbackAsync() instead.', | |
258 stacklevel=2, category=DeprecationWarning) | |
259 from twisted.internet import reactor | |
260 serverToClient = LoopbackRelay(client, logFile) | |
261 clientToServer = LoopbackRelay(server, logFile) | |
262 server.makeConnection(serverToClient) | |
263 client.makeConnection(clientToServer) | |
264 while 1: | |
265 reactor.iterate(0.01) # this is to clear any deferreds | |
266 serverToClient.clearBuffer() | |
267 clientToServer.clearBuffer() | |
268 if serverToClient.shouldLose: | |
269 serverToClient.clearBuffer() | |
270 server.connectionLost(failure.Failure(main.CONNECTION_DONE)) | |
271 break | |
272 elif clientToServer.shouldLose: | |
273 client.connectionLost(failure.Failure(main.CONNECTION_DONE)) | |
274 break | |
275 reactor.iterate() # last gasp before I go away | |
276 | |
277 | |
278 class LoopbackClientFactory(protocol.ClientFactory): | |
279 | |
280 def __init__(self, protocol): | |
281 self.disconnected = 0 | |
282 self.deferred = defer.Deferred() | |
283 self.protocol = protocol | |
284 | |
285 def buildProtocol(self, addr): | |
286 return self.protocol | |
287 | |
288 def clientConnectionLost(self, connector, reason): | |
289 self.disconnected = 1 | |
290 self.deferred.callback(None) | |
291 | |
292 | |
293 class _FireOnClose(policies.ProtocolWrapper): | |
294 def __init__(self, protocol, factory): | |
295 policies.ProtocolWrapper.__init__(self, protocol, factory) | |
296 self.deferred = defer.Deferred() | |
297 | |
298 def connectionLost(self, reason): | |
299 policies.ProtocolWrapper.connectionLost(self, reason) | |
300 self.deferred.callback(None) | |
301 | |
302 | |
303 def loopbackTCP(server, client, port=0, noisy=True): | |
304 """Run session between server and client protocol instances over TCP.""" | |
305 from twisted.internet import reactor | |
306 f = policies.WrappingFactory(protocol.Factory()) | |
307 serverWrapper = _FireOnClose(f, server) | |
308 f.noisy = noisy | |
309 f.buildProtocol = lambda addr: serverWrapper | |
310 serverPort = reactor.listenTCP(port, f, interface='127.0.0.1') | |
311 clientF = LoopbackClientFactory(client) | |
312 clientF.noisy = noisy | |
313 reactor.connectTCP('127.0.0.1', serverPort.getHost().port, clientF) | |
314 d = clientF.deferred | |
315 d.addCallback(lambda x: serverWrapper.deferred) | |
316 d.addCallback(lambda x: serverPort.stopListening()) | |
317 return d | |
318 | |
319 | |
320 def loopbackUNIX(server, client, noisy=True): | |
321 """Run session between server and client protocol instances over UNIX socket
.""" | |
322 path = tempfile.mktemp() | |
323 from twisted.internet import reactor | |
324 f = policies.WrappingFactory(protocol.Factory()) | |
325 serverWrapper = _FireOnClose(f, server) | |
326 f.noisy = noisy | |
327 f.buildProtocol = lambda addr: serverWrapper | |
328 serverPort = reactor.listenUNIX(path, f) | |
329 clientF = LoopbackClientFactory(client) | |
330 clientF.noisy = noisy | |
331 reactor.connectUNIX(path, clientF) | |
332 d = clientF.deferred | |
333 d.addCallback(lambda x: serverWrapper.deferred) | |
334 d.addCallback(lambda x: serverPort.stopListening()) | |
335 return d | |
OLD | NEW |