Index: net/tools/testserver/xmppserver_test.py |
diff --git a/net/tools/testserver/xmppserver_test.py b/net/tools/testserver/xmppserver_test.py |
index dd276e2d8db12bca14dff70462f0f423a34c66b7..338daec852390a39337c2997d1815f7aa6be44ec 100644 |
--- a/net/tools/testserver/xmppserver_test.py |
+++ b/net/tools/testserver/xmppserver_test.py |
@@ -164,11 +164,15 @@ class HandshakeTaskTest(unittest.TestCase): |
'foo', '', '', '') |
-class XmppConnectionTest(unittest.TestCase): |
+class FakeSocket(object): |
+ """A fake socket object used for testing. |
+ """ |
- def setUp(self): |
- self.connections = set() |
- self.data = [] |
+ def __init__(self): |
+ self._sent_data = [] |
+ |
+ def GetSentData(self): |
+ return self._sent_data |
# socket-like methods. |
def fileno(self): |
@@ -181,12 +185,19 @@ class XmppConnectionTest(unittest.TestCase): |
return ('', 0) |
def send(self, data): |
- self.data.append(data) |
+ self._sent_data.append(data) |
pass |
def close(self): |
pass |
+ |
+class XmppConnectionTest(unittest.TestCase): |
+ |
+ def setUp(self): |
+ self.connections = set() |
+ self.fake_socket = FakeSocket() |
+ |
# XmppConnection delegate methods. |
def OnXmppHandshakeDone(self, xmpp_connection): |
self.connections.add(xmpp_connection) |
@@ -201,27 +212,29 @@ class XmppConnectionTest(unittest.TestCase): |
def testBasic(self): |
socket_map = {} |
xmpp_connection = xmppserver.XmppConnection( |
- self, socket_map, self, ('', 0)) |
+ self.fake_socket, socket_map, self, ('', 0)) |
self.assertEqual(len(socket_map), 1) |
self.assertEqual(len(self.connections), 0) |
xmpp_connection.HandshakeDone(xmppserver.Jid('foo', 'bar')) |
self.assertEqual(len(socket_map), 1) |
self.assertEqual(len(self.connections), 1) |
+ sent_data = self.fake_socket.GetSentData() |
+ |
# Test subscription request. |
- self.assertEqual(len(self.data), 0) |
+ self.assertEqual(len(sent_data), 0) |
xmpp_connection.collect_incoming_data( |
'<iq><subscribe xmlns="google:push"></subscribe></iq>') |
- self.assertEqual(len(self.data), 1) |
+ self.assertEqual(len(sent_data), 1) |
# Test acks. |
xmpp_connection.collect_incoming_data('<iq type="result"/>') |
- self.assertEqual(len(self.data), 1) |
+ self.assertEqual(len(sent_data), 1) |
# Test notification. |
xmpp_connection.collect_incoming_data( |
'<message><push xmlns="google:push"/></message>') |
- self.assertEqual(len(self.data), 2) |
+ self.assertEqual(len(sent_data), 2) |
# Test unexpected stanza. |
def SendUnexpectedStanza(): |
@@ -240,35 +253,107 @@ class XmppConnectionTest(unittest.TestCase): |
self.assertEqual(len(socket_map), 0) |
self.assertEqual(len(self.connections), 0) |
-class XmppServerTest(unittest.TestCase): |
- # socket-like methods. |
- def fileno(self): |
- return 0 |
+class FakeXmppServer(xmppserver.XmppServer): |
+ """A fake XMPP server object used for testing. |
+ """ |
- def setblocking(self, int): |
- pass |
+ def __init__(self): |
+ self._socket_map = {} |
+ self._fake_sockets = set() |
+ self._next_jid_suffix = 1 |
+ xmppserver.XmppServer.__init__(self, self._socket_map, ('', 0)) |
- def getpeername(self): |
- return ('', 0) |
+ def GetSocketMap(self): |
+ return self._socket_map |
+ |
+ def GetFakeSockets(self): |
+ return self._fake_sockets |
+ |
+ def AddHandshakeCompletedConnection(self): |
+ """Creates a new XMPP connection and completes its handshake. |
+ """ |
+ xmpp_connection = self.handle_accept() |
+ jid = xmppserver.Jid('user%s' % self._next_jid_suffix, 'domain.com') |
+ self._next_jid_suffix += 1 |
+ xmpp_connection.HandshakeDone(jid) |
+ |
+ # XmppServer overrides. |
+ def accept(self): |
+ fake_socket = FakeSocket() |
+ self._fake_sockets.add(fake_socket) |
+ return (fake_socket, ('', 0)) |
def close(self): |
- pass |
+ self._fake_sockets.clear() |
+ xmppserver.XmppServer.close(self) |
- def testBasic(self): |
- class FakeXmppServer(xmppserver.XmppServer): |
- def accept(self2): |
- return (self, ('', 0)) |
- socket_map = {} |
- self.assertEqual(len(socket_map), 0) |
- xmpp_server = FakeXmppServer(socket_map, ('', 0)) |
+class XmppServerTest(unittest.TestCase): |
+ |
+ def setUp(self): |
+ self.xmpp_server = FakeXmppServer() |
+ |
+ def AssertSentDataLength(self, expected_length): |
+ for fake_socket in self.xmpp_server.GetFakeSockets(): |
+ self.assertEqual(len(fake_socket.GetSentData()), expected_length) |
+ |
+ def testBasic(self): |
+ socket_map = self.xmpp_server.GetSocketMap() |
self.assertEqual(len(socket_map), 1) |
- xmpp_server.handle_accept() |
+ self.xmpp_server.AddHandshakeCompletedConnection() |
self.assertEqual(len(socket_map), 2) |
- xmpp_server.close() |
+ self.xmpp_server.close() |
self.assertEqual(len(socket_map), 0) |
+ def testMakeNotification(self): |
+ notification = self.xmpp_server.MakeNotification('channel', 'data') |
+ expected_xml = ( |
+ '<message>' |
+ ' <push channel="channel" xmlns="google:push">' |
+ ' <data>%s</data>' |
+ ' </push>' |
+ '</message>' % base64.b64encode('data')) |
+ self.assertEqual(notification.toxml(), expected_xml) |
+ |
+ def testSendNotification(self): |
+ # Add a few connections. |
+ for _ in xrange(0, 7): |
+ self.xmpp_server.AddHandshakeCompletedConnection() |
+ |
+ self.assertEqual(len(self.xmpp_server.GetFakeSockets()), 7) |
+ |
+ self.AssertSentDataLength(0) |
+ self.xmpp_server.SendNotification('channel', 'data') |
+ self.AssertSentDataLength(1) |
+ |
+ def testEnableDisableNotifications(self): |
+ # Add a few connections. |
+ for _ in xrange(0, 5): |
+ self.xmpp_server.AddHandshakeCompletedConnection() |
+ |
+ self.assertEqual(len(self.xmpp_server.GetFakeSockets()), 5) |
+ |
+ self.AssertSentDataLength(0) |
+ self.xmpp_server.SendNotification('channel', 'data') |
+ self.AssertSentDataLength(1) |
+ |
+ self.xmpp_server.EnableNotifications() |
+ self.xmpp_server.SendNotification('channel', 'data') |
+ self.AssertSentDataLength(2) |
+ |
+ self.xmpp_server.DisableNotifications() |
+ self.xmpp_server.SendNotification('channel', 'data') |
+ self.AssertSentDataLength(2) |
+ |
+ self.xmpp_server.DisableNotifications() |
+ self.xmpp_server.SendNotification('channel', 'data') |
+ self.AssertSentDataLength(2) |
+ |
+ self.xmpp_server.EnableNotifications() |
+ self.xmpp_server.SendNotification('channel', 'data') |
+ self.AssertSentDataLength(3) |
+ |
if __name__ == '__main__': |
unittest.main() |