Index: tests/rietveld_test.py |
diff --git a/tests/rietveld_test.py b/tests/rietveld_test.py |
index 7bcb9bcbcadc946c3b06660efa67ddb6260c836f..880d17e70637bdb740c499595bcd13f2860a021d 100755 |
--- a/tests/rietveld_test.py |
+++ b/tests/rietveld_test.py |
@@ -5,19 +5,24 @@ |
"""Unit tests for rietveld.py.""" |
+import httplib |
import logging |
import os |
import socket |
import ssl |
+import StringIO |
import sys |
+import tempfile |
import time |
import traceback |
import unittest |
+import urllib2 |
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
from testing_support.patches_data import GIT, RAW |
from testing_support import auto_stub |
+from third_party import httplib2 |
import patch |
import rietveld |
@@ -490,6 +495,104 @@ class DefaultTimeoutTest(auto_stub.TestCase): |
self.rietveld.post('/api/1234', [('key', 'data')]) |
self.assertNotEqual(self.sleep_time, 0) |
+ |
+class OAuthRpcServerTest(auto_stub.TestCase): |
+ def setUp(self): |
+ super(OAuthRpcServerTest, self).setUp() |
+ self.rpc_server = rietveld.OAuthRpcServer( |
+ 'http://www.example.com', 'foo', 'bar') |
+ |
+ def set_mock_response(self, status): |
+ def mock_http_request(*args, **kwargs): |
+ return (httplib2.Response({'status': status}), 'body') |
+ self.mock(self.rpc_server._http, 'request', mock_http_request) |
+ |
+ def test_404(self): |
+ self.set_mock_response(404) |
+ with self.assertRaises(urllib2.HTTPError) as ctx: |
+ self.rpc_server.Send('/foo') |
+ self.assertEquals(404, ctx.exception.code) |
+ |
+ def test_200(self): |
+ self.set_mock_response(200) |
+ ret = self.rpc_server.Send('/foo') |
+ self.assertEquals('body', ret) |
+ |
+ |
+class RietveldOAuthRpcServerTest(auto_stub.TestCase): |
+ def setUp(self): |
+ super(RietveldOAuthRpcServerTest, self).setUp() |
+ with tempfile.NamedTemporaryFile() as private_key_file: |
+ self.rietveld = rietveld.JwtOAuth2Rietveld( |
+ 'http://www.example.com', 'foo', private_key_file.name, maxtries=2) |
+ |
+ self.mock(time, 'sleep', lambda duration: None) |
+ |
+ def test_retries_500(self): |
+ urls = [] |
+ def mock_http_request(url, *args, **kwargs): |
+ urls.append(url) |
+ return (httplib2.Response({'status': 500}), 'body') |
+ self.mock(self.rietveld.rpc_server._http, 'request', mock_http_request) |
+ |
+ with self.assertRaises(urllib2.HTTPError) as ctx: |
+ self.rietveld.get('/foo') |
+ self.assertEquals(500, ctx.exception.code) |
+ |
+ self.assertEqual(2, len(urls)) # maxtries was 2 |
+ self.assertEqual(['https://www.example.com/foo'] * 2, urls) |
+ |
+ def test_does_not_retry_404(self): |
+ urls = [] |
+ def mock_http_request(url, *args, **kwargs): |
+ urls.append(url) |
+ return (httplib2.Response({'status': 404}), 'body') |
+ self.mock(self.rietveld.rpc_server._http, 'request', mock_http_request) |
+ |
+ with self.assertRaises(urllib2.HTTPError) as ctx: |
+ self.rietveld.get('/foo') |
+ self.assertEquals(404, ctx.exception.code) |
+ |
+ self.assertEqual(1, len(urls)) # doesn't retry |
+ |
+ def test_retries_404_when_requested(self): |
+ urls = [] |
+ def mock_http_request(url, *args, **kwargs): |
+ urls.append(url) |
+ return (httplib2.Response({'status': 404}), 'body') |
+ self.mock(self.rietveld.rpc_server._http, 'request', mock_http_request) |
+ |
+ with self.assertRaises(urllib2.HTTPError) as ctx: |
+ self.rietveld.get('/foo', retry_on_404=True) |
+ self.assertEquals(404, ctx.exception.code) |
+ |
+ self.assertEqual(2, len(urls)) # maxtries was 2 |
+ |
+ def test_socket_timeout(self): |
+ urls = [] |
+ def mock_http_request(url, *args, **kwargs): |
+ urls.append(url) |
+ raise socket.error('timed out') |
+ self.mock(self.rietveld.rpc_server._http, 'request', mock_http_request) |
+ |
+ with self.assertRaises(socket.error): |
+ self.rietveld.get('/foo') |
+ |
+ self.assertEqual(2, len(urls)) # maxtries was 2 |
+ |
+ def test_other_socket_error(self): |
+ urls = [] |
+ def mock_http_request(url, *args, **kwargs): |
+ urls.append(url) |
+ raise socket.error('something else') |
+ self.mock(self.rietveld.rpc_server._http, 'request', mock_http_request) |
+ |
+ with self.assertRaises(socket.error): |
+ self.rietveld.get('/foo') |
+ |
+ self.assertEqual(1, len(urls)) |
+ |
+ |
if __name__ == '__main__': |
logging.basicConfig(level=[ |
logging.ERROR, logging.INFO, logging.DEBUG][min(2, sys.argv.count('-v'))]) |