OLD | NEW |
| (Empty) |
1 #!/usr/bin/python2.4 | |
2 # | |
3 # Copyright 2008 Google Inc. | |
4 # | |
5 # Licensed under the Apache License, Version 2.0 (the "License"); | |
6 # you may not use this file except in compliance with the License. | |
7 # You may obtain a copy of the License at | |
8 # | |
9 # http://www.apache.org/licenses/LICENSE-2.0 | |
10 # | |
11 # Unless required by applicable law or agreed to in writing, software | |
12 # distributed under the License is distributed on an "AS IS" BASIS, | |
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
14 # See the License for the specific language governing permissions and | |
15 # limitations under the License. | |
16 | |
17 """Mox, an object-mocking framework for Python. | |
18 | |
19 Mox works in the record-replay-verify paradigm. When you first create | |
20 a mock object, it is in record mode. You then programmatically set | |
21 the expected behavior of the mock object (what methods are to be | |
22 called on it, with what parameters, what they should return, and in | |
23 what order). | |
24 | |
25 Once you have set up the expected mock behavior, you put it in replay | |
26 mode. Now the mock responds to method calls just as you told it to. | |
27 If an unexpected method (or an expected method with unexpected | |
28 parameters) is called, then an exception will be raised. | |
29 | |
30 Once you are done interacting with the mock, you need to verify that | |
31 all the expected interactions occured. (Maybe your code exited | |
32 prematurely without calling some cleanup method!) The verify phase | |
33 ensures that every expected method was called; otherwise, an exception | |
34 will be raised. | |
35 | |
36 WARNING! Mock objects created by Mox are not thread-safe. If you are | |
37 call a mock in multiple threads, it should be guarded by a mutex. | |
38 | |
39 TODO(stevepm): Add the option to make mocks thread-safe! | |
40 | |
41 Suggested usage / workflow: | |
42 | |
43 # Create Mox factory | |
44 my_mox = Mox() | |
45 | |
46 # Create a mock data access object | |
47 mock_dao = my_mox.CreateMock(DAOClass) | |
48 | |
49 # Set up expected behavior | |
50 mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person) | |
51 mock_dao.DeletePerson(person) | |
52 | |
53 # Put mocks in replay mode | |
54 my_mox.ReplayAll() | |
55 | |
56 # Inject mock object and run test | |
57 controller.SetDao(mock_dao) | |
58 controller.DeletePersonById('1') | |
59 | |
60 # Verify all methods were called as expected | |
61 my_mox.VerifyAll() | |
62 """ | |
63 | |
64 from collections import deque | |
65 import difflib | |
66 import inspect | |
67 import re | |
68 import types | |
69 import unittest | |
70 | |
71 import stubout | |
72 | |
73 class Error(AssertionError): | |
74 """Base exception for this module.""" | |
75 | |
76 pass | |
77 | |
78 | |
79 class ExpectedMethodCallsError(Error): | |
80 """Raised when Verify() is called before all expected methods have been called | |
81 """ | |
82 | |
83 def __init__(self, expected_methods): | |
84 """Init exception. | |
85 | |
86 Args: | |
87 # expected_methods: A sequence of MockMethod objects that should have been | |
88 # called. | |
89 expected_methods: [MockMethod] | |
90 | |
91 Raises: | |
92 ValueError: if expected_methods contains no methods. | |
93 """ | |
94 | |
95 if not expected_methods: | |
96 raise ValueError("There must be at least one expected method") | |
97 Error.__init__(self) | |
98 self._expected_methods = expected_methods | |
99 | |
100 def __str__(self): | |
101 calls = "\n".join(["%3d. %s" % (i, m) | |
102 for i, m in enumerate(self._expected_methods)]) | |
103 return "Verify: Expected methods never called:\n%s" % (calls,) | |
104 | |
105 | |
106 class UnexpectedMethodCallError(Error): | |
107 """Raised when an unexpected method is called. | |
108 | |
109 This can occur if a method is called with incorrect parameters, or out of the | |
110 specified order. | |
111 """ | |
112 | |
113 def __init__(self, unexpected_method, expected): | |
114 """Init exception. | |
115 | |
116 Args: | |
117 # unexpected_method: MockMethod that was called but was not at the head of | |
118 # the expected_method queue. | |
119 # expected: MockMethod or UnorderedGroup the method should have | |
120 # been in. | |
121 unexpected_method: MockMethod | |
122 expected: MockMethod or UnorderedGroup | |
123 """ | |
124 | |
125 Error.__init__(self) | |
126 if expected is None: | |
127 self._str = "Unexpected method call %s" % (unexpected_method,) | |
128 else: | |
129 differ = difflib.Differ() | |
130 diff = differ.compare(str(unexpected_method).splitlines(True), | |
131 str(expected).splitlines(True)) | |
132 self._str = ("Unexpected method call. unexpected:- expected:+\n%s" | |
133 % ("\n".join(diff),)) | |
134 | |
135 def __str__(self): | |
136 return self._str | |
137 | |
138 | |
139 class UnknownMethodCallError(Error): | |
140 """Raised if an unknown method is requested of the mock object.""" | |
141 | |
142 def __init__(self, unknown_method_name): | |
143 """Init exception. | |
144 | |
145 Args: | |
146 # unknown_method_name: Method call that is not part of the mocked class's | |
147 # public interface. | |
148 unknown_method_name: str | |
149 """ | |
150 | |
151 Error.__init__(self) | |
152 self._unknown_method_name = unknown_method_name | |
153 | |
154 def __str__(self): | |
155 return "Method called is not a member of the object: %s" % \ | |
156 self._unknown_method_name | |
157 | |
158 | |
159 class Mox(object): | |
160 """Mox: a factory for creating mock objects.""" | |
161 | |
162 # A list of types that should be stubbed out with MockObjects (as | |
163 # opposed to MockAnythings). | |
164 _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType, | |
165 types.ObjectType, types.TypeType] | |
166 | |
167 def __init__(self): | |
168 """Initialize a new Mox.""" | |
169 | |
170 self._mock_objects = [] | |
171 self.stubs = stubout.StubOutForTesting() | |
172 | |
173 def CreateMock(self, class_to_mock): | |
174 """Create a new mock object. | |
175 | |
176 Args: | |
177 # class_to_mock: the class to be mocked | |
178 class_to_mock: class | |
179 | |
180 Returns: | |
181 MockObject that can be used as the class_to_mock would be. | |
182 """ | |
183 | |
184 new_mock = MockObject(class_to_mock) | |
185 self._mock_objects.append(new_mock) | |
186 return new_mock | |
187 | |
188 def CreateMockAnything(self, description=None): | |
189 """Create a mock that will accept any method calls. | |
190 | |
191 This does not enforce an interface. | |
192 | |
193 Args: | |
194 description: str. Optionally, a descriptive name for the mock object being | |
195 created, for debugging output purposes. | |
196 """ | |
197 new_mock = MockAnything(description=description) | |
198 self._mock_objects.append(new_mock) | |
199 return new_mock | |
200 | |
201 def ReplayAll(self): | |
202 """Set all mock objects to replay mode.""" | |
203 | |
204 for mock_obj in self._mock_objects: | |
205 mock_obj._Replay() | |
206 | |
207 | |
208 def VerifyAll(self): | |
209 """Call verify on all mock objects created.""" | |
210 | |
211 for mock_obj in self._mock_objects: | |
212 mock_obj._Verify() | |
213 | |
214 def ResetAll(self): | |
215 """Call reset on all mock objects. This does not unset stubs.""" | |
216 | |
217 for mock_obj in self._mock_objects: | |
218 mock_obj._Reset() | |
219 | |
220 def StubOutWithMock(self, obj, attr_name, use_mock_anything=False): | |
221 """Replace a method, attribute, etc. with a Mock. | |
222 | |
223 This will replace a class or module with a MockObject, and everything else | |
224 (method, function, etc) with a MockAnything. This can be overridden to | |
225 always use a MockAnything by setting use_mock_anything to True. | |
226 | |
227 Args: | |
228 obj: A Python object (class, module, instance, callable). | |
229 attr_name: str. The name of the attribute to replace with a mock. | |
230 use_mock_anything: bool. True if a MockAnything should be used regardless | |
231 of the type of attribute. | |
232 """ | |
233 | |
234 attr_to_replace = getattr(obj, attr_name) | |
235 | |
236 # Check for a MockAnything. This could cause confusing problems later on. | |
237 if attr_to_replace == MockAnything(): | |
238 raise TypeError('Cannot mock a MockAnything! Did you remember to ' | |
239 'call UnsetStubs in your previous test?') | |
240 | |
241 if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything: | |
242 stub = self.CreateMock(attr_to_replace) | |
243 else: | |
244 stub = self.CreateMockAnything(description='Stub for %s' % attr_to_replace
) | |
245 | |
246 self.stubs.Set(obj, attr_name, stub) | |
247 | |
248 def UnsetStubs(self): | |
249 """Restore stubs to their original state.""" | |
250 | |
251 self.stubs.UnsetAll() | |
252 | |
253 def Replay(*args): | |
254 """Put mocks into Replay mode. | |
255 | |
256 Args: | |
257 # args is any number of mocks to put into replay mode. | |
258 """ | |
259 | |
260 for mock in args: | |
261 mock._Replay() | |
262 | |
263 | |
264 def Verify(*args): | |
265 """Verify mocks. | |
266 | |
267 Args: | |
268 # args is any number of mocks to be verified. | |
269 """ | |
270 | |
271 for mock in args: | |
272 mock._Verify() | |
273 | |
274 | |
275 def Reset(*args): | |
276 """Reset mocks. | |
277 | |
278 Args: | |
279 # args is any number of mocks to be reset. | |
280 """ | |
281 | |
282 for mock in args: | |
283 mock._Reset() | |
284 | |
285 | |
286 class MockAnything: | |
287 """A mock that can be used to mock anything. | |
288 | |
289 This is helpful for mocking classes that do not provide a public interface. | |
290 """ | |
291 | |
292 def __init__(self, description=None): | |
293 """Initialize a new MockAnything. | |
294 | |
295 Args: | |
296 description: str. Optionally, a descriptive name for the mock object being | |
297 created, for debugging output purposes. | |
298 """ | |
299 self._description = description | |
300 self._Reset() | |
301 | |
302 def __str__(self): | |
303 return "<MockAnything instance at %s>" % id(self) | |
304 | |
305 def __repr__(self): | |
306 return '<MockAnything instance>' | |
307 | |
308 def __getattr__(self, method_name): | |
309 """Intercept method calls on this object. | |
310 | |
311 A new MockMethod is returned that is aware of the MockAnything's | |
312 state (record or replay). The call will be recorded or replayed | |
313 by the MockMethod's __call__. | |
314 | |
315 Args: | |
316 # method name: the name of the method being called. | |
317 method_name: str | |
318 | |
319 Returns: | |
320 A new MockMethod aware of MockAnything's state (record or replay). | |
321 """ | |
322 | |
323 return self._CreateMockMethod(method_name) | |
324 | |
325 def _CreateMockMethod(self, method_name, method_to_mock=None): | |
326 """Create a new mock method call and return it. | |
327 | |
328 Args: | |
329 # method_name: the name of the method being called. | |
330 # method_to_mock: The actual method being mocked, used for introspection. | |
331 method_name: str | |
332 method_to_mock: a method object | |
333 | |
334 Returns: | |
335 A new MockMethod aware of MockAnything's state (record or replay). | |
336 """ | |
337 | |
338 return MockMethod(method_name, self._expected_calls_queue, | |
339 self._replay_mode, method_to_mock=method_to_mock, | |
340 description=self._description) | |
341 | |
342 def __nonzero__(self): | |
343 """Return 1 for nonzero so the mock can be used as a conditional.""" | |
344 | |
345 return 1 | |
346 | |
347 def __eq__(self, rhs): | |
348 """Provide custom logic to compare objects.""" | |
349 | |
350 return (isinstance(rhs, MockAnything) and | |
351 self._replay_mode == rhs._replay_mode and | |
352 self._expected_calls_queue == rhs._expected_calls_queue) | |
353 | |
354 def __ne__(self, rhs): | |
355 """Provide custom logic to compare objects.""" | |
356 | |
357 return not self == rhs | |
358 | |
359 def _Replay(self): | |
360 """Start replaying expected method calls.""" | |
361 | |
362 self._replay_mode = True | |
363 | |
364 def _Verify(self): | |
365 """Verify that all of the expected calls have been made. | |
366 | |
367 Raises: | |
368 ExpectedMethodCallsError: if there are still more method calls in the | |
369 expected queue. | |
370 """ | |
371 | |
372 # If the list of expected calls is not empty, raise an exception | |
373 if self._expected_calls_queue: | |
374 # The last MultipleTimesGroup is not popped from the queue. | |
375 if (len(self._expected_calls_queue) == 1 and | |
376 isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and | |
377 self._expected_calls_queue[0].IsSatisfied()): | |
378 pass | |
379 else: | |
380 raise ExpectedMethodCallsError(self._expected_calls_queue) | |
381 | |
382 def _Reset(self): | |
383 """Reset the state of this mock to record mode with an empty queue.""" | |
384 | |
385 # Maintain a list of method calls we are expecting | |
386 self._expected_calls_queue = deque() | |
387 | |
388 # Make sure we are in setup mode, not replay mode | |
389 self._replay_mode = False | |
390 | |
391 | |
392 class MockObject(MockAnything, object): | |
393 """A mock object that simulates the public/protected interface of a class.""" | |
394 | |
395 def __init__(self, class_to_mock): | |
396 """Initialize a mock object. | |
397 | |
398 This determines the methods and properties of the class and stores them. | |
399 | |
400 Args: | |
401 # class_to_mock: class to be mocked | |
402 class_to_mock: class | |
403 """ | |
404 | |
405 # This is used to hack around the mixin/inheritance of MockAnything, which | |
406 # is not a proper object (it can be anything. :-) | |
407 MockAnything.__dict__['__init__'](self) | |
408 | |
409 # Get a list of all the public and special methods we should mock. | |
410 self._known_methods = set() | |
411 self._known_vars = set() | |
412 self._class_to_mock = class_to_mock | |
413 for method in dir(class_to_mock): | |
414 if callable(getattr(class_to_mock, method)): | |
415 self._known_methods.add(method) | |
416 else: | |
417 self._known_vars.add(method) | |
418 | |
419 def __getattr__(self, name): | |
420 """Intercept attribute request on this object. | |
421 | |
422 If the attribute is a public class variable, it will be returned and not | |
423 recorded as a call. | |
424 | |
425 If the attribute is not a variable, it is handled like a method | |
426 call. The method name is checked against the set of mockable | |
427 methods, and a new MockMethod is returned that is aware of the | |
428 MockObject's state (record or replay). The call will be recorded | |
429 or replayed by the MockMethod's __call__. | |
430 | |
431 Args: | |
432 # name: the name of the attribute being requested. | |
433 name: str | |
434 | |
435 Returns: | |
436 Either a class variable or a new MockMethod that is aware of the state | |
437 of the mock (record or replay). | |
438 | |
439 Raises: | |
440 UnknownMethodCallError if the MockObject does not mock the requested | |
441 method. | |
442 """ | |
443 | |
444 if name in self._known_vars: | |
445 return getattr(self._class_to_mock, name) | |
446 | |
447 if name in self._known_methods: | |
448 return self._CreateMockMethod( | |
449 name, | |
450 method_to_mock=getattr(self._class_to_mock, name)) | |
451 | |
452 raise UnknownMethodCallError(name) | |
453 | |
454 def __eq__(self, rhs): | |
455 """Provide custom logic to compare objects.""" | |
456 | |
457 return (isinstance(rhs, MockObject) and | |
458 self._class_to_mock == rhs._class_to_mock and | |
459 self._replay_mode == rhs._replay_mode and | |
460 self._expected_calls_queue == rhs._expected_calls_queue) | |
461 | |
462 def __setitem__(self, key, value): | |
463 """Provide custom logic for mocking classes that support item assignment. | |
464 | |
465 Args: | |
466 key: Key to set the value for. | |
467 value: Value to set. | |
468 | |
469 Returns: | |
470 Expected return value in replay mode. A MockMethod object for the | |
471 __setitem__ method that has already been called if not in replay mode. | |
472 | |
473 Raises: | |
474 TypeError if the underlying class does not support item assignment. | |
475 UnexpectedMethodCallError if the object does not expect the call to | |
476 __setitem__. | |
477 | |
478 """ | |
479 # Verify the class supports item assignment. | |
480 if '__setitem__' not in dir(self._class_to_mock): | |
481 raise TypeError('object does not support item assignment') | |
482 | |
483 # If we are in replay mode then simply call the mock __setitem__ method. | |
484 if self._replay_mode: | |
485 return MockMethod('__setitem__', self._expected_calls_queue, | |
486 self._replay_mode)(key, value) | |
487 | |
488 | |
489 # Otherwise, create a mock method __setitem__. | |
490 return self._CreateMockMethod('__setitem__')(key, value) | |
491 | |
492 def __getitem__(self, key): | |
493 """Provide custom logic for mocking classes that are subscriptable. | |
494 | |
495 Args: | |
496 key: Key to return the value for. | |
497 | |
498 Returns: | |
499 Expected return value in replay mode. A MockMethod object for the | |
500 __getitem__ method that has already been called if not in replay mode. | |
501 | |
502 Raises: | |
503 TypeError if the underlying class is not subscriptable. | |
504 UnexpectedMethodCallError if the object does not expect the call to | |
505 __getitem__. | |
506 | |
507 """ | |
508 # Verify the class supports item assignment. | |
509 if '__getitem__' not in dir(self._class_to_mock): | |
510 raise TypeError('unsubscriptable object') | |
511 | |
512 # If we are in replay mode then simply call the mock __getitem__ method. | |
513 if self._replay_mode: | |
514 return MockMethod('__getitem__', self._expected_calls_queue, | |
515 self._replay_mode)(key) | |
516 | |
517 | |
518 # Otherwise, create a mock method __getitem__. | |
519 return self._CreateMockMethod('__getitem__')(key) | |
520 | |
521 def __iter__(self): | |
522 """Provide custom logic for mocking classes that are iterable. | |
523 | |
524 Returns: | |
525 Expected return value in replay mode. A MockMethod object for the | |
526 __iter__ method that has already been called if not in replay mode. | |
527 | |
528 Raises: | |
529 TypeError if the underlying class is not iterable. | |
530 UnexpectedMethodCallError if the object does not expect the call to | |
531 __iter__. | |
532 | |
533 """ | |
534 methods = dir(self._class_to_mock) | |
535 | |
536 # Verify the class supports iteration. | |
537 if '__iter__' not in methods: | |
538 # If it doesn't have iter method and we are in replay method, then try to | |
539 # iterate using subscripts. | |
540 if '__getitem__' not in methods or not self._replay_mode: | |
541 raise TypeError('not iterable object') | |
542 else: | |
543 results = [] | |
544 index = 0 | |
545 try: | |
546 while True: | |
547 results.append(self[index]) | |
548 index += 1 | |
549 except IndexError: | |
550 return iter(results) | |
551 | |
552 # If we are in replay mode then simply call the mock __iter__ method. | |
553 if self._replay_mode: | |
554 return MockMethod('__iter__', self._expected_calls_queue, | |
555 self._replay_mode)() | |
556 | |
557 | |
558 # Otherwise, create a mock method __iter__. | |
559 return self._CreateMockMethod('__iter__')() | |
560 | |
561 | |
562 def __contains__(self, key): | |
563 """Provide custom logic for mocking classes that contain items. | |
564 | |
565 Args: | |
566 key: Key to look in container for. | |
567 | |
568 Returns: | |
569 Expected return value in replay mode. A MockMethod object for the | |
570 __contains__ method that has already been called if not in replay mode. | |
571 | |
572 Raises: | |
573 TypeError if the underlying class does not implement __contains__ | |
574 UnexpectedMethodCaller if the object does not expect the call to | |
575 __contains__. | |
576 | |
577 """ | |
578 contains = self._class_to_mock.__dict__.get('__contains__', None) | |
579 | |
580 if contains is None: | |
581 raise TypeError('unsubscriptable object') | |
582 | |
583 if self._replay_mode: | |
584 return MockMethod('__contains__', self._expected_calls_queue, | |
585 self._replay_mode)(key) | |
586 | |
587 return self._CreateMockMethod('__contains__')(key) | |
588 | |
589 def __call__(self, *params, **named_params): | |
590 """Provide custom logic for mocking classes that are callable.""" | |
591 | |
592 # Verify the class we are mocking is callable. | |
593 callable = hasattr(self._class_to_mock, '__call__') | |
594 if not callable: | |
595 raise TypeError('Not callable') | |
596 | |
597 # Because the call is happening directly on this object instead of a method, | |
598 # the call on the mock method is made right here | |
599 mock_method = self._CreateMockMethod('__call__') | |
600 return mock_method(*params, **named_params) | |
601 | |
602 @property | |
603 def __class__(self): | |
604 """Return the class that is being mocked.""" | |
605 | |
606 return self._class_to_mock | |
607 | |
608 | |
609 class MethodCallChecker(object): | |
610 """Ensures that methods are called correctly.""" | |
611 | |
612 _NEEDED, _DEFAULT, _GIVEN = range(3) | |
613 | |
614 def __init__(self, method): | |
615 """Creates a checker. | |
616 | |
617 Args: | |
618 # method: A method to check. | |
619 method: function | |
620 | |
621 Raises: | |
622 ValueError: method could not be inspected, so checks aren't possible. | |
623 Some methods and functions like built-ins can't be inspected. | |
624 """ | |
625 try: | |
626 self._args, varargs, varkw, defaults = inspect.getargspec(method) | |
627 except TypeError: | |
628 raise ValueError('Could not get argument specification for %r' | |
629 % (method,)) | |
630 if inspect.ismethod(method): | |
631 self._args = self._args[1:] # Skip 'self'. | |
632 self._method = method | |
633 | |
634 self._has_varargs = varargs is not None | |
635 self._has_varkw = varkw is not None | |
636 if defaults is None: | |
637 self._required_args = self._args | |
638 self._default_args = [] | |
639 else: | |
640 self._required_args = self._args[:-len(defaults)] | |
641 self._default_args = self._args[-len(defaults):] | |
642 | |
643 def _RecordArgumentGiven(self, arg_name, arg_status): | |
644 """Mark an argument as being given. | |
645 | |
646 Args: | |
647 # arg_name: The name of the argument to mark in arg_status. | |
648 # arg_status: Maps argument names to one of _NEEDED, _DEFAULT, _GIVEN. | |
649 arg_name: string | |
650 arg_status: dict | |
651 | |
652 Raises: | |
653 AttributeError: arg_name is already marked as _GIVEN. | |
654 """ | |
655 if arg_status.get(arg_name, None) == MethodCallChecker._GIVEN: | |
656 raise AttributeError('%s provided more than once' % (arg_name,)) | |
657 arg_status[arg_name] = MethodCallChecker._GIVEN | |
658 | |
659 def Check(self, params, named_params): | |
660 """Ensures that the parameters used while recording a call are valid. | |
661 | |
662 Args: | |
663 # params: A list of positional parameters. | |
664 # named_params: A dict of named parameters. | |
665 params: list | |
666 named_params: dict | |
667 | |
668 Raises: | |
669 AttributeError: the given parameters don't work with the given method. | |
670 """ | |
671 arg_status = dict((a, MethodCallChecker._NEEDED) | |
672 for a in self._required_args) | |
673 for arg in self._default_args: | |
674 arg_status[arg] = MethodCallChecker._DEFAULT | |
675 | |
676 # Check that each positional param is valid. | |
677 for i in range(len(params)): | |
678 try: | |
679 arg_name = self._args[i] | |
680 except IndexError: | |
681 if not self._has_varargs: | |
682 raise AttributeError('%s does not take %d or more positional ' | |
683 'arguments' % (self._method.__name__, i)) | |
684 else: | |
685 self._RecordArgumentGiven(arg_name, arg_status) | |
686 | |
687 # Check each keyword argument. | |
688 for arg_name in named_params: | |
689 if arg_name not in arg_status and not self._has_varkw: | |
690 raise AttributeError('%s is not expecting keyword argument %s' | |
691 % (self._method.__name__, arg_name)) | |
692 self._RecordArgumentGiven(arg_name, arg_status) | |
693 | |
694 # Ensure all the required arguments have been given. | |
695 still_needed = [k for k, v in arg_status.iteritems() | |
696 if v == MethodCallChecker._NEEDED] | |
697 if still_needed: | |
698 raise AttributeError('No values given for arguments %s' | |
699 % (' '.join(sorted(still_needed)))) | |
700 | |
701 | |
702 class MockMethod(object): | |
703 """Callable mock method. | |
704 | |
705 A MockMethod should act exactly like the method it mocks, accepting parameters | |
706 and returning a value, or throwing an exception (as specified). When this | |
707 method is called, it can optionally verify whether the called method (name and | |
708 signature) matches the expected method. | |
709 """ | |
710 | |
711 def __init__(self, method_name, call_queue, replay_mode, | |
712 method_to_mock=None, description=None): | |
713 """Construct a new mock method. | |
714 | |
715 Args: | |
716 # method_name: the name of the method | |
717 # call_queue: deque of calls, verify this call against the head, or add | |
718 # this call to the queue. | |
719 # replay_mode: False if we are recording, True if we are verifying calls | |
720 # against the call queue. | |
721 # method_to_mock: The actual method being mocked, used for introspection. | |
722 # description: optionally, a descriptive name for this method. Typically | |
723 # this is equal to the descriptive name of the method's class. | |
724 method_name: str | |
725 call_queue: list or deque | |
726 replay_mode: bool | |
727 method_to_mock: a method object | |
728 description: str or None | |
729 """ | |
730 | |
731 self._name = method_name | |
732 self._call_queue = call_queue | |
733 if not isinstance(call_queue, deque): | |
734 self._call_queue = deque(self._call_queue) | |
735 self._replay_mode = replay_mode | |
736 self._description = description | |
737 | |
738 self._params = None | |
739 self._named_params = None | |
740 self._return_value = None | |
741 self._exception = None | |
742 self._side_effects = None | |
743 | |
744 try: | |
745 self._checker = MethodCallChecker(method_to_mock) | |
746 except ValueError: | |
747 self._checker = None | |
748 | |
749 def __call__(self, *params, **named_params): | |
750 """Log parameters and return the specified return value. | |
751 | |
752 If the Mock(Anything/Object) associated with this call is in record mode, | |
753 this MockMethod will be pushed onto the expected call queue. If the mock | |
754 is in replay mode, this will pop a MockMethod off the top of the queue and | |
755 verify this call is equal to the expected call. | |
756 | |
757 Raises: | |
758 UnexpectedMethodCall if this call is supposed to match an expected method | |
759 call and it does not. | |
760 """ | |
761 | |
762 self._params = params | |
763 self._named_params = named_params | |
764 | |
765 if not self._replay_mode: | |
766 if self._checker is not None: | |
767 self._checker.Check(params, named_params) | |
768 self._call_queue.append(self) | |
769 return self | |
770 | |
771 expected_method = self._VerifyMethodCall() | |
772 | |
773 if expected_method._side_effects: | |
774 expected_method._side_effects(*params, **named_params) | |
775 | |
776 if expected_method._exception: | |
777 raise expected_method._exception | |
778 | |
779 return expected_method._return_value | |
780 | |
781 def __getattr__(self, name): | |
782 """Raise an AttributeError with a helpful message.""" | |
783 | |
784 raise AttributeError('MockMethod has no attribute "%s". ' | |
785 'Did you remember to put your mocks in replay mode?' % name) | |
786 | |
787 def __iter__(self): | |
788 """Raise a TypeError with a helpful message.""" | |
789 raise TypeError('MockMethod cannot be iterated. ' | |
790 'Did you remember to put your mocks in replay mode?') | |
791 | |
792 def next(self): | |
793 """Raise a TypeError with a helpful message.""" | |
794 raise TypeError('MockMethod cannot be iterated. ' | |
795 'Did you remember to put your mocks in replay mode?') | |
796 | |
797 def _PopNextMethod(self): | |
798 """Pop the next method from our call queue.""" | |
799 try: | |
800 return self._call_queue.popleft() | |
801 except IndexError: | |
802 raise UnexpectedMethodCallError(self, None) | |
803 | |
804 def _VerifyMethodCall(self): | |
805 """Verify the called method is expected. | |
806 | |
807 This can be an ordered method, or part of an unordered set. | |
808 | |
809 Returns: | |
810 The expected mock method. | |
811 | |
812 Raises: | |
813 UnexpectedMethodCall if the method called was not expected. | |
814 """ | |
815 | |
816 expected = self._PopNextMethod() | |
817 | |
818 # Loop here, because we might have a MethodGroup followed by another | |
819 # group. | |
820 while isinstance(expected, MethodGroup): | |
821 expected, method = expected.MethodCalled(self) | |
822 if method is not None: | |
823 return method | |
824 | |
825 # This is a mock method, so just check equality. | |
826 if expected != self: | |
827 raise UnexpectedMethodCallError(self, expected) | |
828 | |
829 return expected | |
830 | |
831 def __str__(self): | |
832 params = ', '.join( | |
833 [repr(p) for p in self._params or []] + | |
834 ['%s=%r' % x for x in sorted((self._named_params or {}).items())]) | |
835 full_desc = "%s(%s) -> %r" % (self._name, params, self._return_value) | |
836 if self._description: | |
837 full_desc = "%s.%s" % (self._description, full_desc) | |
838 return full_desc | |
839 | |
840 def __eq__(self, rhs): | |
841 """Test whether this MockMethod is equivalent to another MockMethod. | |
842 | |
843 Args: | |
844 # rhs: the right hand side of the test | |
845 rhs: MockMethod | |
846 """ | |
847 | |
848 return (isinstance(rhs, MockMethod) and | |
849 self._name == rhs._name and | |
850 self._params == rhs._params and | |
851 self._named_params == rhs._named_params) | |
852 | |
853 def __ne__(self, rhs): | |
854 """Test whether this MockMethod is not equivalent to another MockMethod. | |
855 | |
856 Args: | |
857 # rhs: the right hand side of the test | |
858 rhs: MockMethod | |
859 """ | |
860 | |
861 return not self == rhs | |
862 | |
863 def GetPossibleGroup(self): | |
864 """Returns a possible group from the end of the call queue or None if no | |
865 other methods are on the stack. | |
866 """ | |
867 | |
868 # Remove this method from the tail of the queue so we can add it to a group. | |
869 this_method = self._call_queue.pop() | |
870 assert this_method == self | |
871 | |
872 # Determine if the tail of the queue is a group, or just a regular ordered | |
873 # mock method. | |
874 group = None | |
875 try: | |
876 group = self._call_queue[-1] | |
877 except IndexError: | |
878 pass | |
879 | |
880 return group | |
881 | |
882 def _CheckAndCreateNewGroup(self, group_name, group_class): | |
883 """Checks if the last method (a possible group) is an instance of our | |
884 group_class. Adds the current method to this group or creates a new one. | |
885 | |
886 Args: | |
887 | |
888 group_name: the name of the group. | |
889 group_class: the class used to create instance of this new group | |
890 """ | |
891 group = self.GetPossibleGroup() | |
892 | |
893 # If this is a group, and it is the correct group, add the method. | |
894 if isinstance(group, group_class) and group.group_name() == group_name: | |
895 group.AddMethod(self) | |
896 return self | |
897 | |
898 # Create a new group and add the method. | |
899 new_group = group_class(group_name) | |
900 new_group.AddMethod(self) | |
901 self._call_queue.append(new_group) | |
902 return self | |
903 | |
904 def InAnyOrder(self, group_name="default"): | |
905 """Move this method into a group of unordered calls. | |
906 | |
907 A group of unordered calls must be defined together, and must be executed | |
908 in full before the next expected method can be called. There can be | |
909 multiple groups that are expected serially, if they are given | |
910 different group names. The same group name can be reused if there is a | |
911 standard method call, or a group with a different name, spliced between | |
912 usages. | |
913 | |
914 Args: | |
915 group_name: the name of the unordered group. | |
916 | |
917 Returns: | |
918 self | |
919 """ | |
920 return self._CheckAndCreateNewGroup(group_name, UnorderedGroup) | |
921 | |
922 def MultipleTimes(self, group_name="default"): | |
923 """Move this method into group of calls which may be called multiple times. | |
924 | |
925 A group of repeating calls must be defined together, and must be executed in | |
926 full before the next expected mehtod can be called. | |
927 | |
928 Args: | |
929 group_name: the name of the unordered group. | |
930 | |
931 Returns: | |
932 self | |
933 """ | |
934 return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup) | |
935 | |
936 def AndReturn(self, return_value): | |
937 """Set the value to return when this method is called. | |
938 | |
939 Args: | |
940 # return_value can be anything. | |
941 """ | |
942 | |
943 self._return_value = return_value | |
944 return return_value | |
945 | |
946 def AndRaise(self, exception): | |
947 """Set the exception to raise when this method is called. | |
948 | |
949 Args: | |
950 # exception: the exception to raise when this method is called. | |
951 exception: Exception | |
952 """ | |
953 | |
954 self._exception = exception | |
955 | |
956 def WithSideEffects(self, side_effects): | |
957 """Set the side effects that are simulated when this method is called. | |
958 | |
959 Args: | |
960 side_effects: A callable which modifies the parameters or other relevant | |
961 state which a given test case depends on. | |
962 | |
963 Returns: | |
964 Self for chaining with AndReturn and AndRaise. | |
965 """ | |
966 self._side_effects = side_effects | |
967 return self | |
968 | |
969 class Comparator: | |
970 """Base class for all Mox comparators. | |
971 | |
972 A Comparator can be used as a parameter to a mocked method when the exact | |
973 value is not known. For example, the code you are testing might build up a | |
974 long SQL string that is passed to your mock DAO. You're only interested that | |
975 the IN clause contains the proper primary keys, so you can set your mock | |
976 up as follows: | |
977 | |
978 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) | |
979 | |
980 Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'. | |
981 | |
982 A Comparator may replace one or more parameters, for example: | |
983 # return at most 10 rows | |
984 mock_dao.RunQuery(StrContains('SELECT'), 10) | |
985 | |
986 or | |
987 | |
988 # Return some non-deterministic number of rows | |
989 mock_dao.RunQuery(StrContains('SELECT'), IsA(int)) | |
990 """ | |
991 | |
992 def equals(self, rhs): | |
993 """Special equals method that all comparators must implement. | |
994 | |
995 Args: | |
996 rhs: any python object | |
997 """ | |
998 | |
999 raise NotImplementedError, 'method must be implemented by a subclass.' | |
1000 | |
1001 def __eq__(self, rhs): | |
1002 return self.equals(rhs) | |
1003 | |
1004 def __ne__(self, rhs): | |
1005 return not self.equals(rhs) | |
1006 | |
1007 | |
1008 class IsA(Comparator): | |
1009 """This class wraps a basic Python type or class. It is used to verify | |
1010 that a parameter is of the given type or class. | |
1011 | |
1012 Example: | |
1013 mock_dao.Connect(IsA(DbConnectInfo)) | |
1014 """ | |
1015 | |
1016 def __init__(self, class_name): | |
1017 """Initialize IsA | |
1018 | |
1019 Args: | |
1020 class_name: basic python type or a class | |
1021 """ | |
1022 | |
1023 self._class_name = class_name | |
1024 | |
1025 def equals(self, rhs): | |
1026 """Check to see if the RHS is an instance of class_name. | |
1027 | |
1028 Args: | |
1029 # rhs: the right hand side of the test | |
1030 rhs: object | |
1031 | |
1032 Returns: | |
1033 bool | |
1034 """ | |
1035 | |
1036 try: | |
1037 return isinstance(rhs, self._class_name) | |
1038 except TypeError: | |
1039 # Check raw types if there was a type error. This is helpful for | |
1040 # things like cStringIO.StringIO. | |
1041 return type(rhs) == type(self._class_name) | |
1042 | |
1043 def __repr__(self): | |
1044 return str(self._class_name) | |
1045 | |
1046 class IsAlmost(Comparator): | |
1047 """Comparison class used to check whether a parameter is nearly equal | |
1048 to a given value. Generally useful for floating point numbers. | |
1049 | |
1050 Example mock_dao.SetTimeout((IsAlmost(3.9))) | |
1051 """ | |
1052 | |
1053 def __init__(self, float_value, places=7): | |
1054 """Initialize IsAlmost. | |
1055 | |
1056 Args: | |
1057 float_value: The value for making the comparison. | |
1058 places: The number of decimal places to round to. | |
1059 """ | |
1060 | |
1061 self._float_value = float_value | |
1062 self._places = places | |
1063 | |
1064 def equals(self, rhs): | |
1065 """Check to see if RHS is almost equal to float_value | |
1066 | |
1067 Args: | |
1068 rhs: the value to compare to float_value | |
1069 | |
1070 Returns: | |
1071 bool | |
1072 """ | |
1073 | |
1074 try: | |
1075 return round(rhs-self._float_value, self._places) == 0 | |
1076 except TypeError: | |
1077 # This is probably because either float_value or rhs is not a number. | |
1078 return False | |
1079 | |
1080 def __repr__(self): | |
1081 return str(self._float_value) | |
1082 | |
1083 class StrContains(Comparator): | |
1084 """Comparison class used to check whether a substring exists in a | |
1085 string parameter. This can be useful in mocking a database with SQL | |
1086 passed in as a string parameter, for example. | |
1087 | |
1088 Example: | |
1089 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) | |
1090 """ | |
1091 | |
1092 def __init__(self, search_string): | |
1093 """Initialize. | |
1094 | |
1095 Args: | |
1096 # search_string: the string you are searching for | |
1097 search_string: str | |
1098 """ | |
1099 | |
1100 self._search_string = search_string | |
1101 | |
1102 def equals(self, rhs): | |
1103 """Check to see if the search_string is contained in the rhs string. | |
1104 | |
1105 Args: | |
1106 # rhs: the right hand side of the test | |
1107 rhs: object | |
1108 | |
1109 Returns: | |
1110 bool | |
1111 """ | |
1112 | |
1113 try: | |
1114 return rhs.find(self._search_string) > -1 | |
1115 except Exception: | |
1116 return False | |
1117 | |
1118 def __repr__(self): | |
1119 return '<str containing \'%s\'>' % self._search_string | |
1120 | |
1121 | |
1122 class Regex(Comparator): | |
1123 """Checks if a string matches a regular expression. | |
1124 | |
1125 This uses a given regular expression to determine equality. | |
1126 """ | |
1127 | |
1128 def __init__(self, pattern, flags=0): | |
1129 """Initialize. | |
1130 | |
1131 Args: | |
1132 # pattern is the regular expression to search for | |
1133 pattern: str | |
1134 # flags passed to re.compile function as the second argument | |
1135 flags: int | |
1136 """ | |
1137 | |
1138 self.regex = re.compile(pattern, flags=flags) | |
1139 | |
1140 def equals(self, rhs): | |
1141 """Check to see if rhs matches regular expression pattern. | |
1142 | |
1143 Returns: | |
1144 bool | |
1145 """ | |
1146 | |
1147 return self.regex.search(rhs) is not None | |
1148 | |
1149 def __repr__(self): | |
1150 s = '<regular expression \'%s\'' % self.regex.pattern | |
1151 if self.regex.flags: | |
1152 s += ', flags=%d' % self.regex.flags | |
1153 s += '>' | |
1154 return s | |
1155 | |
1156 | |
1157 class In(Comparator): | |
1158 """Checks whether an item (or key) is in a list (or dict) parameter. | |
1159 | |
1160 Example: | |
1161 mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result) | |
1162 """ | |
1163 | |
1164 def __init__(self, key): | |
1165 """Initialize. | |
1166 | |
1167 Args: | |
1168 # key is any thing that could be in a list or a key in a dict | |
1169 """ | |
1170 | |
1171 self._key = key | |
1172 | |
1173 def equals(self, rhs): | |
1174 """Check to see whether key is in rhs. | |
1175 | |
1176 Args: | |
1177 rhs: dict | |
1178 | |
1179 Returns: | |
1180 bool | |
1181 """ | |
1182 | |
1183 return self._key in rhs | |
1184 | |
1185 def __repr__(self): | |
1186 return '<sequence or map containing \'%s\'>' % self._key | |
1187 | |
1188 | |
1189 class Not(Comparator): | |
1190 """Checks whether a predicates is False. | |
1191 | |
1192 Example: | |
1193 mock_dao.UpdateUsers(Not(ContainsKeyValue('stevepm', stevepm_user_info))) | |
1194 """ | |
1195 | |
1196 def __init__(self, predicate): | |
1197 """Initialize. | |
1198 | |
1199 Args: | |
1200 # predicate: a Comparator instance. | |
1201 """ | |
1202 | |
1203 assert isinstance(predicate, Comparator), ("predicate %r must be a" | |
1204 " Comparator." % predicate) | |
1205 self._predicate = predicate | |
1206 | |
1207 def equals(self, rhs): | |
1208 """Check to see whether the predicate is False. | |
1209 | |
1210 Args: | |
1211 rhs: A value that will be given in argument of the predicate. | |
1212 | |
1213 Returns: | |
1214 bool | |
1215 """ | |
1216 | |
1217 return not self._predicate.equals(rhs) | |
1218 | |
1219 def __repr__(self): | |
1220 return '<not \'%s\'>' % self._predicate | |
1221 | |
1222 | |
1223 class ContainsKeyValue(Comparator): | |
1224 """Checks whether a key/value pair is in a dict parameter. | |
1225 | |
1226 Example: | |
1227 mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info)) | |
1228 """ | |
1229 | |
1230 def __init__(self, key, value): | |
1231 """Initialize. | |
1232 | |
1233 Args: | |
1234 # key: a key in a dict | |
1235 # value: the corresponding value | |
1236 """ | |
1237 | |
1238 self._key = key | |
1239 self._value = value | |
1240 | |
1241 def equals(self, rhs): | |
1242 """Check whether the given key/value pair is in the rhs dict. | |
1243 | |
1244 Returns: | |
1245 bool | |
1246 """ | |
1247 | |
1248 try: | |
1249 return rhs[self._key] == self._value | |
1250 except Exception: | |
1251 return False | |
1252 | |
1253 def __repr__(self): | |
1254 return '<map containing the entry \'%s: %s\'>' % (self._key, self._value) | |
1255 | |
1256 | |
1257 class SameElementsAs(Comparator): | |
1258 """Checks whether iterables contain the same elements (ignoring order). | |
1259 | |
1260 Example: | |
1261 mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki')) | |
1262 """ | |
1263 | |
1264 def __init__(self, expected_seq): | |
1265 """Initialize. | |
1266 | |
1267 Args: | |
1268 expected_seq: a sequence | |
1269 """ | |
1270 | |
1271 self._expected_seq = expected_seq | |
1272 | |
1273 def equals(self, actual_seq): | |
1274 """Check to see whether actual_seq has same elements as expected_seq. | |
1275 | |
1276 Args: | |
1277 actual_seq: sequence | |
1278 | |
1279 Returns: | |
1280 bool | |
1281 """ | |
1282 | |
1283 try: | |
1284 expected = dict([(element, None) for element in self._expected_seq]) | |
1285 actual = dict([(element, None) for element in actual_seq]) | |
1286 except TypeError: | |
1287 # Fall back to slower list-compare if any of the objects are unhashable. | |
1288 expected = list(self._expected_seq) | |
1289 actual = list(actual_seq) | |
1290 expected.sort() | |
1291 actual.sort() | |
1292 return expected == actual | |
1293 | |
1294 def __repr__(self): | |
1295 return '<sequence with same elements as \'%s\'>' % self._expected_seq | |
1296 | |
1297 | |
1298 class And(Comparator): | |
1299 """Evaluates one or more Comparators on RHS and returns an AND of the results. | |
1300 """ | |
1301 | |
1302 def __init__(self, *args): | |
1303 """Initialize. | |
1304 | |
1305 Args: | |
1306 *args: One or more Comparator | |
1307 """ | |
1308 | |
1309 self._comparators = args | |
1310 | |
1311 def equals(self, rhs): | |
1312 """Checks whether all Comparators are equal to rhs. | |
1313 | |
1314 Args: | |
1315 # rhs: can be anything | |
1316 | |
1317 Returns: | |
1318 bool | |
1319 """ | |
1320 | |
1321 for comparator in self._comparators: | |
1322 if not comparator.equals(rhs): | |
1323 return False | |
1324 | |
1325 return True | |
1326 | |
1327 def __repr__(self): | |
1328 return '<AND %s>' % str(self._comparators) | |
1329 | |
1330 | |
1331 class Or(Comparator): | |
1332 """Evaluates one or more Comparators on RHS and returns an OR of the results. | |
1333 """ | |
1334 | |
1335 def __init__(self, *args): | |
1336 """Initialize. | |
1337 | |
1338 Args: | |
1339 *args: One or more Mox comparators | |
1340 """ | |
1341 | |
1342 self._comparators = args | |
1343 | |
1344 def equals(self, rhs): | |
1345 """Checks whether any Comparator is equal to rhs. | |
1346 | |
1347 Args: | |
1348 # rhs: can be anything | |
1349 | |
1350 Returns: | |
1351 bool | |
1352 """ | |
1353 | |
1354 for comparator in self._comparators: | |
1355 if comparator.equals(rhs): | |
1356 return True | |
1357 | |
1358 return False | |
1359 | |
1360 def __repr__(self): | |
1361 return '<OR %s>' % str(self._comparators) | |
1362 | |
1363 | |
1364 class Func(Comparator): | |
1365 """Call a function that should verify the parameter passed in is correct. | |
1366 | |
1367 You may need the ability to perform more advanced operations on the parameter | |
1368 in order to validate it. You can use this to have a callable validate any | |
1369 parameter. The callable should return either True or False. | |
1370 | |
1371 | |
1372 Example: | |
1373 | |
1374 def myParamValidator(param): | |
1375 # Advanced logic here | |
1376 return True | |
1377 | |
1378 mock_dao.DoSomething(Func(myParamValidator), true) | |
1379 """ | |
1380 | |
1381 def __init__(self, func): | |
1382 """Initialize. | |
1383 | |
1384 Args: | |
1385 func: callable that takes one parameter and returns a bool | |
1386 """ | |
1387 | |
1388 self._func = func | |
1389 | |
1390 def equals(self, rhs): | |
1391 """Test whether rhs passes the function test. | |
1392 | |
1393 rhs is passed into func. | |
1394 | |
1395 Args: | |
1396 rhs: any python object | |
1397 | |
1398 Returns: | |
1399 the result of func(rhs) | |
1400 """ | |
1401 | |
1402 return self._func(rhs) | |
1403 | |
1404 def __repr__(self): | |
1405 return str(self._func) | |
1406 | |
1407 | |
1408 class IgnoreArg(Comparator): | |
1409 """Ignore an argument. | |
1410 | |
1411 This can be used when we don't care about an argument of a method call. | |
1412 | |
1413 Example: | |
1414 # Check if CastMagic is called with 3 as first arg and 'disappear' as third. | |
1415 mymock.CastMagic(3, IgnoreArg(), 'disappear') | |
1416 """ | |
1417 | |
1418 def equals(self, unused_rhs): | |
1419 """Ignores arguments and returns True. | |
1420 | |
1421 Args: | |
1422 unused_rhs: any python object | |
1423 | |
1424 Returns: | |
1425 always returns True | |
1426 """ | |
1427 | |
1428 return True | |
1429 | |
1430 def __repr__(self): | |
1431 return '<IgnoreArg>' | |
1432 | |
1433 | |
1434 class MethodGroup(object): | |
1435 """Base class containing common behaviour for MethodGroups.""" | |
1436 | |
1437 def __init__(self, group_name): | |
1438 self._group_name = group_name | |
1439 | |
1440 def group_name(self): | |
1441 return self._group_name | |
1442 | |
1443 def __str__(self): | |
1444 return '<%s "%s">' % (self.__class__.__name__, self._group_name) | |
1445 | |
1446 def AddMethod(self, mock_method): | |
1447 raise NotImplementedError | |
1448 | |
1449 def MethodCalled(self, mock_method): | |
1450 raise NotImplementedError | |
1451 | |
1452 def IsSatisfied(self): | |
1453 raise NotImplementedError | |
1454 | |
1455 class UnorderedGroup(MethodGroup): | |
1456 """UnorderedGroup holds a set of method calls that may occur in any order. | |
1457 | |
1458 This construct is helpful for non-deterministic events, such as iterating | |
1459 over the keys of a dict. | |
1460 """ | |
1461 | |
1462 def __init__(self, group_name): | |
1463 super(UnorderedGroup, self).__init__(group_name) | |
1464 self._methods = [] | |
1465 | |
1466 def AddMethod(self, mock_method): | |
1467 """Add a method to this group. | |
1468 | |
1469 Args: | |
1470 mock_method: A mock method to be added to this group. | |
1471 """ | |
1472 | |
1473 self._methods.append(mock_method) | |
1474 | |
1475 def MethodCalled(self, mock_method): | |
1476 """Remove a method call from the group. | |
1477 | |
1478 If the method is not in the set, an UnexpectedMethodCallError will be | |
1479 raised. | |
1480 | |
1481 Args: | |
1482 mock_method: a mock method that should be equal to a method in the group. | |
1483 | |
1484 Returns: | |
1485 The mock method from the group | |
1486 | |
1487 Raises: | |
1488 UnexpectedMethodCallError if the mock_method was not in the group. | |
1489 """ | |
1490 | |
1491 # Check to see if this method exists, and if so, remove it from the set | |
1492 # and return it. | |
1493 for method in self._methods: | |
1494 if method == mock_method: | |
1495 # Remove the called mock_method instead of the method in the group. | |
1496 # The called method will match any comparators when equality is checked | |
1497 # during removal. The method in the group could pass a comparator to | |
1498 # another comparator during the equality check. | |
1499 self._methods.remove(mock_method) | |
1500 | |
1501 # If this group is not empty, put it back at the head of the queue. | |
1502 if not self.IsSatisfied(): | |
1503 mock_method._call_queue.appendleft(self) | |
1504 | |
1505 return self, method | |
1506 | |
1507 raise UnexpectedMethodCallError(mock_method, self) | |
1508 | |
1509 def IsSatisfied(self): | |
1510 """Return True if there are not any methods in this group.""" | |
1511 | |
1512 return len(self._methods) == 0 | |
1513 | |
1514 | |
1515 class MultipleTimesGroup(MethodGroup): | |
1516 """MultipleTimesGroup holds methods that may be called any number of times. | |
1517 | |
1518 Note: Each method must be called at least once. | |
1519 | |
1520 This is helpful, if you don't know or care how many times a method is called. | |
1521 """ | |
1522 | |
1523 def __init__(self, group_name): | |
1524 super(MultipleTimesGroup, self).__init__(group_name) | |
1525 self._methods = set() | |
1526 self._methods_left = set() | |
1527 | |
1528 def AddMethod(self, mock_method): | |
1529 """Add a method to this group. | |
1530 | |
1531 Args: | |
1532 mock_method: A mock method to be added to this group. | |
1533 """ | |
1534 | |
1535 self._methods.add(mock_method) | |
1536 self._methods_left.add(mock_method) | |
1537 | |
1538 def MethodCalled(self, mock_method): | |
1539 """Remove a method call from the group. | |
1540 | |
1541 If the method is not in the set, an UnexpectedMethodCallError will be | |
1542 raised. | |
1543 | |
1544 Args: | |
1545 mock_method: a mock method that should be equal to a method in the group. | |
1546 | |
1547 Returns: | |
1548 The mock method from the group | |
1549 | |
1550 Raises: | |
1551 UnexpectedMethodCallError if the mock_method was not in the group. | |
1552 """ | |
1553 | |
1554 # Check to see if this method exists, and if so add it to the set of | |
1555 # called methods. | |
1556 for method in self._methods: | |
1557 if method == mock_method: | |
1558 self._methods_left.discard(method) | |
1559 # Always put this group back on top of the queue, because we don't know | |
1560 # when we are done. | |
1561 mock_method._call_queue.appendleft(self) | |
1562 return self, method | |
1563 | |
1564 if self.IsSatisfied(): | |
1565 next_method = mock_method._PopNextMethod(); | |
1566 return next_method, None | |
1567 else: | |
1568 raise UnexpectedMethodCallError(mock_method, self) | |
1569 | |
1570 def IsSatisfied(self): | |
1571 """Return True if all methods in this group are called at least once.""" | |
1572 return len(self._methods_left) == 0 | |
1573 | |
1574 | |
1575 class MoxMetaTestBase(type): | |
1576 """Metaclass to add mox cleanup and verification to every test. | |
1577 | |
1578 As the mox unit testing class is being constructed (MoxTestBase or a | |
1579 subclass), this metaclass will modify all test functions to call the | |
1580 CleanUpMox method of the test class after they finish. This means that | |
1581 unstubbing and verifying will happen for every test with no additional code, | |
1582 and any failures will result in test failures as opposed to errors. | |
1583 """ | |
1584 | |
1585 def __init__(cls, name, bases, d): | |
1586 type.__init__(cls, name, bases, d) | |
1587 | |
1588 # also get all the attributes from the base classes to account | |
1589 # for a case when test class is not the immediate child of MoxTestBase | |
1590 for base in bases: | |
1591 for attr_name in dir(base): | |
1592 d[attr_name] = getattr(base, attr_name) | |
1593 | |
1594 for func_name, func in d.items(): | |
1595 if func_name.startswith('test') and callable(func): | |
1596 setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func)) | |
1597 | |
1598 @staticmethod | |
1599 def CleanUpTest(cls, func): | |
1600 """Adds Mox cleanup code to any MoxTestBase method. | |
1601 | |
1602 Always unsets stubs after a test. Will verify all mocks for tests that | |
1603 otherwise pass. | |
1604 | |
1605 Args: | |
1606 cls: MoxTestBase or subclass; the class whose test method we are altering. | |
1607 func: method; the method of the MoxTestBase test class we wish to alter. | |
1608 | |
1609 Returns: | |
1610 The modified method. | |
1611 """ | |
1612 def new_method(self, *args, **kwargs): | |
1613 mox_obj = getattr(self, 'mox', None) | |
1614 cleanup_mox = False | |
1615 if mox_obj and isinstance(mox_obj, Mox): | |
1616 cleanup_mox = True | |
1617 try: | |
1618 func(self, *args, **kwargs) | |
1619 finally: | |
1620 if cleanup_mox: | |
1621 mox_obj.UnsetStubs() | |
1622 if cleanup_mox: | |
1623 mox_obj.VerifyAll() | |
1624 new_method.__name__ = func.__name__ | |
1625 new_method.__doc__ = func.__doc__ | |
1626 new_method.__module__ = func.__module__ | |
1627 return new_method | |
1628 | |
1629 | |
1630 class MoxTestBase(unittest.TestCase): | |
1631 """Convenience test class to make stubbing easier. | |
1632 | |
1633 Sets up a "mox" attribute which is an instance of Mox - any mox tests will | |
1634 want this. Also automatically unsets any stubs and verifies that all mock | |
1635 methods have been called at the end of each test, eliminating boilerplate | |
1636 code. | |
1637 """ | |
1638 | |
1639 __metaclass__ = MoxMetaTestBase | |
1640 | |
1641 def setUp(self): | |
1642 super(MoxTestBase, self).setUp() | |
1643 self.mox = Mox() | |
OLD | NEW |