OLD | NEW |
| (Empty) |
1 # -*- test-case-name: twisted.test.test_protocols -*- | |
2 # Copyright (c) 2001-2008 Twisted Matrix Laboratories. | |
3 # See LICENSE for details. | |
4 | |
5 | |
6 """ | |
7 Basic protocols, such as line-oriented, netstring, and int prefixed strings. | |
8 | |
9 Maintainer: U{Itamar Shtull-Trauring<mailto:twisted@itamarst.org>} | |
10 """ | |
11 | |
12 # System imports | |
13 import re | |
14 import struct | |
15 | |
16 from zope.interface import implements | |
17 | |
18 # Twisted imports | |
19 from twisted.internet import protocol, defer, interfaces, error | |
20 from twisted.python import log | |
21 | |
22 LENGTH, DATA, COMMA = range(3) | |
23 NUMBER = re.compile('(\d*)(:?)') | |
24 DEBUG = 0 | |
25 | |
26 class NetstringParseError(ValueError): | |
27 """The incoming data is not in valid Netstring format.""" | |
28 pass | |
29 | |
30 | |
31 class NetstringReceiver(protocol.Protocol): | |
32 """This uses djb's Netstrings protocol to break up the input into strings. | |
33 | |
34 Each string makes a callback to stringReceived, with a single | |
35 argument of that string. | |
36 | |
37 Security features: | |
38 1. Messages are limited in size, useful if you don't want someone | |
39 sending you a 500MB netstring (change MAX_LENGTH to the maximum | |
40 length you wish to accept). | |
41 2. The connection is lost if an illegal message is received. | |
42 """ | |
43 | |
44 MAX_LENGTH = 99999 | |
45 brokenPeer = 0 | |
46 _readerState = LENGTH | |
47 _readerLength = 0 | |
48 | |
49 def stringReceived(self, line): | |
50 """ | |
51 Override this. | |
52 """ | |
53 raise NotImplementedError | |
54 | |
55 def doData(self): | |
56 buffer,self.__data = self.__data[:int(self._readerLength)],self.__data[i
nt(self._readerLength):] | |
57 self._readerLength = self._readerLength - len(buffer) | |
58 self.__buffer = self.__buffer + buffer | |
59 if self._readerLength != 0: | |
60 return | |
61 self.stringReceived(self.__buffer) | |
62 self._readerState = COMMA | |
63 | |
64 def doComma(self): | |
65 self._readerState = LENGTH | |
66 if self.__data[0] != ',': | |
67 if DEBUG: | |
68 raise NetstringParseError(repr(self.__data)) | |
69 else: | |
70 raise NetstringParseError | |
71 self.__data = self.__data[1:] | |
72 | |
73 | |
74 def doLength(self): | |
75 m = NUMBER.match(self.__data) | |
76 if not m.end(): | |
77 if DEBUG: | |
78 raise NetstringParseError(repr(self.__data)) | |
79 else: | |
80 raise NetstringParseError | |
81 self.__data = self.__data[m.end():] | |
82 if m.group(1): | |
83 try: | |
84 self._readerLength = self._readerLength * (10**len(m.group(1)))
+ long(m.group(1)) | |
85 except OverflowError: | |
86 raise NetstringParseError, "netstring too long" | |
87 if self._readerLength > self.MAX_LENGTH: | |
88 raise NetstringParseError, "netstring too long" | |
89 if m.group(2): | |
90 self.__buffer = '' | |
91 self._readerState = DATA | |
92 | |
93 def dataReceived(self, data): | |
94 self.__data = data | |
95 try: | |
96 while self.__data: | |
97 if self._readerState == DATA: | |
98 self.doData() | |
99 elif self._readerState == COMMA: | |
100 self.doComma() | |
101 elif self._readerState == LENGTH: | |
102 self.doLength() | |
103 else: | |
104 raise RuntimeError, "mode is not DATA, COMMA or LENGTH" | |
105 except NetstringParseError: | |
106 self.transport.loseConnection() | |
107 self.brokenPeer = 1 | |
108 | |
109 def sendString(self, data): | |
110 self.transport.write('%d:%s,' % (len(data), data)) | |
111 | |
112 | |
113 class SafeNetstringReceiver(NetstringReceiver): | |
114 """This class is deprecated, use NetstringReceiver instead. | |
115 """ | |
116 | |
117 | |
118 class LineOnlyReceiver(protocol.Protocol): | |
119 """A protocol that receives only lines. | |
120 | |
121 This is purely a speed optimisation over LineReceiver, for the | |
122 cases that raw mode is known to be unnecessary. | |
123 | |
124 @cvar delimiter: The line-ending delimiter to use. By default this is | |
125 '\\r\\n'. | |
126 @cvar MAX_LENGTH: The maximum length of a line to allow (If a | |
127 sent line is longer than this, the connection is dropped). | |
128 Default is 16384. | |
129 """ | |
130 _buffer = '' | |
131 delimiter = '\r\n' | |
132 MAX_LENGTH = 16384 | |
133 | |
134 def dataReceived(self, data): | |
135 """Translates bytes into lines, and calls lineReceived.""" | |
136 lines = (self._buffer+data).split(self.delimiter) | |
137 self._buffer = lines.pop(-1) | |
138 for line in lines: | |
139 if self.transport.disconnecting: | |
140 # this is necessary because the transport may be told to lose | |
141 # the connection by a line within a larger packet, and it is | |
142 # important to disregard all the lines in that packet following | |
143 # the one that told it to close. | |
144 return | |
145 if len(line) > self.MAX_LENGTH: | |
146 return self.lineLengthExceeded(line) | |
147 else: | |
148 self.lineReceived(line) | |
149 if len(self._buffer) > self.MAX_LENGTH: | |
150 return self.lineLengthExceeded(self._buffer) | |
151 | |
152 def lineReceived(self, line): | |
153 """Override this for when each line is received. | |
154 """ | |
155 raise NotImplementedError | |
156 | |
157 def sendLine(self, line): | |
158 """Sends a line to the other end of the connection. | |
159 """ | |
160 return self.transport.writeSequence((line,self.delimiter)) | |
161 | |
162 def lineLengthExceeded(self, line): | |
163 """Called when the maximum line length has been reached. | |
164 Override if it needs to be dealt with in some special way. | |
165 """ | |
166 return error.ConnectionLost('Line length exceeded') | |
167 | |
168 | |
169 class _PauseableMixin: | |
170 paused = False | |
171 | |
172 def pauseProducing(self): | |
173 self.paused = True | |
174 self.transport.pauseProducing() | |
175 | |
176 def resumeProducing(self): | |
177 self.paused = False | |
178 self.transport.resumeProducing() | |
179 self.dataReceived('') | |
180 | |
181 def stopProducing(self): | |
182 self.paused = True | |
183 self.transport.stopProducing() | |
184 | |
185 | |
186 class LineReceiver(protocol.Protocol, _PauseableMixin): | |
187 """A protocol that receives lines and/or raw data, depending on mode. | |
188 | |
189 In line mode, each line that's received becomes a callback to | |
190 L{lineReceived}. In raw data mode, each chunk of raw data becomes a | |
191 callback to L{rawDataReceived}. The L{setLineMode} and L{setRawMode} | |
192 methods switch between the two modes. | |
193 | |
194 This is useful for line-oriented protocols such as IRC, HTTP, POP, etc. | |
195 | |
196 @cvar delimiter: The line-ending delimiter to use. By default this is | |
197 '\\r\\n'. | |
198 @cvar MAX_LENGTH: The maximum length of a line to allow (If a | |
199 sent line is longer than this, the connection is dropped). | |
200 Default is 16384. | |
201 """ | |
202 line_mode = 1 | |
203 __buffer = '' | |
204 delimiter = '\r\n' | |
205 MAX_LENGTH = 16384 | |
206 | |
207 def clearLineBuffer(self): | |
208 """Clear buffered data.""" | |
209 self.__buffer = "" | |
210 | |
211 def dataReceived(self, data): | |
212 """Protocol.dataReceived. | |
213 Translates bytes into lines, and calls lineReceived (or | |
214 rawDataReceived, depending on mode.) | |
215 """ | |
216 self.__buffer = self.__buffer+data | |
217 while self.line_mode and not self.paused: | |
218 try: | |
219 line, self.__buffer = self.__buffer.split(self.delimiter, 1) | |
220 except ValueError: | |
221 if len(self.__buffer) > self.MAX_LENGTH: | |
222 line, self.__buffer = self.__buffer, '' | |
223 return self.lineLengthExceeded(line) | |
224 break | |
225 else: | |
226 linelength = len(line) | |
227 if linelength > self.MAX_LENGTH: | |
228 exceeded = line + self.__buffer | |
229 self.__buffer = '' | |
230 return self.lineLengthExceeded(exceeded) | |
231 why = self.lineReceived(line) | |
232 if why or self.transport and self.transport.disconnecting: | |
233 return why | |
234 else: | |
235 if not self.paused: | |
236 data=self.__buffer | |
237 self.__buffer='' | |
238 if data: | |
239 return self.rawDataReceived(data) | |
240 | |
241 def setLineMode(self, extra=''): | |
242 """Sets the line-mode of this receiver. | |
243 | |
244 If you are calling this from a rawDataReceived callback, | |
245 you can pass in extra unhandled data, and that data will | |
246 be parsed for lines. Further data received will be sent | |
247 to lineReceived rather than rawDataReceived. | |
248 | |
249 Do not pass extra data if calling this function from | |
250 within a lineReceived callback. | |
251 """ | |
252 self.line_mode = 1 | |
253 if extra: | |
254 return self.dataReceived(extra) | |
255 | |
256 def setRawMode(self): | |
257 """Sets the raw mode of this receiver. | |
258 Further data received will be sent to rawDataReceived rather | |
259 than lineReceived. | |
260 """ | |
261 self.line_mode = 0 | |
262 | |
263 def rawDataReceived(self, data): | |
264 """Override this for when raw data is received. | |
265 """ | |
266 raise NotImplementedError | |
267 | |
268 def lineReceived(self, line): | |
269 """Override this for when each line is received. | |
270 """ | |
271 raise NotImplementedError | |
272 | |
273 def sendLine(self, line): | |
274 """Sends a line to the other end of the connection. | |
275 """ | |
276 return self.transport.write(line + self.delimiter) | |
277 | |
278 def lineLengthExceeded(self, line): | |
279 """Called when the maximum line length has been reached. | |
280 Override if it needs to be dealt with in some special way. | |
281 | |
282 The argument 'line' contains the remainder of the buffer, starting | |
283 with (at least some part) of the line which is too long. This may | |
284 be more than one line, or may be only the initial portion of the | |
285 line. | |
286 """ | |
287 return self.transport.loseConnection() | |
288 | |
289 | |
290 class StringTooLongError(AssertionError): | |
291 """ | |
292 Raised when trying to send a string too long for a length prefixed | |
293 protocol. | |
294 """ | |
295 | |
296 | |
297 class IntNStringReceiver(protocol.Protocol, _PauseableMixin): | |
298 """ | |
299 Generic class for length prefixed protocols. | |
300 | |
301 @ivar recvd: buffer holding received data when splitted. | |
302 @type recvd: C{str} | |
303 | |
304 @ivar structFormat: format used for struct packing/unpacking. Define it in | |
305 subclass. | |
306 @type structFormat: C{str} | |
307 | |
308 @ivar prefixLength: length of the prefix, in bytes. Define it in subclass, | |
309 using C{struct.calcSize(structFormat)} | |
310 @type prefixLength: C{int} | |
311 """ | |
312 MAX_LENGTH = 99999 | |
313 recvd = "" | |
314 | |
315 def stringReceived(self, msg): | |
316 """ | |
317 Override this. | |
318 """ | |
319 raise NotImplementedError | |
320 | |
321 def dataReceived(self, recd): | |
322 """ | |
323 Convert int prefixed strings into calls to stringReceived. | |
324 """ | |
325 self.recvd = self.recvd + recd | |
326 while len(self.recvd) >= self.prefixLength and not self.paused: | |
327 length ,= struct.unpack( | |
328 self.structFormat, self.recvd[:self.prefixLength]) | |
329 if length > self.MAX_LENGTH: | |
330 self.transport.loseConnection() | |
331 return | |
332 if len(self.recvd) < length + self.prefixLength: | |
333 break | |
334 packet = self.recvd[self.prefixLength:length + self.prefixLength] | |
335 self.recvd = self.recvd[length + self.prefixLength:] | |
336 self.stringReceived(packet) | |
337 | |
338 def sendString(self, data): | |
339 """ | |
340 Send an prefixed string to the other end of the connection. | |
341 | |
342 @type data: C{str} | |
343 """ | |
344 if len(data) >= 2 ** (8 * self.prefixLength): | |
345 raise StringTooLongError( | |
346 "Try to send %s bytes whereas maximum is %s" % ( | |
347 len(data), 2 ** (8 * self.prefixLength))) | |
348 self.transport.write(struct.pack(self.structFormat, len(data)) + data) | |
349 | |
350 | |
351 class Int32StringReceiver(IntNStringReceiver): | |
352 """ | |
353 A receiver for int32-prefixed strings. | |
354 | |
355 An int32 string is a string prefixed by 4 bytes, the 32-bit length of | |
356 the string encoded in network byte order. | |
357 | |
358 This class publishes the same interface as NetstringReceiver. | |
359 """ | |
360 structFormat = "!I" | |
361 prefixLength = struct.calcsize(structFormat) | |
362 | |
363 | |
364 class Int16StringReceiver(IntNStringReceiver): | |
365 """ | |
366 A receiver for int16-prefixed strings. | |
367 | |
368 An int16 string is a string prefixed by 2 bytes, the 16-bit length of | |
369 the string encoded in network byte order. | |
370 | |
371 This class publishes the same interface as NetstringReceiver. | |
372 """ | |
373 structFormat = "!H" | |
374 prefixLength = struct.calcsize(structFormat) | |
375 | |
376 | |
377 class Int8StringReceiver(IntNStringReceiver): | |
378 """ | |
379 A receiver for int8-prefixed strings. | |
380 | |
381 An int8 string is a string prefixed by 1 byte, the 8-bit length of | |
382 the string. | |
383 | |
384 This class publishes the same interface as NetstringReceiver. | |
385 """ | |
386 structFormat = "!B" | |
387 prefixLength = struct.calcsize(structFormat) | |
388 | |
389 | |
390 class StatefulStringProtocol: | |
391 """ | |
392 A stateful string protocol. | |
393 | |
394 This is a mixin for string protocols (Int32StringReceiver, | |
395 NetstringReceiver) which translates stringReceived into a callback | |
396 (prefixed with 'proto_') depending on state. | |
397 | |
398 The state 'done' is special; if a proto_* method returns it, the | |
399 connection will be closed immediately. | |
400 """ | |
401 | |
402 state = 'init' | |
403 | |
404 def stringReceived(self,string): | |
405 """Choose a protocol phase function and call it. | |
406 | |
407 Call back to the appropriate protocol phase; this begins with | |
408 the function proto_init and moves on to proto_* depending on | |
409 what each proto_* function returns. (For example, if | |
410 self.proto_init returns 'foo', then self.proto_foo will be the | |
411 next function called when a protocol message is received. | |
412 """ | |
413 try: | |
414 pto = 'proto_'+self.state | |
415 statehandler = getattr(self,pto) | |
416 except AttributeError: | |
417 log.msg('callback',self.state,'not found') | |
418 else: | |
419 self.state = statehandler(string) | |
420 if self.state == 'done': | |
421 self.transport.loseConnection() | |
422 | |
423 class FileSender: | |
424 """A producer that sends the contents of a file to a consumer. | |
425 | |
426 This is a helper for protocols that, at some point, will take a | |
427 file-like object, read its contents, and write them out to the network, | |
428 optionally performing some transformation on the bytes in between. | |
429 """ | |
430 implements(interfaces.IProducer) | |
431 | |
432 CHUNK_SIZE = 2 ** 14 | |
433 | |
434 lastSent = '' | |
435 deferred = None | |
436 | |
437 def beginFileTransfer(self, file, consumer, transform = None): | |
438 """Begin transferring a file | |
439 | |
440 @type file: Any file-like object | |
441 @param file: The file object to read data from | |
442 | |
443 @type consumer: Any implementor of IConsumer | |
444 @param consumer: The object to write data to | |
445 | |
446 @param transform: A callable taking one string argument and returning | |
447 the same. All bytes read from the file are passed through this before | |
448 being written to the consumer. | |
449 | |
450 @rtype: C{Deferred} | |
451 @return: A deferred whose callback will be invoked when the file has bee
n | |
452 completely written to the consumer. The last byte written to the consum
er | |
453 is passed to the callback. | |
454 """ | |
455 self.file = file | |
456 self.consumer = consumer | |
457 self.transform = transform | |
458 | |
459 self.deferred = deferred = defer.Deferred() | |
460 self.consumer.registerProducer(self, False) | |
461 return deferred | |
462 | |
463 def resumeProducing(self): | |
464 chunk = '' | |
465 if self.file: | |
466 chunk = self.file.read(self.CHUNK_SIZE) | |
467 if not chunk: | |
468 self.file = None | |
469 self.consumer.unregisterProducer() | |
470 if self.deferred: | |
471 self.deferred.callback(self.lastSent) | |
472 self.deferred = None | |
473 return | |
474 | |
475 if self.transform: | |
476 chunk = self.transform(chunk) | |
477 self.consumer.write(chunk) | |
478 self.lastSent = chunk[-1] | |
479 | |
480 def pauseProducing(self): | |
481 pass | |
482 | |
483 def stopProducing(self): | |
484 if self.deferred: | |
485 self.deferred.errback(Exception("Consumer asked us to stop producing
")) | |
486 self.deferred = None | |
OLD | NEW |