OLD | NEW |
| (Empty) |
1 # test-case-name: twisted.names.test.test_dns | |
2 # Copyright (c) 2001-2007 Twisted Matrix Laboratories. | |
3 # See LICENSE for details. | |
4 | |
5 """ | |
6 Tests for twisted.names.dns. | |
7 """ | |
8 | |
9 try: | |
10 from cStringIO import StringIO | |
11 except ImportError: | |
12 from StringIO import StringIO | |
13 | |
14 import struct | |
15 | |
16 from twisted.internet import address, task | |
17 from twisted.internet.error import CannotListenError | |
18 from twisted.trial import unittest | |
19 from twisted.names import dns | |
20 | |
21 from twisted.test import proto_helpers | |
22 | |
23 | |
24 | |
25 class RoundtripDNSTestCase(unittest.TestCase): | |
26 """Encoding and then decoding various objects.""" | |
27 | |
28 names = ["example.org", "go-away.fish.tv", "23strikesback.net"] | |
29 | |
30 def testName(self): | |
31 for n in self.names: | |
32 # encode the name | |
33 f = StringIO() | |
34 dns.Name(n).encode(f) | |
35 | |
36 # decode the name | |
37 f.seek(0, 0) | |
38 result = dns.Name() | |
39 result.decode(f) | |
40 self.assertEquals(result.name, n) | |
41 | |
42 def testQuery(self): | |
43 for n in self.names: | |
44 for dnstype in range(1, 17): | |
45 for dnscls in range(1, 5): | |
46 # encode the query | |
47 f = StringIO() | |
48 dns.Query(n, dnstype, dnscls).encode(f) | |
49 | |
50 # decode the result | |
51 f.seek(0, 0) | |
52 result = dns.Query() | |
53 result.decode(f) | |
54 self.assertEquals(result.name.name, n) | |
55 self.assertEquals(result.type, dnstype) | |
56 self.assertEquals(result.cls, dnscls) | |
57 | |
58 def testRR(self): | |
59 # encode the RR | |
60 f = StringIO() | |
61 dns.RRHeader("test.org", 3, 4, 17).encode(f) | |
62 | |
63 # decode the result | |
64 f.seek(0, 0) | |
65 result = dns.RRHeader() | |
66 result.decode(f) | |
67 self.assertEquals(str(result.name), "test.org") | |
68 self.assertEquals(result.type, 3) | |
69 self.assertEquals(result.cls, 4) | |
70 self.assertEquals(result.ttl, 17) | |
71 | |
72 | |
73 def testResources(self): | |
74 names = ( | |
75 "this.are.test.name", | |
76 "will.compress.will.this.will.name.will.hopefully", | |
77 "test.CASE.preSErVatIOn.YeAH", | |
78 "a.s.h.o.r.t.c.a.s.e.t.o.t.e.s.t", | |
79 "singleton" | |
80 ) | |
81 for s in names: | |
82 f = StringIO() | |
83 dns.SimpleRecord(s).encode(f) | |
84 f.seek(0, 0) | |
85 result = dns.SimpleRecord() | |
86 result.decode(f) | |
87 self.assertEquals(str(result.name), s) | |
88 | |
89 def testHashable(self): | |
90 records = [ | |
91 dns.Record_NS, dns.Record_MD, dns.Record_MF, dns.Record_CNAME, | |
92 dns.Record_MB, dns.Record_MG, dns.Record_MR, dns.Record_PTR, | |
93 dns.Record_DNAME, dns.Record_A, dns.Record_SOA, dns.Record_NULL, | |
94 dns.Record_WKS, dns.Record_SRV, dns.Record_AFSDB, dns.Record_RP, | |
95 dns.Record_HINFO, dns.Record_MINFO, dns.Record_MX, dns.Record_TXT, | |
96 dns.Record_AAAA, dns.Record_A6 | |
97 ] | |
98 | |
99 for k in records: | |
100 k1, k2 = k(), k() | |
101 hk1 = hash(k1) | |
102 hk2 = hash(k2) | |
103 self.assertEquals(hk1, hk2, "%s != %s (for %s)" % (hk1,hk2,k)) | |
104 | |
105 | |
106 | |
107 class MessageTestCase(unittest.TestCase): | |
108 def testEmptyMessage(self): | |
109 """ | |
110 Test that a message which has been truncated causes an EOFError to | |
111 be raised when it is parsed. | |
112 """ | |
113 msg = dns.Message() | |
114 self.assertRaises(EOFError, msg.fromStr, '') | |
115 | |
116 | |
117 def testEmptyQuery(self): | |
118 """ | |
119 Test that bytes representing an empty query message can be decoded | |
120 as such. | |
121 """ | |
122 msg = dns.Message() | |
123 msg.fromStr( | |
124 '\x01\x00' # Message ID | |
125 '\x00' # answer bit, opCode nibble, auth bit, trunc bit, recursive b
it | |
126 '\x00' # recursion bit, empty bit, empty bit, empty bit, response co
de nibble | |
127 '\x00\x00' # number of queries | |
128 '\x00\x00' # number of answers | |
129 '\x00\x00' # number of authorities | |
130 '\x00\x00' # number of additionals | |
131 ) | |
132 self.assertEquals(msg.id, 256) | |
133 self.failIf(msg.answer, "Message was not supposed to be an answer.") | |
134 self.assertEquals(msg.opCode, dns.OP_QUERY) | |
135 self.failIf(msg.auth, "Message was not supposed to be authoritative.") | |
136 self.failIf(msg.trunc, "Message was not supposed to be truncated.") | |
137 self.assertEquals(msg.queries, []) | |
138 self.assertEquals(msg.answers, []) | |
139 self.assertEquals(msg.authority, []) | |
140 self.assertEquals(msg.additional, []) | |
141 | |
142 | |
143 def testNULL(self): | |
144 bytes = ''.join([chr(i) for i in range(256)]) | |
145 rec = dns.Record_NULL(bytes) | |
146 rr = dns.RRHeader('testname', dns.NULL, payload=rec) | |
147 msg1 = dns.Message() | |
148 msg1.answers.append(rr) | |
149 s = StringIO() | |
150 msg1.encode(s) | |
151 s.seek(0, 0) | |
152 msg2 = dns.Message() | |
153 msg2.decode(s) | |
154 | |
155 self.failUnless(isinstance(msg2.answers[0].payload, dns.Record_NULL)) | |
156 self.assertEquals(msg2.answers[0].payload.payload, bytes) | |
157 | |
158 | |
159 | |
160 class TestController(object): | |
161 """ | |
162 Pretend to be a DNS query processor for a DNSDatagramProtocol. | |
163 | |
164 @ivar messages: the list of received messages. | |
165 @type messages: C{list} of (msg, protocol, address) | |
166 """ | |
167 | |
168 def __init__(self): | |
169 """ | |
170 Initialize the controller: create a list of messages. | |
171 """ | |
172 self.messages = [] | |
173 | |
174 | |
175 def messageReceived(self, msg, proto, addr): | |
176 """ | |
177 Save the message so that it can be checked during the tests. | |
178 """ | |
179 self.messages.append((msg, proto, addr)) | |
180 | |
181 | |
182 | |
183 class DatagramProtocolTestCase(unittest.TestCase): | |
184 """ | |
185 Test various aspects of L{dns.DNSDatagramProtocol}. | |
186 """ | |
187 | |
188 def setUp(self): | |
189 """ | |
190 Create a L{dns.DNSDatagramProtocol} with a deterministic clock. | |
191 """ | |
192 self.clock = task.Clock() | |
193 self.controller = TestController() | |
194 self.proto = dns.DNSDatagramProtocol(self.controller) | |
195 transport = proto_helpers.FakeDatagramTransport() | |
196 self.proto.makeConnection(transport) | |
197 self.proto.callLater = self.clock.callLater | |
198 | |
199 | |
200 def test_truncatedPacket(self): | |
201 """ | |
202 Test that when a short datagram is received, datagramReceived does | |
203 not raise an exception while processing it. | |
204 """ | |
205 self.proto.datagramReceived('', | |
206 address.IPv4Address('UDP', '127.0.0.1', 12345)) | |
207 self.assertEquals(self.controller.messages, []) | |
208 | |
209 | |
210 def test_simpleQuery(self): | |
211 """ | |
212 Test content received after a query. | |
213 """ | |
214 d = self.proto.query(('127.0.0.1', 21345), [dns.Query('foo')]) | |
215 self.assertEquals(len(self.proto.liveMessages.keys()), 1) | |
216 m = dns.Message() | |
217 m.id = self.proto.liveMessages.items()[0][0] | |
218 m.answers = [dns.RRHeader(payload=dns.Record_A(address='1.2.3.4'))] | |
219 called = False | |
220 def cb(result): | |
221 self.assertEquals(result.answers[0].payload.dottedQuad(), '1.2.3.4') | |
222 d.addCallback(cb) | |
223 self.proto.datagramReceived(m.toStr(), ('127.0.0.1', 21345)) | |
224 return d | |
225 | |
226 | |
227 def test_queryTimeout(self): | |
228 """ | |
229 Test that query timeouts after some seconds. | |
230 """ | |
231 d = self.proto.query(('127.0.0.1', 21345), [dns.Query('foo')]) | |
232 self.assertEquals(len(self.proto.liveMessages), 1) | |
233 self.clock.advance(10) | |
234 self.assertFailure(d, dns.DNSQueryTimeoutError) | |
235 self.assertEquals(len(self.proto.liveMessages), 0) | |
236 return d | |
237 | |
238 | |
239 def test_writeError(self): | |
240 """ | |
241 Exceptions raised by the transport's write method should be turned into | |
242 C{Failure}s passed to errbacks of the C{Deferred} returned by | |
243 L{DNSDatagramProtocol.query}. | |
244 """ | |
245 def writeError(message, addr): | |
246 raise RuntimeError("bar") | |
247 self.proto.transport.write = writeError | |
248 | |
249 d = self.proto.query(('127.0.0.1', 21345), [dns.Query('foo')]) | |
250 return self.assertFailure(d, RuntimeError) | |
251 | |
252 | |
253 def test_listenError(self): | |
254 """ | |
255 Exception L{CannotListenError} raised by C{listenUDP} should be turned | |
256 into a C{Failure} passed to errback of the C{Deferred} returned by | |
257 L{DNSDatagramProtocol.query}. | |
258 """ | |
259 def startListeningError(): | |
260 raise CannotListenError(None, None, None) | |
261 self.proto.startListening = startListeningError | |
262 # Clean up transport so that the protocol calls startListening again | |
263 self.proto.transport = None | |
264 | |
265 d = self.proto.query(('127.0.0.1', 21345), [dns.Query('foo')]) | |
266 return self.assertFailure(d, CannotListenError) | |
267 | |
268 | |
269 | |
270 class TestTCPController(TestController): | |
271 """ | |
272 Pretend to be a DNS query processor for a DNSProtocol. | |
273 """ | |
274 def connectionMade(self, proto): | |
275 pass | |
276 | |
277 | |
278 | |
279 class DNSProtocolTestCase(unittest.TestCase): | |
280 """ | |
281 Test various aspects of L{dns.DNSProtocol}. | |
282 """ | |
283 | |
284 def setUp(self): | |
285 """ | |
286 Create a L{dns.DNSProtocol} with a deterministic clock. | |
287 """ | |
288 self.clock = task.Clock() | |
289 controller = TestTCPController() | |
290 self.proto = dns.DNSProtocol(controller) | |
291 self.proto.makeConnection(proto_helpers.StringTransport()) | |
292 self.proto.callLater = self.clock.callLater | |
293 | |
294 | |
295 def test_queryTimeout(self): | |
296 """ | |
297 Test that query timeouts after some seconds. | |
298 """ | |
299 d = self.proto.query([dns.Query('foo')]) | |
300 self.assertEquals(len(self.proto.liveMessages), 1) | |
301 self.clock.advance(60) | |
302 self.assertFailure(d, dns.DNSQueryTimeoutError) | |
303 self.assertEquals(len(self.proto.liveMessages), 0) | |
304 return d | |
305 | |
306 | |
307 def test_simpleQuery(self): | |
308 """ | |
309 Test content received after a query. | |
310 """ | |
311 d = self.proto.query([dns.Query('foo')]) | |
312 self.assertEquals(len(self.proto.liveMessages.keys()), 1) | |
313 m = dns.Message() | |
314 m.id = self.proto.liveMessages.items()[0][0] | |
315 m.answers = [dns.RRHeader(payload=dns.Record_A(address='1.2.3.4'))] | |
316 called = False | |
317 def cb(result): | |
318 self.assertEquals(result.answers[0].payload.dottedQuad(), '1.2.3.4') | |
319 d.addCallback(cb) | |
320 s = m.toStr() | |
321 s = struct.pack('!H', len(s)) + s | |
322 self.proto.dataReceived(s) | |
323 return d | |
324 | |
325 | |
326 def test_writeError(self): | |
327 """ | |
328 Exceptions raised by the transport's write method should be turned into | |
329 C{Failure}s passed to errbacks of the C{Deferred} returned by | |
330 L{DNSProtocol.query}. | |
331 """ | |
332 def writeError(message): | |
333 raise RuntimeError("bar") | |
334 self.proto.transport.write = writeError | |
335 | |
336 d = self.proto.query([dns.Query('foo')]) | |
337 return self.assertFailure(d, RuntimeError) | |
338 | |
339 | |
340 | |
OLD | NEW |