Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(2)

Unified Diff: third_party/gsutil/third_party/protorpc/protorpc/remote_test.py

Issue 1377933002: [catapult] - Copy Telemetry's gsutilz over to third_party. (Closed) Base URL: https://github.com/catapult-project/catapult.git@master
Patch Set: Rename to gsutil. Created 5 years, 3 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
Index: third_party/gsutil/third_party/protorpc/protorpc/remote_test.py
diff --git a/third_party/gsutil/third_party/protorpc/protorpc/remote_test.py b/third_party/gsutil/third_party/protorpc/protorpc/remote_test.py
new file mode 100755
index 0000000000000000000000000000000000000000..354763326913dd739af2536d69e617c6ab71daff
--- /dev/null
+++ b/third_party/gsutil/third_party/protorpc/protorpc/remote_test.py
@@ -0,0 +1,926 @@
+#!/usr/bin/env python
+#
+# Copyright 2010 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for protorpc.remote."""
+
+__author__ = 'rafek@google.com (Rafe Kaplan)'
+
+
+import sys
+import types
+import unittest
+from wsgiref import headers
+
+from protorpc import descriptor
+from protorpc import message_types
+from protorpc import messages
+from protorpc import protobuf
+from protorpc import protojson
+from protorpc import remote
+from protorpc import test_util
+from protorpc import transport
+
+import mox
+
+
+class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
+ test_util.TestCase):
+
+ MODULE = remote
+
+
+class Request(messages.Message):
+ """Test request message."""
+
+ value = messages.StringField(1)
+
+
+class Response(messages.Message):
+ """Test response message."""
+
+ value = messages.StringField(1)
+
+
+class MyService(remote.Service):
+
+ @remote.method(Request, Response)
+ def remote_method(self, request):
+ response = Response()
+ response.value = request.value
+ return response
+
+
+class SimpleRequest(messages.Message):
+ """Simple request message type used for tests."""
+
+ param1 = messages.StringField(1)
+ param2 = messages.StringField(2)
+
+
+class SimpleResponse(messages.Message):
+ """Simple response message type used for tests."""
+
+
+class BasicService(remote.Service):
+ """A basic service with decorated remote method."""
+
+ def __init__(self):
+ self.request_ids = []
+
+ @remote.method(SimpleRequest, SimpleResponse)
+ def remote_method(self, request):
+ self.request_ids.append(id(request))
+ return SimpleResponse()
+
+
+class RpcErrorTest(test_util.TestCase):
+
+ def testFromStatus(self):
+ for state in remote.RpcState:
+ exception = remote.RpcError.from_state
+ self.assertEquals(remote.ServerError,
+ remote.RpcError.from_state('SERVER_ERROR'))
+
+
+class ApplicationErrorTest(test_util.TestCase):
+
+ def testErrorCode(self):
+ self.assertEquals('blam',
+ remote.ApplicationError('an error', 'blam').error_name)
+
+ def testStr(self):
+ self.assertEquals('an error', str(remote.ApplicationError('an error', 1)))
+
+ def testRepr(self):
+ self.assertEquals("ApplicationError('an error', 1)",
+ repr(remote.ApplicationError('an error', 1)))
+
+ self.assertEquals("ApplicationError('an error')",
+ repr(remote.ApplicationError('an error')))
+
+
+class MethodTest(test_util.TestCase):
+ """Test remote method decorator."""
+
+ def testMethod(self):
+ """Test use of remote decorator."""
+ self.assertEquals(SimpleRequest,
+ BasicService.remote_method.remote.request_type)
+ self.assertEquals(SimpleResponse,
+ BasicService.remote_method.remote.response_type)
+ self.assertTrue(isinstance(BasicService.remote_method.remote.method,
+ types.FunctionType))
+
+ def testMethodMessageResolution(self):
+ """Test use of remote decorator to resolve message types by name."""
+ class OtherService(remote.Service):
+
+ @remote.method('SimpleRequest', 'SimpleResponse')
+ def remote_method(self, request):
+ pass
+
+ self.assertEquals(SimpleRequest,
+ OtherService.remote_method.remote.request_type)
+ self.assertEquals(SimpleResponse,
+ OtherService.remote_method.remote.response_type)
+
+ def testMethodMessageResolution_NotFound(self):
+ """Test failure to find message types."""
+ class OtherService(remote.Service):
+
+ @remote.method('NoSuchRequest', 'NoSuchResponse')
+ def remote_method(self, request):
+ pass
+
+ self.assertRaisesWithRegexpMatch(
+ messages.DefinitionNotFoundError,
+ 'Could not find definition for NoSuchRequest',
+ getattr,
+ OtherService.remote_method.remote,
+ 'request_type')
+
+ self.assertRaisesWithRegexpMatch(
+ messages.DefinitionNotFoundError,
+ 'Could not find definition for NoSuchResponse',
+ getattr,
+ OtherService.remote_method.remote,
+ 'response_type')
+
+ def testInvocation(self):
+ """Test that invocation passes request through properly."""
+ service = BasicService()
+ request = SimpleRequest()
+ self.assertEquals(SimpleResponse(), service.remote_method(request))
+ self.assertEquals([id(request)], service.request_ids)
+
+ def testInvocation_WrongRequestType(self):
+ """Wrong request type passed to remote method."""
+ service = BasicService()
+
+ self.assertRaises(remote.RequestError,
+ service.remote_method,
+ 'wrong')
+
+ self.assertRaises(remote.RequestError,
+ service.remote_method,
+ None)
+
+ self.assertRaises(remote.RequestError,
+ service.remote_method,
+ SimpleResponse())
+
+ def testInvocation_WrongResponseType(self):
+ """Wrong response type returned from remote method."""
+
+ class AnotherService(object):
+
+ @remote.method(SimpleRequest, SimpleResponse)
+ def remote_method(self, unused_request):
+ return self.return_this
+
+ service = AnotherService()
+
+ service.return_this = 'wrong'
+ self.assertRaises(remote.ServerError,
+ service.remote_method,
+ SimpleRequest())
+ service.return_this = None
+ self.assertRaises(remote.ServerError,
+ service.remote_method,
+ SimpleRequest())
+ service.return_this = SimpleRequest()
+ self.assertRaises(remote.ServerError,
+ service.remote_method,
+ SimpleRequest())
+
+ def testBadRequestType(self):
+ """Test bad request types used in remote definition."""
+
+ for request_type in (None, 1020, messages.Message, str):
+
+ def declare():
+ class BadService(object):
+
+ @remote.method(request_type, SimpleResponse)
+ def remote_method(self, request):
+ pass
+
+ self.assertRaises(TypeError, declare)
+
+ def testBadResponseType(self):
+ """Test bad response types used in remote definition."""
+
+ for response_type in (None, 1020, messages.Message, str):
+
+ def declare():
+ class BadService(object):
+
+ @remote.method(SimpleRequest, response_type)
+ def remote_method(self, request):
+ pass
+
+ self.assertRaises(TypeError, declare)
+
+
+class GetRemoteMethodTest(test_util.TestCase):
+ """Test for is_remote_method."""
+
+ def testGetRemoteMethod(self):
+ """Test valid remote method detection."""
+
+ class Service(object):
+
+ @remote.method(Request, Response)
+ def remote_method(self, request):
+ pass
+
+ self.assertEquals(Service.remote_method.remote,
+ remote.get_remote_method_info(Service.remote_method))
+ self.assertTrue(Service.remote_method.remote,
+ remote.get_remote_method_info(Service().remote_method))
+
+ def testGetNotRemoteMethod(self):
+ """Test positive result on a remote method."""
+
+ class NotService(object):
+
+ def not_remote_method(self, request):
+ pass
+
+ def fn(self):
+ pass
+
+ class NotReallyRemote(object):
+ """Test negative result on many bad values for remote methods."""
+
+ def not_really(self, request):
+ pass
+
+ not_really.remote = 'something else'
+
+ for not_remote in [NotService.not_remote_method,
+ NotService().not_remote_method,
+ NotReallyRemote.not_really,
+ NotReallyRemote().not_really,
+ None,
+ 1,
+ 'a string',
+ fn]:
+ self.assertEquals(None, remote.get_remote_method_info(not_remote))
+
+
+class RequestStateTest(test_util.TestCase):
+ """Test request state."""
+
+ STATE_CLASS = remote.RequestState
+
+ def testConstructor(self):
+ """Test constructor."""
+ state = self.STATE_CLASS(remote_host='remote-host',
+ remote_address='remote-address',
+ server_host='server-host',
+ server_port=10)
+ self.assertEquals('remote-host', state.remote_host)
+ self.assertEquals('remote-address', state.remote_address)
+ self.assertEquals('server-host', state.server_host)
+ self.assertEquals(10, state.server_port)
+
+ state = self.STATE_CLASS()
+ self.assertEquals(None, state.remote_host)
+ self.assertEquals(None, state.remote_address)
+ self.assertEquals(None, state.server_host)
+ self.assertEquals(None, state.server_port)
+
+ def testConstructorError(self):
+ """Test unexpected keyword argument."""
+ self.assertRaises(TypeError,
+ self.STATE_CLASS,
+ x=10)
+
+ def testRepr(self):
+ """Test string representation."""
+ self.assertEquals('<%s>' % self.STATE_CLASS.__name__,
+ repr(self.STATE_CLASS()))
+ self.assertEquals("<%s remote_host='abc'>" % self.STATE_CLASS.__name__,
+ repr(self.STATE_CLASS(remote_host='abc')))
+ self.assertEquals("<%s remote_host='abc' "
+ "remote_address='def'>" % self.STATE_CLASS.__name__,
+ repr(self.STATE_CLASS(remote_host='abc',
+ remote_address='def')))
+ self.assertEquals("<%s remote_host='abc' "
+ "remote_address='def' "
+ "server_host='ghi'>" % self.STATE_CLASS.__name__,
+ repr(self.STATE_CLASS(remote_host='abc',
+ remote_address='def',
+ server_host='ghi')))
+ self.assertEquals("<%s remote_host='abc' "
+ "remote_address='def' "
+ "server_host='ghi' "
+ 'server_port=102>' % self.STATE_CLASS.__name__,
+ repr(self.STATE_CLASS(remote_host='abc',
+ remote_address='def',
+ server_host='ghi',
+ server_port=102)))
+
+
+class HttpRequestStateTest(RequestStateTest):
+
+ STATE_CLASS = remote.HttpRequestState
+
+ def testHttpMethod(self):
+ state = remote.HttpRequestState(http_method='GET')
+ self.assertEquals('GET', state.http_method)
+
+ def testHttpMethod(self):
+ state = remote.HttpRequestState(service_path='/bar')
+ self.assertEquals('/bar', state.service_path)
+
+ def testHeadersList(self):
+ state = remote.HttpRequestState(
+ headers=[('a', 'b'), ('c', 'd'), ('c', 'e')])
+
+ self.assertEquals(['a', 'c', 'c'], list(state.headers.keys()))
+ self.assertEquals(['b'], state.headers.get_all('a'))
+ self.assertEquals(['d', 'e'], state.headers.get_all('c'))
+
+ def testHeadersDict(self):
+ state = remote.HttpRequestState(headers={'a': 'b', 'c': ['d', 'e']})
+
+ self.assertEquals(['a', 'c', 'c'], sorted(state.headers.keys()))
+ self.assertEquals(['b'], state.headers.get_all('a'))
+ self.assertEquals(['d', 'e'], state.headers.get_all('c'))
+
+ def testRepr(self):
+ super(HttpRequestStateTest, self).testRepr()
+
+ self.assertEquals("<%s remote_host='abc' "
+ "remote_address='def' "
+ "server_host='ghi' "
+ 'server_port=102 '
+ "http_method='POST' "
+ "service_path='/bar' "
+ "headers=[('a', 'b'), ('c', 'd')]>" %
+ self.STATE_CLASS.__name__,
+ repr(self.STATE_CLASS(remote_host='abc',
+ remote_address='def',
+ server_host='ghi',
+ server_port=102,
+ http_method='POST',
+ service_path='/bar',
+ headers={'a': 'b', 'c': 'd'},
+ )))
+
+
+class ServiceTest(test_util.TestCase):
+ """Test Service class."""
+
+ def testServiceBase_AllRemoteMethods(self):
+ """Test that service base class has no remote methods."""
+ self.assertEquals({}, remote.Service.all_remote_methods())
+
+ def testAllRemoteMethods(self):
+ """Test all_remote_methods with properly Service subclass."""
+ self.assertEquals({'remote_method': MyService.remote_method},
+ MyService.all_remote_methods())
+
+ def testAllRemoteMethods_SubClass(self):
+ """Test all_remote_methods on a sub-class of a service."""
+ class SubClass(MyService):
+
+ @remote.method(Request, Response)
+ def sub_class_method(self, request):
+ pass
+
+ self.assertEquals({'remote_method': SubClass.remote_method,
+ 'sub_class_method': SubClass.sub_class_method,
+ },
+ SubClass.all_remote_methods())
+
+ def testOverrideMethod(self):
+ """Test that trying to override a remote method with remote decorator."""
+ class SubClass(MyService):
+
+ def remote_method(self, request):
+ response = super(SubClass, self).remote_method(request)
+ response.value = '(%s)' % response.value
+ return response
+
+ self.assertEquals({'remote_method': SubClass.remote_method,
+ },
+ SubClass.all_remote_methods())
+
+ instance = SubClass()
+ self.assertEquals('(Hello)',
+ instance.remote_method(Request(value='Hello')).value)
+ self.assertEquals(Request, SubClass.remote_method.remote.request_type)
+ self.assertEquals(Response, SubClass.remote_method.remote.response_type)
+
+ def testOverrideMethodWithRemote(self):
+ """Test trying to override a remote method with remote decorator."""
+ def do_override():
+ class SubClass(MyService):
+
+ @remote.method(Request, Response)
+ def remote_method(self, request):
+ pass
+
+ self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError,
+ 'Do not use method decorator when '
+ 'overloading remote method remote_method '
+ 'on service SubClass',
+ do_override)
+
+ def testOverrideMethodWithInvalidValue(self):
+ """Test trying to override a remote method with remote decorator."""
+ def do_override(bad_value):
+ class SubClass(MyService):
+
+ remote_method = bad_value
+
+ for bad_value in [None, 1, 'string', {}]:
+ self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError,
+ 'Must override remote_method in '
+ 'SubClass with a method',
+ do_override, bad_value)
+
+ def testCallingRemoteMethod(self):
+ """Test invoking a remote method."""
+ expected = Response()
+ expected.value = 'what was passed in'
+
+ request = Request()
+ request.value = 'what was passed in'
+
+ service = MyService()
+ self.assertEquals(expected, service.remote_method(request))
+
+ def testFactory(self):
+ """Test using factory to pass in state."""
+ class StatefulService(remote.Service):
+
+ def __init__(self, a, b, c=None):
+ self.a = a
+ self.b = b
+ self.c = c
+
+ state = [1, 2, 3]
+
+ factory = StatefulService.new_factory(1, state)
+
+ module_name = ServiceTest.__module__
+ pattern = ('Creates new instances of service StatefulService.\n\n'
+ 'Returns:\n'
+ ' New instance of %s.StatefulService.' % module_name)
+ self.assertEqual(pattern, factory.__doc__)
+ self.assertEquals('StatefulService_service_factory', factory.__name__)
+ self.assertEquals(StatefulService, factory.service_class)
+
+ service = factory()
+ self.assertEquals(1, service.a)
+ self.assertEquals(id(state), id(service.b))
+ self.assertEquals(None, service.c)
+
+ factory = StatefulService.new_factory(2, b=3, c=4)
+ service = factory()
+ self.assertEquals(2, service.a)
+ self.assertEquals(3, service.b)
+ self.assertEquals(4, service.c)
+
+ def testFactoryError(self):
+ """Test misusing a factory."""
+ # Passing positional argument that is not accepted by class.
+ self.assertRaises(TypeError, remote.Service.new_factory(1))
+
+ # Passing keyword argument that is not accepted by class.
+ self.assertRaises(TypeError, remote.Service.new_factory(x=1))
+
+ class StatefulService(remote.Service):
+
+ def __init__(self, a):
+ pass
+
+ # Missing required parameter.
+ self.assertRaises(TypeError, StatefulService.new_factory())
+
+ def testDefinitionName(self):
+ """Test getting service definition name."""
+ class TheService(remote.Service):
+ pass
+
+ module_name = test_util.get_module_name(ServiceTest)
+ self.assertEqual(TheService.definition_name(),
+ '%s.TheService' % module_name)
+ self.assertTrue(TheService.outer_definition_name(),
+ module_name)
+ self.assertTrue(TheService.definition_package(),
+ module_name)
+
+ def testDefinitionNameWithPackage(self):
+ """Test getting service definition name when package defined."""
+ global package
+ package = 'my.package'
+ try:
+ class TheService(remote.Service):
+ pass
+
+ self.assertEquals('my.package.TheService', TheService.definition_name())
+ self.assertEquals('my.package', TheService.outer_definition_name())
+ self.assertEquals('my.package', TheService.definition_package())
+ finally:
+ del package
+
+ def testDefinitionNameWithNoModule(self):
+ """Test getting service definition name when package defined."""
+ module = sys.modules[__name__]
+ try:
+ del sys.modules[__name__]
+ class TheService(remote.Service):
+ pass
+
+ self.assertEquals('TheService', TheService.definition_name())
+ self.assertEquals(None, TheService.outer_definition_name())
+ self.assertEquals(None, TheService.definition_package())
+ finally:
+ sys.modules[__name__] = module
+
+
+class StubTest(test_util.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.transport = self.mox.CreateMockAnything()
+
+ def testDefinitionName(self):
+ self.assertEquals(BasicService.definition_name(),
+ BasicService.Stub.definition_name())
+ self.assertEquals(BasicService.outer_definition_name(),
+ BasicService.Stub.outer_definition_name())
+ self.assertEquals(BasicService.definition_package(),
+ BasicService.Stub.definition_package())
+
+ def testRemoteMethods(self):
+ self.assertEquals(BasicService.all_remote_methods(),
+ BasicService.Stub.all_remote_methods())
+
+ def testSync_WithRequest(self):
+ stub = BasicService.Stub(self.transport)
+
+ request = SimpleRequest()
+ request.param1 = 'val1'
+ request.param2 = 'val2'
+ response = SimpleResponse()
+
+ rpc = transport.Rpc(request)
+ rpc.set_response(response)
+ self.transport.send_rpc(BasicService.remote_method.remote,
+ request).AndReturn(rpc)
+
+ self.mox.ReplayAll()
+
+ self.assertEquals(SimpleResponse(), stub.remote_method(request))
+
+ self.mox.VerifyAll()
+
+ def testSync_WithKwargs(self):
+ stub = BasicService.Stub(self.transport)
+
+
+ request = SimpleRequest()
+ request.param1 = 'val1'
+ request.param2 = 'val2'
+ response = SimpleResponse()
+
+ rpc = transport.Rpc(request)
+ rpc.set_response(response)
+ self.transport.send_rpc(BasicService.remote_method.remote,
+ request).AndReturn(rpc)
+
+ self.mox.ReplayAll()
+
+ self.assertEquals(SimpleResponse(), stub.remote_method(param1='val1',
+ param2='val2'))
+
+ self.mox.VerifyAll()
+
+ def testAsync_WithRequest(self):
+ stub = BasicService.Stub(self.transport)
+
+ request = SimpleRequest()
+ request.param1 = 'val1'
+ request.param2 = 'val2'
+ response = SimpleResponse()
+
+ rpc = transport.Rpc(request)
+
+ self.transport.send_rpc(BasicService.remote_method.remote,
+ request).AndReturn(rpc)
+
+ self.mox.ReplayAll()
+
+ self.assertEquals(rpc, stub.async.remote_method(request))
+
+ self.mox.VerifyAll()
+
+ def testAsync_WithKwargs(self):
+ stub = BasicService.Stub(self.transport)
+
+ request = SimpleRequest()
+ request.param1 = 'val1'
+ request.param2 = 'val2'
+ response = SimpleResponse()
+
+ rpc = transport.Rpc(request)
+
+ self.transport.send_rpc(BasicService.remote_method.remote,
+ request).AndReturn(rpc)
+
+ self.mox.ReplayAll()
+
+ self.assertEquals(rpc, stub.async.remote_method(param1='val1',
+ param2='val2'))
+
+ self.mox.VerifyAll()
+
+ def testAsync_WithRequestAndKwargs(self):
+ stub = BasicService.Stub(self.transport)
+
+ request = SimpleRequest()
+ request.param1 = 'val1'
+ request.param2 = 'val2'
+ response = SimpleResponse()
+
+ self.mox.ReplayAll()
+
+ self.assertRaisesWithRegexpMatch(
+ TypeError,
+ r'May not provide both args and kwargs',
+ stub.async.remote_method,
+ request,
+ param1='val1',
+ param2='val2')
+
+ self.mox.VerifyAll()
+
+ def testAsync_WithTooManyPositionals(self):
+ stub = BasicService.Stub(self.transport)
+
+ request = SimpleRequest()
+ request.param1 = 'val1'
+ request.param2 = 'val2'
+ response = SimpleResponse()
+
+ self.mox.ReplayAll()
+
+ self.assertRaisesWithRegexpMatch(
+ TypeError,
+ r'remote_method\(\) takes at most 2 positional arguments \(3 given\)',
+ stub.async.remote_method,
+ request, 'another value')
+
+ self.mox.VerifyAll()
+
+
+class IsErrorStatusTest(test_util.TestCase):
+
+ def testIsError(self):
+ for state in (s for s in remote.RpcState if s > remote.RpcState.RUNNING):
+ status = remote.RpcStatus(state=state)
+ self.assertTrue(remote.is_error_status(status))
+
+ def testIsNotError(self):
+ for state in (s for s in remote.RpcState if s <= remote.RpcState.RUNNING):
+ status = remote.RpcStatus(state=state)
+ self.assertFalse(remote.is_error_status(status))
+
+ def testStateNone(self):
+ self.assertRaises(messages.ValidationError,
+ remote.is_error_status, remote.RpcStatus())
+
+
+class CheckRpcStatusTest(test_util.TestCase):
+
+ def testStateNone(self):
+ self.assertRaises(messages.ValidationError,
+ remote.check_rpc_status, remote.RpcStatus())
+
+ def testNoError(self):
+ for state in (remote.RpcState.OK, remote.RpcState.RUNNING):
+ remote.check_rpc_status(remote.RpcStatus(state=state))
+
+ def testErrorState(self):
+ status = remote.RpcStatus(state=remote.RpcState.REQUEST_ERROR,
+ error_message='a request error')
+ self.assertRaisesWithRegexpMatch(remote.RequestError,
+ 'a request error',
+ remote.check_rpc_status, status)
+
+ def testApplicationErrorState(self):
+ status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR,
+ error_message='an application error',
+ error_name='blam')
+ try:
+ remote.check_rpc_status(status)
+ self.fail('Should have raised application error.')
+ except remote.ApplicationError as err:
+ self.assertEquals('an application error', str(err))
+ self.assertEquals('blam', err.error_name)
+
+
+class ProtocolConfigTest(test_util.TestCase):
+
+ def testConstructor(self):
+ config = remote.ProtocolConfig(
+ protojson,
+ 'proto1',
+ 'application/X-Json',
+ iter(['text/Json', 'text/JavaScript']))
+ self.assertEquals(protojson, config.protocol)
+ self.assertEquals('proto1', config.name)
+ self.assertEquals('application/x-json', config.default_content_type)
+ self.assertEquals(('text/json', 'text/javascript'),
+ config.alternate_content_types)
+ self.assertEquals(('application/x-json', 'text/json', 'text/javascript'),
+ config.content_types)
+
+ def testConstructorDefaults(self):
+ config = remote.ProtocolConfig(protojson, 'proto2')
+ self.assertEquals(protojson, config.protocol)
+ self.assertEquals('proto2', config.name)
+ self.assertEquals('application/json', config.default_content_type)
+ self.assertEquals(('application/x-javascript',
+ 'text/javascript',
+ 'text/x-javascript',
+ 'text/x-json',
+ 'text/json'),
+ config.alternate_content_types)
+ self.assertEquals(('application/json',
+ 'application/x-javascript',
+ 'text/javascript',
+ 'text/x-javascript',
+ 'text/x-json',
+ 'text/json'), config.content_types)
+
+ def testEmptyAlternativeTypes(self):
+ config = remote.ProtocolConfig(protojson, 'proto2',
+ alternative_content_types=())
+ self.assertEquals(protojson, config.protocol)
+ self.assertEquals('proto2', config.name)
+ self.assertEquals('application/json', config.default_content_type)
+ self.assertEquals((), config.alternate_content_types)
+ self.assertEquals(('application/json',), config.content_types)
+
+ def testDuplicateContentTypes(self):
+ self.assertRaises(remote.ServiceConfigurationError,
+ remote.ProtocolConfig,
+ protojson,
+ 'json',
+ 'text/plain',
+ ('text/plain',))
+
+ self.assertRaises(remote.ServiceConfigurationError,
+ remote.ProtocolConfig,
+ protojson,
+ 'json',
+ 'text/plain',
+ ('text/html', 'text/html'))
+
+ def testEncodeMessage(self):
+ config = remote.ProtocolConfig(protojson, 'proto2')
+ encoded_message = config.encode_message(
+ remote.RpcStatus(state=remote.RpcState.SERVER_ERROR,
+ error_message='bad error'))
+
+ # Convert back to a dictionary from JSON.
+ dict_message = protojson.json.loads(encoded_message)
+ self.assertEquals({'state': 'SERVER_ERROR', 'error_message': 'bad error'},
+ dict_message)
+
+ def testDecodeMessage(self):
+ config = remote.ProtocolConfig(protojson, 'proto2')
+ self.assertEquals(
+ remote.RpcStatus(state=remote.RpcState.SERVER_ERROR,
+ error_message="bad error"),
+ config.decode_message(
+ remote.RpcStatus,
+ '{"state": "SERVER_ERROR", "error_message": "bad error"}'))
+
+
+class ProtocolsTest(test_util.TestCase):
+
+ def setUp(self):
+ self.protocols = remote.Protocols()
+
+ def testEmpty(self):
+ self.assertEquals((), self.protocols.names)
+ self.assertEquals((), self.protocols.content_types)
+
+ def testAddProtocolAllDefaults(self):
+ self.protocols.add_protocol(protojson, 'json')
+ self.assertEquals(('json',), self.protocols.names)
+ self.assertEquals(('application/json',
+ 'application/x-javascript',
+ 'text/javascript',
+ 'text/json',
+ 'text/x-javascript',
+ 'text/x-json'),
+ self.protocols.content_types)
+
+ def testAddProtocolNoDefaultAlternatives(self):
+ class Protocol(object):
+ CONTENT_TYPE = 'text/plain'
+
+ self.protocols.add_protocol(Protocol, 'text')
+ self.assertEquals(('text',), self.protocols.names)
+ self.assertEquals(('text/plain',), self.protocols.content_types)
+
+ def testAddProtocolOverrideDefaults(self):
+ self.protocols.add_protocol(protojson, 'json',
+ default_content_type='text/blar',
+ alternative_content_types=('text/blam',
+ 'text/blim'))
+ self.assertEquals(('json',), self.protocols.names)
+ self.assertEquals(('text/blam', 'text/blar', 'text/blim'),
+ self.protocols.content_types)
+
+ def testLookupByName(self):
+ self.protocols.add_protocol(protojson, 'json')
+ self.protocols.add_protocol(protojson, 'json2',
+ default_content_type='text/plain',
+ alternative_content_types=())
+
+ self.assertEquals('json', self.protocols.lookup_by_name('JsOn').name)
+ self.assertEquals('json2', self.protocols.lookup_by_name('Json2').name)
+
+ def testLookupByContentType(self):
+ self.protocols.add_protocol(protojson, 'json')
+ self.protocols.add_protocol(protojson, 'json2',
+ default_content_type='text/plain',
+ alternative_content_types=())
+
+ self.assertEquals(
+ 'json',
+ self.protocols.lookup_by_content_type('AppliCation/Json').name)
+
+ self.assertEquals(
+ 'json',
+ self.protocols.lookup_by_content_type('text/x-Json').name)
+
+ self.assertEquals(
+ 'json2',
+ self.protocols.lookup_by_content_type('text/Plain').name)
+
+ def testNewDefault(self):
+ protocols = remote.Protocols.new_default()
+ self.assertEquals(('protobuf', 'protojson'), protocols.names)
+
+ protobuf_protocol = protocols.lookup_by_name('protobuf')
+ self.assertEquals(protobuf, protobuf_protocol.protocol)
+
+ protojson_protocol = protocols.lookup_by_name('protojson')
+ self.assertEquals(protojson.ProtoJson.get_default(),
+ protojson_protocol.protocol)
+
+ def testGetDefaultProtocols(self):
+ protocols = remote.Protocols.get_default()
+ self.assertEquals(('protobuf', 'protojson'), protocols.names)
+
+ protobuf_protocol = protocols.lookup_by_name('protobuf')
+ self.assertEquals(protobuf, protobuf_protocol.protocol)
+
+ protojson_protocol = protocols.lookup_by_name('protojson')
+ self.assertEquals(protojson.ProtoJson.get_default(),
+ protojson_protocol.protocol)
+
+ self.assertTrue(protocols is remote.Protocols.get_default())
+
+ def testSetDefaultProtocols(self):
+ protocols = remote.Protocols()
+ remote.Protocols.set_default(protocols)
+ self.assertTrue(protocols is remote.Protocols.get_default())
+
+ def testSetDefaultWithoutProtocols(self):
+ self.assertRaises(TypeError, remote.Protocols.set_default, None)
+ self.assertRaises(TypeError, remote.Protocols.set_default, 'hi protocols')
+ self.assertRaises(TypeError, remote.Protocols.set_default, {})
+
+
+def main():
+ unittest.main()
+
+
+if __name__ == '__main__':
+ main()

Powered by Google App Engine
This is Rietveld 408576698