OLD | NEW |
(Empty) | |
| 1 #!/usr/bin/python2.4 |
| 2 # |
| 3 # Copyright 2008 Google Inc. |
| 4 # |
| 5 # Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 # you may not use this file except in compliance with the License. |
| 7 # You may obtain a copy of the License at |
| 8 # |
| 9 # http://www.apache.org/licenses/LICENSE-2.0 |
| 10 # |
| 11 # Unless required by applicable law or agreed to in writing, software |
| 12 # distributed under the License is distributed on an "AS IS" BASIS, |
| 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 # See the License for the specific language governing permissions and |
| 15 # limitations under the License. |
| 16 |
| 17 """Mox, an object-mocking framework for Python. |
| 18 |
| 19 Mox works in the record-replay-verify paradigm. When you first create |
| 20 a mock object, it is in record mode. You then programmatically set |
| 21 the expected behavior of the mock object (what methods are to be |
| 22 called on it, with what parameters, what they should return, and in |
| 23 what order). |
| 24 |
| 25 Once you have set up the expected mock behavior, you put it in replay |
| 26 mode. Now the mock responds to method calls just as you told it to. |
| 27 If an unexpected method (or an expected method with unexpected |
| 28 parameters) is called, then an exception will be raised. |
| 29 |
| 30 Once you are done interacting with the mock, you need to verify that |
| 31 all the expected interactions occured. (Maybe your code exited |
| 32 prematurely without calling some cleanup method!) The verify phase |
| 33 ensures that every expected method was called; otherwise, an exception |
| 34 will be raised. |
| 35 |
| 36 WARNING! Mock objects created by Mox are not thread-safe. If you are |
| 37 call a mock in multiple threads, it should be guarded by a mutex. |
| 38 |
| 39 TODO(stevepm): Add the option to make mocks thread-safe! |
| 40 |
| 41 Suggested usage / workflow: |
| 42 |
| 43 # Create Mox factory |
| 44 my_mox = Mox() |
| 45 |
| 46 # Create a mock data access object |
| 47 mock_dao = my_mox.CreateMock(DAOClass) |
| 48 |
| 49 # Set up expected behavior |
| 50 mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person) |
| 51 mock_dao.DeletePerson(person) |
| 52 |
| 53 # Put mocks in replay mode |
| 54 my_mox.ReplayAll() |
| 55 |
| 56 # Inject mock object and run test |
| 57 controller.SetDao(mock_dao) |
| 58 controller.DeletePersonById('1') |
| 59 |
| 60 # Verify all methods were called as expected |
| 61 my_mox.VerifyAll() |
| 62 """ |
| 63 |
| 64 from collections import deque |
| 65 import difflib |
| 66 import inspect |
| 67 import re |
| 68 import types |
| 69 import unittest |
| 70 |
| 71 import stubout |
| 72 |
| 73 class Error(AssertionError): |
| 74 """Base exception for this module.""" |
| 75 |
| 76 pass |
| 77 |
| 78 |
| 79 class ExpectedMethodCallsError(Error): |
| 80 """Raised when Verify() is called before all expected methods have been called |
| 81 """ |
| 82 |
| 83 def __init__(self, expected_methods): |
| 84 """Init exception. |
| 85 |
| 86 Args: |
| 87 # expected_methods: A sequence of MockMethod objects that should have been |
| 88 # called. |
| 89 expected_methods: [MockMethod] |
| 90 |
| 91 Raises: |
| 92 ValueError: if expected_methods contains no methods. |
| 93 """ |
| 94 |
| 95 if not expected_methods: |
| 96 raise ValueError("There must be at least one expected method") |
| 97 Error.__init__(self) |
| 98 self._expected_methods = expected_methods |
| 99 |
| 100 def __str__(self): |
| 101 calls = "\n".join(["%3d. %s" % (i, m) |
| 102 for i, m in enumerate(self._expected_methods)]) |
| 103 return "Verify: Expected methods never called:\n%s" % (calls,) |
| 104 |
| 105 |
| 106 class UnexpectedMethodCallError(Error): |
| 107 """Raised when an unexpected method is called. |
| 108 |
| 109 This can occur if a method is called with incorrect parameters, or out of the |
| 110 specified order. |
| 111 """ |
| 112 |
| 113 def __init__(self, unexpected_method, expected): |
| 114 """Init exception. |
| 115 |
| 116 Args: |
| 117 # unexpected_method: MockMethod that was called but was not at the head of |
| 118 # the expected_method queue. |
| 119 # expected: MockMethod or UnorderedGroup the method should have |
| 120 # been in. |
| 121 unexpected_method: MockMethod |
| 122 expected: MockMethod or UnorderedGroup |
| 123 """ |
| 124 |
| 125 Error.__init__(self) |
| 126 if expected is None: |
| 127 self._str = "Unexpected method call %s" % (unexpected_method,) |
| 128 else: |
| 129 differ = difflib.Differ() |
| 130 diff = differ.compare(str(unexpected_method).splitlines(True), |
| 131 str(expected).splitlines(True)) |
| 132 self._str = ("Unexpected method call. unexpected:- expected:+\n%s" |
| 133 % ("\n".join(diff),)) |
| 134 |
| 135 def __str__(self): |
| 136 return self._str |
| 137 |
| 138 |
| 139 class UnknownMethodCallError(Error): |
| 140 """Raised if an unknown method is requested of the mock object.""" |
| 141 |
| 142 def __init__(self, unknown_method_name): |
| 143 """Init exception. |
| 144 |
| 145 Args: |
| 146 # unknown_method_name: Method call that is not part of the mocked class's |
| 147 # public interface. |
| 148 unknown_method_name: str |
| 149 """ |
| 150 |
| 151 Error.__init__(self) |
| 152 self._unknown_method_name = unknown_method_name |
| 153 |
| 154 def __str__(self): |
| 155 return "Method called is not a member of the object: %s" % \ |
| 156 self._unknown_method_name |
| 157 |
| 158 |
| 159 class Mox(object): |
| 160 """Mox: a factory for creating mock objects.""" |
| 161 |
| 162 # A list of types that should be stubbed out with MockObjects (as |
| 163 # opposed to MockAnythings). |
| 164 _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType, |
| 165 types.ObjectType, types.TypeType] |
| 166 |
| 167 def __init__(self): |
| 168 """Initialize a new Mox.""" |
| 169 |
| 170 self._mock_objects = [] |
| 171 self.stubs = stubout.StubOutForTesting() |
| 172 |
| 173 def CreateMock(self, class_to_mock): |
| 174 """Create a new mock object. |
| 175 |
| 176 Args: |
| 177 # class_to_mock: the class to be mocked |
| 178 class_to_mock: class |
| 179 |
| 180 Returns: |
| 181 MockObject that can be used as the class_to_mock would be. |
| 182 """ |
| 183 |
| 184 new_mock = MockObject(class_to_mock) |
| 185 self._mock_objects.append(new_mock) |
| 186 return new_mock |
| 187 |
| 188 def CreateMockAnything(self, description=None): |
| 189 """Create a mock that will accept any method calls. |
| 190 |
| 191 This does not enforce an interface. |
| 192 |
| 193 Args: |
| 194 description: str. Optionally, a descriptive name for the mock object being |
| 195 created, for debugging output purposes. |
| 196 """ |
| 197 new_mock = MockAnything(description=description) |
| 198 self._mock_objects.append(new_mock) |
| 199 return new_mock |
| 200 |
| 201 def ReplayAll(self): |
| 202 """Set all mock objects to replay mode.""" |
| 203 |
| 204 for mock_obj in self._mock_objects: |
| 205 mock_obj._Replay() |
| 206 |
| 207 |
| 208 def VerifyAll(self): |
| 209 """Call verify on all mock objects created.""" |
| 210 |
| 211 for mock_obj in self._mock_objects: |
| 212 mock_obj._Verify() |
| 213 |
| 214 def ResetAll(self): |
| 215 """Call reset on all mock objects. This does not unset stubs.""" |
| 216 |
| 217 for mock_obj in self._mock_objects: |
| 218 mock_obj._Reset() |
| 219 |
| 220 def StubOutWithMock(self, obj, attr_name, use_mock_anything=False): |
| 221 """Replace a method, attribute, etc. with a Mock. |
| 222 |
| 223 This will replace a class or module with a MockObject, and everything else |
| 224 (method, function, etc) with a MockAnything. This can be overridden to |
| 225 always use a MockAnything by setting use_mock_anything to True. |
| 226 |
| 227 Args: |
| 228 obj: A Python object (class, module, instance, callable). |
| 229 attr_name: str. The name of the attribute to replace with a mock. |
| 230 use_mock_anything: bool. True if a MockAnything should be used regardless |
| 231 of the type of attribute. |
| 232 """ |
| 233 |
| 234 attr_to_replace = getattr(obj, attr_name) |
| 235 |
| 236 # Check for a MockAnything. This could cause confusing problems later on. |
| 237 if attr_to_replace == MockAnything(): |
| 238 raise TypeError('Cannot mock a MockAnything! Did you remember to ' |
| 239 'call UnsetStubs in your previous test?') |
| 240 |
| 241 if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything: |
| 242 stub = self.CreateMock(attr_to_replace) |
| 243 else: |
| 244 stub = self.CreateMockAnything(description='Stub for %s' % attr_to_replace
) |
| 245 |
| 246 self.stubs.Set(obj, attr_name, stub) |
| 247 |
| 248 def UnsetStubs(self): |
| 249 """Restore stubs to their original state.""" |
| 250 |
| 251 self.stubs.UnsetAll() |
| 252 |
| 253 def Replay(*args): |
| 254 """Put mocks into Replay mode. |
| 255 |
| 256 Args: |
| 257 # args is any number of mocks to put into replay mode. |
| 258 """ |
| 259 |
| 260 for mock in args: |
| 261 mock._Replay() |
| 262 |
| 263 |
| 264 def Verify(*args): |
| 265 """Verify mocks. |
| 266 |
| 267 Args: |
| 268 # args is any number of mocks to be verified. |
| 269 """ |
| 270 |
| 271 for mock in args: |
| 272 mock._Verify() |
| 273 |
| 274 |
| 275 def Reset(*args): |
| 276 """Reset mocks. |
| 277 |
| 278 Args: |
| 279 # args is any number of mocks to be reset. |
| 280 """ |
| 281 |
| 282 for mock in args: |
| 283 mock._Reset() |
| 284 |
| 285 |
| 286 class MockAnything: |
| 287 """A mock that can be used to mock anything. |
| 288 |
| 289 This is helpful for mocking classes that do not provide a public interface. |
| 290 """ |
| 291 |
| 292 def __init__(self, description=None): |
| 293 """Initialize a new MockAnything. |
| 294 |
| 295 Args: |
| 296 description: str. Optionally, a descriptive name for the mock object being |
| 297 created, for debugging output purposes. |
| 298 """ |
| 299 self._description = description |
| 300 self._Reset() |
| 301 |
| 302 def __str__(self): |
| 303 return "<MockAnything instance at %s>" % id(self) |
| 304 |
| 305 def __repr__(self): |
| 306 return '<MockAnything instance>' |
| 307 |
| 308 def __getattr__(self, method_name): |
| 309 """Intercept method calls on this object. |
| 310 |
| 311 A new MockMethod is returned that is aware of the MockAnything's |
| 312 state (record or replay). The call will be recorded or replayed |
| 313 by the MockMethod's __call__. |
| 314 |
| 315 Args: |
| 316 # method name: the name of the method being called. |
| 317 method_name: str |
| 318 |
| 319 Returns: |
| 320 A new MockMethod aware of MockAnything's state (record or replay). |
| 321 """ |
| 322 |
| 323 return self._CreateMockMethod(method_name) |
| 324 |
| 325 def _CreateMockMethod(self, method_name, method_to_mock=None): |
| 326 """Create a new mock method call and return it. |
| 327 |
| 328 Args: |
| 329 # method_name: the name of the method being called. |
| 330 # method_to_mock: The actual method being mocked, used for introspection. |
| 331 method_name: str |
| 332 method_to_mock: a method object |
| 333 |
| 334 Returns: |
| 335 A new MockMethod aware of MockAnything's state (record or replay). |
| 336 """ |
| 337 |
| 338 return MockMethod(method_name, self._expected_calls_queue, |
| 339 self._replay_mode, method_to_mock=method_to_mock, |
| 340 description=self._description) |
| 341 |
| 342 def __nonzero__(self): |
| 343 """Return 1 for nonzero so the mock can be used as a conditional.""" |
| 344 |
| 345 return 1 |
| 346 |
| 347 def __eq__(self, rhs): |
| 348 """Provide custom logic to compare objects.""" |
| 349 |
| 350 return (isinstance(rhs, MockAnything) and |
| 351 self._replay_mode == rhs._replay_mode and |
| 352 self._expected_calls_queue == rhs._expected_calls_queue) |
| 353 |
| 354 def __ne__(self, rhs): |
| 355 """Provide custom logic to compare objects.""" |
| 356 |
| 357 return not self == rhs |
| 358 |
| 359 def _Replay(self): |
| 360 """Start replaying expected method calls.""" |
| 361 |
| 362 self._replay_mode = True |
| 363 |
| 364 def _Verify(self): |
| 365 """Verify that all of the expected calls have been made. |
| 366 |
| 367 Raises: |
| 368 ExpectedMethodCallsError: if there are still more method calls in the |
| 369 expected queue. |
| 370 """ |
| 371 |
| 372 # If the list of expected calls is not empty, raise an exception |
| 373 if self._expected_calls_queue: |
| 374 # The last MultipleTimesGroup is not popped from the queue. |
| 375 if (len(self._expected_calls_queue) == 1 and |
| 376 isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and |
| 377 self._expected_calls_queue[0].IsSatisfied()): |
| 378 pass |
| 379 else: |
| 380 raise ExpectedMethodCallsError(self._expected_calls_queue) |
| 381 |
| 382 def _Reset(self): |
| 383 """Reset the state of this mock to record mode with an empty queue.""" |
| 384 |
| 385 # Maintain a list of method calls we are expecting |
| 386 self._expected_calls_queue = deque() |
| 387 |
| 388 # Make sure we are in setup mode, not replay mode |
| 389 self._replay_mode = False |
| 390 |
| 391 |
| 392 class MockObject(MockAnything, object): |
| 393 """A mock object that simulates the public/protected interface of a class.""" |
| 394 |
| 395 def __init__(self, class_to_mock): |
| 396 """Initialize a mock object. |
| 397 |
| 398 This determines the methods and properties of the class and stores them. |
| 399 |
| 400 Args: |
| 401 # class_to_mock: class to be mocked |
| 402 class_to_mock: class |
| 403 """ |
| 404 |
| 405 # This is used to hack around the mixin/inheritance of MockAnything, which |
| 406 # is not a proper object (it can be anything. :-) |
| 407 MockAnything.__dict__['__init__'](self) |
| 408 |
| 409 # Get a list of all the public and special methods we should mock. |
| 410 self._known_methods = set() |
| 411 self._known_vars = set() |
| 412 self._class_to_mock = class_to_mock |
| 413 for method in dir(class_to_mock): |
| 414 if callable(getattr(class_to_mock, method)): |
| 415 self._known_methods.add(method) |
| 416 else: |
| 417 self._known_vars.add(method) |
| 418 |
| 419 def __getattr__(self, name): |
| 420 """Intercept attribute request on this object. |
| 421 |
| 422 If the attribute is a public class variable, it will be returned and not |
| 423 recorded as a call. |
| 424 |
| 425 If the attribute is not a variable, it is handled like a method |
| 426 call. The method name is checked against the set of mockable |
| 427 methods, and a new MockMethod is returned that is aware of the |
| 428 MockObject's state (record or replay). The call will be recorded |
| 429 or replayed by the MockMethod's __call__. |
| 430 |
| 431 Args: |
| 432 # name: the name of the attribute being requested. |
| 433 name: str |
| 434 |
| 435 Returns: |
| 436 Either a class variable or a new MockMethod that is aware of the state |
| 437 of the mock (record or replay). |
| 438 |
| 439 Raises: |
| 440 UnknownMethodCallError if the MockObject does not mock the requested |
| 441 method. |
| 442 """ |
| 443 |
| 444 if name in self._known_vars: |
| 445 return getattr(self._class_to_mock, name) |
| 446 |
| 447 if name in self._known_methods: |
| 448 return self._CreateMockMethod( |
| 449 name, |
| 450 method_to_mock=getattr(self._class_to_mock, name)) |
| 451 |
| 452 raise UnknownMethodCallError(name) |
| 453 |
| 454 def __eq__(self, rhs): |
| 455 """Provide custom logic to compare objects.""" |
| 456 |
| 457 return (isinstance(rhs, MockObject) and |
| 458 self._class_to_mock == rhs._class_to_mock and |
| 459 self._replay_mode == rhs._replay_mode and |
| 460 self._expected_calls_queue == rhs._expected_calls_queue) |
| 461 |
| 462 def __setitem__(self, key, value): |
| 463 """Provide custom logic for mocking classes that support item assignment. |
| 464 |
| 465 Args: |
| 466 key: Key to set the value for. |
| 467 value: Value to set. |
| 468 |
| 469 Returns: |
| 470 Expected return value in replay mode. A MockMethod object for the |
| 471 __setitem__ method that has already been called if not in replay mode. |
| 472 |
| 473 Raises: |
| 474 TypeError if the underlying class does not support item assignment. |
| 475 UnexpectedMethodCallError if the object does not expect the call to |
| 476 __setitem__. |
| 477 |
| 478 """ |
| 479 # Verify the class supports item assignment. |
| 480 if '__setitem__' not in dir(self._class_to_mock): |
| 481 raise TypeError('object does not support item assignment') |
| 482 |
| 483 # If we are in replay mode then simply call the mock __setitem__ method. |
| 484 if self._replay_mode: |
| 485 return MockMethod('__setitem__', self._expected_calls_queue, |
| 486 self._replay_mode)(key, value) |
| 487 |
| 488 |
| 489 # Otherwise, create a mock method __setitem__. |
| 490 return self._CreateMockMethod('__setitem__')(key, value) |
| 491 |
| 492 def __getitem__(self, key): |
| 493 """Provide custom logic for mocking classes that are subscriptable. |
| 494 |
| 495 Args: |
| 496 key: Key to return the value for. |
| 497 |
| 498 Returns: |
| 499 Expected return value in replay mode. A MockMethod object for the |
| 500 __getitem__ method that has already been called if not in replay mode. |
| 501 |
| 502 Raises: |
| 503 TypeError if the underlying class is not subscriptable. |
| 504 UnexpectedMethodCallError if the object does not expect the call to |
| 505 __getitem__. |
| 506 |
| 507 """ |
| 508 # Verify the class supports item assignment. |
| 509 if '__getitem__' not in dir(self._class_to_mock): |
| 510 raise TypeError('unsubscriptable object') |
| 511 |
| 512 # If we are in replay mode then simply call the mock __getitem__ method. |
| 513 if self._replay_mode: |
| 514 return MockMethod('__getitem__', self._expected_calls_queue, |
| 515 self._replay_mode)(key) |
| 516 |
| 517 |
| 518 # Otherwise, create a mock method __getitem__. |
| 519 return self._CreateMockMethod('__getitem__')(key) |
| 520 |
| 521 def __iter__(self): |
| 522 """Provide custom logic for mocking classes that are iterable. |
| 523 |
| 524 Returns: |
| 525 Expected return value in replay mode. A MockMethod object for the |
| 526 __iter__ method that has already been called if not in replay mode. |
| 527 |
| 528 Raises: |
| 529 TypeError if the underlying class is not iterable. |
| 530 UnexpectedMethodCallError if the object does not expect the call to |
| 531 __iter__. |
| 532 |
| 533 """ |
| 534 methods = dir(self._class_to_mock) |
| 535 |
| 536 # Verify the class supports iteration. |
| 537 if '__iter__' not in methods: |
| 538 # If it doesn't have iter method and we are in replay method, then try to |
| 539 # iterate using subscripts. |
| 540 if '__getitem__' not in methods or not self._replay_mode: |
| 541 raise TypeError('not iterable object') |
| 542 else: |
| 543 results = [] |
| 544 index = 0 |
| 545 try: |
| 546 while True: |
| 547 results.append(self[index]) |
| 548 index += 1 |
| 549 except IndexError: |
| 550 return iter(results) |
| 551 |
| 552 # If we are in replay mode then simply call the mock __iter__ method. |
| 553 if self._replay_mode: |
| 554 return MockMethod('__iter__', self._expected_calls_queue, |
| 555 self._replay_mode)() |
| 556 |
| 557 |
| 558 # Otherwise, create a mock method __iter__. |
| 559 return self._CreateMockMethod('__iter__')() |
| 560 |
| 561 |
| 562 def __contains__(self, key): |
| 563 """Provide custom logic for mocking classes that contain items. |
| 564 |
| 565 Args: |
| 566 key: Key to look in container for. |
| 567 |
| 568 Returns: |
| 569 Expected return value in replay mode. A MockMethod object for the |
| 570 __contains__ method that has already been called if not in replay mode. |
| 571 |
| 572 Raises: |
| 573 TypeError if the underlying class does not implement __contains__ |
| 574 UnexpectedMethodCaller if the object does not expect the call to |
| 575 __contains__. |
| 576 |
| 577 """ |
| 578 contains = self._class_to_mock.__dict__.get('__contains__', None) |
| 579 |
| 580 if contains is None: |
| 581 raise TypeError('unsubscriptable object') |
| 582 |
| 583 if self._replay_mode: |
| 584 return MockMethod('__contains__', self._expected_calls_queue, |
| 585 self._replay_mode)(key) |
| 586 |
| 587 return self._CreateMockMethod('__contains__')(key) |
| 588 |
| 589 def __call__(self, *params, **named_params): |
| 590 """Provide custom logic for mocking classes that are callable.""" |
| 591 |
| 592 # Verify the class we are mocking is callable. |
| 593 callable = hasattr(self._class_to_mock, '__call__') |
| 594 if not callable: |
| 595 raise TypeError('Not callable') |
| 596 |
| 597 # Because the call is happening directly on this object instead of a method, |
| 598 # the call on the mock method is made right here |
| 599 mock_method = self._CreateMockMethod('__call__') |
| 600 return mock_method(*params, **named_params) |
| 601 |
| 602 @property |
| 603 def __class__(self): |
| 604 """Return the class that is being mocked.""" |
| 605 |
| 606 return self._class_to_mock |
| 607 |
| 608 |
| 609 class MethodCallChecker(object): |
| 610 """Ensures that methods are called correctly.""" |
| 611 |
| 612 _NEEDED, _DEFAULT, _GIVEN = range(3) |
| 613 |
| 614 def __init__(self, method): |
| 615 """Creates a checker. |
| 616 |
| 617 Args: |
| 618 # method: A method to check. |
| 619 method: function |
| 620 |
| 621 Raises: |
| 622 ValueError: method could not be inspected, so checks aren't possible. |
| 623 Some methods and functions like built-ins can't be inspected. |
| 624 """ |
| 625 try: |
| 626 self._args, varargs, varkw, defaults = inspect.getargspec(method) |
| 627 except TypeError: |
| 628 raise ValueError('Could not get argument specification for %r' |
| 629 % (method,)) |
| 630 if inspect.ismethod(method): |
| 631 self._args = self._args[1:] # Skip 'self'. |
| 632 self._method = method |
| 633 |
| 634 self._has_varargs = varargs is not None |
| 635 self._has_varkw = varkw is not None |
| 636 if defaults is None: |
| 637 self._required_args = self._args |
| 638 self._default_args = [] |
| 639 else: |
| 640 self._required_args = self._args[:-len(defaults)] |
| 641 self._default_args = self._args[-len(defaults):] |
| 642 |
| 643 def _RecordArgumentGiven(self, arg_name, arg_status): |
| 644 """Mark an argument as being given. |
| 645 |
| 646 Args: |
| 647 # arg_name: The name of the argument to mark in arg_status. |
| 648 # arg_status: Maps argument names to one of _NEEDED, _DEFAULT, _GIVEN. |
| 649 arg_name: string |
| 650 arg_status: dict |
| 651 |
| 652 Raises: |
| 653 AttributeError: arg_name is already marked as _GIVEN. |
| 654 """ |
| 655 if arg_status.get(arg_name, None) == MethodCallChecker._GIVEN: |
| 656 raise AttributeError('%s provided more than once' % (arg_name,)) |
| 657 arg_status[arg_name] = MethodCallChecker._GIVEN |
| 658 |
| 659 def Check(self, params, named_params): |
| 660 """Ensures that the parameters used while recording a call are valid. |
| 661 |
| 662 Args: |
| 663 # params: A list of positional parameters. |
| 664 # named_params: A dict of named parameters. |
| 665 params: list |
| 666 named_params: dict |
| 667 |
| 668 Raises: |
| 669 AttributeError: the given parameters don't work with the given method. |
| 670 """ |
| 671 arg_status = dict((a, MethodCallChecker._NEEDED) |
| 672 for a in self._required_args) |
| 673 for arg in self._default_args: |
| 674 arg_status[arg] = MethodCallChecker._DEFAULT |
| 675 |
| 676 # Check that each positional param is valid. |
| 677 for i in range(len(params)): |
| 678 try: |
| 679 arg_name = self._args[i] |
| 680 except IndexError: |
| 681 if not self._has_varargs: |
| 682 raise AttributeError('%s does not take %d or more positional ' |
| 683 'arguments' % (self._method.__name__, i)) |
| 684 else: |
| 685 self._RecordArgumentGiven(arg_name, arg_status) |
| 686 |
| 687 # Check each keyword argument. |
| 688 for arg_name in named_params: |
| 689 if arg_name not in arg_status and not self._has_varkw: |
| 690 raise AttributeError('%s is not expecting keyword argument %s' |
| 691 % (self._method.__name__, arg_name)) |
| 692 self._RecordArgumentGiven(arg_name, arg_status) |
| 693 |
| 694 # Ensure all the required arguments have been given. |
| 695 still_needed = [k for k, v in arg_status.iteritems() |
| 696 if v == MethodCallChecker._NEEDED] |
| 697 if still_needed: |
| 698 raise AttributeError('No values given for arguments %s' |
| 699 % (' '.join(sorted(still_needed)))) |
| 700 |
| 701 |
| 702 class MockMethod(object): |
| 703 """Callable mock method. |
| 704 |
| 705 A MockMethod should act exactly like the method it mocks, accepting parameters |
| 706 and returning a value, or throwing an exception (as specified). When this |
| 707 method is called, it can optionally verify whether the called method (name and |
| 708 signature) matches the expected method. |
| 709 """ |
| 710 |
| 711 def __init__(self, method_name, call_queue, replay_mode, |
| 712 method_to_mock=None, description=None): |
| 713 """Construct a new mock method. |
| 714 |
| 715 Args: |
| 716 # method_name: the name of the method |
| 717 # call_queue: deque of calls, verify this call against the head, or add |
| 718 # this call to the queue. |
| 719 # replay_mode: False if we are recording, True if we are verifying calls |
| 720 # against the call queue. |
| 721 # method_to_mock: The actual method being mocked, used for introspection. |
| 722 # description: optionally, a descriptive name for this method. Typically |
| 723 # this is equal to the descriptive name of the method's class. |
| 724 method_name: str |
| 725 call_queue: list or deque |
| 726 replay_mode: bool |
| 727 method_to_mock: a method object |
| 728 description: str or None |
| 729 """ |
| 730 |
| 731 self._name = method_name |
| 732 self._call_queue = call_queue |
| 733 if not isinstance(call_queue, deque): |
| 734 self._call_queue = deque(self._call_queue) |
| 735 self._replay_mode = replay_mode |
| 736 self._description = description |
| 737 |
| 738 self._params = None |
| 739 self._named_params = None |
| 740 self._return_value = None |
| 741 self._exception = None |
| 742 self._side_effects = None |
| 743 |
| 744 try: |
| 745 self._checker = MethodCallChecker(method_to_mock) |
| 746 except ValueError: |
| 747 self._checker = None |
| 748 |
| 749 def __call__(self, *params, **named_params): |
| 750 """Log parameters and return the specified return value. |
| 751 |
| 752 If the Mock(Anything/Object) associated with this call is in record mode, |
| 753 this MockMethod will be pushed onto the expected call queue. If the mock |
| 754 is in replay mode, this will pop a MockMethod off the top of the queue and |
| 755 verify this call is equal to the expected call. |
| 756 |
| 757 Raises: |
| 758 UnexpectedMethodCall if this call is supposed to match an expected method |
| 759 call and it does not. |
| 760 """ |
| 761 |
| 762 self._params = params |
| 763 self._named_params = named_params |
| 764 |
| 765 if not self._replay_mode: |
| 766 if self._checker is not None: |
| 767 self._checker.Check(params, named_params) |
| 768 self._call_queue.append(self) |
| 769 return self |
| 770 |
| 771 expected_method = self._VerifyMethodCall() |
| 772 |
| 773 if expected_method._side_effects: |
| 774 expected_method._side_effects(*params, **named_params) |
| 775 |
| 776 if expected_method._exception: |
| 777 raise expected_method._exception |
| 778 |
| 779 return expected_method._return_value |
| 780 |
| 781 def __getattr__(self, name): |
| 782 """Raise an AttributeError with a helpful message.""" |
| 783 |
| 784 raise AttributeError('MockMethod has no attribute "%s". ' |
| 785 'Did you remember to put your mocks in replay mode?' % name) |
| 786 |
| 787 def __iter__(self): |
| 788 """Raise a TypeError with a helpful message.""" |
| 789 raise TypeError('MockMethod cannot be iterated. ' |
| 790 'Did you remember to put your mocks in replay mode?') |
| 791 |
| 792 def next(self): |
| 793 """Raise a TypeError with a helpful message.""" |
| 794 raise TypeError('MockMethod cannot be iterated. ' |
| 795 'Did you remember to put your mocks in replay mode?') |
| 796 |
| 797 def _PopNextMethod(self): |
| 798 """Pop the next method from our call queue.""" |
| 799 try: |
| 800 return self._call_queue.popleft() |
| 801 except IndexError: |
| 802 raise UnexpectedMethodCallError(self, None) |
| 803 |
| 804 def _VerifyMethodCall(self): |
| 805 """Verify the called method is expected. |
| 806 |
| 807 This can be an ordered method, or part of an unordered set. |
| 808 |
| 809 Returns: |
| 810 The expected mock method. |
| 811 |
| 812 Raises: |
| 813 UnexpectedMethodCall if the method called was not expected. |
| 814 """ |
| 815 |
| 816 expected = self._PopNextMethod() |
| 817 |
| 818 # Loop here, because we might have a MethodGroup followed by another |
| 819 # group. |
| 820 while isinstance(expected, MethodGroup): |
| 821 expected, method = expected.MethodCalled(self) |
| 822 if method is not None: |
| 823 return method |
| 824 |
| 825 # This is a mock method, so just check equality. |
| 826 if expected != self: |
| 827 raise UnexpectedMethodCallError(self, expected) |
| 828 |
| 829 return expected |
| 830 |
| 831 def __str__(self): |
| 832 params = ', '.join( |
| 833 [repr(p) for p in self._params or []] + |
| 834 ['%s=%r' % x for x in sorted((self._named_params or {}).items())]) |
| 835 full_desc = "%s(%s) -> %r" % (self._name, params, self._return_value) |
| 836 if self._description: |
| 837 full_desc = "%s.%s" % (self._description, full_desc) |
| 838 return full_desc |
| 839 |
| 840 def __eq__(self, rhs): |
| 841 """Test whether this MockMethod is equivalent to another MockMethod. |
| 842 |
| 843 Args: |
| 844 # rhs: the right hand side of the test |
| 845 rhs: MockMethod |
| 846 """ |
| 847 |
| 848 return (isinstance(rhs, MockMethod) and |
| 849 self._name == rhs._name and |
| 850 self._params == rhs._params and |
| 851 self._named_params == rhs._named_params) |
| 852 |
| 853 def __ne__(self, rhs): |
| 854 """Test whether this MockMethod is not equivalent to another MockMethod. |
| 855 |
| 856 Args: |
| 857 # rhs: the right hand side of the test |
| 858 rhs: MockMethod |
| 859 """ |
| 860 |
| 861 return not self == rhs |
| 862 |
| 863 def GetPossibleGroup(self): |
| 864 """Returns a possible group from the end of the call queue or None if no |
| 865 other methods are on the stack. |
| 866 """ |
| 867 |
| 868 # Remove this method from the tail of the queue so we can add it to a group. |
| 869 this_method = self._call_queue.pop() |
| 870 assert this_method == self |
| 871 |
| 872 # Determine if the tail of the queue is a group, or just a regular ordered |
| 873 # mock method. |
| 874 group = None |
| 875 try: |
| 876 group = self._call_queue[-1] |
| 877 except IndexError: |
| 878 pass |
| 879 |
| 880 return group |
| 881 |
| 882 def _CheckAndCreateNewGroup(self, group_name, group_class): |
| 883 """Checks if the last method (a possible group) is an instance of our |
| 884 group_class. Adds the current method to this group or creates a new one. |
| 885 |
| 886 Args: |
| 887 |
| 888 group_name: the name of the group. |
| 889 group_class: the class used to create instance of this new group |
| 890 """ |
| 891 group = self.GetPossibleGroup() |
| 892 |
| 893 # If this is a group, and it is the correct group, add the method. |
| 894 if isinstance(group, group_class) and group.group_name() == group_name: |
| 895 group.AddMethod(self) |
| 896 return self |
| 897 |
| 898 # Create a new group and add the method. |
| 899 new_group = group_class(group_name) |
| 900 new_group.AddMethod(self) |
| 901 self._call_queue.append(new_group) |
| 902 return self |
| 903 |
| 904 def InAnyOrder(self, group_name="default"): |
| 905 """Move this method into a group of unordered calls. |
| 906 |
| 907 A group of unordered calls must be defined together, and must be executed |
| 908 in full before the next expected method can be called. There can be |
| 909 multiple groups that are expected serially, if they are given |
| 910 different group names. The same group name can be reused if there is a |
| 911 standard method call, or a group with a different name, spliced between |
| 912 usages. |
| 913 |
| 914 Args: |
| 915 group_name: the name of the unordered group. |
| 916 |
| 917 Returns: |
| 918 self |
| 919 """ |
| 920 return self._CheckAndCreateNewGroup(group_name, UnorderedGroup) |
| 921 |
| 922 def MultipleTimes(self, group_name="default"): |
| 923 """Move this method into group of calls which may be called multiple times. |
| 924 |
| 925 A group of repeating calls must be defined together, and must be executed in |
| 926 full before the next expected mehtod can be called. |
| 927 |
| 928 Args: |
| 929 group_name: the name of the unordered group. |
| 930 |
| 931 Returns: |
| 932 self |
| 933 """ |
| 934 return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup) |
| 935 |
| 936 def AndReturn(self, return_value): |
| 937 """Set the value to return when this method is called. |
| 938 |
| 939 Args: |
| 940 # return_value can be anything. |
| 941 """ |
| 942 |
| 943 self._return_value = return_value |
| 944 return return_value |
| 945 |
| 946 def AndRaise(self, exception): |
| 947 """Set the exception to raise when this method is called. |
| 948 |
| 949 Args: |
| 950 # exception: the exception to raise when this method is called. |
| 951 exception: Exception |
| 952 """ |
| 953 |
| 954 self._exception = exception |
| 955 |
| 956 def WithSideEffects(self, side_effects): |
| 957 """Set the side effects that are simulated when this method is called. |
| 958 |
| 959 Args: |
| 960 side_effects: A callable which modifies the parameters or other relevant |
| 961 state which a given test case depends on. |
| 962 |
| 963 Returns: |
| 964 Self for chaining with AndReturn and AndRaise. |
| 965 """ |
| 966 self._side_effects = side_effects |
| 967 return self |
| 968 |
| 969 class Comparator: |
| 970 """Base class for all Mox comparators. |
| 971 |
| 972 A Comparator can be used as a parameter to a mocked method when the exact |
| 973 value is not known. For example, the code you are testing might build up a |
| 974 long SQL string that is passed to your mock DAO. You're only interested that |
| 975 the IN clause contains the proper primary keys, so you can set your mock |
| 976 up as follows: |
| 977 |
| 978 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) |
| 979 |
| 980 Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'. |
| 981 |
| 982 A Comparator may replace one or more parameters, for example: |
| 983 # return at most 10 rows |
| 984 mock_dao.RunQuery(StrContains('SELECT'), 10) |
| 985 |
| 986 or |
| 987 |
| 988 # Return some non-deterministic number of rows |
| 989 mock_dao.RunQuery(StrContains('SELECT'), IsA(int)) |
| 990 """ |
| 991 |
| 992 def equals(self, rhs): |
| 993 """Special equals method that all comparators must implement. |
| 994 |
| 995 Args: |
| 996 rhs: any python object |
| 997 """ |
| 998 |
| 999 raise NotImplementedError, 'method must be implemented by a subclass.' |
| 1000 |
| 1001 def __eq__(self, rhs): |
| 1002 return self.equals(rhs) |
| 1003 |
| 1004 def __ne__(self, rhs): |
| 1005 return not self.equals(rhs) |
| 1006 |
| 1007 |
| 1008 class IsA(Comparator): |
| 1009 """This class wraps a basic Python type or class. It is used to verify |
| 1010 that a parameter is of the given type or class. |
| 1011 |
| 1012 Example: |
| 1013 mock_dao.Connect(IsA(DbConnectInfo)) |
| 1014 """ |
| 1015 |
| 1016 def __init__(self, class_name): |
| 1017 """Initialize IsA |
| 1018 |
| 1019 Args: |
| 1020 class_name: basic python type or a class |
| 1021 """ |
| 1022 |
| 1023 self._class_name = class_name |
| 1024 |
| 1025 def equals(self, rhs): |
| 1026 """Check to see if the RHS is an instance of class_name. |
| 1027 |
| 1028 Args: |
| 1029 # rhs: the right hand side of the test |
| 1030 rhs: object |
| 1031 |
| 1032 Returns: |
| 1033 bool |
| 1034 """ |
| 1035 |
| 1036 try: |
| 1037 return isinstance(rhs, self._class_name) |
| 1038 except TypeError: |
| 1039 # Check raw types if there was a type error. This is helpful for |
| 1040 # things like cStringIO.StringIO. |
| 1041 return type(rhs) == type(self._class_name) |
| 1042 |
| 1043 def __repr__(self): |
| 1044 return str(self._class_name) |
| 1045 |
| 1046 class IsAlmost(Comparator): |
| 1047 """Comparison class used to check whether a parameter is nearly equal |
| 1048 to a given value. Generally useful for floating point numbers. |
| 1049 |
| 1050 Example mock_dao.SetTimeout((IsAlmost(3.9))) |
| 1051 """ |
| 1052 |
| 1053 def __init__(self, float_value, places=7): |
| 1054 """Initialize IsAlmost. |
| 1055 |
| 1056 Args: |
| 1057 float_value: The value for making the comparison. |
| 1058 places: The number of decimal places to round to. |
| 1059 """ |
| 1060 |
| 1061 self._float_value = float_value |
| 1062 self._places = places |
| 1063 |
| 1064 def equals(self, rhs): |
| 1065 """Check to see if RHS is almost equal to float_value |
| 1066 |
| 1067 Args: |
| 1068 rhs: the value to compare to float_value |
| 1069 |
| 1070 Returns: |
| 1071 bool |
| 1072 """ |
| 1073 |
| 1074 try: |
| 1075 return round(rhs-self._float_value, self._places) == 0 |
| 1076 except TypeError: |
| 1077 # This is probably because either float_value or rhs is not a number. |
| 1078 return False |
| 1079 |
| 1080 def __repr__(self): |
| 1081 return str(self._float_value) |
| 1082 |
| 1083 class StrContains(Comparator): |
| 1084 """Comparison class used to check whether a substring exists in a |
| 1085 string parameter. This can be useful in mocking a database with SQL |
| 1086 passed in as a string parameter, for example. |
| 1087 |
| 1088 Example: |
| 1089 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) |
| 1090 """ |
| 1091 |
| 1092 def __init__(self, search_string): |
| 1093 """Initialize. |
| 1094 |
| 1095 Args: |
| 1096 # search_string: the string you are searching for |
| 1097 search_string: str |
| 1098 """ |
| 1099 |
| 1100 self._search_string = search_string |
| 1101 |
| 1102 def equals(self, rhs): |
| 1103 """Check to see if the search_string is contained in the rhs string. |
| 1104 |
| 1105 Args: |
| 1106 # rhs: the right hand side of the test |
| 1107 rhs: object |
| 1108 |
| 1109 Returns: |
| 1110 bool |
| 1111 """ |
| 1112 |
| 1113 try: |
| 1114 return rhs.find(self._search_string) > -1 |
| 1115 except Exception: |
| 1116 return False |
| 1117 |
| 1118 def __repr__(self): |
| 1119 return '<str containing \'%s\'>' % self._search_string |
| 1120 |
| 1121 |
| 1122 class Regex(Comparator): |
| 1123 """Checks if a string matches a regular expression. |
| 1124 |
| 1125 This uses a given regular expression to determine equality. |
| 1126 """ |
| 1127 |
| 1128 def __init__(self, pattern, flags=0): |
| 1129 """Initialize. |
| 1130 |
| 1131 Args: |
| 1132 # pattern is the regular expression to search for |
| 1133 pattern: str |
| 1134 # flags passed to re.compile function as the second argument |
| 1135 flags: int |
| 1136 """ |
| 1137 |
| 1138 self.regex = re.compile(pattern, flags=flags) |
| 1139 |
| 1140 def equals(self, rhs): |
| 1141 """Check to see if rhs matches regular expression pattern. |
| 1142 |
| 1143 Returns: |
| 1144 bool |
| 1145 """ |
| 1146 |
| 1147 return self.regex.search(rhs) is not None |
| 1148 |
| 1149 def __repr__(self): |
| 1150 s = '<regular expression \'%s\'' % self.regex.pattern |
| 1151 if self.regex.flags: |
| 1152 s += ', flags=%d' % self.regex.flags |
| 1153 s += '>' |
| 1154 return s |
| 1155 |
| 1156 |
| 1157 class In(Comparator): |
| 1158 """Checks whether an item (or key) is in a list (or dict) parameter. |
| 1159 |
| 1160 Example: |
| 1161 mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result) |
| 1162 """ |
| 1163 |
| 1164 def __init__(self, key): |
| 1165 """Initialize. |
| 1166 |
| 1167 Args: |
| 1168 # key is any thing that could be in a list or a key in a dict |
| 1169 """ |
| 1170 |
| 1171 self._key = key |
| 1172 |
| 1173 def equals(self, rhs): |
| 1174 """Check to see whether key is in rhs. |
| 1175 |
| 1176 Args: |
| 1177 rhs: dict |
| 1178 |
| 1179 Returns: |
| 1180 bool |
| 1181 """ |
| 1182 |
| 1183 return self._key in rhs |
| 1184 |
| 1185 def __repr__(self): |
| 1186 return '<sequence or map containing \'%s\'>' % self._key |
| 1187 |
| 1188 |
| 1189 class Not(Comparator): |
| 1190 """Checks whether a predicates is False. |
| 1191 |
| 1192 Example: |
| 1193 mock_dao.UpdateUsers(Not(ContainsKeyValue('stevepm', stevepm_user_info))) |
| 1194 """ |
| 1195 |
| 1196 def __init__(self, predicate): |
| 1197 """Initialize. |
| 1198 |
| 1199 Args: |
| 1200 # predicate: a Comparator instance. |
| 1201 """ |
| 1202 |
| 1203 assert isinstance(predicate, Comparator), ("predicate %r must be a" |
| 1204 " Comparator." % predicate) |
| 1205 self._predicate = predicate |
| 1206 |
| 1207 def equals(self, rhs): |
| 1208 """Check to see whether the predicate is False. |
| 1209 |
| 1210 Args: |
| 1211 rhs: A value that will be given in argument of the predicate. |
| 1212 |
| 1213 Returns: |
| 1214 bool |
| 1215 """ |
| 1216 |
| 1217 return not self._predicate.equals(rhs) |
| 1218 |
| 1219 def __repr__(self): |
| 1220 return '<not \'%s\'>' % self._predicate |
| 1221 |
| 1222 |
| 1223 class ContainsKeyValue(Comparator): |
| 1224 """Checks whether a key/value pair is in a dict parameter. |
| 1225 |
| 1226 Example: |
| 1227 mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info)) |
| 1228 """ |
| 1229 |
| 1230 def __init__(self, key, value): |
| 1231 """Initialize. |
| 1232 |
| 1233 Args: |
| 1234 # key: a key in a dict |
| 1235 # value: the corresponding value |
| 1236 """ |
| 1237 |
| 1238 self._key = key |
| 1239 self._value = value |
| 1240 |
| 1241 def equals(self, rhs): |
| 1242 """Check whether the given key/value pair is in the rhs dict. |
| 1243 |
| 1244 Returns: |
| 1245 bool |
| 1246 """ |
| 1247 |
| 1248 try: |
| 1249 return rhs[self._key] == self._value |
| 1250 except Exception: |
| 1251 return False |
| 1252 |
| 1253 def __repr__(self): |
| 1254 return '<map containing the entry \'%s: %s\'>' % (self._key, self._value) |
| 1255 |
| 1256 |
| 1257 class SameElementsAs(Comparator): |
| 1258 """Checks whether iterables contain the same elements (ignoring order). |
| 1259 |
| 1260 Example: |
| 1261 mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki')) |
| 1262 """ |
| 1263 |
| 1264 def __init__(self, expected_seq): |
| 1265 """Initialize. |
| 1266 |
| 1267 Args: |
| 1268 expected_seq: a sequence |
| 1269 """ |
| 1270 |
| 1271 self._expected_seq = expected_seq |
| 1272 |
| 1273 def equals(self, actual_seq): |
| 1274 """Check to see whether actual_seq has same elements as expected_seq. |
| 1275 |
| 1276 Args: |
| 1277 actual_seq: sequence |
| 1278 |
| 1279 Returns: |
| 1280 bool |
| 1281 """ |
| 1282 |
| 1283 try: |
| 1284 expected = dict([(element, None) for element in self._expected_seq]) |
| 1285 actual = dict([(element, None) for element in actual_seq]) |
| 1286 except TypeError: |
| 1287 # Fall back to slower list-compare if any of the objects are unhashable. |
| 1288 expected = list(self._expected_seq) |
| 1289 actual = list(actual_seq) |
| 1290 expected.sort() |
| 1291 actual.sort() |
| 1292 return expected == actual |
| 1293 |
| 1294 def __repr__(self): |
| 1295 return '<sequence with same elements as \'%s\'>' % self._expected_seq |
| 1296 |
| 1297 |
| 1298 class And(Comparator): |
| 1299 """Evaluates one or more Comparators on RHS and returns an AND of the results. |
| 1300 """ |
| 1301 |
| 1302 def __init__(self, *args): |
| 1303 """Initialize. |
| 1304 |
| 1305 Args: |
| 1306 *args: One or more Comparator |
| 1307 """ |
| 1308 |
| 1309 self._comparators = args |
| 1310 |
| 1311 def equals(self, rhs): |
| 1312 """Checks whether all Comparators are equal to rhs. |
| 1313 |
| 1314 Args: |
| 1315 # rhs: can be anything |
| 1316 |
| 1317 Returns: |
| 1318 bool |
| 1319 """ |
| 1320 |
| 1321 for comparator in self._comparators: |
| 1322 if not comparator.equals(rhs): |
| 1323 return False |
| 1324 |
| 1325 return True |
| 1326 |
| 1327 def __repr__(self): |
| 1328 return '<AND %s>' % str(self._comparators) |
| 1329 |
| 1330 |
| 1331 class Or(Comparator): |
| 1332 """Evaluates one or more Comparators on RHS and returns an OR of the results. |
| 1333 """ |
| 1334 |
| 1335 def __init__(self, *args): |
| 1336 """Initialize. |
| 1337 |
| 1338 Args: |
| 1339 *args: One or more Mox comparators |
| 1340 """ |
| 1341 |
| 1342 self._comparators = args |
| 1343 |
| 1344 def equals(self, rhs): |
| 1345 """Checks whether any Comparator is equal to rhs. |
| 1346 |
| 1347 Args: |
| 1348 # rhs: can be anything |
| 1349 |
| 1350 Returns: |
| 1351 bool |
| 1352 """ |
| 1353 |
| 1354 for comparator in self._comparators: |
| 1355 if comparator.equals(rhs): |
| 1356 return True |
| 1357 |
| 1358 return False |
| 1359 |
| 1360 def __repr__(self): |
| 1361 return '<OR %s>' % str(self._comparators) |
| 1362 |
| 1363 |
| 1364 class Func(Comparator): |
| 1365 """Call a function that should verify the parameter passed in is correct. |
| 1366 |
| 1367 You may need the ability to perform more advanced operations on the parameter |
| 1368 in order to validate it. You can use this to have a callable validate any |
| 1369 parameter. The callable should return either True or False. |
| 1370 |
| 1371 |
| 1372 Example: |
| 1373 |
| 1374 def myParamValidator(param): |
| 1375 # Advanced logic here |
| 1376 return True |
| 1377 |
| 1378 mock_dao.DoSomething(Func(myParamValidator), true) |
| 1379 """ |
| 1380 |
| 1381 def __init__(self, func): |
| 1382 """Initialize. |
| 1383 |
| 1384 Args: |
| 1385 func: callable that takes one parameter and returns a bool |
| 1386 """ |
| 1387 |
| 1388 self._func = func |
| 1389 |
| 1390 def equals(self, rhs): |
| 1391 """Test whether rhs passes the function test. |
| 1392 |
| 1393 rhs is passed into func. |
| 1394 |
| 1395 Args: |
| 1396 rhs: any python object |
| 1397 |
| 1398 Returns: |
| 1399 the result of func(rhs) |
| 1400 """ |
| 1401 |
| 1402 return self._func(rhs) |
| 1403 |
| 1404 def __repr__(self): |
| 1405 return str(self._func) |
| 1406 |
| 1407 |
| 1408 class IgnoreArg(Comparator): |
| 1409 """Ignore an argument. |
| 1410 |
| 1411 This can be used when we don't care about an argument of a method call. |
| 1412 |
| 1413 Example: |
| 1414 # Check if CastMagic is called with 3 as first arg and 'disappear' as third. |
| 1415 mymock.CastMagic(3, IgnoreArg(), 'disappear') |
| 1416 """ |
| 1417 |
| 1418 def equals(self, unused_rhs): |
| 1419 """Ignores arguments and returns True. |
| 1420 |
| 1421 Args: |
| 1422 unused_rhs: any python object |
| 1423 |
| 1424 Returns: |
| 1425 always returns True |
| 1426 """ |
| 1427 |
| 1428 return True |
| 1429 |
| 1430 def __repr__(self): |
| 1431 return '<IgnoreArg>' |
| 1432 |
| 1433 |
| 1434 class MethodGroup(object): |
| 1435 """Base class containing common behaviour for MethodGroups.""" |
| 1436 |
| 1437 def __init__(self, group_name): |
| 1438 self._group_name = group_name |
| 1439 |
| 1440 def group_name(self): |
| 1441 return self._group_name |
| 1442 |
| 1443 def __str__(self): |
| 1444 return '<%s "%s">' % (self.__class__.__name__, self._group_name) |
| 1445 |
| 1446 def AddMethod(self, mock_method): |
| 1447 raise NotImplementedError |
| 1448 |
| 1449 def MethodCalled(self, mock_method): |
| 1450 raise NotImplementedError |
| 1451 |
| 1452 def IsSatisfied(self): |
| 1453 raise NotImplementedError |
| 1454 |
| 1455 class UnorderedGroup(MethodGroup): |
| 1456 """UnorderedGroup holds a set of method calls that may occur in any order. |
| 1457 |
| 1458 This construct is helpful for non-deterministic events, such as iterating |
| 1459 over the keys of a dict. |
| 1460 """ |
| 1461 |
| 1462 def __init__(self, group_name): |
| 1463 super(UnorderedGroup, self).__init__(group_name) |
| 1464 self._methods = [] |
| 1465 |
| 1466 def AddMethod(self, mock_method): |
| 1467 """Add a method to this group. |
| 1468 |
| 1469 Args: |
| 1470 mock_method: A mock method to be added to this group. |
| 1471 """ |
| 1472 |
| 1473 self._methods.append(mock_method) |
| 1474 |
| 1475 def MethodCalled(self, mock_method): |
| 1476 """Remove a method call from the group. |
| 1477 |
| 1478 If the method is not in the set, an UnexpectedMethodCallError will be |
| 1479 raised. |
| 1480 |
| 1481 Args: |
| 1482 mock_method: a mock method that should be equal to a method in the group. |
| 1483 |
| 1484 Returns: |
| 1485 The mock method from the group |
| 1486 |
| 1487 Raises: |
| 1488 UnexpectedMethodCallError if the mock_method was not in the group. |
| 1489 """ |
| 1490 |
| 1491 # Check to see if this method exists, and if so, remove it from the set |
| 1492 # and return it. |
| 1493 for method in self._methods: |
| 1494 if method == mock_method: |
| 1495 # Remove the called mock_method instead of the method in the group. |
| 1496 # The called method will match any comparators when equality is checked |
| 1497 # during removal. The method in the group could pass a comparator to |
| 1498 # another comparator during the equality check. |
| 1499 self._methods.remove(mock_method) |
| 1500 |
| 1501 # If this group is not empty, put it back at the head of the queue. |
| 1502 if not self.IsSatisfied(): |
| 1503 mock_method._call_queue.appendleft(self) |
| 1504 |
| 1505 return self, method |
| 1506 |
| 1507 raise UnexpectedMethodCallError(mock_method, self) |
| 1508 |
| 1509 def IsSatisfied(self): |
| 1510 """Return True if there are not any methods in this group.""" |
| 1511 |
| 1512 return len(self._methods) == 0 |
| 1513 |
| 1514 |
| 1515 class MultipleTimesGroup(MethodGroup): |
| 1516 """MultipleTimesGroup holds methods that may be called any number of times. |
| 1517 |
| 1518 Note: Each method must be called at least once. |
| 1519 |
| 1520 This is helpful, if you don't know or care how many times a method is called. |
| 1521 """ |
| 1522 |
| 1523 def __init__(self, group_name): |
| 1524 super(MultipleTimesGroup, self).__init__(group_name) |
| 1525 self._methods = set() |
| 1526 self._methods_left = set() |
| 1527 |
| 1528 def AddMethod(self, mock_method): |
| 1529 """Add a method to this group. |
| 1530 |
| 1531 Args: |
| 1532 mock_method: A mock method to be added to this group. |
| 1533 """ |
| 1534 |
| 1535 self._methods.add(mock_method) |
| 1536 self._methods_left.add(mock_method) |
| 1537 |
| 1538 def MethodCalled(self, mock_method): |
| 1539 """Remove a method call from the group. |
| 1540 |
| 1541 If the method is not in the set, an UnexpectedMethodCallError will be |
| 1542 raised. |
| 1543 |
| 1544 Args: |
| 1545 mock_method: a mock method that should be equal to a method in the group. |
| 1546 |
| 1547 Returns: |
| 1548 The mock method from the group |
| 1549 |
| 1550 Raises: |
| 1551 UnexpectedMethodCallError if the mock_method was not in the group. |
| 1552 """ |
| 1553 |
| 1554 # Check to see if this method exists, and if so add it to the set of |
| 1555 # called methods. |
| 1556 for method in self._methods: |
| 1557 if method == mock_method: |
| 1558 self._methods_left.discard(method) |
| 1559 # Always put this group back on top of the queue, because we don't know |
| 1560 # when we are done. |
| 1561 mock_method._call_queue.appendleft(self) |
| 1562 return self, method |
| 1563 |
| 1564 if self.IsSatisfied(): |
| 1565 next_method = mock_method._PopNextMethod(); |
| 1566 return next_method, None |
| 1567 else: |
| 1568 raise UnexpectedMethodCallError(mock_method, self) |
| 1569 |
| 1570 def IsSatisfied(self): |
| 1571 """Return True if all methods in this group are called at least once.""" |
| 1572 return len(self._methods_left) == 0 |
| 1573 |
| 1574 |
| 1575 class MoxMetaTestBase(type): |
| 1576 """Metaclass to add mox cleanup and verification to every test. |
| 1577 |
| 1578 As the mox unit testing class is being constructed (MoxTestBase or a |
| 1579 subclass), this metaclass will modify all test functions to call the |
| 1580 CleanUpMox method of the test class after they finish. This means that |
| 1581 unstubbing and verifying will happen for every test with no additional code, |
| 1582 and any failures will result in test failures as opposed to errors. |
| 1583 """ |
| 1584 |
| 1585 def __init__(cls, name, bases, d): |
| 1586 type.__init__(cls, name, bases, d) |
| 1587 |
| 1588 # also get all the attributes from the base classes to account |
| 1589 # for a case when test class is not the immediate child of MoxTestBase |
| 1590 for base in bases: |
| 1591 for attr_name in dir(base): |
| 1592 d[attr_name] = getattr(base, attr_name) |
| 1593 |
| 1594 for func_name, func in d.items(): |
| 1595 if func_name.startswith('test') and callable(func): |
| 1596 setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func)) |
| 1597 |
| 1598 @staticmethod |
| 1599 def CleanUpTest(cls, func): |
| 1600 """Adds Mox cleanup code to any MoxTestBase method. |
| 1601 |
| 1602 Always unsets stubs after a test. Will verify all mocks for tests that |
| 1603 otherwise pass. |
| 1604 |
| 1605 Args: |
| 1606 cls: MoxTestBase or subclass; the class whose test method we are altering. |
| 1607 func: method; the method of the MoxTestBase test class we wish to alter. |
| 1608 |
| 1609 Returns: |
| 1610 The modified method. |
| 1611 """ |
| 1612 def new_method(self, *args, **kwargs): |
| 1613 mox_obj = getattr(self, 'mox', None) |
| 1614 cleanup_mox = False |
| 1615 if mox_obj and isinstance(mox_obj, Mox): |
| 1616 cleanup_mox = True |
| 1617 try: |
| 1618 func(self, *args, **kwargs) |
| 1619 finally: |
| 1620 if cleanup_mox: |
| 1621 mox_obj.UnsetStubs() |
| 1622 if cleanup_mox: |
| 1623 mox_obj.VerifyAll() |
| 1624 new_method.__name__ = func.__name__ |
| 1625 new_method.__doc__ = func.__doc__ |
| 1626 new_method.__module__ = func.__module__ |
| 1627 return new_method |
| 1628 |
| 1629 |
| 1630 class MoxTestBase(unittest.TestCase): |
| 1631 """Convenience test class to make stubbing easier. |
| 1632 |
| 1633 Sets up a "mox" attribute which is an instance of Mox - any mox tests will |
| 1634 want this. Also automatically unsets any stubs and verifies that all mock |
| 1635 methods have been called at the end of each test, eliminating boilerplate |
| 1636 code. |
| 1637 """ |
| 1638 |
| 1639 __metaclass__ = MoxMetaTestBase |
| 1640 |
| 1641 def setUp(self): |
| 1642 super(MoxTestBase, self).setUp() |
| 1643 self.mox = Mox() |
OLD | NEW |