Index: tests/pymox/mox.py |
=================================================================== |
--- tests/pymox/mox.py (revision 0) |
+++ tests/pymox/mox.py (revision 0) |
@@ -0,0 +1,1643 @@ |
+#!/usr/bin/python2.4 |
+# |
+# Copyright 2008 Google Inc. |
+# |
+# Licensed under the Apache License, Version 2.0 (the "License"); |
+# you may not use this file except in compliance with the License. |
+# You may obtain a copy of the License at |
+# |
+# http://www.apache.org/licenses/LICENSE-2.0 |
+# |
+# Unless required by applicable law or agreed to in writing, software |
+# distributed under the License is distributed on an "AS IS" BASIS, |
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
+# See the License for the specific language governing permissions and |
+# limitations under the License. |
+ |
+"""Mox, an object-mocking framework for Python. |
+ |
+Mox works in the record-replay-verify paradigm. When you first create |
+a mock object, it is in record mode. You then programmatically set |
+the expected behavior of the mock object (what methods are to be |
+called on it, with what parameters, what they should return, and in |
+what order). |
+ |
+Once you have set up the expected mock behavior, you put it in replay |
+mode. Now the mock responds to method calls just as you told it to. |
+If an unexpected method (or an expected method with unexpected |
+parameters) is called, then an exception will be raised. |
+ |
+Once you are done interacting with the mock, you need to verify that |
+all the expected interactions occured. (Maybe your code exited |
+prematurely without calling some cleanup method!) The verify phase |
+ensures that every expected method was called; otherwise, an exception |
+will be raised. |
+ |
+WARNING! Mock objects created by Mox are not thread-safe. If you are |
+call a mock in multiple threads, it should be guarded by a mutex. |
+ |
+TODO(stevepm): Add the option to make mocks thread-safe! |
+ |
+Suggested usage / workflow: |
+ |
+ # Create Mox factory |
+ my_mox = Mox() |
+ |
+ # Create a mock data access object |
+ mock_dao = my_mox.CreateMock(DAOClass) |
+ |
+ # Set up expected behavior |
+ mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person) |
+ mock_dao.DeletePerson(person) |
+ |
+ # Put mocks in replay mode |
+ my_mox.ReplayAll() |
+ |
+ # Inject mock object and run test |
+ controller.SetDao(mock_dao) |
+ controller.DeletePersonById('1') |
+ |
+ # Verify all methods were called as expected |
+ my_mox.VerifyAll() |
+""" |
+ |
+from collections import deque |
+import difflib |
+import inspect |
+import re |
+import types |
+import unittest |
+ |
+import stubout |
+ |
+class Error(AssertionError): |
+ """Base exception for this module.""" |
+ |
+ pass |
+ |
+ |
+class ExpectedMethodCallsError(Error): |
+ """Raised when Verify() is called before all expected methods have been called |
+ """ |
+ |
+ def __init__(self, expected_methods): |
+ """Init exception. |
+ |
+ Args: |
+ # expected_methods: A sequence of MockMethod objects that should have been |
+ # called. |
+ expected_methods: [MockMethod] |
+ |
+ Raises: |
+ ValueError: if expected_methods contains no methods. |
+ """ |
+ |
+ if not expected_methods: |
+ raise ValueError("There must be at least one expected method") |
+ Error.__init__(self) |
+ self._expected_methods = expected_methods |
+ |
+ def __str__(self): |
+ calls = "\n".join(["%3d. %s" % (i, m) |
+ for i, m in enumerate(self._expected_methods)]) |
+ return "Verify: Expected methods never called:\n%s" % (calls,) |
+ |
+ |
+class UnexpectedMethodCallError(Error): |
+ """Raised when an unexpected method is called. |
+ |
+ This can occur if a method is called with incorrect parameters, or out of the |
+ specified order. |
+ """ |
+ |
+ def __init__(self, unexpected_method, expected): |
+ """Init exception. |
+ |
+ Args: |
+ # unexpected_method: MockMethod that was called but was not at the head of |
+ # the expected_method queue. |
+ # expected: MockMethod or UnorderedGroup the method should have |
+ # been in. |
+ unexpected_method: MockMethod |
+ expected: MockMethod or UnorderedGroup |
+ """ |
+ |
+ Error.__init__(self) |
+ if expected is None: |
+ self._str = "Unexpected method call %s" % (unexpected_method,) |
+ else: |
+ differ = difflib.Differ() |
+ diff = differ.compare(str(unexpected_method).splitlines(True), |
+ str(expected).splitlines(True)) |
+ self._str = ("Unexpected method call. unexpected:- expected:+\n%s" |
+ % ("\n".join(diff),)) |
+ |
+ def __str__(self): |
+ return self._str |
+ |
+ |
+class UnknownMethodCallError(Error): |
+ """Raised if an unknown method is requested of the mock object.""" |
+ |
+ def __init__(self, unknown_method_name): |
+ """Init exception. |
+ |
+ Args: |
+ # unknown_method_name: Method call that is not part of the mocked class's |
+ # public interface. |
+ unknown_method_name: str |
+ """ |
+ |
+ Error.__init__(self) |
+ self._unknown_method_name = unknown_method_name |
+ |
+ def __str__(self): |
+ return "Method called is not a member of the object: %s" % \ |
+ self._unknown_method_name |
+ |
+ |
+class Mox(object): |
+ """Mox: a factory for creating mock objects.""" |
+ |
+ # A list of types that should be stubbed out with MockObjects (as |
+ # opposed to MockAnythings). |
+ _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType, |
+ types.ObjectType, types.TypeType] |
+ |
+ def __init__(self): |
+ """Initialize a new Mox.""" |
+ |
+ self._mock_objects = [] |
+ self.stubs = stubout.StubOutForTesting() |
+ |
+ def CreateMock(self, class_to_mock): |
+ """Create a new mock object. |
+ |
+ Args: |
+ # class_to_mock: the class to be mocked |
+ class_to_mock: class |
+ |
+ Returns: |
+ MockObject that can be used as the class_to_mock would be. |
+ """ |
+ |
+ new_mock = MockObject(class_to_mock) |
+ self._mock_objects.append(new_mock) |
+ return new_mock |
+ |
+ def CreateMockAnything(self, description=None): |
+ """Create a mock that will accept any method calls. |
+ |
+ This does not enforce an interface. |
+ |
+ Args: |
+ description: str. Optionally, a descriptive name for the mock object being |
+ created, for debugging output purposes. |
+ """ |
+ new_mock = MockAnything(description=description) |
+ self._mock_objects.append(new_mock) |
+ return new_mock |
+ |
+ def ReplayAll(self): |
+ """Set all mock objects to replay mode.""" |
+ |
+ for mock_obj in self._mock_objects: |
+ mock_obj._Replay() |
+ |
+ |
+ def VerifyAll(self): |
+ """Call verify on all mock objects created.""" |
+ |
+ for mock_obj in self._mock_objects: |
+ mock_obj._Verify() |
+ |
+ def ResetAll(self): |
+ """Call reset on all mock objects. This does not unset stubs.""" |
+ |
+ for mock_obj in self._mock_objects: |
+ mock_obj._Reset() |
+ |
+ def StubOutWithMock(self, obj, attr_name, use_mock_anything=False): |
+ """Replace a method, attribute, etc. with a Mock. |
+ |
+ This will replace a class or module with a MockObject, and everything else |
+ (method, function, etc) with a MockAnything. This can be overridden to |
+ always use a MockAnything by setting use_mock_anything to True. |
+ |
+ Args: |
+ obj: A Python object (class, module, instance, callable). |
+ attr_name: str. The name of the attribute to replace with a mock. |
+ use_mock_anything: bool. True if a MockAnything should be used regardless |
+ of the type of attribute. |
+ """ |
+ |
+ attr_to_replace = getattr(obj, attr_name) |
+ |
+ # Check for a MockAnything. This could cause confusing problems later on. |
+ if attr_to_replace == MockAnything(): |
+ raise TypeError('Cannot mock a MockAnything! Did you remember to ' |
+ 'call UnsetStubs in your previous test?') |
+ |
+ if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything: |
+ stub = self.CreateMock(attr_to_replace) |
+ else: |
+ stub = self.CreateMockAnything(description='Stub for %s' % attr_to_replace) |
+ |
+ self.stubs.Set(obj, attr_name, stub) |
+ |
+ def UnsetStubs(self): |
+ """Restore stubs to their original state.""" |
+ |
+ self.stubs.UnsetAll() |
+ |
+def Replay(*args): |
+ """Put mocks into Replay mode. |
+ |
+ Args: |
+ # args is any number of mocks to put into replay mode. |
+ """ |
+ |
+ for mock in args: |
+ mock._Replay() |
+ |
+ |
+def Verify(*args): |
+ """Verify mocks. |
+ |
+ Args: |
+ # args is any number of mocks to be verified. |
+ """ |
+ |
+ for mock in args: |
+ mock._Verify() |
+ |
+ |
+def Reset(*args): |
+ """Reset mocks. |
+ |
+ Args: |
+ # args is any number of mocks to be reset. |
+ """ |
+ |
+ for mock in args: |
+ mock._Reset() |
+ |
+ |
+class MockAnything: |
+ """A mock that can be used to mock anything. |
+ |
+ This is helpful for mocking classes that do not provide a public interface. |
+ """ |
+ |
+ def __init__(self, description=None): |
+ """Initialize a new MockAnything. |
+ |
+ Args: |
+ description: str. Optionally, a descriptive name for the mock object being |
+ created, for debugging output purposes. |
+ """ |
+ self._description = description |
+ self._Reset() |
+ |
+ def __str__(self): |
+ return "<MockAnything instance at %s>" % id(self) |
+ |
+ def __repr__(self): |
+ return '<MockAnything instance>' |
+ |
+ def __getattr__(self, method_name): |
+ """Intercept method calls on this object. |
+ |
+ A new MockMethod is returned that is aware of the MockAnything's |
+ state (record or replay). The call will be recorded or replayed |
+ by the MockMethod's __call__. |
+ |
+ Args: |
+ # method name: the name of the method being called. |
+ method_name: str |
+ |
+ Returns: |
+ A new MockMethod aware of MockAnything's state (record or replay). |
+ """ |
+ |
+ return self._CreateMockMethod(method_name) |
+ |
+ def _CreateMockMethod(self, method_name, method_to_mock=None): |
+ """Create a new mock method call and return it. |
+ |
+ Args: |
+ # method_name: the name of the method being called. |
+ # method_to_mock: The actual method being mocked, used for introspection. |
+ method_name: str |
+ method_to_mock: a method object |
+ |
+ Returns: |
+ A new MockMethod aware of MockAnything's state (record or replay). |
+ """ |
+ |
+ return MockMethod(method_name, self._expected_calls_queue, |
+ self._replay_mode, method_to_mock=method_to_mock, |
+ description=self._description) |
+ |
+ def __nonzero__(self): |
+ """Return 1 for nonzero so the mock can be used as a conditional.""" |
+ |
+ return 1 |
+ |
+ def __eq__(self, rhs): |
+ """Provide custom logic to compare objects.""" |
+ |
+ return (isinstance(rhs, MockAnything) and |
+ self._replay_mode == rhs._replay_mode and |
+ self._expected_calls_queue == rhs._expected_calls_queue) |
+ |
+ def __ne__(self, rhs): |
+ """Provide custom logic to compare objects.""" |
+ |
+ return not self == rhs |
+ |
+ def _Replay(self): |
+ """Start replaying expected method calls.""" |
+ |
+ self._replay_mode = True |
+ |
+ def _Verify(self): |
+ """Verify that all of the expected calls have been made. |
+ |
+ Raises: |
+ ExpectedMethodCallsError: if there are still more method calls in the |
+ expected queue. |
+ """ |
+ |
+ # If the list of expected calls is not empty, raise an exception |
+ if self._expected_calls_queue: |
+ # The last MultipleTimesGroup is not popped from the queue. |
+ if (len(self._expected_calls_queue) == 1 and |
+ isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and |
+ self._expected_calls_queue[0].IsSatisfied()): |
+ pass |
+ else: |
+ raise ExpectedMethodCallsError(self._expected_calls_queue) |
+ |
+ def _Reset(self): |
+ """Reset the state of this mock to record mode with an empty queue.""" |
+ |
+ # Maintain a list of method calls we are expecting |
+ self._expected_calls_queue = deque() |
+ |
+ # Make sure we are in setup mode, not replay mode |
+ self._replay_mode = False |
+ |
+ |
+class MockObject(MockAnything, object): |
+ """A mock object that simulates the public/protected interface of a class.""" |
+ |
+ def __init__(self, class_to_mock): |
+ """Initialize a mock object. |
+ |
+ This determines the methods and properties of the class and stores them. |
+ |
+ Args: |
+ # class_to_mock: class to be mocked |
+ class_to_mock: class |
+ """ |
+ |
+ # This is used to hack around the mixin/inheritance of MockAnything, which |
+ # is not a proper object (it can be anything. :-) |
+ MockAnything.__dict__['__init__'](self) |
+ |
+ # Get a list of all the public and special methods we should mock. |
+ self._known_methods = set() |
+ self._known_vars = set() |
+ self._class_to_mock = class_to_mock |
+ for method in dir(class_to_mock): |
+ if callable(getattr(class_to_mock, method)): |
+ self._known_methods.add(method) |
+ else: |
+ self._known_vars.add(method) |
+ |
+ def __getattr__(self, name): |
+ """Intercept attribute request on this object. |
+ |
+ If the attribute is a public class variable, it will be returned and not |
+ recorded as a call. |
+ |
+ If the attribute is not a variable, it is handled like a method |
+ call. The method name is checked against the set of mockable |
+ methods, and a new MockMethod is returned that is aware of the |
+ MockObject's state (record or replay). The call will be recorded |
+ or replayed by the MockMethod's __call__. |
+ |
+ Args: |
+ # name: the name of the attribute being requested. |
+ name: str |
+ |
+ Returns: |
+ Either a class variable or a new MockMethod that is aware of the state |
+ of the mock (record or replay). |
+ |
+ Raises: |
+ UnknownMethodCallError if the MockObject does not mock the requested |
+ method. |
+ """ |
+ |
+ if name in self._known_vars: |
+ return getattr(self._class_to_mock, name) |
+ |
+ if name in self._known_methods: |
+ return self._CreateMockMethod( |
+ name, |
+ method_to_mock=getattr(self._class_to_mock, name)) |
+ |
+ raise UnknownMethodCallError(name) |
+ |
+ def __eq__(self, rhs): |
+ """Provide custom logic to compare objects.""" |
+ |
+ return (isinstance(rhs, MockObject) and |
+ self._class_to_mock == rhs._class_to_mock and |
+ self._replay_mode == rhs._replay_mode and |
+ self._expected_calls_queue == rhs._expected_calls_queue) |
+ |
+ def __setitem__(self, key, value): |
+ """Provide custom logic for mocking classes that support item assignment. |
+ |
+ Args: |
+ key: Key to set the value for. |
+ value: Value to set. |
+ |
+ Returns: |
+ Expected return value in replay mode. A MockMethod object for the |
+ __setitem__ method that has already been called if not in replay mode. |
+ |
+ Raises: |
+ TypeError if the underlying class does not support item assignment. |
+ UnexpectedMethodCallError if the object does not expect the call to |
+ __setitem__. |
+ |
+ """ |
+ # Verify the class supports item assignment. |
+ if '__setitem__' not in dir(self._class_to_mock): |
+ raise TypeError('object does not support item assignment') |
+ |
+ # If we are in replay mode then simply call the mock __setitem__ method. |
+ if self._replay_mode: |
+ return MockMethod('__setitem__', self._expected_calls_queue, |
+ self._replay_mode)(key, value) |
+ |
+ |
+ # Otherwise, create a mock method __setitem__. |
+ return self._CreateMockMethod('__setitem__')(key, value) |
+ |
+ def __getitem__(self, key): |
+ """Provide custom logic for mocking classes that are subscriptable. |
+ |
+ Args: |
+ key: Key to return the value for. |
+ |
+ Returns: |
+ Expected return value in replay mode. A MockMethod object for the |
+ __getitem__ method that has already been called if not in replay mode. |
+ |
+ Raises: |
+ TypeError if the underlying class is not subscriptable. |
+ UnexpectedMethodCallError if the object does not expect the call to |
+ __getitem__. |
+ |
+ """ |
+ # Verify the class supports item assignment. |
+ if '__getitem__' not in dir(self._class_to_mock): |
+ raise TypeError('unsubscriptable object') |
+ |
+ # If we are in replay mode then simply call the mock __getitem__ method. |
+ if self._replay_mode: |
+ return MockMethod('__getitem__', self._expected_calls_queue, |
+ self._replay_mode)(key) |
+ |
+ |
+ # Otherwise, create a mock method __getitem__. |
+ return self._CreateMockMethod('__getitem__')(key) |
+ |
+ def __iter__(self): |
+ """Provide custom logic for mocking classes that are iterable. |
+ |
+ Returns: |
+ Expected return value in replay mode. A MockMethod object for the |
+ __iter__ method that has already been called if not in replay mode. |
+ |
+ Raises: |
+ TypeError if the underlying class is not iterable. |
+ UnexpectedMethodCallError if the object does not expect the call to |
+ __iter__. |
+ |
+ """ |
+ methods = dir(self._class_to_mock) |
+ |
+ # Verify the class supports iteration. |
+ if '__iter__' not in methods: |
+ # If it doesn't have iter method and we are in replay method, then try to |
+ # iterate using subscripts. |
+ if '__getitem__' not in methods or not self._replay_mode: |
+ raise TypeError('not iterable object') |
+ else: |
+ results = [] |
+ index = 0 |
+ try: |
+ while True: |
+ results.append(self[index]) |
+ index += 1 |
+ except IndexError: |
+ return iter(results) |
+ |
+ # If we are in replay mode then simply call the mock __iter__ method. |
+ if self._replay_mode: |
+ return MockMethod('__iter__', self._expected_calls_queue, |
+ self._replay_mode)() |
+ |
+ |
+ # Otherwise, create a mock method __iter__. |
+ return self._CreateMockMethod('__iter__')() |
+ |
+ |
+ def __contains__(self, key): |
+ """Provide custom logic for mocking classes that contain items. |
+ |
+ Args: |
+ key: Key to look in container for. |
+ |
+ Returns: |
+ Expected return value in replay mode. A MockMethod object for the |
+ __contains__ method that has already been called if not in replay mode. |
+ |
+ Raises: |
+ TypeError if the underlying class does not implement __contains__ |
+ UnexpectedMethodCaller if the object does not expect the call to |
+ __contains__. |
+ |
+ """ |
+ contains = self._class_to_mock.__dict__.get('__contains__', None) |
+ |
+ if contains is None: |
+ raise TypeError('unsubscriptable object') |
+ |
+ if self._replay_mode: |
+ return MockMethod('__contains__', self._expected_calls_queue, |
+ self._replay_mode)(key) |
+ |
+ return self._CreateMockMethod('__contains__')(key) |
+ |
+ def __call__(self, *params, **named_params): |
+ """Provide custom logic for mocking classes that are callable.""" |
+ |
+ # Verify the class we are mocking is callable. |
+ callable = hasattr(self._class_to_mock, '__call__') |
+ if not callable: |
+ raise TypeError('Not callable') |
+ |
+ # Because the call is happening directly on this object instead of a method, |
+ # the call on the mock method is made right here |
+ mock_method = self._CreateMockMethod('__call__') |
+ return mock_method(*params, **named_params) |
+ |
+ @property |
+ def __class__(self): |
+ """Return the class that is being mocked.""" |
+ |
+ return self._class_to_mock |
+ |
+ |
+class MethodCallChecker(object): |
+ """Ensures that methods are called correctly.""" |
+ |
+ _NEEDED, _DEFAULT, _GIVEN = range(3) |
+ |
+ def __init__(self, method): |
+ """Creates a checker. |
+ |
+ Args: |
+ # method: A method to check. |
+ method: function |
+ |
+ Raises: |
+ ValueError: method could not be inspected, so checks aren't possible. |
+ Some methods and functions like built-ins can't be inspected. |
+ """ |
+ try: |
+ self._args, varargs, varkw, defaults = inspect.getargspec(method) |
+ except TypeError: |
+ raise ValueError('Could not get argument specification for %r' |
+ % (method,)) |
+ if inspect.ismethod(method): |
+ self._args = self._args[1:] # Skip 'self'. |
+ self._method = method |
+ |
+ self._has_varargs = varargs is not None |
+ self._has_varkw = varkw is not None |
+ if defaults is None: |
+ self._required_args = self._args |
+ self._default_args = [] |
+ else: |
+ self._required_args = self._args[:-len(defaults)] |
+ self._default_args = self._args[-len(defaults):] |
+ |
+ def _RecordArgumentGiven(self, arg_name, arg_status): |
+ """Mark an argument as being given. |
+ |
+ Args: |
+ # arg_name: The name of the argument to mark in arg_status. |
+ # arg_status: Maps argument names to one of _NEEDED, _DEFAULT, _GIVEN. |
+ arg_name: string |
+ arg_status: dict |
+ |
+ Raises: |
+ AttributeError: arg_name is already marked as _GIVEN. |
+ """ |
+ if arg_status.get(arg_name, None) == MethodCallChecker._GIVEN: |
+ raise AttributeError('%s provided more than once' % (arg_name,)) |
+ arg_status[arg_name] = MethodCallChecker._GIVEN |
+ |
+ def Check(self, params, named_params): |
+ """Ensures that the parameters used while recording a call are valid. |
+ |
+ Args: |
+ # params: A list of positional parameters. |
+ # named_params: A dict of named parameters. |
+ params: list |
+ named_params: dict |
+ |
+ Raises: |
+ AttributeError: the given parameters don't work with the given method. |
+ """ |
+ arg_status = dict((a, MethodCallChecker._NEEDED) |
+ for a in self._required_args) |
+ for arg in self._default_args: |
+ arg_status[arg] = MethodCallChecker._DEFAULT |
+ |
+ # Check that each positional param is valid. |
+ for i in range(len(params)): |
+ try: |
+ arg_name = self._args[i] |
+ except IndexError: |
+ if not self._has_varargs: |
+ raise AttributeError('%s does not take %d or more positional ' |
+ 'arguments' % (self._method.__name__, i)) |
+ else: |
+ self._RecordArgumentGiven(arg_name, arg_status) |
+ |
+ # Check each keyword argument. |
+ for arg_name in named_params: |
+ if arg_name not in arg_status and not self._has_varkw: |
+ raise AttributeError('%s is not expecting keyword argument %s' |
+ % (self._method.__name__, arg_name)) |
+ self._RecordArgumentGiven(arg_name, arg_status) |
+ |
+ # Ensure all the required arguments have been given. |
+ still_needed = [k for k, v in arg_status.iteritems() |
+ if v == MethodCallChecker._NEEDED] |
+ if still_needed: |
+ raise AttributeError('No values given for arguments %s' |
+ % (' '.join(sorted(still_needed)))) |
+ |
+ |
+class MockMethod(object): |
+ """Callable mock method. |
+ |
+ A MockMethod should act exactly like the method it mocks, accepting parameters |
+ and returning a value, or throwing an exception (as specified). When this |
+ method is called, it can optionally verify whether the called method (name and |
+ signature) matches the expected method. |
+ """ |
+ |
+ def __init__(self, method_name, call_queue, replay_mode, |
+ method_to_mock=None, description=None): |
+ """Construct a new mock method. |
+ |
+ Args: |
+ # method_name: the name of the method |
+ # call_queue: deque of calls, verify this call against the head, or add |
+ # this call to the queue. |
+ # replay_mode: False if we are recording, True if we are verifying calls |
+ # against the call queue. |
+ # method_to_mock: The actual method being mocked, used for introspection. |
+ # description: optionally, a descriptive name for this method. Typically |
+ # this is equal to the descriptive name of the method's class. |
+ method_name: str |
+ call_queue: list or deque |
+ replay_mode: bool |
+ method_to_mock: a method object |
+ description: str or None |
+ """ |
+ |
+ self._name = method_name |
+ self._call_queue = call_queue |
+ if not isinstance(call_queue, deque): |
+ self._call_queue = deque(self._call_queue) |
+ self._replay_mode = replay_mode |
+ self._description = description |
+ |
+ self._params = None |
+ self._named_params = None |
+ self._return_value = None |
+ self._exception = None |
+ self._side_effects = None |
+ |
+ try: |
+ self._checker = MethodCallChecker(method_to_mock) |
+ except ValueError: |
+ self._checker = None |
+ |
+ def __call__(self, *params, **named_params): |
+ """Log parameters and return the specified return value. |
+ |
+ If the Mock(Anything/Object) associated with this call is in record mode, |
+ this MockMethod will be pushed onto the expected call queue. If the mock |
+ is in replay mode, this will pop a MockMethod off the top of the queue and |
+ verify this call is equal to the expected call. |
+ |
+ Raises: |
+ UnexpectedMethodCall if this call is supposed to match an expected method |
+ call and it does not. |
+ """ |
+ |
+ self._params = params |
+ self._named_params = named_params |
+ |
+ if not self._replay_mode: |
+ if self._checker is not None: |
+ self._checker.Check(params, named_params) |
+ self._call_queue.append(self) |
+ return self |
+ |
+ expected_method = self._VerifyMethodCall() |
+ |
+ if expected_method._side_effects: |
+ expected_method._side_effects(*params, **named_params) |
+ |
+ if expected_method._exception: |
+ raise expected_method._exception |
+ |
+ return expected_method._return_value |
+ |
+ def __getattr__(self, name): |
+ """Raise an AttributeError with a helpful message.""" |
+ |
+ raise AttributeError('MockMethod has no attribute "%s". ' |
+ 'Did you remember to put your mocks in replay mode?' % name) |
+ |
+ def __iter__(self): |
+ """Raise a TypeError with a helpful message.""" |
+ raise TypeError('MockMethod cannot be iterated. ' |
+ 'Did you remember to put your mocks in replay mode?') |
+ |
+ def next(self): |
+ """Raise a TypeError with a helpful message.""" |
+ raise TypeError('MockMethod cannot be iterated. ' |
+ 'Did you remember to put your mocks in replay mode?') |
+ |
+ def _PopNextMethod(self): |
+ """Pop the next method from our call queue.""" |
+ try: |
+ return self._call_queue.popleft() |
+ except IndexError: |
+ raise UnexpectedMethodCallError(self, None) |
+ |
+ def _VerifyMethodCall(self): |
+ """Verify the called method is expected. |
+ |
+ This can be an ordered method, or part of an unordered set. |
+ |
+ Returns: |
+ The expected mock method. |
+ |
+ Raises: |
+ UnexpectedMethodCall if the method called was not expected. |
+ """ |
+ |
+ expected = self._PopNextMethod() |
+ |
+ # Loop here, because we might have a MethodGroup followed by another |
+ # group. |
+ while isinstance(expected, MethodGroup): |
+ expected, method = expected.MethodCalled(self) |
+ if method is not None: |
+ return method |
+ |
+ # This is a mock method, so just check equality. |
+ if expected != self: |
+ raise UnexpectedMethodCallError(self, expected) |
+ |
+ return expected |
+ |
+ def __str__(self): |
+ params = ', '.join( |
+ [repr(p) for p in self._params or []] + |
+ ['%s=%r' % x for x in sorted((self._named_params or {}).items())]) |
+ full_desc = "%s(%s) -> %r" % (self._name, params, self._return_value) |
+ if self._description: |
+ full_desc = "%s.%s" % (self._description, full_desc) |
+ return full_desc |
+ |
+ def __eq__(self, rhs): |
+ """Test whether this MockMethod is equivalent to another MockMethod. |
+ |
+ Args: |
+ # rhs: the right hand side of the test |
+ rhs: MockMethod |
+ """ |
+ |
+ return (isinstance(rhs, MockMethod) and |
+ self._name == rhs._name and |
+ self._params == rhs._params and |
+ self._named_params == rhs._named_params) |
+ |
+ def __ne__(self, rhs): |
+ """Test whether this MockMethod is not equivalent to another MockMethod. |
+ |
+ Args: |
+ # rhs: the right hand side of the test |
+ rhs: MockMethod |
+ """ |
+ |
+ return not self == rhs |
+ |
+ def GetPossibleGroup(self): |
+ """Returns a possible group from the end of the call queue or None if no |
+ other methods are on the stack. |
+ """ |
+ |
+ # Remove this method from the tail of the queue so we can add it to a group. |
+ this_method = self._call_queue.pop() |
+ assert this_method == self |
+ |
+ # Determine if the tail of the queue is a group, or just a regular ordered |
+ # mock method. |
+ group = None |
+ try: |
+ group = self._call_queue[-1] |
+ except IndexError: |
+ pass |
+ |
+ return group |
+ |
+ def _CheckAndCreateNewGroup(self, group_name, group_class): |
+ """Checks if the last method (a possible group) is an instance of our |
+ group_class. Adds the current method to this group or creates a new one. |
+ |
+ Args: |
+ |
+ group_name: the name of the group. |
+ group_class: the class used to create instance of this new group |
+ """ |
+ group = self.GetPossibleGroup() |
+ |
+ # If this is a group, and it is the correct group, add the method. |
+ if isinstance(group, group_class) and group.group_name() == group_name: |
+ group.AddMethod(self) |
+ return self |
+ |
+ # Create a new group and add the method. |
+ new_group = group_class(group_name) |
+ new_group.AddMethod(self) |
+ self._call_queue.append(new_group) |
+ return self |
+ |
+ def InAnyOrder(self, group_name="default"): |
+ """Move this method into a group of unordered calls. |
+ |
+ A group of unordered calls must be defined together, and must be executed |
+ in full before the next expected method can be called. There can be |
+ multiple groups that are expected serially, if they are given |
+ different group names. The same group name can be reused if there is a |
+ standard method call, or a group with a different name, spliced between |
+ usages. |
+ |
+ Args: |
+ group_name: the name of the unordered group. |
+ |
+ Returns: |
+ self |
+ """ |
+ return self._CheckAndCreateNewGroup(group_name, UnorderedGroup) |
+ |
+ def MultipleTimes(self, group_name="default"): |
+ """Move this method into group of calls which may be called multiple times. |
+ |
+ A group of repeating calls must be defined together, and must be executed in |
+ full before the next expected mehtod can be called. |
+ |
+ Args: |
+ group_name: the name of the unordered group. |
+ |
+ Returns: |
+ self |
+ """ |
+ return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup) |
+ |
+ def AndReturn(self, return_value): |
+ """Set the value to return when this method is called. |
+ |
+ Args: |
+ # return_value can be anything. |
+ """ |
+ |
+ self._return_value = return_value |
+ return return_value |
+ |
+ def AndRaise(self, exception): |
+ """Set the exception to raise when this method is called. |
+ |
+ Args: |
+ # exception: the exception to raise when this method is called. |
+ exception: Exception |
+ """ |
+ |
+ self._exception = exception |
+ |
+ def WithSideEffects(self, side_effects): |
+ """Set the side effects that are simulated when this method is called. |
+ |
+ Args: |
+ side_effects: A callable which modifies the parameters or other relevant |
+ state which a given test case depends on. |
+ |
+ Returns: |
+ Self for chaining with AndReturn and AndRaise. |
+ """ |
+ self._side_effects = side_effects |
+ return self |
+ |
+class Comparator: |
+ """Base class for all Mox comparators. |
+ |
+ A Comparator can be used as a parameter to a mocked method when the exact |
+ value is not known. For example, the code you are testing might build up a |
+ long SQL string that is passed to your mock DAO. You're only interested that |
+ the IN clause contains the proper primary keys, so you can set your mock |
+ up as follows: |
+ |
+ mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) |
+ |
+ Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'. |
+ |
+ A Comparator may replace one or more parameters, for example: |
+ # return at most 10 rows |
+ mock_dao.RunQuery(StrContains('SELECT'), 10) |
+ |
+ or |
+ |
+ # Return some non-deterministic number of rows |
+ mock_dao.RunQuery(StrContains('SELECT'), IsA(int)) |
+ """ |
+ |
+ def equals(self, rhs): |
+ """Special equals method that all comparators must implement. |
+ |
+ Args: |
+ rhs: any python object |
+ """ |
+ |
+ raise NotImplementedError, 'method must be implemented by a subclass.' |
+ |
+ def __eq__(self, rhs): |
+ return self.equals(rhs) |
+ |
+ def __ne__(self, rhs): |
+ return not self.equals(rhs) |
+ |
+ |
+class IsA(Comparator): |
+ """This class wraps a basic Python type or class. It is used to verify |
+ that a parameter is of the given type or class. |
+ |
+ Example: |
+ mock_dao.Connect(IsA(DbConnectInfo)) |
+ """ |
+ |
+ def __init__(self, class_name): |
+ """Initialize IsA |
+ |
+ Args: |
+ class_name: basic python type or a class |
+ """ |
+ |
+ self._class_name = class_name |
+ |
+ def equals(self, rhs): |
+ """Check to see if the RHS is an instance of class_name. |
+ |
+ Args: |
+ # rhs: the right hand side of the test |
+ rhs: object |
+ |
+ Returns: |
+ bool |
+ """ |
+ |
+ try: |
+ return isinstance(rhs, self._class_name) |
+ except TypeError: |
+ # Check raw types if there was a type error. This is helpful for |
+ # things like cStringIO.StringIO. |
+ return type(rhs) == type(self._class_name) |
+ |
+ def __repr__(self): |
+ return str(self._class_name) |
+ |
+class IsAlmost(Comparator): |
+ """Comparison class used to check whether a parameter is nearly equal |
+ to a given value. Generally useful for floating point numbers. |
+ |
+ Example mock_dao.SetTimeout((IsAlmost(3.9))) |
+ """ |
+ |
+ def __init__(self, float_value, places=7): |
+ """Initialize IsAlmost. |
+ |
+ Args: |
+ float_value: The value for making the comparison. |
+ places: The number of decimal places to round to. |
+ """ |
+ |
+ self._float_value = float_value |
+ self._places = places |
+ |
+ def equals(self, rhs): |
+ """Check to see if RHS is almost equal to float_value |
+ |
+ Args: |
+ rhs: the value to compare to float_value |
+ |
+ Returns: |
+ bool |
+ """ |
+ |
+ try: |
+ return round(rhs-self._float_value, self._places) == 0 |
+ except TypeError: |
+ # This is probably because either float_value or rhs is not a number. |
+ return False |
+ |
+ def __repr__(self): |
+ return str(self._float_value) |
+ |
+class StrContains(Comparator): |
+ """Comparison class used to check whether a substring exists in a |
+ string parameter. This can be useful in mocking a database with SQL |
+ passed in as a string parameter, for example. |
+ |
+ Example: |
+ mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) |
+ """ |
+ |
+ def __init__(self, search_string): |
+ """Initialize. |
+ |
+ Args: |
+ # search_string: the string you are searching for |
+ search_string: str |
+ """ |
+ |
+ self._search_string = search_string |
+ |
+ def equals(self, rhs): |
+ """Check to see if the search_string is contained in the rhs string. |
+ |
+ Args: |
+ # rhs: the right hand side of the test |
+ rhs: object |
+ |
+ Returns: |
+ bool |
+ """ |
+ |
+ try: |
+ return rhs.find(self._search_string) > -1 |
+ except Exception: |
+ return False |
+ |
+ def __repr__(self): |
+ return '<str containing \'%s\'>' % self._search_string |
+ |
+ |
+class Regex(Comparator): |
+ """Checks if a string matches a regular expression. |
+ |
+ This uses a given regular expression to determine equality. |
+ """ |
+ |
+ def __init__(self, pattern, flags=0): |
+ """Initialize. |
+ |
+ Args: |
+ # pattern is the regular expression to search for |
+ pattern: str |
+ # flags passed to re.compile function as the second argument |
+ flags: int |
+ """ |
+ |
+ self.regex = re.compile(pattern, flags=flags) |
+ |
+ def equals(self, rhs): |
+ """Check to see if rhs matches regular expression pattern. |
+ |
+ Returns: |
+ bool |
+ """ |
+ |
+ return self.regex.search(rhs) is not None |
+ |
+ def __repr__(self): |
+ s = '<regular expression \'%s\'' % self.regex.pattern |
+ if self.regex.flags: |
+ s += ', flags=%d' % self.regex.flags |
+ s += '>' |
+ return s |
+ |
+ |
+class In(Comparator): |
+ """Checks whether an item (or key) is in a list (or dict) parameter. |
+ |
+ Example: |
+ mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result) |
+ """ |
+ |
+ def __init__(self, key): |
+ """Initialize. |
+ |
+ Args: |
+ # key is any thing that could be in a list or a key in a dict |
+ """ |
+ |
+ self._key = key |
+ |
+ def equals(self, rhs): |
+ """Check to see whether key is in rhs. |
+ |
+ Args: |
+ rhs: dict |
+ |
+ Returns: |
+ bool |
+ """ |
+ |
+ return self._key in rhs |
+ |
+ def __repr__(self): |
+ return '<sequence or map containing \'%s\'>' % self._key |
+ |
+ |
+class Not(Comparator): |
+ """Checks whether a predicates is False. |
+ |
+ Example: |
+ mock_dao.UpdateUsers(Not(ContainsKeyValue('stevepm', stevepm_user_info))) |
+ """ |
+ |
+ def __init__(self, predicate): |
+ """Initialize. |
+ |
+ Args: |
+ # predicate: a Comparator instance. |
+ """ |
+ |
+ assert isinstance(predicate, Comparator), ("predicate %r must be a" |
+ " Comparator." % predicate) |
+ self._predicate = predicate |
+ |
+ def equals(self, rhs): |
+ """Check to see whether the predicate is False. |
+ |
+ Args: |
+ rhs: A value that will be given in argument of the predicate. |
+ |
+ Returns: |
+ bool |
+ """ |
+ |
+ return not self._predicate.equals(rhs) |
+ |
+ def __repr__(self): |
+ return '<not \'%s\'>' % self._predicate |
+ |
+ |
+class ContainsKeyValue(Comparator): |
+ """Checks whether a key/value pair is in a dict parameter. |
+ |
+ Example: |
+ mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info)) |
+ """ |
+ |
+ def __init__(self, key, value): |
+ """Initialize. |
+ |
+ Args: |
+ # key: a key in a dict |
+ # value: the corresponding value |
+ """ |
+ |
+ self._key = key |
+ self._value = value |
+ |
+ def equals(self, rhs): |
+ """Check whether the given key/value pair is in the rhs dict. |
+ |
+ Returns: |
+ bool |
+ """ |
+ |
+ try: |
+ return rhs[self._key] == self._value |
+ except Exception: |
+ return False |
+ |
+ def __repr__(self): |
+ return '<map containing the entry \'%s: %s\'>' % (self._key, self._value) |
+ |
+ |
+class SameElementsAs(Comparator): |
+ """Checks whether iterables contain the same elements (ignoring order). |
+ |
+ Example: |
+ mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki')) |
+ """ |
+ |
+ def __init__(self, expected_seq): |
+ """Initialize. |
+ |
+ Args: |
+ expected_seq: a sequence |
+ """ |
+ |
+ self._expected_seq = expected_seq |
+ |
+ def equals(self, actual_seq): |
+ """Check to see whether actual_seq has same elements as expected_seq. |
+ |
+ Args: |
+ actual_seq: sequence |
+ |
+ Returns: |
+ bool |
+ """ |
+ |
+ try: |
+ expected = dict([(element, None) for element in self._expected_seq]) |
+ actual = dict([(element, None) for element in actual_seq]) |
+ except TypeError: |
+ # Fall back to slower list-compare if any of the objects are unhashable. |
+ expected = list(self._expected_seq) |
+ actual = list(actual_seq) |
+ expected.sort() |
+ actual.sort() |
+ return expected == actual |
+ |
+ def __repr__(self): |
+ return '<sequence with same elements as \'%s\'>' % self._expected_seq |
+ |
+ |
+class And(Comparator): |
+ """Evaluates one or more Comparators on RHS and returns an AND of the results. |
+ """ |
+ |
+ def __init__(self, *args): |
+ """Initialize. |
+ |
+ Args: |
+ *args: One or more Comparator |
+ """ |
+ |
+ self._comparators = args |
+ |
+ def equals(self, rhs): |
+ """Checks whether all Comparators are equal to rhs. |
+ |
+ Args: |
+ # rhs: can be anything |
+ |
+ Returns: |
+ bool |
+ """ |
+ |
+ for comparator in self._comparators: |
+ if not comparator.equals(rhs): |
+ return False |
+ |
+ return True |
+ |
+ def __repr__(self): |
+ return '<AND %s>' % str(self._comparators) |
+ |
+ |
+class Or(Comparator): |
+ """Evaluates one or more Comparators on RHS and returns an OR of the results. |
+ """ |
+ |
+ def __init__(self, *args): |
+ """Initialize. |
+ |
+ Args: |
+ *args: One or more Mox comparators |
+ """ |
+ |
+ self._comparators = args |
+ |
+ def equals(self, rhs): |
+ """Checks whether any Comparator is equal to rhs. |
+ |
+ Args: |
+ # rhs: can be anything |
+ |
+ Returns: |
+ bool |
+ """ |
+ |
+ for comparator in self._comparators: |
+ if comparator.equals(rhs): |
+ return True |
+ |
+ return False |
+ |
+ def __repr__(self): |
+ return '<OR %s>' % str(self._comparators) |
+ |
+ |
+class Func(Comparator): |
+ """Call a function that should verify the parameter passed in is correct. |
+ |
+ You may need the ability to perform more advanced operations on the parameter |
+ in order to validate it. You can use this to have a callable validate any |
+ parameter. The callable should return either True or False. |
+ |
+ |
+ Example: |
+ |
+ def myParamValidator(param): |
+ # Advanced logic here |
+ return True |
+ |
+ mock_dao.DoSomething(Func(myParamValidator), true) |
+ """ |
+ |
+ def __init__(self, func): |
+ """Initialize. |
+ |
+ Args: |
+ func: callable that takes one parameter and returns a bool |
+ """ |
+ |
+ self._func = func |
+ |
+ def equals(self, rhs): |
+ """Test whether rhs passes the function test. |
+ |
+ rhs is passed into func. |
+ |
+ Args: |
+ rhs: any python object |
+ |
+ Returns: |
+ the result of func(rhs) |
+ """ |
+ |
+ return self._func(rhs) |
+ |
+ def __repr__(self): |
+ return str(self._func) |
+ |
+ |
+class IgnoreArg(Comparator): |
+ """Ignore an argument. |
+ |
+ This can be used when we don't care about an argument of a method call. |
+ |
+ Example: |
+ # Check if CastMagic is called with 3 as first arg and 'disappear' as third. |
+ mymock.CastMagic(3, IgnoreArg(), 'disappear') |
+ """ |
+ |
+ def equals(self, unused_rhs): |
+ """Ignores arguments and returns True. |
+ |
+ Args: |
+ unused_rhs: any python object |
+ |
+ Returns: |
+ always returns True |
+ """ |
+ |
+ return True |
+ |
+ def __repr__(self): |
+ return '<IgnoreArg>' |
+ |
+ |
+class MethodGroup(object): |
+ """Base class containing common behaviour for MethodGroups.""" |
+ |
+ def __init__(self, group_name): |
+ self._group_name = group_name |
+ |
+ def group_name(self): |
+ return self._group_name |
+ |
+ def __str__(self): |
+ return '<%s "%s">' % (self.__class__.__name__, self._group_name) |
+ |
+ def AddMethod(self, mock_method): |
+ raise NotImplementedError |
+ |
+ def MethodCalled(self, mock_method): |
+ raise NotImplementedError |
+ |
+ def IsSatisfied(self): |
+ raise NotImplementedError |
+ |
+class UnorderedGroup(MethodGroup): |
+ """UnorderedGroup holds a set of method calls that may occur in any order. |
+ |
+ This construct is helpful for non-deterministic events, such as iterating |
+ over the keys of a dict. |
+ """ |
+ |
+ def __init__(self, group_name): |
+ super(UnorderedGroup, self).__init__(group_name) |
+ self._methods = [] |
+ |
+ def AddMethod(self, mock_method): |
+ """Add a method to this group. |
+ |
+ Args: |
+ mock_method: A mock method to be added to this group. |
+ """ |
+ |
+ self._methods.append(mock_method) |
+ |
+ def MethodCalled(self, mock_method): |
+ """Remove a method call from the group. |
+ |
+ If the method is not in the set, an UnexpectedMethodCallError will be |
+ raised. |
+ |
+ Args: |
+ mock_method: a mock method that should be equal to a method in the group. |
+ |
+ Returns: |
+ The mock method from the group |
+ |
+ Raises: |
+ UnexpectedMethodCallError if the mock_method was not in the group. |
+ """ |
+ |
+ # Check to see if this method exists, and if so, remove it from the set |
+ # and return it. |
+ for method in self._methods: |
+ if method == mock_method: |
+ # Remove the called mock_method instead of the method in the group. |
+ # The called method will match any comparators when equality is checked |
+ # during removal. The method in the group could pass a comparator to |
+ # another comparator during the equality check. |
+ self._methods.remove(mock_method) |
+ |
+ # If this group is not empty, put it back at the head of the queue. |
+ if not self.IsSatisfied(): |
+ mock_method._call_queue.appendleft(self) |
+ |
+ return self, method |
+ |
+ raise UnexpectedMethodCallError(mock_method, self) |
+ |
+ def IsSatisfied(self): |
+ """Return True if there are not any methods in this group.""" |
+ |
+ return len(self._methods) == 0 |
+ |
+ |
+class MultipleTimesGroup(MethodGroup): |
+ """MultipleTimesGroup holds methods that may be called any number of times. |
+ |
+ Note: Each method must be called at least once. |
+ |
+ This is helpful, if you don't know or care how many times a method is called. |
+ """ |
+ |
+ def __init__(self, group_name): |
+ super(MultipleTimesGroup, self).__init__(group_name) |
+ self._methods = set() |
+ self._methods_left = set() |
+ |
+ def AddMethod(self, mock_method): |
+ """Add a method to this group. |
+ |
+ Args: |
+ mock_method: A mock method to be added to this group. |
+ """ |
+ |
+ self._methods.add(mock_method) |
+ self._methods_left.add(mock_method) |
+ |
+ def MethodCalled(self, mock_method): |
+ """Remove a method call from the group. |
+ |
+ If the method is not in the set, an UnexpectedMethodCallError will be |
+ raised. |
+ |
+ Args: |
+ mock_method: a mock method that should be equal to a method in the group. |
+ |
+ Returns: |
+ The mock method from the group |
+ |
+ Raises: |
+ UnexpectedMethodCallError if the mock_method was not in the group. |
+ """ |
+ |
+ # Check to see if this method exists, and if so add it to the set of |
+ # called methods. |
+ for method in self._methods: |
+ if method == mock_method: |
+ self._methods_left.discard(method) |
+ # Always put this group back on top of the queue, because we don't know |
+ # when we are done. |
+ mock_method._call_queue.appendleft(self) |
+ return self, method |
+ |
+ if self.IsSatisfied(): |
+ next_method = mock_method._PopNextMethod(); |
+ return next_method, None |
+ else: |
+ raise UnexpectedMethodCallError(mock_method, self) |
+ |
+ def IsSatisfied(self): |
+ """Return True if all methods in this group are called at least once.""" |
+ return len(self._methods_left) == 0 |
+ |
+ |
+class MoxMetaTestBase(type): |
+ """Metaclass to add mox cleanup and verification to every test. |
+ |
+ As the mox unit testing class is being constructed (MoxTestBase or a |
+ subclass), this metaclass will modify all test functions to call the |
+ CleanUpMox method of the test class after they finish. This means that |
+ unstubbing and verifying will happen for every test with no additional code, |
+ and any failures will result in test failures as opposed to errors. |
+ """ |
+ |
+ def __init__(cls, name, bases, d): |
+ type.__init__(cls, name, bases, d) |
+ |
+ # also get all the attributes from the base classes to account |
+ # for a case when test class is not the immediate child of MoxTestBase |
+ for base in bases: |
+ for attr_name in dir(base): |
+ d[attr_name] = getattr(base, attr_name) |
+ |
+ for func_name, func in d.items(): |
+ if func_name.startswith('test') and callable(func): |
+ setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func)) |
+ |
+ @staticmethod |
+ def CleanUpTest(cls, func): |
+ """Adds Mox cleanup code to any MoxTestBase method. |
+ |
+ Always unsets stubs after a test. Will verify all mocks for tests that |
+ otherwise pass. |
+ |
+ Args: |
+ cls: MoxTestBase or subclass; the class whose test method we are altering. |
+ func: method; the method of the MoxTestBase test class we wish to alter. |
+ |
+ Returns: |
+ The modified method. |
+ """ |
+ def new_method(self, *args, **kwargs): |
+ mox_obj = getattr(self, 'mox', None) |
+ cleanup_mox = False |
+ if mox_obj and isinstance(mox_obj, Mox): |
+ cleanup_mox = True |
+ try: |
+ func(self, *args, **kwargs) |
+ finally: |
+ if cleanup_mox: |
+ mox_obj.UnsetStubs() |
+ if cleanup_mox: |
+ mox_obj.VerifyAll() |
+ new_method.__name__ = func.__name__ |
+ new_method.__doc__ = func.__doc__ |
+ new_method.__module__ = func.__module__ |
+ return new_method |
+ |
+ |
+class MoxTestBase(unittest.TestCase): |
+ """Convenience test class to make stubbing easier. |
+ |
+ Sets up a "mox" attribute which is an instance of Mox - any mox tests will |
+ want this. Also automatically unsets any stubs and verifies that all mock |
+ methods have been called at the end of each test, eliminating boilerplate |
+ code. |
+ """ |
+ |
+ __metaclass__ = MoxMetaTestBase |
+ |
+ def setUp(self): |
+ super(MoxTestBase, self).setUp() |
+ self.mox = Mox() |
Property changes on: tests\pymox\mox.py |
___________________________________________________________________ |
Added: svn:eol-style |
+ LF |