Index: tools/telemetry/third_party/gsutilz/third_party/protorpc/protorpc/remote_test.py |
diff --git a/tools/telemetry/third_party/gsutilz/third_party/protorpc/protorpc/remote_test.py b/tools/telemetry/third_party/gsutilz/third_party/protorpc/protorpc/remote_test.py |
new file mode 100755 |
index 0000000000000000000000000000000000000000..354763326913dd739af2536d69e617c6ab71daff |
--- /dev/null |
+++ b/tools/telemetry/third_party/gsutilz/third_party/protorpc/protorpc/remote_test.py |
@@ -0,0 +1,926 @@ |
+#!/usr/bin/env python |
+# |
+# Copyright 2010 Google Inc. |
+# |
+# Licensed under the Apache License, Version 2.0 (the "License"); |
+# you may not use this file except in compliance with the License. |
+# You may obtain a copy of the License at |
+# |
+# http://www.apache.org/licenses/LICENSE-2.0 |
+# |
+# Unless required by applicable law or agreed to in writing, software |
+# distributed under the License is distributed on an "AS IS" BASIS, |
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
+# See the License for the specific language governing permissions and |
+# limitations under the License. |
+# |
+ |
+"""Tests for protorpc.remote.""" |
+ |
+__author__ = 'rafek@google.com (Rafe Kaplan)' |
+ |
+ |
+import sys |
+import types |
+import unittest |
+from wsgiref import headers |
+ |
+from protorpc import descriptor |
+from protorpc import message_types |
+from protorpc import messages |
+from protorpc import protobuf |
+from protorpc import protojson |
+from protorpc import remote |
+from protorpc import test_util |
+from protorpc import transport |
+ |
+import mox |
+ |
+ |
+class ModuleInterfaceTest(test_util.ModuleInterfaceTest, |
+ test_util.TestCase): |
+ |
+ MODULE = remote |
+ |
+ |
+class Request(messages.Message): |
+ """Test request message.""" |
+ |
+ value = messages.StringField(1) |
+ |
+ |
+class Response(messages.Message): |
+ """Test response message.""" |
+ |
+ value = messages.StringField(1) |
+ |
+ |
+class MyService(remote.Service): |
+ |
+ @remote.method(Request, Response) |
+ def remote_method(self, request): |
+ response = Response() |
+ response.value = request.value |
+ return response |
+ |
+ |
+class SimpleRequest(messages.Message): |
+ """Simple request message type used for tests.""" |
+ |
+ param1 = messages.StringField(1) |
+ param2 = messages.StringField(2) |
+ |
+ |
+class SimpleResponse(messages.Message): |
+ """Simple response message type used for tests.""" |
+ |
+ |
+class BasicService(remote.Service): |
+ """A basic service with decorated remote method.""" |
+ |
+ def __init__(self): |
+ self.request_ids = [] |
+ |
+ @remote.method(SimpleRequest, SimpleResponse) |
+ def remote_method(self, request): |
+ self.request_ids.append(id(request)) |
+ return SimpleResponse() |
+ |
+ |
+class RpcErrorTest(test_util.TestCase): |
+ |
+ def testFromStatus(self): |
+ for state in remote.RpcState: |
+ exception = remote.RpcError.from_state |
+ self.assertEquals(remote.ServerError, |
+ remote.RpcError.from_state('SERVER_ERROR')) |
+ |
+ |
+class ApplicationErrorTest(test_util.TestCase): |
+ |
+ def testErrorCode(self): |
+ self.assertEquals('blam', |
+ remote.ApplicationError('an error', 'blam').error_name) |
+ |
+ def testStr(self): |
+ self.assertEquals('an error', str(remote.ApplicationError('an error', 1))) |
+ |
+ def testRepr(self): |
+ self.assertEquals("ApplicationError('an error', 1)", |
+ repr(remote.ApplicationError('an error', 1))) |
+ |
+ self.assertEquals("ApplicationError('an error')", |
+ repr(remote.ApplicationError('an error'))) |
+ |
+ |
+class MethodTest(test_util.TestCase): |
+ """Test remote method decorator.""" |
+ |
+ def testMethod(self): |
+ """Test use of remote decorator.""" |
+ self.assertEquals(SimpleRequest, |
+ BasicService.remote_method.remote.request_type) |
+ self.assertEquals(SimpleResponse, |
+ BasicService.remote_method.remote.response_type) |
+ self.assertTrue(isinstance(BasicService.remote_method.remote.method, |
+ types.FunctionType)) |
+ |
+ def testMethodMessageResolution(self): |
+ """Test use of remote decorator to resolve message types by name.""" |
+ class OtherService(remote.Service): |
+ |
+ @remote.method('SimpleRequest', 'SimpleResponse') |
+ def remote_method(self, request): |
+ pass |
+ |
+ self.assertEquals(SimpleRequest, |
+ OtherService.remote_method.remote.request_type) |
+ self.assertEquals(SimpleResponse, |
+ OtherService.remote_method.remote.response_type) |
+ |
+ def testMethodMessageResolution_NotFound(self): |
+ """Test failure to find message types.""" |
+ class OtherService(remote.Service): |
+ |
+ @remote.method('NoSuchRequest', 'NoSuchResponse') |
+ def remote_method(self, request): |
+ pass |
+ |
+ self.assertRaisesWithRegexpMatch( |
+ messages.DefinitionNotFoundError, |
+ 'Could not find definition for NoSuchRequest', |
+ getattr, |
+ OtherService.remote_method.remote, |
+ 'request_type') |
+ |
+ self.assertRaisesWithRegexpMatch( |
+ messages.DefinitionNotFoundError, |
+ 'Could not find definition for NoSuchResponse', |
+ getattr, |
+ OtherService.remote_method.remote, |
+ 'response_type') |
+ |
+ def testInvocation(self): |
+ """Test that invocation passes request through properly.""" |
+ service = BasicService() |
+ request = SimpleRequest() |
+ self.assertEquals(SimpleResponse(), service.remote_method(request)) |
+ self.assertEquals([id(request)], service.request_ids) |
+ |
+ def testInvocation_WrongRequestType(self): |
+ """Wrong request type passed to remote method.""" |
+ service = BasicService() |
+ |
+ self.assertRaises(remote.RequestError, |
+ service.remote_method, |
+ 'wrong') |
+ |
+ self.assertRaises(remote.RequestError, |
+ service.remote_method, |
+ None) |
+ |
+ self.assertRaises(remote.RequestError, |
+ service.remote_method, |
+ SimpleResponse()) |
+ |
+ def testInvocation_WrongResponseType(self): |
+ """Wrong response type returned from remote method.""" |
+ |
+ class AnotherService(object): |
+ |
+ @remote.method(SimpleRequest, SimpleResponse) |
+ def remote_method(self, unused_request): |
+ return self.return_this |
+ |
+ service = AnotherService() |
+ |
+ service.return_this = 'wrong' |
+ self.assertRaises(remote.ServerError, |
+ service.remote_method, |
+ SimpleRequest()) |
+ service.return_this = None |
+ self.assertRaises(remote.ServerError, |
+ service.remote_method, |
+ SimpleRequest()) |
+ service.return_this = SimpleRequest() |
+ self.assertRaises(remote.ServerError, |
+ service.remote_method, |
+ SimpleRequest()) |
+ |
+ def testBadRequestType(self): |
+ """Test bad request types used in remote definition.""" |
+ |
+ for request_type in (None, 1020, messages.Message, str): |
+ |
+ def declare(): |
+ class BadService(object): |
+ |
+ @remote.method(request_type, SimpleResponse) |
+ def remote_method(self, request): |
+ pass |
+ |
+ self.assertRaises(TypeError, declare) |
+ |
+ def testBadResponseType(self): |
+ """Test bad response types used in remote definition.""" |
+ |
+ for response_type in (None, 1020, messages.Message, str): |
+ |
+ def declare(): |
+ class BadService(object): |
+ |
+ @remote.method(SimpleRequest, response_type) |
+ def remote_method(self, request): |
+ pass |
+ |
+ self.assertRaises(TypeError, declare) |
+ |
+ |
+class GetRemoteMethodTest(test_util.TestCase): |
+ """Test for is_remote_method.""" |
+ |
+ def testGetRemoteMethod(self): |
+ """Test valid remote method detection.""" |
+ |
+ class Service(object): |
+ |
+ @remote.method(Request, Response) |
+ def remote_method(self, request): |
+ pass |
+ |
+ self.assertEquals(Service.remote_method.remote, |
+ remote.get_remote_method_info(Service.remote_method)) |
+ self.assertTrue(Service.remote_method.remote, |
+ remote.get_remote_method_info(Service().remote_method)) |
+ |
+ def testGetNotRemoteMethod(self): |
+ """Test positive result on a remote method.""" |
+ |
+ class NotService(object): |
+ |
+ def not_remote_method(self, request): |
+ pass |
+ |
+ def fn(self): |
+ pass |
+ |
+ class NotReallyRemote(object): |
+ """Test negative result on many bad values for remote methods.""" |
+ |
+ def not_really(self, request): |
+ pass |
+ |
+ not_really.remote = 'something else' |
+ |
+ for not_remote in [NotService.not_remote_method, |
+ NotService().not_remote_method, |
+ NotReallyRemote.not_really, |
+ NotReallyRemote().not_really, |
+ None, |
+ 1, |
+ 'a string', |
+ fn]: |
+ self.assertEquals(None, remote.get_remote_method_info(not_remote)) |
+ |
+ |
+class RequestStateTest(test_util.TestCase): |
+ """Test request state.""" |
+ |
+ STATE_CLASS = remote.RequestState |
+ |
+ def testConstructor(self): |
+ """Test constructor.""" |
+ state = self.STATE_CLASS(remote_host='remote-host', |
+ remote_address='remote-address', |
+ server_host='server-host', |
+ server_port=10) |
+ self.assertEquals('remote-host', state.remote_host) |
+ self.assertEquals('remote-address', state.remote_address) |
+ self.assertEquals('server-host', state.server_host) |
+ self.assertEquals(10, state.server_port) |
+ |
+ state = self.STATE_CLASS() |
+ self.assertEquals(None, state.remote_host) |
+ self.assertEquals(None, state.remote_address) |
+ self.assertEquals(None, state.server_host) |
+ self.assertEquals(None, state.server_port) |
+ |
+ def testConstructorError(self): |
+ """Test unexpected keyword argument.""" |
+ self.assertRaises(TypeError, |
+ self.STATE_CLASS, |
+ x=10) |
+ |
+ def testRepr(self): |
+ """Test string representation.""" |
+ self.assertEquals('<%s>' % self.STATE_CLASS.__name__, |
+ repr(self.STATE_CLASS())) |
+ self.assertEquals("<%s remote_host='abc'>" % self.STATE_CLASS.__name__, |
+ repr(self.STATE_CLASS(remote_host='abc'))) |
+ self.assertEquals("<%s remote_host='abc' " |
+ "remote_address='def'>" % self.STATE_CLASS.__name__, |
+ repr(self.STATE_CLASS(remote_host='abc', |
+ remote_address='def'))) |
+ self.assertEquals("<%s remote_host='abc' " |
+ "remote_address='def' " |
+ "server_host='ghi'>" % self.STATE_CLASS.__name__, |
+ repr(self.STATE_CLASS(remote_host='abc', |
+ remote_address='def', |
+ server_host='ghi'))) |
+ self.assertEquals("<%s remote_host='abc' " |
+ "remote_address='def' " |
+ "server_host='ghi' " |
+ 'server_port=102>' % self.STATE_CLASS.__name__, |
+ repr(self.STATE_CLASS(remote_host='abc', |
+ remote_address='def', |
+ server_host='ghi', |
+ server_port=102))) |
+ |
+ |
+class HttpRequestStateTest(RequestStateTest): |
+ |
+ STATE_CLASS = remote.HttpRequestState |
+ |
+ def testHttpMethod(self): |
+ state = remote.HttpRequestState(http_method='GET') |
+ self.assertEquals('GET', state.http_method) |
+ |
+ def testHttpMethod(self): |
+ state = remote.HttpRequestState(service_path='/bar') |
+ self.assertEquals('/bar', state.service_path) |
+ |
+ def testHeadersList(self): |
+ state = remote.HttpRequestState( |
+ headers=[('a', 'b'), ('c', 'd'), ('c', 'e')]) |
+ |
+ self.assertEquals(['a', 'c', 'c'], list(state.headers.keys())) |
+ self.assertEquals(['b'], state.headers.get_all('a')) |
+ self.assertEquals(['d', 'e'], state.headers.get_all('c')) |
+ |
+ def testHeadersDict(self): |
+ state = remote.HttpRequestState(headers={'a': 'b', 'c': ['d', 'e']}) |
+ |
+ self.assertEquals(['a', 'c', 'c'], sorted(state.headers.keys())) |
+ self.assertEquals(['b'], state.headers.get_all('a')) |
+ self.assertEquals(['d', 'e'], state.headers.get_all('c')) |
+ |
+ def testRepr(self): |
+ super(HttpRequestStateTest, self).testRepr() |
+ |
+ self.assertEquals("<%s remote_host='abc' " |
+ "remote_address='def' " |
+ "server_host='ghi' " |
+ 'server_port=102 ' |
+ "http_method='POST' " |
+ "service_path='/bar' " |
+ "headers=[('a', 'b'), ('c', 'd')]>" % |
+ self.STATE_CLASS.__name__, |
+ repr(self.STATE_CLASS(remote_host='abc', |
+ remote_address='def', |
+ server_host='ghi', |
+ server_port=102, |
+ http_method='POST', |
+ service_path='/bar', |
+ headers={'a': 'b', 'c': 'd'}, |
+ ))) |
+ |
+ |
+class ServiceTest(test_util.TestCase): |
+ """Test Service class.""" |
+ |
+ def testServiceBase_AllRemoteMethods(self): |
+ """Test that service base class has no remote methods.""" |
+ self.assertEquals({}, remote.Service.all_remote_methods()) |
+ |
+ def testAllRemoteMethods(self): |
+ """Test all_remote_methods with properly Service subclass.""" |
+ self.assertEquals({'remote_method': MyService.remote_method}, |
+ MyService.all_remote_methods()) |
+ |
+ def testAllRemoteMethods_SubClass(self): |
+ """Test all_remote_methods on a sub-class of a service.""" |
+ class SubClass(MyService): |
+ |
+ @remote.method(Request, Response) |
+ def sub_class_method(self, request): |
+ pass |
+ |
+ self.assertEquals({'remote_method': SubClass.remote_method, |
+ 'sub_class_method': SubClass.sub_class_method, |
+ }, |
+ SubClass.all_remote_methods()) |
+ |
+ def testOverrideMethod(self): |
+ """Test that trying to override a remote method with remote decorator.""" |
+ class SubClass(MyService): |
+ |
+ def remote_method(self, request): |
+ response = super(SubClass, self).remote_method(request) |
+ response.value = '(%s)' % response.value |
+ return response |
+ |
+ self.assertEquals({'remote_method': SubClass.remote_method, |
+ }, |
+ SubClass.all_remote_methods()) |
+ |
+ instance = SubClass() |
+ self.assertEquals('(Hello)', |
+ instance.remote_method(Request(value='Hello')).value) |
+ self.assertEquals(Request, SubClass.remote_method.remote.request_type) |
+ self.assertEquals(Response, SubClass.remote_method.remote.response_type) |
+ |
+ def testOverrideMethodWithRemote(self): |
+ """Test trying to override a remote method with remote decorator.""" |
+ def do_override(): |
+ class SubClass(MyService): |
+ |
+ @remote.method(Request, Response) |
+ def remote_method(self, request): |
+ pass |
+ |
+ self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError, |
+ 'Do not use method decorator when ' |
+ 'overloading remote method remote_method ' |
+ 'on service SubClass', |
+ do_override) |
+ |
+ def testOverrideMethodWithInvalidValue(self): |
+ """Test trying to override a remote method with remote decorator.""" |
+ def do_override(bad_value): |
+ class SubClass(MyService): |
+ |
+ remote_method = bad_value |
+ |
+ for bad_value in [None, 1, 'string', {}]: |
+ self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError, |
+ 'Must override remote_method in ' |
+ 'SubClass with a method', |
+ do_override, bad_value) |
+ |
+ def testCallingRemoteMethod(self): |
+ """Test invoking a remote method.""" |
+ expected = Response() |
+ expected.value = 'what was passed in' |
+ |
+ request = Request() |
+ request.value = 'what was passed in' |
+ |
+ service = MyService() |
+ self.assertEquals(expected, service.remote_method(request)) |
+ |
+ def testFactory(self): |
+ """Test using factory to pass in state.""" |
+ class StatefulService(remote.Service): |
+ |
+ def __init__(self, a, b, c=None): |
+ self.a = a |
+ self.b = b |
+ self.c = c |
+ |
+ state = [1, 2, 3] |
+ |
+ factory = StatefulService.new_factory(1, state) |
+ |
+ module_name = ServiceTest.__module__ |
+ pattern = ('Creates new instances of service StatefulService.\n\n' |
+ 'Returns:\n' |
+ ' New instance of %s.StatefulService.' % module_name) |
+ self.assertEqual(pattern, factory.__doc__) |
+ self.assertEquals('StatefulService_service_factory', factory.__name__) |
+ self.assertEquals(StatefulService, factory.service_class) |
+ |
+ service = factory() |
+ self.assertEquals(1, service.a) |
+ self.assertEquals(id(state), id(service.b)) |
+ self.assertEquals(None, service.c) |
+ |
+ factory = StatefulService.new_factory(2, b=3, c=4) |
+ service = factory() |
+ self.assertEquals(2, service.a) |
+ self.assertEquals(3, service.b) |
+ self.assertEquals(4, service.c) |
+ |
+ def testFactoryError(self): |
+ """Test misusing a factory.""" |
+ # Passing positional argument that is not accepted by class. |
+ self.assertRaises(TypeError, remote.Service.new_factory(1)) |
+ |
+ # Passing keyword argument that is not accepted by class. |
+ self.assertRaises(TypeError, remote.Service.new_factory(x=1)) |
+ |
+ class StatefulService(remote.Service): |
+ |
+ def __init__(self, a): |
+ pass |
+ |
+ # Missing required parameter. |
+ self.assertRaises(TypeError, StatefulService.new_factory()) |
+ |
+ def testDefinitionName(self): |
+ """Test getting service definition name.""" |
+ class TheService(remote.Service): |
+ pass |
+ |
+ module_name = test_util.get_module_name(ServiceTest) |
+ self.assertEqual(TheService.definition_name(), |
+ '%s.TheService' % module_name) |
+ self.assertTrue(TheService.outer_definition_name(), |
+ module_name) |
+ self.assertTrue(TheService.definition_package(), |
+ module_name) |
+ |
+ def testDefinitionNameWithPackage(self): |
+ """Test getting service definition name when package defined.""" |
+ global package |
+ package = 'my.package' |
+ try: |
+ class TheService(remote.Service): |
+ pass |
+ |
+ self.assertEquals('my.package.TheService', TheService.definition_name()) |
+ self.assertEquals('my.package', TheService.outer_definition_name()) |
+ self.assertEquals('my.package', TheService.definition_package()) |
+ finally: |
+ del package |
+ |
+ def testDefinitionNameWithNoModule(self): |
+ """Test getting service definition name when package defined.""" |
+ module = sys.modules[__name__] |
+ try: |
+ del sys.modules[__name__] |
+ class TheService(remote.Service): |
+ pass |
+ |
+ self.assertEquals('TheService', TheService.definition_name()) |
+ self.assertEquals(None, TheService.outer_definition_name()) |
+ self.assertEquals(None, TheService.definition_package()) |
+ finally: |
+ sys.modules[__name__] = module |
+ |
+ |
+class StubTest(test_util.TestCase): |
+ |
+ def setUp(self): |
+ self.mox = mox.Mox() |
+ self.transport = self.mox.CreateMockAnything() |
+ |
+ def testDefinitionName(self): |
+ self.assertEquals(BasicService.definition_name(), |
+ BasicService.Stub.definition_name()) |
+ self.assertEquals(BasicService.outer_definition_name(), |
+ BasicService.Stub.outer_definition_name()) |
+ self.assertEquals(BasicService.definition_package(), |
+ BasicService.Stub.definition_package()) |
+ |
+ def testRemoteMethods(self): |
+ self.assertEquals(BasicService.all_remote_methods(), |
+ BasicService.Stub.all_remote_methods()) |
+ |
+ def testSync_WithRequest(self): |
+ stub = BasicService.Stub(self.transport) |
+ |
+ request = SimpleRequest() |
+ request.param1 = 'val1' |
+ request.param2 = 'val2' |
+ response = SimpleResponse() |
+ |
+ rpc = transport.Rpc(request) |
+ rpc.set_response(response) |
+ self.transport.send_rpc(BasicService.remote_method.remote, |
+ request).AndReturn(rpc) |
+ |
+ self.mox.ReplayAll() |
+ |
+ self.assertEquals(SimpleResponse(), stub.remote_method(request)) |
+ |
+ self.mox.VerifyAll() |
+ |
+ def testSync_WithKwargs(self): |
+ stub = BasicService.Stub(self.transport) |
+ |
+ |
+ request = SimpleRequest() |
+ request.param1 = 'val1' |
+ request.param2 = 'val2' |
+ response = SimpleResponse() |
+ |
+ rpc = transport.Rpc(request) |
+ rpc.set_response(response) |
+ self.transport.send_rpc(BasicService.remote_method.remote, |
+ request).AndReturn(rpc) |
+ |
+ self.mox.ReplayAll() |
+ |
+ self.assertEquals(SimpleResponse(), stub.remote_method(param1='val1', |
+ param2='val2')) |
+ |
+ self.mox.VerifyAll() |
+ |
+ def testAsync_WithRequest(self): |
+ stub = BasicService.Stub(self.transport) |
+ |
+ request = SimpleRequest() |
+ request.param1 = 'val1' |
+ request.param2 = 'val2' |
+ response = SimpleResponse() |
+ |
+ rpc = transport.Rpc(request) |
+ |
+ self.transport.send_rpc(BasicService.remote_method.remote, |
+ request).AndReturn(rpc) |
+ |
+ self.mox.ReplayAll() |
+ |
+ self.assertEquals(rpc, stub.async.remote_method(request)) |
+ |
+ self.mox.VerifyAll() |
+ |
+ def testAsync_WithKwargs(self): |
+ stub = BasicService.Stub(self.transport) |
+ |
+ request = SimpleRequest() |
+ request.param1 = 'val1' |
+ request.param2 = 'val2' |
+ response = SimpleResponse() |
+ |
+ rpc = transport.Rpc(request) |
+ |
+ self.transport.send_rpc(BasicService.remote_method.remote, |
+ request).AndReturn(rpc) |
+ |
+ self.mox.ReplayAll() |
+ |
+ self.assertEquals(rpc, stub.async.remote_method(param1='val1', |
+ param2='val2')) |
+ |
+ self.mox.VerifyAll() |
+ |
+ def testAsync_WithRequestAndKwargs(self): |
+ stub = BasicService.Stub(self.transport) |
+ |
+ request = SimpleRequest() |
+ request.param1 = 'val1' |
+ request.param2 = 'val2' |
+ response = SimpleResponse() |
+ |
+ self.mox.ReplayAll() |
+ |
+ self.assertRaisesWithRegexpMatch( |
+ TypeError, |
+ r'May not provide both args and kwargs', |
+ stub.async.remote_method, |
+ request, |
+ param1='val1', |
+ param2='val2') |
+ |
+ self.mox.VerifyAll() |
+ |
+ def testAsync_WithTooManyPositionals(self): |
+ stub = BasicService.Stub(self.transport) |
+ |
+ request = SimpleRequest() |
+ request.param1 = 'val1' |
+ request.param2 = 'val2' |
+ response = SimpleResponse() |
+ |
+ self.mox.ReplayAll() |
+ |
+ self.assertRaisesWithRegexpMatch( |
+ TypeError, |
+ r'remote_method\(\) takes at most 2 positional arguments \(3 given\)', |
+ stub.async.remote_method, |
+ request, 'another value') |
+ |
+ self.mox.VerifyAll() |
+ |
+ |
+class IsErrorStatusTest(test_util.TestCase): |
+ |
+ def testIsError(self): |
+ for state in (s for s in remote.RpcState if s > remote.RpcState.RUNNING): |
+ status = remote.RpcStatus(state=state) |
+ self.assertTrue(remote.is_error_status(status)) |
+ |
+ def testIsNotError(self): |
+ for state in (s for s in remote.RpcState if s <= remote.RpcState.RUNNING): |
+ status = remote.RpcStatus(state=state) |
+ self.assertFalse(remote.is_error_status(status)) |
+ |
+ def testStateNone(self): |
+ self.assertRaises(messages.ValidationError, |
+ remote.is_error_status, remote.RpcStatus()) |
+ |
+ |
+class CheckRpcStatusTest(test_util.TestCase): |
+ |
+ def testStateNone(self): |
+ self.assertRaises(messages.ValidationError, |
+ remote.check_rpc_status, remote.RpcStatus()) |
+ |
+ def testNoError(self): |
+ for state in (remote.RpcState.OK, remote.RpcState.RUNNING): |
+ remote.check_rpc_status(remote.RpcStatus(state=state)) |
+ |
+ def testErrorState(self): |
+ status = remote.RpcStatus(state=remote.RpcState.REQUEST_ERROR, |
+ error_message='a request error') |
+ self.assertRaisesWithRegexpMatch(remote.RequestError, |
+ 'a request error', |
+ remote.check_rpc_status, status) |
+ |
+ def testApplicationErrorState(self): |
+ status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR, |
+ error_message='an application error', |
+ error_name='blam') |
+ try: |
+ remote.check_rpc_status(status) |
+ self.fail('Should have raised application error.') |
+ except remote.ApplicationError as err: |
+ self.assertEquals('an application error', str(err)) |
+ self.assertEquals('blam', err.error_name) |
+ |
+ |
+class ProtocolConfigTest(test_util.TestCase): |
+ |
+ def testConstructor(self): |
+ config = remote.ProtocolConfig( |
+ protojson, |
+ 'proto1', |
+ 'application/X-Json', |
+ iter(['text/Json', 'text/JavaScript'])) |
+ self.assertEquals(protojson, config.protocol) |
+ self.assertEquals('proto1', config.name) |
+ self.assertEquals('application/x-json', config.default_content_type) |
+ self.assertEquals(('text/json', 'text/javascript'), |
+ config.alternate_content_types) |
+ self.assertEquals(('application/x-json', 'text/json', 'text/javascript'), |
+ config.content_types) |
+ |
+ def testConstructorDefaults(self): |
+ config = remote.ProtocolConfig(protojson, 'proto2') |
+ self.assertEquals(protojson, config.protocol) |
+ self.assertEquals('proto2', config.name) |
+ self.assertEquals('application/json', config.default_content_type) |
+ self.assertEquals(('application/x-javascript', |
+ 'text/javascript', |
+ 'text/x-javascript', |
+ 'text/x-json', |
+ 'text/json'), |
+ config.alternate_content_types) |
+ self.assertEquals(('application/json', |
+ 'application/x-javascript', |
+ 'text/javascript', |
+ 'text/x-javascript', |
+ 'text/x-json', |
+ 'text/json'), config.content_types) |
+ |
+ def testEmptyAlternativeTypes(self): |
+ config = remote.ProtocolConfig(protojson, 'proto2', |
+ alternative_content_types=()) |
+ self.assertEquals(protojson, config.protocol) |
+ self.assertEquals('proto2', config.name) |
+ self.assertEquals('application/json', config.default_content_type) |
+ self.assertEquals((), config.alternate_content_types) |
+ self.assertEquals(('application/json',), config.content_types) |
+ |
+ def testDuplicateContentTypes(self): |
+ self.assertRaises(remote.ServiceConfigurationError, |
+ remote.ProtocolConfig, |
+ protojson, |
+ 'json', |
+ 'text/plain', |
+ ('text/plain',)) |
+ |
+ self.assertRaises(remote.ServiceConfigurationError, |
+ remote.ProtocolConfig, |
+ protojson, |
+ 'json', |
+ 'text/plain', |
+ ('text/html', 'text/html')) |
+ |
+ def testEncodeMessage(self): |
+ config = remote.ProtocolConfig(protojson, 'proto2') |
+ encoded_message = config.encode_message( |
+ remote.RpcStatus(state=remote.RpcState.SERVER_ERROR, |
+ error_message='bad error')) |
+ |
+ # Convert back to a dictionary from JSON. |
+ dict_message = protojson.json.loads(encoded_message) |
+ self.assertEquals({'state': 'SERVER_ERROR', 'error_message': 'bad error'}, |
+ dict_message) |
+ |
+ def testDecodeMessage(self): |
+ config = remote.ProtocolConfig(protojson, 'proto2') |
+ self.assertEquals( |
+ remote.RpcStatus(state=remote.RpcState.SERVER_ERROR, |
+ error_message="bad error"), |
+ config.decode_message( |
+ remote.RpcStatus, |
+ '{"state": "SERVER_ERROR", "error_message": "bad error"}')) |
+ |
+ |
+class ProtocolsTest(test_util.TestCase): |
+ |
+ def setUp(self): |
+ self.protocols = remote.Protocols() |
+ |
+ def testEmpty(self): |
+ self.assertEquals((), self.protocols.names) |
+ self.assertEquals((), self.protocols.content_types) |
+ |
+ def testAddProtocolAllDefaults(self): |
+ self.protocols.add_protocol(protojson, 'json') |
+ self.assertEquals(('json',), self.protocols.names) |
+ self.assertEquals(('application/json', |
+ 'application/x-javascript', |
+ 'text/javascript', |
+ 'text/json', |
+ 'text/x-javascript', |
+ 'text/x-json'), |
+ self.protocols.content_types) |
+ |
+ def testAddProtocolNoDefaultAlternatives(self): |
+ class Protocol(object): |
+ CONTENT_TYPE = 'text/plain' |
+ |
+ self.protocols.add_protocol(Protocol, 'text') |
+ self.assertEquals(('text',), self.protocols.names) |
+ self.assertEquals(('text/plain',), self.protocols.content_types) |
+ |
+ def testAddProtocolOverrideDefaults(self): |
+ self.protocols.add_protocol(protojson, 'json', |
+ default_content_type='text/blar', |
+ alternative_content_types=('text/blam', |
+ 'text/blim')) |
+ self.assertEquals(('json',), self.protocols.names) |
+ self.assertEquals(('text/blam', 'text/blar', 'text/blim'), |
+ self.protocols.content_types) |
+ |
+ def testLookupByName(self): |
+ self.protocols.add_protocol(protojson, 'json') |
+ self.protocols.add_protocol(protojson, 'json2', |
+ default_content_type='text/plain', |
+ alternative_content_types=()) |
+ |
+ self.assertEquals('json', self.protocols.lookup_by_name('JsOn').name) |
+ self.assertEquals('json2', self.protocols.lookup_by_name('Json2').name) |
+ |
+ def testLookupByContentType(self): |
+ self.protocols.add_protocol(protojson, 'json') |
+ self.protocols.add_protocol(protojson, 'json2', |
+ default_content_type='text/plain', |
+ alternative_content_types=()) |
+ |
+ self.assertEquals( |
+ 'json', |
+ self.protocols.lookup_by_content_type('AppliCation/Json').name) |
+ |
+ self.assertEquals( |
+ 'json', |
+ self.protocols.lookup_by_content_type('text/x-Json').name) |
+ |
+ self.assertEquals( |
+ 'json2', |
+ self.protocols.lookup_by_content_type('text/Plain').name) |
+ |
+ def testNewDefault(self): |
+ protocols = remote.Protocols.new_default() |
+ self.assertEquals(('protobuf', 'protojson'), protocols.names) |
+ |
+ protobuf_protocol = protocols.lookup_by_name('protobuf') |
+ self.assertEquals(protobuf, protobuf_protocol.protocol) |
+ |
+ protojson_protocol = protocols.lookup_by_name('protojson') |
+ self.assertEquals(protojson.ProtoJson.get_default(), |
+ protojson_protocol.protocol) |
+ |
+ def testGetDefaultProtocols(self): |
+ protocols = remote.Protocols.get_default() |
+ self.assertEquals(('protobuf', 'protojson'), protocols.names) |
+ |
+ protobuf_protocol = protocols.lookup_by_name('protobuf') |
+ self.assertEquals(protobuf, protobuf_protocol.protocol) |
+ |
+ protojson_protocol = protocols.lookup_by_name('protojson') |
+ self.assertEquals(protojson.ProtoJson.get_default(), |
+ protojson_protocol.protocol) |
+ |
+ self.assertTrue(protocols is remote.Protocols.get_default()) |
+ |
+ def testSetDefaultProtocols(self): |
+ protocols = remote.Protocols() |
+ remote.Protocols.set_default(protocols) |
+ self.assertTrue(protocols is remote.Protocols.get_default()) |
+ |
+ def testSetDefaultWithoutProtocols(self): |
+ self.assertRaises(TypeError, remote.Protocols.set_default, None) |
+ self.assertRaises(TypeError, remote.Protocols.set_default, 'hi protocols') |
+ self.assertRaises(TypeError, remote.Protocols.set_default, {}) |
+ |
+ |
+def main(): |
+ unittest.main() |
+ |
+ |
+if __name__ == '__main__': |
+ main() |