| OLD | NEW |
| (Empty) |
| 1 # Copyright (c) 2001-2004 Twisted Matrix Laboratories. | |
| 2 # See LICENSE for details. | |
| 3 | |
| 4 | |
| 5 """Testing for twisted.persisted.journal.""" | |
| 6 | |
| 7 from twisted.trial import unittest | |
| 8 from twisted.persisted.journal.base import ICommand, MemoryJournal, serviceComma
nd, ServiceWrapperCommand, command, Wrappable | |
| 9 from twisted.persisted.journal.picklelog import DirDBMLog | |
| 10 from zope.interface import implements | |
| 11 | |
| 12 import shutil, os.path | |
| 13 | |
| 14 | |
| 15 | |
| 16 class AddTime: | |
| 17 | |
| 18 implements(ICommand) | |
| 19 | |
| 20 def execute(self, svc, cmdtime): | |
| 21 svc.values["time"] = cmdtime | |
| 22 | |
| 23 | |
| 24 class Counter(Wrappable): | |
| 25 | |
| 26 objectType = "counter" | |
| 27 | |
| 28 def __init__(self, uid): | |
| 29 self.uid = uid | |
| 30 self.x = 0 | |
| 31 | |
| 32 def getUid(self): | |
| 33 return self.uid | |
| 34 | |
| 35 def _increment(self): | |
| 36 self.x += 1 | |
| 37 | |
| 38 increment = command("_increment") | |
| 39 | |
| 40 | |
| 41 class Service: | |
| 42 | |
| 43 def __init__(self, logpath, journalpath): | |
| 44 log = DirDBMLog(logpath) | |
| 45 self.journal = MemoryJournal(log, self, journalpath, self._gotData) | |
| 46 self.journal.updateFromLog() | |
| 47 | |
| 48 def _gotData(self, result): | |
| 49 if result is None: | |
| 50 self.values = {} | |
| 51 self.counters = {} | |
| 52 else: | |
| 53 self.values, self.counters = result | |
| 54 | |
| 55 def _makeCounter(self, id): | |
| 56 c = Counter(id) | |
| 57 self.counters[id] = c | |
| 58 return c | |
| 59 | |
| 60 makeCounter = serviceCommand("_makeCounter") | |
| 61 | |
| 62 def loadObject(self, type, id): | |
| 63 if type != "counter": raise ValueError | |
| 64 return self.counters[id] | |
| 65 | |
| 66 def _add(self, key, value): | |
| 67 """Add a new entry.""" | |
| 68 self.values[key] = value | |
| 69 | |
| 70 def _delete(self, key): | |
| 71 """Delete an entry.""" | |
| 72 del self.values[key] | |
| 73 | |
| 74 def get(self, key): | |
| 75 """Return value of an entry.""" | |
| 76 return self.values[key] | |
| 77 | |
| 78 def addtime(self, journal): | |
| 79 """Set a key 'time' with the current time.""" | |
| 80 journal.executeCommand(AddTime()) | |
| 81 | |
| 82 # and now the command wrappers | |
| 83 | |
| 84 add = serviceCommand("_add") | |
| 85 | |
| 86 delete = serviceCommand("_delete") | |
| 87 | |
| 88 | |
| 89 class JournalTestCase(unittest.TestCase): | |
| 90 | |
| 91 def setUp(self): | |
| 92 self.logpath = self.mktemp() | |
| 93 self.journalpath = self.mktemp() | |
| 94 self.svc = Service(self.logpath, self.journalpath) | |
| 95 | |
| 96 def tearDown(self): | |
| 97 if hasattr(self, "svc"): | |
| 98 del self.svc | |
| 99 # delete stuff? ... | |
| 100 if os.path.isdir(self.logpath): | |
| 101 shutil.rmtree(self.logpath) | |
| 102 if os.path.exists(self.logpath): | |
| 103 os.unlink(self.logpath) | |
| 104 if os.path.isdir(self.journalpath): | |
| 105 shutil.rmtree(self.journalpath) | |
| 106 if os.path.exists(self.journalpath): | |
| 107 os.unlink(self.journalpath) | |
| 108 | |
| 109 def testCommandExecution(self): | |
| 110 svc = self.svc | |
| 111 svc.add(svc.journal, "foo", "bar") | |
| 112 self.assertEquals(svc.get("foo"), "bar") | |
| 113 | |
| 114 svc.delete(svc.journal, "foo") | |
| 115 self.assertRaises(KeyError, svc.get, "foo") | |
| 116 | |
| 117 def testLogging(self): | |
| 118 svc = self.svc | |
| 119 log = self.svc.journal.log | |
| 120 j = self.svc.journal | |
| 121 svc.add(j, "foo", "bar") | |
| 122 svc.add(j, 1, "hello") | |
| 123 svc.delete(j, "foo") | |
| 124 | |
| 125 commands = [ServiceWrapperCommand("_add", ("foo", "bar")), | |
| 126 ServiceWrapperCommand("_add", (1, "hello")), | |
| 127 ServiceWrapperCommand("_delete", ("foo",))] | |
| 128 | |
| 129 self.assertEquals(log.getCurrentIndex(), 3) | |
| 130 for i in range(1, 4): | |
| 131 for a, b in zip(commands[i-1:], [c for t, c in log.getCommandsSince(
i)]): | |
| 132 self.assertEquals(a, b) | |
| 133 | |
| 134 def testRecovery(self): | |
| 135 svc = self.svc | |
| 136 j = svc.journal | |
| 137 svc.add(j, "foo", "bar") | |
| 138 svc.add(j, 1, "hello") | |
| 139 # we sync *before* delete to make sure commands get executed | |
| 140 svc.journal.sync((svc.values, svc.counters)) | |
| 141 svc.delete(j, "foo") | |
| 142 d = svc.makeCounter(j, 1) | |
| 143 d.addCallback(lambda c, j=j: c.increment(j)) | |
| 144 del svc, self.svc | |
| 145 | |
| 146 # first, load from snapshot | |
| 147 svc = Service(self.logpath, self.journalpath) | |
| 148 self.assertEquals(svc.values, {1: "hello"}) | |
| 149 self.assertEquals(svc.counters[1].x, 1) | |
| 150 del svc | |
| 151 | |
| 152 # now, tamper with log, and then try | |
| 153 f = open(self.journalpath, "w") | |
| 154 f.write("sfsdfsdfsd") | |
| 155 f.close() | |
| 156 svc = Service(self.logpath, self.journalpath) | |
| 157 self.assertEquals(svc.values, {1: "hello"}) | |
| 158 self.assertEquals(svc.counters[1].x, 1) | |
| 159 | |
| 160 def testTime(self): | |
| 161 svc = self.svc | |
| 162 svc.addtime(svc.journal) | |
| 163 t = svc.get("time") | |
| 164 | |
| 165 log = self.svc.journal.log | |
| 166 (t2, c), = log.getCommandsSince(1) | |
| 167 self.assertEquals(t, t2) | |
| 168 | |
| 169 | |
| OLD | NEW |