OLD | NEW |
| (Empty) |
1 # Copyright (c) 2001-2004 Twisted Matrix Laboratories. | |
2 # See LICENSE for details. | |
3 | |
4 | |
5 """Testing for twisted.persisted.journal.""" | |
6 | |
7 from twisted.trial import unittest | |
8 from twisted.persisted.journal.base import ICommand, MemoryJournal, serviceComma
nd, ServiceWrapperCommand, command, Wrappable | |
9 from twisted.persisted.journal.picklelog import DirDBMLog | |
10 from zope.interface import implements | |
11 | |
12 import shutil, os.path | |
13 | |
14 | |
15 | |
16 class AddTime: | |
17 | |
18 implements(ICommand) | |
19 | |
20 def execute(self, svc, cmdtime): | |
21 svc.values["time"] = cmdtime | |
22 | |
23 | |
24 class Counter(Wrappable): | |
25 | |
26 objectType = "counter" | |
27 | |
28 def __init__(self, uid): | |
29 self.uid = uid | |
30 self.x = 0 | |
31 | |
32 def getUid(self): | |
33 return self.uid | |
34 | |
35 def _increment(self): | |
36 self.x += 1 | |
37 | |
38 increment = command("_increment") | |
39 | |
40 | |
41 class Service: | |
42 | |
43 def __init__(self, logpath, journalpath): | |
44 log = DirDBMLog(logpath) | |
45 self.journal = MemoryJournal(log, self, journalpath, self._gotData) | |
46 self.journal.updateFromLog() | |
47 | |
48 def _gotData(self, result): | |
49 if result is None: | |
50 self.values = {} | |
51 self.counters = {} | |
52 else: | |
53 self.values, self.counters = result | |
54 | |
55 def _makeCounter(self, id): | |
56 c = Counter(id) | |
57 self.counters[id] = c | |
58 return c | |
59 | |
60 makeCounter = serviceCommand("_makeCounter") | |
61 | |
62 def loadObject(self, type, id): | |
63 if type != "counter": raise ValueError | |
64 return self.counters[id] | |
65 | |
66 def _add(self, key, value): | |
67 """Add a new entry.""" | |
68 self.values[key] = value | |
69 | |
70 def _delete(self, key): | |
71 """Delete an entry.""" | |
72 del self.values[key] | |
73 | |
74 def get(self, key): | |
75 """Return value of an entry.""" | |
76 return self.values[key] | |
77 | |
78 def addtime(self, journal): | |
79 """Set a key 'time' with the current time.""" | |
80 journal.executeCommand(AddTime()) | |
81 | |
82 # and now the command wrappers | |
83 | |
84 add = serviceCommand("_add") | |
85 | |
86 delete = serviceCommand("_delete") | |
87 | |
88 | |
89 class JournalTestCase(unittest.TestCase): | |
90 | |
91 def setUp(self): | |
92 self.logpath = self.mktemp() | |
93 self.journalpath = self.mktemp() | |
94 self.svc = Service(self.logpath, self.journalpath) | |
95 | |
96 def tearDown(self): | |
97 if hasattr(self, "svc"): | |
98 del self.svc | |
99 # delete stuff? ... | |
100 if os.path.isdir(self.logpath): | |
101 shutil.rmtree(self.logpath) | |
102 if os.path.exists(self.logpath): | |
103 os.unlink(self.logpath) | |
104 if os.path.isdir(self.journalpath): | |
105 shutil.rmtree(self.journalpath) | |
106 if os.path.exists(self.journalpath): | |
107 os.unlink(self.journalpath) | |
108 | |
109 def testCommandExecution(self): | |
110 svc = self.svc | |
111 svc.add(svc.journal, "foo", "bar") | |
112 self.assertEquals(svc.get("foo"), "bar") | |
113 | |
114 svc.delete(svc.journal, "foo") | |
115 self.assertRaises(KeyError, svc.get, "foo") | |
116 | |
117 def testLogging(self): | |
118 svc = self.svc | |
119 log = self.svc.journal.log | |
120 j = self.svc.journal | |
121 svc.add(j, "foo", "bar") | |
122 svc.add(j, 1, "hello") | |
123 svc.delete(j, "foo") | |
124 | |
125 commands = [ServiceWrapperCommand("_add", ("foo", "bar")), | |
126 ServiceWrapperCommand("_add", (1, "hello")), | |
127 ServiceWrapperCommand("_delete", ("foo",))] | |
128 | |
129 self.assertEquals(log.getCurrentIndex(), 3) | |
130 for i in range(1, 4): | |
131 for a, b in zip(commands[i-1:], [c for t, c in log.getCommandsSince(
i)]): | |
132 self.assertEquals(a, b) | |
133 | |
134 def testRecovery(self): | |
135 svc = self.svc | |
136 j = svc.journal | |
137 svc.add(j, "foo", "bar") | |
138 svc.add(j, 1, "hello") | |
139 # we sync *before* delete to make sure commands get executed | |
140 svc.journal.sync((svc.values, svc.counters)) | |
141 svc.delete(j, "foo") | |
142 d = svc.makeCounter(j, 1) | |
143 d.addCallback(lambda c, j=j: c.increment(j)) | |
144 del svc, self.svc | |
145 | |
146 # first, load from snapshot | |
147 svc = Service(self.logpath, self.journalpath) | |
148 self.assertEquals(svc.values, {1: "hello"}) | |
149 self.assertEquals(svc.counters[1].x, 1) | |
150 del svc | |
151 | |
152 # now, tamper with log, and then try | |
153 f = open(self.journalpath, "w") | |
154 f.write("sfsdfsdfsd") | |
155 f.close() | |
156 svc = Service(self.logpath, self.journalpath) | |
157 self.assertEquals(svc.values, {1: "hello"}) | |
158 self.assertEquals(svc.counters[1].x, 1) | |
159 | |
160 def testTime(self): | |
161 svc = self.svc | |
162 svc.addtime(svc.journal) | |
163 t = svc.get("time") | |
164 | |
165 log = self.svc.journal.log | |
166 (t2, c), = log.getCommandsSince(1) | |
167 self.assertEquals(t, t2) | |
168 | |
169 | |
OLD | NEW |