OLD | NEW |
(Empty) | |
| 1 """Astroid hooks for various builtins.""" |
| 2 |
| 3 import sys |
| 4 from functools import partial |
| 5 from textwrap import dedent |
| 6 |
| 7 import six |
| 8 from astroid import (MANAGER, UseInferenceDefault, |
| 9 inference_tip, YES, InferenceError, UnresolvableName) |
| 10 from astroid import nodes |
| 11 from astroid.builder import AstroidBuilder |
| 12 |
| 13 |
| 14 def _extend_str(class_node, rvalue): |
| 15 """function to extend builtin str/unicode class""" |
| 16 # TODO(cpopa): this approach will make astroid to believe |
| 17 # that some arguments can be passed by keyword, but |
| 18 # unfortunately, strings and bytes don't accept keyword arguments. |
| 19 code = dedent(''' |
| 20 class whatever(object): |
| 21 def join(self, iterable): |
| 22 return {rvalue} |
| 23 def replace(self, old, new, count=None): |
| 24 return {rvalue} |
| 25 def format(self, *args, **kwargs): |
| 26 return {rvalue} |
| 27 def encode(self, encoding='ascii', errors=None): |
| 28 return '' |
| 29 def decode(self, encoding='ascii', errors=None): |
| 30 return u'' |
| 31 def capitalize(self): |
| 32 return {rvalue} |
| 33 def title(self): |
| 34 return {rvalue} |
| 35 def lower(self): |
| 36 return {rvalue} |
| 37 def upper(self): |
| 38 return {rvalue} |
| 39 def swapcase(self): |
| 40 return {rvalue} |
| 41 def index(self, sub, start=None, end=None): |
| 42 return 0 |
| 43 def find(self, sub, start=None, end=None): |
| 44 return 0 |
| 45 def count(self, sub, start=None, end=None): |
| 46 return 0 |
| 47 def strip(self, chars=None): |
| 48 return {rvalue} |
| 49 def lstrip(self, chars=None): |
| 50 return {rvalue} |
| 51 def rstrip(self, chars=None): |
| 52 return {rvalue} |
| 53 def rjust(self, width, fillchar=None): |
| 54 return {rvalue} |
| 55 def center(self, width, fillchar=None): |
| 56 return {rvalue} |
| 57 def ljust(self, width, fillchar=None): |
| 58 return {rvalue} |
| 59 ''') |
| 60 code = code.format(rvalue=rvalue) |
| 61 fake = AstroidBuilder(MANAGER).string_build(code)['whatever'] |
| 62 for method in fake.mymethods(): |
| 63 class_node.locals[method.name] = [method] |
| 64 method.parent = class_node |
| 65 |
| 66 def extend_builtins(class_transforms): |
| 67 from astroid.bases import BUILTINS |
| 68 builtin_ast = MANAGER.astroid_cache[BUILTINS] |
| 69 for class_name, transform in class_transforms.items(): |
| 70 transform(builtin_ast[class_name]) |
| 71 |
| 72 if sys.version_info > (3, 0): |
| 73 extend_builtins({'bytes': partial(_extend_str, rvalue="b''"), |
| 74 'str': partial(_extend_str, rvalue="''")}) |
| 75 else: |
| 76 extend_builtins({'str': partial(_extend_str, rvalue="''"), |
| 77 'unicode': partial(_extend_str, rvalue="u''")}) |
| 78 |
| 79 |
| 80 def register_builtin_transform(transform, builtin_name): |
| 81 """Register a new transform function for the given *builtin_name*. |
| 82 |
| 83 The transform function must accept two parameters, a node and |
| 84 an optional context. |
| 85 """ |
| 86 def _transform_wrapper(node, context=None): |
| 87 result = transform(node, context=context) |
| 88 if result: |
| 89 result.parent = node |
| 90 result.lineno = node.lineno |
| 91 result.col_offset = node.col_offset |
| 92 return iter([result]) |
| 93 |
| 94 MANAGER.register_transform(nodes.CallFunc, |
| 95 inference_tip(_transform_wrapper), |
| 96 lambda n: (isinstance(n.func, nodes.Name) and |
| 97 n.func.name == builtin_name)) |
| 98 |
| 99 |
| 100 def _generic_inference(node, context, node_type, transform): |
| 101 args = node.args |
| 102 if not args: |
| 103 return node_type() |
| 104 if len(node.args) > 1: |
| 105 raise UseInferenceDefault() |
| 106 |
| 107 arg, = args |
| 108 transformed = transform(arg) |
| 109 if not transformed: |
| 110 try: |
| 111 infered = next(arg.infer(context=context)) |
| 112 except (InferenceError, StopIteration): |
| 113 raise UseInferenceDefault() |
| 114 if infered is YES: |
| 115 raise UseInferenceDefault() |
| 116 transformed = transform(infered) |
| 117 if not transformed or transformed is YES: |
| 118 raise UseInferenceDefault() |
| 119 return transformed |
| 120 |
| 121 |
| 122 def _generic_transform(arg, klass, iterables, build_elts): |
| 123 if isinstance(arg, klass): |
| 124 return arg |
| 125 elif isinstance(arg, iterables): |
| 126 if not all(isinstance(elt, nodes.Const) |
| 127 for elt in arg.elts): |
| 128 # TODO(cpopa): Don't support heterogenous elements. |
| 129 # Not yet, though. |
| 130 raise UseInferenceDefault() |
| 131 elts = [elt.value for elt in arg.elts] |
| 132 elif isinstance(arg, nodes.Dict): |
| 133 if not all(isinstance(elt[0], nodes.Const) |
| 134 for elt in arg.items): |
| 135 raise UseInferenceDefault() |
| 136 elts = [item[0].value for item in arg.items] |
| 137 elif (isinstance(arg, nodes.Const) and |
| 138 isinstance(arg.value, (six.string_types, six.binary_type))): |
| 139 elts = arg.value |
| 140 else: |
| 141 return |
| 142 return klass(elts=build_elts(elts)) |
| 143 |
| 144 |
| 145 def _infer_builtin(node, context, |
| 146 klass=None, iterables=None, |
| 147 build_elts=None): |
| 148 transform_func = partial( |
| 149 _generic_transform, |
| 150 klass=klass, |
| 151 iterables=iterables, |
| 152 build_elts=build_elts) |
| 153 |
| 154 return _generic_inference(node, context, klass, transform_func) |
| 155 |
| 156 # pylint: disable=invalid-name |
| 157 infer_tuple = partial( |
| 158 _infer_builtin, |
| 159 klass=nodes.Tuple, |
| 160 iterables=(nodes.List, nodes.Set), |
| 161 build_elts=tuple) |
| 162 |
| 163 infer_list = partial( |
| 164 _infer_builtin, |
| 165 klass=nodes.List, |
| 166 iterables=(nodes.Tuple, nodes.Set), |
| 167 build_elts=list) |
| 168 |
| 169 infer_set = partial( |
| 170 _infer_builtin, |
| 171 klass=nodes.Set, |
| 172 iterables=(nodes.List, nodes.Tuple), |
| 173 build_elts=set) |
| 174 |
| 175 |
| 176 def _get_elts(arg, context): |
| 177 is_iterable = lambda n: isinstance(n, |
| 178 (nodes.List, nodes.Tuple, nodes.Set)) |
| 179 try: |
| 180 infered = next(arg.infer(context)) |
| 181 except (InferenceError, UnresolvableName): |
| 182 raise UseInferenceDefault() |
| 183 if isinstance(infered, nodes.Dict): |
| 184 items = infered.items |
| 185 elif is_iterable(infered): |
| 186 items = [] |
| 187 for elt in infered.elts: |
| 188 # If an item is not a pair of two items, |
| 189 # then fallback to the default inference. |
| 190 # Also, take in consideration only hashable items, |
| 191 # tuples and consts. We are choosing Names as well. |
| 192 if not is_iterable(elt): |
| 193 raise UseInferenceDefault() |
| 194 if len(elt.elts) != 2: |
| 195 raise UseInferenceDefault() |
| 196 if not isinstance(elt.elts[0], |
| 197 (nodes.Tuple, nodes.Const, nodes.Name)): |
| 198 raise UseInferenceDefault() |
| 199 items.append(tuple(elt.elts)) |
| 200 else: |
| 201 raise UseInferenceDefault() |
| 202 return items |
| 203 |
| 204 def infer_dict(node, context=None): |
| 205 """Try to infer a dict call to a Dict node. |
| 206 |
| 207 The function treats the following cases: |
| 208 |
| 209 * dict() |
| 210 * dict(mapping) |
| 211 * dict(iterable) |
| 212 * dict(iterable, **kwargs) |
| 213 * dict(mapping, **kwargs) |
| 214 * dict(**kwargs) |
| 215 |
| 216 If a case can't be infered, we'll fallback to default inference. |
| 217 """ |
| 218 has_keywords = lambda args: all(isinstance(arg, nodes.Keyword) |
| 219 for arg in args) |
| 220 if not node.args and not node.kwargs: |
| 221 # dict() |
| 222 return nodes.Dict() |
| 223 elif has_keywords(node.args) and node.args: |
| 224 # dict(a=1, b=2, c=4) |
| 225 items = [(nodes.Const(arg.arg), arg.value) for arg in node.args] |
| 226 elif (len(node.args) >= 2 and |
| 227 has_keywords(node.args[1:])): |
| 228 # dict(some_iterable, b=2, c=4) |
| 229 elts = _get_elts(node.args[0], context) |
| 230 keys = [(nodes.Const(arg.arg), arg.value) for arg in node.args[1:]] |
| 231 items = elts + keys |
| 232 elif len(node.args) == 1: |
| 233 items = _get_elts(node.args[0], context) |
| 234 else: |
| 235 raise UseInferenceDefault() |
| 236 |
| 237 empty = nodes.Dict() |
| 238 empty.items = items |
| 239 return empty |
| 240 |
| 241 # Builtins inference |
| 242 register_builtin_transform(infer_tuple, 'tuple') |
| 243 register_builtin_transform(infer_set, 'set') |
| 244 register_builtin_transform(infer_list, 'list') |
| 245 register_builtin_transform(infer_dict, 'dict') |
OLD | NEW |