| OLD | NEW |
| (Empty) |
| 1 # Copyright (c) 2008 Twisted Matrix Laboratories. | |
| 2 # See LICENSE for details. | |
| 3 | |
| 4 | |
| 5 """ | |
| 6 TCP support for IOCP reactor | |
| 7 """ | |
| 8 | |
| 9 from twisted.internet import interfaces, error, address, main, defer | |
| 10 from twisted.internet.abstract import isIPAddress | |
| 11 from twisted.internet.tcp import _SocketCloser, Connector as TCPConnector | |
| 12 from twisted.persisted import styles | |
| 13 from twisted.python import log, failure, reflect, util | |
| 14 | |
| 15 from zope.interface import implements | |
| 16 import socket, operator, errno, struct | |
| 17 | |
| 18 from twisted.internet.iocpreactor import iocpsupport as _iocp, abstract | |
| 19 from twisted.internet.iocpreactor.interfaces import IReadWriteHandle | |
| 20 from twisted.internet.iocpreactor.const import ERROR_IO_PENDING | |
| 21 from twisted.internet.iocpreactor.const import SO_UPDATE_CONNECT_CONTEXT | |
| 22 from twisted.internet.iocpreactor.const import SO_UPDATE_ACCEPT_CONTEXT | |
| 23 from twisted.internet.iocpreactor.const import ERROR_CONNECTION_REFUSED | |
| 24 from twisted.internet.iocpreactor.const import ERROR_NETWORK_UNREACHABLE | |
| 25 | |
| 26 # ConnectEx returns these. XXX: find out what it does for timeout | |
| 27 connectExErrors = { | |
| 28 ERROR_CONNECTION_REFUSED: errno.WSAECONNREFUSED, | |
| 29 ERROR_NETWORK_UNREACHABLE: errno.WSAENETUNREACH, | |
| 30 } | |
| 31 | |
| 32 | |
| 33 | |
| 34 class Connection(abstract.FileHandle, _SocketCloser): | |
| 35 implements(IReadWriteHandle, interfaces.ITCPTransport, | |
| 36 interfaces.ISystemHandle) | |
| 37 | |
| 38 | |
| 39 def __init__(self, sock, proto, reactor=None): | |
| 40 abstract.FileHandle.__init__(self, reactor) | |
| 41 self.socket = sock | |
| 42 self.getFileHandle = sock.fileno | |
| 43 self.protocol = proto | |
| 44 | |
| 45 | |
| 46 def getHandle(self): | |
| 47 return self.socket | |
| 48 | |
| 49 | |
| 50 def dataReceived(self, rbuffer): | |
| 51 # XXX: some day, we'll have protocols that can handle raw buffers | |
| 52 self.protocol.dataReceived(str(rbuffer)) | |
| 53 | |
| 54 | |
| 55 def readFromHandle(self, bufflist, evt): | |
| 56 return _iocp.recv(self.getFileHandle(), bufflist, evt) | |
| 57 | |
| 58 | |
| 59 def writeToHandle(self, buff, evt): | |
| 60 return _iocp.send(self.getFileHandle(), buff, evt) | |
| 61 | |
| 62 | |
| 63 def _closeWriteConnection(self): | |
| 64 try: | |
| 65 getattr(self.socket, self._socketShutdownMethod)(1) | |
| 66 except socket.error: | |
| 67 pass | |
| 68 p = interfaces.IHalfCloseableProtocol(self.protocol, None) | |
| 69 if p: | |
| 70 try: | |
| 71 p.writeConnectionLost() | |
| 72 except: | |
| 73 f = failure.Failure() | |
| 74 log.err() | |
| 75 self.connectionLost(f) | |
| 76 | |
| 77 | |
| 78 def readConnectionLost(self, reason): | |
| 79 p = interfaces.IHalfCloseableProtocol(self.protocol, None) | |
| 80 if p: | |
| 81 try: | |
| 82 p.readConnectionLost() | |
| 83 except: | |
| 84 log.err() | |
| 85 self.connectionLost(failure.Failure()) | |
| 86 else: | |
| 87 self.connectionLost(reason) | |
| 88 | |
| 89 | |
| 90 def connectionLost(self, reason): | |
| 91 abstract.FileHandle.connectionLost(self, reason) | |
| 92 self._closeSocket() | |
| 93 protocol = self.protocol | |
| 94 del self.protocol | |
| 95 del self.socket | |
| 96 del self.getFileHandle | |
| 97 protocol.connectionLost(reason) | |
| 98 | |
| 99 | |
| 100 def logPrefix(self): | |
| 101 """ | |
| 102 Return the prefix to log with when I own the logging thread. | |
| 103 """ | |
| 104 return self.logstr | |
| 105 | |
| 106 | |
| 107 def getTcpNoDelay(self): | |
| 108 return operator.truth(self.socket.getsockopt(socket.IPPROTO_TCP, | |
| 109 socket.TCP_NODELAY)) | |
| 110 | |
| 111 | |
| 112 def setTcpNoDelay(self, enabled): | |
| 113 self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled) | |
| 114 | |
| 115 | |
| 116 def getTcpKeepAlive(self): | |
| 117 return operator.truth(self.socket.getsockopt(socket.SOL_SOCKET, | |
| 118 socket.SO_KEEPALIVE)) | |
| 119 | |
| 120 | |
| 121 def setTcpKeepAlive(self, enabled): | |
| 122 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, enabled) | |
| 123 | |
| 124 | |
| 125 | |
| 126 class Client(Connection): | |
| 127 addressFamily = socket.AF_INET | |
| 128 socketType = socket.SOCK_STREAM | |
| 129 | |
| 130 | |
| 131 def __init__(self, host, port, bindAddress, connector, reactor): | |
| 132 self.connector = connector | |
| 133 self.addr = (host, port) | |
| 134 self.reactor = reactor | |
| 135 # ConnectEx documentation says socket _has_ to be bound | |
| 136 if bindAddress is None: | |
| 137 bindAddress = ('', 0) | |
| 138 | |
| 139 try: | |
| 140 try: | |
| 141 skt = reactor.createSocket(self.addressFamily, self.socketType) | |
| 142 except socket.error, se: | |
| 143 raise error.ConnectBindError(se[0], se[1]) | |
| 144 else: | |
| 145 try: | |
| 146 skt.bind(bindAddress) | |
| 147 except socket.error, se: | |
| 148 raise error.ConnectBindError(se[0], se[1]) | |
| 149 self.socket = skt | |
| 150 Connection.__init__(self, skt, None) | |
| 151 reactor.callLater(0, self.resolveAddress) | |
| 152 except error.ConnectBindError, err: | |
| 153 reactor.callLater(0, self.failIfNotConnected, err) | |
| 154 | |
| 155 | |
| 156 def resolveAddress(self): | |
| 157 if isIPAddress(self.addr[0]): | |
| 158 self._setRealAddress(self.addr[0]) | |
| 159 else: | |
| 160 d = self.reactor.resolve(self.addr[0]) | |
| 161 d.addCallbacks(self._setRealAddress, self.failIfNotConnected) | |
| 162 | |
| 163 | |
| 164 def _setRealAddress(self, address): | |
| 165 self.realAddress = (address, self.addr[1]) | |
| 166 self.doConnect() | |
| 167 | |
| 168 | |
| 169 def failIfNotConnected(self, err): | |
| 170 if (self.connected or self.disconnected or | |
| 171 not hasattr(self, "connector")): | |
| 172 return | |
| 173 | |
| 174 try: | |
| 175 self._closeSocket() | |
| 176 except AttributeError: | |
| 177 pass | |
| 178 else: | |
| 179 del self.socket, self.getFileHandle | |
| 180 self.reactor.removeActiveHandle(self) | |
| 181 | |
| 182 self.connector.connectionFailed(failure.Failure(err)) | |
| 183 del self.connector | |
| 184 | |
| 185 | |
| 186 def stopConnecting(self): | |
| 187 """ | |
| 188 Stop attempt to connect. | |
| 189 """ | |
| 190 self.failIfNotConnected(error.UserError()) | |
| 191 | |
| 192 | |
| 193 def cbConnect(self, rc, bytes, evt): | |
| 194 if rc: | |
| 195 rc = connectExErrors.get(rc, rc) | |
| 196 self.failIfNotConnected(error.getConnectError((rc, | |
| 197 errno.errorcode.get(rc, 'Unknown error')))) | |
| 198 else: | |
| 199 self.socket.setsockopt(socket.SOL_SOCKET, | |
| 200 SO_UPDATE_CONNECT_CONTEXT, | |
| 201 struct.pack('I', self.socket.fileno())) | |
| 202 self.protocol = self.connector.buildProtocol(self.getPeer()) | |
| 203 self.connected = True | |
| 204 self.logstr = self.protocol.__class__.__name__+",client" | |
| 205 self.protocol.makeConnection(self) | |
| 206 self.startReading() | |
| 207 | |
| 208 | |
| 209 def doConnect(self): | |
| 210 if not hasattr(self, "connector"): | |
| 211 # this happens if we connector.stopConnecting in | |
| 212 # factory.startedConnecting | |
| 213 return | |
| 214 assert _iocp.have_connectex | |
| 215 self.reactor.addActiveHandle(self) | |
| 216 evt = _iocp.Event(self.cbConnect, self) | |
| 217 | |
| 218 rc = _iocp.connect(self.socket.fileno(), self.realAddress, evt) | |
| 219 if rc == ERROR_IO_PENDING: | |
| 220 return | |
| 221 else: | |
| 222 evt.ignore = True | |
| 223 self.cbConnect(rc, 0, 0, evt) | |
| 224 | |
| 225 | |
| 226 def getHost(self): | |
| 227 """ | |
| 228 Returns an IPv4Address. | |
| 229 | |
| 230 This indicates the address from which I am connecting. | |
| 231 """ | |
| 232 return address.IPv4Address('TCP', *(self.socket.getsockname() + | |
| 233 ('INET',))) | |
| 234 | |
| 235 | |
| 236 def getPeer(self): | |
| 237 """ | |
| 238 Returns an IPv4Address. | |
| 239 | |
| 240 This indicates the address that I am connected to. | |
| 241 """ | |
| 242 return address.IPv4Address('TCP', *(self.addr + ('INET',))) | |
| 243 | |
| 244 | |
| 245 def __repr__(self): | |
| 246 s = ('<%s to %s at %x>' % | |
| 247 (self.__class__, self.addr, util.unsignedID(self))) | |
| 248 return s | |
| 249 | |
| 250 | |
| 251 def connectionLost(self, reason): | |
| 252 if not self.connected: | |
| 253 self.failIfNotConnected(error.ConnectError(string=reason)) | |
| 254 else: | |
| 255 Connection.connectionLost(self, reason) | |
| 256 self.connector.connectionLost(reason) | |
| 257 | |
| 258 | |
| 259 | |
| 260 class Server(Connection): | |
| 261 """ | |
| 262 Serverside socket-stream connection class. | |
| 263 | |
| 264 I am a serverside network connection transport; a socket which came from an | |
| 265 accept() on a server. | |
| 266 """ | |
| 267 | |
| 268 | |
| 269 def __init__(self, sock, protocol, clientAddr, serverAddr, sessionno): | |
| 270 """ | |
| 271 Server(sock, protocol, client, server, sessionno) | |
| 272 | |
| 273 Initialize me with a socket, a protocol, a descriptor for my peer (a | |
| 274 tuple of host, port describing the other end of the connection), an | |
| 275 instance of Port, and a session number. | |
| 276 """ | |
| 277 Connection.__init__(self, sock, protocol) | |
| 278 self.serverAddr = serverAddr | |
| 279 self.clientAddr = clientAddr | |
| 280 self.sessionno = sessionno | |
| 281 self.logstr = "%s,%s,%s" % (self.protocol.__class__.__name__, | |
| 282 sessionno, self.clientAddr.host) | |
| 283 self.repstr = "<%s #%s on %s>" % (self.protocol.__class__.__name__, | |
| 284 self.sessionno, self.serverAddr.port) | |
| 285 self.connected = True | |
| 286 self.startReading() | |
| 287 | |
| 288 | |
| 289 def __repr__(self): | |
| 290 """ | |
| 291 A string representation of this connection. | |
| 292 """ | |
| 293 return self.repstr | |
| 294 | |
| 295 | |
| 296 def getHost(self): | |
| 297 """ | |
| 298 Returns an IPv4Address. | |
| 299 | |
| 300 This indicates the server's address. | |
| 301 """ | |
| 302 return self.serverAddr | |
| 303 | |
| 304 | |
| 305 def getPeer(self): | |
| 306 """ | |
| 307 Returns an IPv4Address. | |
| 308 | |
| 309 This indicates the client's address. | |
| 310 """ | |
| 311 return self.clientAddr | |
| 312 | |
| 313 | |
| 314 | |
| 315 class Connector(TCPConnector): | |
| 316 def _makeTransport(self): | |
| 317 return Client(self.host, self.port, self.bindAddress, self, | |
| 318 self.reactor) | |
| 319 | |
| 320 | |
| 321 | |
| 322 class Port(styles.Ephemeral, _SocketCloser): | |
| 323 implements(interfaces.IListeningPort) | |
| 324 | |
| 325 connected = False | |
| 326 disconnected = False | |
| 327 disconnecting = False | |
| 328 addressFamily = socket.AF_INET | |
| 329 socketType = socket.SOCK_STREAM | |
| 330 | |
| 331 sessionno = 0 | |
| 332 | |
| 333 maxAccepts = 100 | |
| 334 | |
| 335 # Actual port number being listened on, only set to a non-None | |
| 336 # value when we are actually listening. | |
| 337 _realPortNumber = None | |
| 338 | |
| 339 | |
| 340 def __init__(self, port, factory, backlog=50, interface='', reactor=None): | |
| 341 self.port = port | |
| 342 self.factory = factory | |
| 343 self.backlog = backlog | |
| 344 self.interface = interface | |
| 345 self.reactor = reactor | |
| 346 | |
| 347 skt = socket.socket(self.addressFamily, self.socketType) | |
| 348 self.addrLen = _iocp.maxAddrLen(skt.fileno()) | |
| 349 | |
| 350 | |
| 351 def __repr__(self): | |
| 352 if self._realPortNumber is not None: | |
| 353 return "<%s of %s on %s>" % (self.__class__, | |
| 354 self.factory.__class__, | |
| 355 self._realPortNumber) | |
| 356 else: | |
| 357 return "<%s of %s (not listening)>" % (self.__class__, | |
| 358 self.factory.__class__) | |
| 359 | |
| 360 | |
| 361 def startListening(self): | |
| 362 try: | |
| 363 skt = self.reactor.createSocket(self.addressFamily, | |
| 364 self.socketType) | |
| 365 # TODO: resolve self.interface if necessary | |
| 366 skt.bind((self.interface, self.port)) | |
| 367 except socket.error, le: | |
| 368 raise error.CannotListenError, (self.interface, self.port, le) | |
| 369 | |
| 370 # Make sure that if we listened on port 0, we update that to | |
| 371 # reflect what the OS actually assigned us. | |
| 372 self._realPortNumber = skt.getsockname()[1] | |
| 373 | |
| 374 log.msg("%s starting on %s" % (self.factory.__class__, | |
| 375 self._realPortNumber)) | |
| 376 | |
| 377 self.factory.doStart() | |
| 378 skt.listen(self.backlog) | |
| 379 self.connected = True | |
| 380 self.reactor.addActiveHandle(self) | |
| 381 self.socket = skt | |
| 382 self.getFileHandle = self.socket.fileno | |
| 383 self.doAccept() | |
| 384 | |
| 385 | |
| 386 def loseConnection(self, connDone=failure.Failure(main.CONNECTION_DONE)): | |
| 387 """ | |
| 388 Stop accepting connections on this port. | |
| 389 | |
| 390 This will shut down my socket and call self.connectionLost(). | |
| 391 It returns a deferred which will fire successfully when the | |
| 392 port is actually closed. | |
| 393 """ | |
| 394 self.disconnecting = True | |
| 395 if self.connected: | |
| 396 self.deferred = defer.Deferred() | |
| 397 self.reactor.callLater(0, self.connectionLost, connDone) | |
| 398 return self.deferred | |
| 399 | |
| 400 stopListening = loseConnection | |
| 401 | |
| 402 | |
| 403 def connectionLost(self, reason): | |
| 404 """ | |
| 405 Cleans up my socket. | |
| 406 """ | |
| 407 log.msg('(Port %s Closed)' % self._realPortNumber) | |
| 408 self._realPortNumber = None | |
| 409 self.disconnected = True | |
| 410 self.reactor.removeActiveHandle(self) | |
| 411 self.connected = False | |
| 412 self._closeSocket() | |
| 413 del self.socket | |
| 414 del self.getFileHandle | |
| 415 self.factory.doStop() | |
| 416 if hasattr(self, "deferred"): | |
| 417 self.deferred.callback(None) | |
| 418 del self.deferred | |
| 419 | |
| 420 | |
| 421 def logPrefix(self): | |
| 422 """ | |
| 423 Returns the name of my class, to prefix log entries with. | |
| 424 """ | |
| 425 return reflect.qual(self.factory.__class__) | |
| 426 | |
| 427 | |
| 428 def getHost(self): | |
| 429 """ | |
| 430 Returns an IPv4Address. | |
| 431 | |
| 432 This indicates the server's address. | |
| 433 """ | |
| 434 return address.IPv4Address('TCP', *(self.socket.getsockname() + | |
| 435 ('INET',))) | |
| 436 | |
| 437 | |
| 438 def cbAccept(self, rc, bytes, evt): | |
| 439 self.handleAccept(rc, evt) | |
| 440 if not (self.disconnecting or self.disconnected): | |
| 441 self.doAccept() | |
| 442 | |
| 443 | |
| 444 def handleAccept(self, rc, evt): | |
| 445 if self.disconnecting or self.disconnected: | |
| 446 return False | |
| 447 | |
| 448 # possible errors: | |
| 449 # (WSAEMFILE, WSAENOBUFS, WSAENFILE, WSAENOMEM, WSAECONNABORTED) | |
| 450 if rc: | |
| 451 log.msg("Could not accept new connection -- %s (%s)" % | |
| 452 (errno.errorcode.get(rc, 'unknown error'), rc)) | |
| 453 return False | |
| 454 else: | |
| 455 evt.newskt.setsockopt(socket.SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, | |
| 456 struct.pack('I', self.socket.fileno())) | |
| 457 family, lAddr, rAddr = _iocp.get_accept_addrs(evt.newskt.fileno(), | |
| 458 evt.buff) | |
| 459 assert family == self.addressFamily | |
| 460 | |
| 461 protocol = self.factory.buildProtocol( | |
| 462 address._ServerFactoryIPv4Address('TCP', rAddr[0], rAddr[1])) | |
| 463 if protocol is None: | |
| 464 evt.newskt.close() | |
| 465 else: | |
| 466 s = self.sessionno | |
| 467 self.sessionno = s+1 | |
| 468 transport = Server(evt.newskt, protocol, | |
| 469 address.IPv4Address('TCP', rAddr[0], rAddr[1], 'INET'), | |
| 470 address.IPv4Address('TCP', lAddr[0], lAddr[1], 'INET'), | |
| 471 s) | |
| 472 protocol.makeConnection(transport) | |
| 473 return True | |
| 474 | |
| 475 | |
| 476 def doAccept(self): | |
| 477 numAccepts = 0 | |
| 478 while 1: | |
| 479 evt = _iocp.Event(self.cbAccept, self) | |
| 480 | |
| 481 # see AcceptEx documentation | |
| 482 evt.buff = buff = _iocp.AllocateReadBuffer(2 * (self.addrLen + 16)) | |
| 483 | |
| 484 evt.newskt = newskt = self.reactor.createSocket(self.addressFamily, | |
| 485 self.socketType) | |
| 486 rc = _iocp.accept(self.socket.fileno(), newskt.fileno(), buff, evt) | |
| 487 | |
| 488 if (rc == ERROR_IO_PENDING | |
| 489 or (not rc and numAccepts >= self.maxAccepts)): | |
| 490 break | |
| 491 else: | |
| 492 evt.ignore = True | |
| 493 if not self.handleAccept(rc, evt): | |
| 494 break | |
| 495 numAccepts += 1 | |
| 496 | |
| 497 | |
| OLD | NEW |