OLD | NEW |
(Empty) | |
| 1 # Copyright 2016 The Chromium Authors. All rights reserved. |
| 2 # Use of this source code is govered by a BSD-style |
| 3 # license that can be found in the LICENSE file or at |
| 4 # https://developers.google.com/open-source/licenses/bsd |
| 5 |
| 6 """Unit tests for the sql module.""" |
| 7 |
| 8 import unittest |
| 9 |
| 10 import settings |
| 11 from framework import sql |
| 12 |
| 13 |
| 14 class MockSQLCnxn(object): |
| 15 """This class mocks the connection and cursor classes.""" |
| 16 |
| 17 def __init__(self, instance, database): |
| 18 self.instance = instance |
| 19 self.database = database |
| 20 self.last_executed = None |
| 21 self.last_executed_args = None |
| 22 self.result_rows = None |
| 23 self.rowcount = 0 |
| 24 self.lastrowid = None |
| 25 |
| 26 def execute(self, stmt_str, args=None): |
| 27 self.last_executed = stmt_str % tuple(args or []) |
| 28 |
| 29 def executemany(self, stmt_str, args): |
| 30 # We cannot format the string because args has many values for each %s. |
| 31 self.last_executed = stmt_str |
| 32 self.last_executed_args = tuple(args) |
| 33 |
| 34 if stmt_str.startswith('INSERT'): |
| 35 self.lastrowid = 123 |
| 36 |
| 37 def fetchall(self): |
| 38 return self.result_rows |
| 39 |
| 40 def cursor(self): |
| 41 return self |
| 42 |
| 43 def close(self): |
| 44 pass |
| 45 |
| 46 def commit(self): |
| 47 pass |
| 48 |
| 49 |
| 50 sql.MakeConnection = MockSQLCnxn |
| 51 |
| 52 |
| 53 class MonorailConnectionTest(unittest.TestCase): |
| 54 |
| 55 def setUp(self): |
| 56 self.cnxn = sql.MonorailConnection() |
| 57 self.orig_dev_mode = settings.dev_mode |
| 58 self.orig_num_logical_shards = settings.num_logical_shards |
| 59 settings.dev_mode = False |
| 60 |
| 61 def tearDown(self): |
| 62 settings.dev_mode = self.orig_dev_mode |
| 63 settings.num_logical_shards = self.orig_num_logical_shards |
| 64 |
| 65 def testGetMasterConnection(self): |
| 66 sql_cnxn = self.cnxn.GetMasterConnection() |
| 67 self.assertEqual(settings.db_instance, sql_cnxn.instance) |
| 68 self.assertEqual(settings.db_database_name, sql_cnxn.database) |
| 69 |
| 70 sql_cnxn2 = self.cnxn.GetMasterConnection() |
| 71 self.assertIs(sql_cnxn2, sql_cnxn) |
| 72 |
| 73 def testGetConnectionForShard(self): |
| 74 sql_cnxn = self.cnxn.GetConnectionForShard(1) |
| 75 self.assertEqual(settings.physical_db_name_format % 1, |
| 76 sql_cnxn.instance) |
| 77 self.assertEqual(settings.db_database_name, sql_cnxn.database) |
| 78 |
| 79 sql_cnxn2 = self.cnxn.GetConnectionForShard(1) |
| 80 self.assertIs(sql_cnxn2, sql_cnxn) |
| 81 |
| 82 |
| 83 class TableManagerTest(unittest.TestCase): |
| 84 |
| 85 def setUp(self): |
| 86 self.emp_tbl = sql.SQLTableManager('Employee') |
| 87 self.cnxn = sql.MonorailConnection() |
| 88 self.master_cnxn = self.cnxn.GetMasterConnection() |
| 89 |
| 90 def testSelect_Trivial(self): |
| 91 self.master_cnxn.result_rows = [(111, True), (222, False)] |
| 92 rows = self.emp_tbl.Select(self.cnxn) |
| 93 self.assertEqual('SELECT * FROM Employee', self.master_cnxn.last_executed) |
| 94 self.assertEqual([(111, True), (222, False)], rows) |
| 95 |
| 96 def testSelect_Conditions(self): |
| 97 self.master_cnxn.result_rows = [(111,)] |
| 98 rows = self.emp_tbl.Select( |
| 99 self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[10, 20]) |
| 100 self.assertEqual( |
| 101 'SELECT emp_id FROM Employee' |
| 102 '\nWHERE dept_id IN (10,20)' |
| 103 '\n AND fulltime = 1', |
| 104 self.master_cnxn.last_executed) |
| 105 self.assertEqual([(111,)], rows) |
| 106 |
| 107 def testSelectRow(self): |
| 108 self.master_cnxn.result_rows = [(111,)] |
| 109 row = self.emp_tbl.SelectRow( |
| 110 self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[10, 20]) |
| 111 self.assertEqual( |
| 112 'SELECT DISTINCT emp_id FROM Employee' |
| 113 '\nWHERE dept_id IN (10,20)' |
| 114 '\n AND fulltime = 1', |
| 115 self.master_cnxn.last_executed) |
| 116 self.assertEqual((111,), row) |
| 117 |
| 118 def testSelectRow_NoMatches(self): |
| 119 self.master_cnxn.result_rows = [] |
| 120 row = self.emp_tbl.SelectRow( |
| 121 self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[99]) |
| 122 self.assertEqual( |
| 123 'SELECT DISTINCT emp_id FROM Employee' |
| 124 '\nWHERE dept_id IN (99)' |
| 125 '\n AND fulltime = 1', |
| 126 self.master_cnxn.last_executed) |
| 127 self.assertEqual(None, row) |
| 128 |
| 129 row = self.emp_tbl.SelectRow( |
| 130 self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[99], |
| 131 default=(-1,)) |
| 132 self.assertEqual((-1,), row) |
| 133 |
| 134 def testSelectValue(self): |
| 135 self.master_cnxn.result_rows = [(111,)] |
| 136 val = self.emp_tbl.SelectValue( |
| 137 self.cnxn, 'emp_id', fulltime=True, dept_id=[10, 20]) |
| 138 self.assertEqual( |
| 139 'SELECT DISTINCT emp_id FROM Employee' |
| 140 '\nWHERE dept_id IN (10,20)' |
| 141 '\n AND fulltime = 1', |
| 142 self.master_cnxn.last_executed) |
| 143 self.assertEqual(111, val) |
| 144 |
| 145 def testSelectValue_NoMatches(self): |
| 146 self.master_cnxn.result_rows = [] |
| 147 val = self.emp_tbl.SelectValue( |
| 148 self.cnxn, 'emp_id', fulltime=True, dept_id=[99]) |
| 149 self.assertEqual( |
| 150 'SELECT DISTINCT emp_id FROM Employee' |
| 151 '\nWHERE dept_id IN (99)' |
| 152 '\n AND fulltime = 1', |
| 153 self.master_cnxn.last_executed) |
| 154 self.assertEqual(None, val) |
| 155 |
| 156 val = self.emp_tbl.SelectValue( |
| 157 self.cnxn, 'emp_id', fulltime=True, dept_id=[99], |
| 158 default=-1) |
| 159 self.assertEqual(-1, val) |
| 160 |
| 161 def testInsertRow(self): |
| 162 self.master_cnxn.rowcount = 1 |
| 163 generated_id = self.emp_tbl.InsertRow(self.cnxn, emp_id=111, fulltime=True) |
| 164 self.assertEqual( |
| 165 'INSERT INTO Employee (emp_id, fulltime)' |
| 166 '\nVALUES (%s,%s)', |
| 167 self.master_cnxn.last_executed) |
| 168 self.assertEqual( |
| 169 ([111, 1],), |
| 170 self.master_cnxn.last_executed_args) |
| 171 self.assertEqual(123, generated_id) |
| 172 |
| 173 def testInsertRows_Empty(self): |
| 174 generated_id = self.emp_tbl.InsertRows( |
| 175 self.cnxn, ['emp_id', 'fulltime'], []) |
| 176 self.assertIsNone(self.master_cnxn.last_executed) |
| 177 self.assertIsNone(self.master_cnxn.last_executed_args) |
| 178 self.assertEqual(None, generated_id) |
| 179 |
| 180 def testInsertRows(self): |
| 181 self.master_cnxn.rowcount = 2 |
| 182 generated_ids = self.emp_tbl.InsertRows( |
| 183 self.cnxn, ['emp_id', 'fulltime'], [(111, True), (222, False)]) |
| 184 self.assertEqual( |
| 185 'INSERT INTO Employee (emp_id, fulltime)' |
| 186 '\nVALUES (%s,%s)', |
| 187 self.master_cnxn.last_executed) |
| 188 self.assertEqual( |
| 189 ([111, 1], [222, 0]), |
| 190 self.master_cnxn.last_executed_args) |
| 191 self.assertEqual([], generated_ids) |
| 192 |
| 193 def testUpdate(self): |
| 194 self.master_cnxn.rowcount = 2 |
| 195 rowcount = self.emp_tbl.Update( |
| 196 self.cnxn, {'fulltime': True}, emp_id=[111, 222]) |
| 197 self.assertEqual( |
| 198 'UPDATE Employee SET fulltime=1' |
| 199 '\nWHERE emp_id IN (111,222)', |
| 200 self.master_cnxn.last_executed) |
| 201 self.assertEqual(2, rowcount) |
| 202 |
| 203 def testIncrementCounterValue(self): |
| 204 self.master_cnxn.rowcount = 1 |
| 205 self.master_cnxn.lastrowid = 9 |
| 206 new_counter_val = self.emp_tbl.IncrementCounterValue( |
| 207 self.cnxn, 'years_worked', emp_id=111) |
| 208 self.assertEqual( |
| 209 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + 1)' |
| 210 '\nWHERE emp_id = 111', |
| 211 self.master_cnxn.last_executed) |
| 212 self.assertEqual(9, new_counter_val) |
| 213 |
| 214 def testDelete(self): |
| 215 self.master_cnxn.rowcount = 1 |
| 216 rowcount = self.emp_tbl.Delete(self.cnxn, fulltime=True) |
| 217 self.assertEqual( |
| 218 'DELETE FROM Employee' |
| 219 '\nWHERE fulltime = 1', |
| 220 self.master_cnxn.last_executed) |
| 221 self.assertEqual(1, rowcount) |
| 222 |
| 223 |
| 224 class StatementTest(unittest.TestCase): |
| 225 |
| 226 def testMakeSelect(self): |
| 227 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| 228 stmt_str, args = stmt.Generate() |
| 229 self.assertEqual( |
| 230 'SELECT emp_id, fulltime FROM Employee', |
| 231 stmt_str) |
| 232 self.assertEqual([], args) |
| 233 |
| 234 stmt = sql.Statement.MakeSelect( |
| 235 'Employee', ['emp_id', 'fulltime'], distinct=True) |
| 236 stmt_str, args = stmt.Generate() |
| 237 self.assertEqual( |
| 238 'SELECT DISTINCT emp_id, fulltime FROM Employee', |
| 239 stmt_str) |
| 240 self.assertEqual([], args) |
| 241 |
| 242 def testMakeInsert(self): |
| 243 stmt = sql.Statement.MakeInsert( |
| 244 'Employee', ['emp_id', 'fulltime'], [(111, True), (222, False)]) |
| 245 stmt_str, args = stmt.Generate() |
| 246 self.assertEqual( |
| 247 'INSERT INTO Employee (emp_id, fulltime)' |
| 248 '\nVALUES (%s,%s)', |
| 249 stmt_str) |
| 250 self.assertEqual([[111, 1], [222, 0]], args) |
| 251 |
| 252 stmt = sql.Statement.MakeInsert( |
| 253 'Employee', ['emp_id', 'fulltime'], [(111, False)], replace=True) |
| 254 stmt_str, args = stmt.Generate() |
| 255 self.assertEqual( |
| 256 'INSERT INTO Employee (emp_id, fulltime)' |
| 257 '\nVALUES (%s,%s)' |
| 258 '\nON DUPLICATE KEY UPDATE ' |
| 259 'emp_id=VALUES(emp_id), fulltime=VALUES(fulltime)', |
| 260 stmt_str) |
| 261 self.assertEqual([[111, 0]], args) |
| 262 |
| 263 stmt = sql.Statement.MakeInsert( |
| 264 'Employee', ['emp_id', 'fulltime'], [(111, False)], ignore=True) |
| 265 stmt_str, args = stmt.Generate() |
| 266 self.assertEqual( |
| 267 'INSERT IGNORE INTO Employee (emp_id, fulltime)' |
| 268 '\nVALUES (%s,%s)', |
| 269 stmt_str) |
| 270 self.assertEqual([[111, 0]], args) |
| 271 |
| 272 def testMakeUpdate(self): |
| 273 stmt = sql.Statement.MakeUpdate('Employee', {'fulltime': True}) |
| 274 stmt_str, args = stmt.Generate() |
| 275 self.assertEqual( |
| 276 'UPDATE Employee SET fulltime=%s', |
| 277 stmt_str) |
| 278 self.assertEqual([1], args) |
| 279 |
| 280 def testMakeIncrement(self): |
| 281 stmt = sql.Statement.MakeIncrement('Employee', 'years_worked') |
| 282 stmt_str, args = stmt.Generate() |
| 283 self.assertEqual( |
| 284 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + %s)', |
| 285 stmt_str) |
| 286 self.assertEqual([1], args) |
| 287 |
| 288 stmt = sql.Statement.MakeIncrement('Employee', 'years_worked', step=5) |
| 289 stmt_str, args = stmt.Generate() |
| 290 self.assertEqual( |
| 291 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + %s)', |
| 292 stmt_str) |
| 293 self.assertEqual([5], args) |
| 294 |
| 295 def testMakeDelete(self): |
| 296 stmt = sql.Statement.MakeDelete('Employee') |
| 297 stmt_str, args = stmt.Generate() |
| 298 self.assertEqual( |
| 299 'DELETE FROM Employee', |
| 300 stmt_str) |
| 301 self.assertEqual([], args) |
| 302 |
| 303 def testAddUseClause(self): |
| 304 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| 305 stmt.AddUseClause('USE INDEX (emp_id) USE INDEX FOR ORDER BY (emp_id)') |
| 306 stmt.AddOrderByTerms([('emp_id', [])]) |
| 307 stmt_str, args = stmt.Generate() |
| 308 self.assertEqual( |
| 309 'SELECT emp_id, fulltime FROM Employee' |
| 310 '\nUSE INDEX (emp_id) USE INDEX FOR ORDER BY (emp_id)' |
| 311 '\nORDER BY emp_id', |
| 312 stmt_str) |
| 313 self.assertEqual([], args) |
| 314 |
| 315 def testAddJoinClause_Empty(self): |
| 316 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| 317 stmt.AddJoinClauses([]) |
| 318 stmt_str, args = stmt.Generate() |
| 319 self.assertEqual( |
| 320 'SELECT emp_id, fulltime FROM Employee', |
| 321 stmt_str) |
| 322 self.assertEqual([], args) |
| 323 |
| 324 def testAddJoinClause(self): |
| 325 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| 326 stmt.AddJoinClauses([('CorporateHoliday', [])]) |
| 327 stmt.AddJoinClauses( |
| 328 [('Product ON Project.inventor_id = emp_id', [])], left=True) |
| 329 stmt_str, args = stmt.Generate() |
| 330 self.assertEqual( |
| 331 'SELECT emp_id, fulltime FROM Employee' |
| 332 '\n JOIN CorporateHoliday' |
| 333 '\n LEFT JOIN Product ON Project.inventor_id = emp_id', |
| 334 stmt_str) |
| 335 self.assertEqual([], args) |
| 336 |
| 337 def testAddGroupByTerms_Empty(self): |
| 338 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| 339 stmt.AddGroupByTerms([]) |
| 340 stmt_str, args = stmt.Generate() |
| 341 self.assertEqual( |
| 342 'SELECT emp_id, fulltime FROM Employee', |
| 343 stmt_str) |
| 344 self.assertEqual([], args) |
| 345 |
| 346 def testAddGroupByTerms(self): |
| 347 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| 348 stmt.AddGroupByTerms(['dept_id', 'location_id']) |
| 349 stmt_str, args = stmt.Generate() |
| 350 self.assertEqual( |
| 351 'SELECT emp_id, fulltime FROM Employee' |
| 352 '\nGROUP BY dept_id, location_id', |
| 353 stmt_str) |
| 354 self.assertEqual([], args) |
| 355 |
| 356 def testAddOrderByTerms_Empty(self): |
| 357 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| 358 stmt.AddOrderByTerms([]) |
| 359 stmt_str, args = stmt.Generate() |
| 360 self.assertEqual( |
| 361 'SELECT emp_id, fulltime FROM Employee', |
| 362 stmt_str) |
| 363 self.assertEqual([], args) |
| 364 |
| 365 def testAddOrderByTerms(self): |
| 366 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| 367 stmt.AddOrderByTerms([('dept_id', []), ('emp_id DESC', [])]) |
| 368 stmt_str, args = stmt.Generate() |
| 369 self.assertEqual( |
| 370 'SELECT emp_id, fulltime FROM Employee' |
| 371 '\nORDER BY dept_id, emp_id DESC', |
| 372 stmt_str) |
| 373 self.assertEqual([], args) |
| 374 |
| 375 def testSetLimitAndOffset(self): |
| 376 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| 377 stmt.SetLimitAndOffset(100, 0) |
| 378 stmt_str, args = stmt.Generate() |
| 379 self.assertEqual( |
| 380 'SELECT emp_id, fulltime FROM Employee' |
| 381 '\nLIMIT 100', |
| 382 stmt_str) |
| 383 self.assertEqual([], args) |
| 384 |
| 385 stmt.SetLimitAndOffset(100, 500) |
| 386 stmt_str, args = stmt.Generate() |
| 387 self.assertEqual( |
| 388 'SELECT emp_id, fulltime FROM Employee' |
| 389 '\nLIMIT 100 OFFSET 500', |
| 390 stmt_str) |
| 391 self.assertEqual([], args) |
| 392 |
| 393 def testAddWhereTerms_Select(self): |
| 394 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| 395 stmt.AddWhereTerms([], emp_id=[111, 222]) |
| 396 stmt_str, args = stmt.Generate() |
| 397 self.assertEqual( |
| 398 'SELECT emp_id, fulltime FROM Employee' |
| 399 '\nWHERE emp_id IN (%s,%s)', |
| 400 stmt_str) |
| 401 self.assertEqual([111, 222], args) |
| 402 |
| 403 def testAddWhereTerms_Update(self): |
| 404 stmt = sql.Statement.MakeUpdate('Employee', {'fulltime': True}) |
| 405 stmt.AddWhereTerms([], emp_id=[111, 222]) |
| 406 stmt_str, args = stmt.Generate() |
| 407 self.assertEqual( |
| 408 'UPDATE Employee SET fulltime=%s' |
| 409 '\nWHERE emp_id IN (%s,%s)', |
| 410 stmt_str) |
| 411 self.assertEqual([1, 111, 222], args) |
| 412 |
| 413 def testAddWhereTerms_Delete(self): |
| 414 stmt = sql.Statement.MakeDelete('Employee') |
| 415 stmt.AddWhereTerms([], emp_id=[111, 222]) |
| 416 stmt_str, args = stmt.Generate() |
| 417 self.assertEqual( |
| 418 'DELETE FROM Employee' |
| 419 '\nWHERE emp_id IN (%s,%s)', |
| 420 stmt_str) |
| 421 self.assertEqual([111, 222], args) |
| 422 |
| 423 def testAddWhereTerms_Empty(self): |
| 424 """Add empty terms should have no effect.""" |
| 425 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| 426 stmt.AddWhereTerms([]) |
| 427 stmt_str, args = stmt.Generate() |
| 428 self.assertEqual( |
| 429 'SELECT emp_id, fulltime FROM Employee', |
| 430 stmt_str) |
| 431 self.assertEqual([], args) |
| 432 |
| 433 def testAddWhereTerms_MulitpleTerms(self): |
| 434 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| 435 stmt.AddWhereTerms( |
| 436 [('emp_id %% %s = %s', [2, 0])], fulltime=True, emp_id_not=222) |
| 437 stmt_str, args = stmt.Generate() |
| 438 self.assertEqual( |
| 439 'SELECT emp_id, fulltime FROM Employee' |
| 440 '\nWHERE emp_id %% %s = %s' |
| 441 '\n AND emp_id != %s' |
| 442 '\n AND fulltime = %s', |
| 443 stmt_str) |
| 444 self.assertEqual([2, 0, 222, 1], args) |
| 445 |
| 446 |
| 447 |
| 448 class FunctionsTest(unittest.TestCase): |
| 449 |
| 450 def testBoolsToInts_NoChanges(self): |
| 451 self.assertEqual(['hello'], sql._BoolsToInts(['hello'])) |
| 452 self.assertEqual([['hello']], sql._BoolsToInts([['hello']])) |
| 453 self.assertEqual([['hello']], sql._BoolsToInts([('hello',)])) |
| 454 self.assertEqual([12], sql._BoolsToInts([12])) |
| 455 self.assertEqual([[12]], sql._BoolsToInts([[12]])) |
| 456 self.assertEqual([[12]], sql._BoolsToInts([(12,)])) |
| 457 self.assertEqual( |
| 458 [12, 13, 'hi', [99, 'yo']], |
| 459 sql._BoolsToInts([12, 13, 'hi', [99, 'yo']])) |
| 460 |
| 461 def testBoolsToInts_WithChanges(self): |
| 462 self.assertEqual([1, 0], sql._BoolsToInts([True, False])) |
| 463 self.assertEqual([[1, 0]], sql._BoolsToInts([[True, False]])) |
| 464 self.assertEqual([[1, 0]], sql._BoolsToInts([(True, False)])) |
| 465 self.assertEqual( |
| 466 [12, 1, 'hi', [0, 'yo']], |
| 467 sql._BoolsToInts([12, True, 'hi', [False, 'yo']])) |
| 468 |
| 469 |
| 470 if __name__ == '__main__': |
| 471 unittest.main() |
OLD | NEW |