OLD | NEW |
(Empty) | |
| 1 import os |
| 2 |
| 3 from Cython.Compiler import CmdLine |
| 4 from Cython.TestUtils import TransformTest |
| 5 from Cython.Compiler.ParseTreeTransforms import * |
| 6 from Cython.Compiler.Nodes import * |
| 7 from Cython.Compiler import Main, Symtab |
| 8 |
| 9 |
| 10 class TestNormalizeTree(TransformTest): |
| 11 def test_parserbehaviour_is_what_we_coded_for(self): |
| 12 t = self.fragment(u"if x: y").root |
| 13 self.assertLines(u""" |
| 14 (root): StatListNode |
| 15 stats[0]: IfStatNode |
| 16 if_clauses[0]: IfClauseNode |
| 17 condition: NameNode |
| 18 body: ExprStatNode |
| 19 expr: NameNode |
| 20 """, self.treetypes(t)) |
| 21 |
| 22 def test_wrap_singlestat(self): |
| 23 t = self.run_pipeline([NormalizeTree(None)], u"if x: y") |
| 24 self.assertLines(u""" |
| 25 (root): StatListNode |
| 26 stats[0]: IfStatNode |
| 27 if_clauses[0]: IfClauseNode |
| 28 condition: NameNode |
| 29 body: StatListNode |
| 30 stats[0]: ExprStatNode |
| 31 expr: NameNode |
| 32 """, self.treetypes(t)) |
| 33 |
| 34 def test_wrap_multistat(self): |
| 35 t = self.run_pipeline([NormalizeTree(None)], u""" |
| 36 if z: |
| 37 x |
| 38 y |
| 39 """) |
| 40 self.assertLines(u""" |
| 41 (root): StatListNode |
| 42 stats[0]: IfStatNode |
| 43 if_clauses[0]: IfClauseNode |
| 44 condition: NameNode |
| 45 body: StatListNode |
| 46 stats[0]: ExprStatNode |
| 47 expr: NameNode |
| 48 stats[1]: ExprStatNode |
| 49 expr: NameNode |
| 50 """, self.treetypes(t)) |
| 51 |
| 52 def test_statinexpr(self): |
| 53 t = self.run_pipeline([NormalizeTree(None)], u""" |
| 54 a, b = x, y |
| 55 """) |
| 56 self.assertLines(u""" |
| 57 (root): StatListNode |
| 58 stats[0]: SingleAssignmentNode |
| 59 lhs: TupleNode |
| 60 args[0]: NameNode |
| 61 args[1]: NameNode |
| 62 rhs: TupleNode |
| 63 args[0]: NameNode |
| 64 args[1]: NameNode |
| 65 """, self.treetypes(t)) |
| 66 |
| 67 def test_wrap_offagain(self): |
| 68 t = self.run_pipeline([NormalizeTree(None)], u""" |
| 69 x |
| 70 y |
| 71 if z: |
| 72 x |
| 73 """) |
| 74 self.assertLines(u""" |
| 75 (root): StatListNode |
| 76 stats[0]: ExprStatNode |
| 77 expr: NameNode |
| 78 stats[1]: ExprStatNode |
| 79 expr: NameNode |
| 80 stats[2]: IfStatNode |
| 81 if_clauses[0]: IfClauseNode |
| 82 condition: NameNode |
| 83 body: StatListNode |
| 84 stats[0]: ExprStatNode |
| 85 expr: NameNode |
| 86 """, self.treetypes(t)) |
| 87 |
| 88 |
| 89 def test_pass_eliminated(self): |
| 90 t = self.run_pipeline([NormalizeTree(None)], u"pass") |
| 91 self.assert_(len(t.stats) == 0) |
| 92 |
| 93 class TestWithTransform(object): # (TransformTest): # Disabled! |
| 94 |
| 95 def test_simplified(self): |
| 96 t = self.run_pipeline([WithTransform(None)], u""" |
| 97 with x: |
| 98 y = z ** 3 |
| 99 """) |
| 100 |
| 101 self.assertCode(u""" |
| 102 |
| 103 $0_0 = x |
| 104 $0_2 = $0_0.__exit__ |
| 105 $0_0.__enter__() |
| 106 $0_1 = True |
| 107 try: |
| 108 try: |
| 109 $1_0 = None |
| 110 y = z ** 3 |
| 111 except: |
| 112 $0_1 = False |
| 113 if (not $0_2($1_0)): |
| 114 raise |
| 115 finally: |
| 116 if $0_1: |
| 117 $0_2(None, None, None) |
| 118 |
| 119 """, t) |
| 120 |
| 121 def test_basic(self): |
| 122 t = self.run_pipeline([WithTransform(None)], u""" |
| 123 with x as y: |
| 124 y = z ** 3 |
| 125 """) |
| 126 self.assertCode(u""" |
| 127 |
| 128 $0_0 = x |
| 129 $0_2 = $0_0.__exit__ |
| 130 $0_3 = $0_0.__enter__() |
| 131 $0_1 = True |
| 132 try: |
| 133 try: |
| 134 $1_0 = None |
| 135 y = $0_3 |
| 136 y = z ** 3 |
| 137 except: |
| 138 $0_1 = False |
| 139 if (not $0_2($1_0)): |
| 140 raise |
| 141 finally: |
| 142 if $0_1: |
| 143 $0_2(None, None, None) |
| 144 |
| 145 """, t) |
| 146 |
| 147 |
| 148 class TestInterpretCompilerDirectives(TransformTest): |
| 149 """ |
| 150 This class tests the parallel directives AST-rewriting and importing. |
| 151 """ |
| 152 |
| 153 # Test the parallel directives (c)importing |
| 154 |
| 155 import_code = u""" |
| 156 cimport cython.parallel |
| 157 cimport cython.parallel as par |
| 158 from cython cimport parallel as par2 |
| 159 from cython cimport parallel |
| 160 |
| 161 from cython.parallel cimport threadid as tid |
| 162 from cython.parallel cimport threadavailable as tavail |
| 163 from cython.parallel cimport prange |
| 164 """ |
| 165 |
| 166 expected_directives_dict = { |
| 167 u'cython.parallel': u'cython.parallel', |
| 168 u'par': u'cython.parallel', |
| 169 u'par2': u'cython.parallel', |
| 170 u'parallel': u'cython.parallel', |
| 171 |
| 172 u"tid": u"cython.parallel.threadid", |
| 173 u"tavail": u"cython.parallel.threadavailable", |
| 174 u"prange": u"cython.parallel.prange", |
| 175 } |
| 176 |
| 177 |
| 178 def setUp(self): |
| 179 super(TestInterpretCompilerDirectives, self).setUp() |
| 180 |
| 181 compilation_options = Main.CompilationOptions(Main.default_options) |
| 182 ctx = compilation_options.create_context() |
| 183 |
| 184 transform = InterpretCompilerDirectives(ctx, ctx.compiler_directives) |
| 185 transform.module_scope = Symtab.ModuleScope('__main__', None, ctx) |
| 186 self.pipeline = [transform] |
| 187 |
| 188 self.debug_exception_on_error = DebugFlags.debug_exception_on_error |
| 189 |
| 190 def tearDown(self): |
| 191 DebugFlags.debug_exception_on_error = self.debug_exception_on_error |
| 192 |
| 193 def test_parallel_directives_cimports(self): |
| 194 self.run_pipeline(self.pipeline, self.import_code) |
| 195 parallel_directives = self.pipeline[0].parallel_directives |
| 196 self.assertEqual(parallel_directives, self.expected_directives_dict) |
| 197 |
| 198 def test_parallel_directives_imports(self): |
| 199 self.run_pipeline(self.pipeline, |
| 200 self.import_code.replace(u'cimport', u'import')) |
| 201 parallel_directives = self.pipeline[0].parallel_directives |
| 202 self.assertEqual(parallel_directives, self.expected_directives_dict) |
| 203 |
| 204 |
| 205 # TODO: Re-enable once they're more robust. |
| 206 if sys.version_info[:2] >= (2, 5) and False: |
| 207 from Cython.Debugger import DebugWriter |
| 208 from Cython.Debugger.Tests.TestLibCython import DebuggerTestCase |
| 209 else: |
| 210 # skip test, don't let it inherit unittest.TestCase |
| 211 DebuggerTestCase = object |
| 212 |
| 213 class TestDebugTransform(DebuggerTestCase): |
| 214 |
| 215 def elem_hasattrs(self, elem, attrs): |
| 216 # we shall supporteth python 2.3 ! |
| 217 return all([attr in elem.attrib for attr in attrs]) |
| 218 |
| 219 def test_debug_info(self): |
| 220 try: |
| 221 assert os.path.exists(self.debug_dest) |
| 222 |
| 223 t = DebugWriter.etree.parse(self.debug_dest) |
| 224 # the xpath of the standard ElementTree is primitive, don't use |
| 225 # anything fancy |
| 226 L = list(t.find('/Module/Globals')) |
| 227 # assertTrue is retarded, use the normal assert statement |
| 228 assert L |
| 229 xml_globals = dict( |
| 230 [(e.attrib['name'], e.attrib['type']) for e in L]) |
| 231 self.assertEqual(len(L), len(xml_globals)) |
| 232 |
| 233 L = list(t.find('/Module/Functions')) |
| 234 assert L |
| 235 xml_funcs = dict([(e.attrib['qualified_name'], e) for e in L]) |
| 236 self.assertEqual(len(L), len(xml_funcs)) |
| 237 |
| 238 # test globals |
| 239 self.assertEqual('CObject', xml_globals.get('c_var')) |
| 240 self.assertEqual('PythonObject', xml_globals.get('python_var')) |
| 241 |
| 242 # test functions |
| 243 funcnames = ('codefile.spam', 'codefile.ham', 'codefile.eggs', |
| 244 'codefile.closure', 'codefile.inner') |
| 245 required_xml_attrs = 'name', 'cname', 'qualified_name' |
| 246 assert all([f in xml_funcs for f in funcnames]) |
| 247 spam, ham, eggs = [xml_funcs[funcname] for funcname in funcnames] |
| 248 |
| 249 self.assertEqual(spam.attrib['name'], 'spam') |
| 250 self.assertNotEqual('spam', spam.attrib['cname']) |
| 251 assert self.elem_hasattrs(spam, required_xml_attrs) |
| 252 |
| 253 # test locals of functions |
| 254 spam_locals = list(spam.find('Locals')) |
| 255 assert spam_locals |
| 256 spam_locals.sort(key=lambda e: e.attrib['name']) |
| 257 names = [e.attrib['name'] for e in spam_locals] |
| 258 self.assertEqual(list('abcd'), names) |
| 259 assert self.elem_hasattrs(spam_locals[0], required_xml_attrs) |
| 260 |
| 261 # test arguments of functions |
| 262 spam_arguments = list(spam.find('Arguments')) |
| 263 assert spam_arguments |
| 264 self.assertEqual(1, len(list(spam_arguments))) |
| 265 |
| 266 # test step-into functions |
| 267 step_into = spam.find('StepIntoFunctions') |
| 268 spam_stepinto = [x.attrib['name'] for x in step_into] |
| 269 assert spam_stepinto |
| 270 self.assertEqual(2, len(spam_stepinto)) |
| 271 assert 'puts' in spam_stepinto |
| 272 assert 'some_c_function' in spam_stepinto |
| 273 except: |
| 274 f = open(self.debug_dest) |
| 275 try: |
| 276 print(f.read()) |
| 277 finally: |
| 278 f.close() |
| 279 raise |
| 280 |
| 281 |
| 282 |
| 283 if __name__ == "__main__": |
| 284 import unittest |
| 285 unittest.main() |
OLD | NEW |