OLD | NEW |
| (Empty) |
1 # -*- test-case-name: twisted.test.test_persisted -*- | |
2 # Copyright (c) 2001-2004 Twisted Matrix Laboratories. | |
3 # See LICENSE for details. | |
4 | |
5 | |
6 | |
7 """ | |
8 Different styles of persisted objects. | |
9 """ | |
10 | |
11 # System Imports | |
12 import types | |
13 import copy_reg | |
14 import copy | |
15 | |
16 try: | |
17 import cStringIO as StringIO | |
18 except ImportError: | |
19 import StringIO | |
20 | |
21 # Twisted Imports | |
22 from twisted.python import log | |
23 | |
24 try: | |
25 from new import instancemethod | |
26 except: | |
27 from org.python.core import PyMethod | |
28 instancemethod = PyMethod | |
29 | |
30 oldModules = {} | |
31 | |
32 ## First, let's register support for some stuff that really ought to | |
33 ## be registerable... | |
34 | |
35 def pickleMethod(method): | |
36 'support function for copy_reg to pickle method refs' | |
37 return unpickleMethod, (method.im_func.__name__, | |
38 method.im_self, | |
39 method.im_class) | |
40 | |
41 def unpickleMethod(im_name, | |
42 im_self, | |
43 im_class): | |
44 'support function for copy_reg to unpickle method refs' | |
45 try: | |
46 unbound = getattr(im_class,im_name) | |
47 if im_self is None: | |
48 return unbound | |
49 bound=instancemethod(unbound.im_func, | |
50 im_self, | |
51 im_class) | |
52 return bound | |
53 except AttributeError: | |
54 log.msg("Method",im_name,"not on class",im_class) | |
55 assert im_self is not None,"No recourse: no instance to guess from." | |
56 # Attempt a common fix before bailing -- if classes have | |
57 # changed around since we pickled this method, we may still be | |
58 # able to get it by looking on the instance's current class. | |
59 unbound = getattr(im_self.__class__,im_name) | |
60 log.msg("Attempting fixup with",unbound) | |
61 if im_self is None: | |
62 return unbound | |
63 bound=instancemethod(unbound.im_func, | |
64 im_self, | |
65 im_self.__class__) | |
66 return bound | |
67 | |
68 copy_reg.pickle(types.MethodType, | |
69 pickleMethod, | |
70 unpickleMethod) | |
71 | |
72 def pickleModule(module): | |
73 'support function for copy_reg to pickle module refs' | |
74 return unpickleModule, (module.__name__,) | |
75 | |
76 def unpickleModule(name): | |
77 'support function for copy_reg to unpickle module refs' | |
78 if oldModules.has_key(name): | |
79 log.msg("Module has moved: %s" % name) | |
80 name = oldModules[name] | |
81 log.msg(name) | |
82 return __import__(name,{},{},'x') | |
83 | |
84 | |
85 copy_reg.pickle(types.ModuleType, | |
86 pickleModule, | |
87 unpickleModule) | |
88 | |
89 def pickleStringO(stringo): | |
90 'support function for copy_reg to pickle StringIO.OutputTypes' | |
91 return unpickleStringO, (stringo.getvalue(), stringo.tell()) | |
92 | |
93 def unpickleStringO(val, sek): | |
94 x = StringIO.StringIO() | |
95 x.write(val) | |
96 x.seek(sek) | |
97 return x | |
98 | |
99 if hasattr(StringIO, 'OutputType'): | |
100 copy_reg.pickle(StringIO.OutputType, | |
101 pickleStringO, | |
102 unpickleStringO) | |
103 | |
104 def pickleStringI(stringi): | |
105 return unpickleStringI, (stringi.getvalue(), stringi.tell()) | |
106 | |
107 def unpickleStringI(val, sek): | |
108 x = StringIO.StringIO(val) | |
109 x.seek(sek) | |
110 return x | |
111 | |
112 | |
113 if hasattr(StringIO, 'InputType'): | |
114 copy_reg.pickle(StringIO.InputType, | |
115 pickleStringI, | |
116 unpickleStringI) | |
117 | |
118 class Ephemeral: | |
119 """ | |
120 This type of object is never persisted; if possible, even references to it | |
121 are eliminated. | |
122 """ | |
123 | |
124 def __getstate__(self): | |
125 log.msg( "WARNING: serializing ephemeral %s" % self ) | |
126 import gc | |
127 for r in gc.get_referrers(self): | |
128 log.msg( " referred to by %s" % (r,)) | |
129 return None | |
130 | |
131 def __setstate__(self, state): | |
132 log.msg( "WARNING: unserializing ephemeral %s" % self.__class__ ) | |
133 self.__class__ = Ephemeral | |
134 | |
135 | |
136 versionedsToUpgrade = {} | |
137 upgraded = {} | |
138 | |
139 def doUpgrade(): | |
140 global versionedsToUpgrade, upgraded | |
141 for versioned in versionedsToUpgrade.values(): | |
142 requireUpgrade(versioned) | |
143 versionedsToUpgrade = {} | |
144 upgraded = {} | |
145 | |
146 def requireUpgrade(obj): | |
147 """Require that a Versioned instance be upgraded completely first. | |
148 """ | |
149 objID = id(obj) | |
150 if objID in versionedsToUpgrade and objID not in upgraded: | |
151 upgraded[objID] = 1 | |
152 obj.versionUpgrade() | |
153 return obj | |
154 | |
155 from twisted.python import reflect | |
156 | |
157 def _aybabtu(c): | |
158 l = [] | |
159 for b in reflect.allYourBase(c, Versioned): | |
160 if b not in l and b is not Versioned: | |
161 l.append(b) | |
162 return l | |
163 | |
164 class Versioned: | |
165 """ | |
166 This type of object is persisted with versioning information. | |
167 | |
168 I have a single class attribute, the int persistenceVersion. After I am | |
169 unserialized (and styles.doUpgrade() is called), self.upgradeToVersionX() | |
170 will be called for each version upgrade I must undergo. | |
171 | |
172 For example, if I serialize an instance of a Foo(Versioned) at version 4 | |
173 and then unserialize it when the code is at version 9, the calls:: | |
174 | |
175 self.upgradeToVersion5() | |
176 self.upgradeToVersion6() | |
177 self.upgradeToVersion7() | |
178 self.upgradeToVersion8() | |
179 self.upgradeToVersion9() | |
180 | |
181 will be made. If any of these methods are undefined, a warning message | |
182 will be printed. | |
183 """ | |
184 persistenceVersion = 0 | |
185 persistenceForgets = () | |
186 | |
187 def __setstate__(self, state): | |
188 versionedsToUpgrade[id(self)] = self | |
189 self.__dict__ = state | |
190 | |
191 def __getstate__(self, dict=None): | |
192 """Get state, adding a version number to it on its way out. | |
193 """ | |
194 dct = copy.copy(dict or self.__dict__) | |
195 bases = _aybabtu(self.__class__) | |
196 bases.reverse() | |
197 bases.append(self.__class__) # don't forget me!! | |
198 for base in bases: | |
199 if base.__dict__.has_key('persistenceForgets'): | |
200 for slot in base.persistenceForgets: | |
201 if dct.has_key(slot): | |
202 del dct[slot] | |
203 if base.__dict__.has_key('persistenceVersion'): | |
204 dct['%s.persistenceVersion' % reflect.qual(base)] = base.persist
enceVersion | |
205 return dct | |
206 | |
207 def versionUpgrade(self): | |
208 """(internal) Do a version upgrade. | |
209 """ | |
210 bases = _aybabtu(self.__class__) | |
211 # put the bases in order so superclasses' persistenceVersion methods | |
212 # will be called first. | |
213 bases.reverse() | |
214 bases.append(self.__class__) # don't forget me!! | |
215 # first let's look for old-skool versioned's | |
216 if self.__dict__.has_key("persistenceVersion"): | |
217 | |
218 # Hacky heuristic: if more than one class subclasses Versioned, | |
219 # we'll assume that the higher version number wins for the older | |
220 # class, so we'll consider the attribute the version of the older | |
221 # class. There are obviously possibly times when this will | |
222 # eventually be an incorrect assumption, but hopefully old-school | |
223 # persistenceVersion stuff won't make it that far into multiple | |
224 # classes inheriting from Versioned. | |
225 | |
226 pver = self.__dict__['persistenceVersion'] | |
227 del self.__dict__['persistenceVersion'] | |
228 highestVersion = 0 | |
229 highestBase = None | |
230 for base in bases: | |
231 if not base.__dict__.has_key('persistenceVersion'): | |
232 continue | |
233 if base.persistenceVersion > highestVersion: | |
234 highestBase = base | |
235 highestVersion = base.persistenceVersion | |
236 if highestBase: | |
237 self.__dict__['%s.persistenceVersion' % reflect.qual(highestBase
)] = pver | |
238 for base in bases: | |
239 # ugly hack, but it's what the user expects, really | |
240 if (Versioned not in base.__bases__ and | |
241 not base.__dict__.has_key('persistenceVersion')): | |
242 continue | |
243 currentVers = base.persistenceVersion | |
244 pverName = '%s.persistenceVersion' % reflect.qual(base) | |
245 persistVers = (self.__dict__.get(pverName) or 0) | |
246 if persistVers: | |
247 del self.__dict__[pverName] | |
248 assert persistVers <= currentVers, "Sorry, can't go backwards in ti
me." | |
249 while persistVers < currentVers: | |
250 persistVers = persistVers + 1 | |
251 method = base.__dict__.get('upgradeToVersion%s' % persistVers, N
one) | |
252 if method: | |
253 log.msg( "Upgrading %s (of %s @ %s) to version %s" % (reflec
t.qual(base), reflect.qual(self.__class__), id(self), persistVers) ) | |
254 method(self) | |
255 else: | |
256 log.msg( 'Warning: cannot upgrade %s to version %s' % (base,
persistVers) ) | |
OLD | NEW |