Index: appengine/monorail/framework/sql.py |
diff --git a/appengine/monorail/framework/sql.py b/appengine/monorail/framework/sql.py |
new file mode 100644 |
index 0000000000000000000000000000000000000000..223912db397be67ab55dd899324b967c8998fdab |
--- /dev/null |
+++ b/appengine/monorail/framework/sql.py |
@@ -0,0 +1,745 @@ |
+# Copyright 2016 The Chromium Authors. All rights reserved. |
+# Use of this source code is govered by a BSD-style |
+# license that can be found in the LICENSE file or at |
+# https://developers.google.com/open-source/licenses/bsd |
+ |
+"""A set of classes for interacting with tables in SQL.""" |
+ |
+import logging |
+import random |
+import re |
+import sys |
+import time |
+ |
+import settings |
+ |
+if not settings.unit_test_mode: |
+ import MySQLdb |
+ |
+from framework import framework_helpers |
+ |
+ |
+# MonorailConnection maintains a dictionary of connections to SQL databases. |
+# Each is identified by an int shard ID. |
+# And there is one connection to the master DB identified by key MASTER_CNXN. |
+MASTER_CNXN = 'master_cnxn' |
+ |
+ |
+@framework_helpers.retry(2, delay=1, backoff=2) |
+def MakeConnection(instance, database): |
+ logging.info('About to connect to SQL instance %r db %r', instance, database) |
+ if settings.unit_test_mode: |
+ raise ValueError('unit tests should not need real database connections') |
+ if settings.dev_mode: |
+ cnxn = MySQLdb.connect( |
+ host='127.0.0.1', port=3306, db=database, user='root', charset='utf8') |
+ else: |
+ cnxn = MySQLdb.connect( |
+ unix_socket='/cloudsql/' + instance, db=database, user='root', |
+ charset='utf8') |
+ return cnxn |
+ |
+ |
+class MonorailConnection(object): |
+ """Create and manage connections to the SQL servers. |
+ |
+ We only store connections in the context of a single user request, not |
+ across user requests. The main purpose of this class is to make using |
+ sharded tables easier. |
+ """ |
+ |
+ def __init__(self): |
+ self.sql_cnxns = {} # {MASTER_CNXN: cnxn, shard_id: cnxn, ...} |
+ |
+ def GetMasterConnection(self): |
+ """Return a connection to the master SQL DB.""" |
+ if MASTER_CNXN not in self.sql_cnxns: |
+ self.sql_cnxns[MASTER_CNXN] = MakeConnection( |
+ settings.db_instance, settings.db_database_name) |
+ logging.info( |
+ 'created a master connection %r', self.sql_cnxns[MASTER_CNXN]) |
+ |
+ return self.sql_cnxns[MASTER_CNXN] |
+ |
+ def GetConnectionForShard(self, shard_id): |
+ """Return a connection to the DB replica that will be used for shard_id.""" |
+ if settings.dev_mode: |
+ return self.GetMasterConnection() |
+ |
+ if shard_id not in self.sql_cnxns: |
+ physical_shard_id = shard_id % settings.num_logical_shards |
+ shard_instance_name = ( |
+ settings.physical_db_name_format % physical_shard_id) |
+ self.sql_cnxns[shard_id] = MakeConnection( |
+ shard_instance_name, settings.db_database_name) |
+ logging.info('created a replica connection for shard %d', shard_id) |
+ |
+ return self.sql_cnxns[shard_id] |
+ |
+ def Execute(self, stmt_str, stmt_args, shard_id=None, commit=True): |
+ """Execute the given SQL statement on one of the relevant databases.""" |
+ if shard_id is None: |
+ # No shard was specified, so hit the master. |
+ sql_cnxn = self.GetMasterConnection() |
+ else: |
+ sql_cnxn = self.GetConnectionForShard(shard_id) |
+ |
+ return self._ExecuteWithSQLConnection( |
+ sql_cnxn, stmt_str, stmt_args, commit=commit) |
+ |
+ def _ExecuteWithSQLConnection( |
+ self, sql_cnxn, stmt_str, stmt_args, commit=True): |
+ """Execute a statement on the given database and return a cursor.""" |
+ cursor = sql_cnxn.cursor() |
+ start_time = time.time() |
+ if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'): |
+ logging.info('SQL stmt_str: \n%s', stmt_str) |
+ logging.info('SQL stmt_args: %r', stmt_args) |
+ cursor.executemany(stmt_str, stmt_args) |
+ else: |
+ logging.info('SQL stmt: \n%s', (stmt_str % tuple(stmt_args))) |
+ cursor.execute(stmt_str, args=stmt_args) |
+ logging.info('%d rows in %d ms', cursor.rowcount, |
+ int((time.time() - start_time) * 1000)) |
+ if commit and not stmt_str.startswith('SELECT'): |
+ start_time = time.time() |
+ try: |
+ sql_cnxn.commit() |
+ except MySQLdb.DatabaseError: |
+ sql_cnxn.rollback() |
+ logging.info('commit took %d ms', |
+ int((time.time() - start_time) * 1000)) |
+ |
+ return cursor |
+ |
+ def Commit(self): |
+ """Explicitly commit any pending txns. Normally done automatically.""" |
+ sql_cnxn = self.GetMasterConnection() |
+ start_time = time.time() |
+ try: |
+ sql_cnxn.commit() |
+ except MySQLdb.DatabaseError: |
+ logging.exception('Commit failed for cnxn, rolling back') |
+ sql_cnxn.rollback() |
+ logging.info('final commit took %d ms', |
+ int((time.time() - start_time) * 1000)) |
+ |
+ def Close(self): |
+ """Safely close any connections that are still open.""" |
+ for sql_cnxn in self.sql_cnxns.itervalues(): |
+ try: |
+ sql_cnxn.close() |
+ except MySQLdb.DatabaseError: |
+ # This might happen if the cnxn is somehow already closed. |
+ logging.exception('ProgrammingError when trying to close cnxn') |
+ |
+ |
+class SQLTableManager(object): |
+ """Helper class to make it easier to deal with an SQL table.""" |
+ |
+ def __init__(self, table_name): |
+ self.table_name = table_name |
+ |
+ def Select( |
+ self, cnxn, distinct=False, cols=None, left_joins=None, |
+ joins=None, where=None, or_where_conds=False, group_by=None, |
+ order_by=None, limit=None, offset=None, shard_id=None, use_clause=None, |
+ **kwargs): |
+ """Compose and execute an SQL SELECT statement on this table. |
+ |
+ Args: |
+ cnxn: MonorailConnection to the databases. |
+ distinct: If True, add DISTINCT keyword. |
+ cols: List of columns to retrieve, defaults to '*'. |
+ left_joins: List of LEFT JOIN (str, args) pairs. |
+ joins: List of regular JOIN (str, args) pairs. |
+ where: List of (str, args) for WHERE clause. |
+ or_where_conds: Set to True to use OR in the WHERE conds. |
+ group_by: List of strings for GROUP BY clause. |
+ order_by: List of (str, args) for ORDER BY clause. |
+ limit: Optional LIMIT on the number of rows returned. |
+ offset: Optional OFFSET when using LIMIT. |
+ shard_id: Int ID of the shard to query. |
+ use_clause: Optional string USE clause to tell the DB which index to use. |
+ **kwargs: WHERE-clause equality and set-membership conditions. |
+ |
+ Keyword args are used to build up more WHERE conditions that compare |
+ column values to constants. Key word Argument foo='bar' translates to 'foo |
+ = "bar"', and foo=[3, 4, 5] translates to 'foo IN (3, 4, 5)'. |
+ |
+ Returns: |
+ A list of rows, each row is a tuple of values for the requested cols. |
+ """ |
+ cols = cols or ['*'] # If columns not specified, retrieve all columns. |
+ stmt = Statement.MakeSelect( |
+ self.table_name, cols, distinct=distinct, |
+ or_where_conds=or_where_conds) |
+ if use_clause: |
+ stmt.AddUseClause(use_clause) |
+ stmt.AddJoinClauses(left_joins or [], left=True) |
+ stmt.AddJoinClauses(joins or []) |
+ stmt.AddWhereTerms(where or [], **kwargs) |
+ stmt.AddGroupByTerms(group_by or []) |
+ stmt.AddOrderByTerms(order_by or []) |
+ stmt.SetLimitAndOffset(limit, offset) |
+ stmt_str, stmt_args = stmt.Generate() |
+ |
+ cursor = cnxn.Execute(stmt_str, stmt_args, shard_id=shard_id) |
+ rows = cursor.fetchall() |
+ return rows |
+ |
+ def SelectRow( |
+ self, cnxn, cols=None, default=None, where=None, **kwargs): |
+ """Run a query that is expected to return just one row.""" |
+ rows = self.Select(cnxn, distinct=True, cols=cols, where=where, **kwargs) |
+ if len(rows) == 1: |
+ return rows[0] |
+ elif not rows: |
+ logging.info('SelectRow got 0 results, so using default %r', default) |
+ return default |
+ else: |
+ raise ValueError('SelectRow got %d results, expected only 1', len(rows)) |
+ |
+ def SelectValue(self, cnxn, col, default=None, where=None, **kwargs): |
+ """Run a query that is expected to return just one row w/ one value.""" |
+ row = self.SelectRow( |
+ cnxn, cols=[col], default=[default], where=where, **kwargs) |
+ return row[0] |
+ |
+ def InsertRows( |
+ self, cnxn, cols, row_values, replace=False, ignore=False, |
+ commit=True, return_generated_ids=False): |
+ """Insert all the given rows. |
+ |
+ Args: |
+ cnxn: MonorailConnection object. |
+ cols: List of column names to set. |
+ row_values: List of lists with values to store. The length of each |
+ nested list should be equal to len(cols). |
+ replace: Set to True if inserted values should replace existing DB rows |
+ that have the same DB keys. |
+ ignore: Set to True to ignore rows that would duplicate existing DB keys. |
+ commit: Set to False if this operation is part of a series of operations |
+ that should not be committed until the final one is done. |
+ return_generated_ids: Set to True to return a list of generated |
+ autoincrement IDs for inserted rows. This requires us to insert rows |
+ one at a time. |
+ |
+ Returns: |
+ If return_generated_ids is set to True, this method returns a list of the |
+ auto-increment IDs generated by the DB. Otherwise, [] is returned. |
+ """ |
+ if not row_values: |
+ return None # Nothing to insert |
+ |
+ generated_ids = [] |
+ if return_generated_ids: |
+ # We must insert the rows one-at-a-time to know the generated IDs. |
+ for row_value in row_values: |
+ stmt = Statement.MakeInsert( |
+ self.table_name, cols, [row_value], replace=replace, ignore=ignore) |
+ stmt_str, stmt_args = stmt.Generate() |
+ cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit) |
+ if cursor.lastrowid: |
+ generated_ids.append(cursor.lastrowid) |
+ return generated_ids |
+ |
+ stmt = Statement.MakeInsert( |
+ self.table_name, cols, row_values, replace=replace, ignore=ignore) |
+ stmt_str, stmt_args = stmt.Generate() |
+ cnxn.Execute(stmt_str, stmt_args, commit=commit) |
+ return [] |
+ |
+ |
+ def InsertRow( |
+ self, cnxn, replace=False, ignore=False, commit=True, **kwargs): |
+ """Insert a single row into the table. |
+ |
+ Args: |
+ cnxn: MonorailConnection object. |
+ replace: Set to True if inserted values should replace existing DB rows |
+ that have the same DB keys. |
+ ignore: Set to True to ignore rows that would duplicate existing DB keys. |
+ commit: Set to False if this operation is part of a series of operations |
+ that should not be committed until the final one is done. |
+ **kwargs: column=value assignments to specify what to store in the DB. |
+ |
+ Returns: |
+ The generated autoincrement ID of the key column if one was generated. |
+ Otherwise, return None. |
+ """ |
+ cols = sorted(kwargs.keys()) |
+ row = tuple(kwargs[col] for col in cols) |
+ generated_ids = self.InsertRows( |
+ cnxn, cols, [row], replace=replace, ignore=ignore, |
+ commit=commit, return_generated_ids=True) |
+ if generated_ids: |
+ return generated_ids[0] |
+ else: |
+ return None |
+ |
+ def Update(self, cnxn, delta, where=None, commit=True, **kwargs): |
+ """Update one or more rows. |
+ |
+ Args: |
+ cnxn: MonorailConnection object. |
+ delta: Dictionary of {column: new_value} assignments. |
+ where: Optional list of WHERE conditions saying which rows to update. |
+ commit: Set to False if this operation is part of a series of operations |
+ that should not be committed until the final one is done. |
+ **kwargs: WHERE-clause equality and set-membership conditions. |
+ |
+ Returns: |
+ Int number of rows updated. |
+ """ |
+ if not delta: |
+ return 0 # Nothing is being changed |
+ |
+ stmt = Statement.MakeUpdate(self.table_name, delta) |
+ stmt.AddWhereTerms(where, **kwargs) |
+ stmt_str, stmt_args = stmt.Generate() |
+ |
+ cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit) |
+ return cursor.rowcount |
+ |
+ def IncrementCounterValue(self, cnxn, col_name, where=None, **kwargs): |
+ """Atomically increment a counter stored in MySQL, return new value. |
+ |
+ Args: |
+ cnxn: MonorailConnection object. |
+ col_name: int column to increment. |
+ where: Optional list of WHERE conditions saying which rows to update. |
+ **kwargs: WHERE-clause equality and set-membership conditions. The |
+ where and kwargs together should narrow the update down to exactly |
+ one row. |
+ |
+ Returns: |
+ The new, post-increment value of the counter. |
+ """ |
+ stmt = Statement.MakeIncrement(self.table_name, col_name) |
+ stmt.AddWhereTerms(where, **kwargs) |
+ stmt_str, stmt_args = stmt.Generate() |
+ |
+ cursor = cnxn.Execute(stmt_str, stmt_args) |
+ assert cursor.rowcount == 1, ( |
+ 'missing or ambiguous counter: %r' % cursor.rowcount) |
+ return cursor.lastrowid |
+ |
+ def Delete(self, cnxn, where=None, commit=True, **kwargs): |
+ """Delete the specified table rows. |
+ |
+ Args: |
+ cnxn: MonorailConnection object. |
+ where: Optional list of WHERE conditions saying which rows to update. |
+ commit: Set to False if this operation is part of a series of operations |
+ that should not be committed until the final one is done. |
+ **kwargs: WHERE-clause equality and set-membership conditions. |
+ |
+ Returns: |
+ Int number of rows updated. |
+ """ |
+ # Deleting the whole table is never intended in Monorail. |
+ assert where or kwargs |
+ |
+ stmt = Statement.MakeDelete(self.table_name) |
+ stmt.AddWhereTerms(where, **kwargs) |
+ stmt_str, stmt_args = stmt.Generate() |
+ |
+ cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit) |
+ return cursor.rowcount |
+ |
+ |
+class Statement(object): |
+ """A class to help build complex SQL statements w/ full escaping. |
+ |
+ Start with a Make*() method, then fill in additional clauses as needed, |
+ then call Generate() to return the SQL string and argument list. We pass |
+ the string and args to MySQLdb separately so that it can do escaping on |
+ the arg values as appropriate to prevent SQL-injection attacks. |
+ |
+ The only values that are not escaped by MySQLdb are the table names |
+ and column names, and bits of SQL syntax, all of which is hard-coded |
+ in our application. |
+ """ |
+ |
+ @classmethod |
+ def MakeSelect(cls, table_name, cols, distinct=False, or_where_conds=False): |
+ """Constuct a SELECT statement.""" |
+ assert _IsValidTableName(table_name) |
+ assert all(_IsValidColumnName(col) for col in cols) |
+ main_clause = 'SELECT%s %s FROM %s' % ( |
+ (' DISTINCT' if distinct else ''), ', '.join(cols), table_name) |
+ return cls(main_clause, or_where_conds=or_where_conds) |
+ |
+ @classmethod |
+ def MakeInsert( |
+ cls, table_name, cols, new_values, replace=False, ignore=False): |
+ """Constuct an INSERT statement.""" |
+ if replace == True: |
+ return cls.MakeReplace(table_name, cols, new_values, ignore) |
+ assert _IsValidTableName(table_name) |
+ assert all(_IsValidColumnName(col) for col in cols) |
+ ignore_word = ' IGNORE' if ignore else '' |
+ main_clause = 'INSERT%s INTO %s (%s)' % ( |
+ ignore_word, table_name, ', '.join(cols)) |
+ return cls(main_clause, insert_args=new_values) |
+ |
+ @classmethod |
+ def MakeReplace( |
+ cls, table_name, cols, new_values, ignore=False): |
+ """Construct an INSERT...ON DUPLICATE KEY UPDATE... statement. |
+ |
+ Uses the INSERT/UPDATE syntax because REPLACE is literally a DELETE |
+ followed by an INSERT, which doesn't play well with foreign keys. |
+ INSERT/UPDATE is an atomic check of whether the primary key exists, |
+ followed by an INSERT if it doesn't or an UPDATE if it does. |
+ """ |
+ assert _IsValidTableName(table_name) |
+ assert all(_IsValidColumnName(col) for col in cols) |
+ ignore_word = ' IGNORE' if ignore else '' |
+ main_clause = 'INSERT%s INTO %s (%s)' % ( |
+ ignore_word, table_name, ', '.join(cols)) |
+ return cls(main_clause, insert_args=new_values, duplicate_update_cols=cols) |
+ |
+ @classmethod |
+ def MakeUpdate(cls, table_name, delta): |
+ """Constuct an UPDATE statement.""" |
+ assert _IsValidTableName(table_name) |
+ assert all(_IsValidColumnName(col) for col in delta.iterkeys()) |
+ update_strs = [] |
+ update_args = [] |
+ for col, val in delta.iteritems(): |
+ update_strs.append(col + '=%s') |
+ update_args.append(val) |
+ |
+ main_clause = 'UPDATE %s SET %s' % ( |
+ table_name, ', '.join(update_strs)) |
+ return cls(main_clause, update_args=update_args) |
+ |
+ @classmethod |
+ def MakeIncrement(cls, table_name, col_name, step=1): |
+ """Constuct an UPDATE statement that increments and returns a counter.""" |
+ assert _IsValidTableName(table_name) |
+ assert _IsValidColumnName(col_name) |
+ |
+ main_clause = ( |
+ 'UPDATE %s SET %s = LAST_INSERT_ID(%s + %%s)' % ( |
+ table_name, col_name, col_name)) |
+ update_args = [step] |
+ return cls(main_clause, update_args=update_args) |
+ |
+ @classmethod |
+ def MakeDelete(cls, table_name): |
+ """Constuct a DELETE statement.""" |
+ assert _IsValidTableName(table_name) |
+ main_clause = 'DELETE FROM %s' % table_name |
+ return cls(main_clause) |
+ |
+ def __init__( |
+ self, main_clause, insert_args=None, update_args=None, |
+ duplicate_update_cols=None, or_where_conds=False): |
+ self.main_clause = main_clause # E.g., SELECT or DELETE |
+ self.or_where_conds = or_where_conds |
+ self.insert_args = insert_args or [] # For INSERT statements |
+ self.update_args = update_args or [] # For UPDATEs |
+ self.duplicate_update_cols = duplicate_update_cols or [] # For REPLACE-ish |
+ |
+ self.use_clauses = [] |
+ self.join_clauses, self.join_args = [], [] |
+ self.where_conds, self.where_args = [], [] |
+ self.group_by_terms, self.group_by_args = [], [] |
+ self.order_by_terms, self.order_by_args = [], [] |
+ self.limit, self.offset = None, None |
+ |
+ def Generate(self): |
+ """Return an SQL string having %s placeholders and args to fill them in.""" |
+ clauses = [self.main_clause] + self.use_clauses + self.join_clauses |
+ if self.where_conds: |
+ if self.or_where_conds: |
+ clauses.append('WHERE ' + '\n OR '.join(self.where_conds)) |
+ else: |
+ clauses.append('WHERE ' + '\n AND '.join(self.where_conds)) |
+ if self.group_by_terms: |
+ clauses.append('GROUP BY ' + ', '.join(self.group_by_terms)) |
+ if self.order_by_terms: |
+ clauses.append('ORDER BY ' + ', '.join(self.order_by_terms)) |
+ |
+ if self.limit and self.offset: |
+ clauses.append('LIMIT %d OFFSET %d' % (self.limit, self.offset)) |
+ elif self.limit: |
+ clauses.append('LIMIT %d' % self.limit) |
+ elif self.offset: |
+ clauses.append('LIMIT %d OFFSET %d' % (sys.maxint, self.offset)) |
+ |
+ if self.insert_args: |
+ clauses.append('VALUES (' + PlaceHolders(self.insert_args[0]) + ')') |
+ args = self.insert_args |
+ if self.duplicate_update_cols: |
+ clauses.append('ON DUPLICATE KEY UPDATE %s' % ( |
+ ', '.join(['%s=VALUES(%s)' % (col, col) |
+ for col in self.duplicate_update_cols]))) |
+ assert not (self.join_args + self.update_args + self.where_args + |
+ self.group_by_args + self.order_by_args) |
+ else: |
+ args = (self.join_args + self.update_args + self.where_args + |
+ self.group_by_args + self.order_by_args) |
+ assert not (self.insert_args + self.duplicate_update_cols) |
+ |
+ args = _BoolsToInts(args) |
+ stmt_str = '\n'.join(clause for clause in clauses if clause) |
+ |
+ assert _IsValidStatement(stmt_str), stmt_str |
+ return stmt_str, args |
+ |
+ def AddUseClause(self, use_clause): |
+ """Add a USE clause (giving the DB a hint about which indexes to use).""" |
+ assert _IsValidUseClause(use_clause), use_clause |
+ self.use_clauses.append(use_clause) |
+ |
+ def AddJoinClauses(self, join_pairs, left=False): |
+ """Save JOIN clauses based on the given list of join conditions.""" |
+ for join, args in join_pairs: |
+ assert _IsValidJoin(join), join |
+ assert join.count('%s') == len(args), join |
+ self.join_clauses.append( |
+ ' %sJOIN %s' % (('LEFT ' if left else ''), join)) |
+ self.join_args.extend(args) |
+ |
+ def AddGroupByTerms(self, group_by_term_list): |
+ """Save info needed to generate the GROUP BY clause.""" |
+ assert all(_IsValidGroupByTerm(term) for term in group_by_term_list) |
+ self.group_by_terms.extend(group_by_term_list) |
+ |
+ def AddOrderByTerms(self, order_by_pairs): |
+ """Save info needed to generate the ORDER BY clause.""" |
+ for term, args in order_by_pairs: |
+ assert _IsValidOrderByTerm(term), term |
+ assert term.count('%s') == len(args), term |
+ self.order_by_terms.append(term) |
+ self.order_by_args.extend(args) |
+ |
+ def SetLimitAndOffset(self, limit, offset): |
+ """Save info needed to generate the LIMIT OFFSET clause.""" |
+ self.limit = limit |
+ self.offset = offset |
+ |
+ def AddWhereTerms(self, where_cond_pairs, **kwargs): |
+ """Gererate a WHERE clause.""" |
+ where_cond_pairs = where_cond_pairs or [] |
+ |
+ for cond, args in where_cond_pairs: |
+ assert _IsValidWhereCond(cond), cond |
+ assert cond.count('%s') == len(args), cond |
+ self.where_conds.append(cond) |
+ self.where_args.extend(args) |
+ |
+ for col, val in sorted(kwargs.items()): |
+ assert _IsValidColumnName(col), col |
+ eq = True |
+ if col.endswith('_not'): |
+ col = col[:-4] |
+ eq = False |
+ |
+ if isinstance(val, set): |
+ val = list(val) # MySQL inteface cannot handle sets. |
+ |
+ if val is None or val == []: |
+ op = 'IS' if eq else 'IS NOT' |
+ self.where_conds.append(col + ' ' + op + ' NULL') |
+ elif isinstance(val, list): |
+ op = 'IN' if eq else 'NOT IN' |
+ # Sadly, MySQLdb cannot escape lists, so we flatten to multiple "%s"s |
+ self.where_conds.append( |
+ col + ' ' + op + ' (' + PlaceHolders(val) + ')') |
+ self.where_args.extend(val) |
+ else: |
+ op = '=' if eq else '!=' |
+ self.where_conds.append(col + ' ' + op + ' %s') |
+ self.where_args.append(val) |
+ |
+ |
+def PlaceHolders(sql_args): |
+ """Return a comma-separated list of %s placeholders for the given args.""" |
+ return ','.join('%s' for _ in sql_args) |
+ |
+ |
+TABLE_PAT = '[A-Z][_a-zA-Z0-9]+' |
+COLUMN_PAT = '[a-z][_a-z]+' |
+COMPARE_OP_PAT = '(<|>|=|!=|>=|<=|LIKE|NOT LIKE)' |
+SHORTHAND = { |
+ 'table': TABLE_PAT, |
+ 'column': COLUMN_PAT, |
+ 'tab_col': r'(%s\.)?%s' % (TABLE_PAT, COLUMN_PAT), |
+ 'placeholder': '%s', # That's a literal %s that gets passed to MySQLdb |
+ 'multi_placeholder': '%s(, ?%s)*', |
+ 'compare_op': COMPARE_OP_PAT, |
+ 'opt_asc_desc': '( ASC| DESC)?', |
+ 'opt_alias': '( AS %s)?' % TABLE_PAT, |
+ 'email_cond': (r'LOWER\(User\d+\.email\) ' |
+ r'(%s %%s|IN \(%%s(, ?%%s)*\))' % COMPARE_OP_PAT), |
+ } |
+ |
+ |
+def _MakeRE(regex_str): |
+ """Return a regular expression object, expanding our shorthand as needed.""" |
+ return re.compile(regex_str.format(**SHORTHAND)) |
+ |
+ |
+TABLE_RE = _MakeRE('^{table}$') |
+TAB_COL_RE = _MakeRE('^{tab_col}$') |
+USE_CLAUSE_RE = _MakeRE( |
+ r'^USE INDEX \({column}\) USE INDEX FOR ORDER BY \({column}\)$') |
+COLUMN_RE_LIST = [ |
+ TAB_COL_RE, |
+ _MakeRE(r'\*'), |
+ _MakeRE(r'COUNT\(\*\)'), |
+ _MakeRE(r'COUNT\({tab_col}\)'), |
+ _MakeRE(r'MAX\({tab_col}\)'), |
+ _MakeRE(r'MIN\({tab_col}\)'), |
+ ] |
+JOIN_RE_LIST = [ |
+ TABLE_RE, |
+ _MakeRE( |
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}' |
+ r'( AND {tab_col} = {tab_col})?' |
+ r'( AND {tab_col} IN \({multi_placeholder}\))?$'), |
+ _MakeRE( |
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}' |
+ r'( AND {tab_col} IN \({multi_placeholder}\))?$'), |
+ _MakeRE( |
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}' |
+ r'( AND {tab_col} = {tab_col})?' |
+ r' AND {tab_col} = {placeholder}$'), |
+ _MakeRE( |
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col} AND {email_cond}$'), |
+ _MakeRE( |
+ r'^{table}{opt_alias} ON ' |
+ r'\({tab_col} = {tab_col} OR {tab_col} = {tab_col}\)$'), |
+ _MakeRE( |
+ r'^\({table} AS {table} JOIN User AS {table} ' |
+ r'ON {tab_col} = {tab_col} ' |
+ r'AND {email_cond}\) ON Issue.id = {tab_col}'), |
+ _MakeRE( |
+ r'^{table} AS {table} ON {tab_col} = {tab_col} ' |
+ r'LEFT JOIN {table} AS {table} ON {tab_col} = {tab_col}'), |
+ ] |
+ORDER_BY_RE_LIST = [ |
+ _MakeRE(r'^{tab_col}{opt_asc_desc}$'), |
+ _MakeRE(r'^LOWER\({tab_col}\){opt_asc_desc}$'), |
+ _MakeRE(r'^ISNULL\({tab_col}\){opt_asc_desc}$'), |
+ _MakeRE(r'^FIELD\({tab_col}, {multi_placeholder}\){opt_asc_desc}$'), |
+ _MakeRE(r'^FIELD\(IF\(ISNULL\({tab_col}\), {tab_col}, {tab_col}\), ' |
+ r'{multi_placeholder}\){opt_asc_desc}$'), |
+ ] |
+GROUP_BY_RE_LIST = [ |
+ TAB_COL_RE, |
+ ] |
+WHERE_COND_RE_LIST = [ |
+ _MakeRE(r'^TRUE$'), |
+ _MakeRE(r'^FALSE$'), |
+ _MakeRE(r'^{tab_col} IS NULL$'), |
+ _MakeRE(r'^{tab_col} IS NOT NULL$'), |
+ _MakeRE(r'^{tab_col} {compare_op} {tab_col}$'), |
+ _MakeRE(r'^{tab_col} {compare_op} {placeholder}$'), |
+ _MakeRE(r'^{tab_col} %% {placeholder} = {placeholder}$'), |
+ _MakeRE(r'^{tab_col} IN \({multi_placeholder}\)$'), |
+ _MakeRE(r'^{tab_col} NOT IN \({multi_placeholder}\)$'), |
+ _MakeRE(r'^LOWER\({tab_col}\) IS NULL$'), |
+ _MakeRE(r'^LOWER\({tab_col}\) IS NOT NULL$'), |
+ _MakeRE(r'^LOWER\({tab_col}\) {compare_op} {placeholder}$'), |
+ _MakeRE(r'^LOWER\({tab_col}\) IN \({multi_placeholder}\)$'), |
+ _MakeRE(r'^LOWER\({tab_col}\) NOT IN \({multi_placeholder}\)$'), |
+ _MakeRE(r'^LOWER\({tab_col}\) LIKE {placeholder}$'), |
+ _MakeRE(r'^LOWER\({tab_col}\) NOT LIKE {placeholder}$'), |
+ _MakeRE(r'^timestep < \(SELECT MAX\(j.timestep\) FROM Invalidate AS j ' |
+ r'WHERE j.kind = %s ' |
+ r'AND j.cache_key = Invalidate.cache_key\)$'), |
+ _MakeRE(r'^\({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}\) ' |
+ 'AND \({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}' |
+ '\)$'), |
+ _MakeRE(r'^\({tab_col} IS NOT NULL AND {tab_col} {compare_op} ' |
+ '{placeholder}\) OR \({tab_col} IS NOT NULL AND {tab_col} ' |
+ '{compare_op} {placeholder}\)$'), |
+ ] |
+ |
+# Note: We never use ';' for multiple statements, '@' for SQL variables, or |
+# any quoted strings in stmt_str (quotes are put in my MySQLdb for args). |
+STMT_STR_RE = re.compile( |
+ r'\A(SELECT|UPDATE|DELETE|INSERT|REPLACE) [-+=!<>%*.,()\w\s]+\Z', |
+ re.MULTILINE) |
+ |
+ |
+def _IsValidTableName(table_name): |
+ return TABLE_RE.match(table_name) |
+ |
+ |
+def _IsValidColumnName(column_expr): |
+ return any(regex.match(column_expr) for regex in COLUMN_RE_LIST) |
+ |
+ |
+def _IsValidUseClause(use_clause): |
+ return USE_CLAUSE_RE.match(use_clause) |
+ |
+ |
+def _IsValidJoin(join): |
+ return any(regex.match(join) for regex in JOIN_RE_LIST) |
+ |
+ |
+def _IsValidOrderByTerm(term): |
+ return any(regex.match(term) for regex in ORDER_BY_RE_LIST) |
+ |
+ |
+def _IsValidGroupByTerm(term): |
+ return any(regex.match(term) for regex in GROUP_BY_RE_LIST) |
+ |
+ |
+def _IsValidWhereCond(cond): |
+ if cond.startswith('NOT '): |
+ cond = cond[4:] |
+ if cond.startswith('(') and cond.endswith(')'): |
+ cond = cond[1:-1] |
+ |
+ if any(regex.match(cond) for regex in WHERE_COND_RE_LIST): |
+ return True |
+ |
+ if ' OR ' in cond: |
+ return all(_IsValidWhereCond(c) for c in cond.split(' OR ')) |
+ |
+ if ' AND ' in cond: |
+ return all(_IsValidWhereCond(c) for c in cond.split(' AND ')) |
+ |
+ return False |
+ |
+ |
+def _IsValidStatement(stmt_str): |
+ """Final check to make sure there is no funny junk sneaking in somehow.""" |
+ return (STMT_STR_RE.match(stmt_str) and |
+ '--' not in stmt_str) |
+ |
+ |
+def _BoolsToInts(arg_list): |
+ """Convert any True values to 1s and Falses to 0s. |
+ |
+ Google's copy of MySQLdb has bool-to-int conversion disabled, |
+ and yet it seems to be needed otherwise they are converted |
+ to strings and always interpreted as 0 (which is FALSE). |
+ |
+ Args: |
+ arg_list: (nested) list of SQL statment argument values, which may |
+ include some boolean values. |
+ |
+ Returns: |
+ The same list, but with True replaced by 1 and False replaced by 0. |
+ """ |
+ result = [] |
+ for arg in arg_list: |
+ if isinstance(arg, (list, tuple)): |
+ result.append(_BoolsToInts(arg)) |
+ elif arg is True: |
+ result.append(1) |
+ elif arg is False: |
+ result.append(0) |
+ else: |
+ result.append(arg) |
+ |
+ return result |