| OLD | NEW | 
|---|
|  | (Empty) | 
| 1 # Copyright (c) 2008 Twisted Matrix Laboratories. |  | 
| 2 # See LICENSE for details. |  | 
| 3 |  | 
| 4 |  | 
| 5 """ |  | 
| 6 TCP support for IOCP reactor |  | 
| 7 """ |  | 
| 8 |  | 
| 9 from twisted.internet import interfaces, error, address, main, defer |  | 
| 10 from twisted.internet.abstract import isIPAddress |  | 
| 11 from twisted.internet.tcp import _SocketCloser, Connector as TCPConnector |  | 
| 12 from twisted.persisted import styles |  | 
| 13 from twisted.python import log, failure, reflect, util |  | 
| 14 |  | 
| 15 from zope.interface import implements |  | 
| 16 import socket, operator, errno, struct |  | 
| 17 |  | 
| 18 from twisted.internet.iocpreactor import iocpsupport as _iocp, abstract |  | 
| 19 from twisted.internet.iocpreactor.interfaces import IReadWriteHandle |  | 
| 20 from twisted.internet.iocpreactor.const import ERROR_IO_PENDING |  | 
| 21 from twisted.internet.iocpreactor.const import SO_UPDATE_CONNECT_CONTEXT |  | 
| 22 from twisted.internet.iocpreactor.const import SO_UPDATE_ACCEPT_CONTEXT |  | 
| 23 from twisted.internet.iocpreactor.const import ERROR_CONNECTION_REFUSED |  | 
| 24 from twisted.internet.iocpreactor.const import ERROR_NETWORK_UNREACHABLE |  | 
| 25 |  | 
| 26 # ConnectEx returns these. XXX: find out what it does for timeout |  | 
| 27 connectExErrors = { |  | 
| 28         ERROR_CONNECTION_REFUSED: errno.WSAECONNREFUSED, |  | 
| 29         ERROR_NETWORK_UNREACHABLE: errno.WSAENETUNREACH, |  | 
| 30         } |  | 
| 31 |  | 
| 32 |  | 
| 33 |  | 
| 34 class Connection(abstract.FileHandle, _SocketCloser): |  | 
| 35     implements(IReadWriteHandle, interfaces.ITCPTransport, |  | 
| 36                interfaces.ISystemHandle) |  | 
| 37 |  | 
| 38 |  | 
| 39     def __init__(self, sock, proto, reactor=None): |  | 
| 40         abstract.FileHandle.__init__(self, reactor) |  | 
| 41         self.socket = sock |  | 
| 42         self.getFileHandle = sock.fileno |  | 
| 43         self.protocol = proto |  | 
| 44 |  | 
| 45 |  | 
| 46     def getHandle(self): |  | 
| 47         return self.socket |  | 
| 48 |  | 
| 49 |  | 
| 50     def dataReceived(self, rbuffer): |  | 
| 51         # XXX: some day, we'll have protocols that can handle raw buffers |  | 
| 52         self.protocol.dataReceived(str(rbuffer)) |  | 
| 53 |  | 
| 54 |  | 
| 55     def readFromHandle(self, bufflist, evt): |  | 
| 56         return _iocp.recv(self.getFileHandle(), bufflist, evt) |  | 
| 57 |  | 
| 58 |  | 
| 59     def writeToHandle(self, buff, evt): |  | 
| 60         return _iocp.send(self.getFileHandle(), buff, evt) |  | 
| 61 |  | 
| 62 |  | 
| 63     def _closeWriteConnection(self): |  | 
| 64         try: |  | 
| 65             getattr(self.socket, self._socketShutdownMethod)(1) |  | 
| 66         except socket.error: |  | 
| 67             pass |  | 
| 68         p = interfaces.IHalfCloseableProtocol(self.protocol, None) |  | 
| 69         if p: |  | 
| 70             try: |  | 
| 71                 p.writeConnectionLost() |  | 
| 72             except: |  | 
| 73                 f = failure.Failure() |  | 
| 74                 log.err() |  | 
| 75                 self.connectionLost(f) |  | 
| 76 |  | 
| 77 |  | 
| 78     def readConnectionLost(self, reason): |  | 
| 79         p = interfaces.IHalfCloseableProtocol(self.protocol, None) |  | 
| 80         if p: |  | 
| 81             try: |  | 
| 82                 p.readConnectionLost() |  | 
| 83             except: |  | 
| 84                 log.err() |  | 
| 85                 self.connectionLost(failure.Failure()) |  | 
| 86         else: |  | 
| 87             self.connectionLost(reason) |  | 
| 88 |  | 
| 89 |  | 
| 90     def connectionLost(self, reason): |  | 
| 91         abstract.FileHandle.connectionLost(self, reason) |  | 
| 92         self._closeSocket() |  | 
| 93         protocol = self.protocol |  | 
| 94         del self.protocol |  | 
| 95         del self.socket |  | 
| 96         del self.getFileHandle |  | 
| 97         protocol.connectionLost(reason) |  | 
| 98 |  | 
| 99 |  | 
| 100     def logPrefix(self): |  | 
| 101         """ |  | 
| 102         Return the prefix to log with when I own the logging thread. |  | 
| 103         """ |  | 
| 104         return self.logstr |  | 
| 105 |  | 
| 106 |  | 
| 107     def getTcpNoDelay(self): |  | 
| 108         return operator.truth(self.socket.getsockopt(socket.IPPROTO_TCP, |  | 
| 109                                                      socket.TCP_NODELAY)) |  | 
| 110 |  | 
| 111 |  | 
| 112     def setTcpNoDelay(self, enabled): |  | 
| 113         self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled) |  | 
| 114 |  | 
| 115 |  | 
| 116     def getTcpKeepAlive(self): |  | 
| 117         return operator.truth(self.socket.getsockopt(socket.SOL_SOCKET, |  | 
| 118                                                      socket.SO_KEEPALIVE)) |  | 
| 119 |  | 
| 120 |  | 
| 121     def setTcpKeepAlive(self, enabled): |  | 
| 122         self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, enabled) |  | 
| 123 |  | 
| 124 |  | 
| 125 |  | 
| 126 class Client(Connection): |  | 
| 127     addressFamily = socket.AF_INET |  | 
| 128     socketType = socket.SOCK_STREAM |  | 
| 129 |  | 
| 130 |  | 
| 131     def __init__(self, host, port, bindAddress, connector, reactor): |  | 
| 132         self.connector = connector |  | 
| 133         self.addr = (host, port) |  | 
| 134         self.reactor = reactor |  | 
| 135         # ConnectEx documentation says socket _has_ to be bound |  | 
| 136         if bindAddress is None: |  | 
| 137             bindAddress = ('', 0) |  | 
| 138 |  | 
| 139         try: |  | 
| 140             try: |  | 
| 141                 skt = reactor.createSocket(self.addressFamily, self.socketType) |  | 
| 142             except socket.error, se: |  | 
| 143                 raise error.ConnectBindError(se[0], se[1]) |  | 
| 144             else: |  | 
| 145                 try: |  | 
| 146                     skt.bind(bindAddress) |  | 
| 147                 except socket.error, se: |  | 
| 148                     raise error.ConnectBindError(se[0], se[1]) |  | 
| 149                 self.socket = skt |  | 
| 150                 Connection.__init__(self, skt, None) |  | 
| 151                 reactor.callLater(0, self.resolveAddress) |  | 
| 152         except error.ConnectBindError, err: |  | 
| 153             reactor.callLater(0, self.failIfNotConnected, err) |  | 
| 154 |  | 
| 155 |  | 
| 156     def resolveAddress(self): |  | 
| 157         if isIPAddress(self.addr[0]): |  | 
| 158             self._setRealAddress(self.addr[0]) |  | 
| 159         else: |  | 
| 160             d = self.reactor.resolve(self.addr[0]) |  | 
| 161             d.addCallbacks(self._setRealAddress, self.failIfNotConnected) |  | 
| 162 |  | 
| 163 |  | 
| 164     def _setRealAddress(self, address): |  | 
| 165         self.realAddress = (address, self.addr[1]) |  | 
| 166         self.doConnect() |  | 
| 167 |  | 
| 168 |  | 
| 169     def failIfNotConnected(self, err): |  | 
| 170         if (self.connected or self.disconnected or |  | 
| 171             not hasattr(self, "connector")): |  | 
| 172             return |  | 
| 173 |  | 
| 174         try: |  | 
| 175             self._closeSocket() |  | 
| 176         except AttributeError: |  | 
| 177             pass |  | 
| 178         else: |  | 
| 179             del self.socket, self.getFileHandle |  | 
| 180         self.reactor.removeActiveHandle(self) |  | 
| 181 |  | 
| 182         self.connector.connectionFailed(failure.Failure(err)) |  | 
| 183         del self.connector |  | 
| 184 |  | 
| 185 |  | 
| 186     def stopConnecting(self): |  | 
| 187         """ |  | 
| 188         Stop attempt to connect. |  | 
| 189         """ |  | 
| 190         self.failIfNotConnected(error.UserError()) |  | 
| 191 |  | 
| 192 |  | 
| 193     def cbConnect(self, rc, bytes, evt): |  | 
| 194         if rc: |  | 
| 195             rc = connectExErrors.get(rc, rc) |  | 
| 196             self.failIfNotConnected(error.getConnectError((rc, |  | 
| 197                                     errno.errorcode.get(rc, 'Unknown error')))) |  | 
| 198         else: |  | 
| 199             self.socket.setsockopt(socket.SOL_SOCKET, |  | 
| 200                                    SO_UPDATE_CONNECT_CONTEXT, |  | 
| 201                                    struct.pack('I', self.socket.fileno())) |  | 
| 202             self.protocol = self.connector.buildProtocol(self.getPeer()) |  | 
| 203             self.connected = True |  | 
| 204             self.logstr = self.protocol.__class__.__name__+",client" |  | 
| 205             self.protocol.makeConnection(self) |  | 
| 206             self.startReading() |  | 
| 207 |  | 
| 208 |  | 
| 209     def doConnect(self): |  | 
| 210         if not hasattr(self, "connector"): |  | 
| 211             # this happens if we connector.stopConnecting in |  | 
| 212             # factory.startedConnecting |  | 
| 213             return |  | 
| 214         assert _iocp.have_connectex |  | 
| 215         self.reactor.addActiveHandle(self) |  | 
| 216         evt = _iocp.Event(self.cbConnect, self) |  | 
| 217 |  | 
| 218         rc = _iocp.connect(self.socket.fileno(), self.realAddress, evt) |  | 
| 219         if rc == ERROR_IO_PENDING: |  | 
| 220             return |  | 
| 221         else: |  | 
| 222             evt.ignore = True |  | 
| 223             self.cbConnect(rc, 0, 0, evt) |  | 
| 224 |  | 
| 225 |  | 
| 226     def getHost(self): |  | 
| 227         """ |  | 
| 228         Returns an IPv4Address. |  | 
| 229 |  | 
| 230         This indicates the address from which I am connecting. |  | 
| 231         """ |  | 
| 232         return address.IPv4Address('TCP', *(self.socket.getsockname() + |  | 
| 233                                             ('INET',))) |  | 
| 234 |  | 
| 235 |  | 
| 236     def getPeer(self): |  | 
| 237         """ |  | 
| 238         Returns an IPv4Address. |  | 
| 239 |  | 
| 240         This indicates the address that I am connected to. |  | 
| 241         """ |  | 
| 242         return address.IPv4Address('TCP', *(self.addr + ('INET',))) |  | 
| 243 |  | 
| 244 |  | 
| 245     def __repr__(self): |  | 
| 246         s = ('<%s to %s at %x>' % |  | 
| 247                 (self.__class__, self.addr, util.unsignedID(self))) |  | 
| 248         return s |  | 
| 249 |  | 
| 250 |  | 
| 251     def connectionLost(self, reason): |  | 
| 252         if not self.connected: |  | 
| 253             self.failIfNotConnected(error.ConnectError(string=reason)) |  | 
| 254         else: |  | 
| 255             Connection.connectionLost(self, reason) |  | 
| 256             self.connector.connectionLost(reason) |  | 
| 257 |  | 
| 258 |  | 
| 259 |  | 
| 260 class Server(Connection): |  | 
| 261     """ |  | 
| 262     Serverside socket-stream connection class. |  | 
| 263 |  | 
| 264     I am a serverside network connection transport; a socket which came from an |  | 
| 265     accept() on a server. |  | 
| 266     """ |  | 
| 267 |  | 
| 268 |  | 
| 269     def __init__(self, sock, protocol, clientAddr, serverAddr, sessionno): |  | 
| 270         """ |  | 
| 271         Server(sock, protocol, client, server, sessionno) |  | 
| 272 |  | 
| 273         Initialize me with a socket, a protocol, a descriptor for my peer (a |  | 
| 274         tuple of host, port describing the other end of the connection), an |  | 
| 275         instance of Port, and a session number. |  | 
| 276         """ |  | 
| 277         Connection.__init__(self, sock, protocol) |  | 
| 278         self.serverAddr = serverAddr |  | 
| 279         self.clientAddr = clientAddr |  | 
| 280         self.sessionno = sessionno |  | 
| 281         self.logstr = "%s,%s,%s" % (self.protocol.__class__.__name__, |  | 
| 282                                     sessionno, self.clientAddr.host) |  | 
| 283         self.repstr = "<%s #%s on %s>" % (self.protocol.__class__.__name__, |  | 
| 284                                           self.sessionno, self.serverAddr.port) |  | 
| 285         self.connected = True |  | 
| 286         self.startReading() |  | 
| 287 |  | 
| 288 |  | 
| 289     def __repr__(self): |  | 
| 290         """ |  | 
| 291         A string representation of this connection. |  | 
| 292         """ |  | 
| 293         return self.repstr |  | 
| 294 |  | 
| 295 |  | 
| 296     def getHost(self): |  | 
| 297         """ |  | 
| 298         Returns an IPv4Address. |  | 
| 299 |  | 
| 300         This indicates the server's address. |  | 
| 301         """ |  | 
| 302         return self.serverAddr |  | 
| 303 |  | 
| 304 |  | 
| 305     def getPeer(self): |  | 
| 306         """ |  | 
| 307         Returns an IPv4Address. |  | 
| 308 |  | 
| 309         This indicates the client's address. |  | 
| 310         """ |  | 
| 311         return self.clientAddr |  | 
| 312 |  | 
| 313 |  | 
| 314 |  | 
| 315 class Connector(TCPConnector): |  | 
| 316     def _makeTransport(self): |  | 
| 317         return Client(self.host, self.port, self.bindAddress, self, |  | 
| 318                       self.reactor) |  | 
| 319 |  | 
| 320 |  | 
| 321 |  | 
| 322 class Port(styles.Ephemeral, _SocketCloser): |  | 
| 323     implements(interfaces.IListeningPort) |  | 
| 324 |  | 
| 325     connected = False |  | 
| 326     disconnected = False |  | 
| 327     disconnecting = False |  | 
| 328     addressFamily = socket.AF_INET |  | 
| 329     socketType = socket.SOCK_STREAM |  | 
| 330 |  | 
| 331     sessionno = 0 |  | 
| 332 |  | 
| 333     maxAccepts = 100 |  | 
| 334 |  | 
| 335     # Actual port number being listened on, only set to a non-None |  | 
| 336     # value when we are actually listening. |  | 
| 337     _realPortNumber = None |  | 
| 338 |  | 
| 339 |  | 
| 340     def __init__(self, port, factory, backlog=50, interface='', reactor=None): |  | 
| 341         self.port = port |  | 
| 342         self.factory = factory |  | 
| 343         self.backlog = backlog |  | 
| 344         self.interface = interface |  | 
| 345         self.reactor = reactor |  | 
| 346 |  | 
| 347         skt = socket.socket(self.addressFamily, self.socketType) |  | 
| 348         self.addrLen = _iocp.maxAddrLen(skt.fileno()) |  | 
| 349 |  | 
| 350 |  | 
| 351     def __repr__(self): |  | 
| 352         if self._realPortNumber is not None: |  | 
| 353             return "<%s of %s on %s>" % (self.__class__, |  | 
| 354                                          self.factory.__class__, |  | 
| 355                                          self._realPortNumber) |  | 
| 356         else: |  | 
| 357             return "<%s of %s (not listening)>" % (self.__class__, |  | 
| 358                                                    self.factory.__class__) |  | 
| 359 |  | 
| 360 |  | 
| 361     def startListening(self): |  | 
| 362         try: |  | 
| 363             skt = self.reactor.createSocket(self.addressFamily, |  | 
| 364                                             self.socketType) |  | 
| 365             # TODO: resolve self.interface if necessary |  | 
| 366             skt.bind((self.interface, self.port)) |  | 
| 367         except socket.error, le: |  | 
| 368             raise error.CannotListenError, (self.interface, self.port, le) |  | 
| 369 |  | 
| 370         # Make sure that if we listened on port 0, we update that to |  | 
| 371         # reflect what the OS actually assigned us. |  | 
| 372         self._realPortNumber = skt.getsockname()[1] |  | 
| 373 |  | 
| 374         log.msg("%s starting on %s" % (self.factory.__class__, |  | 
| 375                                        self._realPortNumber)) |  | 
| 376 |  | 
| 377         self.factory.doStart() |  | 
| 378         skt.listen(self.backlog) |  | 
| 379         self.connected = True |  | 
| 380         self.reactor.addActiveHandle(self) |  | 
| 381         self.socket = skt |  | 
| 382         self.getFileHandle = self.socket.fileno |  | 
| 383         self.doAccept() |  | 
| 384 |  | 
| 385 |  | 
| 386     def loseConnection(self, connDone=failure.Failure(main.CONNECTION_DONE)): |  | 
| 387         """ |  | 
| 388         Stop accepting connections on this port. |  | 
| 389 |  | 
| 390         This will shut down my socket and call self.connectionLost(). |  | 
| 391         It returns a deferred which will fire successfully when the |  | 
| 392         port is actually closed. |  | 
| 393         """ |  | 
| 394         self.disconnecting = True |  | 
| 395         if self.connected: |  | 
| 396             self.deferred = defer.Deferred() |  | 
| 397             self.reactor.callLater(0, self.connectionLost, connDone) |  | 
| 398             return self.deferred |  | 
| 399 |  | 
| 400     stopListening = loseConnection |  | 
| 401 |  | 
| 402 |  | 
| 403     def connectionLost(self, reason): |  | 
| 404         """ |  | 
| 405         Cleans up my socket. |  | 
| 406         """ |  | 
| 407         log.msg('(Port %s Closed)' % self._realPortNumber) |  | 
| 408         self._realPortNumber = None |  | 
| 409         self.disconnected = True |  | 
| 410         self.reactor.removeActiveHandle(self) |  | 
| 411         self.connected = False |  | 
| 412         self._closeSocket() |  | 
| 413         del self.socket |  | 
| 414         del self.getFileHandle |  | 
| 415         self.factory.doStop() |  | 
| 416         if hasattr(self, "deferred"): |  | 
| 417             self.deferred.callback(None) |  | 
| 418             del self.deferred |  | 
| 419 |  | 
| 420 |  | 
| 421     def logPrefix(self): |  | 
| 422         """ |  | 
| 423         Returns the name of my class, to prefix log entries with. |  | 
| 424         """ |  | 
| 425         return reflect.qual(self.factory.__class__) |  | 
| 426 |  | 
| 427 |  | 
| 428     def getHost(self): |  | 
| 429         """ |  | 
| 430         Returns an IPv4Address. |  | 
| 431 |  | 
| 432         This indicates the server's address. |  | 
| 433         """ |  | 
| 434         return address.IPv4Address('TCP', *(self.socket.getsockname() + |  | 
| 435                                             ('INET',))) |  | 
| 436 |  | 
| 437 |  | 
| 438     def cbAccept(self, rc, bytes, evt): |  | 
| 439         self.handleAccept(rc, evt) |  | 
| 440         if not (self.disconnecting or self.disconnected): |  | 
| 441             self.doAccept() |  | 
| 442 |  | 
| 443 |  | 
| 444     def handleAccept(self, rc, evt): |  | 
| 445         if self.disconnecting or self.disconnected: |  | 
| 446             return False |  | 
| 447 |  | 
| 448         # possible errors: |  | 
| 449         # (WSAEMFILE, WSAENOBUFS, WSAENFILE, WSAENOMEM, WSAECONNABORTED) |  | 
| 450         if rc: |  | 
| 451             log.msg("Could not accept new connection -- %s (%s)" % |  | 
| 452                     (errno.errorcode.get(rc, 'unknown error'), rc)) |  | 
| 453             return False |  | 
| 454         else: |  | 
| 455             evt.newskt.setsockopt(socket.SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, |  | 
| 456                                   struct.pack('I', self.socket.fileno())) |  | 
| 457             family, lAddr, rAddr = _iocp.get_accept_addrs(evt.newskt.fileno(), |  | 
| 458                                                           evt.buff) |  | 
| 459             assert family == self.addressFamily |  | 
| 460 |  | 
| 461             protocol = self.factory.buildProtocol( |  | 
| 462                 address._ServerFactoryIPv4Address('TCP', rAddr[0], rAddr[1])) |  | 
| 463             if protocol is None: |  | 
| 464                 evt.newskt.close() |  | 
| 465             else: |  | 
| 466                 s = self.sessionno |  | 
| 467                 self.sessionno = s+1 |  | 
| 468                 transport = Server(evt.newskt, protocol, |  | 
| 469                         address.IPv4Address('TCP', rAddr[0], rAddr[1], 'INET'), |  | 
| 470                         address.IPv4Address('TCP', lAddr[0], lAddr[1], 'INET'), |  | 
| 471                         s) |  | 
| 472                 protocol.makeConnection(transport) |  | 
| 473             return True |  | 
| 474 |  | 
| 475 |  | 
| 476     def doAccept(self): |  | 
| 477         numAccepts = 0 |  | 
| 478         while 1: |  | 
| 479             evt = _iocp.Event(self.cbAccept, self) |  | 
| 480 |  | 
| 481             # see AcceptEx documentation |  | 
| 482             evt.buff = buff = _iocp.AllocateReadBuffer(2 * (self.addrLen + 16)) |  | 
| 483 |  | 
| 484             evt.newskt = newskt = self.reactor.createSocket(self.addressFamily, |  | 
| 485                                                             self.socketType) |  | 
| 486             rc = _iocp.accept(self.socket.fileno(), newskt.fileno(), buff, evt) |  | 
| 487 |  | 
| 488             if (rc == ERROR_IO_PENDING |  | 
| 489                 or (not rc and numAccepts >= self.maxAccepts)): |  | 
| 490                 break |  | 
| 491             else: |  | 
| 492                 evt.ignore = True |  | 
| 493                 if not self.handleAccept(rc, evt): |  | 
| 494                     break |  | 
| 495             numAccepts += 1 |  | 
| 496 |  | 
| 497 |  | 
| OLD | NEW | 
|---|