OLD | NEW |
| (Empty) |
1 # -*- test-case-name: twisted.names.test.test_srvconnect -*- | |
2 # Copyright (c) 2001-2007 Twisted Matrix Laboratories. | |
3 # See LICENSE for details. | |
4 | |
5 import random | |
6 | |
7 from zope.interface import implements | |
8 | |
9 from twisted.internet import error, interfaces | |
10 | |
11 from twisted.names import client, dns | |
12 from twisted.names.error import DNSNameError | |
13 | |
14 class _SRVConnector_ClientFactoryWrapper: | |
15 def __init__(self, connector, wrappedFactory): | |
16 self.__connector = connector | |
17 self.__wrappedFactory = wrappedFactory | |
18 | |
19 def startedConnecting(self, connector): | |
20 self.__wrappedFactory.startedConnecting(self.__connector) | |
21 | |
22 def clientConnectionFailed(self, connector, reason): | |
23 self.__connector.connectionFailed(reason) | |
24 | |
25 def clientConnectionLost(self, connector, reason): | |
26 self.__connector.connectionLost(reason) | |
27 | |
28 def __getattr__(self, key): | |
29 return getattr(self.__wrappedFactory, key) | |
30 | |
31 class SRVConnector: | |
32 """A connector that looks up DNS SRV records. See RFC2782.""" | |
33 | |
34 implements(interfaces.IConnector) | |
35 | |
36 stopAfterDNS=0 | |
37 | |
38 def __init__(self, reactor, service, domain, factory, | |
39 protocol='tcp', connectFuncName='connectTCP', | |
40 connectFuncArgs=(), | |
41 connectFuncKwArgs={}, | |
42 ): | |
43 self.reactor = reactor | |
44 self.service = service | |
45 self.domain = domain | |
46 self.factory = factory | |
47 | |
48 self.protocol = protocol | |
49 self.connectFuncName = connectFuncName | |
50 self.connectFuncArgs = connectFuncArgs | |
51 self.connectFuncKwArgs = connectFuncKwArgs | |
52 | |
53 self.connector = None | |
54 self.servers = None | |
55 self.orderedServers = None # list of servers already used in this round | |
56 | |
57 def connect(self): | |
58 """Start connection to remote server.""" | |
59 self.factory.doStart() | |
60 self.factory.startedConnecting(self) | |
61 | |
62 if not self.servers: | |
63 if self.domain is None: | |
64 self.connectionFailed(error.DNSLookupError("Domain is not define
d.")) | |
65 return | |
66 d = client.lookupService('_%s._%s.%s' % (self.service, | |
67 self.protocol, | |
68 self.domain)) | |
69 d.addCallbacks(self._cbGotServers, self._ebGotServers) | |
70 d.addCallback(lambda x, self=self: self._reallyConnect()) | |
71 d.addErrback(self.connectionFailed) | |
72 elif self.connector is None: | |
73 self._reallyConnect() | |
74 else: | |
75 self.connector.connect() | |
76 | |
77 def _ebGotServers(self, failure): | |
78 failure.trap(DNSNameError) | |
79 | |
80 # Some DNS servers reply with NXDOMAIN when in fact there are | |
81 # just no SRV records for that domain. Act as if we just got an | |
82 # empty response and use fallback. | |
83 | |
84 self.servers = [] | |
85 self.orderedServers = [] | |
86 | |
87 def _cbGotServers(self, (answers, auth, add)): | |
88 if len(answers) == 1 and answers[0].type == dns.SRV \ | |
89 and answers[0].payload \ | |
90 and answers[0].payload.target == dns.Name('.'): | |
91 # decidedly not available | |
92 raise error.DNSLookupError("Service %s not available for domain %s." | |
93 % (repr(self.service), repr(self.domain))
) | |
94 | |
95 self.servers = [] | |
96 self.orderedServers = [] | |
97 for a in answers: | |
98 if a.type != dns.SRV or not a.payload: | |
99 continue | |
100 | |
101 self.orderedServers.append((a.payload.priority, a.payload.weight, | |
102 str(a.payload.target), a.payload.port)) | |
103 | |
104 def _serverCmp(self, a, b): | |
105 if a[0]!=b[0]: | |
106 return cmp(a[0], b[0]) | |
107 else: | |
108 return cmp(a[1], b[1]) | |
109 | |
110 def pickServer(self): | |
111 assert self.servers is not None | |
112 assert self.orderedServers is not None | |
113 | |
114 if not self.servers and not self.orderedServers: | |
115 # no SRV record, fall back.. | |
116 return self.domain, self.service | |
117 | |
118 if not self.servers and self.orderedServers: | |
119 # start new round | |
120 self.servers = self.orderedServers | |
121 self.orderedServers = [] | |
122 | |
123 assert self.servers | |
124 | |
125 self.servers.sort(self._serverCmp) | |
126 minPriority=self.servers[0][0] | |
127 | |
128 weightIndex = zip(xrange(len(self.servers)), [x[1] for x in self.servers | |
129 if x[0]==minPriority]) | |
130 weightSum = reduce(lambda x, y: (None, x[1]+y[1]), weightIndex, (None, 0
))[1] | |
131 rand = random.randint(0, weightSum) | |
132 | |
133 for index, weight in weightIndex: | |
134 weightSum -= weight | |
135 if weightSum <= 0: | |
136 chosen = self.servers[index] | |
137 del self.servers[index] | |
138 self.orderedServers.append(chosen) | |
139 | |
140 p, w, host, port = chosen | |
141 return host, port | |
142 | |
143 raise RuntimeError, 'Impossible %s pickServer result.' % self.__class__.
__name__ | |
144 | |
145 def _reallyConnect(self): | |
146 if self.stopAfterDNS: | |
147 self.stopAfterDNS=0 | |
148 return | |
149 | |
150 self.host, self.port = self.pickServer() | |
151 assert self.host is not None, 'Must have a host to connect to.' | |
152 assert self.port is not None, 'Must have a port to connect to.' | |
153 | |
154 connectFunc = getattr(self.reactor, self.connectFuncName) | |
155 self.connector=connectFunc( | |
156 self.host, self.port, | |
157 _SRVConnector_ClientFactoryWrapper(self, self.factory), | |
158 *self.connectFuncArgs, **self.connectFuncKwArgs) | |
159 | |
160 def stopConnecting(self): | |
161 """Stop attempting to connect.""" | |
162 if self.connector: | |
163 self.connector.stopConnecting() | |
164 else: | |
165 self.stopAfterDNS=1 | |
166 | |
167 def disconnect(self): | |
168 """Disconnect whatever our are state is.""" | |
169 if self.connector is not None: | |
170 self.connector.disconnect() | |
171 else: | |
172 self.stopConnecting() | |
173 | |
174 def getDestination(self): | |
175 assert self.connector | |
176 return self.connector.getDestination() | |
177 | |
178 def connectionFailed(self, reason): | |
179 self.factory.clientConnectionFailed(self, reason) | |
180 self.factory.doStop() | |
181 | |
182 def connectionLost(self, reason): | |
183 self.factory.clientConnectionLost(self, reason) | |
184 self.factory.doStop() | |
185 | |
OLD | NEW |