OLD | NEW |
| (Empty) |
1 #!/usr/bin/env python | |
2 # Copyright 2014 The Chromium Authors. All rights reserved. | |
3 # Use of this source code is governed by a BSD-style license that can be | |
4 # found in the LICENSE file. | |
5 | |
6 """ | |
7 Collection of unit tests for 'infra.libs.memoize' library. | |
8 """ | |
9 | |
10 import unittest | |
11 | |
12 from infra_libs import memoize | |
13 | |
14 class MemoTestCase(unittest.TestCase): | |
15 | |
16 def setUp(self): | |
17 self.tagged = set() | |
18 | |
19 def tag(self, tag=None): | |
20 self.assertNotIn(tag, self.tagged) | |
21 self.tagged.add(tag) | |
22 | |
23 def assertTagged(self, *tags): | |
24 if len(tags) == 0: | |
25 tags = (None,) | |
26 self.assertEqual(set(tags), self.tagged) | |
27 | |
28 def clearTagged(self): | |
29 self.tagged.clear() | |
30 | |
31 | |
32 class FunctionTestCase(MemoTestCase): | |
33 | |
34 def testFuncNoArgs(self): | |
35 @memoize.memo() | |
36 def func(): | |
37 self.tag() | |
38 return 'foo' | |
39 | |
40 for _ in xrange(10): | |
41 self.assertEqual(func(), 'foo') | |
42 self.assertTagged() | |
43 | |
44 def testFuncAllArgs(self): | |
45 @memoize.memo() | |
46 def func(a, b): | |
47 self.tag((a, b)) | |
48 return a + b | |
49 | |
50 # Execute multiple rounds of two unique function executions. | |
51 for _ in xrange(10): | |
52 self.assertEqual(func(1, 2), 3) | |
53 self.assertEqual(func(3, 4), 7) | |
54 self.assertTagged( | |
55 (1, 2), | |
56 (3, 4), | |
57 ) | |
58 | |
59 def testFuncIgnoreArgs(self): | |
60 @memoize.memo(ignore=('b')) | |
61 def func(a, b): | |
62 self.tag(a) | |
63 return a + b | |
64 | |
65 # Execute multiple rounds of two unique function executions. | |
66 for _ in xrange(10): | |
67 self.assertEqual(func(1, 1), 2) | |
68 self.assertEqual(func(1, 2), 2) | |
69 self.assertEqual(func(2, 1), 3) | |
70 self.assertEqual(func(2, 2), 3) | |
71 self.assertTagged( | |
72 1, | |
73 2, | |
74 ) | |
75 | |
76 def testOldClassMethod(self): | |
77 class Test: | |
78 # Disable 'no __init__ method' warning | pylint: disable=W0232 | |
79 # pylint: disable=old-style-class | |
80 | |
81 @classmethod | |
82 @memoize.memo() | |
83 def func(cls, a): | |
84 self.tag(a) | |
85 return a | |
86 | |
87 # Execute multiple rounds of two unique function executions. | |
88 for _ in xrange(10): | |
89 self.assertEqual(Test.func(1), 1) | |
90 self.assertEqual(Test.func(2), 2) | |
91 self.assertTagged( | |
92 1, | |
93 2, | |
94 ) | |
95 | |
96 def testNewClassMethod(self): | |
97 class Test(object): | |
98 # Disable 'no __init__ method' warning | pylint: disable=W0232 | |
99 | |
100 @classmethod | |
101 @memoize.memo() | |
102 def func(cls, a): | |
103 self.tag(a) | |
104 return a | |
105 | |
106 # Execute multiple rounds of two unique function executions. | |
107 for _ in xrange(10): | |
108 self.assertEqual(Test.func(1), 1) | |
109 self.assertEqual(Test.func(2), 2) | |
110 self.assertTagged( | |
111 1, | |
112 2, | |
113 ) | |
114 | |
115 def testOldClassStaticMethod(self): | |
116 class Test: | |
117 # Disable 'no __init__ method' warning | pylint: disable=W0232 | |
118 # pylint: disable=old-style-class | |
119 | |
120 @staticmethod | |
121 @memoize.memo() | |
122 def func(a): | |
123 self.tag(a) | |
124 return a | |
125 | |
126 # Execute multiple rounds of two unique function executions. | |
127 for _ in xrange(10): | |
128 self.assertEqual(Test.func(1), 1) | |
129 self.assertEqual(Test.func(2), 2) | |
130 self.assertTagged( | |
131 1, | |
132 2, | |
133 ) | |
134 | |
135 def testNewClassStaticMethod(self): | |
136 class Test(object): | |
137 # Disable 'no __init__ method' warning | pylint: disable=W0232 | |
138 | |
139 @staticmethod | |
140 @memoize.memo() | |
141 def func(a): | |
142 self.tag(a) | |
143 return a | |
144 | |
145 # Execute multiple rounds of two unique function executions. | |
146 for _ in xrange(10): | |
147 self.assertEqual(Test.func(1), 1) | |
148 self.assertEqual(Test.func(2), 2) | |
149 self.assertTagged( | |
150 1, | |
151 2, | |
152 ) | |
153 | |
154 def testClearAllArgs(self): | |
155 @memoize.memo() | |
156 def func(a, b=10): | |
157 self.tag((a, b)) | |
158 return a + b | |
159 | |
160 # First round | |
161 self.assertEqual(func(1), 11) | |
162 self.assertEqual(func(1, b=0), 1) | |
163 self.assertTagged( | |
164 (1, 10), | |
165 (1, 0), | |
166 ) | |
167 | |
168 # Clear (1) | |
169 self.clearTagged() | |
170 func.memo_clear(1) | |
171 | |
172 self.assertEqual(func(1), 11) | |
173 self.assertEqual(func(1, b=0), 1) | |
174 self.assertTagged( | |
175 (1, 10), | |
176 ) | |
177 | |
178 # Clear (1, b=0) | |
179 self.clearTagged() | |
180 func.memo_clear(1, b=0) | |
181 | |
182 self.assertEqual(func(1), 11) | |
183 self.assertEqual(func(1, b=0), 1) | |
184 self.assertTagged( | |
185 (1, 0), | |
186 ) | |
187 | |
188 | |
189 class MemoInstanceMethodTestCase(MemoTestCase): | |
190 | |
191 class TestBaseOld: | |
192 # pylint: disable=old-style-class | |
193 def __init__(self, test_case, name): | |
194 self.test_case = test_case | |
195 self.name = name | |
196 | |
197 def __hash__(self): | |
198 # Prevent this instance from being used as a memo key | |
199 raise NotImplementedError() | |
200 | |
201 | |
202 class TestBaseNew(object): | |
203 def __init__(self, test_case, name): | |
204 self.test_case = test_case | |
205 self.name = name | |
206 | |
207 def __hash__(self): | |
208 # Prevent this instance from being used as a memo key | |
209 raise NotImplementedError() | |
210 | |
211 | |
212 class TestHash(object): | |
213 | |
214 def __init__(self): | |
215 self._counter = 0 | |
216 | |
217 @memoize.memo() | |
218 def __hash__(self): | |
219 assert self._counter == 0 | |
220 self._counter += 1 | |
221 return self._counter | |
222 | |
223 | |
224 def testOldClassNoArgs(self): | |
225 class Test(self.TestBaseOld): | |
226 # Disable 'hash not overridden' warning | pylint: disable=W0223 | |
227 | |
228 @memoize.memo() | |
229 def func(self): | |
230 self.test_case.tag(self.name) | |
231 return 'foo' | |
232 | |
233 t0 = Test(self, 't0') | |
234 t1 = Test(self, 't1') | |
235 for _ in xrange(10): | |
236 self.assertEqual(t0.func(), 'foo') | |
237 self.assertEqual(t1.func(), 'foo') | |
238 self.assertTagged( | |
239 't0', | |
240 't1', | |
241 ) | |
242 | |
243 def testNewClassNoArgs(self): | |
244 class Test(self.TestBaseNew): | |
245 # Disable 'hash not overridden' warning | pylint: disable=W0223 | |
246 | |
247 @memoize.memo() | |
248 def func(self): | |
249 self.test_case.tag(self.name) | |
250 return 'foo' | |
251 | |
252 t0 = Test(self, 't0') | |
253 t1 = Test(self, 't1') | |
254 for _ in xrange(10): | |
255 self.assertEqual(t0.func(), 'foo') | |
256 self.assertEqual(t1.func(), 'foo') | |
257 self.assertTagged( | |
258 't0', | |
259 't1', | |
260 ) | |
261 | |
262 def testOldClassArgs(self): | |
263 class Test(self.TestBaseOld): | |
264 # Disable 'hash not overridden' warning | pylint: disable=W0223 | |
265 | |
266 @memoize.memo() | |
267 def func(self, a, b): | |
268 self.test_case.tag((self.name, a, b)) | |
269 return a + b | |
270 | |
271 t0 = Test(self, 't0') | |
272 t1 = Test(self, 't1') | |
273 for _ in xrange(10): | |
274 self.assertEqual(t0.func(1, 2), 3) | |
275 self.assertEqual(t0.func(1, 3), 4) | |
276 self.assertEqual(t1.func(1, 2), 3) | |
277 self.assertEqual(t1.func(1, 3), 4) | |
278 self.assertTagged( | |
279 ('t0', 1, 2), | |
280 ('t0', 1, 3), | |
281 ('t1', 1, 2), | |
282 ('t1', 1, 3), | |
283 ) | |
284 | |
285 def testNewClassArgs(self): | |
286 class Test(self.TestBaseNew): | |
287 # Disable 'hash not overridden' warning | pylint: disable=W0223 | |
288 | |
289 @memoize.memo() | |
290 def func(self, a, b): | |
291 self.test_case.tag((self.name, a, b)) | |
292 return a + b | |
293 | |
294 t0 = Test(self, 't0') | |
295 t1 = Test(self, 't1') | |
296 for _ in xrange(10): | |
297 self.assertEqual(t0.func(1, 2), 3) | |
298 self.assertEqual(t0.func(1, 3), 4) | |
299 self.assertEqual(t1.func(1, 2), 3) | |
300 self.assertEqual(t1.func(1, 3), 4) | |
301 self.assertTagged( | |
302 ('t0', 1, 2), | |
303 ('t0', 1, 3), | |
304 ('t1', 1, 2), | |
305 ('t1', 1, 3), | |
306 ) | |
307 | |
308 def testNewClassDirectCall(self): | |
309 class Test(self.TestBaseNew): | |
310 # Disable 'hash not overridden' warning | pylint: disable=W0223 | |
311 | |
312 @memoize.memo() | |
313 def func(self, a, b): | |
314 self.test_case.tag((self.name, a, b)) | |
315 return a + b | |
316 | |
317 t0 = Test(self, 't0') | |
318 for _ in xrange(10): | |
319 self.assertEqual(t0.func.__get__(t0)(1,2), 3) | |
320 self.assertTagged( | |
321 ('t0', 1, 2), | |
322 ) | |
323 | |
324 def testClear(self): | |
325 class Test(self.TestBaseNew): | |
326 # Disable 'hash not overridden' warning | pylint: disable=W0223 | |
327 | |
328 @memoize.memo() | |
329 def func(self, a): | |
330 self.test_case.tag((self.name, a)) | |
331 return a | |
332 | |
333 # Call '10' and '20' | |
334 t = Test(self, 'test') | |
335 t.func(10) | |
336 self.assertTagged( | |
337 ('test', 10), | |
338 ) | |
339 | |
340 # Clear | |
341 self.clearTagged() | |
342 t.func.memo_clear(10) | |
343 | |
344 # Call '10'; it should be tagged | |
345 t.func(10) | |
346 self.assertTagged( | |
347 ('test', 10), | |
348 ) | |
349 | |
350 | |
351 def testOverrideHash(self): | |
352 t = self.TestHash() | |
353 self.assertEquals(hash(t), 1) | |
354 self.assertEquals(hash(t), 1) | |
355 | |
356 | |
357 class MemoClassMethodTestCase(MemoTestCase): | |
358 """Tests handling of the 'cls' and 'self' parameters""" | |
359 | |
360 class Test(object): | |
361 | |
362 def __init__(self, test_case, name): | |
363 self.test_case = test_case | |
364 self.name = name | |
365 self._value = 0 | |
366 | |
367 @memoize.memo(ignore=('tag',)) | |
368 def func(self, a, tag): | |
369 self.test_case.tag(tag) | |
370 return a | |
371 | |
372 @classmethod | |
373 @memoize.memo(ignore=('test_case', 'tag')) | |
374 def class_func(cls, test_case, memo_value, tag): | |
375 test_case.tag(tag) | |
376 return memo_value | |
377 | |
378 @property | |
379 @memoize.memo() | |
380 def prop(self): | |
381 self.test_case.tag(self.name) | |
382 return self._value | |
383 | |
384 @prop.setter | |
385 def prop(self, value): | |
386 self._value = value | |
387 | |
388 @prop.deleter | |
389 def prop(self): | |
390 self._value = 0 | |
391 | |
392 | |
393 class TestWithEquals(Test): | |
394 | |
395 def __hash__(self): | |
396 return 7 | |
397 | |
398 def __eq__(self, other): | |
399 return type(other) == type(self) | |
400 | |
401 | |
402 def testClassMethodNoEquals(self): | |
403 self.assertEqual(self.Test.class_func(self, 1, 't0'), 1) | |
404 self.assertEqual(self.Test.class_func(self, 1, 't1'), 1) | |
405 self.assertEqual(self.Test.class_func(self, 2, 't2'), 2) | |
406 | |
407 self.Test.class_func.memo_clear(self.Test, self, 2, None) | |
408 self.assertEqual(self.Test.class_func(self, 1, 't3'), 1) | |
409 self.assertEqual(self.Test.class_func(self, 2, 't4'), 2) | |
410 | |
411 self.Test.class_func.memo_clear() | |
412 self.assertEqual(self.Test.class_func(self, 1, 't5'), 1) | |
413 | |
414 self.assertTagged( | |
415 't0', | |
416 't2', | |
417 't4', | |
418 't5', | |
419 ) | |
420 | |
421 def testInstanceMethodNoEquals(self): | |
422 t0 = self.Test(self, 't0') | |
423 t1 = self.Test(self, 't1') | |
424 | |
425 self.assertEqual(t0.func(1, 't0.0'), 1) | |
426 self.assertEqual(t1.func(1, 't1.0'), 1) | |
427 | |
428 t1.func.memo_clear(1, None) | |
429 self.assertEqual(t0.func(1, 't0.1'), 1) | |
430 self.assertEqual(t1.func(1, 't1.1'), 1) | |
431 self.assertTagged( | |
432 't0.0', | |
433 't1.0', | |
434 't1.1', | |
435 ) | |
436 | |
437 def testInstanceMethodWithEquals(self): | |
438 t0 = self.TestWithEquals(self, 't0') | |
439 t1 = self.TestWithEquals(self, 't1') | |
440 | |
441 self.assertEqual(hash(t0), 7) | |
442 self.assertTrue(t0 == t1) | |
443 self.assertEqual(t0.func(1, 't0.0'), 1) | |
444 self.assertEqual(t1.func(1, 't1.0'), 1) | |
445 | |
446 t1.func.memo_clear(1, None) | |
447 self.assertEqual(t0.func(1, 't0.1'), 1) | |
448 self.assertEqual(t1.func(1, 't1.1'), 1) | |
449 self.assertTagged( | |
450 't0.0', | |
451 't1.0', | |
452 't1.1', | |
453 ) | |
454 | |
455 def testProperty(self): | |
456 t0 = self.Test(self, 't0') | |
457 t1 = self.Test(self, 't1') | |
458 | |
459 # The property can be set. | |
460 t0.prop = 1024 | |
461 t1.prop = 1337 | |
462 del(t1.prop) | |
463 for _ in xrange(10): | |
464 self.assertEqual(t0.prop, 1024) | |
465 self.assertEqual(t1.prop, 0) | |
466 | |
467 self.assertTagged( | |
468 't0', | |
469 't1', | |
470 ) | |
471 | |
472 | |
473 if __name__ == '__main__': | |
474 unittest.main() | |
OLD | NEW |