OLD | NEW |
| (Empty) |
1 # -*- test-case-name: twisted.test.test_policies -*- | |
2 # Copyright (c) 2001-2007 Twisted Matrix Laboratories. | |
3 # See LICENSE for details. | |
4 | |
5 | |
6 """ | |
7 Resource limiting policies. | |
8 | |
9 @seealso: See also L{twisted.protocols.htb} for rate limiting. | |
10 """ | |
11 | |
12 # system imports | |
13 import sys, operator | |
14 | |
15 # twisted imports | |
16 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory | |
17 from twisted.internet import reactor, error | |
18 from twisted.python import log | |
19 from zope.interface import providedBy, directlyProvides | |
20 | |
21 | |
22 class ProtocolWrapper(Protocol): | |
23 """Wraps protocol instances and acts as their transport as well.""" | |
24 | |
25 disconnecting = 0 | |
26 | |
27 def __init__(self, factory, wrappedProtocol): | |
28 self.wrappedProtocol = wrappedProtocol | |
29 self.factory = factory | |
30 | |
31 def makeConnection(self, transport): | |
32 directlyProvides(self, *providedBy(self) + providedBy(transport)) | |
33 Protocol.makeConnection(self, transport) | |
34 | |
35 # Transport relaying | |
36 | |
37 def write(self, data): | |
38 self.transport.write(data) | |
39 | |
40 def writeSequence(self, data): | |
41 self.transport.writeSequence(data) | |
42 | |
43 def loseConnection(self): | |
44 self.disconnecting = 1 | |
45 self.transport.loseConnection() | |
46 | |
47 def getPeer(self): | |
48 return self.transport.getPeer() | |
49 | |
50 def getHost(self): | |
51 return self.transport.getHost() | |
52 | |
53 def registerProducer(self, producer, streaming): | |
54 self.transport.registerProducer(producer, streaming) | |
55 | |
56 def unregisterProducer(self): | |
57 self.transport.unregisterProducer() | |
58 | |
59 def stopConsuming(self): | |
60 self.transport.stopConsuming() | |
61 | |
62 def __getattr__(self, name): | |
63 return getattr(self.transport, name) | |
64 | |
65 # Protocol relaying | |
66 | |
67 def connectionMade(self): | |
68 self.factory.registerProtocol(self) | |
69 self.wrappedProtocol.makeConnection(self) | |
70 | |
71 def dataReceived(self, data): | |
72 self.wrappedProtocol.dataReceived(data) | |
73 | |
74 def connectionLost(self, reason): | |
75 self.factory.unregisterProtocol(self) | |
76 self.wrappedProtocol.connectionLost(reason) | |
77 | |
78 | |
79 class WrappingFactory(ClientFactory): | |
80 """Wraps a factory and its protocols, and keeps track of them.""" | |
81 | |
82 protocol = ProtocolWrapper | |
83 | |
84 def __init__(self, wrappedFactory): | |
85 self.wrappedFactory = wrappedFactory | |
86 self.protocols = {} | |
87 | |
88 def doStart(self): | |
89 self.wrappedFactory.doStart() | |
90 ClientFactory.doStart(self) | |
91 | |
92 def doStop(self): | |
93 self.wrappedFactory.doStop() | |
94 ClientFactory.doStop(self) | |
95 | |
96 def startedConnecting(self, connector): | |
97 self.wrappedFactory.startedConnecting(connector) | |
98 | |
99 def clientConnectionFailed(self, connector, reason): | |
100 self.wrappedFactory.clientConnectionFailed(connector, reason) | |
101 | |
102 def clientConnectionLost(self, connector, reason): | |
103 self.wrappedFactory.clientConnectionLost(connector, reason) | |
104 | |
105 def buildProtocol(self, addr): | |
106 return self.protocol(self, self.wrappedFactory.buildProtocol(addr)) | |
107 | |
108 def registerProtocol(self, p): | |
109 """Called by protocol to register itself.""" | |
110 self.protocols[p] = 1 | |
111 | |
112 def unregisterProtocol(self, p): | |
113 """Called by protocols when they go away.""" | |
114 del self.protocols[p] | |
115 | |
116 | |
117 class ThrottlingProtocol(ProtocolWrapper): | |
118 """Protocol for ThrottlingFactory.""" | |
119 | |
120 # wrap API for tracking bandwidth | |
121 | |
122 def write(self, data): | |
123 self.factory.registerWritten(len(data)) | |
124 ProtocolWrapper.write(self, data) | |
125 | |
126 def writeSequence(self, seq): | |
127 self.factory.registerWritten(reduce(operator.add, map(len, seq))) | |
128 ProtocolWrapper.writeSequence(self, seq) | |
129 | |
130 def dataReceived(self, data): | |
131 self.factory.registerRead(len(data)) | |
132 ProtocolWrapper.dataReceived(self, data) | |
133 | |
134 def registerProducer(self, producer, streaming): | |
135 self.producer = producer | |
136 ProtocolWrapper.registerProducer(self, producer, streaming) | |
137 | |
138 def unregisterProducer(self): | |
139 del self.producer | |
140 ProtocolWrapper.unregisterProducer(self) | |
141 | |
142 | |
143 def throttleReads(self): | |
144 self.transport.pauseProducing() | |
145 | |
146 def unthrottleReads(self): | |
147 self.transport.resumeProducing() | |
148 | |
149 def throttleWrites(self): | |
150 if hasattr(self, "producer"): | |
151 self.producer.pauseProducing() | |
152 | |
153 def unthrottleWrites(self): | |
154 if hasattr(self, "producer"): | |
155 self.producer.resumeProducing() | |
156 | |
157 | |
158 class ThrottlingFactory(WrappingFactory): | |
159 """ | |
160 Throttles bandwidth and number of connections. | |
161 | |
162 Write bandwidth will only be throttled if there is a producer | |
163 registered. | |
164 """ | |
165 | |
166 protocol = ThrottlingProtocol | |
167 | |
168 def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, | |
169 readLimit=None, writeLimit=None): | |
170 WrappingFactory.__init__(self, wrappedFactory) | |
171 self.connectionCount = 0 | |
172 self.maxConnectionCount = maxConnectionCount | |
173 self.readLimit = readLimit # max bytes we should read per second | |
174 self.writeLimit = writeLimit # max bytes we should write per second | |
175 self.readThisSecond = 0 | |
176 self.writtenThisSecond = 0 | |
177 self.unthrottleReadsID = None | |
178 self.checkReadBandwidthID = None | |
179 self.unthrottleWritesID = None | |
180 self.checkWriteBandwidthID = None | |
181 | |
182 | |
183 def callLater(self, period, func): | |
184 """ | |
185 Wrapper around L{reactor.callLater} for test purpose. | |
186 """ | |
187 return reactor.callLater(period, func) | |
188 | |
189 | |
190 def registerWritten(self, length): | |
191 """ | |
192 Called by protocol to tell us more bytes were written. | |
193 """ | |
194 self.writtenThisSecond += length | |
195 | |
196 | |
197 def registerRead(self, length): | |
198 """ | |
199 Called by protocol to tell us more bytes were read. | |
200 """ | |
201 self.readThisSecond += length | |
202 | |
203 | |
204 def checkReadBandwidth(self): | |
205 """ | |
206 Checks if we've passed bandwidth limits. | |
207 """ | |
208 if self.readThisSecond > self.readLimit: | |
209 self.throttleReads() | |
210 throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0 | |
211 self.unthrottleReadsID = self.callLater(throttleTime, | |
212 self.unthrottleReads) | |
213 self.readThisSecond = 0 | |
214 self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth) | |
215 | |
216 | |
217 def checkWriteBandwidth(self): | |
218 if self.writtenThisSecond > self.writeLimit: | |
219 self.throttleWrites() | |
220 throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1
.0 | |
221 self.unthrottleWritesID = self.callLater(throttleTime, | |
222 self.unthrottleWrites) | |
223 # reset for next round | |
224 self.writtenThisSecond = 0 | |
225 self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth) | |
226 | |
227 | |
228 def throttleReads(self): | |
229 """ | |
230 Throttle reads on all protocols. | |
231 """ | |
232 log.msg("Throttling reads on %s" % self) | |
233 for p in self.protocols.keys(): | |
234 p.throttleReads() | |
235 | |
236 | |
237 def unthrottleReads(self): | |
238 """ | |
239 Stop throttling reads on all protocols. | |
240 """ | |
241 self.unthrottleReadsID = None | |
242 log.msg("Stopped throttling reads on %s" % self) | |
243 for p in self.protocols.keys(): | |
244 p.unthrottleReads() | |
245 | |
246 | |
247 def throttleWrites(self): | |
248 """ | |
249 Throttle writes on all protocols. | |
250 """ | |
251 log.msg("Throttling writes on %s" % self) | |
252 for p in self.protocols.keys(): | |
253 p.throttleWrites() | |
254 | |
255 | |
256 def unthrottleWrites(self): | |
257 """ | |
258 Stop throttling writes on all protocols. | |
259 """ | |
260 self.unthrottleWritesID = None | |
261 log.msg("Stopped throttling writes on %s" % self) | |
262 for p in self.protocols.keys(): | |
263 p.unthrottleWrites() | |
264 | |
265 | |
266 def buildProtocol(self, addr): | |
267 if self.connectionCount == 0: | |
268 if self.readLimit is not None: | |
269 self.checkReadBandwidth() | |
270 if self.writeLimit is not None: | |
271 self.checkWriteBandwidth() | |
272 | |
273 if self.connectionCount < self.maxConnectionCount: | |
274 self.connectionCount += 1 | |
275 return WrappingFactory.buildProtocol(self, addr) | |
276 else: | |
277 log.msg("Max connection count reached!") | |
278 return None | |
279 | |
280 | |
281 def unregisterProtocol(self, p): | |
282 WrappingFactory.unregisterProtocol(self, p) | |
283 self.connectionCount -= 1 | |
284 if self.connectionCount == 0: | |
285 if self.unthrottleReadsID is not None: | |
286 self.unthrottleReadsID.cancel() | |
287 if self.checkReadBandwidthID is not None: | |
288 self.checkReadBandwidthID.cancel() | |
289 if self.unthrottleWritesID is not None: | |
290 self.unthrottleWritesID.cancel() | |
291 if self.checkWriteBandwidthID is not None: | |
292 self.checkWriteBandwidthID.cancel() | |
293 | |
294 | |
295 | |
296 class SpewingProtocol(ProtocolWrapper): | |
297 def dataReceived(self, data): | |
298 log.msg("Received: %r" % data) | |
299 ProtocolWrapper.dataReceived(self,data) | |
300 | |
301 def write(self, data): | |
302 log.msg("Sending: %r" % data) | |
303 ProtocolWrapper.write(self,data) | |
304 | |
305 | |
306 | |
307 class SpewingFactory(WrappingFactory): | |
308 protocol = SpewingProtocol | |
309 | |
310 | |
311 | |
312 class LimitConnectionsByPeer(WrappingFactory): | |
313 | |
314 maxConnectionsPerPeer = 5 | |
315 | |
316 def startFactory(self): | |
317 self.peerConnections = {} | |
318 | |
319 def buildProtocol(self, addr): | |
320 peerHost = addr[0] | |
321 connectionCount = self.peerConnections.get(peerHost, 0) | |
322 if connectionCount >= self.maxConnectionsPerPeer: | |
323 return None | |
324 self.peerConnections[peerHost] = connectionCount + 1 | |
325 return WrappingFactory.buildProtocol(self, addr) | |
326 | |
327 def unregisterProtocol(self, p): | |
328 peerHost = p.getPeer()[1] | |
329 self.peerConnections[peerHost] -= 1 | |
330 if self.peerConnections[peerHost] == 0: | |
331 del self.peerConnections[peerHost] | |
332 | |
333 | |
334 class LimitTotalConnectionsFactory(ServerFactory): | |
335 """ | |
336 Factory that limits the number of simultaneous connections. | |
337 | |
338 @type connectionCount: C{int} | |
339 @ivar connectionCount: number of current connections. | |
340 @type connectionLimit: C{int} or C{None} | |
341 @cvar connectionLimit: maximum number of connections. | |
342 @type overflowProtocol: L{Protocol} or C{None} | |
343 @cvar overflowProtocol: Protocol to use for new connections when | |
344 connectionLimit is exceeded. If C{None} (the default value), excess | |
345 connections will be closed immediately. | |
346 """ | |
347 connectionCount = 0 | |
348 connectionLimit = None | |
349 overflowProtocol = None | |
350 | |
351 def buildProtocol(self, addr): | |
352 if (self.connectionLimit is None or | |
353 self.connectionCount < self.connectionLimit): | |
354 # Build the normal protocol | |
355 wrappedProtocol = self.protocol() | |
356 elif self.overflowProtocol is None: | |
357 # Just drop the connection | |
358 return None | |
359 else: | |
360 # Too many connections, so build the overflow protocol | |
361 wrappedProtocol = self.overflowProtocol() | |
362 | |
363 wrappedProtocol.factory = self | |
364 protocol = ProtocolWrapper(self, wrappedProtocol) | |
365 self.connectionCount += 1 | |
366 return protocol | |
367 | |
368 def registerProtocol(self, p): | |
369 pass | |
370 | |
371 def unregisterProtocol(self, p): | |
372 self.connectionCount -= 1 | |
373 | |
374 | |
375 | |
376 class TimeoutProtocol(ProtocolWrapper): | |
377 """ | |
378 Protocol that automatically disconnects when the connection is idle. | |
379 """ | |
380 | |
381 def __init__(self, factory, wrappedProtocol, timeoutPeriod): | |
382 """ | |
383 Constructor. | |
384 | |
385 @param factory: An L{IFactory}. | |
386 @param wrappedProtocol: A L{Protocol} to wrapp. | |
387 @param timeoutPeriod: Number of seconds to wait for activity before | |
388 timing out. | |
389 """ | |
390 ProtocolWrapper.__init__(self, factory, wrappedProtocol) | |
391 self.timeoutCall = None | |
392 self.setTimeout(timeoutPeriod) | |
393 | |
394 | |
395 def setTimeout(self, timeoutPeriod=None): | |
396 """ | |
397 Set a timeout. | |
398 | |
399 This will cancel any existing timeouts. | |
400 | |
401 @param timeoutPeriod: If not C{None}, change the timeout period. | |
402 Otherwise, use the existing value. | |
403 """ | |
404 self.cancelTimeout() | |
405 if timeoutPeriod is not None: | |
406 self.timeoutPeriod = timeoutPeriod | |
407 self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeo
utFunc) | |
408 | |
409 | |
410 def cancelTimeout(self): | |
411 """ | |
412 Cancel the timeout. | |
413 | |
414 If the timeout was already cancelled, this does nothing. | |
415 """ | |
416 if self.timeoutCall: | |
417 try: | |
418 self.timeoutCall.cancel() | |
419 except error.AlreadyCalled: | |
420 pass | |
421 self.timeoutCall = None | |
422 | |
423 | |
424 def resetTimeout(self): | |
425 """ | |
426 Reset the timeout, usually because some activity just happened. | |
427 """ | |
428 if self.timeoutCall: | |
429 self.timeoutCall.reset(self.timeoutPeriod) | |
430 | |
431 | |
432 def write(self, data): | |
433 self.resetTimeout() | |
434 ProtocolWrapper.write(self, data) | |
435 | |
436 | |
437 def writeSequence(self, seq): | |
438 self.resetTimeout() | |
439 ProtocolWrapper.writeSequence(self, seq) | |
440 | |
441 | |
442 def dataReceived(self, data): | |
443 self.resetTimeout() | |
444 ProtocolWrapper.dataReceived(self, data) | |
445 | |
446 | |
447 def connectionLost(self, reason): | |
448 self.cancelTimeout() | |
449 ProtocolWrapper.connectionLost(self, reason) | |
450 | |
451 | |
452 def timeoutFunc(self): | |
453 """ | |
454 This method is called when the timeout is triggered. | |
455 | |
456 By default it calls L{loseConnection}. Override this if you want | |
457 something else to happen. | |
458 """ | |
459 self.loseConnection() | |
460 | |
461 | |
462 | |
463 class TimeoutFactory(WrappingFactory): | |
464 """ | |
465 Factory for TimeoutWrapper. | |
466 """ | |
467 protocol = TimeoutProtocol | |
468 | |
469 | |
470 def __init__(self, wrappedFactory, timeoutPeriod=30*60): | |
471 self.timeoutPeriod = timeoutPeriod | |
472 WrappingFactory.__init__(self, wrappedFactory) | |
473 | |
474 | |
475 def buildProtocol(self, addr): | |
476 return self.protocol(self, self.wrappedFactory.buildProtocol(addr), | |
477 timeoutPeriod=self.timeoutPeriod) | |
478 | |
479 | |
480 def callLater(self, period, func): | |
481 """ | |
482 Wrapper around L{reactor.callLater} for test purpose. | |
483 """ | |
484 return reactor.callLater(period, func) | |
485 | |
486 | |
487 | |
488 class TrafficLoggingProtocol(ProtocolWrapper): | |
489 | |
490 def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None, | |
491 number=0): | |
492 """ | |
493 @param factory: factory which created this protocol. | |
494 @type factory: C{protocol.Factory}. | |
495 @param wrappedProtocol: the underlying protocol. | |
496 @type wrappedProtocol: C{protocol.Protocol}. | |
497 @param logfile: file opened for writing used to write log messages. | |
498 @type logfile: C{file} | |
499 @param lengthLimit: maximum size of the datareceived logged. | |
500 @type lengthLimit: C{int} | |
501 @param number: identifier of the connection. | |
502 @type number: C{int}. | |
503 """ | |
504 ProtocolWrapper.__init__(self, factory, wrappedProtocol) | |
505 self.logfile = logfile | |
506 self.lengthLimit = lengthLimit | |
507 self._number = number | |
508 | |
509 | |
510 def _log(self, line): | |
511 self.logfile.write(line + '\n') | |
512 self.logfile.flush() | |
513 | |
514 | |
515 def _mungeData(self, data): | |
516 if self.lengthLimit and len(data) > self.lengthLimit: | |
517 data = data[:self.lengthLimit - 12] + '<... elided>' | |
518 return data | |
519 | |
520 | |
521 # IProtocol | |
522 def connectionMade(self): | |
523 self._log('*') | |
524 return ProtocolWrapper.connectionMade(self) | |
525 | |
526 | |
527 def dataReceived(self, data): | |
528 self._log('C %d: %r' % (self._number, self._mungeData(data))) | |
529 return ProtocolWrapper.dataReceived(self, data) | |
530 | |
531 | |
532 def connectionLost(self, reason): | |
533 self._log('C %d: %r' % (self._number, reason)) | |
534 return ProtocolWrapper.connectionLost(self, reason) | |
535 | |
536 | |
537 # ITransport | |
538 def write(self, data): | |
539 self._log('S %d: %r' % (self._number, self._mungeData(data))) | |
540 return ProtocolWrapper.write(self, data) | |
541 | |
542 | |
543 def writeSequence(self, iovec): | |
544 self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iove
c])) | |
545 return ProtocolWrapper.writeSequence(self, iovec) | |
546 | |
547 | |
548 def loseConnection(self): | |
549 self._log('S %d: *' % (self._number,)) | |
550 return ProtocolWrapper.loseConnection(self) | |
551 | |
552 | |
553 | |
554 class TrafficLoggingFactory(WrappingFactory): | |
555 protocol = TrafficLoggingProtocol | |
556 | |
557 _counter = 0 | |
558 | |
559 def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None): | |
560 self.logfilePrefix = logfilePrefix | |
561 self.lengthLimit = lengthLimit | |
562 WrappingFactory.__init__(self, wrappedFactory) | |
563 | |
564 | |
565 def open(self, name): | |
566 return file(name, 'w') | |
567 | |
568 | |
569 def buildProtocol(self, addr): | |
570 self._counter += 1 | |
571 logfile = self.open(self.logfilePrefix + '-' + str(self._counter)) | |
572 return self.protocol(self, self.wrappedFactory.buildProtocol(addr), | |
573 logfile, self.lengthLimit, self._counter) | |
574 | |
575 | |
576 def resetCounter(self): | |
577 """ | |
578 Reset the value of the counter used to identify connections. | |
579 """ | |
580 self._counter = 0 | |
581 | |
582 | |
583 | |
584 class TimeoutMixin: | |
585 """Mixin for protocols which wish to timeout connections | |
586 | |
587 @cvar timeOut: The number of seconds after which to timeout the connection. | |
588 """ | |
589 timeOut = None | |
590 | |
591 __timeoutCall = None | |
592 | |
593 def callLater(self, period, func): | |
594 return reactor.callLater(period, func) | |
595 | |
596 | |
597 def resetTimeout(self): | |
598 """Reset the timeout count down""" | |
599 if self.__timeoutCall is not None and self.timeOut is not None: | |
600 self.__timeoutCall.reset(self.timeOut) | |
601 | |
602 def setTimeout(self, period): | |
603 """Change the timeout period | |
604 | |
605 @type period: C{int} or C{NoneType} | |
606 @param period: The period, in seconds, to change the timeout to, or | |
607 C{None} to disable the timeout. | |
608 """ | |
609 prev = self.timeOut | |
610 self.timeOut = period | |
611 | |
612 if self.__timeoutCall is not None: | |
613 if period is None: | |
614 self.__timeoutCall.cancel() | |
615 self.__timeoutCall = None | |
616 else: | |
617 self.__timeoutCall.reset(period) | |
618 elif period is not None: | |
619 self.__timeoutCall = self.callLater(period, self.__timedOut) | |
620 | |
621 return prev | |
622 | |
623 def __timedOut(self): | |
624 self.__timeoutCall = None | |
625 self.timeoutConnection() | |
626 | |
627 def timeoutConnection(self): | |
628 """Called when the connection times out. | |
629 Override to define behavior other than dropping the connection. | |
630 """ | |
631 self.transport.loseConnection() | |
OLD | NEW |