Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(2050)

Unified Diff: appengine/monorail/framework/sql.py

Issue 1868553004: Open Source Monorail (Closed) Base URL: https://chromium.googlesource.com/infra/infra.git@master
Patch Set: Rebase Created 4 years, 8 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
« no previous file with comments | « appengine/monorail/framework/sorting.py ('k') | appengine/monorail/framework/table_view_helpers.py » ('j') | no next file with comments »
Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
Index: appengine/monorail/framework/sql.py
diff --git a/appengine/monorail/framework/sql.py b/appengine/monorail/framework/sql.py
new file mode 100644
index 0000000000000000000000000000000000000000..223912db397be67ab55dd899324b967c8998fdab
--- /dev/null
+++ b/appengine/monorail/framework/sql.py
@@ -0,0 +1,745 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is govered by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A set of classes for interacting with tables in SQL."""
+
+import logging
+import random
+import re
+import sys
+import time
+
+import settings
+
+if not settings.unit_test_mode:
+ import MySQLdb
+
+from framework import framework_helpers
+
+
+# MonorailConnection maintains a dictionary of connections to SQL databases.
+# Each is identified by an int shard ID.
+# And there is one connection to the master DB identified by key MASTER_CNXN.
+MASTER_CNXN = 'master_cnxn'
+
+
+@framework_helpers.retry(2, delay=1, backoff=2)
+def MakeConnection(instance, database):
+ logging.info('About to connect to SQL instance %r db %r', instance, database)
+ if settings.unit_test_mode:
+ raise ValueError('unit tests should not need real database connections')
+ if settings.dev_mode:
+ cnxn = MySQLdb.connect(
+ host='127.0.0.1', port=3306, db=database, user='root', charset='utf8')
+ else:
+ cnxn = MySQLdb.connect(
+ unix_socket='/cloudsql/' + instance, db=database, user='root',
+ charset='utf8')
+ return cnxn
+
+
+class MonorailConnection(object):
+ """Create and manage connections to the SQL servers.
+
+ We only store connections in the context of a single user request, not
+ across user requests. The main purpose of this class is to make using
+ sharded tables easier.
+ """
+
+ def __init__(self):
+ self.sql_cnxns = {} # {MASTER_CNXN: cnxn, shard_id: cnxn, ...}
+
+ def GetMasterConnection(self):
+ """Return a connection to the master SQL DB."""
+ if MASTER_CNXN not in self.sql_cnxns:
+ self.sql_cnxns[MASTER_CNXN] = MakeConnection(
+ settings.db_instance, settings.db_database_name)
+ logging.info(
+ 'created a master connection %r', self.sql_cnxns[MASTER_CNXN])
+
+ return self.sql_cnxns[MASTER_CNXN]
+
+ def GetConnectionForShard(self, shard_id):
+ """Return a connection to the DB replica that will be used for shard_id."""
+ if settings.dev_mode:
+ return self.GetMasterConnection()
+
+ if shard_id not in self.sql_cnxns:
+ physical_shard_id = shard_id % settings.num_logical_shards
+ shard_instance_name = (
+ settings.physical_db_name_format % physical_shard_id)
+ self.sql_cnxns[shard_id] = MakeConnection(
+ shard_instance_name, settings.db_database_name)
+ logging.info('created a replica connection for shard %d', shard_id)
+
+ return self.sql_cnxns[shard_id]
+
+ def Execute(self, stmt_str, stmt_args, shard_id=None, commit=True):
+ """Execute the given SQL statement on one of the relevant databases."""
+ if shard_id is None:
+ # No shard was specified, so hit the master.
+ sql_cnxn = self.GetMasterConnection()
+ else:
+ sql_cnxn = self.GetConnectionForShard(shard_id)
+
+ return self._ExecuteWithSQLConnection(
+ sql_cnxn, stmt_str, stmt_args, commit=commit)
+
+ def _ExecuteWithSQLConnection(
+ self, sql_cnxn, stmt_str, stmt_args, commit=True):
+ """Execute a statement on the given database and return a cursor."""
+ cursor = sql_cnxn.cursor()
+ start_time = time.time()
+ if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'):
+ logging.info('SQL stmt_str: \n%s', stmt_str)
+ logging.info('SQL stmt_args: %r', stmt_args)
+ cursor.executemany(stmt_str, stmt_args)
+ else:
+ logging.info('SQL stmt: \n%s', (stmt_str % tuple(stmt_args)))
+ cursor.execute(stmt_str, args=stmt_args)
+ logging.info('%d rows in %d ms', cursor.rowcount,
+ int((time.time() - start_time) * 1000))
+ if commit and not stmt_str.startswith('SELECT'):
+ start_time = time.time()
+ try:
+ sql_cnxn.commit()
+ except MySQLdb.DatabaseError:
+ sql_cnxn.rollback()
+ logging.info('commit took %d ms',
+ int((time.time() - start_time) * 1000))
+
+ return cursor
+
+ def Commit(self):
+ """Explicitly commit any pending txns. Normally done automatically."""
+ sql_cnxn = self.GetMasterConnection()
+ start_time = time.time()
+ try:
+ sql_cnxn.commit()
+ except MySQLdb.DatabaseError:
+ logging.exception('Commit failed for cnxn, rolling back')
+ sql_cnxn.rollback()
+ logging.info('final commit took %d ms',
+ int((time.time() - start_time) * 1000))
+
+ def Close(self):
+ """Safely close any connections that are still open."""
+ for sql_cnxn in self.sql_cnxns.itervalues():
+ try:
+ sql_cnxn.close()
+ except MySQLdb.DatabaseError:
+ # This might happen if the cnxn is somehow already closed.
+ logging.exception('ProgrammingError when trying to close cnxn')
+
+
+class SQLTableManager(object):
+ """Helper class to make it easier to deal with an SQL table."""
+
+ def __init__(self, table_name):
+ self.table_name = table_name
+
+ def Select(
+ self, cnxn, distinct=False, cols=None, left_joins=None,
+ joins=None, where=None, or_where_conds=False, group_by=None,
+ order_by=None, limit=None, offset=None, shard_id=None, use_clause=None,
+ **kwargs):
+ """Compose and execute an SQL SELECT statement on this table.
+
+ Args:
+ cnxn: MonorailConnection to the databases.
+ distinct: If True, add DISTINCT keyword.
+ cols: List of columns to retrieve, defaults to '*'.
+ left_joins: List of LEFT JOIN (str, args) pairs.
+ joins: List of regular JOIN (str, args) pairs.
+ where: List of (str, args) for WHERE clause.
+ or_where_conds: Set to True to use OR in the WHERE conds.
+ group_by: List of strings for GROUP BY clause.
+ order_by: List of (str, args) for ORDER BY clause.
+ limit: Optional LIMIT on the number of rows returned.
+ offset: Optional OFFSET when using LIMIT.
+ shard_id: Int ID of the shard to query.
+ use_clause: Optional string USE clause to tell the DB which index to use.
+ **kwargs: WHERE-clause equality and set-membership conditions.
+
+ Keyword args are used to build up more WHERE conditions that compare
+ column values to constants. Key word Argument foo='bar' translates to 'foo
+ = "bar"', and foo=[3, 4, 5] translates to 'foo IN (3, 4, 5)'.
+
+ Returns:
+ A list of rows, each row is a tuple of values for the requested cols.
+ """
+ cols = cols or ['*'] # If columns not specified, retrieve all columns.
+ stmt = Statement.MakeSelect(
+ self.table_name, cols, distinct=distinct,
+ or_where_conds=or_where_conds)
+ if use_clause:
+ stmt.AddUseClause(use_clause)
+ stmt.AddJoinClauses(left_joins or [], left=True)
+ stmt.AddJoinClauses(joins or [])
+ stmt.AddWhereTerms(where or [], **kwargs)
+ stmt.AddGroupByTerms(group_by or [])
+ stmt.AddOrderByTerms(order_by or [])
+ stmt.SetLimitAndOffset(limit, offset)
+ stmt_str, stmt_args = stmt.Generate()
+
+ cursor = cnxn.Execute(stmt_str, stmt_args, shard_id=shard_id)
+ rows = cursor.fetchall()
+ return rows
+
+ def SelectRow(
+ self, cnxn, cols=None, default=None, where=None, **kwargs):
+ """Run a query that is expected to return just one row."""
+ rows = self.Select(cnxn, distinct=True, cols=cols, where=where, **kwargs)
+ if len(rows) == 1:
+ return rows[0]
+ elif not rows:
+ logging.info('SelectRow got 0 results, so using default %r', default)
+ return default
+ else:
+ raise ValueError('SelectRow got %d results, expected only 1', len(rows))
+
+ def SelectValue(self, cnxn, col, default=None, where=None, **kwargs):
+ """Run a query that is expected to return just one row w/ one value."""
+ row = self.SelectRow(
+ cnxn, cols=[col], default=[default], where=where, **kwargs)
+ return row[0]
+
+ def InsertRows(
+ self, cnxn, cols, row_values, replace=False, ignore=False,
+ commit=True, return_generated_ids=False):
+ """Insert all the given rows.
+
+ Args:
+ cnxn: MonorailConnection object.
+ cols: List of column names to set.
+ row_values: List of lists with values to store. The length of each
+ nested list should be equal to len(cols).
+ replace: Set to True if inserted values should replace existing DB rows
+ that have the same DB keys.
+ ignore: Set to True to ignore rows that would duplicate existing DB keys.
+ commit: Set to False if this operation is part of a series of operations
+ that should not be committed until the final one is done.
+ return_generated_ids: Set to True to return a list of generated
+ autoincrement IDs for inserted rows. This requires us to insert rows
+ one at a time.
+
+ Returns:
+ If return_generated_ids is set to True, this method returns a list of the
+ auto-increment IDs generated by the DB. Otherwise, [] is returned.
+ """
+ if not row_values:
+ return None # Nothing to insert
+
+ generated_ids = []
+ if return_generated_ids:
+ # We must insert the rows one-at-a-time to know the generated IDs.
+ for row_value in row_values:
+ stmt = Statement.MakeInsert(
+ self.table_name, cols, [row_value], replace=replace, ignore=ignore)
+ stmt_str, stmt_args = stmt.Generate()
+ cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
+ if cursor.lastrowid:
+ generated_ids.append(cursor.lastrowid)
+ return generated_ids
+
+ stmt = Statement.MakeInsert(
+ self.table_name, cols, row_values, replace=replace, ignore=ignore)
+ stmt_str, stmt_args = stmt.Generate()
+ cnxn.Execute(stmt_str, stmt_args, commit=commit)
+ return []
+
+
+ def InsertRow(
+ self, cnxn, replace=False, ignore=False, commit=True, **kwargs):
+ """Insert a single row into the table.
+
+ Args:
+ cnxn: MonorailConnection object.
+ replace: Set to True if inserted values should replace existing DB rows
+ that have the same DB keys.
+ ignore: Set to True to ignore rows that would duplicate existing DB keys.
+ commit: Set to False if this operation is part of a series of operations
+ that should not be committed until the final one is done.
+ **kwargs: column=value assignments to specify what to store in the DB.
+
+ Returns:
+ The generated autoincrement ID of the key column if one was generated.
+ Otherwise, return None.
+ """
+ cols = sorted(kwargs.keys())
+ row = tuple(kwargs[col] for col in cols)
+ generated_ids = self.InsertRows(
+ cnxn, cols, [row], replace=replace, ignore=ignore,
+ commit=commit, return_generated_ids=True)
+ if generated_ids:
+ return generated_ids[0]
+ else:
+ return None
+
+ def Update(self, cnxn, delta, where=None, commit=True, **kwargs):
+ """Update one or more rows.
+
+ Args:
+ cnxn: MonorailConnection object.
+ delta: Dictionary of {column: new_value} assignments.
+ where: Optional list of WHERE conditions saying which rows to update.
+ commit: Set to False if this operation is part of a series of operations
+ that should not be committed until the final one is done.
+ **kwargs: WHERE-clause equality and set-membership conditions.
+
+ Returns:
+ Int number of rows updated.
+ """
+ if not delta:
+ return 0 # Nothing is being changed
+
+ stmt = Statement.MakeUpdate(self.table_name, delta)
+ stmt.AddWhereTerms(where, **kwargs)
+ stmt_str, stmt_args = stmt.Generate()
+
+ cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
+ return cursor.rowcount
+
+ def IncrementCounterValue(self, cnxn, col_name, where=None, **kwargs):
+ """Atomically increment a counter stored in MySQL, return new value.
+
+ Args:
+ cnxn: MonorailConnection object.
+ col_name: int column to increment.
+ where: Optional list of WHERE conditions saying which rows to update.
+ **kwargs: WHERE-clause equality and set-membership conditions. The
+ where and kwargs together should narrow the update down to exactly
+ one row.
+
+ Returns:
+ The new, post-increment value of the counter.
+ """
+ stmt = Statement.MakeIncrement(self.table_name, col_name)
+ stmt.AddWhereTerms(where, **kwargs)
+ stmt_str, stmt_args = stmt.Generate()
+
+ cursor = cnxn.Execute(stmt_str, stmt_args)
+ assert cursor.rowcount == 1, (
+ 'missing or ambiguous counter: %r' % cursor.rowcount)
+ return cursor.lastrowid
+
+ def Delete(self, cnxn, where=None, commit=True, **kwargs):
+ """Delete the specified table rows.
+
+ Args:
+ cnxn: MonorailConnection object.
+ where: Optional list of WHERE conditions saying which rows to update.
+ commit: Set to False if this operation is part of a series of operations
+ that should not be committed until the final one is done.
+ **kwargs: WHERE-clause equality and set-membership conditions.
+
+ Returns:
+ Int number of rows updated.
+ """
+ # Deleting the whole table is never intended in Monorail.
+ assert where or kwargs
+
+ stmt = Statement.MakeDelete(self.table_name)
+ stmt.AddWhereTerms(where, **kwargs)
+ stmt_str, stmt_args = stmt.Generate()
+
+ cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
+ return cursor.rowcount
+
+
+class Statement(object):
+ """A class to help build complex SQL statements w/ full escaping.
+
+ Start with a Make*() method, then fill in additional clauses as needed,
+ then call Generate() to return the SQL string and argument list. We pass
+ the string and args to MySQLdb separately so that it can do escaping on
+ the arg values as appropriate to prevent SQL-injection attacks.
+
+ The only values that are not escaped by MySQLdb are the table names
+ and column names, and bits of SQL syntax, all of which is hard-coded
+ in our application.
+ """
+
+ @classmethod
+ def MakeSelect(cls, table_name, cols, distinct=False, or_where_conds=False):
+ """Constuct a SELECT statement."""
+ assert _IsValidTableName(table_name)
+ assert all(_IsValidColumnName(col) for col in cols)
+ main_clause = 'SELECT%s %s FROM %s' % (
+ (' DISTINCT' if distinct else ''), ', '.join(cols), table_name)
+ return cls(main_clause, or_where_conds=or_where_conds)
+
+ @classmethod
+ def MakeInsert(
+ cls, table_name, cols, new_values, replace=False, ignore=False):
+ """Constuct an INSERT statement."""
+ if replace == True:
+ return cls.MakeReplace(table_name, cols, new_values, ignore)
+ assert _IsValidTableName(table_name)
+ assert all(_IsValidColumnName(col) for col in cols)
+ ignore_word = ' IGNORE' if ignore else ''
+ main_clause = 'INSERT%s INTO %s (%s)' % (
+ ignore_word, table_name, ', '.join(cols))
+ return cls(main_clause, insert_args=new_values)
+
+ @classmethod
+ def MakeReplace(
+ cls, table_name, cols, new_values, ignore=False):
+ """Construct an INSERT...ON DUPLICATE KEY UPDATE... statement.
+
+ Uses the INSERT/UPDATE syntax because REPLACE is literally a DELETE
+ followed by an INSERT, which doesn't play well with foreign keys.
+ INSERT/UPDATE is an atomic check of whether the primary key exists,
+ followed by an INSERT if it doesn't or an UPDATE if it does.
+ """
+ assert _IsValidTableName(table_name)
+ assert all(_IsValidColumnName(col) for col in cols)
+ ignore_word = ' IGNORE' if ignore else ''
+ main_clause = 'INSERT%s INTO %s (%s)' % (
+ ignore_word, table_name, ', '.join(cols))
+ return cls(main_clause, insert_args=new_values, duplicate_update_cols=cols)
+
+ @classmethod
+ def MakeUpdate(cls, table_name, delta):
+ """Constuct an UPDATE statement."""
+ assert _IsValidTableName(table_name)
+ assert all(_IsValidColumnName(col) for col in delta.iterkeys())
+ update_strs = []
+ update_args = []
+ for col, val in delta.iteritems():
+ update_strs.append(col + '=%s')
+ update_args.append(val)
+
+ main_clause = 'UPDATE %s SET %s' % (
+ table_name, ', '.join(update_strs))
+ return cls(main_clause, update_args=update_args)
+
+ @classmethod
+ def MakeIncrement(cls, table_name, col_name, step=1):
+ """Constuct an UPDATE statement that increments and returns a counter."""
+ assert _IsValidTableName(table_name)
+ assert _IsValidColumnName(col_name)
+
+ main_clause = (
+ 'UPDATE %s SET %s = LAST_INSERT_ID(%s + %%s)' % (
+ table_name, col_name, col_name))
+ update_args = [step]
+ return cls(main_clause, update_args=update_args)
+
+ @classmethod
+ def MakeDelete(cls, table_name):
+ """Constuct a DELETE statement."""
+ assert _IsValidTableName(table_name)
+ main_clause = 'DELETE FROM %s' % table_name
+ return cls(main_clause)
+
+ def __init__(
+ self, main_clause, insert_args=None, update_args=None,
+ duplicate_update_cols=None, or_where_conds=False):
+ self.main_clause = main_clause # E.g., SELECT or DELETE
+ self.or_where_conds = or_where_conds
+ self.insert_args = insert_args or [] # For INSERT statements
+ self.update_args = update_args or [] # For UPDATEs
+ self.duplicate_update_cols = duplicate_update_cols or [] # For REPLACE-ish
+
+ self.use_clauses = []
+ self.join_clauses, self.join_args = [], []
+ self.where_conds, self.where_args = [], []
+ self.group_by_terms, self.group_by_args = [], []
+ self.order_by_terms, self.order_by_args = [], []
+ self.limit, self.offset = None, None
+
+ def Generate(self):
+ """Return an SQL string having %s placeholders and args to fill them in."""
+ clauses = [self.main_clause] + self.use_clauses + self.join_clauses
+ if self.where_conds:
+ if self.or_where_conds:
+ clauses.append('WHERE ' + '\n OR '.join(self.where_conds))
+ else:
+ clauses.append('WHERE ' + '\n AND '.join(self.where_conds))
+ if self.group_by_terms:
+ clauses.append('GROUP BY ' + ', '.join(self.group_by_terms))
+ if self.order_by_terms:
+ clauses.append('ORDER BY ' + ', '.join(self.order_by_terms))
+
+ if self.limit and self.offset:
+ clauses.append('LIMIT %d OFFSET %d' % (self.limit, self.offset))
+ elif self.limit:
+ clauses.append('LIMIT %d' % self.limit)
+ elif self.offset:
+ clauses.append('LIMIT %d OFFSET %d' % (sys.maxint, self.offset))
+
+ if self.insert_args:
+ clauses.append('VALUES (' + PlaceHolders(self.insert_args[0]) + ')')
+ args = self.insert_args
+ if self.duplicate_update_cols:
+ clauses.append('ON DUPLICATE KEY UPDATE %s' % (
+ ', '.join(['%s=VALUES(%s)' % (col, col)
+ for col in self.duplicate_update_cols])))
+ assert not (self.join_args + self.update_args + self.where_args +
+ self.group_by_args + self.order_by_args)
+ else:
+ args = (self.join_args + self.update_args + self.where_args +
+ self.group_by_args + self.order_by_args)
+ assert not (self.insert_args + self.duplicate_update_cols)
+
+ args = _BoolsToInts(args)
+ stmt_str = '\n'.join(clause for clause in clauses if clause)
+
+ assert _IsValidStatement(stmt_str), stmt_str
+ return stmt_str, args
+
+ def AddUseClause(self, use_clause):
+ """Add a USE clause (giving the DB a hint about which indexes to use)."""
+ assert _IsValidUseClause(use_clause), use_clause
+ self.use_clauses.append(use_clause)
+
+ def AddJoinClauses(self, join_pairs, left=False):
+ """Save JOIN clauses based on the given list of join conditions."""
+ for join, args in join_pairs:
+ assert _IsValidJoin(join), join
+ assert join.count('%s') == len(args), join
+ self.join_clauses.append(
+ ' %sJOIN %s' % (('LEFT ' if left else ''), join))
+ self.join_args.extend(args)
+
+ def AddGroupByTerms(self, group_by_term_list):
+ """Save info needed to generate the GROUP BY clause."""
+ assert all(_IsValidGroupByTerm(term) for term in group_by_term_list)
+ self.group_by_terms.extend(group_by_term_list)
+
+ def AddOrderByTerms(self, order_by_pairs):
+ """Save info needed to generate the ORDER BY clause."""
+ for term, args in order_by_pairs:
+ assert _IsValidOrderByTerm(term), term
+ assert term.count('%s') == len(args), term
+ self.order_by_terms.append(term)
+ self.order_by_args.extend(args)
+
+ def SetLimitAndOffset(self, limit, offset):
+ """Save info needed to generate the LIMIT OFFSET clause."""
+ self.limit = limit
+ self.offset = offset
+
+ def AddWhereTerms(self, where_cond_pairs, **kwargs):
+ """Gererate a WHERE clause."""
+ where_cond_pairs = where_cond_pairs or []
+
+ for cond, args in where_cond_pairs:
+ assert _IsValidWhereCond(cond), cond
+ assert cond.count('%s') == len(args), cond
+ self.where_conds.append(cond)
+ self.where_args.extend(args)
+
+ for col, val in sorted(kwargs.items()):
+ assert _IsValidColumnName(col), col
+ eq = True
+ if col.endswith('_not'):
+ col = col[:-4]
+ eq = False
+
+ if isinstance(val, set):
+ val = list(val) # MySQL inteface cannot handle sets.
+
+ if val is None or val == []:
+ op = 'IS' if eq else 'IS NOT'
+ self.where_conds.append(col + ' ' + op + ' NULL')
+ elif isinstance(val, list):
+ op = 'IN' if eq else 'NOT IN'
+ # Sadly, MySQLdb cannot escape lists, so we flatten to multiple "%s"s
+ self.where_conds.append(
+ col + ' ' + op + ' (' + PlaceHolders(val) + ')')
+ self.where_args.extend(val)
+ else:
+ op = '=' if eq else '!='
+ self.where_conds.append(col + ' ' + op + ' %s')
+ self.where_args.append(val)
+
+
+def PlaceHolders(sql_args):
+ """Return a comma-separated list of %s placeholders for the given args."""
+ return ','.join('%s' for _ in sql_args)
+
+
+TABLE_PAT = '[A-Z][_a-zA-Z0-9]+'
+COLUMN_PAT = '[a-z][_a-z]+'
+COMPARE_OP_PAT = '(<|>|=|!=|>=|<=|LIKE|NOT LIKE)'
+SHORTHAND = {
+ 'table': TABLE_PAT,
+ 'column': COLUMN_PAT,
+ 'tab_col': r'(%s\.)?%s' % (TABLE_PAT, COLUMN_PAT),
+ 'placeholder': '%s', # That's a literal %s that gets passed to MySQLdb
+ 'multi_placeholder': '%s(, ?%s)*',
+ 'compare_op': COMPARE_OP_PAT,
+ 'opt_asc_desc': '( ASC| DESC)?',
+ 'opt_alias': '( AS %s)?' % TABLE_PAT,
+ 'email_cond': (r'LOWER\(User\d+\.email\) '
+ r'(%s %%s|IN \(%%s(, ?%%s)*\))' % COMPARE_OP_PAT),
+ }
+
+
+def _MakeRE(regex_str):
+ """Return a regular expression object, expanding our shorthand as needed."""
+ return re.compile(regex_str.format(**SHORTHAND))
+
+
+TABLE_RE = _MakeRE('^{table}$')
+TAB_COL_RE = _MakeRE('^{tab_col}$')
+USE_CLAUSE_RE = _MakeRE(
+ r'^USE INDEX \({column}\) USE INDEX FOR ORDER BY \({column}\)$')
+COLUMN_RE_LIST = [
+ TAB_COL_RE,
+ _MakeRE(r'\*'),
+ _MakeRE(r'COUNT\(\*\)'),
+ _MakeRE(r'COUNT\({tab_col}\)'),
+ _MakeRE(r'MAX\({tab_col}\)'),
+ _MakeRE(r'MIN\({tab_col}\)'),
+ ]
+JOIN_RE_LIST = [
+ TABLE_RE,
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} = {tab_col})?'
+ r'( AND {tab_col} IN \({multi_placeholder}\))?$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} IN \({multi_placeholder}\))?$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} = {tab_col})?'
+ r' AND {tab_col} = {placeholder}$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col} AND {email_cond}$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON '
+ r'\({tab_col} = {tab_col} OR {tab_col} = {tab_col}\)$'),
+ _MakeRE(
+ r'^\({table} AS {table} JOIN User AS {table} '
+ r'ON {tab_col} = {tab_col} '
+ r'AND {email_cond}\) ON Issue.id = {tab_col}'),
+ _MakeRE(
+ r'^{table} AS {table} ON {tab_col} = {tab_col} '
+ r'LEFT JOIN {table} AS {table} ON {tab_col} = {tab_col}'),
+ ]
+ORDER_BY_RE_LIST = [
+ _MakeRE(r'^{tab_col}{opt_asc_desc}$'),
+ _MakeRE(r'^LOWER\({tab_col}\){opt_asc_desc}$'),
+ _MakeRE(r'^ISNULL\({tab_col}\){opt_asc_desc}$'),
+ _MakeRE(r'^FIELD\({tab_col}, {multi_placeholder}\){opt_asc_desc}$'),
+ _MakeRE(r'^FIELD\(IF\(ISNULL\({tab_col}\), {tab_col}, {tab_col}\), '
+ r'{multi_placeholder}\){opt_asc_desc}$'),
+ ]
+GROUP_BY_RE_LIST = [
+ TAB_COL_RE,
+ ]
+WHERE_COND_RE_LIST = [
+ _MakeRE(r'^TRUE$'),
+ _MakeRE(r'^FALSE$'),
+ _MakeRE(r'^{tab_col} IS NULL$'),
+ _MakeRE(r'^{tab_col} IS NOT NULL$'),
+ _MakeRE(r'^{tab_col} {compare_op} {tab_col}$'),
+ _MakeRE(r'^{tab_col} {compare_op} {placeholder}$'),
+ _MakeRE(r'^{tab_col} %% {placeholder} = {placeholder}$'),
+ _MakeRE(r'^{tab_col} IN \({multi_placeholder}\)$'),
+ _MakeRE(r'^{tab_col} NOT IN \({multi_placeholder}\)$'),
+ _MakeRE(r'^LOWER\({tab_col}\) IS NULL$'),
+ _MakeRE(r'^LOWER\({tab_col}\) IS NOT NULL$'),
+ _MakeRE(r'^LOWER\({tab_col}\) {compare_op} {placeholder}$'),
+ _MakeRE(r'^LOWER\({tab_col}\) IN \({multi_placeholder}\)$'),
+ _MakeRE(r'^LOWER\({tab_col}\) NOT IN \({multi_placeholder}\)$'),
+ _MakeRE(r'^LOWER\({tab_col}\) LIKE {placeholder}$'),
+ _MakeRE(r'^LOWER\({tab_col}\) NOT LIKE {placeholder}$'),
+ _MakeRE(r'^timestep < \(SELECT MAX\(j.timestep\) FROM Invalidate AS j '
+ r'WHERE j.kind = %s '
+ r'AND j.cache_key = Invalidate.cache_key\)$'),
+ _MakeRE(r'^\({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}\) '
+ 'AND \({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}'
+ '\)$'),
+ _MakeRE(r'^\({tab_col} IS NOT NULL AND {tab_col} {compare_op} '
+ '{placeholder}\) OR \({tab_col} IS NOT NULL AND {tab_col} '
+ '{compare_op} {placeholder}\)$'),
+ ]
+
+# Note: We never use ';' for multiple statements, '@' for SQL variables, or
+# any quoted strings in stmt_str (quotes are put in my MySQLdb for args).
+STMT_STR_RE = re.compile(
+ r'\A(SELECT|UPDATE|DELETE|INSERT|REPLACE) [-+=!<>%*.,()\w\s]+\Z',
+ re.MULTILINE)
+
+
+def _IsValidTableName(table_name):
+ return TABLE_RE.match(table_name)
+
+
+def _IsValidColumnName(column_expr):
+ return any(regex.match(column_expr) for regex in COLUMN_RE_LIST)
+
+
+def _IsValidUseClause(use_clause):
+ return USE_CLAUSE_RE.match(use_clause)
+
+
+def _IsValidJoin(join):
+ return any(regex.match(join) for regex in JOIN_RE_LIST)
+
+
+def _IsValidOrderByTerm(term):
+ return any(regex.match(term) for regex in ORDER_BY_RE_LIST)
+
+
+def _IsValidGroupByTerm(term):
+ return any(regex.match(term) for regex in GROUP_BY_RE_LIST)
+
+
+def _IsValidWhereCond(cond):
+ if cond.startswith('NOT '):
+ cond = cond[4:]
+ if cond.startswith('(') and cond.endswith(')'):
+ cond = cond[1:-1]
+
+ if any(regex.match(cond) for regex in WHERE_COND_RE_LIST):
+ return True
+
+ if ' OR ' in cond:
+ return all(_IsValidWhereCond(c) for c in cond.split(' OR '))
+
+ if ' AND ' in cond:
+ return all(_IsValidWhereCond(c) for c in cond.split(' AND '))
+
+ return False
+
+
+def _IsValidStatement(stmt_str):
+ """Final check to make sure there is no funny junk sneaking in somehow."""
+ return (STMT_STR_RE.match(stmt_str) and
+ '--' not in stmt_str)
+
+
+def _BoolsToInts(arg_list):
+ """Convert any True values to 1s and Falses to 0s.
+
+ Google's copy of MySQLdb has bool-to-int conversion disabled,
+ and yet it seems to be needed otherwise they are converted
+ to strings and always interpreted as 0 (which is FALSE).
+
+ Args:
+ arg_list: (nested) list of SQL statment argument values, which may
+ include some boolean values.
+
+ Returns:
+ The same list, but with True replaced by 1 and False replaced by 0.
+ """
+ result = []
+ for arg in arg_list:
+ if isinstance(arg, (list, tuple)):
+ result.append(_BoolsToInts(arg))
+ elif arg is True:
+ result.append(1)
+ elif arg is False:
+ result.append(0)
+ else:
+ result.append(arg)
+
+ return result
« no previous file with comments | « appengine/monorail/framework/sorting.py ('k') | appengine/monorail/framework/table_view_helpers.py » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698