OLD | NEW |
| (Empty) |
1 # -*- test-case-name: twisted.conch.test.test_filetransfer -*- | |
2 # Copyright (c) 2001-2008 Twisted Matrix Laboratories. | |
3 # See LICENSE file for details. | |
4 | |
5 | |
6 import os | |
7 import struct | |
8 import sys | |
9 | |
10 from twisted.trial import unittest | |
11 try: | |
12 from twisted.conch import unix | |
13 unix # shut up pyflakes | |
14 except ImportError: | |
15 unix = None | |
16 try: | |
17 del sys.modules['twisted.conch.unix'] # remove the bad import | |
18 except KeyError: | |
19 # In Python 2.4, the bad import has already been cleaned up for us. | |
20 # Hooray. | |
21 pass | |
22 | |
23 from twisted.conch import avatar | |
24 from twisted.conch.ssh import common, connection, filetransfer, session | |
25 from twisted.internet import defer | |
26 from twisted.protocols import loopback | |
27 from twisted.python import components | |
28 | |
29 | |
30 class TestAvatar(avatar.ConchUser): | |
31 def __init__(self): | |
32 avatar.ConchUser.__init__(self) | |
33 self.channelLookup['session'] = session.SSHSession | |
34 self.subsystemLookup['sftp'] = filetransfer.FileTransferServer | |
35 | |
36 def _runAsUser(self, f, *args, **kw): | |
37 try: | |
38 f = iter(f) | |
39 except TypeError: | |
40 f = [(f, args, kw)] | |
41 for i in f: | |
42 func = i[0] | |
43 args = len(i)>1 and i[1] or () | |
44 kw = len(i)>2 and i[2] or {} | |
45 r = func(*args, **kw) | |
46 return r | |
47 | |
48 | |
49 class FileTransferTestAvatar(TestAvatar): | |
50 | |
51 def __init__(self, homeDir): | |
52 TestAvatar.__init__(self) | |
53 self.homeDir = homeDir | |
54 | |
55 def getHomeDir(self): | |
56 return os.path.join(os.getcwd(), self.homeDir) | |
57 | |
58 | |
59 class ConchSessionForTestAvatar: | |
60 | |
61 def __init__(self, avatar): | |
62 self.avatar = avatar | |
63 | |
64 if unix: | |
65 if not hasattr(unix, 'SFTPServerForUnixConchUser'): | |
66 # unix should either be a fully working module, or None. I'm not sure | |
67 # how this happens, but on win32 it does. Try to cope. --spiv. | |
68 import warnings | |
69 warnings.warn(("twisted.conch.unix imported %r, " | |
70 "but doesn't define SFTPServerForUnixConchUser'") | |
71 % (unix,)) | |
72 unix = None | |
73 else: | |
74 class FileTransferForTestAvatar(unix.SFTPServerForUnixConchUser): | |
75 | |
76 def gotVersion(self, version, otherExt): | |
77 return {'conchTest' : 'ext data'} | |
78 | |
79 def extendedRequest(self, extName, extData): | |
80 if extName == 'testExtendedRequest': | |
81 return 'bar' | |
82 raise NotImplementedError | |
83 | |
84 components.registerAdapter(FileTransferForTestAvatar, | |
85 TestAvatar, | |
86 filetransfer.ISFTPServer) | |
87 | |
88 class SFTPTestBase(unittest.TestCase): | |
89 | |
90 def setUp(self): | |
91 self.testDir = self.mktemp() | |
92 # Give the testDir another level so we can safely "cd .." from it in | |
93 # tests. | |
94 self.testDir = os.path.join(self.testDir, 'extra') | |
95 os.makedirs(os.path.join(self.testDir, 'testDirectory')) | |
96 | |
97 f = file(os.path.join(self.testDir, 'testfile1'),'w') | |
98 f.write('a'*10+'b'*10) | |
99 f.write(file('/dev/urandom').read(1024*64)) # random data | |
100 os.chmod(os.path.join(self.testDir, 'testfile1'), 0644) | |
101 file(os.path.join(self.testDir, 'testRemoveFile'), 'w').write('a') | |
102 file(os.path.join(self.testDir, 'testRenameFile'), 'w').write('a') | |
103 file(os.path.join(self.testDir, '.testHiddenFile'), 'w').write('a') | |
104 | |
105 | |
106 class TestOurServerOurClient(SFTPTestBase): | |
107 | |
108 if not unix: | |
109 skip = "can't run on non-posix computers" | |
110 | |
111 def setUp(self): | |
112 SFTPTestBase.setUp(self) | |
113 | |
114 self.avatar = FileTransferTestAvatar(self.testDir) | |
115 self.server = filetransfer.FileTransferServer(avatar=self.avatar) | |
116 clientTransport = loopback.LoopbackRelay(self.server) | |
117 | |
118 self.client = filetransfer.FileTransferClient() | |
119 self._serverVersion = None | |
120 self._extData = None | |
121 def _(serverVersion, extData): | |
122 self._serverVersion = serverVersion | |
123 self._extData = extData | |
124 self.client.gotServerVersion = _ | |
125 serverTransport = loopback.LoopbackRelay(self.client) | |
126 self.client.makeConnection(clientTransport) | |
127 self.server.makeConnection(serverTransport) | |
128 | |
129 self.clientTransport = clientTransport | |
130 self.serverTransport = serverTransport | |
131 | |
132 self._emptyBuffers() | |
133 | |
134 | |
135 def _emptyBuffers(self): | |
136 while self.serverTransport.buffer or self.clientTransport.buffer: | |
137 self.serverTransport.clearBuffer() | |
138 self.clientTransport.clearBuffer() | |
139 | |
140 | |
141 def testServerVersion(self): | |
142 self.failUnlessEqual(self._serverVersion, 3) | |
143 self.failUnlessEqual(self._extData, {'conchTest' : 'ext data'}) | |
144 | |
145 def testOpenFileIO(self): | |
146 d = self.client.openFile("testfile1", filetransfer.FXF_READ | | |
147 filetransfer.FXF_WRITE, {}) | |
148 self._emptyBuffers() | |
149 | |
150 def _fileOpened(openFile): | |
151 self.failUnlessEqual(openFile, filetransfer.ISFTPFile(openFile)) | |
152 d = _readChunk(openFile) | |
153 d.addCallback(_writeChunk, openFile) | |
154 return d | |
155 | |
156 def _readChunk(openFile): | |
157 d = openFile.readChunk(0, 20) | |
158 self._emptyBuffers() | |
159 d.addCallback(self.failUnlessEqual, 'a'*10 + 'b'*10) | |
160 return d | |
161 | |
162 def _writeChunk(_, openFile): | |
163 d = openFile.writeChunk(20, 'c'*10) | |
164 self._emptyBuffers() | |
165 d.addCallback(_readChunk2, openFile) | |
166 return d | |
167 | |
168 def _readChunk2(_, openFile): | |
169 d = openFile.readChunk(0, 30) | |
170 self._emptyBuffers() | |
171 d.addCallback(self.failUnlessEqual, 'a'*10 + 'b'*10 + 'c'*10) | |
172 return d | |
173 | |
174 d.addCallback(_fileOpened) | |
175 return d | |
176 | |
177 def testClosedFileGetAttrs(self): | |
178 d = self.client.openFile("testfile1", filetransfer.FXF_READ | | |
179 filetransfer.FXF_WRITE, {}) | |
180 self._emptyBuffers() | |
181 | |
182 def _getAttrs(_, openFile): | |
183 d = openFile.getAttrs() | |
184 self._emptyBuffers() | |
185 return d | |
186 | |
187 def _err(f): | |
188 self.flushLoggedErrors() | |
189 return f | |
190 | |
191 def _close(openFile): | |
192 d = openFile.close() | |
193 self._emptyBuffers() | |
194 d.addCallback(_getAttrs, openFile) | |
195 d.addErrback(_err) | |
196 return self.assertFailure(d, filetransfer.SFTPError) | |
197 | |
198 d.addCallback(_close) | |
199 return d | |
200 | |
201 def testOpenFileAttributes(self): | |
202 d = self.client.openFile("testfile1", filetransfer.FXF_READ | | |
203 filetransfer.FXF_WRITE, {}) | |
204 self._emptyBuffers() | |
205 | |
206 def _getAttrs(openFile): | |
207 d = openFile.getAttrs() | |
208 self._emptyBuffers() | |
209 d.addCallback(_getAttrs2) | |
210 return d | |
211 | |
212 def _getAttrs2(attrs1): | |
213 d = self.client.getAttrs('testfile1') | |
214 self._emptyBuffers() | |
215 d.addCallback(self.failUnlessEqual, attrs1) | |
216 return d | |
217 | |
218 return d.addCallback(_getAttrs) | |
219 | |
220 | |
221 def testOpenFileSetAttrs(self): | |
222 # XXX test setAttrs | |
223 # Ok, how about this for a start? It caught a bug :) -- spiv. | |
224 d = self.client.openFile("testfile1", filetransfer.FXF_READ | | |
225 filetransfer.FXF_WRITE, {}) | |
226 self._emptyBuffers() | |
227 | |
228 def _getAttrs(openFile): | |
229 d = openFile.getAttrs() | |
230 self._emptyBuffers() | |
231 d.addCallback(_setAttrs) | |
232 return d | |
233 | |
234 def _setAttrs(attrs): | |
235 attrs['atime'] = 0 | |
236 d = self.client.setAttrs('testfile1', attrs) | |
237 self._emptyBuffers() | |
238 d.addCallback(_getAttrs2) | |
239 d.addCallback(self.failUnlessEqual, attrs) | |
240 return d | |
241 | |
242 def _getAttrs2(_): | |
243 d = self.client.getAttrs('testfile1') | |
244 self._emptyBuffers() | |
245 return d | |
246 | |
247 d.addCallback(_getAttrs) | |
248 return d | |
249 | |
250 | |
251 def test_openFileExtendedAttributes(self): | |
252 """ | |
253 Check that L{filetransfer.FileTransferClient.openFile} can send | |
254 extended attributes, that should be extracted server side. By default, | |
255 they are ignored, so we just verify they are correctly parsed. | |
256 """ | |
257 savedAttributes = {} | |
258 def openFile(filename, flags, attrs): | |
259 savedAttributes.update(attrs) | |
260 self.server.client.openFile = openFile | |
261 | |
262 d = self.client.openFile("testfile1", filetransfer.FXF_READ | | |
263 filetransfer.FXF_WRITE, {"ext_foo": "bar"}) | |
264 self._emptyBuffers() | |
265 | |
266 def check(ign): | |
267 self.assertEquals(savedAttributes, {"ext_foo": "bar"}) | |
268 | |
269 return d.addCallback(check) | |
270 | |
271 | |
272 def testRemoveFile(self): | |
273 d = self.client.getAttrs("testRemoveFile") | |
274 self._emptyBuffers() | |
275 def _removeFile(ignored): | |
276 d = self.client.removeFile("testRemoveFile") | |
277 self._emptyBuffers() | |
278 return d | |
279 d.addCallback(_removeFile) | |
280 d.addCallback(_removeFile) | |
281 return self.assertFailure(d, filetransfer.SFTPError) | |
282 | |
283 def testRenameFile(self): | |
284 d = self.client.getAttrs("testRenameFile") | |
285 self._emptyBuffers() | |
286 def _rename(attrs): | |
287 d = self.client.renameFile("testRenameFile", "testRenamedFile") | |
288 self._emptyBuffers() | |
289 d.addCallback(_testRenamed, attrs) | |
290 return d | |
291 def _testRenamed(_, attrs): | |
292 d = self.client.getAttrs("testRenamedFile") | |
293 self._emptyBuffers() | |
294 d.addCallback(self.failUnlessEqual, attrs) | |
295 return d.addCallback(_rename) | |
296 | |
297 def testDirectoryBad(self): | |
298 d = self.client.getAttrs("testMakeDirectory") | |
299 self._emptyBuffers() | |
300 return self.assertFailure(d, filetransfer.SFTPError) | |
301 | |
302 def testDirectoryCreation(self): | |
303 d = self.client.makeDirectory("testMakeDirectory", {}) | |
304 self._emptyBuffers() | |
305 | |
306 def _getAttrs(_): | |
307 d = self.client.getAttrs("testMakeDirectory") | |
308 self._emptyBuffers() | |
309 return d | |
310 | |
311 # XXX not until version 4/5 | |
312 # self.failUnlessEqual(filetransfer.FILEXFER_TYPE_DIRECTORY&attrs['type'
], | |
313 # filetransfer.FILEXFER_TYPE_DIRECTORY) | |
314 | |
315 def _removeDirectory(_): | |
316 d = self.client.removeDirectory("testMakeDirectory") | |
317 self._emptyBuffers() | |
318 return d | |
319 | |
320 d.addCallback(_getAttrs) | |
321 d.addCallback(_removeDirectory) | |
322 d.addCallback(_getAttrs) | |
323 return self.assertFailure(d, filetransfer.SFTPError) | |
324 | |
325 def testOpenDirectory(self): | |
326 d = self.client.openDirectory('') | |
327 self._emptyBuffers() | |
328 files = [] | |
329 | |
330 def _getFiles(openDir): | |
331 def append(f): | |
332 files.append(f) | |
333 return openDir | |
334 d = defer.maybeDeferred(openDir.next) | |
335 self._emptyBuffers() | |
336 d.addCallback(append) | |
337 d.addCallback(_getFiles) | |
338 d.addErrback(_close, openDir) | |
339 return d | |
340 | |
341 def _checkFiles(ignored): | |
342 fs = list(zip(*files)[0]) | |
343 fs.sort() | |
344 self.failUnlessEqual(fs, | |
345 ['.testHiddenFile', 'testDirectory', | |
346 'testRemoveFile', 'testRenameFile', | |
347 'testfile1']) | |
348 | |
349 def _close(_, openDir): | |
350 d = openDir.close() | |
351 self._emptyBuffers() | |
352 return d | |
353 | |
354 d.addCallback(_getFiles) | |
355 d.addCallback(_checkFiles) | |
356 return d | |
357 | |
358 def testLinkDoesntExist(self): | |
359 d = self.client.getAttrs('testLink') | |
360 self._emptyBuffers() | |
361 return self.assertFailure(d, filetransfer.SFTPError) | |
362 | |
363 def testLinkSharesAttrs(self): | |
364 d = self.client.makeLink('testLink', 'testfile1') | |
365 self._emptyBuffers() | |
366 def _getFirstAttrs(_): | |
367 d = self.client.getAttrs('testLink', 1) | |
368 self._emptyBuffers() | |
369 return d | |
370 def _getSecondAttrs(firstAttrs): | |
371 d = self.client.getAttrs('testfile1') | |
372 self._emptyBuffers() | |
373 d.addCallback(self.assertEqual, firstAttrs) | |
374 return d | |
375 d.addCallback(_getFirstAttrs) | |
376 return d.addCallback(_getSecondAttrs) | |
377 | |
378 def testLinkPath(self): | |
379 d = self.client.makeLink('testLink', 'testfile1') | |
380 self._emptyBuffers() | |
381 def _readLink(_): | |
382 d = self.client.readLink('testLink') | |
383 self._emptyBuffers() | |
384 d.addCallback(self.failUnlessEqual, | |
385 os.path.join(os.getcwd(), self.testDir, 'testfile1')) | |
386 return d | |
387 def _realPath(_): | |
388 d = self.client.realPath('testLink') | |
389 self._emptyBuffers() | |
390 d.addCallback(self.failUnlessEqual, | |
391 os.path.join(os.getcwd(), self.testDir, 'testfile1')) | |
392 return d | |
393 d.addCallback(_readLink) | |
394 d.addCallback(_realPath) | |
395 return d | |
396 | |
397 def testExtendedRequest(self): | |
398 d = self.client.extendedRequest('testExtendedRequest', 'foo') | |
399 self._emptyBuffers() | |
400 d.addCallback(self.failUnlessEqual, 'bar') | |
401 d.addCallback(self._cbTestExtendedRequest) | |
402 return d | |
403 | |
404 def _cbTestExtendedRequest(self, ignored): | |
405 d = self.client.extendedRequest('testBadRequest', '') | |
406 self._emptyBuffers() | |
407 return self.assertFailure(d, NotImplementedError) | |
408 | |
409 | |
410 class FakeConn: | |
411 def sendClose(self, channel): | |
412 pass | |
413 | |
414 | |
415 class TestFileTransferClose(unittest.TestCase): | |
416 | |
417 if not unix: | |
418 skip = "can't run on non-posix computers" | |
419 | |
420 def setUp(self): | |
421 self.avatar = TestAvatar() | |
422 | |
423 def buildServerConnection(self): | |
424 # make a server connection | |
425 conn = connection.SSHConnection() | |
426 # server connections have a 'self.transport.avatar'. | |
427 class DummyTransport: | |
428 def __init__(self): | |
429 self.transport = self | |
430 def sendPacket(self, kind, data): | |
431 pass | |
432 def logPrefix(self): | |
433 return 'dummy transport' | |
434 conn.transport = DummyTransport() | |
435 conn.transport.avatar = self.avatar | |
436 return conn | |
437 | |
438 def interceptConnectionLost(self, sftpServer): | |
439 self.connectionLostFired = False | |
440 origConnectionLost = sftpServer.connectionLost | |
441 def connectionLost(reason): | |
442 self.connectionLostFired = True | |
443 origConnectionLost(reason) | |
444 sftpServer.connectionLost = connectionLost | |
445 | |
446 def assertSFTPConnectionLost(self): | |
447 self.assertTrue(self.connectionLostFired, | |
448 "sftpServer's connectionLost was not called") | |
449 | |
450 def test_sessionClose(self): | |
451 """ | |
452 Closing a session should notify an SFTP subsystem launched by that | |
453 session. | |
454 """ | |
455 # make a session | |
456 testSession = session.SSHSession(conn=FakeConn(), avatar=self.avatar) | |
457 | |
458 # start an SFTP subsystem on the session | |
459 testSession.request_subsystem(common.NS('sftp')) | |
460 sftpServer = testSession.client.transport.proto | |
461 | |
462 # intercept connectionLost so we can check that it's called | |
463 self.interceptConnectionLost(sftpServer) | |
464 | |
465 # close session | |
466 testSession.closeReceived() | |
467 | |
468 self.assertSFTPConnectionLost() | |
469 | |
470 def test_clientClosesChannelOnConnnection(self): | |
471 """ | |
472 A client sending CHANNEL_CLOSE should trigger closeReceived on the | |
473 associated channel instance. | |
474 """ | |
475 conn = self.buildServerConnection() | |
476 | |
477 # somehow get a session | |
478 packet = common.NS('session') + struct.pack('>L', 0) * 3 | |
479 conn.ssh_CHANNEL_OPEN(packet) | |
480 sessionChannel = conn.channels[0] | |
481 | |
482 sessionChannel.request_subsystem(common.NS('sftp')) | |
483 sftpServer = sessionChannel.client.transport.proto | |
484 self.interceptConnectionLost(sftpServer) | |
485 | |
486 # intercept closeReceived | |
487 self.interceptConnectionLost(sftpServer) | |
488 | |
489 # close the connection | |
490 conn.ssh_CHANNEL_CLOSE(struct.pack('>L', 0)) | |
491 | |
492 self.assertSFTPConnectionLost() | |
493 | |
494 | |
495 def test_stopConnectionServiceClosesChannel(self): | |
496 """ | |
497 Closing an SSH connection should close all sessions within it. | |
498 """ | |
499 conn = self.buildServerConnection() | |
500 | |
501 # somehow get a session | |
502 packet = common.NS('session') + struct.pack('>L', 0) * 3 | |
503 conn.ssh_CHANNEL_OPEN(packet) | |
504 sessionChannel = conn.channels[0] | |
505 | |
506 sessionChannel.request_subsystem(common.NS('sftp')) | |
507 sftpServer = sessionChannel.client.transport.proto | |
508 self.interceptConnectionLost(sftpServer) | |
509 | |
510 # close the connection | |
511 conn.serviceStopped() | |
512 | |
513 self.assertSFTPConnectionLost() | |
OLD | NEW |