OLD | NEW |
| (Empty) |
1 # Copyright (c) 2008 Twisted Matrix Laboratories. | |
2 # See LICENSE for details. | |
3 | |
4 | |
5 """ | |
6 TCP support for IOCP reactor | |
7 """ | |
8 | |
9 from twisted.internet import interfaces, error, address, main, defer | |
10 from twisted.internet.abstract import isIPAddress | |
11 from twisted.internet.tcp import _SocketCloser, Connector as TCPConnector | |
12 from twisted.persisted import styles | |
13 from twisted.python import log, failure, reflect, util | |
14 | |
15 from zope.interface import implements | |
16 import socket, operator, errno, struct | |
17 | |
18 from twisted.internet.iocpreactor import iocpsupport as _iocp, abstract | |
19 from twisted.internet.iocpreactor.interfaces import IReadWriteHandle | |
20 from twisted.internet.iocpreactor.const import ERROR_IO_PENDING | |
21 from twisted.internet.iocpreactor.const import SO_UPDATE_CONNECT_CONTEXT | |
22 from twisted.internet.iocpreactor.const import SO_UPDATE_ACCEPT_CONTEXT | |
23 from twisted.internet.iocpreactor.const import ERROR_CONNECTION_REFUSED | |
24 from twisted.internet.iocpreactor.const import ERROR_NETWORK_UNREACHABLE | |
25 | |
26 # ConnectEx returns these. XXX: find out what it does for timeout | |
27 connectExErrors = { | |
28 ERROR_CONNECTION_REFUSED: errno.WSAECONNREFUSED, | |
29 ERROR_NETWORK_UNREACHABLE: errno.WSAENETUNREACH, | |
30 } | |
31 | |
32 | |
33 | |
34 class Connection(abstract.FileHandle, _SocketCloser): | |
35 implements(IReadWriteHandle, interfaces.ITCPTransport, | |
36 interfaces.ISystemHandle) | |
37 | |
38 | |
39 def __init__(self, sock, proto, reactor=None): | |
40 abstract.FileHandle.__init__(self, reactor) | |
41 self.socket = sock | |
42 self.getFileHandle = sock.fileno | |
43 self.protocol = proto | |
44 | |
45 | |
46 def getHandle(self): | |
47 return self.socket | |
48 | |
49 | |
50 def dataReceived(self, rbuffer): | |
51 # XXX: some day, we'll have protocols that can handle raw buffers | |
52 self.protocol.dataReceived(str(rbuffer)) | |
53 | |
54 | |
55 def readFromHandle(self, bufflist, evt): | |
56 return _iocp.recv(self.getFileHandle(), bufflist, evt) | |
57 | |
58 | |
59 def writeToHandle(self, buff, evt): | |
60 return _iocp.send(self.getFileHandle(), buff, evt) | |
61 | |
62 | |
63 def _closeWriteConnection(self): | |
64 try: | |
65 getattr(self.socket, self._socketShutdownMethod)(1) | |
66 except socket.error: | |
67 pass | |
68 p = interfaces.IHalfCloseableProtocol(self.protocol, None) | |
69 if p: | |
70 try: | |
71 p.writeConnectionLost() | |
72 except: | |
73 f = failure.Failure() | |
74 log.err() | |
75 self.connectionLost(f) | |
76 | |
77 | |
78 def readConnectionLost(self, reason): | |
79 p = interfaces.IHalfCloseableProtocol(self.protocol, None) | |
80 if p: | |
81 try: | |
82 p.readConnectionLost() | |
83 except: | |
84 log.err() | |
85 self.connectionLost(failure.Failure()) | |
86 else: | |
87 self.connectionLost(reason) | |
88 | |
89 | |
90 def connectionLost(self, reason): | |
91 abstract.FileHandle.connectionLost(self, reason) | |
92 self._closeSocket() | |
93 protocol = self.protocol | |
94 del self.protocol | |
95 del self.socket | |
96 del self.getFileHandle | |
97 protocol.connectionLost(reason) | |
98 | |
99 | |
100 def logPrefix(self): | |
101 """ | |
102 Return the prefix to log with when I own the logging thread. | |
103 """ | |
104 return self.logstr | |
105 | |
106 | |
107 def getTcpNoDelay(self): | |
108 return operator.truth(self.socket.getsockopt(socket.IPPROTO_TCP, | |
109 socket.TCP_NODELAY)) | |
110 | |
111 | |
112 def setTcpNoDelay(self, enabled): | |
113 self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled) | |
114 | |
115 | |
116 def getTcpKeepAlive(self): | |
117 return operator.truth(self.socket.getsockopt(socket.SOL_SOCKET, | |
118 socket.SO_KEEPALIVE)) | |
119 | |
120 | |
121 def setTcpKeepAlive(self, enabled): | |
122 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, enabled) | |
123 | |
124 | |
125 | |
126 class Client(Connection): | |
127 addressFamily = socket.AF_INET | |
128 socketType = socket.SOCK_STREAM | |
129 | |
130 | |
131 def __init__(self, host, port, bindAddress, connector, reactor): | |
132 self.connector = connector | |
133 self.addr = (host, port) | |
134 self.reactor = reactor | |
135 # ConnectEx documentation says socket _has_ to be bound | |
136 if bindAddress is None: | |
137 bindAddress = ('', 0) | |
138 | |
139 try: | |
140 try: | |
141 skt = reactor.createSocket(self.addressFamily, self.socketType) | |
142 except socket.error, se: | |
143 raise error.ConnectBindError(se[0], se[1]) | |
144 else: | |
145 try: | |
146 skt.bind(bindAddress) | |
147 except socket.error, se: | |
148 raise error.ConnectBindError(se[0], se[1]) | |
149 self.socket = skt | |
150 Connection.__init__(self, skt, None) | |
151 reactor.callLater(0, self.resolveAddress) | |
152 except error.ConnectBindError, err: | |
153 reactor.callLater(0, self.failIfNotConnected, err) | |
154 | |
155 | |
156 def resolveAddress(self): | |
157 if isIPAddress(self.addr[0]): | |
158 self._setRealAddress(self.addr[0]) | |
159 else: | |
160 d = self.reactor.resolve(self.addr[0]) | |
161 d.addCallbacks(self._setRealAddress, self.failIfNotConnected) | |
162 | |
163 | |
164 def _setRealAddress(self, address): | |
165 self.realAddress = (address, self.addr[1]) | |
166 self.doConnect() | |
167 | |
168 | |
169 def failIfNotConnected(self, err): | |
170 if (self.connected or self.disconnected or | |
171 not hasattr(self, "connector")): | |
172 return | |
173 | |
174 try: | |
175 self._closeSocket() | |
176 except AttributeError: | |
177 pass | |
178 else: | |
179 del self.socket, self.getFileHandle | |
180 self.reactor.removeActiveHandle(self) | |
181 | |
182 self.connector.connectionFailed(failure.Failure(err)) | |
183 del self.connector | |
184 | |
185 | |
186 def stopConnecting(self): | |
187 """ | |
188 Stop attempt to connect. | |
189 """ | |
190 self.failIfNotConnected(error.UserError()) | |
191 | |
192 | |
193 def cbConnect(self, rc, bytes, evt): | |
194 if rc: | |
195 rc = connectExErrors.get(rc, rc) | |
196 self.failIfNotConnected(error.getConnectError((rc, | |
197 errno.errorcode.get(rc, 'Unknown error')))) | |
198 else: | |
199 self.socket.setsockopt(socket.SOL_SOCKET, | |
200 SO_UPDATE_CONNECT_CONTEXT, | |
201 struct.pack('I', self.socket.fileno())) | |
202 self.protocol = self.connector.buildProtocol(self.getPeer()) | |
203 self.connected = True | |
204 self.logstr = self.protocol.__class__.__name__+",client" | |
205 self.protocol.makeConnection(self) | |
206 self.startReading() | |
207 | |
208 | |
209 def doConnect(self): | |
210 if not hasattr(self, "connector"): | |
211 # this happens if we connector.stopConnecting in | |
212 # factory.startedConnecting | |
213 return | |
214 assert _iocp.have_connectex | |
215 self.reactor.addActiveHandle(self) | |
216 evt = _iocp.Event(self.cbConnect, self) | |
217 | |
218 rc = _iocp.connect(self.socket.fileno(), self.realAddress, evt) | |
219 if rc == ERROR_IO_PENDING: | |
220 return | |
221 else: | |
222 evt.ignore = True | |
223 self.cbConnect(rc, 0, 0, evt) | |
224 | |
225 | |
226 def getHost(self): | |
227 """ | |
228 Returns an IPv4Address. | |
229 | |
230 This indicates the address from which I am connecting. | |
231 """ | |
232 return address.IPv4Address('TCP', *(self.socket.getsockname() + | |
233 ('INET',))) | |
234 | |
235 | |
236 def getPeer(self): | |
237 """ | |
238 Returns an IPv4Address. | |
239 | |
240 This indicates the address that I am connected to. | |
241 """ | |
242 return address.IPv4Address('TCP', *(self.addr + ('INET',))) | |
243 | |
244 | |
245 def __repr__(self): | |
246 s = ('<%s to %s at %x>' % | |
247 (self.__class__, self.addr, util.unsignedID(self))) | |
248 return s | |
249 | |
250 | |
251 def connectionLost(self, reason): | |
252 if not self.connected: | |
253 self.failIfNotConnected(error.ConnectError(string=reason)) | |
254 else: | |
255 Connection.connectionLost(self, reason) | |
256 self.connector.connectionLost(reason) | |
257 | |
258 | |
259 | |
260 class Server(Connection): | |
261 """ | |
262 Serverside socket-stream connection class. | |
263 | |
264 I am a serverside network connection transport; a socket which came from an | |
265 accept() on a server. | |
266 """ | |
267 | |
268 | |
269 def __init__(self, sock, protocol, clientAddr, serverAddr, sessionno): | |
270 """ | |
271 Server(sock, protocol, client, server, sessionno) | |
272 | |
273 Initialize me with a socket, a protocol, a descriptor for my peer (a | |
274 tuple of host, port describing the other end of the connection), an | |
275 instance of Port, and a session number. | |
276 """ | |
277 Connection.__init__(self, sock, protocol) | |
278 self.serverAddr = serverAddr | |
279 self.clientAddr = clientAddr | |
280 self.sessionno = sessionno | |
281 self.logstr = "%s,%s,%s" % (self.protocol.__class__.__name__, | |
282 sessionno, self.clientAddr.host) | |
283 self.repstr = "<%s #%s on %s>" % (self.protocol.__class__.__name__, | |
284 self.sessionno, self.serverAddr.port) | |
285 self.connected = True | |
286 self.startReading() | |
287 | |
288 | |
289 def __repr__(self): | |
290 """ | |
291 A string representation of this connection. | |
292 """ | |
293 return self.repstr | |
294 | |
295 | |
296 def getHost(self): | |
297 """ | |
298 Returns an IPv4Address. | |
299 | |
300 This indicates the server's address. | |
301 """ | |
302 return self.serverAddr | |
303 | |
304 | |
305 def getPeer(self): | |
306 """ | |
307 Returns an IPv4Address. | |
308 | |
309 This indicates the client's address. | |
310 """ | |
311 return self.clientAddr | |
312 | |
313 | |
314 | |
315 class Connector(TCPConnector): | |
316 def _makeTransport(self): | |
317 return Client(self.host, self.port, self.bindAddress, self, | |
318 self.reactor) | |
319 | |
320 | |
321 | |
322 class Port(styles.Ephemeral, _SocketCloser): | |
323 implements(interfaces.IListeningPort) | |
324 | |
325 connected = False | |
326 disconnected = False | |
327 disconnecting = False | |
328 addressFamily = socket.AF_INET | |
329 socketType = socket.SOCK_STREAM | |
330 | |
331 sessionno = 0 | |
332 | |
333 maxAccepts = 100 | |
334 | |
335 # Actual port number being listened on, only set to a non-None | |
336 # value when we are actually listening. | |
337 _realPortNumber = None | |
338 | |
339 | |
340 def __init__(self, port, factory, backlog=50, interface='', reactor=None): | |
341 self.port = port | |
342 self.factory = factory | |
343 self.backlog = backlog | |
344 self.interface = interface | |
345 self.reactor = reactor | |
346 | |
347 skt = socket.socket(self.addressFamily, self.socketType) | |
348 self.addrLen = _iocp.maxAddrLen(skt.fileno()) | |
349 | |
350 | |
351 def __repr__(self): | |
352 if self._realPortNumber is not None: | |
353 return "<%s of %s on %s>" % (self.__class__, | |
354 self.factory.__class__, | |
355 self._realPortNumber) | |
356 else: | |
357 return "<%s of %s (not listening)>" % (self.__class__, | |
358 self.factory.__class__) | |
359 | |
360 | |
361 def startListening(self): | |
362 try: | |
363 skt = self.reactor.createSocket(self.addressFamily, | |
364 self.socketType) | |
365 # TODO: resolve self.interface if necessary | |
366 skt.bind((self.interface, self.port)) | |
367 except socket.error, le: | |
368 raise error.CannotListenError, (self.interface, self.port, le) | |
369 | |
370 # Make sure that if we listened on port 0, we update that to | |
371 # reflect what the OS actually assigned us. | |
372 self._realPortNumber = skt.getsockname()[1] | |
373 | |
374 log.msg("%s starting on %s" % (self.factory.__class__, | |
375 self._realPortNumber)) | |
376 | |
377 self.factory.doStart() | |
378 skt.listen(self.backlog) | |
379 self.connected = True | |
380 self.reactor.addActiveHandle(self) | |
381 self.socket = skt | |
382 self.getFileHandle = self.socket.fileno | |
383 self.doAccept() | |
384 | |
385 | |
386 def loseConnection(self, connDone=failure.Failure(main.CONNECTION_DONE)): | |
387 """ | |
388 Stop accepting connections on this port. | |
389 | |
390 This will shut down my socket and call self.connectionLost(). | |
391 It returns a deferred which will fire successfully when the | |
392 port is actually closed. | |
393 """ | |
394 self.disconnecting = True | |
395 if self.connected: | |
396 self.deferred = defer.Deferred() | |
397 self.reactor.callLater(0, self.connectionLost, connDone) | |
398 return self.deferred | |
399 | |
400 stopListening = loseConnection | |
401 | |
402 | |
403 def connectionLost(self, reason): | |
404 """ | |
405 Cleans up my socket. | |
406 """ | |
407 log.msg('(Port %s Closed)' % self._realPortNumber) | |
408 self._realPortNumber = None | |
409 self.disconnected = True | |
410 self.reactor.removeActiveHandle(self) | |
411 self.connected = False | |
412 self._closeSocket() | |
413 del self.socket | |
414 del self.getFileHandle | |
415 self.factory.doStop() | |
416 if hasattr(self, "deferred"): | |
417 self.deferred.callback(None) | |
418 del self.deferred | |
419 | |
420 | |
421 def logPrefix(self): | |
422 """ | |
423 Returns the name of my class, to prefix log entries with. | |
424 """ | |
425 return reflect.qual(self.factory.__class__) | |
426 | |
427 | |
428 def getHost(self): | |
429 """ | |
430 Returns an IPv4Address. | |
431 | |
432 This indicates the server's address. | |
433 """ | |
434 return address.IPv4Address('TCP', *(self.socket.getsockname() + | |
435 ('INET',))) | |
436 | |
437 | |
438 def cbAccept(self, rc, bytes, evt): | |
439 self.handleAccept(rc, evt) | |
440 if not (self.disconnecting or self.disconnected): | |
441 self.doAccept() | |
442 | |
443 | |
444 def handleAccept(self, rc, evt): | |
445 if self.disconnecting or self.disconnected: | |
446 return False | |
447 | |
448 # possible errors: | |
449 # (WSAEMFILE, WSAENOBUFS, WSAENFILE, WSAENOMEM, WSAECONNABORTED) | |
450 if rc: | |
451 log.msg("Could not accept new connection -- %s (%s)" % | |
452 (errno.errorcode.get(rc, 'unknown error'), rc)) | |
453 return False | |
454 else: | |
455 evt.newskt.setsockopt(socket.SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, | |
456 struct.pack('I', self.socket.fileno())) | |
457 family, lAddr, rAddr = _iocp.get_accept_addrs(evt.newskt.fileno(), | |
458 evt.buff) | |
459 assert family == self.addressFamily | |
460 | |
461 protocol = self.factory.buildProtocol( | |
462 address._ServerFactoryIPv4Address('TCP', rAddr[0], rAddr[1])) | |
463 if protocol is None: | |
464 evt.newskt.close() | |
465 else: | |
466 s = self.sessionno | |
467 self.sessionno = s+1 | |
468 transport = Server(evt.newskt, protocol, | |
469 address.IPv4Address('TCP', rAddr[0], rAddr[1], 'INET'), | |
470 address.IPv4Address('TCP', lAddr[0], lAddr[1], 'INET'), | |
471 s) | |
472 protocol.makeConnection(transport) | |
473 return True | |
474 | |
475 | |
476 def doAccept(self): | |
477 numAccepts = 0 | |
478 while 1: | |
479 evt = _iocp.Event(self.cbAccept, self) | |
480 | |
481 # see AcceptEx documentation | |
482 evt.buff = buff = _iocp.AllocateReadBuffer(2 * (self.addrLen + 16)) | |
483 | |
484 evt.newskt = newskt = self.reactor.createSocket(self.addressFamily, | |
485 self.socketType) | |
486 rc = _iocp.accept(self.socket.fileno(), newskt.fileno(), buff, evt) | |
487 | |
488 if (rc == ERROR_IO_PENDING | |
489 or (not rc and numAccepts >= self.maxAccepts)): | |
490 break | |
491 else: | |
492 evt.ignore = True | |
493 if not self.handleAccept(rc, evt): | |
494 break | |
495 numAccepts += 1 | |
496 | |
497 | |
OLD | NEW |