OLD | NEW |
| (Empty) |
1 # Copyright (c) 2001-2008 Twisted Matrix Laboratories. | |
2 # See LICENSE for details. | |
3 | |
4 | |
5 """ | |
6 Tests for twisted.enterprise.adbapi. | |
7 """ | |
8 | |
9 from twisted.trial import unittest | |
10 | |
11 import os, stat | |
12 | |
13 from twisted.enterprise.adbapi import ConnectionPool, ConnectionLost, safe | |
14 from twisted.enterprise.adbapi import _unreleasedVersion | |
15 from twisted.internet import reactor, defer, interfaces | |
16 | |
17 | |
18 simple_table_schema = """ | |
19 CREATE TABLE simple ( | |
20 x integer | |
21 ) | |
22 """ | |
23 | |
24 | |
25 class ADBAPITestBase: | |
26 """Test the asynchronous DB-API code.""" | |
27 | |
28 openfun_called = {} | |
29 | |
30 if interfaces.IReactorThreads(reactor, None) is None: | |
31 skip = "ADB-API requires threads, no way to test without them" | |
32 | |
33 def setUp(self): | |
34 self.startDB() | |
35 self.dbpool = self.makePool(cp_openfun=self.openfun) | |
36 self.dbpool.start() | |
37 | |
38 def tearDown(self): | |
39 d = self.dbpool.runOperation('DROP TABLE simple') | |
40 d.addCallback(lambda res: self.dbpool.close()) | |
41 d.addCallback(lambda res: self.stopDB()) | |
42 return d | |
43 | |
44 def openfun(self, conn): | |
45 self.openfun_called[conn] = True | |
46 | |
47 def checkOpenfunCalled(self, conn=None): | |
48 if not conn: | |
49 self.failUnless(self.openfun_called) | |
50 else: | |
51 self.failUnless(self.openfun_called.has_key(conn)) | |
52 | |
53 def testPool(self): | |
54 d = self.dbpool.runOperation(simple_table_schema) | |
55 if self.test_failures: | |
56 d.addCallback(self._testPool_1_1) | |
57 d.addCallback(self._testPool_1_2) | |
58 d.addCallback(self._testPool_1_3) | |
59 d.addCallback(self._testPool_1_4) | |
60 d.addCallback(lambda res: self.flushLoggedErrors()) | |
61 d.addCallback(self._testPool_2) | |
62 d.addCallback(self._testPool_3) | |
63 d.addCallback(self._testPool_4) | |
64 d.addCallback(self._testPool_5) | |
65 d.addCallback(self._testPool_6) | |
66 d.addCallback(self._testPool_7) | |
67 d.addCallback(self._testPool_8) | |
68 d.addCallback(self._testPool_9) | |
69 return d | |
70 | |
71 def _testPool_1_1(self, res): | |
72 d = defer.maybeDeferred(self.dbpool.runQuery, "select * from NOTABLE") | |
73 d.addCallbacks(lambda res: self.fail('no exception'), | |
74 lambda f: None) | |
75 return d | |
76 | |
77 def _testPool_1_2(self, res): | |
78 d = defer.maybeDeferred(self.dbpool.runOperation, | |
79 "deletexxx from NOTABLE") | |
80 d.addCallbacks(lambda res: self.fail('no exception'), | |
81 lambda f: None) | |
82 return d | |
83 | |
84 def _testPool_1_3(self, res): | |
85 d = defer.maybeDeferred(self.dbpool.runInteraction, | |
86 self.bad_interaction) | |
87 d.addCallbacks(lambda res: self.fail('no exception'), | |
88 lambda f: None) | |
89 return d | |
90 | |
91 def _testPool_1_4(self, res): | |
92 d = defer.maybeDeferred(self.dbpool.runWithConnection, | |
93 self.bad_withConnection) | |
94 d.addCallbacks(lambda res: self.fail('no exception'), | |
95 lambda f: None) | |
96 return d | |
97 | |
98 def _testPool_2(self, res): | |
99 # verify simple table is empty | |
100 sql = "select count(1) from simple" | |
101 d = self.dbpool.runQuery(sql) | |
102 def _check(row): | |
103 self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back") | |
104 self.checkOpenfunCalled() | |
105 d.addCallback(_check) | |
106 return d | |
107 | |
108 def _testPool_3(self, res): | |
109 sql = "select count(1) from simple" | |
110 inserts = [] | |
111 # add some rows to simple table (runOperation) | |
112 for i in range(self.num_iterations): | |
113 sql = "insert into simple(x) values(%d)" % i | |
114 inserts.append(self.dbpool.runOperation(sql)) | |
115 d = defer.gatherResults(inserts) | |
116 | |
117 def _select(res): | |
118 # make sure they were added (runQuery) | |
119 sql = "select x from simple order by x"; | |
120 d = self.dbpool.runQuery(sql) | |
121 return d | |
122 d.addCallback(_select) | |
123 | |
124 def _check(rows): | |
125 self.failUnless(len(rows) == self.num_iterations, | |
126 "Wrong number of rows") | |
127 for i in range(self.num_iterations): | |
128 self.failUnless(len(rows[i]) == 1, "Wrong size row") | |
129 self.failUnless(rows[i][0] == i, "Values not returned.") | |
130 d.addCallback(_check) | |
131 | |
132 return d | |
133 | |
134 def _testPool_4(self, res): | |
135 # runInteraction | |
136 d = self.dbpool.runInteraction(self.interaction) | |
137 d.addCallback(lambda res: self.assertEquals(res, "done")) | |
138 return d | |
139 | |
140 def _testPool_5(self, res): | |
141 # withConnection | |
142 d = self.dbpool.runWithConnection(self.withConnection) | |
143 d.addCallback(lambda res: self.assertEquals(res, "done")) | |
144 return d | |
145 | |
146 def _testPool_6(self, res): | |
147 # Test a withConnection cannot be closed | |
148 d = self.dbpool.runWithConnection(self.close_withConnection) | |
149 return d | |
150 | |
151 def _testPool_7(self, res): | |
152 # give the pool a workout | |
153 ds = [] | |
154 for i in range(self.num_iterations): | |
155 sql = "select x from simple where x = %d" % i | |
156 ds.append(self.dbpool.runQuery(sql)) | |
157 dlist = defer.DeferredList(ds, fireOnOneErrback=True) | |
158 def _check(result): | |
159 for i in range(self.num_iterations): | |
160 self.failUnless(result[i][1][0][0] == i, "Value not returned") | |
161 dlist.addCallback(_check) | |
162 return dlist | |
163 | |
164 def _testPool_8(self, res): | |
165 # now delete everything | |
166 ds = [] | |
167 for i in range(self.num_iterations): | |
168 sql = "delete from simple where x = %d" % i | |
169 ds.append(self.dbpool.runOperation(sql)) | |
170 dlist = defer.DeferredList(ds, fireOnOneErrback=True) | |
171 return dlist | |
172 | |
173 def _testPool_9(self, res): | |
174 # verify simple table is empty | |
175 sql = "select count(1) from simple" | |
176 d = self.dbpool.runQuery(sql) | |
177 def _check(row): | |
178 self.failUnless(int(row[0][0]) == 0, | |
179 "Didn't successfully delete table contents") | |
180 self.checkConnect() | |
181 d.addCallback(_check) | |
182 return d | |
183 | |
184 def checkConnect(self): | |
185 """Check the connect/disconnect synchronous calls.""" | |
186 conn = self.dbpool.connect() | |
187 self.checkOpenfunCalled(conn) | |
188 curs = conn.cursor() | |
189 curs.execute("insert into simple(x) values(1)") | |
190 curs.execute("select x from simple") | |
191 res = curs.fetchall() | |
192 self.failUnlessEqual(len(res), 1) | |
193 self.failUnlessEqual(len(res[0]), 1) | |
194 self.failUnlessEqual(res[0][0], 1) | |
195 curs.execute("delete from simple") | |
196 curs.execute("select x from simple") | |
197 self.failUnlessEqual(len(curs.fetchall()), 0) | |
198 curs.close() | |
199 self.dbpool.disconnect(conn) | |
200 | |
201 def interaction(self, transaction): | |
202 transaction.execute("select x from simple order by x") | |
203 for i in range(self.num_iterations): | |
204 row = transaction.fetchone() | |
205 self.failUnless(len(row) == 1, "Wrong size row") | |
206 self.failUnless(row[0] == i, "Value not returned.") | |
207 # should test this, but gadfly throws an exception instead | |
208 #self.failUnless(transaction.fetchone() is None, "Too many rows") | |
209 return "done" | |
210 | |
211 def bad_interaction(self, transaction): | |
212 if self.can_rollback: | |
213 transaction.execute("insert into simple(x) values(0)") | |
214 | |
215 transaction.execute("select * from NOTABLE") | |
216 | |
217 def withConnection(self, conn): | |
218 curs = conn.cursor() | |
219 try: | |
220 curs.execute("select x from simple order by x") | |
221 for i in range(self.num_iterations): | |
222 row = curs.fetchone() | |
223 self.failUnless(len(row) == 1, "Wrong size row") | |
224 self.failUnless(row[0] == i, "Value not returned.") | |
225 # should test this, but gadfly throws an exception instead | |
226 #self.failUnless(transaction.fetchone() is None, "Too many rows") | |
227 finally: | |
228 curs.close() | |
229 return "done" | |
230 | |
231 def close_withConnection(self, conn): | |
232 conn.close() | |
233 | |
234 def bad_withConnection(self, conn): | |
235 curs = conn.cursor() | |
236 try: | |
237 curs.execute("select * from NOTABLE") | |
238 finally: | |
239 curs.close() | |
240 | |
241 | |
242 class ReconnectTestBase: | |
243 """Test the asynchronous DB-API code with reconnect.""" | |
244 | |
245 if interfaces.IReactorThreads(reactor, None) is None: | |
246 skip = "ADB-API requires threads, no way to test without them" | |
247 | |
248 def setUp(self): | |
249 if self.good_sql is None: | |
250 raise unittest.SkipTest('no good sql for reconnect test') | |
251 self.startDB() | |
252 self.dbpool = self.makePool(cp_max=1, cp_reconnect=True, | |
253 cp_good_sql=self.good_sql) | |
254 self.dbpool.start() | |
255 return self.dbpool.runOperation(simple_table_schema) | |
256 | |
257 def tearDown(self): | |
258 d = self.dbpool.runOperation('DROP TABLE simple') | |
259 d.addCallback(lambda res: self.dbpool.close()) | |
260 d.addCallback(lambda res: self.stopDB()) | |
261 return d | |
262 | |
263 def testPool(self): | |
264 d = defer.succeed(None) | |
265 d.addCallback(self._testPool_1) | |
266 d.addCallback(self._testPool_2) | |
267 if not self.early_reconnect: | |
268 d.addCallback(self._testPool_3) | |
269 d.addCallback(self._testPool_4) | |
270 d.addCallback(self._testPool_5) | |
271 return d | |
272 | |
273 def _testPool_1(self, res): | |
274 sql = "select count(1) from simple" | |
275 d = self.dbpool.runQuery(sql) | |
276 def _check(row): | |
277 self.failUnless(int(row[0][0]) == 0, "Table not empty") | |
278 d.addCallback(_check) | |
279 return d | |
280 | |
281 def _testPool_2(self, res): | |
282 # reach in and close the connection manually | |
283 self.dbpool.connections.values()[0].close() | |
284 | |
285 def _testPool_3(self, res): | |
286 sql = "select count(1) from simple" | |
287 d = defer.maybeDeferred(self.dbpool.runQuery, sql) | |
288 d.addCallbacks(lambda res: self.fail('no exception'), | |
289 lambda f: f.trap(ConnectionLost)) | |
290 return d | |
291 | |
292 def _testPool_4(self, res): | |
293 sql = "select count(1) from simple" | |
294 d = self.dbpool.runQuery(sql) | |
295 def _check(row): | |
296 self.failUnless(int(row[0][0]) == 0, "Table not empty") | |
297 d.addCallback(_check) | |
298 return d | |
299 | |
300 def _testPool_5(self, res): | |
301 sql = "select * from NOTABLE" # bad sql | |
302 d = defer.maybeDeferred(self.dbpool.runQuery, sql) | |
303 d.addCallbacks(lambda res: self.fail('no exception'), | |
304 lambda f: self.failIf(f.check(ConnectionLost))) | |
305 return d | |
306 | |
307 | |
308 class DBTestConnector: | |
309 """A class which knows how to test for the presence of | |
310 and establish a connection to a relational database. | |
311 | |
312 To enable test cases which use a central, system database, | |
313 you must create a database named DB_NAME with a user DB_USER | |
314 and password DB_PASS with full access rights to database DB_NAME. | |
315 """ | |
316 | |
317 TEST_PREFIX = None # used for creating new test cases | |
318 | |
319 DB_NAME = "twisted_test" | |
320 DB_USER = 'twisted_test' | |
321 DB_PASS = 'twisted_test' | |
322 | |
323 DB_DIR = None # directory for database storage | |
324 | |
325 nulls_ok = True # nulls supported | |
326 trailing_spaces_ok = True # trailing spaces in strings preserved | |
327 can_rollback = True # rollback supported | |
328 test_failures = True # test bad sql? | |
329 escape_slashes = True # escape \ in sql? | |
330 good_sql = ConnectionPool.good_sql | |
331 early_reconnect = True # cursor() will fail on closed connection | |
332 can_clear = True # can try to clear out tables when starting | |
333 needs_dbdir = False # if a temporary directory is needed for the db | |
334 | |
335 num_iterations = 50 # number of iterations for test loops | |
336 # (lower this for slow db's) | |
337 | |
338 def setUpClass(self): | |
339 if self.needs_dbdir: | |
340 self.DB_DIR = self.mktemp() | |
341 os.mkdir(self.DB_DIR) | |
342 | |
343 if not self.can_connect(): | |
344 raise unittest.SkipTest('%s: Cannot access db' % self.TEST_PREFIX) | |
345 | |
346 def can_connect(self): | |
347 """Return true if this database is present on the system | |
348 and can be used in a test.""" | |
349 raise NotImplementedError() | |
350 | |
351 def startDB(self): | |
352 """Take any steps needed to bring database up.""" | |
353 pass | |
354 | |
355 def stopDB(self): | |
356 """Bring database down, if needed.""" | |
357 pass | |
358 | |
359 def makePool(self, **newkw): | |
360 """Create a connection pool with additional keyword arguments.""" | |
361 args, kw = self.getPoolArgs() | |
362 kw = kw.copy() | |
363 kw.update(newkw) | |
364 return ConnectionPool(*args, **kw) | |
365 | |
366 def getPoolArgs(self): | |
367 """Return a tuple (args, kw) of list and keyword arguments | |
368 that need to be passed to ConnectionPool to create a connection | |
369 to this database.""" | |
370 raise NotImplementedError() | |
371 | |
372 class GadflyConnector(DBTestConnector): | |
373 TEST_PREFIX = 'Gadfly' | |
374 | |
375 nulls_ok = False | |
376 can_rollback = False | |
377 escape_slashes = False | |
378 good_sql = 'select * from simple where 1=0' | |
379 needs_dbdir = True | |
380 | |
381 num_iterations = 1 # slow | |
382 | |
383 def can_connect(self): | |
384 try: import gadfly | |
385 except: return False | |
386 if not getattr(gadfly, 'connect', None): | |
387 gadfly.connect = gadfly.gadfly | |
388 return True | |
389 | |
390 def startDB(self): | |
391 import gadfly | |
392 conn = gadfly.gadfly() | |
393 conn.startup(self.DB_NAME, self.DB_DIR) | |
394 | |
395 # gadfly seems to want us to create something to get the db going | |
396 cursor = conn.cursor() | |
397 cursor.execute("create table x (x integer)") | |
398 conn.commit() | |
399 conn.close() | |
400 | |
401 def getPoolArgs(self): | |
402 args = ('gadfly', self.DB_NAME, self.DB_DIR) | |
403 kw = {'cp_max': 1} | |
404 return args, kw | |
405 | |
406 class SQLiteConnector(DBTestConnector): | |
407 TEST_PREFIX = 'SQLite' | |
408 | |
409 escape_slashes = False | |
410 needs_dbdir = True | |
411 | |
412 num_iterations = 1 # slow | |
413 | |
414 def can_connect(self): | |
415 try: import sqlite | |
416 except: return False | |
417 return True | |
418 | |
419 def startDB(self): | |
420 self.database = os.path.join(self.DB_DIR, self.DB_NAME) | |
421 if os.path.exists(self.database): | |
422 os.unlink(self.database) | |
423 | |
424 def getPoolArgs(self): | |
425 args = ('sqlite',) | |
426 kw = {'database': self.database, 'cp_max': 1} | |
427 return args, kw | |
428 | |
429 class PyPgSQLConnector(DBTestConnector): | |
430 TEST_PREFIX = "PyPgSQL" | |
431 | |
432 def can_connect(self): | |
433 try: from pyPgSQL import PgSQL | |
434 except: return False | |
435 try: | |
436 conn = PgSQL.connect(database=self.DB_NAME, user=self.DB_USER, | |
437 password=self.DB_PASS) | |
438 conn.close() | |
439 return True | |
440 except: | |
441 return False | |
442 | |
443 def getPoolArgs(self): | |
444 args = ('pyPgSQL.PgSQL',) | |
445 kw = {'database': self.DB_NAME, 'user': self.DB_USER, | |
446 'password': self.DB_PASS, 'cp_min': 0} | |
447 return args, kw | |
448 | |
449 class PsycopgConnector(DBTestConnector): | |
450 TEST_PREFIX = 'Psycopg' | |
451 | |
452 def can_connect(self): | |
453 try: import psycopg | |
454 except: return False | |
455 try: | |
456 conn = psycopg.connect(database=self.DB_NAME, user=self.DB_USER, | |
457 password=self.DB_PASS) | |
458 conn.close() | |
459 return True | |
460 except: | |
461 return False | |
462 | |
463 def getPoolArgs(self): | |
464 args = ('psycopg',) | |
465 kw = {'database': self.DB_NAME, 'user': self.DB_USER, | |
466 'password': self.DB_PASS, 'cp_min': 0} | |
467 return args, kw | |
468 | |
469 class MySQLConnector(DBTestConnector): | |
470 TEST_PREFIX = 'MySQL' | |
471 | |
472 trailing_spaces_ok = False | |
473 can_rollback = False | |
474 early_reconnect = False | |
475 | |
476 def can_connect(self): | |
477 try: import MySQLdb | |
478 except: return False | |
479 try: | |
480 conn = MySQLdb.connect(db=self.DB_NAME, user=self.DB_USER, | |
481 passwd=self.DB_PASS) | |
482 conn.close() | |
483 return True | |
484 except: | |
485 return False | |
486 | |
487 def getPoolArgs(self): | |
488 args = ('MySQLdb',) | |
489 kw = {'db': self.DB_NAME, 'user': self.DB_USER, 'passwd': self.DB_PASS} | |
490 return args, kw | |
491 | |
492 class FirebirdConnector(DBTestConnector): | |
493 TEST_PREFIX = 'Firebird' | |
494 | |
495 test_failures = False # failure testing causes problems | |
496 escape_slashes = False | |
497 good_sql = None # firebird doesn't handle failed sql well | |
498 can_clear = False # firebird is not so good | |
499 needs_dbdir = True | |
500 | |
501 num_iterations = 5 # slow | |
502 | |
503 def can_connect(self): | |
504 try: import kinterbasdb | |
505 except: return False | |
506 try: | |
507 self.startDB() | |
508 self.stopDB() | |
509 return True | |
510 except: | |
511 return False | |
512 | |
513 def startDB(self): | |
514 import kinterbasdb | |
515 self.DB_NAME = os.path.join(self.DB_DIR, DBTestConnector.DB_NAME) | |
516 os.chmod(self.DB_DIR, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO) | |
517 sql = 'create database "%s" user "%s" password "%s"' | |
518 sql %= (self.DB_NAME, self.DB_USER, self.DB_PASS); | |
519 conn = kinterbasdb.create_database(sql) | |
520 conn.close() | |
521 | |
522 def getPoolArgs(self): | |
523 args = ('kinterbasdb',) | |
524 kw = {'database': self.DB_NAME, 'host': '127.0.0.1', | |
525 'user': self.DB_USER, 'password': self.DB_PASS} | |
526 return args, kw | |
527 | |
528 def stopDB(self): | |
529 import kinterbasdb | |
530 conn = kinterbasdb.connect(database=self.DB_NAME, | |
531 host='127.0.0.1', user=self.DB_USER, | |
532 password=self.DB_PASS) | |
533 conn.drop_database() | |
534 | |
535 def makeSQLTests(base, suffix, globals): | |
536 """ | |
537 Make a test case for every db connector which can connect. | |
538 | |
539 @param base: Base class for test case. Additional base classes | |
540 will be a DBConnector subclass and unittest.TestCase | |
541 @param suffix: A suffix used to create test case names. Prefixes | |
542 are defined in the DBConnector subclasses. | |
543 """ | |
544 connectors = [GadflyConnector, SQLiteConnector, PyPgSQLConnector, | |
545 PsycopgConnector, MySQLConnector, FirebirdConnector] | |
546 for connclass in connectors: | |
547 name = connclass.TEST_PREFIX + suffix | |
548 import new | |
549 klass = new.classobj(name, (connclass, base, unittest.TestCase), base.__
dict__) | |
550 globals[name] = klass | |
551 | |
552 # GadflyADBAPITestCase SQLiteADBAPITestCase PyPgSQLADBAPITestCase | |
553 # PsycopgADBAPITestCase MySQLADBAPITestCase FirebirdADBAPITestCase | |
554 makeSQLTests(ADBAPITestBase, 'ADBAPITestCase', globals()) | |
555 | |
556 # GadflyReconnectTestCase SQLiteReconnectTestCase PyPgSQLReconnectTestCase | |
557 # PsycopgReconnectTestCase MySQLReconnectTestCase FirebirdReconnectTestCase | |
558 makeSQLTests(ReconnectTestBase, 'ReconnectTestCase', globals()) | |
559 | |
560 | |
561 | |
562 class DeprecationTestCase(unittest.TestCase): | |
563 """ | |
564 Test deprecations in twisted.enterprise.adbapi | |
565 """ | |
566 | |
567 def test_safe(self): | |
568 """ | |
569 Test deprecation of twisted.enterprise.adbapi.safe() | |
570 """ | |
571 result = self.callDeprecated(_unreleasedVersion, | |
572 safe, "test'") | |
573 | |
574 # make sure safe still behaves like the original | |
575 self.assertEqual(result, "test''") | |
OLD | NEW |