Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(4)

Unified Diff: tools/telemetry/third_party/rope/rope/refactor/extract.py

Issue 1132103009: Example of refactoring using rope library. (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@master
Patch Set: Created 5 years, 7 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
Index: tools/telemetry/third_party/rope/rope/refactor/extract.py
diff --git a/tools/telemetry/third_party/rope/rope/refactor/extract.py b/tools/telemetry/third_party/rope/rope/refactor/extract.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6b3648e47aade5774475446950e136ad227a811
--- /dev/null
+++ b/tools/telemetry/third_party/rope/rope/refactor/extract.py
@@ -0,0 +1,804 @@
+import re
+
+from rope.base import ast, codeanalyze
+from rope.base.change import ChangeSet, ChangeContents
+from rope.base.exceptions import RefactoringError
+from rope.refactor import (sourceutils, similarfinder,
+ patchedast, suites, usefunction)
+
+
+# Extract refactoring has lots of special cases. I tried to split it
+# to smaller parts to make it more manageable:
+#
+# _ExtractInfo: holds information about the refactoring; it is passed
+# to the parts that need to have information about the refactoring
+#
+# _ExtractCollector: merely saves all of the information necessary for
+# performing the refactoring.
+#
+# _DefinitionLocationFinder: finds where to insert the definition.
+#
+# _ExceptionalConditionChecker: checks for exceptional conditions in
+# which the refactoring cannot be applied.
+#
+# _ExtractMethodParts: generates the pieces of code (like definition)
+# needed for performing extract method.
+#
+# _ExtractVariableParts: like _ExtractMethodParts for variables.
+#
+# _ExtractPerformer: Uses above classes to collect refactoring
+# changes.
+#
+# There are a few more helper functions and classes used by above
+# classes.
+class _ExtractRefactoring(object):
+
+ def __init__(self, project, resource, start_offset, end_offset,
+ variable=False):
+ self.project = project
+ self.resource = resource
+ self.start_offset = self._fix_start(resource.read(), start_offset)
+ self.end_offset = self._fix_end(resource.read(), end_offset)
+
+ def _fix_start(self, source, offset):
+ while offset < len(source) and source[offset].isspace():
+ offset += 1
+ return offset
+
+ def _fix_end(self, source, offset):
+ while offset > 0 and source[offset - 1].isspace():
+ offset -= 1
+ return offset
+
+ def get_changes(self, extracted_name, similar=False, global_=False):
+ """Get the changes this refactoring makes
+
+ :parameters:
+ - `similar`: if `True`, similar expressions/statements are also
+ replaced.
+ - `global_`: if `True`, the extracted method/variable will
+ be global.
+
+ """
+ info = _ExtractInfo(
+ self.project, self.resource, self.start_offset, self.end_offset,
+ extracted_name, variable=self.kind == 'variable',
+ similar=similar, make_global=global_)
+ new_contents = _ExtractPerformer(info).extract()
+ changes = ChangeSet('Extract %s <%s>' % (self.kind,
+ extracted_name))
+ changes.add_change(ChangeContents(self.resource, new_contents))
+ return changes
+
+
+class ExtractMethod(_ExtractRefactoring):
+
+ def __init__(self, *args, **kwds):
+ super(ExtractMethod, self).__init__(*args, **kwds)
+
+ kind = 'method'
+
+
+class ExtractVariable(_ExtractRefactoring):
+
+ def __init__(self, *args, **kwds):
+ kwds = dict(kwds)
+ kwds['variable'] = True
+ super(ExtractVariable, self).__init__(*args, **kwds)
+
+ kind = 'variable'
+
+
+class _ExtractInfo(object):
+ """Holds information about the extract to be performed"""
+
+ def __init__(self, project, resource, start, end, new_name,
+ variable, similar, make_global):
+ self.project = project
+ self.resource = resource
+ self.pymodule = project.get_pymodule(resource)
+ self.global_scope = self.pymodule.get_scope()
+ self.source = self.pymodule.source_code
+ self.lines = self.pymodule.lines
+ self.new_name = new_name
+ self.variable = variable
+ self.similar = similar
+ self._init_parts(start, end)
+ self._init_scope()
+ self.make_global = make_global
+
+ def _init_parts(self, start, end):
+ self.region = (self._choose_closest_line_end(start),
+ self._choose_closest_line_end(end, end=True))
+
+ start = self.logical_lines.logical_line_in(
+ self.lines.get_line_number(self.region[0]))[0]
+ end = self.logical_lines.logical_line_in(
+ self.lines.get_line_number(self.region[1]))[1]
+ self.region_lines = (start, end)
+
+ self.lines_region = (self.lines.get_line_start(self.region_lines[0]),
+ self.lines.get_line_end(self.region_lines[1]))
+
+ @property
+ def logical_lines(self):
+ return self.pymodule.logical_lines
+
+ def _init_scope(self):
+ start_line = self.region_lines[0]
+ scope = self.global_scope.get_inner_scope_for_line(start_line)
+ if scope.get_kind() != 'Module' and scope.get_start() == start_line:
+ scope = scope.parent
+ self.scope = scope
+ self.scope_region = self._get_scope_region(self.scope)
+
+ def _get_scope_region(self, scope):
+ return (self.lines.get_line_start(scope.get_start()),
+ self.lines.get_line_end(scope.get_end()) + 1)
+
+ def _choose_closest_line_end(self, offset, end=False):
+ lineno = self.lines.get_line_number(offset)
+ line_start = self.lines.get_line_start(lineno)
+ line_end = self.lines.get_line_end(lineno)
+ if self.source[line_start:offset].strip() == '':
+ if end:
+ return line_start - 1
+ else:
+ return line_start
+ elif self.source[offset:line_end].strip() == '':
+ return min(line_end, len(self.source))
+ return offset
+
+ @property
+ def one_line(self):
+ return self.region != self.lines_region and \
+ (self.logical_lines.logical_line_in(self.region_lines[0]) ==
+ self.logical_lines.logical_line_in(self.region_lines[1]))
+
+ @property
+ def global_(self):
+ return self.scope.parent is None
+
+ @property
+ def method(self):
+ return self.scope.parent is not None and \
+ self.scope.parent.get_kind() == 'Class'
+
+ @property
+ def indents(self):
+ return sourceutils.get_indents(self.pymodule.lines,
+ self.region_lines[0])
+
+ @property
+ def scope_indents(self):
+ if self.global_:
+ return 0
+ return sourceutils.get_indents(self.pymodule.lines,
+ self.scope.get_start())
+
+ @property
+ def extracted(self):
+ return self.source[self.region[0]:self.region[1]]
+
+ _returned = None
+
+ @property
+ def returned(self):
+ """Does the extracted piece contain return statement"""
+ if self._returned is None:
+ node = _parse_text(self.extracted)
+ self._returned = usefunction._returns_last(node)
+ return self._returned
+
+
+class _ExtractCollector(object):
+ """Collects information needed for performing the extract"""
+
+ def __init__(self, info):
+ self.definition = None
+ self.body_pattern = None
+ self.checks = {}
+ self.replacement_pattern = None
+ self.matches = None
+ self.replacements = None
+ self.definition_location = None
+
+
+class _ExtractPerformer(object):
+
+ def __init__(self, info):
+ self.info = info
+ _ExceptionalConditionChecker()(self.info)
+
+ def extract(self):
+ extract_info = self._collect_info()
+ content = codeanalyze.ChangeCollector(self.info.source)
+ definition = extract_info.definition
+ lineno, indents = extract_info.definition_location
+ offset = self.info.lines.get_line_start(lineno)
+ indented = sourceutils.fix_indentation(definition, indents)
+ content.add_change(offset, offset, indented)
+ self._replace_occurrences(content, extract_info)
+ return content.get_changed()
+
+ def _replace_occurrences(self, content, extract_info):
+ for match in extract_info.matches:
+ replacement = similarfinder.CodeTemplate(
+ extract_info.replacement_pattern)
+ mapping = {}
+ for name in replacement.get_names():
+ node = match.get_ast(name)
+ if node:
+ start, end = patchedast.node_region(match.get_ast(name))
+ mapping[name] = self.info.source[start:end]
+ else:
+ mapping[name] = name
+ region = match.get_region()
+ content.add_change(region[0], region[1],
+ replacement.substitute(mapping))
+
+ def _collect_info(self):
+ extract_collector = _ExtractCollector(self.info)
+ self._find_definition(extract_collector)
+ self._find_matches(extract_collector)
+ self._find_definition_location(extract_collector)
+ return extract_collector
+
+ def _find_matches(self, collector):
+ regions = self._where_to_search()
+ finder = similarfinder.SimilarFinder(self.info.pymodule)
+ matches = []
+ for start, end in regions:
+ matches.extend((finder.get_matches(collector.body_pattern,
+ collector.checks, start, end)))
+ collector.matches = matches
+
+ def _where_to_search(self):
+ if self.info.similar:
+ if self.info.make_global or self.info.global_:
+ return [(0, len(self.info.pymodule.source_code))]
+ if self.info.method and not self.info.variable:
+ class_scope = self.info.scope.parent
+ regions = []
+ method_kind = _get_function_kind(self.info.scope)
+ for scope in class_scope.get_scopes():
+ if method_kind == 'method' and \
+ _get_function_kind(scope) != 'method':
+ continue
+ start = self.info.lines.get_line_start(scope.get_start())
+ end = self.info.lines.get_line_end(scope.get_end())
+ regions.append((start, end))
+ return regions
+ else:
+ if self.info.variable:
+ return [self.info.scope_region]
+ else:
+ return [self.info._get_scope_region(
+ self.info.scope.parent)]
+ else:
+ return [self.info.region]
+
+ def _find_definition_location(self, collector):
+ matched_lines = []
+ for match in collector.matches:
+ start = self.info.lines.get_line_number(match.get_region()[0])
+ start_line = self.info.logical_lines.logical_line_in(start)[0]
+ matched_lines.append(start_line)
+ location_finder = _DefinitionLocationFinder(self.info, matched_lines)
+ collector.definition_location = (location_finder.find_lineno(),
+ location_finder.find_indents())
+
+ def _find_definition(self, collector):
+ if self.info.variable:
+ parts = _ExtractVariableParts(self.info)
+ else:
+ parts = _ExtractMethodParts(self.info)
+ collector.definition = parts.get_definition()
+ collector.body_pattern = parts.get_body_pattern()
+ collector.replacement_pattern = parts.get_replacement_pattern()
+ collector.checks = parts.get_checks()
+
+
+class _DefinitionLocationFinder(object):
+
+ def __init__(self, info, matched_lines):
+ self.info = info
+ self.matched_lines = matched_lines
+ # This only happens when subexpressions cannot be matched
+ if not matched_lines:
+ self.matched_lines.append(self.info.region_lines[0])
+
+ def find_lineno(self):
+ if self.info.variable and not self.info.make_global:
+ return self._get_before_line()
+ if self.info.make_global or self.info.global_:
+ toplevel = self._find_toplevel(self.info.scope)
+ ast = self.info.pymodule.get_ast()
+ newlines = sorted(self.matched_lines + [toplevel.get_end() + 1])
+ return suites.find_visible(ast, newlines)
+ return self._get_after_scope()
+
+ def _find_toplevel(self, scope):
+ toplevel = scope
+ if toplevel.parent is not None:
+ while toplevel.parent.parent is not None:
+ toplevel = toplevel.parent
+ return toplevel
+
+ def find_indents(self):
+ if self.info.variable and not self.info.make_global:
+ return sourceutils.get_indents(self.info.lines,
+ self._get_before_line())
+ else:
+ if self.info.global_ or self.info.make_global:
+ return 0
+ return self.info.scope_indents
+
+ def _get_before_line(self):
+ ast = self.info.scope.pyobject.get_ast()
+ return suites.find_visible(ast, self.matched_lines)
+
+ def _get_after_scope(self):
+ return self.info.scope.get_end() + 1
+
+
+class _ExceptionalConditionChecker(object):
+
+ def __call__(self, info):
+ self.base_conditions(info)
+ if info.one_line:
+ self.one_line_conditions(info)
+ else:
+ self.multi_line_conditions(info)
+
+ def base_conditions(self, info):
+ if info.region[1] > info.scope_region[1]:
+ raise RefactoringError('Bad region selected for extract method')
+ end_line = info.region_lines[1]
+ end_scope = info.global_scope.get_inner_scope_for_line(end_line)
+ if end_scope != info.scope and end_scope.get_end() != end_line:
+ raise RefactoringError('Bad region selected for extract method')
+ try:
+ extracted = info.source[info.region[0]:info.region[1]]
+ if info.one_line:
+ extracted = '(%s)' % extracted
+ if _UnmatchedBreakOrContinueFinder.has_errors(extracted):
+ raise RefactoringError('A break/continue without having a '
+ 'matching for/while loop.')
+ except SyntaxError:
+ raise RefactoringError('Extracted piece should '
+ 'contain complete statements.')
+
+ def one_line_conditions(self, info):
+ if self._is_region_on_a_word(info):
+ raise RefactoringError('Should extract complete statements.')
+ if info.variable and not info.one_line:
+ raise RefactoringError('Extract variable should not '
+ 'span multiple lines.')
+
+ def multi_line_conditions(self, info):
+ node = _parse_text(info.source[info.region[0]:info.region[1]])
+ count = usefunction._return_count(node)
+ if count > 1:
+ raise RefactoringError('Extracted piece can have only one '
+ 'return statement.')
+ if usefunction._yield_count(node):
+ raise RefactoringError('Extracted piece cannot '
+ 'have yield statements.')
+ if count == 1 and not usefunction._returns_last(node):
+ raise RefactoringError('Return should be the last statement.')
+ if info.region != info.lines_region:
+ raise RefactoringError('Extracted piece should '
+ 'contain complete statements.')
+
+ def _is_region_on_a_word(self, info):
+ if info.region[0] > 0 and \
+ self._is_on_a_word(info, info.region[0] - 1) or \
+ self._is_on_a_word(info, info.region[1] - 1):
+ return True
+
+ def _is_on_a_word(self, info, offset):
+ prev = info.source[offset]
+ if not (prev.isalnum() or prev == '_') or \
+ offset + 1 == len(info.source):
+ return False
+ next = info.source[offset + 1]
+ return next.isalnum() or next == '_'
+
+
+class _ExtractMethodParts(object):
+
+ def __init__(self, info):
+ self.info = info
+ self.info_collector = self._create_info_collector()
+
+ def get_definition(self):
+ if self.info.global_:
+ return '\n%s\n' % self._get_function_definition()
+ else:
+ return '\n%s' % self._get_function_definition()
+
+ def get_replacement_pattern(self):
+ variables = []
+ variables.extend(self._find_function_arguments())
+ variables.extend(self._find_function_returns())
+ return similarfinder.make_pattern(self._get_call(), variables)
+
+ def get_body_pattern(self):
+ variables = []
+ variables.extend(self._find_function_arguments())
+ variables.extend(self._find_function_returns())
+ variables.extend(self._find_temps())
+ return similarfinder.make_pattern(self._get_body(), variables)
+
+ def _get_body(self):
+ result = sourceutils.fix_indentation(self.info.extracted, 0)
+ if self.info.one_line:
+ result = '(%s)' % result
+ return result
+
+ def _find_temps(self):
+ return usefunction.find_temps(self.info.project,
+ self._get_body())
+
+ def get_checks(self):
+ if self.info.method and not self.info.make_global:
+ if _get_function_kind(self.info.scope) == 'method':
+ class_name = similarfinder._pydefined_to_str(
+ self.info.scope.parent.pyobject)
+ return {self._get_self_name(): 'type=' + class_name}
+ return {}
+
+ def _create_info_collector(self):
+ zero = self.info.scope.get_start() - 1
+ start_line = self.info.region_lines[0] - zero
+ end_line = self.info.region_lines[1] - zero
+ info_collector = _FunctionInformationCollector(start_line, end_line,
+ self.info.global_)
+ body = self.info.source[self.info.scope_region[0]:
+ self.info.scope_region[1]]
+ node = _parse_text(body)
+ ast.walk(node, info_collector)
+ return info_collector
+
+ def _get_function_definition(self):
+ args = self._find_function_arguments()
+ returns = self._find_function_returns()
+ result = []
+ if self.info.method and not self.info.make_global and \
+ _get_function_kind(self.info.scope) != 'method':
+ result.append('@staticmethod\n')
+ result.append('def %s:\n' % self._get_function_signature(args))
+ unindented_body = self._get_unindented_function_body(returns)
+ indents = sourceutils.get_indent(self.info.project)
+ function_body = sourceutils.indent_lines(unindented_body, indents)
+ result.append(function_body)
+ definition = ''.join(result)
+
+ return definition + '\n'
+
+ def _get_function_signature(self, args):
+ args = list(args)
+ prefix = ''
+ if self._extracting_method():
+ self_name = self._get_self_name()
+ if self_name is None:
+ raise RefactoringError('Extracting a method from a function '
+ 'with no self argument.')
+ if self_name in args:
+ args.remove(self_name)
+ args.insert(0, self_name)
+ return prefix + self.info.new_name + \
+ '(%s)' % self._get_comma_form(args)
+
+ def _extracting_method(self):
+ return self.info.method and not self.info.make_global and \
+ _get_function_kind(self.info.scope) == 'method'
+
+ def _get_self_name(self):
+ param_names = self.info.scope.pyobject.get_param_names()
+ if param_names:
+ return param_names[0]
+
+ def _get_function_call(self, args):
+ prefix = ''
+ if self.info.method and not self.info.make_global:
+ if _get_function_kind(self.info.scope) == 'method':
+ self_name = self._get_self_name()
+ if self_name in args:
+ args.remove(self_name)
+ prefix = self_name + '.'
+ else:
+ prefix = self.info.scope.parent.pyobject.get_name() + '.'
+ return prefix + '%s(%s)' % (self.info.new_name,
+ self._get_comma_form(args))
+
+ def _get_comma_form(self, names):
+ result = ''
+ if names:
+ result += names[0]
+ for name in names[1:]:
+ result += ', ' + name
+ return result
+
+ def _get_call(self):
+ if self.info.one_line:
+ args = self._find_function_arguments()
+ return self._get_function_call(args)
+ args = self._find_function_arguments()
+ returns = self._find_function_returns()
+ call_prefix = ''
+ if returns:
+ call_prefix = self._get_comma_form(returns) + ' = '
+ if self.info.returned:
+ call_prefix = 'return '
+ return call_prefix + self._get_function_call(args)
+
+ def _find_function_arguments(self):
+ # if not make_global, do not pass any global names; they are
+ # all visible.
+ if self.info.global_ and not self.info.make_global:
+ return ()
+ if not self.info.one_line:
+ result = (self.info_collector.prewritten &
+ self.info_collector.read)
+ result |= (self.info_collector.prewritten &
+ self.info_collector.postread &
+ (self.info_collector.maybe_written -
+ self.info_collector.written))
+ return list(result)
+ start = self.info.region[0]
+ if start == self.info.lines_region[0]:
+ start = start + re.search('\S', self.info.extracted).start()
+ function_definition = self.info.source[start:self.info.region[1]]
+ read = _VariableReadsAndWritesFinder.find_reads_for_one_liners(
+ function_definition)
+ return list(self.info_collector.prewritten.intersection(read))
+
+ def _find_function_returns(self):
+ if self.info.one_line or self.info.returned:
+ return []
+ written = self.info_collector.written | \
+ self.info_collector.maybe_written
+ return list(written & self.info_collector.postread)
+
+ def _get_unindented_function_body(self, returns):
+ if self.info.one_line:
+ return 'return ' + _join_lines(self.info.extracted)
+ extracted_body = self.info.extracted
+ unindented_body = sourceutils.fix_indentation(extracted_body, 0)
+ if returns:
+ unindented_body += '\nreturn %s' % self._get_comma_form(returns)
+ return unindented_body
+
+
+class _ExtractVariableParts(object):
+
+ def __init__(self, info):
+ self.info = info
+
+ def get_definition(self):
+ result = self.info.new_name + ' = ' + \
+ _join_lines(self.info.extracted) + '\n'
+ return result
+
+ def get_body_pattern(self):
+ return '(%s)' % self.info.extracted.strip()
+
+ def get_replacement_pattern(self):
+ return self.info.new_name
+
+ def get_checks(self):
+ return {}
+
+
+class _FunctionInformationCollector(object):
+
+ def __init__(self, start, end, is_global):
+ self.start = start
+ self.end = end
+ self.is_global = is_global
+ self.prewritten = set()
+ self.maybe_written = set()
+ self.written = set()
+ self.read = set()
+ self.postread = set()
+ self.postwritten = set()
+ self.host_function = True
+ self.conditional = False
+
+ def _read_variable(self, name, lineno):
+ if self.start <= lineno <= self.end:
+ if name not in self.written:
+ if not self.conditional or name not in self.maybe_written:
+ self.read.add(name)
+ if self.end < lineno:
+ if name not in self.postwritten:
+ self.postread.add(name)
+
+ def _written_variable(self, name, lineno):
+ if self.start <= lineno <= self.end:
+ if self.conditional:
+ self.maybe_written.add(name)
+ else:
+ self.written.add(name)
+ if self.start > lineno:
+ self.prewritten.add(name)
+ if self.end < lineno:
+ self.postwritten.add(name)
+
+ def _FunctionDef(self, node):
+ if not self.is_global and self.host_function:
+ self.host_function = False
+ for name in _get_argnames(node.args):
+ self._written_variable(name, node.lineno)
+ for child in node.body:
+ ast.walk(child, self)
+ else:
+ self._written_variable(node.name, node.lineno)
+ visitor = _VariableReadsAndWritesFinder()
+ for child in node.body:
+ ast.walk(child, visitor)
+ for name in visitor.read - visitor.written:
+ self._read_variable(name, node.lineno)
+
+ def _Name(self, node):
+ if isinstance(node.ctx, (ast.Store, ast.AugStore)):
+ self._written_variable(node.id, node.lineno)
+ if not isinstance(node.ctx, ast.Store):
+ self._read_variable(node.id, node.lineno)
+
+ def _Assign(self, node):
+ ast.walk(node.value, self)
+ for child in node.targets:
+ ast.walk(child, self)
+
+ def _ClassDef(self, node):
+ self._written_variable(node.name, node.lineno)
+
+ def _handle_conditional_node(self, node):
+ self.conditional = True
+ try:
+ for child in ast.get_child_nodes(node):
+ ast.walk(child, self)
+ finally:
+ self.conditional = False
+
+ def _If(self, node):
+ self._handle_conditional_node(node)
+
+ def _While(self, node):
+ self._handle_conditional_node(node)
+
+ def _For(self, node):
+ self.conditional = True
+ try:
+ # iter has to be checked before the target variables
+ ast.walk(node.iter, self)
+ ast.walk(node.target, self)
+
+ for child in node.body:
+ ast.walk(child, self)
+ for child in node.orelse:
+ ast.walk(child, self)
+ finally:
+ self.conditional = False
+
+
+def _get_argnames(arguments):
+ result = [node.id for node in arguments.args
+ if isinstance(node, ast.Name)]
+ if arguments.vararg:
+ result.append(arguments.vararg)
+ if arguments.kwarg:
+ result.append(arguments.kwarg)
+ return result
+
+
+class _VariableReadsAndWritesFinder(object):
+
+ def __init__(self):
+ self.written = set()
+ self.read = set()
+
+ def _Name(self, node):
+ if isinstance(node.ctx, (ast.Store, ast.AugStore)):
+ self.written.add(node.id)
+ if not isinstance(node, ast.Store):
+ self.read.add(node.id)
+
+ def _FunctionDef(self, node):
+ self.written.add(node.name)
+ visitor = _VariableReadsAndWritesFinder()
+ for child in ast.get_child_nodes(node):
+ ast.walk(child, visitor)
+ self.read.update(visitor.read - visitor.written)
+
+ def _Class(self, node):
+ self.written.add(node.name)
+
+ @staticmethod
+ def find_reads_and_writes(code):
+ if code.strip() == '':
+ return set(), set()
+ if isinstance(code, unicode):
+ code = code.encode('utf-8')
+ node = _parse_text(code)
+ visitor = _VariableReadsAndWritesFinder()
+ ast.walk(node, visitor)
+ return visitor.read, visitor.written
+
+ @staticmethod
+ def find_reads_for_one_liners(code):
+ if code.strip() == '':
+ return set(), set()
+ node = _parse_text(code)
+ visitor = _VariableReadsAndWritesFinder()
+ ast.walk(node, visitor)
+ return visitor.read
+
+
+class _UnmatchedBreakOrContinueFinder(object):
+
+ def __init__(self):
+ self.error = False
+ self.loop_count = 0
+
+ def _For(self, node):
+ self.loop_encountered(node)
+
+ def _While(self, node):
+ self.loop_encountered(node)
+
+ def loop_encountered(self, node):
+ self.loop_count += 1
+ for child in node.body:
+ ast.walk(child, self)
+ self.loop_count -= 1
+ if node.orelse:
+ ast.walk(node.orelse, self)
+
+ def _Break(self, node):
+ self.check_loop()
+
+ def _Continue(self, node):
+ self.check_loop()
+
+ def check_loop(self):
+ if self.loop_count < 1:
+ self.error = True
+
+ def _FunctionDef(self, node):
+ pass
+
+ def _ClassDef(self, node):
+ pass
+
+ @staticmethod
+ def has_errors(code):
+ if code.strip() == '':
+ return False
+ node = _parse_text(code)
+ visitor = _UnmatchedBreakOrContinueFinder()
+ ast.walk(node, visitor)
+ return visitor.error
+
+
+def _get_function_kind(scope):
+ return scope.pyobject.get_kind()
+
+
+def _parse_text(body):
+ body = sourceutils.fix_indentation(body, 0)
+ node = ast.parse(body)
+ return node
+
+
+def _join_lines(code):
+ lines = []
+ for line in code.splitlines():
+ if line.endswith('\\'):
+ lines.append(line[:-1].strip())
+ else:
+ lines.append(line.strip())
+ return ' '.join(lines)

Powered by Google App Engine
This is Rietveld 408576698