| Index: tests/rietveld_test.py
|
| diff --git a/tests/rietveld_test.py b/tests/rietveld_test.py
|
| index 7b8c5baeca590e59c71420bf1c8a14817a061a99..e136deb411987482b3757760af1c49b56aebaf71 100755
|
| --- a/tests/rietveld_test.py
|
| +++ b/tests/rietveld_test.py
|
| @@ -7,6 +7,7 @@
|
|
|
| import logging
|
| import os
|
| +import socket
|
| import ssl
|
| import sys
|
| import time
|
| @@ -434,23 +435,21 @@ class ProbeException(Exception):
|
| self.value = value
|
|
|
|
|
| -def MockSend(request_path, payload=None,
|
| - content_type="application/octet-stream",
|
| - timeout=None,
|
| - extra_headers=None,
|
| - **kwargs):
|
| +def MockSend(*args, **kwargs):
|
| """Mock upload.py's Send() to probe the timeout value"""
|
| - raise ProbeException(timeout)
|
| + raise ProbeException(kwargs['timeout'])
|
|
|
| -def MockSendTimeout(request_path, payload=None,
|
| - content_type="application/octet-stream",
|
| - timeout=None,
|
| - extra_headers=None,
|
| - **kwargs):
|
| +
|
| +def MockSendTimeout(*args, **kwargs):
|
| """Mock upload.py's Send() to raise SSLError"""
|
| raise ssl.SSLError('The read operation timed out')
|
|
|
|
|
| +def MockSocketConnectTimeout(*args, **kwargs):
|
| + """Mock upload.py's Send() to raise socket.timeout"""
|
| + raise socket.timeout('timed out')
|
| +
|
| +
|
| class DefaultTimeoutTest(auto_stub.TestCase):
|
| TESTED_CLASS = rietveld.Rietveld
|
|
|
| @@ -480,11 +479,17 @@ class DefaultTimeoutTest(auto_stub.TestCase):
|
| def test_ssl_timeout_post(self):
|
| self.mock(self.rietveld.rpc_server, 'Send', MockSendTimeout)
|
| self.mock(time, 'sleep', self.MockSleep)
|
| - self.sleep_time = 0
|
| with self.assertRaises(ssl.SSLError):
|
| self.rietveld.post('/api/1234', [('key', 'data')])
|
| self.assertNotEqual(self.sleep_time, 0)
|
|
|
| + def test_socket_connect_timeout_post(self):
|
| + self.mock(self.rietveld.rpc_server, 'Send', MockSocketConnectTimeout)
|
| + self.mock(time, 'sleep', self.MockSleep)
|
| + with self.assertRaises(socket.timeout):
|
| + self.rietveld.post('/api/1234', [('key', 'data')])
|
| + self.assertNotEqual(self.sleep_time, 0)
|
| +
|
| if __name__ == '__main__':
|
| logging.basicConfig(level=[
|
| logging.ERROR, logging.INFO, logging.DEBUG][min(2, sys.argv.count('-v'))])
|
|
|