Index: tools/telemetry/third_party/gsutilz/third_party/apitools/apitools/base/py/testing/mock.py |
diff --git a/tools/telemetry/third_party/gsutilz/third_party/apitools/apitools/base/py/testing/mock.py b/tools/telemetry/third_party/gsutilz/third_party/apitools/apitools/base/py/testing/mock.py |
new file mode 100644 |
index 0000000000000000000000000000000000000000..75a8b0791381d5bbea14c7d5295cee9416a1d33a |
--- /dev/null |
+++ b/tools/telemetry/third_party/gsutilz/third_party/apitools/apitools/base/py/testing/mock.py |
@@ -0,0 +1,329 @@ |
+"""The mock module allows easy mocking of apitools clients. |
+ |
+This module allows you to mock out the constructor of a particular apitools |
+client, for a specific API and version. Then, when the client is created, it |
+will be run against an expected session that you define. This way code that is |
+not aware of the testing framework can construct new clients as normal, as long |
+as it's all done within the context of a mock. |
+""" |
+ |
+import difflib |
+ |
+from protorpc import messages |
+import six |
+ |
+import apitools.base.py as apitools_base |
+ |
+ |
+class Error(Exception): |
+ |
+ """Exceptions for this module.""" |
+ |
+ |
+def _MessagesEqual(msg1, msg2): |
+ """Compare two protorpc messages for equality. |
+ |
+ Using python's == operator does not work in all cases, specifically when |
+ there is a list involved. |
+ |
+ Args: |
+ msg1: protorpc.messages.Message or [protorpc.messages.Message] or number |
+ or string, One of the messages to compare. |
+ msg2: protorpc.messages.Message or [protorpc.messages.Message] or number |
+ or string, One of the messages to compare. |
+ |
+ Returns: |
+ If the messages are isomorphic. |
+ """ |
+ if isinstance(msg1, list) and isinstance(msg2, list): |
+ if len(msg1) != len(msg2): |
+ return False |
+ return all(_MessagesEqual(x, y) for x, y in zip(msg1, msg2)) |
+ |
+ if (not isinstance(msg1, messages.Message) or |
+ not isinstance(msg2, messages.Message)): |
+ return msg1 == msg2 |
+ for field in msg1.all_fields(): |
+ field1 = getattr(msg1, field.name) |
+ field2 = getattr(msg2, field.name) |
+ if not _MessagesEqual(field1, field2): |
+ return False |
+ return True |
+ |
+ |
+class UnexpectedRequestException(Error): |
+ |
+ def __init__(self, received_call, expected_call): |
+ expected_key, expected_request = expected_call |
+ received_key, received_request = received_call |
+ |
+ expected_repr = apitools_base.MessageToRepr( |
+ expected_request, multiline=True) |
+ received_repr = apitools_base.MessageToRepr( |
+ received_request, multiline=True) |
+ |
+ expected_lines = expected_repr.splitlines() |
+ received_lines = received_repr.splitlines() |
+ |
+ diff_lines = difflib.unified_diff(expected_lines, received_lines) |
+ diff = '\n'.join(diff_lines) |
+ |
+ if expected_key != received_key: |
+ msg = '\n'.join(( |
+ 'expected: {expected_key}({expected_request})', |
+ 'received: {received_key}({received_request})', |
+ '', |
+ )).format( |
+ expected_key=expected_key, |
+ expected_request=expected_repr, |
+ received_key=received_key, |
+ received_request=received_repr) |
+ super(UnexpectedRequestException, self).__init__(msg) |
+ else: |
+ msg = '\n'.join(( |
+ 'for request to {key},', |
+ 'expected: {expected_request}', |
+ 'received: {received_request}', |
+ 'diff: {diff}', |
+ '', |
+ )).format( |
+ key=expected_key, |
+ expected_request=expected_repr, |
+ received_request=received_repr, |
+ diff=diff) |
+ super(UnexpectedRequestException, self).__init__(msg) |
+ |
+ |
+class ExpectedRequestsException(Error): |
+ |
+ def __init__(self, expected_calls): |
+ msg = 'expected:\n' |
+ for (key, request) in expected_calls: |
+ msg += '{key}({request})\n'.format( |
+ key=key, |
+ request=apitools_base.MessageToRepr(request, multiline=True)) |
+ super(ExpectedRequestsException, self).__init__(msg) |
+ |
+ |
+class _ExpectedRequestResponse(object): |
+ |
+ """Encapsulation of an expected request and corresponding response.""" |
+ |
+ def __init__(self, key, request, response=None, exception=None): |
+ self.__key = key |
+ self.__request = request |
+ |
+ if response and exception: |
+ raise apitools_base.ConfigurationValueError( |
+ 'Should specify at most one of response and exception') |
+ if response and isinstance(response, apitools_base.Error): |
+ raise apitools_base.ConfigurationValueError( |
+ 'Responses should not be an instance of Error') |
+ if exception and not isinstance(exception, apitools_base.Error): |
+ raise apitools_base.ConfigurationValueError( |
+ 'Exceptions must be instances of Error') |
+ |
+ self.__response = response |
+ self.__exception = exception |
+ |
+ @property |
+ def key(self): |
+ return self.__key |
+ |
+ @property |
+ def request(self): |
+ return self.__request |
+ |
+ def ValidateAndRespond(self, key, request): |
+ """Validate that key and request match expectations, and respond if so. |
+ |
+ Args: |
+ key: str, Actual key to compare against expectations. |
+ request: protorpc.messages.Message or [protorpc.messages.Message] |
+ or number or string, Actual request to compare againt expectations |
+ |
+ Raises: |
+ UnexpectedRequestException: If key or request dont match |
+ expectations. |
+ apitools_base.Error: If a non-None exception is specified to |
+ be thrown. |
+ |
+ Returns: |
+ The response that was specified to be returned. |
+ |
+ """ |
+ if key != self.__key or not _MessagesEqual(request, self.__request): |
+ raise UnexpectedRequestException((key, request), |
+ (self.__key, self.__request)) |
+ |
+ if self.__exception: |
+ # Can only throw apitools_base.Error. |
+ raise self.__exception # pylint: disable=raising-bad-type |
+ |
+ return self.__response |
+ |
+ |
+class _MockedService(apitools_base.BaseApiService): |
+ |
+ def __init__(self, key, mocked_client, methods, real_service): |
+ super(_MockedService, self).__init__(mocked_client) |
+ self.__dict__.update(real_service.__dict__) |
+ for method in methods: |
+ real_method = None |
+ if real_service: |
+ real_method = getattr(real_service, method) |
+ setattr(self, method, |
+ _MockedMethod(key + '.' + method, |
+ mocked_client, |
+ real_method)) |
+ |
+ |
+class _MockedMethod(object): |
+ |
+ """A mocked API service method.""" |
+ |
+ def __init__(self, key, mocked_client, real_method): |
+ self.__key = key |
+ self.__mocked_client = mocked_client |
+ self.__real_method = real_method |
+ |
+ def Expect(self, request, response=None, exception=None, **unused_kwargs): |
+ """Add an expectation on the mocked method. |
+ |
+ Exactly one of response and exception should be specified. |
+ |
+ Args: |
+ request: The request that should be expected |
+ response: The response that should be returned or None if |
+ exception is provided. |
+ exception: An exception that should be thrown, or None. |
+ |
+ """ |
+ # TODO(jasmuth): the unused_kwargs provides a placeholder for |
+ # future things that can be passed to Expect(), like special |
+ # params to the method call. |
+ |
+ # pylint: disable=protected-access |
+ # Class in same module. |
+ self.__mocked_client._request_responses.append( |
+ _ExpectedRequestResponse(self.__key, |
+ request, |
+ response=response, |
+ exception=exception)) |
+ # pylint: enable=protected-access |
+ |
+ def __call__(self, request, **unused_kwargs): |
+ # TODO(jasmuth): allow the testing code to expect certain |
+ # values in these currently unused_kwargs, especially the |
+ # upload parameter used by media-heavy services like bigquery |
+ # or bigstore. |
+ |
+ # pylint: disable=protected-access |
+ # Class in same module. |
+ if self.__mocked_client._request_responses: |
+ request_response = self.__mocked_client._request_responses.pop(0) |
+ else: |
+ raise UnexpectedRequestException( |
+ (self.__key, request), (None, None)) |
+ # pylint: enable=protected-access |
+ |
+ response = request_response.ValidateAndRespond(self.__key, request) |
+ |
+ if response is None and self.__real_method: |
+ response = self.__real_method(request) |
+ print(apitools_base.MessageToRepr( |
+ response, multiline=True, shortstrings=True)) |
+ return response |
+ |
+ return response |
+ |
+ |
+def _MakeMockedServiceConstructor(mocked_service): |
+ def Constructor(unused_self, unused_client): |
+ return mocked_service |
+ return Constructor |
+ |
+ |
+class Client(object): |
+ |
+ """Mock an apitools client.""" |
+ |
+ def __init__(self, client_class, real_client=None): |
+ """Mock an apitools API, given its class. |
+ |
+ Args: |
+ client_class: The class for the API. eg, if you |
+ from apis.sqladmin import v1beta3 |
+ then you can pass v1beta3.SqladminV1beta3 to this class |
+ and anything within its context will use your mocked |
+ version. |
+ real_client: apitools Client, The client to make requests |
+ against when the expected response is None. |
+ |
+ """ |
+ |
+ if not real_client: |
+ real_client = client_class(get_credentials=False) |
+ |
+ self.__client_class = client_class |
+ self.__real_service_classes = {} |
+ self.__real_client = real_client |
+ |
+ self._request_responses = [] |
+ |
+ def __enter__(self): |
+ return self.Mock() |
+ |
+ def Mock(self): |
+ """Stub out the client class with mocked services.""" |
+ client = self.__real_client or self.__client_class( |
+ get_credentials=False) |
+ for name in dir(self.__client_class): |
+ service_class = getattr(self.__client_class, name) |
+ if not isinstance(service_class, type): |
+ continue |
+ if not issubclass(service_class, apitools_base.BaseApiService): |
+ continue |
+ self.__real_service_classes[name] = service_class |
+ service = service_class(client) |
+ # pylint: disable=protected-access |
+ # Some liberty is allowed with mocking. |
+ collection_name = service_class._NAME |
+ # pylint: enable=protected-access |
+ api_name = '%s_%s' % (self.__client_class._PACKAGE, |
+ self.__client_class._URL_VERSION) |
+ mocked_service = _MockedService( |
+ api_name + '.' + collection_name, self, |
+ service._method_configs.keys(), |
+ service if self.__real_client else None) |
+ mocked_constructor = _MakeMockedServiceConstructor(mocked_service) |
+ setattr(self.__client_class, name, mocked_constructor) |
+ |
+ setattr(self, collection_name, mocked_service) |
+ |
+ self.__real_include_fields = self.__client_class.IncludeFields |
+ self.__client_class.IncludeFields = self.IncludeFields |
+ |
+ return self |
+ |
+ def __exit__(self, exc_type, value, traceback): |
+ self.Unmock() |
+ if value: |
+ six.reraise(exc_type, value, traceback) |
+ return True |
+ |
+ def Unmock(self): |
+ for name, service_class in self.__real_service_classes.items(): |
+ setattr(self.__client_class, name, service_class) |
+ |
+ if self._request_responses: |
+ raise ExpectedRequestsException( |
+ [(rq_rs.key, rq_rs.request) for rq_rs |
+ in self._request_responses]) |
+ |
+ self.__client_class.IncludeFields = self.__real_include_fields |
+ |
+ def IncludeFields(self, include_fields): |
+ if self.__real_client: |
+ return self.__real_include_fields(self.__real_client, |
+ include_fields) |