| OLD | NEW |
| (Empty) |
| 1 # Copyright 2015 The Chromium Authors. All rights reserved. | |
| 2 # Use of this source code is governed by a BSD-style license that can be | |
| 3 # found in the LICENSE file. | |
| 4 | |
| 5 import functools | |
| 6 import os | |
| 7 import sys | |
| 8 | |
| 9 from catapult_base import refactor | |
| 10 | |
| 11 | |
| 12 def Run(sources, target, files_to_update): | |
| 13 """Move modules and update imports. | |
| 14 | |
| 15 Args: | |
| 16 sources: List of source module or package paths. | |
| 17 target: Destination module or package path. | |
| 18 files_to_update: Modules whose imports we should check for changes. | |
| 19 """ | |
| 20 # TODO(dtu): Support moving classes and functions. | |
| 21 moves = tuple(_Move(source, target) for source in sources) | |
| 22 | |
| 23 # Update imports and references. | |
| 24 refactor.Transform(functools.partial(_Update, moves), files_to_update) | |
| 25 | |
| 26 # Move files. | |
| 27 for move in moves: | |
| 28 os.rename(move.source_path, move.target_path) | |
| 29 | |
| 30 | |
| 31 def _Update(moves, module): | |
| 32 for import_statement in module.FindAll(refactor.Import): | |
| 33 for move in moves: | |
| 34 try: | |
| 35 if move.UpdateImportAndReferences(module, import_statement): | |
| 36 break | |
| 37 except NotImplementedError as e: | |
| 38 print >> sys.stderr, 'Error updating %s: %s' % (module.file_path, e) | |
| 39 | |
| 40 | |
| 41 class _Move(object): | |
| 42 def __init__(self, source, target): | |
| 43 self._source_path = os.path.realpath(source) | |
| 44 self._target_path = os.path.realpath(target) | |
| 45 | |
| 46 if os.path.isdir(self._target_path): | |
| 47 self._target_path = os.path.join( | |
| 48 self._target_path, os.path.basename(self._source_path)) | |
| 49 | |
| 50 @property | |
| 51 def source_path(self): | |
| 52 return self._source_path | |
| 53 | |
| 54 @property | |
| 55 def target_path(self): | |
| 56 return self._target_path | |
| 57 | |
| 58 @property | |
| 59 def source_module_path(self): | |
| 60 return _ModulePath(self._source_path) | |
| 61 | |
| 62 @property | |
| 63 def target_module_path(self): | |
| 64 return _ModulePath(self._target_path) | |
| 65 | |
| 66 def UpdateImportAndReferences(self, module, import_statement): | |
| 67 """Update an import statement in a module and all its references.. | |
| 68 | |
| 69 Args: | |
| 70 module: The refactor.Module to update. | |
| 71 import_statement: The refactor.Import to update. | |
| 72 | |
| 73 Returns: | |
| 74 True if the import statement was updated, or False if the import statement | |
| 75 needed no updating. | |
| 76 """ | |
| 77 statement_path_parts = import_statement.path.split('.') | |
| 78 source_path_parts = self.source_module_path.split('.') | |
| 79 if source_path_parts != statement_path_parts[:len(source_path_parts)]: | |
| 80 return False | |
| 81 | |
| 82 # Update import statement. | |
| 83 old_name_parts = import_statement.name.split('.') | |
| 84 new_name_parts = ([self.target_module_path] + | |
| 85 statement_path_parts[len(source_path_parts):]) | |
| 86 import_statement.path = '.'.join(new_name_parts) | |
| 87 new_name = import_statement.name | |
| 88 | |
| 89 # Update references. | |
| 90 for reference in module.FindAll(refactor.Reference): | |
| 91 reference_parts = reference.value.split('.') | |
| 92 if old_name_parts != reference_parts[:len(old_name_parts)]: | |
| 93 continue | |
| 94 | |
| 95 new_reference_parts = [new_name] + reference_parts[len(old_name_parts):] | |
| 96 reference.value = '.'.join(new_reference_parts) | |
| 97 | |
| 98 return True | |
| 99 | |
| 100 | |
| 101 def _BaseDir(module_path): | |
| 102 if not os.path.isdir(module_path): | |
| 103 module_path = os.path.dirname(module_path) | |
| 104 | |
| 105 while '__init__.py' in os.listdir(module_path): | |
| 106 module_path = os.path.dirname(module_path) | |
| 107 | |
| 108 return module_path | |
| 109 | |
| 110 | |
| 111 def _ModulePath(module_path): | |
| 112 if os.path.split(module_path)[1] == '__init__.py': | |
| 113 module_path = os.path.dirname(module_path) | |
| 114 rel_path = os.path.relpath(module_path, _BaseDir(module_path)) | |
| 115 return os.path.splitext(rel_path)[0].replace(os.sep, '.') | |
| OLD | NEW |