OLD | NEW |
| (Empty) |
1 from zope.interface import implements | |
2 | |
3 from twisted.cred.checkers import ICredentialsChecker | |
4 from twisted.cred.credentials import IUsernamePassword | |
5 from twisted.cred.error import UnauthorizedLogin | |
6 from twisted.cred.portal import IRealm, Portal | |
7 | |
8 from twisted.conch.error import ConchError | |
9 from twisted.conch.ssh import userauth | |
10 from twisted.conch.ssh.common import NS | |
11 from twisted.conch.ssh.transport import SSHServerTransport | |
12 | |
13 from twisted.internet import defer | |
14 | |
15 from twisted.trial import unittest | |
16 | |
17 | |
18 | |
19 class FakeTransport(SSHServerTransport): | |
20 """ | |
21 L{userauth.SSHUserAuthServer} expects an SSH transport which has a factory | |
22 attribute which has a portal attribute. Because the portal is important for | |
23 testing authentication, we need to be able to provide an interesting portal | |
24 object to the C{SSHUserAuthServer}. | |
25 | |
26 In addition, we want to be able to capture any packets sent over the | |
27 transport. | |
28 """ | |
29 | |
30 | |
31 class Service(object): | |
32 name = 'nancy' | |
33 | |
34 def serviceStarted(self): | |
35 pass | |
36 | |
37 | |
38 class Factory(object): | |
39 def _makeService(self): | |
40 return FakeTransport.Service() | |
41 | |
42 def getService(self, transport, nextService): | |
43 # This has to return a callable. | |
44 return self._makeService | |
45 | |
46 | |
47 def __init__(self, portal): | |
48 self.factory = self.Factory() | |
49 self.factory.portal = portal | |
50 self.packets = [] | |
51 | |
52 | |
53 def sendPacket(self, messageType, message): | |
54 self.packets.append((messageType, message)) | |
55 | |
56 | |
57 def isEncrypted(self, direction): | |
58 """ | |
59 Pretend that this transport encrypts traffic in both directions. The | |
60 SSHUserAuthServer disables password authentication if the transport | |
61 isn't encrypted. | |
62 """ | |
63 return True | |
64 | |
65 | |
66 | |
67 class Realm(object): | |
68 """ | |
69 A mock realm for testing L{userauth.SSHUserAuthServer}. | |
70 | |
71 This realm is not actually used in the course of testing, so it returns the | |
72 simplest thing that could possibly work. | |
73 """ | |
74 | |
75 implements(IRealm) | |
76 | |
77 def requestAvatar(self, avatarId, mind, *interfaces): | |
78 return defer.succeed((interfaces[0], None, lambda: None)) | |
79 | |
80 | |
81 | |
82 class MockChecker(object): | |
83 """ | |
84 A very simple username/password checker which authenticates anyone whose | |
85 password matches their username and rejects all others. | |
86 """ | |
87 | |
88 credentialInterfaces = (IUsernamePassword,) | |
89 implements(ICredentialsChecker) | |
90 | |
91 | |
92 def requestAvatarId(self, creds): | |
93 if creds.username == creds.password: | |
94 return defer.succeed(creds.username) | |
95 return defer.fail(UnauthorizedLogin("Invalid username/password pair")) | |
96 | |
97 | |
98 | |
99 class TestSSHUserAuthServer(unittest.TestCase): | |
100 """Tests for SSHUserAuthServer.""" | |
101 | |
102 def setUp(self): | |
103 self.realm = Realm() | |
104 portal = Portal(self.realm) | |
105 portal.registerChecker(MockChecker()) | |
106 self.authServer = userauth.SSHUserAuthServer() | |
107 self.authServer.transport = FakeTransport(portal) | |
108 self.authServer.serviceStarted() | |
109 | |
110 | |
111 def tearDown(self): | |
112 self.authServer.serviceStopped() | |
113 self.authServer = None | |
114 | |
115 | |
116 def test_successfulAuthentication(self): | |
117 """ | |
118 When provided with correct authentication information, the server | |
119 should respond by sending a MSG_USERAUTH_SUCCESS message with no other | |
120 data. | |
121 | |
122 See RFC 4252, Section 5.1. | |
123 """ | |
124 packet = NS('foo') + NS('none') + NS('password') + chr(0) + NS('foo') | |
125 d = self.authServer.ssh_USERAUTH_REQUEST(packet) | |
126 | |
127 def check(ignored): | |
128 # Check that the server reports the failure, including 'password' | |
129 # as a valid authentication type. | |
130 self.assertEqual( | |
131 self.authServer.transport.packets, | |
132 [(userauth.MSG_USERAUTH_SUCCESS, '')]) | |
133 return d.addCallback(check) | |
134 | |
135 | |
136 def test_failedAuthentication(self): | |
137 """ | |
138 When provided with invalid authentication details, the server should | |
139 respond by sending a MSG_USERAUTH_FAILURE message which states whether | |
140 the authentication was partially successful, and provides other, open | |
141 options for authentication. | |
142 | |
143 See RFC 4252, Section 5.1. | |
144 """ | |
145 # packet = username, next_service, authentication type, FALSE, password | |
146 packet = NS('foo') + NS('none') + NS('password') + chr(0) + NS('bar') | |
147 d = self.authServer.ssh_USERAUTH_REQUEST(packet) | |
148 | |
149 def check(ignored): | |
150 # Check that the server reports the failure, including 'password' | |
151 # as a valid authentication type. | |
152 self.assertEqual( | |
153 self.authServer.transport.packets, | |
154 [(userauth.MSG_USERAUTH_FAILURE, NS('password') + chr(0))]) | |
155 return d.addCallback(check) | |
156 | |
157 | |
158 def test_requestRaisesConchError(self): | |
159 """ | |
160 ssh_USERAUTH_REQUEST should raise a ConchError if tryAuth returns | |
161 None. Added to catch a bug noticed by pyflakes. This is a whitebox | |
162 test. | |
163 """ | |
164 def mockTryAuth(kind, user, data): | |
165 return None | |
166 | |
167 def mockEbBadAuth(reason): | |
168 reason.trap(ConchError) | |
169 | |
170 self.patch(self.authServer, 'tryAuth', mockTryAuth) | |
171 self.patch(self.authServer, '_ebBadAuth', mockEbBadAuth) | |
172 | |
173 packet = NS('user') + NS('none') + NS('public-key') + NS('data') | |
174 # If an error other than ConchError is raised, this will trigger an | |
175 # exception. | |
176 return self.authServer.ssh_USERAUTH_REQUEST(packet) | |
OLD | NEW |