| Index: third_party/google-endpoints/apitools/base/py/testing/mock.py
|
| diff --git a/third_party/google-endpoints/apitools/base/py/testing/mock.py b/third_party/google-endpoints/apitools/base/py/testing/mock.py
|
| new file mode 100644
|
| index 0000000000000000000000000000000000000000..91958a70c832ae6d114d3e9f89cda634d77c55c7
|
| --- /dev/null
|
| +++ b/third_party/google-endpoints/apitools/base/py/testing/mock.py
|
| @@ -0,0 +1,351 @@
|
| +#
|
| +# Copyright 2015 Google Inc.
|
| +#
|
| +# Licensed under the Apache License, Version 2.0 (the "License");
|
| +# you may not use this file except in compliance with the License.
|
| +# You may obtain a copy of the License at
|
| +#
|
| +# http://www.apache.org/licenses/LICENSE-2.0
|
| +#
|
| +# Unless required by applicable law or agreed to in writing, software
|
| +# distributed under the License is distributed on an "AS IS" BASIS,
|
| +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| +# See the License for the specific language governing permissions and
|
| +# limitations under the License.
|
| +
|
| +"""The mock module allows easy mocking of apitools clients.
|
| +
|
| +This module allows you to mock out the constructor of a particular apitools
|
| +client, for a specific API and version. Then, when the client is created, it
|
| +will be run against an expected session that you define. This way code that is
|
| +not aware of the testing framework can construct new clients as normal, as long
|
| +as it's all done within the context of a mock.
|
| +"""
|
| +
|
| +import difflib
|
| +
|
| +import six
|
| +
|
| +from apitools.base.protorpclite import messages
|
| +from apitools.base.py import base_api
|
| +from apitools.base.py import encoding
|
| +from apitools.base.py import exceptions
|
| +
|
| +
|
| +class Error(Exception):
|
| +
|
| + """Exceptions for this module."""
|
| +
|
| +
|
| +def _MessagesEqual(msg1, msg2):
|
| + """Compare two protorpc messages for equality.
|
| +
|
| + Using python's == operator does not work in all cases, specifically when
|
| + there is a list involved.
|
| +
|
| + Args:
|
| + msg1: protorpc.messages.Message or [protorpc.messages.Message] or number
|
| + or string, One of the messages to compare.
|
| + msg2: protorpc.messages.Message or [protorpc.messages.Message] or number
|
| + or string, One of the messages to compare.
|
| +
|
| + Returns:
|
| + If the messages are isomorphic.
|
| + """
|
| + if isinstance(msg1, list) and isinstance(msg2, list):
|
| + if len(msg1) != len(msg2):
|
| + return False
|
| + return all(_MessagesEqual(x, y) for x, y in zip(msg1, msg2))
|
| +
|
| + if (not isinstance(msg1, messages.Message) or
|
| + not isinstance(msg2, messages.Message)):
|
| + return msg1 == msg2
|
| + for field in msg1.all_fields():
|
| + field1 = getattr(msg1, field.name)
|
| + field2 = getattr(msg2, field.name)
|
| + if not _MessagesEqual(field1, field2):
|
| + return False
|
| + return True
|
| +
|
| +
|
| +class UnexpectedRequestException(Error):
|
| +
|
| + def __init__(self, received_call, expected_call):
|
| + expected_key, expected_request = expected_call
|
| + received_key, received_request = received_call
|
| +
|
| + expected_repr = encoding.MessageToRepr(
|
| + expected_request, multiline=True)
|
| + received_repr = encoding.MessageToRepr(
|
| + received_request, multiline=True)
|
| +
|
| + expected_lines = expected_repr.splitlines()
|
| + received_lines = received_repr.splitlines()
|
| +
|
| + diff_lines = difflib.unified_diff(expected_lines, received_lines)
|
| + diff = '\n'.join(diff_lines)
|
| +
|
| + if expected_key != received_key:
|
| + msg = '\n'.join((
|
| + 'expected: {expected_key}({expected_request})',
|
| + 'received: {received_key}({received_request})',
|
| + '',
|
| + )).format(
|
| + expected_key=expected_key,
|
| + expected_request=expected_repr,
|
| + received_key=received_key,
|
| + received_request=received_repr)
|
| + super(UnexpectedRequestException, self).__init__(msg)
|
| + else:
|
| + msg = '\n'.join((
|
| + 'for request to {key},',
|
| + 'expected: {expected_request}',
|
| + 'received: {received_request}',
|
| + 'diff: {diff}',
|
| + '',
|
| + )).format(
|
| + key=expected_key,
|
| + expected_request=expected_repr,
|
| + received_request=received_repr,
|
| + diff=diff)
|
| + super(UnexpectedRequestException, self).__init__(msg)
|
| +
|
| +
|
| +class ExpectedRequestsException(Error):
|
| +
|
| + def __init__(self, expected_calls):
|
| + msg = 'expected:\n'
|
| + for (key, request) in expected_calls:
|
| + msg += '{key}({request})\n'.format(
|
| + key=key,
|
| + request=encoding.MessageToRepr(request, multiline=True))
|
| + super(ExpectedRequestsException, self).__init__(msg)
|
| +
|
| +
|
| +class _ExpectedRequestResponse(object):
|
| +
|
| + """Encapsulation of an expected request and corresponding response."""
|
| +
|
| + def __init__(self, key, request, response=None, exception=None):
|
| + self.__key = key
|
| + self.__request = request
|
| +
|
| + if response and exception:
|
| + raise exceptions.ConfigurationValueError(
|
| + 'Should specify at most one of response and exception')
|
| + if response and isinstance(response, exceptions.Error):
|
| + raise exceptions.ConfigurationValueError(
|
| + 'Responses should not be an instance of Error')
|
| + if exception and not isinstance(exception, exceptions.Error):
|
| + raise exceptions.ConfigurationValueError(
|
| + 'Exceptions must be instances of Error')
|
| +
|
| + self.__response = response
|
| + self.__exception = exception
|
| +
|
| + @property
|
| + def key(self):
|
| + return self.__key
|
| +
|
| + @property
|
| + def request(self):
|
| + return self.__request
|
| +
|
| + def ValidateAndRespond(self, key, request):
|
| + """Validate that key and request match expectations, and respond if so.
|
| +
|
| + Args:
|
| + key: str, Actual key to compare against expectations.
|
| + request: protorpc.messages.Message or [protorpc.messages.Message]
|
| + or number or string, Actual request to compare againt expectations
|
| +
|
| + Raises:
|
| + UnexpectedRequestException: If key or request dont match
|
| + expectations.
|
| + apitools_base.Error: If a non-None exception is specified to
|
| + be thrown.
|
| +
|
| + Returns:
|
| + The response that was specified to be returned.
|
| +
|
| + """
|
| + if key != self.__key or not _MessagesEqual(request, self.__request):
|
| + raise UnexpectedRequestException((key, request),
|
| + (self.__key, self.__request))
|
| +
|
| + if self.__exception:
|
| + # Can only throw apitools_base.Error.
|
| + raise self.__exception # pylint: disable=raising-bad-type
|
| +
|
| + return self.__response
|
| +
|
| +
|
| +class _MockedService(base_api.BaseApiService):
|
| +
|
| + def __init__(self, key, mocked_client, methods, real_service):
|
| + super(_MockedService, self).__init__(mocked_client)
|
| + self.__dict__.update(real_service.__dict__)
|
| + for method in methods:
|
| + real_method = None
|
| + if real_service:
|
| + real_method = getattr(real_service, method)
|
| + setattr(self, method,
|
| + _MockedMethod(key + '.' + method,
|
| + mocked_client,
|
| + real_method))
|
| +
|
| +
|
| +class _MockedMethod(object):
|
| +
|
| + """A mocked API service method."""
|
| +
|
| + def __init__(self, key, mocked_client, real_method):
|
| + self.__key = key
|
| + self.__mocked_client = mocked_client
|
| + self.__real_method = real_method
|
| +
|
| + def Expect(self, request, response=None, exception=None, **unused_kwargs):
|
| + """Add an expectation on the mocked method.
|
| +
|
| + Exactly one of response and exception should be specified.
|
| +
|
| + Args:
|
| + request: The request that should be expected
|
| + response: The response that should be returned or None if
|
| + exception is provided.
|
| + exception: An exception that should be thrown, or None.
|
| +
|
| + """
|
| + # TODO(jasmuth): the unused_kwargs provides a placeholder for
|
| + # future things that can be passed to Expect(), like special
|
| + # params to the method call.
|
| +
|
| + # pylint: disable=protected-access
|
| + # Class in same module.
|
| + self.__mocked_client._request_responses.append(
|
| + _ExpectedRequestResponse(self.__key,
|
| + request,
|
| + response=response,
|
| + exception=exception))
|
| + # pylint: enable=protected-access
|
| +
|
| + def __call__(self, request, **unused_kwargs):
|
| + # TODO(jasmuth): allow the testing code to expect certain
|
| + # values in these currently unused_kwargs, especially the
|
| + # upload parameter used by media-heavy services like bigquery
|
| + # or bigstore.
|
| +
|
| + # pylint: disable=protected-access
|
| + # Class in same module.
|
| + if self.__mocked_client._request_responses:
|
| + request_response = self.__mocked_client._request_responses.pop(0)
|
| + else:
|
| + raise UnexpectedRequestException(
|
| + (self.__key, request), (None, None))
|
| + # pylint: enable=protected-access
|
| +
|
| + response = request_response.ValidateAndRespond(self.__key, request)
|
| +
|
| + if response is None and self.__real_method:
|
| + response = self.__real_method(request)
|
| + print(encoding.MessageToRepr(
|
| + response, multiline=True, shortstrings=True))
|
| + return response
|
| +
|
| + return response
|
| +
|
| +
|
| +def _MakeMockedServiceConstructor(mocked_service):
|
| + def Constructor(unused_self, unused_client):
|
| + return mocked_service
|
| + return Constructor
|
| +
|
| +
|
| +class Client(object):
|
| +
|
| + """Mock an apitools client."""
|
| +
|
| + def __init__(self, client_class, real_client=None):
|
| + """Mock an apitools API, given its class.
|
| +
|
| + Args:
|
| + client_class: The class for the API. eg, if you
|
| + from apis.sqladmin import v1beta3
|
| + then you can pass v1beta3.SqladminV1beta3 to this class
|
| + and anything within its context will use your mocked
|
| + version.
|
| + real_client: apitools Client, The client to make requests
|
| + against when the expected response is None.
|
| +
|
| + """
|
| +
|
| + if not real_client:
|
| + real_client = client_class(get_credentials=False)
|
| +
|
| + self.__client_class = client_class
|
| + self.__real_service_classes = {}
|
| + self.__real_client = real_client
|
| +
|
| + self._request_responses = []
|
| + self.__real_include_fields = None
|
| +
|
| + def __enter__(self):
|
| + return self.Mock()
|
| +
|
| + def Mock(self):
|
| + """Stub out the client class with mocked services."""
|
| + client = self.__real_client or self.__client_class(
|
| + get_credentials=False)
|
| + for name in dir(self.__client_class):
|
| + service_class = getattr(self.__client_class, name)
|
| + if not isinstance(service_class, type):
|
| + continue
|
| + if not issubclass(service_class, base_api.BaseApiService):
|
| + continue
|
| + self.__real_service_classes[name] = service_class
|
| + service = service_class(client)
|
| + # pylint: disable=protected-access
|
| + # Some liberty is allowed with mocking.
|
| + collection_name = service_class._NAME
|
| + # pylint: enable=protected-access
|
| + api_name = '%s_%s' % (self.__client_class._PACKAGE,
|
| + self.__client_class._URL_VERSION)
|
| + mocked_service = _MockedService(
|
| + api_name + '.' + collection_name, self,
|
| + service._method_configs.keys(),
|
| + service if self.__real_client else None)
|
| + mocked_constructor = _MakeMockedServiceConstructor(mocked_service)
|
| + setattr(self.__client_class, name, mocked_constructor)
|
| +
|
| + setattr(self, collection_name, mocked_service)
|
| +
|
| + self.__real_include_fields = self.__client_class.IncludeFields
|
| + self.__client_class.IncludeFields = self.IncludeFields
|
| +
|
| + return self
|
| +
|
| + def __exit__(self, exc_type, value, traceback):
|
| + self.Unmock()
|
| + if value:
|
| + six.reraise(exc_type, value, traceback)
|
| + return True
|
| +
|
| + def Unmock(self):
|
| + for name, service_class in self.__real_service_classes.items():
|
| + setattr(self.__client_class, name, service_class)
|
| + delattr(self, service_class._NAME)
|
| + self.__real_service_classes = {}
|
| +
|
| + if self._request_responses:
|
| + raise ExpectedRequestsException(
|
| + [(rq_rs.key, rq_rs.request) for rq_rs
|
| + in self._request_responses])
|
| + self._request_responses = []
|
| +
|
| + self.__client_class.IncludeFields = self.__real_include_fields
|
| + self.__real_include_fields = None
|
| +
|
| + def IncludeFields(self, include_fields):
|
| + if self.__real_client:
|
| + return self.__real_include_fields(self.__real_client,
|
| + include_fields)
|
|
|