OLD | NEW |
| (Empty) |
1 # Copyright (c) 2001-2007 Twisted Matrix Laboratories. | |
2 # See LICENSE for details. | |
3 | |
4 """ | |
5 Now with 30% more starch. | |
6 """ | |
7 | |
8 | |
9 import hmac | |
10 from zope.interface import implements, Interface | |
11 | |
12 from twisted.trial import unittest | |
13 from twisted.cred import portal, checkers, credentials, error | |
14 from twisted.python import components | |
15 from twisted.internet import defer | |
16 from twisted.internet.defer import deferredGenerator as dG, waitForDeferred as w
FD | |
17 | |
18 try: | |
19 from crypt import crypt | |
20 except ImportError: | |
21 crypt = None | |
22 | |
23 try: | |
24 from twisted.cred.pamauth import callIntoPAM | |
25 except ImportError: | |
26 pamauth = None | |
27 else: | |
28 from twisted.cred import pamauth | |
29 | |
30 class ITestable(Interface): | |
31 pass | |
32 | |
33 class TestAvatar: | |
34 def __init__(self, name): | |
35 self.name = name | |
36 self.loggedIn = False | |
37 self.loggedOut = False | |
38 | |
39 def login(self): | |
40 assert not self.loggedIn | |
41 self.loggedIn = True | |
42 | |
43 def logout(self): | |
44 self.loggedOut = True | |
45 | |
46 class Testable(components.Adapter): | |
47 implements(ITestable) | |
48 | |
49 # components.Interface(TestAvatar).adaptWith(Testable, ITestable) | |
50 | |
51 components.registerAdapter(Testable, TestAvatar, ITestable) | |
52 | |
53 class IDerivedCredentials(credentials.IUsernamePassword): | |
54 pass | |
55 | |
56 class DerivedCredentials(object): | |
57 implements(IDerivedCredentials, ITestable) | |
58 | |
59 def __init__(self, username, password): | |
60 self.username = username | |
61 self.password = password | |
62 | |
63 def checkPassword(self, password): | |
64 return password == self.password | |
65 | |
66 | |
67 class TestRealm: | |
68 implements(portal.IRealm) | |
69 def __init__(self): | |
70 self.avatars = {} | |
71 | |
72 def requestAvatar(self, avatarId, mind, *interfaces): | |
73 if self.avatars.has_key(avatarId): | |
74 avatar = self.avatars[avatarId] | |
75 else: | |
76 avatar = TestAvatar(avatarId) | |
77 self.avatars[avatarId] = avatar | |
78 avatar.login() | |
79 return (interfaces[0], interfaces[0](avatar), | |
80 avatar.logout) | |
81 | |
82 class NewCredTest(unittest.TestCase): | |
83 def setUp(self): | |
84 r = self.realm = TestRealm() | |
85 p = self.portal = portal.Portal(r) | |
86 up = self.checker = checkers.InMemoryUsernamePasswordDatabaseDontUse() | |
87 up.addUser("bob", "hello") | |
88 p.registerChecker(up) | |
89 | |
90 def testListCheckers(self): | |
91 expected = [credentials.IUsernamePassword, credentials.IUsernameHashedPa
ssword] | |
92 got = self.portal.listCredentialsInterfaces() | |
93 expected.sort() | |
94 got.sort() | |
95 self.assertEquals(got, expected) | |
96 | |
97 def testBasicLogin(self): | |
98 l = []; f = [] | |
99 self.portal.login(credentials.UsernamePassword("bob", "hello"), | |
100 self, ITestable).addCallback( | |
101 l.append).addErrback(f.append) | |
102 if f: | |
103 raise f[0] | |
104 # print l[0].getBriefTraceback() | |
105 iface, impl, logout = l[0] | |
106 # whitebox | |
107 self.assertEquals(iface, ITestable) | |
108 self.failUnless(iface.providedBy(impl), | |
109 "%s does not implement %s" % (impl, iface)) | |
110 # greybox | |
111 self.failUnless(impl.original.loggedIn) | |
112 self.failUnless(not impl.original.loggedOut) | |
113 logout() | |
114 self.failUnless(impl.original.loggedOut) | |
115 | |
116 def test_derivedInterface(self): | |
117 """ | |
118 Login with credentials implementing an interface inheriting from an | |
119 interface registered with a checker (but not itself registered). | |
120 """ | |
121 l = [] | |
122 f = [] | |
123 self.portal.login(DerivedCredentials("bob", "hello"), self, ITestable | |
124 ).addCallback(l.append | |
125 ).addErrback(f.append) | |
126 if f: | |
127 raise f[0] | |
128 iface, impl, logout = l[0] | |
129 # whitebox | |
130 self.assertEquals(iface, ITestable) | |
131 self.failUnless(iface.providedBy(impl), | |
132 "%s does not implement %s" % (impl, iface)) | |
133 # greybox | |
134 self.failUnless(impl.original.loggedIn) | |
135 self.failUnless(not impl.original.loggedOut) | |
136 logout() | |
137 self.failUnless(impl.original.loggedOut) | |
138 | |
139 def testFailedLogin(self): | |
140 l = [] | |
141 self.portal.login(credentials.UsernamePassword("bob", "h3llo"), | |
142 self, ITestable).addErrback( | |
143 lambda x: x.trap(error.UnauthorizedLogin)).addCallback(l.append) | |
144 self.failUnless(l) | |
145 self.failUnlessEqual(error.UnauthorizedLogin, l[0]) | |
146 | |
147 def testFailedLoginName(self): | |
148 l = [] | |
149 self.portal.login(credentials.UsernamePassword("jay", "hello"), | |
150 self, ITestable).addErrback( | |
151 lambda x: x.trap(error.UnauthorizedLogin)).addCallback(l.append) | |
152 self.failUnless(l) | |
153 self.failUnlessEqual(error.UnauthorizedLogin, l[0]) | |
154 | |
155 | |
156 class CramMD5CredentialsTestCase(unittest.TestCase): | |
157 def testIdempotentChallenge(self): | |
158 c = credentials.CramMD5Credentials() | |
159 chal = c.getChallenge() | |
160 self.assertEquals(chal, c.getChallenge()) | |
161 | |
162 def testCheckPassword(self): | |
163 c = credentials.CramMD5Credentials() | |
164 chal = c.getChallenge() | |
165 c.response = hmac.HMAC('secret', chal).hexdigest() | |
166 self.failUnless(c.checkPassword('secret')) | |
167 | |
168 def testWrongPassword(self): | |
169 c = credentials.CramMD5Credentials() | |
170 self.failIf(c.checkPassword('secret')) | |
171 | |
172 class OnDiskDatabaseTestCase(unittest.TestCase): | |
173 users = [ | |
174 ('user1', 'pass1'), | |
175 ('user2', 'pass2'), | |
176 ('user3', 'pass3'), | |
177 ] | |
178 | |
179 | |
180 def testUserLookup(self): | |
181 dbfile = self.mktemp() | |
182 db = checkers.FilePasswordDB(dbfile) | |
183 f = file(dbfile, 'w') | |
184 for (u, p) in self.users: | |
185 f.write('%s:%s\n' % (u, p)) | |
186 f.close() | |
187 | |
188 for (u, p) in self.users: | |
189 self.failUnlessRaises(KeyError, db.getUser, u.upper()) | |
190 self.assertEquals(db.getUser(u), (u, p)) | |
191 | |
192 def testCaseInSensitivity(self): | |
193 dbfile = self.mktemp() | |
194 db = checkers.FilePasswordDB(dbfile, caseSensitive=0) | |
195 f = file(dbfile, 'w') | |
196 for (u, p) in self.users: | |
197 f.write('%s:%s\n' % (u, p)) | |
198 f.close() | |
199 | |
200 for (u, p) in self.users: | |
201 self.assertEquals(db.getUser(u.upper()), (u, p)) | |
202 | |
203 def testRequestAvatarId(self): | |
204 dbfile = self.mktemp() | |
205 db = checkers.FilePasswordDB(dbfile, caseSensitive=0) | |
206 f = file(dbfile, 'w') | |
207 for (u, p) in self.users: | |
208 f.write('%s:%s\n' % (u, p)) | |
209 f.close() | |
210 creds = [credentials.UsernamePassword(u, p) for u, p in self.users] | |
211 d = defer.gatherResults( | |
212 [defer.maybeDeferred(db.requestAvatarId, c) for c in creds]) | |
213 d.addCallback(self.assertEquals, [u for u, p in self.users]) | |
214 return d | |
215 | |
216 def testRequestAvatarId_hashed(self): | |
217 dbfile = self.mktemp() | |
218 db = checkers.FilePasswordDB(dbfile, caseSensitive=0) | |
219 f = file(dbfile, 'w') | |
220 for (u, p) in self.users: | |
221 f.write('%s:%s\n' % (u, p)) | |
222 f.close() | |
223 creds = [credentials.UsernameHashedPassword(u, p) for u, p in self.users
] | |
224 d = defer.gatherResults( | |
225 [defer.maybeDeferred(db.requestAvatarId, c) for c in creds]) | |
226 d.addCallback(self.assertEquals, [u for u, p in self.users]) | |
227 return d | |
228 | |
229 | |
230 | |
231 class HashedPasswordOnDiskDatabaseTestCase(unittest.TestCase): | |
232 users = [ | |
233 ('user1', 'pass1'), | |
234 ('user2', 'pass2'), | |
235 ('user3', 'pass3'), | |
236 ] | |
237 | |
238 | |
239 def hash(self, u, p, s): | |
240 return crypt(p, s) | |
241 | |
242 def setUp(self): | |
243 dbfile = self.mktemp() | |
244 self.db = checkers.FilePasswordDB(dbfile, hash=self.hash) | |
245 f = file(dbfile, 'w') | |
246 for (u, p) in self.users: | |
247 f.write('%s:%s\n' % (u, crypt(p, u[:2]))) | |
248 f.close() | |
249 r = TestRealm() | |
250 self.port = portal.Portal(r) | |
251 self.port.registerChecker(self.db) | |
252 | |
253 def testGoodCredentials(self): | |
254 goodCreds = [credentials.UsernamePassword(u, p) for u, p in self.users] | |
255 d = defer.gatherResults([self.db.requestAvatarId(c) for c in goodCreds]) | |
256 d.addCallback(self.assertEquals, [u for u, p in self.users]) | |
257 return d | |
258 | |
259 def testGoodCredentials_login(self): | |
260 goodCreds = [credentials.UsernamePassword(u, p) for u, p in self.users] | |
261 d = defer.gatherResults([self.port.login(c, None, ITestable) | |
262 for c in goodCreds]) | |
263 d.addCallback(lambda x: [a.original.name for i, a, l in x]) | |
264 d.addCallback(self.assertEquals, [u for u, p in self.users]) | |
265 return d | |
266 | |
267 def testBadCredentials(self): | |
268 badCreds = [credentials.UsernamePassword(u, 'wrong password') | |
269 for u, p in self.users] | |
270 d = defer.DeferredList([self.port.login(c, None, ITestable) | |
271 for c in badCreds], consumeErrors=True) | |
272 d.addCallback(self._assertFailures, error.UnauthorizedLogin) | |
273 return d | |
274 | |
275 def testHashedCredentials(self): | |
276 hashedCreds = [credentials.UsernameHashedPassword(u, crypt(p, u[:2])) | |
277 for u, p in self.users] | |
278 d = defer.DeferredList([self.port.login(c, None, ITestable) | |
279 for c in hashedCreds], consumeErrors=True) | |
280 d.addCallback(self._assertFailures, error.UnhandledCredentials) | |
281 return d | |
282 | |
283 def _assertFailures(self, failures, *expectedFailures): | |
284 for flag, failure in failures: | |
285 self.failUnlessEqual(flag, defer.FAILURE) | |
286 failure.trap(*expectedFailures) | |
287 return None | |
288 | |
289 if crypt is None: | |
290 skip = "crypt module not available" | |
291 | |
292 class PluggableAuthenticationModulesTest(unittest.TestCase): | |
293 | |
294 def setUp(self): | |
295 """ | |
296 Replace L{pamauth.callIntoPAM} with a dummy implementation with | |
297 easily-controlled behavior. | |
298 """ | |
299 self._oldCallIntoPAM = pamauth.callIntoPAM | |
300 pamauth.callIntoPAM = self.callIntoPAM | |
301 | |
302 | |
303 def tearDown(self): | |
304 """ | |
305 Restore the original value of L{pamauth.callIntoPAM}. | |
306 """ | |
307 pamauth.callIntoPAM = self._oldCallIntoPAM | |
308 | |
309 | |
310 def callIntoPAM(self, service, user, conv): | |
311 if service != 'Twisted': | |
312 raise error.UnauthorizedLogin('bad service: %s' % service) | |
313 if user != 'testuser': | |
314 raise error.UnauthorizedLogin('bad username: %s' % user) | |
315 questions = [ | |
316 (1, "Password"), | |
317 (2, "Message w/ Input"), | |
318 (3, "Message w/o Input"), | |
319 ] | |
320 replies = conv(questions) | |
321 if replies != [ | |
322 ("password", 0), | |
323 ("entry", 0), | |
324 ("", 0) | |
325 ]: | |
326 raise error.UnauthorizedLogin('bad conversion: %s' % repr(replie
s)) | |
327 return 1 | |
328 | |
329 def _makeConv(self, d): | |
330 def conv(questions): | |
331 return defer.succeed([(d[t], 0) for t, q in questions]) | |
332 return conv | |
333 | |
334 def testRequestAvatarId(self): | |
335 db = checkers.PluggableAuthenticationModulesChecker() | |
336 conv = self._makeConv({1:'password', 2:'entry', 3:''}) | |
337 creds = credentials.PluggableAuthenticationModules('testuser', | |
338 conv) | |
339 d = db.requestAvatarId(creds) | |
340 d.addCallback(self.assertEquals, 'testuser') | |
341 return d | |
342 | |
343 def testBadCredentials(self): | |
344 db = checkers.PluggableAuthenticationModulesChecker() | |
345 conv = self._makeConv({1:'', 2:'', 3:''}) | |
346 creds = credentials.PluggableAuthenticationModules('testuser', | |
347 conv) | |
348 d = db.requestAvatarId(creds) | |
349 self.assertFailure(d, error.UnauthorizedLogin) | |
350 return d | |
351 | |
352 def testBadUsername(self): | |
353 db = checkers.PluggableAuthenticationModulesChecker() | |
354 conv = self._makeConv({1:'password', 2:'entry', 3:''}) | |
355 creds = credentials.PluggableAuthenticationModules('baduser', | |
356 conv) | |
357 d = db.requestAvatarId(creds) | |
358 self.assertFailure(d, error.UnauthorizedLogin) | |
359 return d | |
360 | |
361 if not pamauth: | |
362 skip = "Can't run without PyPAM" | |
363 | |
364 class CheckersMixin: | |
365 def testPositive(self): | |
366 for chk in self.getCheckers(): | |
367 for (cred, avatarId) in self.getGoodCredentials(): | |
368 r = wFD(chk.requestAvatarId(cred)) | |
369 yield r | |
370 self.assertEquals(r.getResult(), avatarId) | |
371 testPositive = dG(testPositive) | |
372 | |
373 def testNegative(self): | |
374 for chk in self.getCheckers(): | |
375 for cred in self.getBadCredentials(): | |
376 r = wFD(chk.requestAvatarId(cred)) | |
377 yield r | |
378 self.assertRaises(error.UnauthorizedLogin, r.getResult) | |
379 testNegative = dG(testNegative) | |
380 | |
381 class HashlessFilePasswordDBMixin: | |
382 credClass = credentials.UsernamePassword | |
383 diskHash = None | |
384 networkHash = staticmethod(lambda x: x) | |
385 | |
386 _validCredentials = [ | |
387 ('user1', 'password1'), | |
388 ('user2', 'password2'), | |
389 ('user3', 'password3')] | |
390 | |
391 def getGoodCredentials(self): | |
392 for u, p in self._validCredentials: | |
393 yield self.credClass(u, self.networkHash(p)), u | |
394 | |
395 def getBadCredentials(self): | |
396 for u, p in [('user1', 'password3'), | |
397 ('user2', 'password1'), | |
398 ('bloof', 'blarf')]: | |
399 yield self.credClass(u, self.networkHash(p)) | |
400 | |
401 def getCheckers(self): | |
402 diskHash = self.diskHash or (lambda x: x) | |
403 hashCheck = self.diskHash and (lambda username, password, stored: self.d
iskHash(password)) | |
404 | |
405 for cache in True, False: | |
406 fn = self.mktemp() | |
407 fObj = file(fn, 'w') | |
408 for u, p in self._validCredentials: | |
409 fObj.write('%s:%s\n' % (u, diskHash(p))) | |
410 fObj.close() | |
411 yield checkers.FilePasswordDB(fn, cache=cache, hash=hashCheck) | |
412 | |
413 fn = self.mktemp() | |
414 fObj = file(fn, 'w') | |
415 for u, p in self._validCredentials: | |
416 fObj.write('%s dingle dongle %s\n' % (diskHash(p), u)) | |
417 fObj.close() | |
418 yield checkers.FilePasswordDB(fn, ' ', 3, 0, cache=cache, hash=hashC
heck) | |
419 | |
420 fn = self.mktemp() | |
421 fObj = file(fn, 'w') | |
422 for u, p in self._validCredentials: | |
423 fObj.write('zip,zap,%s,zup,%s\n' % (u.title(), diskHash(p))) | |
424 fObj.close() | |
425 yield checkers.FilePasswordDB(fn, ',', 2, 4, False, cache=cache, has
h=hashCheck) | |
426 | |
427 class LocallyHashedFilePasswordDBMixin(HashlessFilePasswordDBMixin): | |
428 diskHash = staticmethod(lambda x: x.encode('hex')) | |
429 | |
430 class NetworkHashedFilePasswordDBMixin(HashlessFilePasswordDBMixin): | |
431 networkHash = staticmethod(lambda x: x.encode('hex')) | |
432 class credClass(credentials.UsernameHashedPassword): | |
433 def checkPassword(self, password): | |
434 return self.hashed.decode('hex') == password | |
435 | |
436 class HashlessFilePasswordDBCheckerTestCase(HashlessFilePasswordDBMixin, Checker
sMixin, unittest.TestCase): | |
437 pass | |
438 | |
439 class LocallyHashedFilePasswordDBCheckerTestCase(LocallyHashedFilePasswordDBMixi
n, CheckersMixin, unittest.TestCase): | |
440 pass | |
441 | |
442 class NetworkHashedFilePasswordDBCheckerTestCase(NetworkHashedFilePasswordDBMixi
n, CheckersMixin, unittest.TestCase): | |
443 pass | |
444 | |
OLD | NEW |