Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(419)

Side by Side Diff: appengine/monorail/framework/sql.py

Issue 1868553004: Open Source Monorail (Closed) Base URL: https://chromium.googlesource.com/infra/infra.git@master
Patch Set: Rebase Created 4 years, 8 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
(Empty)
1 # Copyright 2016 The Chromium Authors. All rights reserved.
2 # Use of this source code is govered by a BSD-style
3 # license that can be found in the LICENSE file or at
4 # https://developers.google.com/open-source/licenses/bsd
5
6 """A set of classes for interacting with tables in SQL."""
7
8 import logging
9 import random
10 import re
11 import sys
12 import time
13
14 import settings
15
16 if not settings.unit_test_mode:
17 import MySQLdb
18
19 from framework import framework_helpers
20
21
22 # MonorailConnection maintains a dictionary of connections to SQL databases.
23 # Each is identified by an int shard ID.
24 # And there is one connection to the master DB identified by key MASTER_CNXN.
25 MASTER_CNXN = 'master_cnxn'
26
27
28 @framework_helpers.retry(2, delay=1, backoff=2)
29 def MakeConnection(instance, database):
30 logging.info('About to connect to SQL instance %r db %r', instance, database)
31 if settings.unit_test_mode:
32 raise ValueError('unit tests should not need real database connections')
33 if settings.dev_mode:
34 cnxn = MySQLdb.connect(
35 host='127.0.0.1', port=3306, db=database, user='root', charset='utf8')
36 else:
37 cnxn = MySQLdb.connect(
38 unix_socket='/cloudsql/' + instance, db=database, user='root',
39 charset='utf8')
40 return cnxn
41
42
43 class MonorailConnection(object):
44 """Create and manage connections to the SQL servers.
45
46 We only store connections in the context of a single user request, not
47 across user requests. The main purpose of this class is to make using
48 sharded tables easier.
49 """
50
51 def __init__(self):
52 self.sql_cnxns = {} # {MASTER_CNXN: cnxn, shard_id: cnxn, ...}
53
54 def GetMasterConnection(self):
55 """Return a connection to the master SQL DB."""
56 if MASTER_CNXN not in self.sql_cnxns:
57 self.sql_cnxns[MASTER_CNXN] = MakeConnection(
58 settings.db_instance, settings.db_database_name)
59 logging.info(
60 'created a master connection %r', self.sql_cnxns[MASTER_CNXN])
61
62 return self.sql_cnxns[MASTER_CNXN]
63
64 def GetConnectionForShard(self, shard_id):
65 """Return a connection to the DB replica that will be used for shard_id."""
66 if settings.dev_mode:
67 return self.GetMasterConnection()
68
69 if shard_id not in self.sql_cnxns:
70 physical_shard_id = shard_id % settings.num_logical_shards
71 shard_instance_name = (
72 settings.physical_db_name_format % physical_shard_id)
73 self.sql_cnxns[shard_id] = MakeConnection(
74 shard_instance_name, settings.db_database_name)
75 logging.info('created a replica connection for shard %d', shard_id)
76
77 return self.sql_cnxns[shard_id]
78
79 def Execute(self, stmt_str, stmt_args, shard_id=None, commit=True):
80 """Execute the given SQL statement on one of the relevant databases."""
81 if shard_id is None:
82 # No shard was specified, so hit the master.
83 sql_cnxn = self.GetMasterConnection()
84 else:
85 sql_cnxn = self.GetConnectionForShard(shard_id)
86
87 return self._ExecuteWithSQLConnection(
88 sql_cnxn, stmt_str, stmt_args, commit=commit)
89
90 def _ExecuteWithSQLConnection(
91 self, sql_cnxn, stmt_str, stmt_args, commit=True):
92 """Execute a statement on the given database and return a cursor."""
93 cursor = sql_cnxn.cursor()
94 start_time = time.time()
95 if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'):
96 logging.info('SQL stmt_str: \n%s', stmt_str)
97 logging.info('SQL stmt_args: %r', stmt_args)
98 cursor.executemany(stmt_str, stmt_args)
99 else:
100 logging.info('SQL stmt: \n%s', (stmt_str % tuple(stmt_args)))
101 cursor.execute(stmt_str, args=stmt_args)
102 logging.info('%d rows in %d ms', cursor.rowcount,
103 int((time.time() - start_time) * 1000))
104 if commit and not stmt_str.startswith('SELECT'):
105 start_time = time.time()
106 try:
107 sql_cnxn.commit()
108 except MySQLdb.DatabaseError:
109 sql_cnxn.rollback()
110 logging.info('commit took %d ms',
111 int((time.time() - start_time) * 1000))
112
113 return cursor
114
115 def Commit(self):
116 """Explicitly commit any pending txns. Normally done automatically."""
117 sql_cnxn = self.GetMasterConnection()
118 start_time = time.time()
119 try:
120 sql_cnxn.commit()
121 except MySQLdb.DatabaseError:
122 logging.exception('Commit failed for cnxn, rolling back')
123 sql_cnxn.rollback()
124 logging.info('final commit took %d ms',
125 int((time.time() - start_time) * 1000))
126
127 def Close(self):
128 """Safely close any connections that are still open."""
129 for sql_cnxn in self.sql_cnxns.itervalues():
130 try:
131 sql_cnxn.close()
132 except MySQLdb.DatabaseError:
133 # This might happen if the cnxn is somehow already closed.
134 logging.exception('ProgrammingError when trying to close cnxn')
135
136
137 class SQLTableManager(object):
138 """Helper class to make it easier to deal with an SQL table."""
139
140 def __init__(self, table_name):
141 self.table_name = table_name
142
143 def Select(
144 self, cnxn, distinct=False, cols=None, left_joins=None,
145 joins=None, where=None, or_where_conds=False, group_by=None,
146 order_by=None, limit=None, offset=None, shard_id=None, use_clause=None,
147 **kwargs):
148 """Compose and execute an SQL SELECT statement on this table.
149
150 Args:
151 cnxn: MonorailConnection to the databases.
152 distinct: If True, add DISTINCT keyword.
153 cols: List of columns to retrieve, defaults to '*'.
154 left_joins: List of LEFT JOIN (str, args) pairs.
155 joins: List of regular JOIN (str, args) pairs.
156 where: List of (str, args) for WHERE clause.
157 or_where_conds: Set to True to use OR in the WHERE conds.
158 group_by: List of strings for GROUP BY clause.
159 order_by: List of (str, args) for ORDER BY clause.
160 limit: Optional LIMIT on the number of rows returned.
161 offset: Optional OFFSET when using LIMIT.
162 shard_id: Int ID of the shard to query.
163 use_clause: Optional string USE clause to tell the DB which index to use.
164 **kwargs: WHERE-clause equality and set-membership conditions.
165
166 Keyword args are used to build up more WHERE conditions that compare
167 column values to constants. Key word Argument foo='bar' translates to 'foo
168 = "bar"', and foo=[3, 4, 5] translates to 'foo IN (3, 4, 5)'.
169
170 Returns:
171 A list of rows, each row is a tuple of values for the requested cols.
172 """
173 cols = cols or ['*'] # If columns not specified, retrieve all columns.
174 stmt = Statement.MakeSelect(
175 self.table_name, cols, distinct=distinct,
176 or_where_conds=or_where_conds)
177 if use_clause:
178 stmt.AddUseClause(use_clause)
179 stmt.AddJoinClauses(left_joins or [], left=True)
180 stmt.AddJoinClauses(joins or [])
181 stmt.AddWhereTerms(where or [], **kwargs)
182 stmt.AddGroupByTerms(group_by or [])
183 stmt.AddOrderByTerms(order_by or [])
184 stmt.SetLimitAndOffset(limit, offset)
185 stmt_str, stmt_args = stmt.Generate()
186
187 cursor = cnxn.Execute(stmt_str, stmt_args, shard_id=shard_id)
188 rows = cursor.fetchall()
189 return rows
190
191 def SelectRow(
192 self, cnxn, cols=None, default=None, where=None, **kwargs):
193 """Run a query that is expected to return just one row."""
194 rows = self.Select(cnxn, distinct=True, cols=cols, where=where, **kwargs)
195 if len(rows) == 1:
196 return rows[0]
197 elif not rows:
198 logging.info('SelectRow got 0 results, so using default %r', default)
199 return default
200 else:
201 raise ValueError('SelectRow got %d results, expected only 1', len(rows))
202
203 def SelectValue(self, cnxn, col, default=None, where=None, **kwargs):
204 """Run a query that is expected to return just one row w/ one value."""
205 row = self.SelectRow(
206 cnxn, cols=[col], default=[default], where=where, **kwargs)
207 return row[0]
208
209 def InsertRows(
210 self, cnxn, cols, row_values, replace=False, ignore=False,
211 commit=True, return_generated_ids=False):
212 """Insert all the given rows.
213
214 Args:
215 cnxn: MonorailConnection object.
216 cols: List of column names to set.
217 row_values: List of lists with values to store. The length of each
218 nested list should be equal to len(cols).
219 replace: Set to True if inserted values should replace existing DB rows
220 that have the same DB keys.
221 ignore: Set to True to ignore rows that would duplicate existing DB keys.
222 commit: Set to False if this operation is part of a series of operations
223 that should not be committed until the final one is done.
224 return_generated_ids: Set to True to return a list of generated
225 autoincrement IDs for inserted rows. This requires us to insert rows
226 one at a time.
227
228 Returns:
229 If return_generated_ids is set to True, this method returns a list of the
230 auto-increment IDs generated by the DB. Otherwise, [] is returned.
231 """
232 if not row_values:
233 return None # Nothing to insert
234
235 generated_ids = []
236 if return_generated_ids:
237 # We must insert the rows one-at-a-time to know the generated IDs.
238 for row_value in row_values:
239 stmt = Statement.MakeInsert(
240 self.table_name, cols, [row_value], replace=replace, ignore=ignore)
241 stmt_str, stmt_args = stmt.Generate()
242 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
243 if cursor.lastrowid:
244 generated_ids.append(cursor.lastrowid)
245 return generated_ids
246
247 stmt = Statement.MakeInsert(
248 self.table_name, cols, row_values, replace=replace, ignore=ignore)
249 stmt_str, stmt_args = stmt.Generate()
250 cnxn.Execute(stmt_str, stmt_args, commit=commit)
251 return []
252
253
254 def InsertRow(
255 self, cnxn, replace=False, ignore=False, commit=True, **kwargs):
256 """Insert a single row into the table.
257
258 Args:
259 cnxn: MonorailConnection object.
260 replace: Set to True if inserted values should replace existing DB rows
261 that have the same DB keys.
262 ignore: Set to True to ignore rows that would duplicate existing DB keys.
263 commit: Set to False if this operation is part of a series of operations
264 that should not be committed until the final one is done.
265 **kwargs: column=value assignments to specify what to store in the DB.
266
267 Returns:
268 The generated autoincrement ID of the key column if one was generated.
269 Otherwise, return None.
270 """
271 cols = sorted(kwargs.keys())
272 row = tuple(kwargs[col] for col in cols)
273 generated_ids = self.InsertRows(
274 cnxn, cols, [row], replace=replace, ignore=ignore,
275 commit=commit, return_generated_ids=True)
276 if generated_ids:
277 return generated_ids[0]
278 else:
279 return None
280
281 def Update(self, cnxn, delta, where=None, commit=True, **kwargs):
282 """Update one or more rows.
283
284 Args:
285 cnxn: MonorailConnection object.
286 delta: Dictionary of {column: new_value} assignments.
287 where: Optional list of WHERE conditions saying which rows to update.
288 commit: Set to False if this operation is part of a series of operations
289 that should not be committed until the final one is done.
290 **kwargs: WHERE-clause equality and set-membership conditions.
291
292 Returns:
293 Int number of rows updated.
294 """
295 if not delta:
296 return 0 # Nothing is being changed
297
298 stmt = Statement.MakeUpdate(self.table_name, delta)
299 stmt.AddWhereTerms(where, **kwargs)
300 stmt_str, stmt_args = stmt.Generate()
301
302 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
303 return cursor.rowcount
304
305 def IncrementCounterValue(self, cnxn, col_name, where=None, **kwargs):
306 """Atomically increment a counter stored in MySQL, return new value.
307
308 Args:
309 cnxn: MonorailConnection object.
310 col_name: int column to increment.
311 where: Optional list of WHERE conditions saying which rows to update.
312 **kwargs: WHERE-clause equality and set-membership conditions. The
313 where and kwargs together should narrow the update down to exactly
314 one row.
315
316 Returns:
317 The new, post-increment value of the counter.
318 """
319 stmt = Statement.MakeIncrement(self.table_name, col_name)
320 stmt.AddWhereTerms(where, **kwargs)
321 stmt_str, stmt_args = stmt.Generate()
322
323 cursor = cnxn.Execute(stmt_str, stmt_args)
324 assert cursor.rowcount == 1, (
325 'missing or ambiguous counter: %r' % cursor.rowcount)
326 return cursor.lastrowid
327
328 def Delete(self, cnxn, where=None, commit=True, **kwargs):
329 """Delete the specified table rows.
330
331 Args:
332 cnxn: MonorailConnection object.
333 where: Optional list of WHERE conditions saying which rows to update.
334 commit: Set to False if this operation is part of a series of operations
335 that should not be committed until the final one is done.
336 **kwargs: WHERE-clause equality and set-membership conditions.
337
338 Returns:
339 Int number of rows updated.
340 """
341 # Deleting the whole table is never intended in Monorail.
342 assert where or kwargs
343
344 stmt = Statement.MakeDelete(self.table_name)
345 stmt.AddWhereTerms(where, **kwargs)
346 stmt_str, stmt_args = stmt.Generate()
347
348 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
349 return cursor.rowcount
350
351
352 class Statement(object):
353 """A class to help build complex SQL statements w/ full escaping.
354
355 Start with a Make*() method, then fill in additional clauses as needed,
356 then call Generate() to return the SQL string and argument list. We pass
357 the string and args to MySQLdb separately so that it can do escaping on
358 the arg values as appropriate to prevent SQL-injection attacks.
359
360 The only values that are not escaped by MySQLdb are the table names
361 and column names, and bits of SQL syntax, all of which is hard-coded
362 in our application.
363 """
364
365 @classmethod
366 def MakeSelect(cls, table_name, cols, distinct=False, or_where_conds=False):
367 """Constuct a SELECT statement."""
368 assert _IsValidTableName(table_name)
369 assert all(_IsValidColumnName(col) for col in cols)
370 main_clause = 'SELECT%s %s FROM %s' % (
371 (' DISTINCT' if distinct else ''), ', '.join(cols), table_name)
372 return cls(main_clause, or_where_conds=or_where_conds)
373
374 @classmethod
375 def MakeInsert(
376 cls, table_name, cols, new_values, replace=False, ignore=False):
377 """Constuct an INSERT statement."""
378 if replace == True:
379 return cls.MakeReplace(table_name, cols, new_values, ignore)
380 assert _IsValidTableName(table_name)
381 assert all(_IsValidColumnName(col) for col in cols)
382 ignore_word = ' IGNORE' if ignore else ''
383 main_clause = 'INSERT%s INTO %s (%s)' % (
384 ignore_word, table_name, ', '.join(cols))
385 return cls(main_clause, insert_args=new_values)
386
387 @classmethod
388 def MakeReplace(
389 cls, table_name, cols, new_values, ignore=False):
390 """Construct an INSERT...ON DUPLICATE KEY UPDATE... statement.
391
392 Uses the INSERT/UPDATE syntax because REPLACE is literally a DELETE
393 followed by an INSERT, which doesn't play well with foreign keys.
394 INSERT/UPDATE is an atomic check of whether the primary key exists,
395 followed by an INSERT if it doesn't or an UPDATE if it does.
396 """
397 assert _IsValidTableName(table_name)
398 assert all(_IsValidColumnName(col) for col in cols)
399 ignore_word = ' IGNORE' if ignore else ''
400 main_clause = 'INSERT%s INTO %s (%s)' % (
401 ignore_word, table_name, ', '.join(cols))
402 return cls(main_clause, insert_args=new_values, duplicate_update_cols=cols)
403
404 @classmethod
405 def MakeUpdate(cls, table_name, delta):
406 """Constuct an UPDATE statement."""
407 assert _IsValidTableName(table_name)
408 assert all(_IsValidColumnName(col) for col in delta.iterkeys())
409 update_strs = []
410 update_args = []
411 for col, val in delta.iteritems():
412 update_strs.append(col + '=%s')
413 update_args.append(val)
414
415 main_clause = 'UPDATE %s SET %s' % (
416 table_name, ', '.join(update_strs))
417 return cls(main_clause, update_args=update_args)
418
419 @classmethod
420 def MakeIncrement(cls, table_name, col_name, step=1):
421 """Constuct an UPDATE statement that increments and returns a counter."""
422 assert _IsValidTableName(table_name)
423 assert _IsValidColumnName(col_name)
424
425 main_clause = (
426 'UPDATE %s SET %s = LAST_INSERT_ID(%s + %%s)' % (
427 table_name, col_name, col_name))
428 update_args = [step]
429 return cls(main_clause, update_args=update_args)
430
431 @classmethod
432 def MakeDelete(cls, table_name):
433 """Constuct a DELETE statement."""
434 assert _IsValidTableName(table_name)
435 main_clause = 'DELETE FROM %s' % table_name
436 return cls(main_clause)
437
438 def __init__(
439 self, main_clause, insert_args=None, update_args=None,
440 duplicate_update_cols=None, or_where_conds=False):
441 self.main_clause = main_clause # E.g., SELECT or DELETE
442 self.or_where_conds = or_where_conds
443 self.insert_args = insert_args or [] # For INSERT statements
444 self.update_args = update_args or [] # For UPDATEs
445 self.duplicate_update_cols = duplicate_update_cols or [] # For REPLACE-ish
446
447 self.use_clauses = []
448 self.join_clauses, self.join_args = [], []
449 self.where_conds, self.where_args = [], []
450 self.group_by_terms, self.group_by_args = [], []
451 self.order_by_terms, self.order_by_args = [], []
452 self.limit, self.offset = None, None
453
454 def Generate(self):
455 """Return an SQL string having %s placeholders and args to fill them in."""
456 clauses = [self.main_clause] + self.use_clauses + self.join_clauses
457 if self.where_conds:
458 if self.or_where_conds:
459 clauses.append('WHERE ' + '\n OR '.join(self.where_conds))
460 else:
461 clauses.append('WHERE ' + '\n AND '.join(self.where_conds))
462 if self.group_by_terms:
463 clauses.append('GROUP BY ' + ', '.join(self.group_by_terms))
464 if self.order_by_terms:
465 clauses.append('ORDER BY ' + ', '.join(self.order_by_terms))
466
467 if self.limit and self.offset:
468 clauses.append('LIMIT %d OFFSET %d' % (self.limit, self.offset))
469 elif self.limit:
470 clauses.append('LIMIT %d' % self.limit)
471 elif self.offset:
472 clauses.append('LIMIT %d OFFSET %d' % (sys.maxint, self.offset))
473
474 if self.insert_args:
475 clauses.append('VALUES (' + PlaceHolders(self.insert_args[0]) + ')')
476 args = self.insert_args
477 if self.duplicate_update_cols:
478 clauses.append('ON DUPLICATE KEY UPDATE %s' % (
479 ', '.join(['%s=VALUES(%s)' % (col, col)
480 for col in self.duplicate_update_cols])))
481 assert not (self.join_args + self.update_args + self.where_args +
482 self.group_by_args + self.order_by_args)
483 else:
484 args = (self.join_args + self.update_args + self.where_args +
485 self.group_by_args + self.order_by_args)
486 assert not (self.insert_args + self.duplicate_update_cols)
487
488 args = _BoolsToInts(args)
489 stmt_str = '\n'.join(clause for clause in clauses if clause)
490
491 assert _IsValidStatement(stmt_str), stmt_str
492 return stmt_str, args
493
494 def AddUseClause(self, use_clause):
495 """Add a USE clause (giving the DB a hint about which indexes to use)."""
496 assert _IsValidUseClause(use_clause), use_clause
497 self.use_clauses.append(use_clause)
498
499 def AddJoinClauses(self, join_pairs, left=False):
500 """Save JOIN clauses based on the given list of join conditions."""
501 for join, args in join_pairs:
502 assert _IsValidJoin(join), join
503 assert join.count('%s') == len(args), join
504 self.join_clauses.append(
505 ' %sJOIN %s' % (('LEFT ' if left else ''), join))
506 self.join_args.extend(args)
507
508 def AddGroupByTerms(self, group_by_term_list):
509 """Save info needed to generate the GROUP BY clause."""
510 assert all(_IsValidGroupByTerm(term) for term in group_by_term_list)
511 self.group_by_terms.extend(group_by_term_list)
512
513 def AddOrderByTerms(self, order_by_pairs):
514 """Save info needed to generate the ORDER BY clause."""
515 for term, args in order_by_pairs:
516 assert _IsValidOrderByTerm(term), term
517 assert term.count('%s') == len(args), term
518 self.order_by_terms.append(term)
519 self.order_by_args.extend(args)
520
521 def SetLimitAndOffset(self, limit, offset):
522 """Save info needed to generate the LIMIT OFFSET clause."""
523 self.limit = limit
524 self.offset = offset
525
526 def AddWhereTerms(self, where_cond_pairs, **kwargs):
527 """Gererate a WHERE clause."""
528 where_cond_pairs = where_cond_pairs or []
529
530 for cond, args in where_cond_pairs:
531 assert _IsValidWhereCond(cond), cond
532 assert cond.count('%s') == len(args), cond
533 self.where_conds.append(cond)
534 self.where_args.extend(args)
535
536 for col, val in sorted(kwargs.items()):
537 assert _IsValidColumnName(col), col
538 eq = True
539 if col.endswith('_not'):
540 col = col[:-4]
541 eq = False
542
543 if isinstance(val, set):
544 val = list(val) # MySQL inteface cannot handle sets.
545
546 if val is None or val == []:
547 op = 'IS' if eq else 'IS NOT'
548 self.where_conds.append(col + ' ' + op + ' NULL')
549 elif isinstance(val, list):
550 op = 'IN' if eq else 'NOT IN'
551 # Sadly, MySQLdb cannot escape lists, so we flatten to multiple "%s"s
552 self.where_conds.append(
553 col + ' ' + op + ' (' + PlaceHolders(val) + ')')
554 self.where_args.extend(val)
555 else:
556 op = '=' if eq else '!='
557 self.where_conds.append(col + ' ' + op + ' %s')
558 self.where_args.append(val)
559
560
561 def PlaceHolders(sql_args):
562 """Return a comma-separated list of %s placeholders for the given args."""
563 return ','.join('%s' for _ in sql_args)
564
565
566 TABLE_PAT = '[A-Z][_a-zA-Z0-9]+'
567 COLUMN_PAT = '[a-z][_a-z]+'
568 COMPARE_OP_PAT = '(<|>|=|!=|>=|<=|LIKE|NOT LIKE)'
569 SHORTHAND = {
570 'table': TABLE_PAT,
571 'column': COLUMN_PAT,
572 'tab_col': r'(%s\.)?%s' % (TABLE_PAT, COLUMN_PAT),
573 'placeholder': '%s', # That's a literal %s that gets passed to MySQLdb
574 'multi_placeholder': '%s(, ?%s)*',
575 'compare_op': COMPARE_OP_PAT,
576 'opt_asc_desc': '( ASC| DESC)?',
577 'opt_alias': '( AS %s)?' % TABLE_PAT,
578 'email_cond': (r'LOWER\(User\d+\.email\) '
579 r'(%s %%s|IN \(%%s(, ?%%s)*\))' % COMPARE_OP_PAT),
580 }
581
582
583 def _MakeRE(regex_str):
584 """Return a regular expression object, expanding our shorthand as needed."""
585 return re.compile(regex_str.format(**SHORTHAND))
586
587
588 TABLE_RE = _MakeRE('^{table}$')
589 TAB_COL_RE = _MakeRE('^{tab_col}$')
590 USE_CLAUSE_RE = _MakeRE(
591 r'^USE INDEX \({column}\) USE INDEX FOR ORDER BY \({column}\)$')
592 COLUMN_RE_LIST = [
593 TAB_COL_RE,
594 _MakeRE(r'\*'),
595 _MakeRE(r'COUNT\(\*\)'),
596 _MakeRE(r'COUNT\({tab_col}\)'),
597 _MakeRE(r'MAX\({tab_col}\)'),
598 _MakeRE(r'MIN\({tab_col}\)'),
599 ]
600 JOIN_RE_LIST = [
601 TABLE_RE,
602 _MakeRE(
603 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
604 r'( AND {tab_col} = {tab_col})?'
605 r'( AND {tab_col} IN \({multi_placeholder}\))?$'),
606 _MakeRE(
607 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
608 r'( AND {tab_col} IN \({multi_placeholder}\))?$'),
609 _MakeRE(
610 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
611 r'( AND {tab_col} = {tab_col})?'
612 r' AND {tab_col} = {placeholder}$'),
613 _MakeRE(
614 r'^{table}{opt_alias} ON {tab_col} = {tab_col} AND {email_cond}$'),
615 _MakeRE(
616 r'^{table}{opt_alias} ON '
617 r'\({tab_col} = {tab_col} OR {tab_col} = {tab_col}\)$'),
618 _MakeRE(
619 r'^\({table} AS {table} JOIN User AS {table} '
620 r'ON {tab_col} = {tab_col} '
621 r'AND {email_cond}\) ON Issue.id = {tab_col}'),
622 _MakeRE(
623 r'^{table} AS {table} ON {tab_col} = {tab_col} '
624 r'LEFT JOIN {table} AS {table} ON {tab_col} = {tab_col}'),
625 ]
626 ORDER_BY_RE_LIST = [
627 _MakeRE(r'^{tab_col}{opt_asc_desc}$'),
628 _MakeRE(r'^LOWER\({tab_col}\){opt_asc_desc}$'),
629 _MakeRE(r'^ISNULL\({tab_col}\){opt_asc_desc}$'),
630 _MakeRE(r'^FIELD\({tab_col}, {multi_placeholder}\){opt_asc_desc}$'),
631 _MakeRE(r'^FIELD\(IF\(ISNULL\({tab_col}\), {tab_col}, {tab_col}\), '
632 r'{multi_placeholder}\){opt_asc_desc}$'),
633 ]
634 GROUP_BY_RE_LIST = [
635 TAB_COL_RE,
636 ]
637 WHERE_COND_RE_LIST = [
638 _MakeRE(r'^TRUE$'),
639 _MakeRE(r'^FALSE$'),
640 _MakeRE(r'^{tab_col} IS NULL$'),
641 _MakeRE(r'^{tab_col} IS NOT NULL$'),
642 _MakeRE(r'^{tab_col} {compare_op} {tab_col}$'),
643 _MakeRE(r'^{tab_col} {compare_op} {placeholder}$'),
644 _MakeRE(r'^{tab_col} %% {placeholder} = {placeholder}$'),
645 _MakeRE(r'^{tab_col} IN \({multi_placeholder}\)$'),
646 _MakeRE(r'^{tab_col} NOT IN \({multi_placeholder}\)$'),
647 _MakeRE(r'^LOWER\({tab_col}\) IS NULL$'),
648 _MakeRE(r'^LOWER\({tab_col}\) IS NOT NULL$'),
649 _MakeRE(r'^LOWER\({tab_col}\) {compare_op} {placeholder}$'),
650 _MakeRE(r'^LOWER\({tab_col}\) IN \({multi_placeholder}\)$'),
651 _MakeRE(r'^LOWER\({tab_col}\) NOT IN \({multi_placeholder}\)$'),
652 _MakeRE(r'^LOWER\({tab_col}\) LIKE {placeholder}$'),
653 _MakeRE(r'^LOWER\({tab_col}\) NOT LIKE {placeholder}$'),
654 _MakeRE(r'^timestep < \(SELECT MAX\(j.timestep\) FROM Invalidate AS j '
655 r'WHERE j.kind = %s '
656 r'AND j.cache_key = Invalidate.cache_key\)$'),
657 _MakeRE(r'^\({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}\) '
658 'AND \({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}'
659 '\)$'),
660 _MakeRE(r'^\({tab_col} IS NOT NULL AND {tab_col} {compare_op} '
661 '{placeholder}\) OR \({tab_col} IS NOT NULL AND {tab_col} '
662 '{compare_op} {placeholder}\)$'),
663 ]
664
665 # Note: We never use ';' for multiple statements, '@' for SQL variables, or
666 # any quoted strings in stmt_str (quotes are put in my MySQLdb for args).
667 STMT_STR_RE = re.compile(
668 r'\A(SELECT|UPDATE|DELETE|INSERT|REPLACE) [-+=!<>%*.,()\w\s]+\Z',
669 re.MULTILINE)
670
671
672 def _IsValidTableName(table_name):
673 return TABLE_RE.match(table_name)
674
675
676 def _IsValidColumnName(column_expr):
677 return any(regex.match(column_expr) for regex in COLUMN_RE_LIST)
678
679
680 def _IsValidUseClause(use_clause):
681 return USE_CLAUSE_RE.match(use_clause)
682
683
684 def _IsValidJoin(join):
685 return any(regex.match(join) for regex in JOIN_RE_LIST)
686
687
688 def _IsValidOrderByTerm(term):
689 return any(regex.match(term) for regex in ORDER_BY_RE_LIST)
690
691
692 def _IsValidGroupByTerm(term):
693 return any(regex.match(term) for regex in GROUP_BY_RE_LIST)
694
695
696 def _IsValidWhereCond(cond):
697 if cond.startswith('NOT '):
698 cond = cond[4:]
699 if cond.startswith('(') and cond.endswith(')'):
700 cond = cond[1:-1]
701
702 if any(regex.match(cond) for regex in WHERE_COND_RE_LIST):
703 return True
704
705 if ' OR ' in cond:
706 return all(_IsValidWhereCond(c) for c in cond.split(' OR '))
707
708 if ' AND ' in cond:
709 return all(_IsValidWhereCond(c) for c in cond.split(' AND '))
710
711 return False
712
713
714 def _IsValidStatement(stmt_str):
715 """Final check to make sure there is no funny junk sneaking in somehow."""
716 return (STMT_STR_RE.match(stmt_str) and
717 '--' not in stmt_str)
718
719
720 def _BoolsToInts(arg_list):
721 """Convert any True values to 1s and Falses to 0s.
722
723 Google's copy of MySQLdb has bool-to-int conversion disabled,
724 and yet it seems to be needed otherwise they are converted
725 to strings and always interpreted as 0 (which is FALSE).
726
727 Args:
728 arg_list: (nested) list of SQL statment argument values, which may
729 include some boolean values.
730
731 Returns:
732 The same list, but with True replaced by 1 and False replaced by 0.
733 """
734 result = []
735 for arg in arg_list:
736 if isinstance(arg, (list, tuple)):
737 result.append(_BoolsToInts(arg))
738 elif arg is True:
739 result.append(1)
740 elif arg is False:
741 result.append(0)
742 else:
743 result.append(arg)
744
745 return result
OLDNEW
« no previous file with comments | « appengine/monorail/framework/sorting.py ('k') | appengine/monorail/framework/table_view_helpers.py » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698