| OLD | NEW |
| (Empty) |
| 1 # Copyright (C) 2007-2008 Twisted Matrix Laboratories | |
| 2 # See LICENSE for details | |
| 3 | |
| 4 """ | |
| 5 Test ssh/channel.py. | |
| 6 """ | |
| 7 from twisted.conch.ssh import channel | |
| 8 from twisted.trial import unittest | |
| 9 | |
| 10 | |
| 11 class MockTransport(object): | |
| 12 """ | |
| 13 A mock Transport. All we use is the getPeer() and getHost() methods. | |
| 14 Channels implement the ITransport interface, and their getPeer() and | |
| 15 getHost() methods return ('SSH', <transport's getPeer/Host value>) so | |
| 16 we need to implement these methods so they have something to draw | |
| 17 from. | |
| 18 """ | |
| 19 def getPeer(self): | |
| 20 return ('MockPeer',) | |
| 21 | |
| 22 def getHost(self): | |
| 23 return ('MockHost',) | |
| 24 | |
| 25 | |
| 26 class MockConnection(object): | |
| 27 """ | |
| 28 A mock for twisted.conch.ssh.connection.SSHConnection. Record the data | |
| 29 that channels send, and when they try to close the connection. | |
| 30 | |
| 31 @ivar data: a C{dict} mapping channel id #s to lists of data sent by that | |
| 32 channel. | |
| 33 @ivar extData: a C{dict} mapping channel id #s to lists of 2-tuples | |
| 34 (extended data type, data) sent by that channel. | |
| 35 @ivar closes: a C{dict} mapping channel id #s to True if that channel sent | |
| 36 a close message. | |
| 37 """ | |
| 38 transport = MockTransport() | |
| 39 | |
| 40 def __init__(self): | |
| 41 self.data = {} | |
| 42 self.extData = {} | |
| 43 self.closes = {} | |
| 44 | |
| 45 def logPrefix(self): | |
| 46 """ | |
| 47 Return our logging prefix. | |
| 48 """ | |
| 49 return "MockConnection" | |
| 50 | |
| 51 def sendData(self, channel, data): | |
| 52 """ | |
| 53 Record the sent data. | |
| 54 """ | |
| 55 self.data.setdefault(channel, []).append(data) | |
| 56 | |
| 57 def sendExtendedData(self, channel, type, data): | |
| 58 """ | |
| 59 Record the sent extended data. | |
| 60 """ | |
| 61 self.extData.setdefault(channel, []).append((type, data)) | |
| 62 | |
| 63 def sendClose(self, channel): | |
| 64 """ | |
| 65 Record that the channel sent a close message. | |
| 66 """ | |
| 67 self.closes[channel] = True | |
| 68 | |
| 69 | |
| 70 class ChannelTestCase(unittest.TestCase): | |
| 71 | |
| 72 def setUp(self): | |
| 73 """ | |
| 74 Initialize the channel. remoteMaxPacket is 10 so that data is able | |
| 75 to be sent (the default of 0 means no data is sent because no packets | |
| 76 are made). | |
| 77 """ | |
| 78 self.conn = MockConnection() | |
| 79 self.channel = channel.SSHChannel(conn=self.conn, | |
| 80 remoteMaxPacket=10) | |
| 81 self.channel.name = 'channel' | |
| 82 | |
| 83 def test_init(self): | |
| 84 """ | |
| 85 Test that SSHChannel initializes correctly. localWindowSize defaults | |
| 86 to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable | |
| 87 defaults (what OpenSSH uses for those variables). | |
| 88 | |
| 89 The values in the second set of assertions are meaningless; they serve | |
| 90 only to verify that the instance variables are assigned in the correct | |
| 91 order. | |
| 92 """ | |
| 93 c = channel.SSHChannel(conn=self.conn) | |
| 94 self.assertEquals(c.localWindowSize, 131072) | |
| 95 self.assertEquals(c.localWindowLeft, 131072) | |
| 96 self.assertEquals(c.localMaxPacket, 32768) | |
| 97 self.assertEquals(c.remoteWindowLeft, 0) | |
| 98 self.assertEquals(c.remoteMaxPacket, 0) | |
| 99 self.assertEquals(c.conn, self.conn) | |
| 100 self.assertEquals(c.data, None) | |
| 101 self.assertEquals(c.avatar, None) | |
| 102 | |
| 103 c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7) | |
| 104 self.assertEquals(c2.localWindowSize, 1) | |
| 105 self.assertEquals(c2.localWindowLeft, 1) | |
| 106 self.assertEquals(c2.localMaxPacket, 2) | |
| 107 self.assertEquals(c2.remoteWindowLeft, 3) | |
| 108 self.assertEquals(c2.remoteMaxPacket, 4) | |
| 109 self.assertEquals(c2.conn, 5) | |
| 110 self.assertEquals(c2.data, 6) | |
| 111 self.assertEquals(c2.avatar, 7) | |
| 112 | |
| 113 def test_str(self): | |
| 114 """ | |
| 115 Test that str(SSHChannel) works gives the channel name and local and | |
| 116 remote windows at a glance.. | |
| 117 """ | |
| 118 self.assertEquals(str(self.channel), '<SSHChannel channel (lw 131072 ' | |
| 119 'rw 0)>') | |
| 120 | |
| 121 def test_logPrefix(self): | |
| 122 """ | |
| 123 Test that SSHChannel.logPrefix gives the name of the channel, the | |
| 124 local channel ID and the underlying connection. | |
| 125 """ | |
| 126 self.assertEquals(self.channel.logPrefix(), 'SSHChannel channel ' | |
| 127 '(unknown) on MockConnection') | |
| 128 | |
| 129 def test_addWindowBytes(self): | |
| 130 """ | |
| 131 Test that addWindowBytes adds bytes to the window and resumes writing | |
| 132 if it was paused. | |
| 133 """ | |
| 134 cb = [False] | |
| 135 def stubStartWriting(): | |
| 136 cb[0] = True | |
| 137 self.channel.startWriting = stubStartWriting | |
| 138 self.channel.write('test') | |
| 139 self.channel.writeExtended(1, 'test') | |
| 140 self.channel.addWindowBytes(50) | |
| 141 self.assertEquals(self.channel.remoteWindowLeft, 50 - 4 - 4) | |
| 142 self.assertTrue(self.channel.areWriting) | |
| 143 self.assertTrue(cb[0]) | |
| 144 self.assertEquals(self.channel.buf, '') | |
| 145 self.assertEquals(self.conn.data[self.channel], ['test']) | |
| 146 self.assertEquals(self.channel.extBuf, []) | |
| 147 self.assertEquals(self.conn.extData[self.channel], [(1, 'test')]) | |
| 148 | |
| 149 cb[0] = False | |
| 150 self.channel.addWindowBytes(20) | |
| 151 self.assertFalse(cb[0]) | |
| 152 | |
| 153 self.channel.write('a'*80) | |
| 154 self.channel.loseConnection() | |
| 155 self.channel.addWindowBytes(20) | |
| 156 self.assertFalse(cb[0]) | |
| 157 | |
| 158 def test_requestReceived(self): | |
| 159 """ | |
| 160 Test that requestReceived handles requests by dispatching them to | |
| 161 request_* methods. | |
| 162 """ | |
| 163 self.channel.request_test_method = lambda data: data == '' | |
| 164 self.assertTrue(self.channel.requestReceived('test-method', '')) | |
| 165 self.assertFalse(self.channel.requestReceived('test-method', 'a')) | |
| 166 self.assertFalse(self.channel.requestReceived('bad-method', '')) | |
| 167 | |
| 168 def test_closeReceieved(self): | |
| 169 """ | |
| 170 Test that the default closeReceieved closes the connection. | |
| 171 """ | |
| 172 self.assertFalse(self.channel.closing) | |
| 173 self.channel.closeReceived() | |
| 174 self.assertTrue(self.channel.closing) | |
| 175 | |
| 176 def test_write(self): | |
| 177 """ | |
| 178 Test that write handles data correctly. Send data up to the size | |
| 179 of the remote window, splitting the data into packets of length | |
| 180 remoteMaxPacket. | |
| 181 """ | |
| 182 cb = [False] | |
| 183 def stubStopWriting(): | |
| 184 cb[0] = True | |
| 185 # no window to start with | |
| 186 self.channel.stopWriting = stubStopWriting | |
| 187 self.channel.write('d') | |
| 188 self.channel.write('a') | |
| 189 self.assertFalse(self.channel.areWriting) | |
| 190 self.assertTrue(cb[0]) | |
| 191 # regular write | |
| 192 self.channel.addWindowBytes(20) | |
| 193 self.channel.write('ta') | |
| 194 data = self.conn.data[self.channel] | |
| 195 self.assertEquals(data, ['da', 'ta']) | |
| 196 self.assertEquals(self.channel.remoteWindowLeft, 16) | |
| 197 # larger than max packet | |
| 198 self.channel.write('12345678901') | |
| 199 self.assertEquals(data, ['da', 'ta', '1234567890', '1']) | |
| 200 self.assertEquals(self.channel.remoteWindowLeft, 5) | |
| 201 # running out of window | |
| 202 cb[0] = False | |
| 203 self.channel.write('123456') | |
| 204 self.assertFalse(self.channel.areWriting) | |
| 205 self.assertTrue(cb[0]) | |
| 206 self.assertEquals(data, ['da', 'ta', '1234567890', '1', '12345']) | |
| 207 self.assertEquals(self.channel.buf, '6') | |
| 208 self.assertEquals(self.channel.remoteWindowLeft, 0) | |
| 209 | |
| 210 def test_writeExtended(self): | |
| 211 """ | |
| 212 Test that writeExtended handles data correctly. Send extended data | |
| 213 up to the size of the window, splitting the extended data into packets | |
| 214 of length remoteMaxPacket. | |
| 215 """ | |
| 216 cb = [False] | |
| 217 def stubStopWriting(): | |
| 218 cb[0] = True | |
| 219 # no window to start with | |
| 220 self.channel.stopWriting = stubStopWriting | |
| 221 self.channel.writeExtended(1, 'd') | |
| 222 self.channel.writeExtended(1, 'a') | |
| 223 self.channel.writeExtended(2, 't') | |
| 224 self.assertFalse(self.channel.areWriting) | |
| 225 self.assertTrue(cb[0]) | |
| 226 # regular write | |
| 227 self.channel.addWindowBytes(20) | |
| 228 self.channel.writeExtended(2, 'a') | |
| 229 data = self.conn.extData[self.channel] | |
| 230 self.assertEquals(data, [(1, 'da'), (2, 't'), (2, 'a')]) | |
| 231 self.assertEquals(self.channel.remoteWindowLeft, 16) | |
| 232 # larger than max packet | |
| 233 self.channel.writeExtended(3, '12345678901') | |
| 234 self.assertEquals(data, [(1, 'da'), (2, 't'), (2, 'a'), | |
| 235 (3, '1234567890'), (3, '1')]) | |
| 236 self.assertEquals(self.channel.remoteWindowLeft, 5) | |
| 237 # running out of window | |
| 238 cb[0] = False | |
| 239 self.channel.writeExtended(4, '123456') | |
| 240 self.assertFalse(self.channel.areWriting) | |
| 241 self.assertTrue(cb[0]) | |
| 242 self.assertEquals(data, [(1, 'da'), (2, 't'), (2, 'a'), | |
| 243 (3, '1234567890'), (3, '1'), (4, '12345')]) | |
| 244 self.assertEquals(self.channel.extBuf, [[4, '6']]) | |
| 245 self.assertEquals(self.channel.remoteWindowLeft, 0) | |
| 246 | |
| 247 def test_writeSequence(self): | |
| 248 """ | |
| 249 Test that writeSequence is equivalent to write(''.join(sequece)). | |
| 250 """ | |
| 251 self.channel.addWindowBytes(20) | |
| 252 self.channel.writeSequence(map(str, range(10))) | |
| 253 self.assertEquals(self.conn.data[self.channel], ['0123456789']) | |
| 254 | |
| 255 def test_loseConnection(self): | |
| 256 """ | |
| 257 Tesyt that loseConnection() doesn't close the channel until all | |
| 258 the data is sent. | |
| 259 """ | |
| 260 self.channel.write('data') | |
| 261 self.channel.writeExtended(1, 'datadata') | |
| 262 self.channel.loseConnection() | |
| 263 self.assertEquals(self.conn.closes.get(self.channel), None) | |
| 264 self.channel.addWindowBytes(4) # send regular data | |
| 265 self.assertEquals(self.conn.closes.get(self.channel), None) | |
| 266 self.channel.addWindowBytes(8) # send extended data | |
| 267 self.assertTrue(self.conn.closes.get(self.channel)) | |
| 268 | |
| 269 def test_getPeer(self): | |
| 270 """ | |
| 271 Test that getPeer() returns ('SSH', <connection transport peer>). | |
| 272 """ | |
| 273 self.assertEquals(self.channel.getPeer(), ('SSH', 'MockPeer')) | |
| 274 | |
| 275 def test_getHost(self): | |
| 276 """ | |
| 277 Test that getHost() returns ('SSH', <connection transport host>). | |
| 278 """ | |
| 279 self.assertEquals(self.channel.getHost(), ('SSH', 'MockHost')) | |
| OLD | NEW |