OLD | NEW |
| (Empty) |
1 # Copyright (c) 2008 Twisted Matrix Laboratories. | |
2 # See LICENSE for details. | |
3 | |
4 | |
5 """ | |
6 UDP support for IOCP reactor | |
7 """ | |
8 | |
9 from twisted.internet import defer, address, error, interfaces | |
10 from twisted.internet.abstract import isIPAddress | |
11 from twisted.python import log, reflect, failure | |
12 | |
13 from zope.interface import implements | |
14 import socket, operator, struct, warnings, errno | |
15 | |
16 from twisted.internet.iocpreactor.const import ERROR_IO_PENDING | |
17 from twisted.internet.iocpreactor.const import ERROR_CONNECTION_REFUSED | |
18 from twisted.internet.iocpreactor.const import ERROR_PORT_UNREACHABLE | |
19 from twisted.internet.iocpreactor.interfaces import IReadWriteHandle | |
20 from twisted.internet.iocpreactor import iocpsupport as _iocp, abstract | |
21 | |
22 | |
23 | |
24 class Port(abstract.FileHandle): | |
25 """ | |
26 UDP port, listening for packets. | |
27 """ | |
28 | |
29 implements(IReadWriteHandle, interfaces.IUDPTransport, | |
30 interfaces.ISystemHandle) | |
31 | |
32 addressFamily = socket.AF_INET | |
33 socketType = socket.SOCK_DGRAM | |
34 maxThroughput = 256 * 1024 # max bytes we read in one eventloop iteration | |
35 dynamicReadBuffers = False | |
36 | |
37 # Actual port number being listened on, only set to a non-None | |
38 # value when we are actually listening. | |
39 _realPortNumber = None | |
40 | |
41 | |
42 def __init__(self, port, proto, interface='', maxPacketSize=8192, | |
43 reactor=None): | |
44 """ | |
45 Initialize with a numeric port to listen on. | |
46 """ | |
47 self.port = port | |
48 self.protocol = proto | |
49 self.readBufferSize = maxPacketSize | |
50 self.interface = interface | |
51 self.setLogStr() | |
52 self._connectedAddr = None | |
53 | |
54 abstract.FileHandle.__init__(self, reactor) | |
55 | |
56 skt = socket.socket(self.addressFamily, self.socketType) | |
57 addrLen = _iocp.maxAddrLen(skt.fileno()) | |
58 self.addressBuffer = _iocp.AllocateReadBuffer(addrLen) | |
59 | |
60 | |
61 def __repr__(self): | |
62 if self._realPortNumber is not None: | |
63 return ("<%s on %s>" % | |
64 (self.protocol.__class__, self._realPortNumber)) | |
65 else: | |
66 return "<%s not connected>" % (self.protocol.__class__,) | |
67 | |
68 | |
69 def getHandle(self): | |
70 """ | |
71 Return a socket object. | |
72 """ | |
73 return self.socket | |
74 | |
75 | |
76 def startListening(self): | |
77 """ | |
78 Create and bind my socket, and begin listening on it. | |
79 | |
80 This is called on unserialization, and must be called after creating a | |
81 server to begin listening on the specified port. | |
82 """ | |
83 self._bindSocket() | |
84 self._connectToProtocol() | |
85 | |
86 | |
87 def createSocket(self): | |
88 return self.reactor.createSocket(self.addressFamily, self.socketType) | |
89 | |
90 | |
91 def _bindSocket(self): | |
92 try: | |
93 skt = self.createSocket() | |
94 skt.bind((self.interface, self.port)) | |
95 except socket.error, le: | |
96 raise error.CannotListenError, (self.interface, self.port, le) | |
97 | |
98 # Make sure that if we listened on port 0, we update that to | |
99 # reflect what the OS actually assigned us. | |
100 self._realPortNumber = skt.getsockname()[1] | |
101 | |
102 log.msg("%s starting on %s" % | |
103 (self.protocol.__class__, self._realPortNumber)) | |
104 | |
105 self.connected = True | |
106 self.socket = skt | |
107 self.getFileHandle = self.socket.fileno | |
108 | |
109 | |
110 def _connectToProtocol(self): | |
111 self.protocol.makeConnection(self) | |
112 self.startReading() | |
113 self.reactor.addActiveHandle(self) | |
114 | |
115 | |
116 def cbRead(self, rc, bytes, evt): | |
117 if self.reading: | |
118 self.handleRead(rc, bytes, evt) | |
119 self.doRead() | |
120 | |
121 | |
122 def handleRead(self, rc, bytes, evt): | |
123 if rc in (errno.WSAECONNREFUSED, errno.WSAECONNRESET, | |
124 ERROR_CONNECTION_REFUSED, ERROR_PORT_UNREACHABLE): | |
125 if self._connectedAddr: | |
126 self.protocol.connectionRefused() | |
127 elif rc: | |
128 log.msg("error in recvfrom -- %s (%s)" % | |
129 (errno.errorcode.get(rc, 'unknown error'), rc)) | |
130 else: | |
131 try: | |
132 self.protocol.datagramReceived(str(evt.buff[:bytes]), | |
133 _iocp.makesockaddr(evt.addr_buff)) | |
134 except: | |
135 log.err() | |
136 | |
137 | |
138 def doRead(self): | |
139 read = 0 | |
140 while self.reading: | |
141 evt = _iocp.Event(self.cbRead, self) | |
142 | |
143 evt.buff = buff = self._readBuffers[0] | |
144 evt.addr_buff = addr_buff = self.addressBuffer | |
145 rc, bytes = _iocp.recvfrom(self.getFileHandle(), buff, | |
146 addr_buff, evt) | |
147 | |
148 if (rc == ERROR_IO_PENDING | |
149 or (not rc and read >= self.maxThroughput)): | |
150 break | |
151 else: | |
152 evt.ignore = True | |
153 self.handleRead(rc, bytes, evt) | |
154 read += bytes | |
155 | |
156 | |
157 def write(self, datagram, addr=None): | |
158 """ | |
159 Write a datagram. | |
160 | |
161 @param addr: should be a tuple (ip, port), can be None in connected | |
162 mode. | |
163 """ | |
164 if self._connectedAddr: | |
165 assert addr in (None, self._connectedAddr) | |
166 try: | |
167 return self.socket.send(datagram) | |
168 except socket.error, se: | |
169 no = se.args[0] | |
170 if no == errno.WSAEINTR: | |
171 return self.write(datagram) | |
172 elif no == errno.WSAEMSGSIZE: | |
173 raise error.MessageLengthError, "message too long" | |
174 elif no in (errno.WSAECONNREFUSED, errno.WSAECONNRESET, | |
175 ERROR_CONNECTION_REFUSED, ERROR_PORT_UNREACHABLE): | |
176 self.protocol.connectionRefused() | |
177 else: | |
178 raise | |
179 else: | |
180 assert addr != None | |
181 if not addr[0].replace(".", "").isdigit(): | |
182 warnings.warn("Please only pass IPs to write(), not hostnames", | |
183 DeprecationWarning, stacklevel=2) | |
184 try: | |
185 return self.socket.sendto(datagram, addr) | |
186 except socket.error, se: | |
187 no = se.args[0] | |
188 if no == errno.WSAEINTR: | |
189 return self.write(datagram, addr) | |
190 elif no == errno.WSAEMSGSIZE: | |
191 raise error.MessageLengthError, "message too long" | |
192 elif no in (errno.WSAECONNREFUSED, errno.WSAECONNRESET, | |
193 ERROR_CONNECTION_REFUSED, ERROR_PORT_UNREACHABLE): | |
194 # in non-connected UDP ECONNREFUSED is platform dependent, | |
195 # I think and the info is not necessarily useful. | |
196 # Nevertheless maybe we should call connectionRefused? XXX | |
197 return | |
198 else: | |
199 raise | |
200 | |
201 | |
202 def writeSequence(self, seq, addr): | |
203 self.write("".join(seq), addr) | |
204 | |
205 | |
206 def connect(self, host, port): | |
207 """ | |
208 'Connect' to remote server. | |
209 """ | |
210 if self._connectedAddr: | |
211 raise RuntimeError( | |
212 "already connected, reconnecting is not currently supported " | |
213 "(talk to itamar if you want this)") | |
214 if not isIPAddress(host): | |
215 raise ValueError, "please pass only IP addresses, not domain names" | |
216 self._connectedAddr = (host, port) | |
217 self.socket.connect((host, port)) | |
218 | |
219 | |
220 def _loseConnection(self): | |
221 self.stopReading() | |
222 self.reactor.removeActiveHandle(self) | |
223 if self.connected: # actually means if we are *listening* | |
224 from twisted.internet import reactor | |
225 reactor.callLater(0, self.connectionLost) | |
226 | |
227 | |
228 def stopListening(self): | |
229 if self.connected: | |
230 result = self.d = defer.Deferred() | |
231 else: | |
232 result = None | |
233 self._loseConnection() | |
234 return result | |
235 | |
236 | |
237 def loseConnection(self): | |
238 warnings.warn("Please use stopListening() to disconnect port", | |
239 DeprecationWarning, stacklevel=2) | |
240 self.stopListening() | |
241 | |
242 | |
243 def connectionLost(self, reason=None): | |
244 """ | |
245 Cleans up my socket. | |
246 """ | |
247 log.msg('(Port %s Closed)' % self._realPortNumber) | |
248 self._realPortNumber = None | |
249 self.stopReading() | |
250 if hasattr(self, "protocol"): | |
251 # we won't have attribute in ConnectedPort, in cases | |
252 # where there was an error in connection process | |
253 self.protocol.doStop() | |
254 self.connected = False | |
255 self.disconnected = True | |
256 self.socket.close() | |
257 del self.socket | |
258 del self.getFileHandle | |
259 if hasattr(self, "d"): | |
260 self.d.callback(None) | |
261 del self.d | |
262 | |
263 | |
264 def setLogStr(self): | |
265 self.logstr = reflect.qual(self.protocol.__class__) + " (UDP)" | |
266 | |
267 | |
268 def logPrefix(self): | |
269 """ | |
270 Returns the name of my class, to prefix log entries with. | |
271 """ | |
272 return self.logstr | |
273 | |
274 | |
275 def getHost(self): | |
276 """ | |
277 Returns an IPv4Address. | |
278 | |
279 This indicates the address from which I am connecting. | |
280 """ | |
281 return address.IPv4Address('UDP', *(self.socket.getsockname() + | |
282 ('INET_UDP',))) | |
283 | |
284 | |
285 | |
286 class MulticastMixin: | |
287 """ | |
288 Implement multicast functionality. | |
289 """ | |
290 | |
291 | |
292 def getOutgoingInterface(self): | |
293 i = self.socket.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF) | |
294 return socket.inet_ntoa(struct.pack("@i", i)) | |
295 | |
296 | |
297 def setOutgoingInterface(self, addr): | |
298 """ | |
299 Returns Deferred of success. | |
300 """ | |
301 return self.reactor.resolve(addr).addCallback(self._setInterface) | |
302 | |
303 | |
304 def _setInterface(self, addr): | |
305 i = socket.inet_aton(addr) | |
306 self.socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, i) | |
307 return 1 | |
308 | |
309 | |
310 def getLoopbackMode(self): | |
311 return self.socket.getsockopt(socket.IPPROTO_IP, | |
312 socket.IP_MULTICAST_LOOP) | |
313 | |
314 | |
315 def setLoopbackMode(self, mode): | |
316 mode = struct.pack("b", operator.truth(mode)) | |
317 self.socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, | |
318 mode) | |
319 | |
320 | |
321 def getTTL(self): | |
322 return self.socket.getsockopt(socket.IPPROTO_IP, | |
323 socket.IP_MULTICAST_TTL) | |
324 | |
325 | |
326 def setTTL(self, ttl): | |
327 ttl = struct.pack("B", ttl) | |
328 self.socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) | |
329 | |
330 | |
331 def joinGroup(self, addr, interface=""): | |
332 """ | |
333 Join a multicast group. Returns Deferred of success. | |
334 """ | |
335 return self.reactor.resolve(addr).addCallback(self._joinAddr1, | |
336 interface, 1) | |
337 | |
338 | |
339 def _joinAddr1(self, addr, interface, join): | |
340 return self.reactor.resolve(interface).addCallback(self._joinAddr2, | |
341 addr, join) | |
342 | |
343 | |
344 def _joinAddr2(self, interface, addr, join): | |
345 addr = socket.inet_aton(addr) | |
346 interface = socket.inet_aton(interface) | |
347 if join: | |
348 cmd = socket.IP_ADD_MEMBERSHIP | |
349 else: | |
350 cmd = socket.IP_DROP_MEMBERSHIP | |
351 try: | |
352 self.socket.setsockopt(socket.IPPROTO_IP, cmd, addr + interface) | |
353 except socket.error, e: | |
354 return failure.Failure(error.MulticastJoinError(addr, interface, | |
355 *e.args)) | |
356 | |
357 | |
358 def leaveGroup(self, addr, interface=""): | |
359 """ | |
360 Leave multicast group, return Deferred of success. | |
361 """ | |
362 return self.reactor.resolve(addr).addCallback(self._joinAddr1, | |
363 interface, 0) | |
364 | |
365 | |
366 | |
367 class MulticastPort(MulticastMixin, Port): | |
368 """ | |
369 UDP Port that supports multicasting. | |
370 """ | |
371 | |
372 implements(interfaces.IMulticastTransport) | |
373 | |
374 | |
375 def __init__(self, port, proto, interface='', maxPacketSize=8192, | |
376 reactor=None, listenMultiple=False): | |
377 Port.__init__(self, port, proto, interface, maxPacketSize, reactor) | |
378 self.listenMultiple = listenMultiple | |
379 | |
380 | |
381 def createSocket(self): | |
382 skt = Port.createSocket(self) | |
383 if self.listenMultiple: | |
384 skt.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
385 if hasattr(socket, "SO_REUSEPORT"): | |
386 skt.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | |
387 return skt | |
388 | |
389 | |
OLD | NEW |