| OLD | NEW |
| (Empty) |
| 1 #!/usr/bin/python2.4 | |
| 2 # | |
| 3 # Copyright 2008 Google Inc. | |
| 4 # | |
| 5 # Licensed under the Apache License, Version 2.0 (the "License"); | |
| 6 # you may not use this file except in compliance with the License. | |
| 7 # You may obtain a copy of the License at | |
| 8 # | |
| 9 # http://www.apache.org/licenses/LICENSE-2.0 | |
| 10 # | |
| 11 # Unless required by applicable law or agreed to in writing, software | |
| 12 # distributed under the License is distributed on an "AS IS" BASIS, | |
| 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 14 # See the License for the specific language governing permissions and | |
| 15 # limitations under the License. | |
| 16 | |
| 17 """Mox, an object-mocking framework for Python. | |
| 18 | |
| 19 Mox works in the record-replay-verify paradigm. When you first create | |
| 20 a mock object, it is in record mode. You then programmatically set | |
| 21 the expected behavior of the mock object (what methods are to be | |
| 22 called on it, with what parameters, what they should return, and in | |
| 23 what order). | |
| 24 | |
| 25 Once you have set up the expected mock behavior, you put it in replay | |
| 26 mode. Now the mock responds to method calls just as you told it to. | |
| 27 If an unexpected method (or an expected method with unexpected | |
| 28 parameters) is called, then an exception will be raised. | |
| 29 | |
| 30 Once you are done interacting with the mock, you need to verify that | |
| 31 all the expected interactions occured. (Maybe your code exited | |
| 32 prematurely without calling some cleanup method!) The verify phase | |
| 33 ensures that every expected method was called; otherwise, an exception | |
| 34 will be raised. | |
| 35 | |
| 36 WARNING! Mock objects created by Mox are not thread-safe. If you are | |
| 37 call a mock in multiple threads, it should be guarded by a mutex. | |
| 38 | |
| 39 TODO(stevepm): Add the option to make mocks thread-safe! | |
| 40 | |
| 41 Suggested usage / workflow: | |
| 42 | |
| 43 # Create Mox factory | |
| 44 my_mox = Mox() | |
| 45 | |
| 46 # Create a mock data access object | |
| 47 mock_dao = my_mox.CreateMock(DAOClass) | |
| 48 | |
| 49 # Set up expected behavior | |
| 50 mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person) | |
| 51 mock_dao.DeletePerson(person) | |
| 52 | |
| 53 # Put mocks in replay mode | |
| 54 my_mox.ReplayAll() | |
| 55 | |
| 56 # Inject mock object and run test | |
| 57 controller.SetDao(mock_dao) | |
| 58 controller.DeletePersonById('1') | |
| 59 | |
| 60 # Verify all methods were called as expected | |
| 61 my_mox.VerifyAll() | |
| 62 """ | |
| 63 | |
| 64 from collections import deque | |
| 65 import difflib | |
| 66 import inspect | |
| 67 import re | |
| 68 import types | |
| 69 import unittest | |
| 70 | |
| 71 import stubout | |
| 72 | |
| 73 class Error(AssertionError): | |
| 74 """Base exception for this module.""" | |
| 75 | |
| 76 pass | |
| 77 | |
| 78 | |
| 79 class ExpectedMethodCallsError(Error): | |
| 80 """Raised when Verify() is called before all expected methods have been called | |
| 81 """ | |
| 82 | |
| 83 def __init__(self, expected_methods): | |
| 84 """Init exception. | |
| 85 | |
| 86 Args: | |
| 87 # expected_methods: A sequence of MockMethod objects that should have been | |
| 88 # called. | |
| 89 expected_methods: [MockMethod] | |
| 90 | |
| 91 Raises: | |
| 92 ValueError: if expected_methods contains no methods. | |
| 93 """ | |
| 94 | |
| 95 if not expected_methods: | |
| 96 raise ValueError("There must be at least one expected method") | |
| 97 Error.__init__(self) | |
| 98 self._expected_methods = expected_methods | |
| 99 | |
| 100 def __str__(self): | |
| 101 calls = "\n".join(["%3d. %s" % (i, m) | |
| 102 for i, m in enumerate(self._expected_methods)]) | |
| 103 return "Verify: Expected methods never called:\n%s" % (calls,) | |
| 104 | |
| 105 | |
| 106 class UnexpectedMethodCallError(Error): | |
| 107 """Raised when an unexpected method is called. | |
| 108 | |
| 109 This can occur if a method is called with incorrect parameters, or out of the | |
| 110 specified order. | |
| 111 """ | |
| 112 | |
| 113 def __init__(self, unexpected_method, expected): | |
| 114 """Init exception. | |
| 115 | |
| 116 Args: | |
| 117 # unexpected_method: MockMethod that was called but was not at the head of | |
| 118 # the expected_method queue. | |
| 119 # expected: MockMethod or UnorderedGroup the method should have | |
| 120 # been in. | |
| 121 unexpected_method: MockMethod | |
| 122 expected: MockMethod or UnorderedGroup | |
| 123 """ | |
| 124 | |
| 125 Error.__init__(self) | |
| 126 if expected is None: | |
| 127 self._str = "Unexpected method call %s" % (unexpected_method,) | |
| 128 else: | |
| 129 differ = difflib.Differ() | |
| 130 diff = differ.compare(str(unexpected_method).splitlines(True), | |
| 131 str(expected).splitlines(True)) | |
| 132 self._str = ("Unexpected method call. unexpected:- expected:+\n%s" | |
| 133 % ("\n".join(diff),)) | |
| 134 | |
| 135 def __str__(self): | |
| 136 return self._str | |
| 137 | |
| 138 | |
| 139 class UnknownMethodCallError(Error): | |
| 140 """Raised if an unknown method is requested of the mock object.""" | |
| 141 | |
| 142 def __init__(self, unknown_method_name): | |
| 143 """Init exception. | |
| 144 | |
| 145 Args: | |
| 146 # unknown_method_name: Method call that is not part of the mocked class's | |
| 147 # public interface. | |
| 148 unknown_method_name: str | |
| 149 """ | |
| 150 | |
| 151 Error.__init__(self) | |
| 152 self._unknown_method_name = unknown_method_name | |
| 153 | |
| 154 def __str__(self): | |
| 155 return "Method called is not a member of the object: %s" % \ | |
| 156 self._unknown_method_name | |
| 157 | |
| 158 | |
| 159 class Mox(object): | |
| 160 """Mox: a factory for creating mock objects.""" | |
| 161 | |
| 162 # A list of types that should be stubbed out with MockObjects (as | |
| 163 # opposed to MockAnythings). | |
| 164 _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType, | |
| 165 types.ObjectType, types.TypeType] | |
| 166 | |
| 167 def __init__(self): | |
| 168 """Initialize a new Mox.""" | |
| 169 | |
| 170 self._mock_objects = [] | |
| 171 self.stubs = stubout.StubOutForTesting() | |
| 172 | |
| 173 def CreateMock(self, class_to_mock): | |
| 174 """Create a new mock object. | |
| 175 | |
| 176 Args: | |
| 177 # class_to_mock: the class to be mocked | |
| 178 class_to_mock: class | |
| 179 | |
| 180 Returns: | |
| 181 MockObject that can be used as the class_to_mock would be. | |
| 182 """ | |
| 183 | |
| 184 new_mock = MockObject(class_to_mock) | |
| 185 self._mock_objects.append(new_mock) | |
| 186 return new_mock | |
| 187 | |
| 188 def CreateMockAnything(self, description=None): | |
| 189 """Create a mock that will accept any method calls. | |
| 190 | |
| 191 This does not enforce an interface. | |
| 192 | |
| 193 Args: | |
| 194 description: str. Optionally, a descriptive name for the mock object being | |
| 195 created, for debugging output purposes. | |
| 196 """ | |
| 197 new_mock = MockAnything(description=description) | |
| 198 self._mock_objects.append(new_mock) | |
| 199 return new_mock | |
| 200 | |
| 201 def ReplayAll(self): | |
| 202 """Set all mock objects to replay mode.""" | |
| 203 | |
| 204 for mock_obj in self._mock_objects: | |
| 205 mock_obj._Replay() | |
| 206 | |
| 207 | |
| 208 def VerifyAll(self): | |
| 209 """Call verify on all mock objects created.""" | |
| 210 | |
| 211 for mock_obj in self._mock_objects: | |
| 212 mock_obj._Verify() | |
| 213 | |
| 214 def ResetAll(self): | |
| 215 """Call reset on all mock objects. This does not unset stubs.""" | |
| 216 | |
| 217 for mock_obj in self._mock_objects: | |
| 218 mock_obj._Reset() | |
| 219 | |
| 220 def StubOutWithMock(self, obj, attr_name, use_mock_anything=False): | |
| 221 """Replace a method, attribute, etc. with a Mock. | |
| 222 | |
| 223 This will replace a class or module with a MockObject, and everything else | |
| 224 (method, function, etc) with a MockAnything. This can be overridden to | |
| 225 always use a MockAnything by setting use_mock_anything to True. | |
| 226 | |
| 227 Args: | |
| 228 obj: A Python object (class, module, instance, callable). | |
| 229 attr_name: str. The name of the attribute to replace with a mock. | |
| 230 use_mock_anything: bool. True if a MockAnything should be used regardless | |
| 231 of the type of attribute. | |
| 232 """ | |
| 233 | |
| 234 attr_to_replace = getattr(obj, attr_name) | |
| 235 | |
| 236 # Check for a MockAnything. This could cause confusing problems later on. | |
| 237 if attr_to_replace == MockAnything(): | |
| 238 raise TypeError('Cannot mock a MockAnything! Did you remember to ' | |
| 239 'call UnsetStubs in your previous test?') | |
| 240 | |
| 241 if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything: | |
| 242 stub = self.CreateMock(attr_to_replace) | |
| 243 else: | |
| 244 stub = self.CreateMockAnything(description='Stub for %s' % attr_to_replace
) | |
| 245 | |
| 246 self.stubs.Set(obj, attr_name, stub) | |
| 247 | |
| 248 def UnsetStubs(self): | |
| 249 """Restore stubs to their original state.""" | |
| 250 | |
| 251 self.stubs.UnsetAll() | |
| 252 | |
| 253 def Replay(*args): | |
| 254 """Put mocks into Replay mode. | |
| 255 | |
| 256 Args: | |
| 257 # args is any number of mocks to put into replay mode. | |
| 258 """ | |
| 259 | |
| 260 for mock in args: | |
| 261 mock._Replay() | |
| 262 | |
| 263 | |
| 264 def Verify(*args): | |
| 265 """Verify mocks. | |
| 266 | |
| 267 Args: | |
| 268 # args is any number of mocks to be verified. | |
| 269 """ | |
| 270 | |
| 271 for mock in args: | |
| 272 mock._Verify() | |
| 273 | |
| 274 | |
| 275 def Reset(*args): | |
| 276 """Reset mocks. | |
| 277 | |
| 278 Args: | |
| 279 # args is any number of mocks to be reset. | |
| 280 """ | |
| 281 | |
| 282 for mock in args: | |
| 283 mock._Reset() | |
| 284 | |
| 285 | |
| 286 class MockAnything: | |
| 287 """A mock that can be used to mock anything. | |
| 288 | |
| 289 This is helpful for mocking classes that do not provide a public interface. | |
| 290 """ | |
| 291 | |
| 292 def __init__(self, description=None): | |
| 293 """Initialize a new MockAnything. | |
| 294 | |
| 295 Args: | |
| 296 description: str. Optionally, a descriptive name for the mock object being | |
| 297 created, for debugging output purposes. | |
| 298 """ | |
| 299 self._description = description | |
| 300 self._Reset() | |
| 301 | |
| 302 def __str__(self): | |
| 303 return "<MockAnything instance at %s>" % id(self) | |
| 304 | |
| 305 def __repr__(self): | |
| 306 return '<MockAnything instance>' | |
| 307 | |
| 308 def __getattr__(self, method_name): | |
| 309 """Intercept method calls on this object. | |
| 310 | |
| 311 A new MockMethod is returned that is aware of the MockAnything's | |
| 312 state (record or replay). The call will be recorded or replayed | |
| 313 by the MockMethod's __call__. | |
| 314 | |
| 315 Args: | |
| 316 # method name: the name of the method being called. | |
| 317 method_name: str | |
| 318 | |
| 319 Returns: | |
| 320 A new MockMethod aware of MockAnything's state (record or replay). | |
| 321 """ | |
| 322 | |
| 323 return self._CreateMockMethod(method_name) | |
| 324 | |
| 325 def _CreateMockMethod(self, method_name, method_to_mock=None): | |
| 326 """Create a new mock method call and return it. | |
| 327 | |
| 328 Args: | |
| 329 # method_name: the name of the method being called. | |
| 330 # method_to_mock: The actual method being mocked, used for introspection. | |
| 331 method_name: str | |
| 332 method_to_mock: a method object | |
| 333 | |
| 334 Returns: | |
| 335 A new MockMethod aware of MockAnything's state (record or replay). | |
| 336 """ | |
| 337 | |
| 338 return MockMethod(method_name, self._expected_calls_queue, | |
| 339 self._replay_mode, method_to_mock=method_to_mock, | |
| 340 description=self._description) | |
| 341 | |
| 342 def __nonzero__(self): | |
| 343 """Return 1 for nonzero so the mock can be used as a conditional.""" | |
| 344 | |
| 345 return 1 | |
| 346 | |
| 347 def __eq__(self, rhs): | |
| 348 """Provide custom logic to compare objects.""" | |
| 349 | |
| 350 return (isinstance(rhs, MockAnything) and | |
| 351 self._replay_mode == rhs._replay_mode and | |
| 352 self._expected_calls_queue == rhs._expected_calls_queue) | |
| 353 | |
| 354 def __ne__(self, rhs): | |
| 355 """Provide custom logic to compare objects.""" | |
| 356 | |
| 357 return not self == rhs | |
| 358 | |
| 359 def _Replay(self): | |
| 360 """Start replaying expected method calls.""" | |
| 361 | |
| 362 self._replay_mode = True | |
| 363 | |
| 364 def _Verify(self): | |
| 365 """Verify that all of the expected calls have been made. | |
| 366 | |
| 367 Raises: | |
| 368 ExpectedMethodCallsError: if there are still more method calls in the | |
| 369 expected queue. | |
| 370 """ | |
| 371 | |
| 372 # If the list of expected calls is not empty, raise an exception | |
| 373 if self._expected_calls_queue: | |
| 374 # The last MultipleTimesGroup is not popped from the queue. | |
| 375 if (len(self._expected_calls_queue) == 1 and | |
| 376 isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and | |
| 377 self._expected_calls_queue[0].IsSatisfied()): | |
| 378 pass | |
| 379 else: | |
| 380 raise ExpectedMethodCallsError(self._expected_calls_queue) | |
| 381 | |
| 382 def _Reset(self): | |
| 383 """Reset the state of this mock to record mode with an empty queue.""" | |
| 384 | |
| 385 # Maintain a list of method calls we are expecting | |
| 386 self._expected_calls_queue = deque() | |
| 387 | |
| 388 # Make sure we are in setup mode, not replay mode | |
| 389 self._replay_mode = False | |
| 390 | |
| 391 | |
| 392 class MockObject(MockAnything, object): | |
| 393 """A mock object that simulates the public/protected interface of a class.""" | |
| 394 | |
| 395 def __init__(self, class_to_mock): | |
| 396 """Initialize a mock object. | |
| 397 | |
| 398 This determines the methods and properties of the class and stores them. | |
| 399 | |
| 400 Args: | |
| 401 # class_to_mock: class to be mocked | |
| 402 class_to_mock: class | |
| 403 """ | |
| 404 | |
| 405 # This is used to hack around the mixin/inheritance of MockAnything, which | |
| 406 # is not a proper object (it can be anything. :-) | |
| 407 MockAnything.__dict__['__init__'](self) | |
| 408 | |
| 409 # Get a list of all the public and special methods we should mock. | |
| 410 self._known_methods = set() | |
| 411 self._known_vars = set() | |
| 412 self._class_to_mock = class_to_mock | |
| 413 for method in dir(class_to_mock): | |
| 414 if callable(getattr(class_to_mock, method)): | |
| 415 self._known_methods.add(method) | |
| 416 else: | |
| 417 self._known_vars.add(method) | |
| 418 | |
| 419 def __getattr__(self, name): | |
| 420 """Intercept attribute request on this object. | |
| 421 | |
| 422 If the attribute is a public class variable, it will be returned and not | |
| 423 recorded as a call. | |
| 424 | |
| 425 If the attribute is not a variable, it is handled like a method | |
| 426 call. The method name is checked against the set of mockable | |
| 427 methods, and a new MockMethod is returned that is aware of the | |
| 428 MockObject's state (record or replay). The call will be recorded | |
| 429 or replayed by the MockMethod's __call__. | |
| 430 | |
| 431 Args: | |
| 432 # name: the name of the attribute being requested. | |
| 433 name: str | |
| 434 | |
| 435 Returns: | |
| 436 Either a class variable or a new MockMethod that is aware of the state | |
| 437 of the mock (record or replay). | |
| 438 | |
| 439 Raises: | |
| 440 UnknownMethodCallError if the MockObject does not mock the requested | |
| 441 method. | |
| 442 """ | |
| 443 | |
| 444 if name in self._known_vars: | |
| 445 return getattr(self._class_to_mock, name) | |
| 446 | |
| 447 if name in self._known_methods: | |
| 448 return self._CreateMockMethod( | |
| 449 name, | |
| 450 method_to_mock=getattr(self._class_to_mock, name)) | |
| 451 | |
| 452 raise UnknownMethodCallError(name) | |
| 453 | |
| 454 def __eq__(self, rhs): | |
| 455 """Provide custom logic to compare objects.""" | |
| 456 | |
| 457 return (isinstance(rhs, MockObject) and | |
| 458 self._class_to_mock == rhs._class_to_mock and | |
| 459 self._replay_mode == rhs._replay_mode and | |
| 460 self._expected_calls_queue == rhs._expected_calls_queue) | |
| 461 | |
| 462 def __setitem__(self, key, value): | |
| 463 """Provide custom logic for mocking classes that support item assignment. | |
| 464 | |
| 465 Args: | |
| 466 key: Key to set the value for. | |
| 467 value: Value to set. | |
| 468 | |
| 469 Returns: | |
| 470 Expected return value in replay mode. A MockMethod object for the | |
| 471 __setitem__ method that has already been called if not in replay mode. | |
| 472 | |
| 473 Raises: | |
| 474 TypeError if the underlying class does not support item assignment. | |
| 475 UnexpectedMethodCallError if the object does not expect the call to | |
| 476 __setitem__. | |
| 477 | |
| 478 """ | |
| 479 # Verify the class supports item assignment. | |
| 480 if '__setitem__' not in dir(self._class_to_mock): | |
| 481 raise TypeError('object does not support item assignment') | |
| 482 | |
| 483 # If we are in replay mode then simply call the mock __setitem__ method. | |
| 484 if self._replay_mode: | |
| 485 return MockMethod('__setitem__', self._expected_calls_queue, | |
| 486 self._replay_mode)(key, value) | |
| 487 | |
| 488 | |
| 489 # Otherwise, create a mock method __setitem__. | |
| 490 return self._CreateMockMethod('__setitem__')(key, value) | |
| 491 | |
| 492 def __getitem__(self, key): | |
| 493 """Provide custom logic for mocking classes that are subscriptable. | |
| 494 | |
| 495 Args: | |
| 496 key: Key to return the value for. | |
| 497 | |
| 498 Returns: | |
| 499 Expected return value in replay mode. A MockMethod object for the | |
| 500 __getitem__ method that has already been called if not in replay mode. | |
| 501 | |
| 502 Raises: | |
| 503 TypeError if the underlying class is not subscriptable. | |
| 504 UnexpectedMethodCallError if the object does not expect the call to | |
| 505 __getitem__. | |
| 506 | |
| 507 """ | |
| 508 # Verify the class supports item assignment. | |
| 509 if '__getitem__' not in dir(self._class_to_mock): | |
| 510 raise TypeError('unsubscriptable object') | |
| 511 | |
| 512 # If we are in replay mode then simply call the mock __getitem__ method. | |
| 513 if self._replay_mode: | |
| 514 return MockMethod('__getitem__', self._expected_calls_queue, | |
| 515 self._replay_mode)(key) | |
| 516 | |
| 517 | |
| 518 # Otherwise, create a mock method __getitem__. | |
| 519 return self._CreateMockMethod('__getitem__')(key) | |
| 520 | |
| 521 def __iter__(self): | |
| 522 """Provide custom logic for mocking classes that are iterable. | |
| 523 | |
| 524 Returns: | |
| 525 Expected return value in replay mode. A MockMethod object for the | |
| 526 __iter__ method that has already been called if not in replay mode. | |
| 527 | |
| 528 Raises: | |
| 529 TypeError if the underlying class is not iterable. | |
| 530 UnexpectedMethodCallError if the object does not expect the call to | |
| 531 __iter__. | |
| 532 | |
| 533 """ | |
| 534 methods = dir(self._class_to_mock) | |
| 535 | |
| 536 # Verify the class supports iteration. | |
| 537 if '__iter__' not in methods: | |
| 538 # If it doesn't have iter method and we are in replay method, then try to | |
| 539 # iterate using subscripts. | |
| 540 if '__getitem__' not in methods or not self._replay_mode: | |
| 541 raise TypeError('not iterable object') | |
| 542 else: | |
| 543 results = [] | |
| 544 index = 0 | |
| 545 try: | |
| 546 while True: | |
| 547 results.append(self[index]) | |
| 548 index += 1 | |
| 549 except IndexError: | |
| 550 return iter(results) | |
| 551 | |
| 552 # If we are in replay mode then simply call the mock __iter__ method. | |
| 553 if self._replay_mode: | |
| 554 return MockMethod('__iter__', self._expected_calls_queue, | |
| 555 self._replay_mode)() | |
| 556 | |
| 557 | |
| 558 # Otherwise, create a mock method __iter__. | |
| 559 return self._CreateMockMethod('__iter__')() | |
| 560 | |
| 561 | |
| 562 def __contains__(self, key): | |
| 563 """Provide custom logic for mocking classes that contain items. | |
| 564 | |
| 565 Args: | |
| 566 key: Key to look in container for. | |
| 567 | |
| 568 Returns: | |
| 569 Expected return value in replay mode. A MockMethod object for the | |
| 570 __contains__ method that has already been called if not in replay mode. | |
| 571 | |
| 572 Raises: | |
| 573 TypeError if the underlying class does not implement __contains__ | |
| 574 UnexpectedMethodCaller if the object does not expect the call to | |
| 575 __contains__. | |
| 576 | |
| 577 """ | |
| 578 contains = self._class_to_mock.__dict__.get('__contains__', None) | |
| 579 | |
| 580 if contains is None: | |
| 581 raise TypeError('unsubscriptable object') | |
| 582 | |
| 583 if self._replay_mode: | |
| 584 return MockMethod('__contains__', self._expected_calls_queue, | |
| 585 self._replay_mode)(key) | |
| 586 | |
| 587 return self._CreateMockMethod('__contains__')(key) | |
| 588 | |
| 589 def __call__(self, *params, **named_params): | |
| 590 """Provide custom logic for mocking classes that are callable.""" | |
| 591 | |
| 592 # Verify the class we are mocking is callable. | |
| 593 callable = hasattr(self._class_to_mock, '__call__') | |
| 594 if not callable: | |
| 595 raise TypeError('Not callable') | |
| 596 | |
| 597 # Because the call is happening directly on this object instead of a method, | |
| 598 # the call on the mock method is made right here | |
| 599 mock_method = self._CreateMockMethod('__call__') | |
| 600 return mock_method(*params, **named_params) | |
| 601 | |
| 602 @property | |
| 603 def __class__(self): | |
| 604 """Return the class that is being mocked.""" | |
| 605 | |
| 606 return self._class_to_mock | |
| 607 | |
| 608 | |
| 609 class MethodCallChecker(object): | |
| 610 """Ensures that methods are called correctly.""" | |
| 611 | |
| 612 _NEEDED, _DEFAULT, _GIVEN = range(3) | |
| 613 | |
| 614 def __init__(self, method): | |
| 615 """Creates a checker. | |
| 616 | |
| 617 Args: | |
| 618 # method: A method to check. | |
| 619 method: function | |
| 620 | |
| 621 Raises: | |
| 622 ValueError: method could not be inspected, so checks aren't possible. | |
| 623 Some methods and functions like built-ins can't be inspected. | |
| 624 """ | |
| 625 try: | |
| 626 self._args, varargs, varkw, defaults = inspect.getargspec(method) | |
| 627 except TypeError: | |
| 628 raise ValueError('Could not get argument specification for %r' | |
| 629 % (method,)) | |
| 630 if inspect.ismethod(method): | |
| 631 self._args = self._args[1:] # Skip 'self'. | |
| 632 self._method = method | |
| 633 | |
| 634 self._has_varargs = varargs is not None | |
| 635 self._has_varkw = varkw is not None | |
| 636 if defaults is None: | |
| 637 self._required_args = self._args | |
| 638 self._default_args = [] | |
| 639 else: | |
| 640 self._required_args = self._args[:-len(defaults)] | |
| 641 self._default_args = self._args[-len(defaults):] | |
| 642 | |
| 643 def _RecordArgumentGiven(self, arg_name, arg_status): | |
| 644 """Mark an argument as being given. | |
| 645 | |
| 646 Args: | |
| 647 # arg_name: The name of the argument to mark in arg_status. | |
| 648 # arg_status: Maps argument names to one of _NEEDED, _DEFAULT, _GIVEN. | |
| 649 arg_name: string | |
| 650 arg_status: dict | |
| 651 | |
| 652 Raises: | |
| 653 AttributeError: arg_name is already marked as _GIVEN. | |
| 654 """ | |
| 655 if arg_status.get(arg_name, None) == MethodCallChecker._GIVEN: | |
| 656 raise AttributeError('%s provided more than once' % (arg_name,)) | |
| 657 arg_status[arg_name] = MethodCallChecker._GIVEN | |
| 658 | |
| 659 def Check(self, params, named_params): | |
| 660 """Ensures that the parameters used while recording a call are valid. | |
| 661 | |
| 662 Args: | |
| 663 # params: A list of positional parameters. | |
| 664 # named_params: A dict of named parameters. | |
| 665 params: list | |
| 666 named_params: dict | |
| 667 | |
| 668 Raises: | |
| 669 AttributeError: the given parameters don't work with the given method. | |
| 670 """ | |
| 671 arg_status = dict((a, MethodCallChecker._NEEDED) | |
| 672 for a in self._required_args) | |
| 673 for arg in self._default_args: | |
| 674 arg_status[arg] = MethodCallChecker._DEFAULT | |
| 675 | |
| 676 # Check that each positional param is valid. | |
| 677 for i in range(len(params)): | |
| 678 try: | |
| 679 arg_name = self._args[i] | |
| 680 except IndexError: | |
| 681 if not self._has_varargs: | |
| 682 raise AttributeError('%s does not take %d or more positional ' | |
| 683 'arguments' % (self._method.__name__, i)) | |
| 684 else: | |
| 685 self._RecordArgumentGiven(arg_name, arg_status) | |
| 686 | |
| 687 # Check each keyword argument. | |
| 688 for arg_name in named_params: | |
| 689 if arg_name not in arg_status and not self._has_varkw: | |
| 690 raise AttributeError('%s is not expecting keyword argument %s' | |
| 691 % (self._method.__name__, arg_name)) | |
| 692 self._RecordArgumentGiven(arg_name, arg_status) | |
| 693 | |
| 694 # Ensure all the required arguments have been given. | |
| 695 still_needed = [k for k, v in arg_status.iteritems() | |
| 696 if v == MethodCallChecker._NEEDED] | |
| 697 if still_needed: | |
| 698 raise AttributeError('No values given for arguments %s' | |
| 699 % (' '.join(sorted(still_needed)))) | |
| 700 | |
| 701 | |
| 702 class MockMethod(object): | |
| 703 """Callable mock method. | |
| 704 | |
| 705 A MockMethod should act exactly like the method it mocks, accepting parameters | |
| 706 and returning a value, or throwing an exception (as specified). When this | |
| 707 method is called, it can optionally verify whether the called method (name and | |
| 708 signature) matches the expected method. | |
| 709 """ | |
| 710 | |
| 711 def __init__(self, method_name, call_queue, replay_mode, | |
| 712 method_to_mock=None, description=None): | |
| 713 """Construct a new mock method. | |
| 714 | |
| 715 Args: | |
| 716 # method_name: the name of the method | |
| 717 # call_queue: deque of calls, verify this call against the head, or add | |
| 718 # this call to the queue. | |
| 719 # replay_mode: False if we are recording, True if we are verifying calls | |
| 720 # against the call queue. | |
| 721 # method_to_mock: The actual method being mocked, used for introspection. | |
| 722 # description: optionally, a descriptive name for this method. Typically | |
| 723 # this is equal to the descriptive name of the method's class. | |
| 724 method_name: str | |
| 725 call_queue: list or deque | |
| 726 replay_mode: bool | |
| 727 method_to_mock: a method object | |
| 728 description: str or None | |
| 729 """ | |
| 730 | |
| 731 self._name = method_name | |
| 732 self._call_queue = call_queue | |
| 733 if not isinstance(call_queue, deque): | |
| 734 self._call_queue = deque(self._call_queue) | |
| 735 self._replay_mode = replay_mode | |
| 736 self._description = description | |
| 737 | |
| 738 self._params = None | |
| 739 self._named_params = None | |
| 740 self._return_value = None | |
| 741 self._exception = None | |
| 742 self._side_effects = None | |
| 743 | |
| 744 try: | |
| 745 self._checker = MethodCallChecker(method_to_mock) | |
| 746 except ValueError: | |
| 747 self._checker = None | |
| 748 | |
| 749 def __call__(self, *params, **named_params): | |
| 750 """Log parameters and return the specified return value. | |
| 751 | |
| 752 If the Mock(Anything/Object) associated with this call is in record mode, | |
| 753 this MockMethod will be pushed onto the expected call queue. If the mock | |
| 754 is in replay mode, this will pop a MockMethod off the top of the queue and | |
| 755 verify this call is equal to the expected call. | |
| 756 | |
| 757 Raises: | |
| 758 UnexpectedMethodCall if this call is supposed to match an expected method | |
| 759 call and it does not. | |
| 760 """ | |
| 761 | |
| 762 self._params = params | |
| 763 self._named_params = named_params | |
| 764 | |
| 765 if not self._replay_mode: | |
| 766 if self._checker is not None: | |
| 767 self._checker.Check(params, named_params) | |
| 768 self._call_queue.append(self) | |
| 769 return self | |
| 770 | |
| 771 expected_method = self._VerifyMethodCall() | |
| 772 | |
| 773 if expected_method._side_effects: | |
| 774 expected_method._side_effects(*params, **named_params) | |
| 775 | |
| 776 if expected_method._exception: | |
| 777 raise expected_method._exception | |
| 778 | |
| 779 return expected_method._return_value | |
| 780 | |
| 781 def __getattr__(self, name): | |
| 782 """Raise an AttributeError with a helpful message.""" | |
| 783 | |
| 784 raise AttributeError('MockMethod has no attribute "%s". ' | |
| 785 'Did you remember to put your mocks in replay mode?' % name) | |
| 786 | |
| 787 def __iter__(self): | |
| 788 """Raise a TypeError with a helpful message.""" | |
| 789 raise TypeError('MockMethod cannot be iterated. ' | |
| 790 'Did you remember to put your mocks in replay mode?') | |
| 791 | |
| 792 def next(self): | |
| 793 """Raise a TypeError with a helpful message.""" | |
| 794 raise TypeError('MockMethod cannot be iterated. ' | |
| 795 'Did you remember to put your mocks in replay mode?') | |
| 796 | |
| 797 def _PopNextMethod(self): | |
| 798 """Pop the next method from our call queue.""" | |
| 799 try: | |
| 800 return self._call_queue.popleft() | |
| 801 except IndexError: | |
| 802 raise UnexpectedMethodCallError(self, None) | |
| 803 | |
| 804 def _VerifyMethodCall(self): | |
| 805 """Verify the called method is expected. | |
| 806 | |
| 807 This can be an ordered method, or part of an unordered set. | |
| 808 | |
| 809 Returns: | |
| 810 The expected mock method. | |
| 811 | |
| 812 Raises: | |
| 813 UnexpectedMethodCall if the method called was not expected. | |
| 814 """ | |
| 815 | |
| 816 expected = self._PopNextMethod() | |
| 817 | |
| 818 # Loop here, because we might have a MethodGroup followed by another | |
| 819 # group. | |
| 820 while isinstance(expected, MethodGroup): | |
| 821 expected, method = expected.MethodCalled(self) | |
| 822 if method is not None: | |
| 823 return method | |
| 824 | |
| 825 # This is a mock method, so just check equality. | |
| 826 if expected != self: | |
| 827 raise UnexpectedMethodCallError(self, expected) | |
| 828 | |
| 829 return expected | |
| 830 | |
| 831 def __str__(self): | |
| 832 params = ', '.join( | |
| 833 [repr(p) for p in self._params or []] + | |
| 834 ['%s=%r' % x for x in sorted((self._named_params or {}).items())]) | |
| 835 full_desc = "%s(%s) -> %r" % (self._name, params, self._return_value) | |
| 836 if self._description: | |
| 837 full_desc = "%s.%s" % (self._description, full_desc) | |
| 838 return full_desc | |
| 839 | |
| 840 def __eq__(self, rhs): | |
| 841 """Test whether this MockMethod is equivalent to another MockMethod. | |
| 842 | |
| 843 Args: | |
| 844 # rhs: the right hand side of the test | |
| 845 rhs: MockMethod | |
| 846 """ | |
| 847 | |
| 848 return (isinstance(rhs, MockMethod) and | |
| 849 self._name == rhs._name and | |
| 850 self._params == rhs._params and | |
| 851 self._named_params == rhs._named_params) | |
| 852 | |
| 853 def __ne__(self, rhs): | |
| 854 """Test whether this MockMethod is not equivalent to another MockMethod. | |
| 855 | |
| 856 Args: | |
| 857 # rhs: the right hand side of the test | |
| 858 rhs: MockMethod | |
| 859 """ | |
| 860 | |
| 861 return not self == rhs | |
| 862 | |
| 863 def GetPossibleGroup(self): | |
| 864 """Returns a possible group from the end of the call queue or None if no | |
| 865 other methods are on the stack. | |
| 866 """ | |
| 867 | |
| 868 # Remove this method from the tail of the queue so we can add it to a group. | |
| 869 this_method = self._call_queue.pop() | |
| 870 assert this_method == self | |
| 871 | |
| 872 # Determine if the tail of the queue is a group, or just a regular ordered | |
| 873 # mock method. | |
| 874 group = None | |
| 875 try: | |
| 876 group = self._call_queue[-1] | |
| 877 except IndexError: | |
| 878 pass | |
| 879 | |
| 880 return group | |
| 881 | |
| 882 def _CheckAndCreateNewGroup(self, group_name, group_class): | |
| 883 """Checks if the last method (a possible group) is an instance of our | |
| 884 group_class. Adds the current method to this group or creates a new one. | |
| 885 | |
| 886 Args: | |
| 887 | |
| 888 group_name: the name of the group. | |
| 889 group_class: the class used to create instance of this new group | |
| 890 """ | |
| 891 group = self.GetPossibleGroup() | |
| 892 | |
| 893 # If this is a group, and it is the correct group, add the method. | |
| 894 if isinstance(group, group_class) and group.group_name() == group_name: | |
| 895 group.AddMethod(self) | |
| 896 return self | |
| 897 | |
| 898 # Create a new group and add the method. | |
| 899 new_group = group_class(group_name) | |
| 900 new_group.AddMethod(self) | |
| 901 self._call_queue.append(new_group) | |
| 902 return self | |
| 903 | |
| 904 def InAnyOrder(self, group_name="default"): | |
| 905 """Move this method into a group of unordered calls. | |
| 906 | |
| 907 A group of unordered calls must be defined together, and must be executed | |
| 908 in full before the next expected method can be called. There can be | |
| 909 multiple groups that are expected serially, if they are given | |
| 910 different group names. The same group name can be reused if there is a | |
| 911 standard method call, or a group with a different name, spliced between | |
| 912 usages. | |
| 913 | |
| 914 Args: | |
| 915 group_name: the name of the unordered group. | |
| 916 | |
| 917 Returns: | |
| 918 self | |
| 919 """ | |
| 920 return self._CheckAndCreateNewGroup(group_name, UnorderedGroup) | |
| 921 | |
| 922 def MultipleTimes(self, group_name="default"): | |
| 923 """Move this method into group of calls which may be called multiple times. | |
| 924 | |
| 925 A group of repeating calls must be defined together, and must be executed in | |
| 926 full before the next expected mehtod can be called. | |
| 927 | |
| 928 Args: | |
| 929 group_name: the name of the unordered group. | |
| 930 | |
| 931 Returns: | |
| 932 self | |
| 933 """ | |
| 934 return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup) | |
| 935 | |
| 936 def AndReturn(self, return_value): | |
| 937 """Set the value to return when this method is called. | |
| 938 | |
| 939 Args: | |
| 940 # return_value can be anything. | |
| 941 """ | |
| 942 | |
| 943 self._return_value = return_value | |
| 944 return return_value | |
| 945 | |
| 946 def AndRaise(self, exception): | |
| 947 """Set the exception to raise when this method is called. | |
| 948 | |
| 949 Args: | |
| 950 # exception: the exception to raise when this method is called. | |
| 951 exception: Exception | |
| 952 """ | |
| 953 | |
| 954 self._exception = exception | |
| 955 | |
| 956 def WithSideEffects(self, side_effects): | |
| 957 """Set the side effects that are simulated when this method is called. | |
| 958 | |
| 959 Args: | |
| 960 side_effects: A callable which modifies the parameters or other relevant | |
| 961 state which a given test case depends on. | |
| 962 | |
| 963 Returns: | |
| 964 Self for chaining with AndReturn and AndRaise. | |
| 965 """ | |
| 966 self._side_effects = side_effects | |
| 967 return self | |
| 968 | |
| 969 class Comparator: | |
| 970 """Base class for all Mox comparators. | |
| 971 | |
| 972 A Comparator can be used as a parameter to a mocked method when the exact | |
| 973 value is not known. For example, the code you are testing might build up a | |
| 974 long SQL string that is passed to your mock DAO. You're only interested that | |
| 975 the IN clause contains the proper primary keys, so you can set your mock | |
| 976 up as follows: | |
| 977 | |
| 978 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) | |
| 979 | |
| 980 Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'. | |
| 981 | |
| 982 A Comparator may replace one or more parameters, for example: | |
| 983 # return at most 10 rows | |
| 984 mock_dao.RunQuery(StrContains('SELECT'), 10) | |
| 985 | |
| 986 or | |
| 987 | |
| 988 # Return some non-deterministic number of rows | |
| 989 mock_dao.RunQuery(StrContains('SELECT'), IsA(int)) | |
| 990 """ | |
| 991 | |
| 992 def equals(self, rhs): | |
| 993 """Special equals method that all comparators must implement. | |
| 994 | |
| 995 Args: | |
| 996 rhs: any python object | |
| 997 """ | |
| 998 | |
| 999 raise NotImplementedError, 'method must be implemented by a subclass.' | |
| 1000 | |
| 1001 def __eq__(self, rhs): | |
| 1002 return self.equals(rhs) | |
| 1003 | |
| 1004 def __ne__(self, rhs): | |
| 1005 return not self.equals(rhs) | |
| 1006 | |
| 1007 | |
| 1008 class IsA(Comparator): | |
| 1009 """This class wraps a basic Python type or class. It is used to verify | |
| 1010 that a parameter is of the given type or class. | |
| 1011 | |
| 1012 Example: | |
| 1013 mock_dao.Connect(IsA(DbConnectInfo)) | |
| 1014 """ | |
| 1015 | |
| 1016 def __init__(self, class_name): | |
| 1017 """Initialize IsA | |
| 1018 | |
| 1019 Args: | |
| 1020 class_name: basic python type or a class | |
| 1021 """ | |
| 1022 | |
| 1023 self._class_name = class_name | |
| 1024 | |
| 1025 def equals(self, rhs): | |
| 1026 """Check to see if the RHS is an instance of class_name. | |
| 1027 | |
| 1028 Args: | |
| 1029 # rhs: the right hand side of the test | |
| 1030 rhs: object | |
| 1031 | |
| 1032 Returns: | |
| 1033 bool | |
| 1034 """ | |
| 1035 | |
| 1036 try: | |
| 1037 return isinstance(rhs, self._class_name) | |
| 1038 except TypeError: | |
| 1039 # Check raw types if there was a type error. This is helpful for | |
| 1040 # things like cStringIO.StringIO. | |
| 1041 return type(rhs) == type(self._class_name) | |
| 1042 | |
| 1043 def __repr__(self): | |
| 1044 return str(self._class_name) | |
| 1045 | |
| 1046 class IsAlmost(Comparator): | |
| 1047 """Comparison class used to check whether a parameter is nearly equal | |
| 1048 to a given value. Generally useful for floating point numbers. | |
| 1049 | |
| 1050 Example mock_dao.SetTimeout((IsAlmost(3.9))) | |
| 1051 """ | |
| 1052 | |
| 1053 def __init__(self, float_value, places=7): | |
| 1054 """Initialize IsAlmost. | |
| 1055 | |
| 1056 Args: | |
| 1057 float_value: The value for making the comparison. | |
| 1058 places: The number of decimal places to round to. | |
| 1059 """ | |
| 1060 | |
| 1061 self._float_value = float_value | |
| 1062 self._places = places | |
| 1063 | |
| 1064 def equals(self, rhs): | |
| 1065 """Check to see if RHS is almost equal to float_value | |
| 1066 | |
| 1067 Args: | |
| 1068 rhs: the value to compare to float_value | |
| 1069 | |
| 1070 Returns: | |
| 1071 bool | |
| 1072 """ | |
| 1073 | |
| 1074 try: | |
| 1075 return round(rhs-self._float_value, self._places) == 0 | |
| 1076 except TypeError: | |
| 1077 # This is probably because either float_value or rhs is not a number. | |
| 1078 return False | |
| 1079 | |
| 1080 def __repr__(self): | |
| 1081 return str(self._float_value) | |
| 1082 | |
| 1083 class StrContains(Comparator): | |
| 1084 """Comparison class used to check whether a substring exists in a | |
| 1085 string parameter. This can be useful in mocking a database with SQL | |
| 1086 passed in as a string parameter, for example. | |
| 1087 | |
| 1088 Example: | |
| 1089 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) | |
| 1090 """ | |
| 1091 | |
| 1092 def __init__(self, search_string): | |
| 1093 """Initialize. | |
| 1094 | |
| 1095 Args: | |
| 1096 # search_string: the string you are searching for | |
| 1097 search_string: str | |
| 1098 """ | |
| 1099 | |
| 1100 self._search_string = search_string | |
| 1101 | |
| 1102 def equals(self, rhs): | |
| 1103 """Check to see if the search_string is contained in the rhs string. | |
| 1104 | |
| 1105 Args: | |
| 1106 # rhs: the right hand side of the test | |
| 1107 rhs: object | |
| 1108 | |
| 1109 Returns: | |
| 1110 bool | |
| 1111 """ | |
| 1112 | |
| 1113 try: | |
| 1114 return rhs.find(self._search_string) > -1 | |
| 1115 except Exception: | |
| 1116 return False | |
| 1117 | |
| 1118 def __repr__(self): | |
| 1119 return '<str containing \'%s\'>' % self._search_string | |
| 1120 | |
| 1121 | |
| 1122 class Regex(Comparator): | |
| 1123 """Checks if a string matches a regular expression. | |
| 1124 | |
| 1125 This uses a given regular expression to determine equality. | |
| 1126 """ | |
| 1127 | |
| 1128 def __init__(self, pattern, flags=0): | |
| 1129 """Initialize. | |
| 1130 | |
| 1131 Args: | |
| 1132 # pattern is the regular expression to search for | |
| 1133 pattern: str | |
| 1134 # flags passed to re.compile function as the second argument | |
| 1135 flags: int | |
| 1136 """ | |
| 1137 | |
| 1138 self.regex = re.compile(pattern, flags=flags) | |
| 1139 | |
| 1140 def equals(self, rhs): | |
| 1141 """Check to see if rhs matches regular expression pattern. | |
| 1142 | |
| 1143 Returns: | |
| 1144 bool | |
| 1145 """ | |
| 1146 | |
| 1147 return self.regex.search(rhs) is not None | |
| 1148 | |
| 1149 def __repr__(self): | |
| 1150 s = '<regular expression \'%s\'' % self.regex.pattern | |
| 1151 if self.regex.flags: | |
| 1152 s += ', flags=%d' % self.regex.flags | |
| 1153 s += '>' | |
| 1154 return s | |
| 1155 | |
| 1156 | |
| 1157 class In(Comparator): | |
| 1158 """Checks whether an item (or key) is in a list (or dict) parameter. | |
| 1159 | |
| 1160 Example: | |
| 1161 mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result) | |
| 1162 """ | |
| 1163 | |
| 1164 def __init__(self, key): | |
| 1165 """Initialize. | |
| 1166 | |
| 1167 Args: | |
| 1168 # key is any thing that could be in a list or a key in a dict | |
| 1169 """ | |
| 1170 | |
| 1171 self._key = key | |
| 1172 | |
| 1173 def equals(self, rhs): | |
| 1174 """Check to see whether key is in rhs. | |
| 1175 | |
| 1176 Args: | |
| 1177 rhs: dict | |
| 1178 | |
| 1179 Returns: | |
| 1180 bool | |
| 1181 """ | |
| 1182 | |
| 1183 return self._key in rhs | |
| 1184 | |
| 1185 def __repr__(self): | |
| 1186 return '<sequence or map containing \'%s\'>' % self._key | |
| 1187 | |
| 1188 | |
| 1189 class Not(Comparator): | |
| 1190 """Checks whether a predicates is False. | |
| 1191 | |
| 1192 Example: | |
| 1193 mock_dao.UpdateUsers(Not(ContainsKeyValue('stevepm', stevepm_user_info))) | |
| 1194 """ | |
| 1195 | |
| 1196 def __init__(self, predicate): | |
| 1197 """Initialize. | |
| 1198 | |
| 1199 Args: | |
| 1200 # predicate: a Comparator instance. | |
| 1201 """ | |
| 1202 | |
| 1203 assert isinstance(predicate, Comparator), ("predicate %r must be a" | |
| 1204 " Comparator." % predicate) | |
| 1205 self._predicate = predicate | |
| 1206 | |
| 1207 def equals(self, rhs): | |
| 1208 """Check to see whether the predicate is False. | |
| 1209 | |
| 1210 Args: | |
| 1211 rhs: A value that will be given in argument of the predicate. | |
| 1212 | |
| 1213 Returns: | |
| 1214 bool | |
| 1215 """ | |
| 1216 | |
| 1217 return not self._predicate.equals(rhs) | |
| 1218 | |
| 1219 def __repr__(self): | |
| 1220 return '<not \'%s\'>' % self._predicate | |
| 1221 | |
| 1222 | |
| 1223 class ContainsKeyValue(Comparator): | |
| 1224 """Checks whether a key/value pair is in a dict parameter. | |
| 1225 | |
| 1226 Example: | |
| 1227 mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info)) | |
| 1228 """ | |
| 1229 | |
| 1230 def __init__(self, key, value): | |
| 1231 """Initialize. | |
| 1232 | |
| 1233 Args: | |
| 1234 # key: a key in a dict | |
| 1235 # value: the corresponding value | |
| 1236 """ | |
| 1237 | |
| 1238 self._key = key | |
| 1239 self._value = value | |
| 1240 | |
| 1241 def equals(self, rhs): | |
| 1242 """Check whether the given key/value pair is in the rhs dict. | |
| 1243 | |
| 1244 Returns: | |
| 1245 bool | |
| 1246 """ | |
| 1247 | |
| 1248 try: | |
| 1249 return rhs[self._key] == self._value | |
| 1250 except Exception: | |
| 1251 return False | |
| 1252 | |
| 1253 def __repr__(self): | |
| 1254 return '<map containing the entry \'%s: %s\'>' % (self._key, self._value) | |
| 1255 | |
| 1256 | |
| 1257 class SameElementsAs(Comparator): | |
| 1258 """Checks whether iterables contain the same elements (ignoring order). | |
| 1259 | |
| 1260 Example: | |
| 1261 mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki')) | |
| 1262 """ | |
| 1263 | |
| 1264 def __init__(self, expected_seq): | |
| 1265 """Initialize. | |
| 1266 | |
| 1267 Args: | |
| 1268 expected_seq: a sequence | |
| 1269 """ | |
| 1270 | |
| 1271 self._expected_seq = expected_seq | |
| 1272 | |
| 1273 def equals(self, actual_seq): | |
| 1274 """Check to see whether actual_seq has same elements as expected_seq. | |
| 1275 | |
| 1276 Args: | |
| 1277 actual_seq: sequence | |
| 1278 | |
| 1279 Returns: | |
| 1280 bool | |
| 1281 """ | |
| 1282 | |
| 1283 try: | |
| 1284 expected = dict([(element, None) for element in self._expected_seq]) | |
| 1285 actual = dict([(element, None) for element in actual_seq]) | |
| 1286 except TypeError: | |
| 1287 # Fall back to slower list-compare if any of the objects are unhashable. | |
| 1288 expected = list(self._expected_seq) | |
| 1289 actual = list(actual_seq) | |
| 1290 expected.sort() | |
| 1291 actual.sort() | |
| 1292 return expected == actual | |
| 1293 | |
| 1294 def __repr__(self): | |
| 1295 return '<sequence with same elements as \'%s\'>' % self._expected_seq | |
| 1296 | |
| 1297 | |
| 1298 class And(Comparator): | |
| 1299 """Evaluates one or more Comparators on RHS and returns an AND of the results. | |
| 1300 """ | |
| 1301 | |
| 1302 def __init__(self, *args): | |
| 1303 """Initialize. | |
| 1304 | |
| 1305 Args: | |
| 1306 *args: One or more Comparator | |
| 1307 """ | |
| 1308 | |
| 1309 self._comparators = args | |
| 1310 | |
| 1311 def equals(self, rhs): | |
| 1312 """Checks whether all Comparators are equal to rhs. | |
| 1313 | |
| 1314 Args: | |
| 1315 # rhs: can be anything | |
| 1316 | |
| 1317 Returns: | |
| 1318 bool | |
| 1319 """ | |
| 1320 | |
| 1321 for comparator in self._comparators: | |
| 1322 if not comparator.equals(rhs): | |
| 1323 return False | |
| 1324 | |
| 1325 return True | |
| 1326 | |
| 1327 def __repr__(self): | |
| 1328 return '<AND %s>' % str(self._comparators) | |
| 1329 | |
| 1330 | |
| 1331 class Or(Comparator): | |
| 1332 """Evaluates one or more Comparators on RHS and returns an OR of the results. | |
| 1333 """ | |
| 1334 | |
| 1335 def __init__(self, *args): | |
| 1336 """Initialize. | |
| 1337 | |
| 1338 Args: | |
| 1339 *args: One or more Mox comparators | |
| 1340 """ | |
| 1341 | |
| 1342 self._comparators = args | |
| 1343 | |
| 1344 def equals(self, rhs): | |
| 1345 """Checks whether any Comparator is equal to rhs. | |
| 1346 | |
| 1347 Args: | |
| 1348 # rhs: can be anything | |
| 1349 | |
| 1350 Returns: | |
| 1351 bool | |
| 1352 """ | |
| 1353 | |
| 1354 for comparator in self._comparators: | |
| 1355 if comparator.equals(rhs): | |
| 1356 return True | |
| 1357 | |
| 1358 return False | |
| 1359 | |
| 1360 def __repr__(self): | |
| 1361 return '<OR %s>' % str(self._comparators) | |
| 1362 | |
| 1363 | |
| 1364 class Func(Comparator): | |
| 1365 """Call a function that should verify the parameter passed in is correct. | |
| 1366 | |
| 1367 You may need the ability to perform more advanced operations on the parameter | |
| 1368 in order to validate it. You can use this to have a callable validate any | |
| 1369 parameter. The callable should return either True or False. | |
| 1370 | |
| 1371 | |
| 1372 Example: | |
| 1373 | |
| 1374 def myParamValidator(param): | |
| 1375 # Advanced logic here | |
| 1376 return True | |
| 1377 | |
| 1378 mock_dao.DoSomething(Func(myParamValidator), true) | |
| 1379 """ | |
| 1380 | |
| 1381 def __init__(self, func): | |
| 1382 """Initialize. | |
| 1383 | |
| 1384 Args: | |
| 1385 func: callable that takes one parameter and returns a bool | |
| 1386 """ | |
| 1387 | |
| 1388 self._func = func | |
| 1389 | |
| 1390 def equals(self, rhs): | |
| 1391 """Test whether rhs passes the function test. | |
| 1392 | |
| 1393 rhs is passed into func. | |
| 1394 | |
| 1395 Args: | |
| 1396 rhs: any python object | |
| 1397 | |
| 1398 Returns: | |
| 1399 the result of func(rhs) | |
| 1400 """ | |
| 1401 | |
| 1402 return self._func(rhs) | |
| 1403 | |
| 1404 def __repr__(self): | |
| 1405 return str(self._func) | |
| 1406 | |
| 1407 | |
| 1408 class IgnoreArg(Comparator): | |
| 1409 """Ignore an argument. | |
| 1410 | |
| 1411 This can be used when we don't care about an argument of a method call. | |
| 1412 | |
| 1413 Example: | |
| 1414 # Check if CastMagic is called with 3 as first arg and 'disappear' as third. | |
| 1415 mymock.CastMagic(3, IgnoreArg(), 'disappear') | |
| 1416 """ | |
| 1417 | |
| 1418 def equals(self, unused_rhs): | |
| 1419 """Ignores arguments and returns True. | |
| 1420 | |
| 1421 Args: | |
| 1422 unused_rhs: any python object | |
| 1423 | |
| 1424 Returns: | |
| 1425 always returns True | |
| 1426 """ | |
| 1427 | |
| 1428 return True | |
| 1429 | |
| 1430 def __repr__(self): | |
| 1431 return '<IgnoreArg>' | |
| 1432 | |
| 1433 | |
| 1434 class MethodGroup(object): | |
| 1435 """Base class containing common behaviour for MethodGroups.""" | |
| 1436 | |
| 1437 def __init__(self, group_name): | |
| 1438 self._group_name = group_name | |
| 1439 | |
| 1440 def group_name(self): | |
| 1441 return self._group_name | |
| 1442 | |
| 1443 def __str__(self): | |
| 1444 return '<%s "%s">' % (self.__class__.__name__, self._group_name) | |
| 1445 | |
| 1446 def AddMethod(self, mock_method): | |
| 1447 raise NotImplementedError | |
| 1448 | |
| 1449 def MethodCalled(self, mock_method): | |
| 1450 raise NotImplementedError | |
| 1451 | |
| 1452 def IsSatisfied(self): | |
| 1453 raise NotImplementedError | |
| 1454 | |
| 1455 class UnorderedGroup(MethodGroup): | |
| 1456 """UnorderedGroup holds a set of method calls that may occur in any order. | |
| 1457 | |
| 1458 This construct is helpful for non-deterministic events, such as iterating | |
| 1459 over the keys of a dict. | |
| 1460 """ | |
| 1461 | |
| 1462 def __init__(self, group_name): | |
| 1463 super(UnorderedGroup, self).__init__(group_name) | |
| 1464 self._methods = [] | |
| 1465 | |
| 1466 def AddMethod(self, mock_method): | |
| 1467 """Add a method to this group. | |
| 1468 | |
| 1469 Args: | |
| 1470 mock_method: A mock method to be added to this group. | |
| 1471 """ | |
| 1472 | |
| 1473 self._methods.append(mock_method) | |
| 1474 | |
| 1475 def MethodCalled(self, mock_method): | |
| 1476 """Remove a method call from the group. | |
| 1477 | |
| 1478 If the method is not in the set, an UnexpectedMethodCallError will be | |
| 1479 raised. | |
| 1480 | |
| 1481 Args: | |
| 1482 mock_method: a mock method that should be equal to a method in the group. | |
| 1483 | |
| 1484 Returns: | |
| 1485 The mock method from the group | |
| 1486 | |
| 1487 Raises: | |
| 1488 UnexpectedMethodCallError if the mock_method was not in the group. | |
| 1489 """ | |
| 1490 | |
| 1491 # Check to see if this method exists, and if so, remove it from the set | |
| 1492 # and return it. | |
| 1493 for method in self._methods: | |
| 1494 if method == mock_method: | |
| 1495 # Remove the called mock_method instead of the method in the group. | |
| 1496 # The called method will match any comparators when equality is checked | |
| 1497 # during removal. The method in the group could pass a comparator to | |
| 1498 # another comparator during the equality check. | |
| 1499 self._methods.remove(mock_method) | |
| 1500 | |
| 1501 # If this group is not empty, put it back at the head of the queue. | |
| 1502 if not self.IsSatisfied(): | |
| 1503 mock_method._call_queue.appendleft(self) | |
| 1504 | |
| 1505 return self, method | |
| 1506 | |
| 1507 raise UnexpectedMethodCallError(mock_method, self) | |
| 1508 | |
| 1509 def IsSatisfied(self): | |
| 1510 """Return True if there are not any methods in this group.""" | |
| 1511 | |
| 1512 return len(self._methods) == 0 | |
| 1513 | |
| 1514 | |
| 1515 class MultipleTimesGroup(MethodGroup): | |
| 1516 """MultipleTimesGroup holds methods that may be called any number of times. | |
| 1517 | |
| 1518 Note: Each method must be called at least once. | |
| 1519 | |
| 1520 This is helpful, if you don't know or care how many times a method is called. | |
| 1521 """ | |
| 1522 | |
| 1523 def __init__(self, group_name): | |
| 1524 super(MultipleTimesGroup, self).__init__(group_name) | |
| 1525 self._methods = set() | |
| 1526 self._methods_left = set() | |
| 1527 | |
| 1528 def AddMethod(self, mock_method): | |
| 1529 """Add a method to this group. | |
| 1530 | |
| 1531 Args: | |
| 1532 mock_method: A mock method to be added to this group. | |
| 1533 """ | |
| 1534 | |
| 1535 self._methods.add(mock_method) | |
| 1536 self._methods_left.add(mock_method) | |
| 1537 | |
| 1538 def MethodCalled(self, mock_method): | |
| 1539 """Remove a method call from the group. | |
| 1540 | |
| 1541 If the method is not in the set, an UnexpectedMethodCallError will be | |
| 1542 raised. | |
| 1543 | |
| 1544 Args: | |
| 1545 mock_method: a mock method that should be equal to a method in the group. | |
| 1546 | |
| 1547 Returns: | |
| 1548 The mock method from the group | |
| 1549 | |
| 1550 Raises: | |
| 1551 UnexpectedMethodCallError if the mock_method was not in the group. | |
| 1552 """ | |
| 1553 | |
| 1554 # Check to see if this method exists, and if so add it to the set of | |
| 1555 # called methods. | |
| 1556 for method in self._methods: | |
| 1557 if method == mock_method: | |
| 1558 self._methods_left.discard(method) | |
| 1559 # Always put this group back on top of the queue, because we don't know | |
| 1560 # when we are done. | |
| 1561 mock_method._call_queue.appendleft(self) | |
| 1562 return self, method | |
| 1563 | |
| 1564 if self.IsSatisfied(): | |
| 1565 next_method = mock_method._PopNextMethod(); | |
| 1566 return next_method, None | |
| 1567 else: | |
| 1568 raise UnexpectedMethodCallError(mock_method, self) | |
| 1569 | |
| 1570 def IsSatisfied(self): | |
| 1571 """Return True if all methods in this group are called at least once.""" | |
| 1572 return len(self._methods_left) == 0 | |
| 1573 | |
| 1574 | |
| 1575 class MoxMetaTestBase(type): | |
| 1576 """Metaclass to add mox cleanup and verification to every test. | |
| 1577 | |
| 1578 As the mox unit testing class is being constructed (MoxTestBase or a | |
| 1579 subclass), this metaclass will modify all test functions to call the | |
| 1580 CleanUpMox method of the test class after they finish. This means that | |
| 1581 unstubbing and verifying will happen for every test with no additional code, | |
| 1582 and any failures will result in test failures as opposed to errors. | |
| 1583 """ | |
| 1584 | |
| 1585 def __init__(cls, name, bases, d): | |
| 1586 type.__init__(cls, name, bases, d) | |
| 1587 | |
| 1588 # also get all the attributes from the base classes to account | |
| 1589 # for a case when test class is not the immediate child of MoxTestBase | |
| 1590 for base in bases: | |
| 1591 for attr_name in dir(base): | |
| 1592 d[attr_name] = getattr(base, attr_name) | |
| 1593 | |
| 1594 for func_name, func in d.items(): | |
| 1595 if func_name.startswith('test') and callable(func): | |
| 1596 setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func)) | |
| 1597 | |
| 1598 @staticmethod | |
| 1599 def CleanUpTest(cls, func): | |
| 1600 """Adds Mox cleanup code to any MoxTestBase method. | |
| 1601 | |
| 1602 Always unsets stubs after a test. Will verify all mocks for tests that | |
| 1603 otherwise pass. | |
| 1604 | |
| 1605 Args: | |
| 1606 cls: MoxTestBase or subclass; the class whose test method we are altering. | |
| 1607 func: method; the method of the MoxTestBase test class we wish to alter. | |
| 1608 | |
| 1609 Returns: | |
| 1610 The modified method. | |
| 1611 """ | |
| 1612 def new_method(self, *args, **kwargs): | |
| 1613 mox_obj = getattr(self, 'mox', None) | |
| 1614 cleanup_mox = False | |
| 1615 if mox_obj and isinstance(mox_obj, Mox): | |
| 1616 cleanup_mox = True | |
| 1617 try: | |
| 1618 func(self, *args, **kwargs) | |
| 1619 finally: | |
| 1620 if cleanup_mox: | |
| 1621 mox_obj.UnsetStubs() | |
| 1622 if cleanup_mox: | |
| 1623 mox_obj.VerifyAll() | |
| 1624 new_method.__name__ = func.__name__ | |
| 1625 new_method.__doc__ = func.__doc__ | |
| 1626 new_method.__module__ = func.__module__ | |
| 1627 return new_method | |
| 1628 | |
| 1629 | |
| 1630 class MoxTestBase(unittest.TestCase): | |
| 1631 """Convenience test class to make stubbing easier. | |
| 1632 | |
| 1633 Sets up a "mox" attribute which is an instance of Mox - any mox tests will | |
| 1634 want this. Also automatically unsets any stubs and verifies that all mock | |
| 1635 methods have been called at the end of each test, eliminating boilerplate | |
| 1636 code. | |
| 1637 """ | |
| 1638 | |
| 1639 __metaclass__ = MoxMetaTestBase | |
| 1640 | |
| 1641 def setUp(self): | |
| 1642 super(MoxTestBase, self).setUp() | |
| 1643 self.mox = Mox() | |
| OLD | NEW |