OLD | NEW |
| (Empty) |
1 # -*- test-case-name: twisted.test.test_amp.TLSTest -*- | |
2 """Utilities and helpers for simulating a network | |
3 """ | |
4 | |
5 import itertools | |
6 | |
7 | |
8 from zope.interface import implements, directlyProvides | |
9 | |
10 from twisted.internet import error | |
11 from twisted.internet import interfaces | |
12 from OpenSSL.SSL import Error as NativeOpenSSLError | |
13 | |
14 from twisted.internet._sslverify import OpenSSLVerifyError | |
15 | |
16 class TLSNegotiation: | |
17 def __init__(self, obj, connectState): | |
18 self.obj = obj | |
19 self.connectState = connectState | |
20 self.sent = False | |
21 self.readyToSend = connectState | |
22 | |
23 def __repr__(self): | |
24 return 'TLSNegotiation(%r)' % (self.obj,) | |
25 | |
26 def pretendToVerify(self, other, tpt): | |
27 # Set the transport problems list here? disconnections? | |
28 # hmmmmm... need some negative path tests. | |
29 | |
30 if not self.obj.iosimVerify(other.obj): | |
31 tpt.problems.append(OpenSSLVerifyError("fake cert", "fake errno", "f
ake depth")) | |
32 tpt.disconnectReason = NativeOpenSSLError() | |
33 tpt.loseConnection() | |
34 | |
35 | |
36 class FakeTransport: | |
37 """A wrapper around a file-like object to make it behave as a Transport. | |
38 | |
39 This doesn't actually stream the file to the attached protocol, | |
40 and is thus useful mainly as a utility for debugging protocols. | |
41 """ | |
42 | |
43 implements(interfaces.ITransport, | |
44 interfaces.ITLSTransport) # ha ha not really | |
45 | |
46 _nextserial = itertools.count().next | |
47 closed = 0 | |
48 disconnecting = 0 | |
49 disconnected = 0 | |
50 disconnectReason = error.ConnectionDone("Connection done") | |
51 producer = None | |
52 streamingProducer = 0 | |
53 tls = None | |
54 | |
55 def __init__(self): | |
56 self.stream = [] | |
57 self.problems = [] | |
58 self.serial = self._nextserial() | |
59 | |
60 def __repr__(self): | |
61 return 'FakeTransport<%s,%s,%s>' % ( | |
62 self.isServer and 'S' or 'C', self.serial, | |
63 self.protocol.__class__.__name__) | |
64 | |
65 def write(self, data): | |
66 if self.tls is not None: | |
67 self.tlsbuf.append(data) | |
68 else: | |
69 self.stream.append(data) | |
70 | |
71 def _checkProducer(self): | |
72 # Cheating; this is called at "idle" times to allow producers to be | |
73 # found and dealt with | |
74 if self.producer: | |
75 self.producer.resumeProducing() | |
76 | |
77 def registerProducer(self, producer, streaming): | |
78 """From abstract.FileDescriptor | |
79 """ | |
80 self.producer = producer | |
81 self.streamingProducer = streaming | |
82 if not streaming: | |
83 producer.resumeProducing() | |
84 | |
85 def unregisterProducer(self): | |
86 self.producer = None | |
87 | |
88 def stopConsuming(self): | |
89 self.unregisterProducer() | |
90 self.loseConnection() | |
91 | |
92 def writeSequence(self, iovec): | |
93 self.write("".join(iovec)) | |
94 | |
95 def loseConnection(self): | |
96 self.disconnecting = True | |
97 | |
98 def reportDisconnect(self): | |
99 if self.tls is not None: | |
100 # We were in the middle of negotiating! Must have been a TLS proble
m. | |
101 err = NativeOpenSSLError() | |
102 else: | |
103 err = self.disconnectReason | |
104 self.protocol.connectionLost(err) | |
105 | |
106 def getPeer(self): | |
107 # XXX: According to ITransport, this should return an IAddress! | |
108 return 'file', 'file' | |
109 | |
110 def getHost(self): | |
111 # XXX: According to ITransport, this should return an IAddress! | |
112 return 'file' | |
113 | |
114 def resumeProducing(self): | |
115 # Never sends data anyways | |
116 pass | |
117 | |
118 def pauseProducing(self): | |
119 # Never sends data anyways | |
120 pass | |
121 | |
122 def stopProducing(self): | |
123 self.loseConnection() | |
124 | |
125 def startTLS(self, contextFactory, beNormal=True): | |
126 # Nothing's using this feature yet, but startTLS has an undocumented | |
127 # second argument which defaults to true; if set to False, servers will | |
128 # behave like clients and clients will behave like servers. | |
129 connectState = self.isServer ^ beNormal | |
130 self.tls = TLSNegotiation(contextFactory, connectState) | |
131 self.tlsbuf = [] | |
132 | |
133 def getOutBuffer(self): | |
134 S = self.stream | |
135 if S: | |
136 self.stream = [] | |
137 return ''.join(S) | |
138 elif self.tls is not None: | |
139 if self.tls.readyToSend: | |
140 # Only _send_ the TLS negotiation "packet" if I'm ready to. | |
141 self.tls.sent = True | |
142 return self.tls | |
143 else: | |
144 return None | |
145 else: | |
146 return None | |
147 | |
148 def bufferReceived(self, buf): | |
149 if isinstance(buf, TLSNegotiation): | |
150 assert self.tls is not None # By the time you're receiving a | |
151 # negotiation, you have to have called | |
152 # startTLS already. | |
153 if self.tls.sent: | |
154 self.tls.pretendToVerify(buf, self) | |
155 self.tls = None # we're done with the handshake if we've gotten | |
156 # this far... although maybe it failed...? | |
157 # TLS started! Unbuffer... | |
158 b, self.tlsbuf = self.tlsbuf, None | |
159 self.writeSequence(b) | |
160 directlyProvides(self, interfaces.ISSLTransport) | |
161 else: | |
162 # We haven't sent our own TLS negotiation: time to do that! | |
163 self.tls.readyToSend = True | |
164 else: | |
165 self.protocol.dataReceived(buf) | |
166 | |
167 | |
168 # this next bit is just to fake out problemsFromTransport, which is an | |
169 # ultra-shitty API anyway. remove it when we manage to remove that. -glyph | |
170 def getHandle(self): | |
171 return self | |
172 | |
173 get_context = getHandle | |
174 get_app_data = getHandle | |
175 | |
176 # end of gross problemsFromTransport stuff | |
177 | |
178 def makeFakeClient(c): | |
179 ft = FakeTransport() | |
180 ft.isServer = False | |
181 ft.protocol = c | |
182 return ft | |
183 | |
184 def makeFakeServer(s): | |
185 ft = FakeTransport() | |
186 ft.isServer = True | |
187 ft.protocol = s | |
188 return ft | |
189 | |
190 class IOPump: | |
191 """Utility to pump data between clients and servers for protocol testing. | |
192 | |
193 Perhaps this is a utility worthy of being in protocol.py? | |
194 """ | |
195 def __init__(self, client, server, clientIO, serverIO, debug): | |
196 self.client = client | |
197 self.server = server | |
198 self.clientIO = clientIO | |
199 self.serverIO = serverIO | |
200 self.debug = debug | |
201 | |
202 def flush(self, debug=False): | |
203 """Pump until there is no more input or output. | |
204 | |
205 Returns whether any data was moved. | |
206 """ | |
207 result = False | |
208 for x in range(1000): | |
209 if self.pump(debug): | |
210 result = True | |
211 else: | |
212 break | |
213 else: | |
214 assert 0, "Too long" | |
215 return result | |
216 | |
217 | |
218 def pump(self, debug=False): | |
219 """Move data back and forth. | |
220 | |
221 Returns whether any data was moved. | |
222 """ | |
223 if self.debug or debug: | |
224 print '-- GLUG --' | |
225 sData = self.serverIO.getOutBuffer() | |
226 cData = self.clientIO.getOutBuffer() | |
227 self.clientIO._checkProducer() | |
228 self.serverIO._checkProducer() | |
229 if self.debug or debug: | |
230 print '.' | |
231 # XXX slightly buggy in the face of incremental output | |
232 if cData: | |
233 print 'C: '+repr(cData) | |
234 if sData: | |
235 print 'S: '+repr(sData) | |
236 if cData: | |
237 self.serverIO.bufferReceived(cData) | |
238 if sData: | |
239 self.clientIO.bufferReceived(sData) | |
240 if cData or sData: | |
241 return True | |
242 if (self.serverIO.disconnecting and | |
243 not self.serverIO.disconnected): | |
244 if self.debug or debug: | |
245 print '* C' | |
246 self.serverIO.disconnected = True | |
247 self.clientIO.disconnecting = True | |
248 self.clientIO.reportDisconnect() | |
249 return True | |
250 if self.clientIO.disconnecting and not self.clientIO.disconnected: | |
251 if self.debug or debug: | |
252 print '* S' | |
253 self.clientIO.disconnected = True | |
254 self.serverIO.disconnecting = True | |
255 self.serverIO.reportDisconnect() | |
256 return True | |
257 return False | |
258 | |
259 | |
260 def connectedServerAndClient(ServerClass, ClientClass, | |
261 clientTransportFactory=makeFakeClient, | |
262 serverTransportFactory=makeFakeServer, | |
263 debug=False): | |
264 """Returns a 3-tuple: (client, server, pump) | |
265 """ | |
266 c = ClientClass() | |
267 s = ServerClass() | |
268 cio = clientTransportFactory(c) | |
269 sio = serverTransportFactory(s) | |
270 c.makeConnection(cio) | |
271 s.makeConnection(sio) | |
272 pump = IOPump(c, s, cio, sio, debug) | |
273 # kick off server greeting, etc | |
274 pump.flush() | |
275 return c, s, pump | |
OLD | NEW |