OLD | NEW |
(Empty) | |
| 1 # |
| 2 # Copyright 2015 Google Inc. |
| 3 # |
| 4 # Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 # you may not use this file except in compliance with the License. |
| 6 # You may obtain a copy of the License at |
| 7 # |
| 8 # http://www.apache.org/licenses/LICENSE-2.0 |
| 9 # |
| 10 # Unless required by applicable law or agreed to in writing, software |
| 11 # distributed under the License is distributed on an "AS IS" BASIS, |
| 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 # See the License for the specific language governing permissions and |
| 14 # limitations under the License. |
| 15 |
| 16 """The mock module allows easy mocking of apitools clients. |
| 17 |
| 18 This module allows you to mock out the constructor of a particular apitools |
| 19 client, for a specific API and version. Then, when the client is created, it |
| 20 will be run against an expected session that you define. This way code that is |
| 21 not aware of the testing framework can construct new clients as normal, as long |
| 22 as it's all done within the context of a mock. |
| 23 """ |
| 24 |
| 25 import difflib |
| 26 |
| 27 import six |
| 28 |
| 29 from apitools.base.protorpclite import messages |
| 30 from apitools.base.py import base_api |
| 31 from apitools.base.py import encoding |
| 32 from apitools.base.py import exceptions |
| 33 |
| 34 |
| 35 class Error(Exception): |
| 36 |
| 37 """Exceptions for this module.""" |
| 38 |
| 39 |
| 40 def _MessagesEqual(msg1, msg2): |
| 41 """Compare two protorpc messages for equality. |
| 42 |
| 43 Using python's == operator does not work in all cases, specifically when |
| 44 there is a list involved. |
| 45 |
| 46 Args: |
| 47 msg1: protorpc.messages.Message or [protorpc.messages.Message] or number |
| 48 or string, One of the messages to compare. |
| 49 msg2: protorpc.messages.Message or [protorpc.messages.Message] or number |
| 50 or string, One of the messages to compare. |
| 51 |
| 52 Returns: |
| 53 If the messages are isomorphic. |
| 54 """ |
| 55 if isinstance(msg1, list) and isinstance(msg2, list): |
| 56 if len(msg1) != len(msg2): |
| 57 return False |
| 58 return all(_MessagesEqual(x, y) for x, y in zip(msg1, msg2)) |
| 59 |
| 60 if (not isinstance(msg1, messages.Message) or |
| 61 not isinstance(msg2, messages.Message)): |
| 62 return msg1 == msg2 |
| 63 for field in msg1.all_fields(): |
| 64 field1 = getattr(msg1, field.name) |
| 65 field2 = getattr(msg2, field.name) |
| 66 if not _MessagesEqual(field1, field2): |
| 67 return False |
| 68 return True |
| 69 |
| 70 |
| 71 class UnexpectedRequestException(Error): |
| 72 |
| 73 def __init__(self, received_call, expected_call): |
| 74 expected_key, expected_request = expected_call |
| 75 received_key, received_request = received_call |
| 76 |
| 77 expected_repr = encoding.MessageToRepr( |
| 78 expected_request, multiline=True) |
| 79 received_repr = encoding.MessageToRepr( |
| 80 received_request, multiline=True) |
| 81 |
| 82 expected_lines = expected_repr.splitlines() |
| 83 received_lines = received_repr.splitlines() |
| 84 |
| 85 diff_lines = difflib.unified_diff(expected_lines, received_lines) |
| 86 diff = '\n'.join(diff_lines) |
| 87 |
| 88 if expected_key != received_key: |
| 89 msg = '\n'.join(( |
| 90 'expected: {expected_key}({expected_request})', |
| 91 'received: {received_key}({received_request})', |
| 92 '', |
| 93 )).format( |
| 94 expected_key=expected_key, |
| 95 expected_request=expected_repr, |
| 96 received_key=received_key, |
| 97 received_request=received_repr) |
| 98 super(UnexpectedRequestException, self).__init__(msg) |
| 99 else: |
| 100 msg = '\n'.join(( |
| 101 'for request to {key},', |
| 102 'expected: {expected_request}', |
| 103 'received: {received_request}', |
| 104 'diff: {diff}', |
| 105 '', |
| 106 )).format( |
| 107 key=expected_key, |
| 108 expected_request=expected_repr, |
| 109 received_request=received_repr, |
| 110 diff=diff) |
| 111 super(UnexpectedRequestException, self).__init__(msg) |
| 112 |
| 113 |
| 114 class ExpectedRequestsException(Error): |
| 115 |
| 116 def __init__(self, expected_calls): |
| 117 msg = 'expected:\n' |
| 118 for (key, request) in expected_calls: |
| 119 msg += '{key}({request})\n'.format( |
| 120 key=key, |
| 121 request=encoding.MessageToRepr(request, multiline=True)) |
| 122 super(ExpectedRequestsException, self).__init__(msg) |
| 123 |
| 124 |
| 125 class _ExpectedRequestResponse(object): |
| 126 |
| 127 """Encapsulation of an expected request and corresponding response.""" |
| 128 |
| 129 def __init__(self, key, request, response=None, exception=None): |
| 130 self.__key = key |
| 131 self.__request = request |
| 132 |
| 133 if response and exception: |
| 134 raise exceptions.ConfigurationValueError( |
| 135 'Should specify at most one of response and exception') |
| 136 if response and isinstance(response, exceptions.Error): |
| 137 raise exceptions.ConfigurationValueError( |
| 138 'Responses should not be an instance of Error') |
| 139 if exception and not isinstance(exception, exceptions.Error): |
| 140 raise exceptions.ConfigurationValueError( |
| 141 'Exceptions must be instances of Error') |
| 142 |
| 143 self.__response = response |
| 144 self.__exception = exception |
| 145 |
| 146 @property |
| 147 def key(self): |
| 148 return self.__key |
| 149 |
| 150 @property |
| 151 def request(self): |
| 152 return self.__request |
| 153 |
| 154 def ValidateAndRespond(self, key, request): |
| 155 """Validate that key and request match expectations, and respond if so. |
| 156 |
| 157 Args: |
| 158 key: str, Actual key to compare against expectations. |
| 159 request: protorpc.messages.Message or [protorpc.messages.Message] |
| 160 or number or string, Actual request to compare againt expectations |
| 161 |
| 162 Raises: |
| 163 UnexpectedRequestException: If key or request dont match |
| 164 expectations. |
| 165 apitools_base.Error: If a non-None exception is specified to |
| 166 be thrown. |
| 167 |
| 168 Returns: |
| 169 The response that was specified to be returned. |
| 170 |
| 171 """ |
| 172 if key != self.__key or not _MessagesEqual(request, self.__request): |
| 173 raise UnexpectedRequestException((key, request), |
| 174 (self.__key, self.__request)) |
| 175 |
| 176 if self.__exception: |
| 177 # Can only throw apitools_base.Error. |
| 178 raise self.__exception # pylint: disable=raising-bad-type |
| 179 |
| 180 return self.__response |
| 181 |
| 182 |
| 183 class _MockedService(base_api.BaseApiService): |
| 184 |
| 185 def __init__(self, key, mocked_client, methods, real_service): |
| 186 super(_MockedService, self).__init__(mocked_client) |
| 187 self.__dict__.update(real_service.__dict__) |
| 188 for method in methods: |
| 189 real_method = None |
| 190 if real_service: |
| 191 real_method = getattr(real_service, method) |
| 192 setattr(self, method, |
| 193 _MockedMethod(key + '.' + method, |
| 194 mocked_client, |
| 195 real_method)) |
| 196 |
| 197 |
| 198 class _MockedMethod(object): |
| 199 |
| 200 """A mocked API service method.""" |
| 201 |
| 202 def __init__(self, key, mocked_client, real_method): |
| 203 self.__key = key |
| 204 self.__mocked_client = mocked_client |
| 205 self.__real_method = real_method |
| 206 |
| 207 def Expect(self, request, response=None, exception=None, **unused_kwargs): |
| 208 """Add an expectation on the mocked method. |
| 209 |
| 210 Exactly one of response and exception should be specified. |
| 211 |
| 212 Args: |
| 213 request: The request that should be expected |
| 214 response: The response that should be returned or None if |
| 215 exception is provided. |
| 216 exception: An exception that should be thrown, or None. |
| 217 |
| 218 """ |
| 219 # TODO(jasmuth): the unused_kwargs provides a placeholder for |
| 220 # future things that can be passed to Expect(), like special |
| 221 # params to the method call. |
| 222 |
| 223 # pylint: disable=protected-access |
| 224 # Class in same module. |
| 225 self.__mocked_client._request_responses.append( |
| 226 _ExpectedRequestResponse(self.__key, |
| 227 request, |
| 228 response=response, |
| 229 exception=exception)) |
| 230 # pylint: enable=protected-access |
| 231 |
| 232 def __call__(self, request, **unused_kwargs): |
| 233 # TODO(jasmuth): allow the testing code to expect certain |
| 234 # values in these currently unused_kwargs, especially the |
| 235 # upload parameter used by media-heavy services like bigquery |
| 236 # or bigstore. |
| 237 |
| 238 # pylint: disable=protected-access |
| 239 # Class in same module. |
| 240 if self.__mocked_client._request_responses: |
| 241 request_response = self.__mocked_client._request_responses.pop(0) |
| 242 else: |
| 243 raise UnexpectedRequestException( |
| 244 (self.__key, request), (None, None)) |
| 245 # pylint: enable=protected-access |
| 246 |
| 247 response = request_response.ValidateAndRespond(self.__key, request) |
| 248 |
| 249 if response is None and self.__real_method: |
| 250 response = self.__real_method(request) |
| 251 print(encoding.MessageToRepr( |
| 252 response, multiline=True, shortstrings=True)) |
| 253 return response |
| 254 |
| 255 return response |
| 256 |
| 257 |
| 258 def _MakeMockedServiceConstructor(mocked_service): |
| 259 def Constructor(unused_self, unused_client): |
| 260 return mocked_service |
| 261 return Constructor |
| 262 |
| 263 |
| 264 class Client(object): |
| 265 |
| 266 """Mock an apitools client.""" |
| 267 |
| 268 def __init__(self, client_class, real_client=None): |
| 269 """Mock an apitools API, given its class. |
| 270 |
| 271 Args: |
| 272 client_class: The class for the API. eg, if you |
| 273 from apis.sqladmin import v1beta3 |
| 274 then you can pass v1beta3.SqladminV1beta3 to this class |
| 275 and anything within its context will use your mocked |
| 276 version. |
| 277 real_client: apitools Client, The client to make requests |
| 278 against when the expected response is None. |
| 279 |
| 280 """ |
| 281 |
| 282 if not real_client: |
| 283 real_client = client_class(get_credentials=False) |
| 284 |
| 285 self.__client_class = client_class |
| 286 self.__real_service_classes = {} |
| 287 self.__real_client = real_client |
| 288 |
| 289 self._request_responses = [] |
| 290 self.__real_include_fields = None |
| 291 |
| 292 def __enter__(self): |
| 293 return self.Mock() |
| 294 |
| 295 def Mock(self): |
| 296 """Stub out the client class with mocked services.""" |
| 297 client = self.__real_client or self.__client_class( |
| 298 get_credentials=False) |
| 299 for name in dir(self.__client_class): |
| 300 service_class = getattr(self.__client_class, name) |
| 301 if not isinstance(service_class, type): |
| 302 continue |
| 303 if not issubclass(service_class, base_api.BaseApiService): |
| 304 continue |
| 305 self.__real_service_classes[name] = service_class |
| 306 service = service_class(client) |
| 307 # pylint: disable=protected-access |
| 308 # Some liberty is allowed with mocking. |
| 309 collection_name = service_class._NAME |
| 310 # pylint: enable=protected-access |
| 311 api_name = '%s_%s' % (self.__client_class._PACKAGE, |
| 312 self.__client_class._URL_VERSION) |
| 313 mocked_service = _MockedService( |
| 314 api_name + '.' + collection_name, self, |
| 315 service._method_configs.keys(), |
| 316 service if self.__real_client else None) |
| 317 mocked_constructor = _MakeMockedServiceConstructor(mocked_service) |
| 318 setattr(self.__client_class, name, mocked_constructor) |
| 319 |
| 320 setattr(self, collection_name, mocked_service) |
| 321 |
| 322 self.__real_include_fields = self.__client_class.IncludeFields |
| 323 self.__client_class.IncludeFields = self.IncludeFields |
| 324 |
| 325 return self |
| 326 |
| 327 def __exit__(self, exc_type, value, traceback): |
| 328 self.Unmock() |
| 329 if value: |
| 330 six.reraise(exc_type, value, traceback) |
| 331 return True |
| 332 |
| 333 def Unmock(self): |
| 334 for name, service_class in self.__real_service_classes.items(): |
| 335 setattr(self.__client_class, name, service_class) |
| 336 delattr(self, service_class._NAME) |
| 337 self.__real_service_classes = {} |
| 338 |
| 339 if self._request_responses: |
| 340 raise ExpectedRequestsException( |
| 341 [(rq_rs.key, rq_rs.request) for rq_rs |
| 342 in self._request_responses]) |
| 343 self._request_responses = [] |
| 344 |
| 345 self.__client_class.IncludeFields = self.__real_include_fields |
| 346 self.__real_include_fields = None |
| 347 |
| 348 def IncludeFields(self, include_fields): |
| 349 if self.__real_client: |
| 350 return self.__real_include_fields(self.__real_client, |
| 351 include_fields) |
OLD | NEW |