OLD | NEW |
| (Empty) |
1 # Copyright (c) 2001-2004 Twisted Matrix Laboratories. | |
2 # See LICENSE for details. | |
3 | |
4 | |
5 """ | |
6 Test case for twisted.protocols.loopback | |
7 """ | |
8 | |
9 from zope.interface import implements | |
10 | |
11 from twisted.trial import unittest | |
12 from twisted.trial.util import suppress as SUPPRESS | |
13 from twisted.protocols import basic, loopback | |
14 from twisted.internet import defer | |
15 from twisted.internet.protocol import Protocol | |
16 from twisted.internet.defer import Deferred | |
17 from twisted.internet.interfaces import IAddress, IPushProducer, IPullProducer | |
18 from twisted.internet import reactor | |
19 | |
20 | |
21 class SimpleProtocol(basic.LineReceiver): | |
22 def __init__(self): | |
23 self.conn = defer.Deferred() | |
24 self.lines = [] | |
25 self.connLost = [] | |
26 | |
27 def connectionMade(self): | |
28 self.conn.callback(None) | |
29 | |
30 def lineReceived(self, line): | |
31 self.lines.append(line) | |
32 | |
33 def connectionLost(self, reason): | |
34 self.connLost.append(reason) | |
35 | |
36 | |
37 class DoomProtocol(SimpleProtocol): | |
38 i = 0 | |
39 def lineReceived(self, line): | |
40 self.i += 1 | |
41 if self.i < 4: | |
42 # by this point we should have connection closed, | |
43 # but just in case we didn't we won't ever send 'Hello 4' | |
44 self.sendLine("Hello %d" % self.i) | |
45 SimpleProtocol.lineReceived(self, line) | |
46 if self.lines[-1] == "Hello 3": | |
47 self.transport.loseConnection() | |
48 | |
49 | |
50 class LoopbackTestCaseMixin: | |
51 def testRegularFunction(self): | |
52 s = SimpleProtocol() | |
53 c = SimpleProtocol() | |
54 | |
55 def sendALine(result): | |
56 s.sendLine("THIS IS LINE ONE!") | |
57 s.transport.loseConnection() | |
58 s.conn.addCallback(sendALine) | |
59 | |
60 def check(ignored): | |
61 self.assertEquals(c.lines, ["THIS IS LINE ONE!"]) | |
62 self.assertEquals(len(s.connLost), 1) | |
63 self.assertEquals(len(c.connLost), 1) | |
64 d = defer.maybeDeferred(self.loopbackFunc, s, c) | |
65 d.addCallback(check) | |
66 return d | |
67 | |
68 def testSneakyHiddenDoom(self): | |
69 s = DoomProtocol() | |
70 c = DoomProtocol() | |
71 | |
72 def sendALine(result): | |
73 s.sendLine("DOOM LINE") | |
74 s.conn.addCallback(sendALine) | |
75 | |
76 def check(ignored): | |
77 self.assertEquals(s.lines, ['Hello 1', 'Hello 2', 'Hello 3']) | |
78 self.assertEquals(c.lines, ['DOOM LINE', 'Hello 1', 'Hello 2', 'Hell
o 3']) | |
79 self.assertEquals(len(s.connLost), 1) | |
80 self.assertEquals(len(c.connLost), 1) | |
81 d = defer.maybeDeferred(self.loopbackFunc, s, c) | |
82 d.addCallback(check) | |
83 return d | |
84 | |
85 | |
86 | |
87 class LoopbackTestCase(LoopbackTestCaseMixin, unittest.TestCase): | |
88 loopbackFunc = staticmethod(loopback.loopback) | |
89 | |
90 def testRegularFunction(self): | |
91 """ | |
92 Suppress loopback deprecation warning. | |
93 """ | |
94 return LoopbackTestCaseMixin.testRegularFunction(self) | |
95 testRegularFunction.suppress = [ | |
96 SUPPRESS(message="loopback\(\) is deprecated", | |
97 category=DeprecationWarning)] | |
98 | |
99 | |
100 | |
101 class LoopbackAsyncTestCase(LoopbackTestCase): | |
102 loopbackFunc = staticmethod(loopback.loopbackAsync) | |
103 | |
104 | |
105 def test_makeConnection(self): | |
106 """ | |
107 Test that the client and server protocol both have makeConnection | |
108 invoked on them by loopbackAsync. | |
109 """ | |
110 class TestProtocol(Protocol): | |
111 transport = None | |
112 def makeConnection(self, transport): | |
113 self.transport = transport | |
114 | |
115 server = TestProtocol() | |
116 client = TestProtocol() | |
117 loopback.loopbackAsync(server, client) | |
118 self.failIfEqual(client.transport, None) | |
119 self.failIfEqual(server.transport, None) | |
120 | |
121 | |
122 def _hostpeertest(self, get, testServer): | |
123 """ | |
124 Test one of the permutations of client/server host/peer. | |
125 """ | |
126 class TestProtocol(Protocol): | |
127 def makeConnection(self, transport): | |
128 Protocol.makeConnection(self, transport) | |
129 self.onConnection.callback(transport) | |
130 | |
131 if testServer: | |
132 server = TestProtocol() | |
133 d = server.onConnection = Deferred() | |
134 client = Protocol() | |
135 else: | |
136 server = Protocol() | |
137 client = TestProtocol() | |
138 d = client.onConnection = Deferred() | |
139 | |
140 loopback.loopbackAsync(server, client) | |
141 | |
142 def connected(transport): | |
143 host = getattr(transport, get)() | |
144 self.failUnless(IAddress.providedBy(host)) | |
145 | |
146 return d.addCallback(connected) | |
147 | |
148 | |
149 def test_serverHost(self): | |
150 """ | |
151 Test that the server gets a transport with a properly functioning | |
152 implementation of L{ITransport.getHost}. | |
153 """ | |
154 return self._hostpeertest("getHost", True) | |
155 | |
156 | |
157 def test_serverPeer(self): | |
158 """ | |
159 Like C{test_serverHost} but for L{ITransport.getPeer} | |
160 """ | |
161 return self._hostpeertest("getPeer", True) | |
162 | |
163 | |
164 def test_clientHost(self, get="getHost"): | |
165 """ | |
166 Test that the client gets a transport with a properly functioning | |
167 implementation of L{ITransport.getHost}. | |
168 """ | |
169 return self._hostpeertest("getHost", False) | |
170 | |
171 | |
172 def test_clientPeer(self): | |
173 """ | |
174 Like C{test_clientHost} but for L{ITransport.getPeer}. | |
175 """ | |
176 return self._hostpeertest("getPeer", False) | |
177 | |
178 | |
179 def _greetingtest(self, write, testServer): | |
180 """ | |
181 Test one of the permutations of write/writeSequence client/server. | |
182 """ | |
183 class GreeteeProtocol(Protocol): | |
184 bytes = "" | |
185 def dataReceived(self, bytes): | |
186 self.bytes += bytes | |
187 if self.bytes == "bytes": | |
188 self.received.callback(None) | |
189 | |
190 class GreeterProtocol(Protocol): | |
191 def connectionMade(self): | |
192 getattr(self.transport, write)("bytes") | |
193 | |
194 if testServer: | |
195 server = GreeterProtocol() | |
196 client = GreeteeProtocol() | |
197 d = client.received = Deferred() | |
198 else: | |
199 server = GreeteeProtocol() | |
200 d = server.received = Deferred() | |
201 client = GreeterProtocol() | |
202 | |
203 loopback.loopbackAsync(server, client) | |
204 return d | |
205 | |
206 | |
207 def test_clientGreeting(self): | |
208 """ | |
209 Test that on a connection where the client speaks first, the server | |
210 receives the bytes sent by the client. | |
211 """ | |
212 return self._greetingtest("write", False) | |
213 | |
214 | |
215 def test_clientGreetingSequence(self): | |
216 """ | |
217 Like C{test_clientGreeting}, but use C{writeSequence} instead of | |
218 C{write} to issue the greeting. | |
219 """ | |
220 return self._greetingtest("writeSequence", False) | |
221 | |
222 | |
223 def test_serverGreeting(self, write="write"): | |
224 """ | |
225 Test that on a connection where the server speaks first, the client | |
226 receives the bytes sent by the server. | |
227 """ | |
228 return self._greetingtest("write", True) | |
229 | |
230 | |
231 def test_serverGreetingSequence(self): | |
232 """ | |
233 Like C{test_serverGreeting}, but use C{writeSequence} instead of | |
234 C{write} to issue the greeting. | |
235 """ | |
236 return self._greetingtest("writeSequence", True) | |
237 | |
238 | |
239 def _producertest(self, producerClass): | |
240 toProduce = map(str, range(0, 10)) | |
241 | |
242 class ProducingProtocol(Protocol): | |
243 def connectionMade(self): | |
244 self.producer = producerClass(list(toProduce)) | |
245 self.producer.start(self.transport) | |
246 | |
247 class ReceivingProtocol(Protocol): | |
248 bytes = "" | |
249 def dataReceived(self, bytes): | |
250 self.bytes += bytes | |
251 if self.bytes == ''.join(toProduce): | |
252 self.received.callback((client, server)) | |
253 | |
254 server = ProducingProtocol() | |
255 client = ReceivingProtocol() | |
256 client.received = Deferred() | |
257 | |
258 loopback.loopbackAsync(server, client) | |
259 return client.received | |
260 | |
261 | |
262 def test_pushProducer(self): | |
263 """ | |
264 Test a push producer registered against a loopback transport. | |
265 """ | |
266 class PushProducer(object): | |
267 implements(IPushProducer) | |
268 resumed = False | |
269 | |
270 def __init__(self, toProduce): | |
271 self.toProduce = toProduce | |
272 | |
273 def resumeProducing(self): | |
274 self.resumed = True | |
275 | |
276 def start(self, consumer): | |
277 self.consumer = consumer | |
278 consumer.registerProducer(self, True) | |
279 self._produceAndSchedule() | |
280 | |
281 def _produceAndSchedule(self): | |
282 if self.toProduce: | |
283 self.consumer.write(self.toProduce.pop(0)) | |
284 reactor.callLater(0, self._produceAndSchedule) | |
285 else: | |
286 self.consumer.unregisterProducer() | |
287 d = self._producertest(PushProducer) | |
288 | |
289 def finished((client, server)): | |
290 self.failIf( | |
291 server.producer.resumed, | |
292 "Streaming producer should not have been resumed.") | |
293 d.addCallback(finished) | |
294 return d | |
295 | |
296 | |
297 def test_pullProducer(self): | |
298 """ | |
299 Test a pull producer registered against a loopback transport. | |
300 """ | |
301 class PullProducer(object): | |
302 implements(IPullProducer) | |
303 | |
304 def __init__(self, toProduce): | |
305 self.toProduce = toProduce | |
306 | |
307 def start(self, consumer): | |
308 self.consumer = consumer | |
309 self.consumer.registerProducer(self, False) | |
310 | |
311 def resumeProducing(self): | |
312 self.consumer.write(self.toProduce.pop(0)) | |
313 if not self.toProduce: | |
314 self.consumer.unregisterProducer() | |
315 return self._producertest(PullProducer) | |
316 | |
317 | |
318 class LoopbackTCPTestCase(LoopbackTestCase): | |
319 loopbackFunc = staticmethod(loopback.loopbackTCP) | |
320 | |
321 | |
322 class LoopbackUNIXTestCase(LoopbackTestCase): | |
323 loopbackFunc = staticmethod(loopback.loopbackUNIX) | |
324 | |
325 def setUp(self): | |
326 from twisted.internet import reactor, interfaces | |
327 if interfaces.IReactorUNIX(reactor, None) is None: | |
328 raise unittest.SkipTest("Current reactor does not support UNIX socke
ts") | |
OLD | NEW |