OLD | NEW |
| (Empty) |
1 # Copyright (C) 2007-2008 Twisted Matrix Laboratories | |
2 # See LICENSE for details | |
3 | |
4 """ | |
5 Test ssh/channel.py. | |
6 """ | |
7 from twisted.conch.ssh import channel | |
8 from twisted.trial import unittest | |
9 | |
10 | |
11 class MockTransport(object): | |
12 """ | |
13 A mock Transport. All we use is the getPeer() and getHost() methods. | |
14 Channels implement the ITransport interface, and their getPeer() and | |
15 getHost() methods return ('SSH', <transport's getPeer/Host value>) so | |
16 we need to implement these methods so they have something to draw | |
17 from. | |
18 """ | |
19 def getPeer(self): | |
20 return ('MockPeer',) | |
21 | |
22 def getHost(self): | |
23 return ('MockHost',) | |
24 | |
25 | |
26 class MockConnection(object): | |
27 """ | |
28 A mock for twisted.conch.ssh.connection.SSHConnection. Record the data | |
29 that channels send, and when they try to close the connection. | |
30 | |
31 @ivar data: a C{dict} mapping channel id #s to lists of data sent by that | |
32 channel. | |
33 @ivar extData: a C{dict} mapping channel id #s to lists of 2-tuples | |
34 (extended data type, data) sent by that channel. | |
35 @ivar closes: a C{dict} mapping channel id #s to True if that channel sent | |
36 a close message. | |
37 """ | |
38 transport = MockTransport() | |
39 | |
40 def __init__(self): | |
41 self.data = {} | |
42 self.extData = {} | |
43 self.closes = {} | |
44 | |
45 def logPrefix(self): | |
46 """ | |
47 Return our logging prefix. | |
48 """ | |
49 return "MockConnection" | |
50 | |
51 def sendData(self, channel, data): | |
52 """ | |
53 Record the sent data. | |
54 """ | |
55 self.data.setdefault(channel, []).append(data) | |
56 | |
57 def sendExtendedData(self, channel, type, data): | |
58 """ | |
59 Record the sent extended data. | |
60 """ | |
61 self.extData.setdefault(channel, []).append((type, data)) | |
62 | |
63 def sendClose(self, channel): | |
64 """ | |
65 Record that the channel sent a close message. | |
66 """ | |
67 self.closes[channel] = True | |
68 | |
69 | |
70 class ChannelTestCase(unittest.TestCase): | |
71 | |
72 def setUp(self): | |
73 """ | |
74 Initialize the channel. remoteMaxPacket is 10 so that data is able | |
75 to be sent (the default of 0 means no data is sent because no packets | |
76 are made). | |
77 """ | |
78 self.conn = MockConnection() | |
79 self.channel = channel.SSHChannel(conn=self.conn, | |
80 remoteMaxPacket=10) | |
81 self.channel.name = 'channel' | |
82 | |
83 def test_init(self): | |
84 """ | |
85 Test that SSHChannel initializes correctly. localWindowSize defaults | |
86 to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable | |
87 defaults (what OpenSSH uses for those variables). | |
88 | |
89 The values in the second set of assertions are meaningless; they serve | |
90 only to verify that the instance variables are assigned in the correct | |
91 order. | |
92 """ | |
93 c = channel.SSHChannel(conn=self.conn) | |
94 self.assertEquals(c.localWindowSize, 131072) | |
95 self.assertEquals(c.localWindowLeft, 131072) | |
96 self.assertEquals(c.localMaxPacket, 32768) | |
97 self.assertEquals(c.remoteWindowLeft, 0) | |
98 self.assertEquals(c.remoteMaxPacket, 0) | |
99 self.assertEquals(c.conn, self.conn) | |
100 self.assertEquals(c.data, None) | |
101 self.assertEquals(c.avatar, None) | |
102 | |
103 c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7) | |
104 self.assertEquals(c2.localWindowSize, 1) | |
105 self.assertEquals(c2.localWindowLeft, 1) | |
106 self.assertEquals(c2.localMaxPacket, 2) | |
107 self.assertEquals(c2.remoteWindowLeft, 3) | |
108 self.assertEquals(c2.remoteMaxPacket, 4) | |
109 self.assertEquals(c2.conn, 5) | |
110 self.assertEquals(c2.data, 6) | |
111 self.assertEquals(c2.avatar, 7) | |
112 | |
113 def test_str(self): | |
114 """ | |
115 Test that str(SSHChannel) works gives the channel name and local and | |
116 remote windows at a glance.. | |
117 """ | |
118 self.assertEquals(str(self.channel), '<SSHChannel channel (lw 131072 ' | |
119 'rw 0)>') | |
120 | |
121 def test_logPrefix(self): | |
122 """ | |
123 Test that SSHChannel.logPrefix gives the name of the channel, the | |
124 local channel ID and the underlying connection. | |
125 """ | |
126 self.assertEquals(self.channel.logPrefix(), 'SSHChannel channel ' | |
127 '(unknown) on MockConnection') | |
128 | |
129 def test_addWindowBytes(self): | |
130 """ | |
131 Test that addWindowBytes adds bytes to the window and resumes writing | |
132 if it was paused. | |
133 """ | |
134 cb = [False] | |
135 def stubStartWriting(): | |
136 cb[0] = True | |
137 self.channel.startWriting = stubStartWriting | |
138 self.channel.write('test') | |
139 self.channel.writeExtended(1, 'test') | |
140 self.channel.addWindowBytes(50) | |
141 self.assertEquals(self.channel.remoteWindowLeft, 50 - 4 - 4) | |
142 self.assertTrue(self.channel.areWriting) | |
143 self.assertTrue(cb[0]) | |
144 self.assertEquals(self.channel.buf, '') | |
145 self.assertEquals(self.conn.data[self.channel], ['test']) | |
146 self.assertEquals(self.channel.extBuf, []) | |
147 self.assertEquals(self.conn.extData[self.channel], [(1, 'test')]) | |
148 | |
149 cb[0] = False | |
150 self.channel.addWindowBytes(20) | |
151 self.assertFalse(cb[0]) | |
152 | |
153 self.channel.write('a'*80) | |
154 self.channel.loseConnection() | |
155 self.channel.addWindowBytes(20) | |
156 self.assertFalse(cb[0]) | |
157 | |
158 def test_requestReceived(self): | |
159 """ | |
160 Test that requestReceived handles requests by dispatching them to | |
161 request_* methods. | |
162 """ | |
163 self.channel.request_test_method = lambda data: data == '' | |
164 self.assertTrue(self.channel.requestReceived('test-method', '')) | |
165 self.assertFalse(self.channel.requestReceived('test-method', 'a')) | |
166 self.assertFalse(self.channel.requestReceived('bad-method', '')) | |
167 | |
168 def test_closeReceieved(self): | |
169 """ | |
170 Test that the default closeReceieved closes the connection. | |
171 """ | |
172 self.assertFalse(self.channel.closing) | |
173 self.channel.closeReceived() | |
174 self.assertTrue(self.channel.closing) | |
175 | |
176 def test_write(self): | |
177 """ | |
178 Test that write handles data correctly. Send data up to the size | |
179 of the remote window, splitting the data into packets of length | |
180 remoteMaxPacket. | |
181 """ | |
182 cb = [False] | |
183 def stubStopWriting(): | |
184 cb[0] = True | |
185 # no window to start with | |
186 self.channel.stopWriting = stubStopWriting | |
187 self.channel.write('d') | |
188 self.channel.write('a') | |
189 self.assertFalse(self.channel.areWriting) | |
190 self.assertTrue(cb[0]) | |
191 # regular write | |
192 self.channel.addWindowBytes(20) | |
193 self.channel.write('ta') | |
194 data = self.conn.data[self.channel] | |
195 self.assertEquals(data, ['da', 'ta']) | |
196 self.assertEquals(self.channel.remoteWindowLeft, 16) | |
197 # larger than max packet | |
198 self.channel.write('12345678901') | |
199 self.assertEquals(data, ['da', 'ta', '1234567890', '1']) | |
200 self.assertEquals(self.channel.remoteWindowLeft, 5) | |
201 # running out of window | |
202 cb[0] = False | |
203 self.channel.write('123456') | |
204 self.assertFalse(self.channel.areWriting) | |
205 self.assertTrue(cb[0]) | |
206 self.assertEquals(data, ['da', 'ta', '1234567890', '1', '12345']) | |
207 self.assertEquals(self.channel.buf, '6') | |
208 self.assertEquals(self.channel.remoteWindowLeft, 0) | |
209 | |
210 def test_writeExtended(self): | |
211 """ | |
212 Test that writeExtended handles data correctly. Send extended data | |
213 up to the size of the window, splitting the extended data into packets | |
214 of length remoteMaxPacket. | |
215 """ | |
216 cb = [False] | |
217 def stubStopWriting(): | |
218 cb[0] = True | |
219 # no window to start with | |
220 self.channel.stopWriting = stubStopWriting | |
221 self.channel.writeExtended(1, 'd') | |
222 self.channel.writeExtended(1, 'a') | |
223 self.channel.writeExtended(2, 't') | |
224 self.assertFalse(self.channel.areWriting) | |
225 self.assertTrue(cb[0]) | |
226 # regular write | |
227 self.channel.addWindowBytes(20) | |
228 self.channel.writeExtended(2, 'a') | |
229 data = self.conn.extData[self.channel] | |
230 self.assertEquals(data, [(1, 'da'), (2, 't'), (2, 'a')]) | |
231 self.assertEquals(self.channel.remoteWindowLeft, 16) | |
232 # larger than max packet | |
233 self.channel.writeExtended(3, '12345678901') | |
234 self.assertEquals(data, [(1, 'da'), (2, 't'), (2, 'a'), | |
235 (3, '1234567890'), (3, '1')]) | |
236 self.assertEquals(self.channel.remoteWindowLeft, 5) | |
237 # running out of window | |
238 cb[0] = False | |
239 self.channel.writeExtended(4, '123456') | |
240 self.assertFalse(self.channel.areWriting) | |
241 self.assertTrue(cb[0]) | |
242 self.assertEquals(data, [(1, 'da'), (2, 't'), (2, 'a'), | |
243 (3, '1234567890'), (3, '1'), (4, '12345')]) | |
244 self.assertEquals(self.channel.extBuf, [[4, '6']]) | |
245 self.assertEquals(self.channel.remoteWindowLeft, 0) | |
246 | |
247 def test_writeSequence(self): | |
248 """ | |
249 Test that writeSequence is equivalent to write(''.join(sequece)). | |
250 """ | |
251 self.channel.addWindowBytes(20) | |
252 self.channel.writeSequence(map(str, range(10))) | |
253 self.assertEquals(self.conn.data[self.channel], ['0123456789']) | |
254 | |
255 def test_loseConnection(self): | |
256 """ | |
257 Tesyt that loseConnection() doesn't close the channel until all | |
258 the data is sent. | |
259 """ | |
260 self.channel.write('data') | |
261 self.channel.writeExtended(1, 'datadata') | |
262 self.channel.loseConnection() | |
263 self.assertEquals(self.conn.closes.get(self.channel), None) | |
264 self.channel.addWindowBytes(4) # send regular data | |
265 self.assertEquals(self.conn.closes.get(self.channel), None) | |
266 self.channel.addWindowBytes(8) # send extended data | |
267 self.assertTrue(self.conn.closes.get(self.channel)) | |
268 | |
269 def test_getPeer(self): | |
270 """ | |
271 Test that getPeer() returns ('SSH', <connection transport peer>). | |
272 """ | |
273 self.assertEquals(self.channel.getPeer(), ('SSH', 'MockPeer')) | |
274 | |
275 def test_getHost(self): | |
276 """ | |
277 Test that getHost() returns ('SSH', <connection transport host>). | |
278 """ | |
279 self.assertEquals(self.channel.getHost(), ('SSH', 'MockHost')) | |
OLD | NEW |