OLD | NEW |
| (Empty) |
1 #!/usr/bin/env python | |
2 # Copyright (c) 2011 The Chromium Authors. All rights reserved. | |
3 # Use of this source code is governed by a BSD-style license that can be | |
4 # found in the LICENSE file. | |
5 | |
6 """Tests exercising the various classes in xmppserver.py.""" | |
7 | |
8 import unittest | |
9 | |
10 import base64 | |
11 import xmppserver | |
12 | |
13 class XmlUtilsTest(unittest.TestCase): | |
14 | |
15 def testParseXml(self): | |
16 xml_text = """<foo xmlns=""><bar xmlns=""><baz/></bar></foo>""" | |
17 xml = xmppserver.ParseXml(xml_text) | |
18 self.assertEqual(xml.toxml(), xml_text) | |
19 | |
20 def testCloneXml(self): | |
21 xml = xmppserver.ParseXml('<foo/>') | |
22 xml_clone = xmppserver.CloneXml(xml) | |
23 xml_clone.setAttribute('bar', 'baz') | |
24 self.assertEqual(xml, xml) | |
25 self.assertEqual(xml_clone, xml_clone) | |
26 self.assertNotEqual(xml, xml_clone) | |
27 | |
28 def testCloneXmlUnlink(self): | |
29 xml_text = '<foo/>' | |
30 xml = xmppserver.ParseXml(xml_text) | |
31 xml_clone = xmppserver.CloneXml(xml) | |
32 xml.unlink() | |
33 self.assertEqual(xml.parentNode, None) | |
34 self.assertNotEqual(xml_clone.parentNode, None) | |
35 self.assertEqual(xml_clone.toxml(), xml_text) | |
36 | |
37 class StanzaParserTest(unittest.TestCase): | |
38 | |
39 def setUp(self): | |
40 self.stanzas = [] | |
41 | |
42 def FeedStanza(self, stanza): | |
43 # We can't append stanza directly because it is unlinked after | |
44 # this callback. | |
45 self.stanzas.append(stanza.toxml()) | |
46 | |
47 def testBasic(self): | |
48 parser = xmppserver.StanzaParser(self) | |
49 parser.FeedString('<foo') | |
50 self.assertEqual(len(self.stanzas), 0) | |
51 parser.FeedString('/><bar></bar>') | |
52 self.assertEqual(self.stanzas[0], '<foo/>') | |
53 self.assertEqual(self.stanzas[1], '<bar/>') | |
54 | |
55 def testStream(self): | |
56 parser = xmppserver.StanzaParser(self) | |
57 parser.FeedString('<stream') | |
58 self.assertEqual(len(self.stanzas), 0) | |
59 parser.FeedString(':stream foo="bar" xmlns:stream="baz">') | |
60 self.assertEqual(self.stanzas[0], | |
61 '<stream:stream foo="bar" xmlns:stream="baz"/>') | |
62 | |
63 def testNested(self): | |
64 parser = xmppserver.StanzaParser(self) | |
65 parser.FeedString('<foo') | |
66 self.assertEqual(len(self.stanzas), 0) | |
67 parser.FeedString(' bar="baz"') | |
68 parser.FeedString('><baz/><blah>meh</blah></foo>') | |
69 self.assertEqual(self.stanzas[0], | |
70 '<foo bar="baz"><baz/><blah>meh</blah></foo>') | |
71 | |
72 | |
73 class JidTest(unittest.TestCase): | |
74 | |
75 def testBasic(self): | |
76 jid = xmppserver.Jid('foo', 'bar.com') | |
77 self.assertEqual(str(jid), 'foo@bar.com') | |
78 | |
79 def testResource(self): | |
80 jid = xmppserver.Jid('foo', 'bar.com', 'resource') | |
81 self.assertEqual(str(jid), 'foo@bar.com/resource') | |
82 | |
83 def testGetBareJid(self): | |
84 jid = xmppserver.Jid('foo', 'bar.com', 'resource') | |
85 self.assertEqual(str(jid.GetBareJid()), 'foo@bar.com') | |
86 | |
87 | |
88 class IdGeneratorTest(unittest.TestCase): | |
89 | |
90 def testBasic(self): | |
91 id_generator = xmppserver.IdGenerator('foo') | |
92 for i in xrange(0, 100): | |
93 self.assertEqual('foo.%d' % i, id_generator.GetNextId()) | |
94 | |
95 | |
96 class HandshakeTaskTest(unittest.TestCase): | |
97 | |
98 def setUp(self): | |
99 self.Reset() | |
100 | |
101 def Reset(self): | |
102 self.data_received = 0 | |
103 self.handshake_done = False | |
104 self.jid = None | |
105 | |
106 def SendData(self, _): | |
107 self.data_received += 1 | |
108 | |
109 def SendStanza(self, _, unused=True): | |
110 self.data_received += 1 | |
111 | |
112 def HandshakeDone(self, jid): | |
113 self.handshake_done = True | |
114 self.jid = jid | |
115 | |
116 def DoHandshake(self, resource_prefix, resource, username, | |
117 initial_stream_domain, auth_domain, auth_stream_domain): | |
118 self.Reset() | |
119 handshake_task = ( | |
120 xmppserver.HandshakeTask(self, resource_prefix, True)) | |
121 stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>') | |
122 stream_xml.setAttribute('to', initial_stream_domain) | |
123 self.assertEqual(self.data_received, 0) | |
124 handshake_task.FeedStanza(stream_xml) | |
125 self.assertEqual(self.data_received, 2) | |
126 | |
127 if auth_domain: | |
128 username_domain = '%s@%s' % (username, auth_domain) | |
129 else: | |
130 username_domain = username | |
131 auth_string = base64.b64encode('\0%s\0bar' % username_domain) | |
132 auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string) | |
133 handshake_task.FeedStanza(auth_xml) | |
134 self.assertEqual(self.data_received, 3) | |
135 | |
136 stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>') | |
137 stream_xml.setAttribute('to', auth_stream_domain) | |
138 handshake_task.FeedStanza(stream_xml) | |
139 self.assertEqual(self.data_received, 5) | |
140 | |
141 bind_xml = xmppserver.ParseXml( | |
142 '<iq type="set"><bind><resource>%s</resource></bind></iq>' % resource) | |
143 handshake_task.FeedStanza(bind_xml) | |
144 self.assertEqual(self.data_received, 6) | |
145 | |
146 self.assertFalse(self.handshake_done) | |
147 | |
148 session_xml = xmppserver.ParseXml( | |
149 '<iq type="set"><session></session></iq>') | |
150 handshake_task.FeedStanza(session_xml) | |
151 self.assertEqual(self.data_received, 7) | |
152 | |
153 self.assertTrue(self.handshake_done) | |
154 | |
155 self.assertEqual(self.jid.username, username) | |
156 self.assertEqual(self.jid.domain, | |
157 auth_stream_domain or auth_domain or | |
158 initial_stream_domain) | |
159 self.assertEqual(self.jid.resource, | |
160 '%s.%s' % (resource_prefix, resource)) | |
161 | |
162 handshake_task.FeedStanza('<ignored/>') | |
163 self.assertEqual(self.data_received, 7) | |
164 | |
165 def DoHandshakeUnauthenticated(self, resource_prefix, resource, username, | |
166 initial_stream_domain): | |
167 self.Reset() | |
168 handshake_task = ( | |
169 xmppserver.HandshakeTask(self, resource_prefix, False)) | |
170 stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>') | |
171 stream_xml.setAttribute('to', initial_stream_domain) | |
172 self.assertEqual(self.data_received, 0) | |
173 handshake_task.FeedStanza(stream_xml) | |
174 self.assertEqual(self.data_received, 2) | |
175 | |
176 self.assertFalse(self.handshake_done) | |
177 | |
178 auth_string = base64.b64encode('\0%s\0bar' % username) | |
179 auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string) | |
180 handshake_task.FeedStanza(auth_xml) | |
181 self.assertEqual(self.data_received, 3) | |
182 | |
183 self.assertTrue(self.handshake_done) | |
184 | |
185 self.assertEqual(self.jid, None) | |
186 | |
187 handshake_task.FeedStanza('<ignored/>') | |
188 self.assertEqual(self.data_received, 3) | |
189 | |
190 def testBasic(self): | |
191 self.DoHandshake('resource_prefix', 'resource', | |
192 'foo', 'bar.com', 'baz.com', 'quux.com') | |
193 | |
194 def testDomainBehavior(self): | |
195 self.DoHandshake('resource_prefix', 'resource', | |
196 'foo', 'bar.com', 'baz.com', 'quux.com') | |
197 self.DoHandshake('resource_prefix', 'resource', | |
198 'foo', 'bar.com', 'baz.com', '') | |
199 self.DoHandshake('resource_prefix', 'resource', | |
200 'foo', 'bar.com', '', '') | |
201 self.DoHandshake('resource_prefix', 'resource', | |
202 'foo', '', '', '') | |
203 | |
204 def testBasicUnauthenticated(self): | |
205 self.DoHandshakeUnauthenticated('resource_prefix', 'resource', | |
206 'foo', 'bar.com') | |
207 | |
208 | |
209 class FakeSocket(object): | |
210 """A fake socket object used for testing. | |
211 """ | |
212 | |
213 def __init__(self): | |
214 self._sent_data = [] | |
215 | |
216 def GetSentData(self): | |
217 return self._sent_data | |
218 | |
219 # socket-like methods. | |
220 def fileno(self): | |
221 return 0 | |
222 | |
223 def setblocking(self, int): | |
224 pass | |
225 | |
226 def getpeername(self): | |
227 return ('', 0) | |
228 | |
229 def send(self, data): | |
230 self._sent_data.append(data) | |
231 pass | |
232 | |
233 def close(self): | |
234 pass | |
235 | |
236 | |
237 class XmppConnectionTest(unittest.TestCase): | |
238 | |
239 def setUp(self): | |
240 self.connections = set() | |
241 self.fake_socket = FakeSocket() | |
242 | |
243 # XmppConnection delegate methods. | |
244 def OnXmppHandshakeDone(self, xmpp_connection): | |
245 self.connections.add(xmpp_connection) | |
246 | |
247 def OnXmppConnectionClosed(self, xmpp_connection): | |
248 self.connections.discard(xmpp_connection) | |
249 | |
250 def ForwardNotification(self, unused_xmpp_connection, notification_stanza): | |
251 for connection in self.connections: | |
252 connection.ForwardNotification(notification_stanza) | |
253 | |
254 def testBasic(self): | |
255 socket_map = {} | |
256 xmpp_connection = xmppserver.XmppConnection( | |
257 self.fake_socket, socket_map, self, ('', 0), True) | |
258 self.assertEqual(len(socket_map), 1) | |
259 self.assertEqual(len(self.connections), 0) | |
260 xmpp_connection.HandshakeDone(xmppserver.Jid('foo', 'bar')) | |
261 self.assertEqual(len(socket_map), 1) | |
262 self.assertEqual(len(self.connections), 1) | |
263 | |
264 sent_data = self.fake_socket.GetSentData() | |
265 | |
266 # Test subscription request. | |
267 self.assertEqual(len(sent_data), 0) | |
268 xmpp_connection.collect_incoming_data( | |
269 '<iq><subscribe xmlns="google:push"></subscribe></iq>') | |
270 self.assertEqual(len(sent_data), 1) | |
271 | |
272 # Test acks. | |
273 xmpp_connection.collect_incoming_data('<iq type="result"/>') | |
274 self.assertEqual(len(sent_data), 1) | |
275 | |
276 # Test notification. | |
277 xmpp_connection.collect_incoming_data( | |
278 '<message><push xmlns="google:push"/></message>') | |
279 self.assertEqual(len(sent_data), 2) | |
280 | |
281 # Test unexpected stanza. | |
282 def SendUnexpectedStanza(): | |
283 xmpp_connection.collect_incoming_data('<foo/>') | |
284 self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza) | |
285 | |
286 # Test unexpected notifier command. | |
287 def SendUnexpectedNotifierCommand(): | |
288 xmpp_connection.collect_incoming_data( | |
289 '<iq><foo xmlns="google:notifier"/></iq>') | |
290 self.assertRaises(xmppserver.UnexpectedXml, | |
291 SendUnexpectedNotifierCommand) | |
292 | |
293 # Test close. | |
294 xmpp_connection.close() | |
295 self.assertEqual(len(socket_map), 0) | |
296 self.assertEqual(len(self.connections), 0) | |
297 | |
298 def testBasicUnauthenticated(self): | |
299 socket_map = {} | |
300 xmpp_connection = xmppserver.XmppConnection( | |
301 self.fake_socket, socket_map, self, ('', 0), False) | |
302 self.assertEqual(len(socket_map), 1) | |
303 self.assertEqual(len(self.connections), 0) | |
304 xmpp_connection.HandshakeDone(None) | |
305 self.assertEqual(len(socket_map), 0) | |
306 self.assertEqual(len(self.connections), 0) | |
307 | |
308 # Test unexpected stanza. | |
309 def SendUnexpectedStanza(): | |
310 xmpp_connection.collect_incoming_data('<foo/>') | |
311 self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza) | |
312 | |
313 # Test redundant close. | |
314 xmpp_connection.close() | |
315 self.assertEqual(len(socket_map), 0) | |
316 self.assertEqual(len(self.connections), 0) | |
317 | |
318 | |
319 class FakeXmppServer(xmppserver.XmppServer): | |
320 """A fake XMPP server object used for testing. | |
321 """ | |
322 | |
323 def __init__(self): | |
324 self._socket_map = {} | |
325 self._fake_sockets = set() | |
326 self._next_jid_suffix = 1 | |
327 xmppserver.XmppServer.__init__(self, self._socket_map, ('', 0)) | |
328 | |
329 def GetSocketMap(self): | |
330 return self._socket_map | |
331 | |
332 def GetFakeSockets(self): | |
333 return self._fake_sockets | |
334 | |
335 def AddHandshakeCompletedConnection(self): | |
336 """Creates a new XMPP connection and completes its handshake. | |
337 """ | |
338 xmpp_connection = self.handle_accept() | |
339 jid = xmppserver.Jid('user%s' % self._next_jid_suffix, 'domain.com') | |
340 self._next_jid_suffix += 1 | |
341 xmpp_connection.HandshakeDone(jid) | |
342 | |
343 # XmppServer overrides. | |
344 def accept(self): | |
345 fake_socket = FakeSocket() | |
346 self._fake_sockets.add(fake_socket) | |
347 return (fake_socket, ('', 0)) | |
348 | |
349 def close(self): | |
350 self._fake_sockets.clear() | |
351 xmppserver.XmppServer.close(self) | |
352 | |
353 | |
354 class XmppServerTest(unittest.TestCase): | |
355 | |
356 def setUp(self): | |
357 self.xmpp_server = FakeXmppServer() | |
358 | |
359 def AssertSentDataLength(self, expected_length): | |
360 for fake_socket in self.xmpp_server.GetFakeSockets(): | |
361 self.assertEqual(len(fake_socket.GetSentData()), expected_length) | |
362 | |
363 def testBasic(self): | |
364 socket_map = self.xmpp_server.GetSocketMap() | |
365 self.assertEqual(len(socket_map), 1) | |
366 self.xmpp_server.AddHandshakeCompletedConnection() | |
367 self.assertEqual(len(socket_map), 2) | |
368 self.xmpp_server.close() | |
369 self.assertEqual(len(socket_map), 0) | |
370 | |
371 def testMakeNotification(self): | |
372 notification = self.xmpp_server.MakeNotification('channel', 'data') | |
373 expected_xml = ( | |
374 '<message>' | |
375 ' <push channel="channel" xmlns="google:push">' | |
376 ' <data>%s</data>' | |
377 ' </push>' | |
378 '</message>' % base64.b64encode('data')) | |
379 self.assertEqual(notification.toxml(), expected_xml) | |
380 | |
381 def testSendNotification(self): | |
382 # Add a few connections. | |
383 for _ in xrange(0, 7): | |
384 self.xmpp_server.AddHandshakeCompletedConnection() | |
385 | |
386 self.assertEqual(len(self.xmpp_server.GetFakeSockets()), 7) | |
387 | |
388 self.AssertSentDataLength(0) | |
389 self.xmpp_server.SendNotification('channel', 'data') | |
390 self.AssertSentDataLength(1) | |
391 | |
392 def testEnableDisableNotifications(self): | |
393 # Add a few connections. | |
394 for _ in xrange(0, 5): | |
395 self.xmpp_server.AddHandshakeCompletedConnection() | |
396 | |
397 self.assertEqual(len(self.xmpp_server.GetFakeSockets()), 5) | |
398 | |
399 self.AssertSentDataLength(0) | |
400 self.xmpp_server.SendNotification('channel', 'data') | |
401 self.AssertSentDataLength(1) | |
402 | |
403 self.xmpp_server.EnableNotifications() | |
404 self.xmpp_server.SendNotification('channel', 'data') | |
405 self.AssertSentDataLength(2) | |
406 | |
407 self.xmpp_server.DisableNotifications() | |
408 self.xmpp_server.SendNotification('channel', 'data') | |
409 self.AssertSentDataLength(2) | |
410 | |
411 self.xmpp_server.DisableNotifications() | |
412 self.xmpp_server.SendNotification('channel', 'data') | |
413 self.AssertSentDataLength(2) | |
414 | |
415 self.xmpp_server.EnableNotifications() | |
416 self.xmpp_server.SendNotification('channel', 'data') | |
417 self.AssertSentDataLength(3) | |
418 | |
419 | |
420 if __name__ == '__main__': | |
421 unittest.main() | |
OLD | NEW |