| OLD | NEW |
| (Empty) |
| 1 # Copyright (c) 2012 The Chromium Authors. All rights reserved. | |
| 2 # Use of this source code is governed by a BSD-style license that can be | |
| 3 # found in the LICENSE file. | |
| 4 """A very very simple mock object harness.""" | |
| 5 | |
| 6 DONT_CARE = '' | |
| 7 | |
| 8 class MockFunctionCall(object): | |
| 9 def __init__(self, name): | |
| 10 self.name = name | |
| 11 self.args = tuple() | |
| 12 self.return_value = None | |
| 13 self.when_called_handlers = [] | |
| 14 | |
| 15 def WithArgs(self, *args): | |
| 16 self.args = args | |
| 17 return self | |
| 18 | |
| 19 def WillReturn(self, value): | |
| 20 self.return_value = value | |
| 21 return self | |
| 22 | |
| 23 def WhenCalled(self, handler): | |
| 24 self.when_called_handlers.append(handler) | |
| 25 | |
| 26 def VerifyEquals(self, got): | |
| 27 if self.name != got.name: | |
| 28 raise Exception('Self %s, got %s' % (repr(self), repr(got))) | |
| 29 if len(self.args) != len(got.args): | |
| 30 raise Exception('Self %s, got %s' % (repr(self), repr(got))) | |
| 31 for i in range(len(self.args)): | |
| 32 self_a = self.args[i] | |
| 33 got_a = got.args[i] | |
| 34 if self_a == DONT_CARE: | |
| 35 continue | |
| 36 if self_a != got_a: | |
| 37 raise Exception('Self %s, got %s' % (repr(self), repr(got))) | |
| 38 | |
| 39 def __repr__(self): | |
| 40 def arg_to_text(a): | |
| 41 if a == DONT_CARE: | |
| 42 return '_' | |
| 43 return repr(a) | |
| 44 args_text = ', '.join([arg_to_text(a) for a in self.args]) | |
| 45 if self.return_value in (None, DONT_CARE): | |
| 46 return '%s(%s)' % (self.name, args_text) | |
| 47 return '%s(%s)->%s' % (self.name, args_text, repr(self.return_value)) | |
| 48 | |
| 49 class MockTrace(object): | |
| 50 def __init__(self): | |
| 51 self.expected_calls = [] | |
| 52 self.next_call_index = 0 | |
| 53 | |
| 54 class MockObject(object): | |
| 55 def __init__(self, parent_mock = None): | |
| 56 if parent_mock: | |
| 57 self._trace = parent_mock._trace # pylint: disable=W0212 | |
| 58 else: | |
| 59 self._trace = MockTrace() | |
| 60 | |
| 61 def __setattr__(self, name, value): | |
| 62 if (not hasattr(self, '_trace') or | |
| 63 hasattr(value, 'is_hook')): | |
| 64 object.__setattr__(self, name, value) | |
| 65 return | |
| 66 assert isinstance(value, MockObject) | |
| 67 object.__setattr__(self, name, value) | |
| 68 | |
| 69 def ExpectCall(self, func_name, *args): | |
| 70 assert self._trace.next_call_index == 0 | |
| 71 if not hasattr(self, func_name): | |
| 72 self._install_hook(func_name) | |
| 73 | |
| 74 call = MockFunctionCall(func_name) | |
| 75 self._trace.expected_calls.append(call) | |
| 76 call.WithArgs(*args) | |
| 77 return call | |
| 78 | |
| 79 def _install_hook(self, func_name): | |
| 80 def handler(*args): | |
| 81 got_call = MockFunctionCall( | |
| 82 func_name).WithArgs(*args).WillReturn(DONT_CARE) | |
| 83 if self._trace.next_call_index >= len(self._trace.expected_calls): | |
| 84 raise Exception( | |
| 85 'Call to %s was not expected, at end of programmed trace.' % | |
| 86 repr(got_call)) | |
| 87 expected_call = self._trace.expected_calls[ | |
| 88 self._trace.next_call_index] | |
| 89 expected_call.VerifyEquals(got_call) | |
| 90 self._trace.next_call_index += 1 | |
| 91 for h in expected_call.when_called_handlers: | |
| 92 h(*args) | |
| 93 return expected_call.return_value | |
| 94 handler.is_hook = True | |
| 95 setattr(self, func_name, handler) | |
| 96 | |
| 97 | |
| 98 | |
| OLD | NEW |