OLD | NEW |
| (Empty) |
1 # Copyright (c) 2001-2004 Twisted Matrix Laboratories. | |
2 # See LICENSE for details. | |
3 | |
4 # | |
5 | |
6 from twisted.conch.error import ConchError | |
7 from twisted.conch.ssh import channel, connection | |
8 from twisted.internet import defer, protocol, reactor | |
9 from twisted.python import log | |
10 from twisted.spread import banana | |
11 | |
12 import os, stat, pickle | |
13 import types # this is for evil | |
14 | |
15 class SSHUnixClientFactory(protocol.ClientFactory): | |
16 # noisy = 1 | |
17 | |
18 def __init__(self, d, options, userAuthObject): | |
19 self.d = d | |
20 self.options = options | |
21 self.userAuthObject = userAuthObject | |
22 | |
23 def clientConnectionLost(self, connector, reason): | |
24 if self.options['reconnect']: | |
25 connector.connect() | |
26 #log.err(reason) | |
27 if not self.d: return | |
28 d = self.d | |
29 self.d = None | |
30 d.errback(reason) | |
31 | |
32 | |
33 def clientConnectionFailed(self, connector, reason): | |
34 #try: | |
35 # os.unlink(connector.transport.addr) | |
36 #except: | |
37 # pass | |
38 #log.err(reason) | |
39 if not self.d: return | |
40 d = self.d | |
41 self.d = None | |
42 d.errback(reason) | |
43 #reactor.connectTCP(options['host'], options['port'], SSHClientFactory()) | |
44 | |
45 def startedConnecting(self, connector): | |
46 fd = connector.transport.fileno() | |
47 stats = os.fstat(fd) | |
48 try: | |
49 filestats = os.stat(connector.transport.addr) | |
50 except: | |
51 connector.stopConnecting() | |
52 return | |
53 if stat.S_IMODE(filestats[0]) != 0600: | |
54 log.msg("socket mode is not 0600: %s" % oct(stat.S_IMODE(stats[0]))) | |
55 elif filestats[4] != os.getuid(): | |
56 log.msg("socket not owned by us: %s" % stats[4]) | |
57 elif filestats[5] != os.getgid(): | |
58 log.msg("socket not owned by our group: %s" % stats[5]) | |
59 # XXX reenable this when i can fix it for cygwin | |
60 #elif filestats[-3:] != stats[-3:]: | |
61 # log.msg("socket doesn't have same create times") | |
62 else: | |
63 log.msg('conecting OK') | |
64 return | |
65 connector.stopConnecting() | |
66 | |
67 def buildProtocol(self, addr): | |
68 # here comes the EVIL | |
69 obj = self.userAuthObject.instance | |
70 bases = [] | |
71 for base in obj.__class__.__bases__: | |
72 if base == connection.SSHConnection: | |
73 bases.append(SSHUnixClientProtocol) | |
74 else: | |
75 bases.append(base) | |
76 newClass = types.ClassType(obj.__class__.__name__, tuple(bases), obj.__c
lass__.__dict__) | |
77 obj.__class__ = newClass | |
78 SSHUnixClientProtocol.__init__(obj) | |
79 log.msg('returning %s' % obj) | |
80 if self.d: | |
81 d = self.d | |
82 self.d = None | |
83 d.callback(None) | |
84 return obj | |
85 | |
86 class SSHUnixServerFactory(protocol.Factory): | |
87 def __init__(self, conn): | |
88 self.conn = conn | |
89 | |
90 def buildProtocol(self, addr): | |
91 return SSHUnixServerProtocol(self.conn) | |
92 | |
93 class SSHUnixProtocol(banana.Banana): | |
94 | |
95 knownDialects = ['none'] | |
96 | |
97 def __init__(self): | |
98 banana.Banana.__init__(self) | |
99 self.deferredQueue = [] | |
100 self.deferreds = {} | |
101 self.deferredID = 0 | |
102 | |
103 def connectionMade(self): | |
104 log.msg('connection made %s' % self) | |
105 banana.Banana.connectionMade(self) | |
106 | |
107 def expressionReceived(self, lst): | |
108 vocabName = lst[0] | |
109 fn = "msg_%s" % vocabName | |
110 func = getattr(self, fn) | |
111 func(lst[1:]) | |
112 | |
113 def sendMessage(self, vocabName, *tup): | |
114 self.sendEncoded([vocabName] + list(tup)) | |
115 | |
116 def returnDeferredLocal(self): | |
117 d = defer.Deferred() | |
118 self.deferredQueue.append(d) | |
119 return d | |
120 | |
121 def returnDeferredWire(self, d): | |
122 di = self.deferredID | |
123 self.deferredID += 1 | |
124 self.sendMessage('returnDeferred', di) | |
125 d.addCallback(self._cbDeferred, di) | |
126 d.addErrback(self._ebDeferred, di) | |
127 | |
128 def _cbDeferred(self, result, di): | |
129 self.sendMessage('callbackDeferred', di, pickle.dumps(result)) | |
130 | |
131 def _ebDeferred(self, reason, di): | |
132 self.sendMessage('errbackDeferred', di, pickle.dumps(reason)) | |
133 | |
134 def msg_returnDeferred(self, lst): | |
135 deferredID = lst[0] | |
136 self.deferreds[deferredID] = self.deferredQueue.pop(0) | |
137 | |
138 def msg_callbackDeferred(self, lst): | |
139 deferredID, result = lst | |
140 d = self.deferreds[deferredID] | |
141 del self.deferreds[deferredID] | |
142 d.callback(pickle.loads(result)) | |
143 | |
144 def msg_errbackDeferred(self, lst): | |
145 deferredID, result = lst | |
146 d = self.deferreds[deferredID] | |
147 del self.deferreds[deferredID] | |
148 d.errback(pickle.loads(result)) | |
149 | |
150 class SSHUnixClientProtocol(SSHUnixProtocol): | |
151 | |
152 def __init__(self): | |
153 SSHUnixProtocol.__init__(self) | |
154 self.isClient = 1 | |
155 self.channelQueue = [] | |
156 self.channels = {} | |
157 | |
158 def logPrefix(self): | |
159 return "SSHUnixClientProtocol (%i) on %s" % (id(self), self.transport.lo
gPrefix()) | |
160 | |
161 def connectionReady(self): | |
162 log.msg('connection ready') | |
163 self.serviceStarted() | |
164 | |
165 def connectionLost(self, reason): | |
166 self.serviceStopped() | |
167 | |
168 def requestRemoteForwarding(self, remotePort, hostport): | |
169 self.sendMessage('requestRemoteForwarding', remotePort, hostport) | |
170 | |
171 def cancelRemoteForwarding(self, remotePort): | |
172 self.sendMessage('cancelRemoteForwarding', remotePort) | |
173 | |
174 def sendGlobalRequest(self, request, data, wantReply = 0): | |
175 self.sendMessage('sendGlobalRequest', request, data, wantReply) | |
176 if wantReply: | |
177 return self.returnDeferredLocal() | |
178 | |
179 def openChannel(self, channel, extra = ''): | |
180 self.channelQueue.append(channel) | |
181 channel.conn = self | |
182 self.sendMessage('openChannel', channel.name, | |
183 channel.localWindowSize, | |
184 channel.localMaxPacket, extra) | |
185 | |
186 def sendRequest(self, channel, requestType, data, wantReply = 0): | |
187 self.sendMessage('sendRequest', channel.id, requestType, data, wantReply
) | |
188 if wantReply: | |
189 return self.returnDeferredLocal() | |
190 | |
191 def adjustWindow(self, channel, bytesToAdd): | |
192 self.sendMessage('adjustWindow', channel.id, bytesToAdd) | |
193 | |
194 def sendData(self, channel, data): | |
195 self.sendMessage('sendData', channel.id, data) | |
196 | |
197 def sendExtendedData(self, channel, dataType, data): | |
198 self.sendMessage('sendExtendedData', channel.id, data) | |
199 | |
200 def sendEOF(self, channel): | |
201 self.sendMessage('sendEOF', channel.id) | |
202 | |
203 def sendClose(self, channel): | |
204 self.sendMessage('sendClose', channel.id) | |
205 | |
206 def msg_channelID(self, lst): | |
207 channelID = lst[0] | |
208 self.channels[channelID] = self.channelQueue.pop(0) | |
209 self.channels[channelID].id = channelID | |
210 | |
211 def msg_channelOpen(self, lst): | |
212 channelID, remoteWindow, remoteMax, specificData = lst | |
213 channel = self.channels[channelID] | |
214 channel.remoteWindowLeft = remoteWindow | |
215 channel.remoteMaxPacket = remoteMax | |
216 channel.channelOpen(specificData) | |
217 | |
218 def msg_openFailed(self, lst): | |
219 channelID, reason = lst | |
220 self.channels[channelID].openFailed(pickle.loads(reason)) | |
221 del self.channels[channelID] | |
222 | |
223 def msg_addWindowBytes(self, lst): | |
224 channelID, bytes = lst | |
225 self.channels[channelID].addWindowBytes(bytes) | |
226 | |
227 def msg_requestReceived(self, lst): | |
228 channelID, requestType, data = lst | |
229 d = defer.maybeDeferred(self.channels[channelID].requestReceived, reques
tType, data) | |
230 self.returnDeferredWire(d) | |
231 | |
232 def msg_dataReceived(self, lst): | |
233 channelID, data = lst | |
234 self.channels[channelID].dataReceived(data) | |
235 | |
236 def msg_extReceived(self, lst): | |
237 channelID, dataType, data = lst | |
238 self.channels[channelID].extReceived(dataType, data) | |
239 | |
240 def msg_eofReceived(self, lst): | |
241 channelID = lst[0] | |
242 self.channels[channelID].eofReceived() | |
243 | |
244 def msg_closeReceived(self, lst): | |
245 channelID = lst[0] | |
246 channel = self.channels[channelID] | |
247 channel.remoteClosed = 1 | |
248 channel.closeReceived() | |
249 | |
250 def msg_closed(self, lst): | |
251 channelID = lst[0] | |
252 channel = self.channels[channelID] | |
253 self.channelClosed(channel) | |
254 | |
255 def channelClosed(self, channel): | |
256 channel.localClosed = channel.remoteClosed = 1 | |
257 del self.channels[channel.id] | |
258 log.callWithLogger(channel, channel.closed) | |
259 | |
260 # just in case the user doesn't override | |
261 | |
262 def serviceStarted(self): | |
263 pass | |
264 | |
265 def serviceStopped(self): | |
266 pass | |
267 | |
268 class SSHUnixServerProtocol(SSHUnixProtocol): | |
269 | |
270 def __init__(self, conn): | |
271 SSHUnixProtocol.__init__(self) | |
272 self.isClient = 0 | |
273 self.conn = conn | |
274 | |
275 def connectionLost(self, reason): | |
276 for channel in self.conn.channels.values(): | |
277 if isinstance(channel, SSHUnixChannel) and channel.unix == self: | |
278 log.msg('forcibly closing %s' % channel) | |
279 try: | |
280 self.conn.sendClose(channel) | |
281 except: | |
282 pass | |
283 | |
284 def haveChannel(self, channelID): | |
285 return self.conn.channels.has_key(channelID) | |
286 | |
287 def getChannel(self, channelID): | |
288 channel = self.conn.channels[channelID] | |
289 if not isinstance(channel, SSHUnixChannel): | |
290 raise ConchError('nice try bub') | |
291 return channel | |
292 | |
293 def msg_requestRemoteForwarding(self, lst): | |
294 remotePort, hostport = lst | |
295 hostport = tuple(hostport) | |
296 self.conn.requestRemoteForwarding(remotePort, hostport) | |
297 | |
298 def msg_cancelRemoteForwarding(self, lst): | |
299 [remotePort] = lst | |
300 self.conn.cancelRemoteForwarding(remotePort) | |
301 | |
302 def msg_sendGlobalRequest(self, lst): | |
303 requestName, data, wantReply = lst | |
304 d = self.conn.sendGlobalRequest(requestName, data, wantReply) | |
305 if wantReply: | |
306 self.returnDeferredWire(d) | |
307 | |
308 def msg_openChannel(self, lst): | |
309 name, windowSize, maxPacket, extra = lst | |
310 channel = SSHUnixChannel(self, name, windowSize, maxPacket) | |
311 self.conn.openChannel(channel, extra) | |
312 self.sendMessage('channelID', channel.id) | |
313 | |
314 def msg_sendRequest(self, lst): | |
315 cn, requestType, data, wantReply = lst | |
316 if not self.haveChannel(cn): | |
317 if wantReply: | |
318 self.returnDeferredWire(defer.fail(ConchError("no channel"))) | |
319 channel = self.getChannel(cn) | |
320 d = self.conn.sendRequest(channel, requestType, data, wantReply) | |
321 if wantReply: | |
322 self.returnDeferredWire(d) | |
323 | |
324 def msg_adjustWindow(self, lst): | |
325 cn, bytesToAdd = lst | |
326 if not self.haveChannel(cn): return | |
327 channel = self.getChannel(cn) | |
328 self.conn.adjustWindow(channel, bytesToAdd) | |
329 | |
330 def msg_sendData(self, lst): | |
331 cn, data = lst | |
332 if not self.haveChannel(cn): return | |
333 channel = self.getChannel(cn) | |
334 self.conn.sendData(channel, data) | |
335 | |
336 def msg_sendExtended(self, lst): | |
337 cn, dataType, data = lst | |
338 if not self.haveChannel(cn): return | |
339 channel = self.getChannel(cn) | |
340 self.conn.sendExtendedData(channel, dataType, data) | |
341 | |
342 def msg_sendEOF(self, lst): | |
343 (cn, ) = lst | |
344 if not self.haveChannel(cn): return | |
345 channel = self.getChannel(cn) | |
346 self.conn.sendEOF(channel) | |
347 | |
348 def msg_sendClose(self, lst): | |
349 (cn, ) = lst | |
350 if not self.haveChannel(cn): return | |
351 channel = self.getChannel(cn) | |
352 self.conn.sendClose(channel) | |
353 | |
354 class SSHUnixChannel(channel.SSHChannel): | |
355 def __init__(self, unix, name, windowSize, maxPacket): | |
356 channel.SSHChannel.__init__(self, windowSize, maxPacket, conn = unix.con
n) | |
357 self.unix = unix | |
358 self.name = name | |
359 | |
360 def channelOpen(self, specificData): | |
361 self.unix.sendMessage('channelOpen', self.id, self.remoteWindowLeft, | |
362 self.remoteMaxPacket, specificData) | |
363 | |
364 def openFailed(self, reason): | |
365 self.unix.sendMessage('openFailed', self.id, pickle.dumps(reason)) | |
366 | |
367 def addWindowBytes(self, bytes): | |
368 self.unix.sendMessage('addWindowBytes', self.id, bytes) | |
369 | |
370 def dataReceived(self, data): | |
371 self.unix.sendMessage('dataReceived', self.id, data) | |
372 | |
373 def requestReceived(self, reqType, data): | |
374 self.unix.sendMessage('requestReceived', self.id, reqType, data) | |
375 return self.unix.returnDeferredLocal() | |
376 | |
377 def extReceived(self, dataType, data): | |
378 self.unix.sendMessage('extReceived', self.id, dataType, data) | |
379 | |
380 def eofReceived(self): | |
381 self.unix.sendMessage('eofReceived', self.id) | |
382 | |
383 def closeReceived(self): | |
384 self.unix.sendMessage('closeReceived', self.id) | |
385 | |
386 def closed(self): | |
387 self.unix.sendMessage('closed', self.id) | |
388 | |
389 def connect(host, port, options, verifyHostKey, userAuthObject): | |
390 if options['nocache']: | |
391 return defer.fail(ConchError('not using connection caching')) | |
392 d = defer.Deferred() | |
393 filename = os.path.expanduser("~/.conch-%s-%s-%i" % (userAuthObject.user, ho
st, port)) | |
394 factory = SSHUnixClientFactory(d, options, userAuthObject) | |
395 reactor.connectUNIX(filename, factory, timeout=2, checkPID=1) | |
396 return d | |
OLD | NEW |