OLD | NEW |
| 1 # Authors: |
| 2 # Trevor Perrin |
| 3 # Martin von Loewis - python 3 port |
| 4 # |
| 5 # See the LICENSE file for legal information regarding use of this file. |
| 6 |
1 """cryptomath module | 7 """cryptomath module |
2 | 8 |
3 This module has basic math/crypto code.""" | 9 This module has basic math/crypto code.""" |
4 | 10 from __future__ import print_function |
5 import os | 11 import os |
6 import math | 12 import math |
7 import base64 | 13 import base64 |
8 import binascii | 14 import binascii |
9 | 15 |
10 # The sha module is deprecated in Python 2.6 | 16 from .compat import * |
11 try: | |
12 import sha | |
13 except ImportError: | |
14 from hashlib import sha1 as sha | |
15 | |
16 # The md5 module is deprecated in Python 2.6 | |
17 try: | |
18 import md5 | |
19 except ImportError: | |
20 from hashlib import md5 | |
21 | |
22 from compat import * | |
23 | 17 |
24 | 18 |
25 # ************************************************************************** | 19 # ************************************************************************** |
26 # Load Optional Modules | 20 # Load Optional Modules |
27 # ************************************************************************** | 21 # ************************************************************************** |
28 | 22 |
29 # Try to load M2Crypto/OpenSSL | 23 # Try to load M2Crypto/OpenSSL |
30 try: | 24 try: |
31 from M2Crypto import m2 | 25 from M2Crypto import m2 |
32 m2cryptoLoaded = True | 26 m2cryptoLoaded = True |
33 | 27 |
34 except ImportError: | 28 except ImportError: |
35 m2cryptoLoaded = False | 29 m2cryptoLoaded = False |
36 | 30 |
37 | |
38 # Try to load cryptlib | |
39 try: | |
40 import cryptlib_py | |
41 try: | |
42 cryptlib_py.cryptInit() | |
43 except cryptlib_py.CryptException, e: | |
44 #If tlslite and cryptoIDlib are both present, | |
45 #they might each try to re-initialize this, | |
46 #so we're tolerant of that. | |
47 if e[0] != cryptlib_py.CRYPT_ERROR_INITED: | |
48 raise | |
49 cryptlibpyLoaded = True | |
50 | |
51 except ImportError: | |
52 cryptlibpyLoaded = False | |
53 | |
54 #Try to load GMPY | 31 #Try to load GMPY |
55 try: | 32 try: |
56 import gmpy | 33 import gmpy |
57 gmpyLoaded = True | 34 gmpyLoaded = True |
58 except ImportError: | 35 except ImportError: |
59 gmpyLoaded = False | 36 gmpyLoaded = False |
60 | 37 |
61 #Try to load pycrypto | 38 #Try to load pycrypto |
62 try: | 39 try: |
63 import Crypto.Cipher.AES | 40 import Crypto.Cipher.AES |
64 pycryptoLoaded = True | 41 pycryptoLoaded = True |
65 except ImportError: | 42 except ImportError: |
66 pycryptoLoaded = False | 43 pycryptoLoaded = False |
67 | 44 |
68 | 45 |
69 # ************************************************************************** | 46 # ************************************************************************** |
70 # PRNG Functions | 47 # PRNG Functions |
71 # ************************************************************************** | 48 # ************************************************************************** |
72 | 49 |
73 # Get os.urandom PRNG | 50 # Check that os.urandom works |
74 try: | 51 import zlib |
75 os.urandom(1) | 52 length = len(zlib.compress(os.urandom(1000))) |
76 def getRandomBytes(howMany): | 53 assert(length > 900) |
77 return stringToBytes(os.urandom(howMany)) | |
78 prngName = "os.urandom" | |
79 | 54 |
80 except: | 55 def getRandomBytes(howMany): |
81 # Else get cryptlib PRNG | 56 b = bytearray(os.urandom(howMany)) |
82 if cryptlibpyLoaded: | 57 assert(len(b) == howMany) |
83 def getRandomBytes(howMany): | 58 return b |
84 randomKey = cryptlib_py.cryptCreateContext(cryptlib_py.CRYPT_UNUSED, | |
85 cryptlib_py.CRYPT_ALGO_AE
S) | |
86 cryptlib_py.cryptSetAttribute(randomKey, | |
87 cryptlib_py.CRYPT_CTXINFO_MODE, | |
88 cryptlib_py.CRYPT_MODE_OFB) | |
89 cryptlib_py.cryptGenerateKey(randomKey) | |
90 bytes = createByteArrayZeros(howMany) | |
91 cryptlib_py.cryptEncrypt(randomKey, bytes) | |
92 return bytes | |
93 prngName = "cryptlib" | |
94 | 59 |
95 else: | 60 prngName = "os.urandom" |
96 #Else get UNIX /dev/urandom PRNG | 61 |
97 try: | 62 # ************************************************************************** |
98 devRandomFile = open("/dev/urandom", "rb") | 63 # Simple hash functions |
99 def getRandomBytes(howMany): | 64 # ************************************************************************** |
100 return stringToBytes(devRandomFile.read(howMany)) | 65 |
101 prngName = "/dev/urandom" | 66 import hmac |
102 except IOError: | 67 import hashlib |
103 #Else get Win32 CryptoAPI PRNG | 68 |
104 try: | 69 def MD5(b): |
105 import win32prng | 70 return bytearray(hashlib.md5(compat26Str(b)).digest()) |
106 def getRandomBytes(howMany): | 71 |
107 s = win32prng.getRandomBytes(howMany) | 72 def SHA1(b): |
108 if len(s) != howMany: | 73 return bytearray(hashlib.sha1(compat26Str(b)).digest()) |
109 raise AssertionError() | 74 |
110 return stringToBytes(s) | 75 def HMAC_MD5(k, b): |
111 prngName ="CryptoAPI" | 76 k = compatHMAC(k) |
112 except ImportError: | 77 b = compatHMAC(b) |
113 #Else no PRNG :-( | 78 return bytearray(hmac.new(k, b, hashlib.md5).digest()) |
114 def getRandomBytes(howMany): | 79 |
115 raise NotImplementedError("No Random Number Generator "\ | 80 def HMAC_SHA1(k, b): |
116 "available.") | 81 k = compatHMAC(k) |
117 prngName = "None" | 82 b = compatHMAC(b) |
| 83 return bytearray(hmac.new(k, b, hashlib.sha1).digest()) |
| 84 |
118 | 85 |
119 # ************************************************************************** | 86 # ************************************************************************** |
120 # Converter Functions | 87 # Converter Functions |
121 # ************************************************************************** | 88 # ************************************************************************** |
122 | 89 |
123 def bytesToNumber(bytes): | 90 def bytesToNumber(b): |
124 total = 0L | 91 total = 0 |
125 multiplier = 1L | 92 multiplier = 1 |
126 for count in range(len(bytes)-1, -1, -1): | 93 for count in range(len(b)-1, -1, -1): |
127 byte = bytes[count] | 94 byte = b[count] |
128 total += multiplier * byte | 95 total += multiplier * byte |
129 multiplier *= 256 | 96 multiplier *= 256 |
130 return total | 97 # Force-cast to long to appease PyCrypto. |
| 98 # https://github.com/trevp/tlslite/issues/15 |
| 99 return long(total) |
131 | 100 |
132 def numberToBytes(n, howManyBytes=None): | 101 def numberToByteArray(n, howManyBytes=None): |
| 102 """Convert an integer into a bytearray, zero-pad to howManyBytes. |
| 103 |
| 104 The returned bytearray may be smaller than howManyBytes, but will |
| 105 not be larger. The returned bytearray will contain a big-endian |
| 106 encoding of the input integer (n). |
| 107 """ |
133 if howManyBytes == None: | 108 if howManyBytes == None: |
134 howManyBytes = numBytes(n) | 109 howManyBytes = numBytes(n) |
135 bytes = createByteArrayZeros(howManyBytes) | 110 b = bytearray(howManyBytes) |
136 for count in range(howManyBytes-1, -1, -1): | 111 for count in range(howManyBytes-1, -1, -1): |
137 bytes[count] = int(n % 256) | 112 b[count] = int(n % 256) |
138 n >>= 8 | 113 n >>= 8 |
139 return bytes | 114 return b |
140 | |
141 def bytesToBase64(bytes): | |
142 s = bytesToString(bytes) | |
143 return stringToBase64(s) | |
144 | |
145 def base64ToBytes(s): | |
146 s = base64ToString(s) | |
147 return stringToBytes(s) | |
148 | |
149 def numberToBase64(n): | |
150 bytes = numberToBytes(n) | |
151 return bytesToBase64(bytes) | |
152 | |
153 def base64ToNumber(s): | |
154 bytes = base64ToBytes(s) | |
155 return bytesToNumber(bytes) | |
156 | |
157 def stringToNumber(s): | |
158 bytes = stringToBytes(s) | |
159 return bytesToNumber(bytes) | |
160 | |
161 def numberToString(s): | |
162 bytes = numberToBytes(s) | |
163 return bytesToString(bytes) | |
164 | |
165 def base64ToString(s): | |
166 try: | |
167 return base64.decodestring(s) | |
168 except binascii.Error, e: | |
169 raise SyntaxError(e) | |
170 except binascii.Incomplete, e: | |
171 raise SyntaxError(e) | |
172 | |
173 def stringToBase64(s): | |
174 return base64.encodestring(s).replace("\n", "") | |
175 | 115 |
176 def mpiToNumber(mpi): #mpi is an openssl-format bignum string | 116 def mpiToNumber(mpi): #mpi is an openssl-format bignum string |
177 if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number | 117 if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number |
178 raise AssertionError() | 118 raise AssertionError() |
179 bytes = stringToBytes(mpi[4:]) | 119 b = bytearray(mpi[4:]) |
180 return bytesToNumber(bytes) | 120 return bytesToNumber(b) |
181 | 121 |
182 def numberToMPI(n): | 122 def numberToMPI(n): |
183 bytes = numberToBytes(n) | 123 b = numberToByteArray(n) |
184 ext = 0 | 124 ext = 0 |
185 #If the high-order bit is going to be set, | 125 #If the high-order bit is going to be set, |
186 #add an extra byte of zeros | 126 #add an extra byte of zeros |
187 if (numBits(n) & 0x7)==0: | 127 if (numBits(n) & 0x7)==0: |
188 ext = 1 | 128 ext = 1 |
189 length = numBytes(n) + ext | 129 length = numBytes(n) + ext |
190 bytes = concatArrays(createByteArrayZeros(4+ext), bytes) | 130 b = bytearray(4+ext) + b |
191 bytes[0] = (length >> 24) & 0xFF | 131 b[0] = (length >> 24) & 0xFF |
192 bytes[1] = (length >> 16) & 0xFF | 132 b[1] = (length >> 16) & 0xFF |
193 bytes[2] = (length >> 8) & 0xFF | 133 b[2] = (length >> 8) & 0xFF |
194 bytes[3] = length & 0xFF | 134 b[3] = length & 0xFF |
195 return bytesToString(bytes) | 135 return bytes(b) |
196 | |
197 | 136 |
198 | 137 |
199 # ************************************************************************** | 138 # ************************************************************************** |
200 # Misc. Utility Functions | 139 # Misc. Utility Functions |
201 # ************************************************************************** | 140 # ************************************************************************** |
202 | 141 |
| 142 def numBits(n): |
| 143 if n==0: |
| 144 return 0 |
| 145 s = "%x" % n |
| 146 return ((len(s)-1)*4) + \ |
| 147 {'0':0, '1':1, '2':2, '3':2, |
| 148 '4':3, '5':3, '6':3, '7':3, |
| 149 '8':4, '9':4, 'a':4, 'b':4, |
| 150 'c':4, 'd':4, 'e':4, 'f':4, |
| 151 }[s[0]] |
| 152 return int(math.floor(math.log(n, 2))+1) |
| 153 |
203 def numBytes(n): | 154 def numBytes(n): |
204 if n==0: | 155 if n==0: |
205 return 0 | 156 return 0 |
206 bits = numBits(n) | 157 bits = numBits(n) |
207 return int(math.ceil(bits / 8.0)) | 158 return int(math.ceil(bits / 8.0)) |
208 | 159 |
209 def hashAndBase64(s): | |
210 return stringToBase64(sha.sha(s).digest()) | |
211 | |
212 def getBase64Nonce(numChars=22): #defaults to an 132 bit nonce | |
213 bytes = getRandomBytes(numChars) | |
214 bytesStr = "".join([chr(b) for b in bytes]) | |
215 return stringToBase64(bytesStr)[:numChars] | |
216 | |
217 | |
218 # ************************************************************************** | 160 # ************************************************************************** |
219 # Big Number Math | 161 # Big Number Math |
220 # ************************************************************************** | 162 # ************************************************************************** |
221 | 163 |
222 def getRandomNumber(low, high): | 164 def getRandomNumber(low, high): |
223 if low >= high: | 165 if low >= high: |
224 raise AssertionError() | 166 raise AssertionError() |
225 howManyBits = numBits(high) | 167 howManyBits = numBits(high) |
226 howManyBytes = numBytes(high) | 168 howManyBytes = numBytes(high) |
227 lastBits = howManyBits % 8 | 169 lastBits = howManyBits % 8 |
228 while 1: | 170 while 1: |
229 bytes = getRandomBytes(howManyBytes) | 171 bytes = getRandomBytes(howManyBytes) |
230 if lastBits: | 172 if lastBits: |
231 bytes[0] = bytes[0] % (1 << lastBits) | 173 bytes[0] = bytes[0] % (1 << lastBits) |
232 n = bytesToNumber(bytes) | 174 n = bytesToNumber(bytes) |
233 if n >= low and n < high: | 175 if n >= low and n < high: |
234 return n | 176 return n |
235 | 177 |
236 def gcd(a,b): | 178 def gcd(a,b): |
237 a, b = max(a,b), min(a,b) | 179 a, b = max(a,b), min(a,b) |
238 while b: | 180 while b: |
239 a, b = b, a % b | 181 a, b = b, a % b |
240 return a | 182 return a |
241 | 183 |
242 def lcm(a, b): | 184 def lcm(a, b): |
243 #This will break when python division changes, but we can't use // cause | 185 return (a * b) // gcd(a, b) |
244 #of Jython | |
245 return (a * b) / gcd(a, b) | |
246 | 186 |
247 #Returns inverse of a mod b, zero if none | 187 #Returns inverse of a mod b, zero if none |
248 #Uses Extended Euclidean Algorithm | 188 #Uses Extended Euclidean Algorithm |
249 def invMod(a, b): | 189 def invMod(a, b): |
250 c, d = a, b | 190 c, d = a, b |
251 uc, ud = 1, 0 | 191 uc, ud = 1, 0 |
252 while c != 0: | 192 while c != 0: |
253 #This will break when python division changes, but we can't use // | 193 q = d // c |
254 #cause of Jython | |
255 q = d / c | |
256 c, d = d-(q*c), c | 194 c, d = d-(q*c), c |
257 uc, ud = ud - (q * uc), uc | 195 uc, ud = ud - (q * uc), uc |
258 if d == 1: | 196 if d == 1: |
259 return ud % b | 197 return ud % b |
260 return 0 | 198 return 0 |
261 | 199 |
262 | 200 |
263 if gmpyLoaded: | 201 if gmpyLoaded: |
264 def powMod(base, power, modulus): | 202 def powMod(base, power, modulus): |
265 base = gmpy.mpz(base) | 203 base = gmpy.mpz(base) |
266 power = gmpy.mpz(power) | 204 power = gmpy.mpz(power) |
267 modulus = gmpy.mpz(modulus) | 205 modulus = gmpy.mpz(modulus) |
268 result = pow(base, power, modulus) | 206 result = pow(base, power, modulus) |
269 return long(result) | 207 return long(result) |
270 | 208 |
271 else: | 209 else: |
272 #Copied from Bryan G. Olson's post to comp.lang.python | |
273 #Does left-to-right instead of pow()'s right-to-left, | |
274 #thus about 30% faster than the python built-in with small bases | |
275 def powMod(base, power, modulus): | 210 def powMod(base, power, modulus): |
276 nBitScan = 5 | 211 if power < 0: |
277 | 212 result = pow(base, power*-1, modulus) |
278 """ Return base**power mod modulus, using multi bit scanning | 213 result = invMod(result, modulus) |
279 with nBitScan bits at a time.""" | 214 return result |
280 | 215 else: |
281 #TREV - Added support for negative exponents | 216 return pow(base, power, modulus) |
282 negativeResult = False | |
283 if (power < 0): | |
284 power *= -1 | |
285 negativeResult = True | |
286 | |
287 exp2 = 2**nBitScan | |
288 mask = exp2 - 1 | |
289 | |
290 # Break power into a list of digits of nBitScan bits. | |
291 # The list is recursive so easy to read in reverse direction. | |
292 nibbles = None | |
293 while power: | |
294 nibbles = int(power & mask), nibbles | |
295 power = power >> nBitScan | |
296 | |
297 # Make a table of powers of base up to 2**nBitScan - 1 | |
298 lowPowers = [1] | |
299 for i in xrange(1, exp2): | |
300 lowPowers.append((lowPowers[i-1] * base) % modulus) | |
301 | |
302 # To exponentiate by the first nibble, look it up in the table | |
303 nib, nibbles = nibbles | |
304 prod = lowPowers[nib] | |
305 | |
306 # For the rest, square nBitScan times, then multiply by | |
307 # base^nibble | |
308 while nibbles: | |
309 nib, nibbles = nibbles | |
310 for i in xrange(nBitScan): | |
311 prod = (prod * prod) % modulus | |
312 if nib: prod = (prod * lowPowers[nib]) % modulus | |
313 | |
314 #TREV - Added support for negative exponents | |
315 if negativeResult: | |
316 prodInv = invMod(prod, modulus) | |
317 #Check to make sure the inverse is correct | |
318 if (prod * prodInv) % modulus != 1: | |
319 raise AssertionError() | |
320 return prodInv | |
321 return prod | |
322 | |
323 | 217 |
324 #Pre-calculate a sieve of the ~100 primes < 1000: | 218 #Pre-calculate a sieve of the ~100 primes < 1000: |
325 def makeSieve(n): | 219 def makeSieve(n): |
326 sieve = range(n) | 220 sieve = list(range(n)) |
327 for count in range(2, int(math.sqrt(n))): | 221 for count in range(2, int(math.sqrt(n))): |
328 if sieve[count] == 0: | 222 if sieve[count] == 0: |
329 continue | 223 continue |
330 x = sieve[count] * 2 | 224 x = sieve[count] * 2 |
331 while x < len(sieve): | 225 while x < len(sieve): |
332 sieve[x] = 0 | 226 sieve[x] = 0 |
333 x += sieve[count] | 227 x += sieve[count] |
334 sieve = [x for x in sieve[2:] if x] | 228 sieve = [x for x in sieve[2:] if x] |
335 return sieve | 229 return sieve |
336 | 230 |
337 sieve = makeSieve(1000) | 231 sieve = makeSieve(1000) |
338 | 232 |
339 def isPrime(n, iterations=5, display=False): | 233 def isPrime(n, iterations=5, display=False): |
340 #Trial division with sieve | 234 #Trial division with sieve |
341 for x in sieve: | 235 for x in sieve: |
342 if x >= n: return True | 236 if x >= n: return True |
343 if n % x == 0: return False | 237 if n % x == 0: return False |
344 #Passed trial division, proceed to Rabin-Miller | 238 #Passed trial division, proceed to Rabin-Miller |
345 #Rabin-Miller implemented per Ferguson & Schneier | 239 #Rabin-Miller implemented per Ferguson & Schneier |
346 #Compute s, t for Rabin-Miller | 240 #Compute s, t for Rabin-Miller |
347 if display: print "*", | 241 if display: print("*", end=' ') |
348 s, t = n-1, 0 | 242 s, t = n-1, 0 |
349 while s % 2 == 0: | 243 while s % 2 == 0: |
350 s, t = s/2, t+1 | 244 s, t = s//2, t+1 |
351 #Repeat Rabin-Miller x times | 245 #Repeat Rabin-Miller x times |
352 a = 2 #Use 2 as a base for first iteration speedup, per HAC | 246 a = 2 #Use 2 as a base for first iteration speedup, per HAC |
353 for count in range(iterations): | 247 for count in range(iterations): |
354 v = powMod(a, s, n) | 248 v = powMod(a, s, n) |
355 if v==1: | 249 if v==1: |
356 continue | 250 continue |
357 i = 0 | 251 i = 0 |
358 while v != n-1: | 252 while v != n-1: |
359 if i == t-1: | 253 if i == t-1: |
360 return False | 254 return False |
361 else: | 255 else: |
362 v, i = powMod(v, 2, n), i+1 | 256 v, i = powMod(v, 2, n), i+1 |
363 a = getRandomNumber(2, n) | 257 a = getRandomNumber(2, n) |
364 return True | 258 return True |
365 | 259 |
366 def getRandomPrime(bits, display=False): | 260 def getRandomPrime(bits, display=False): |
367 if bits < 10: | 261 if bits < 10: |
368 raise AssertionError() | 262 raise AssertionError() |
369 #The 1.5 ensures the 2 MSBs are set | 263 #The 1.5 ensures the 2 MSBs are set |
370 #Thus, when used for p,q in RSA, n will have its MSB set | 264 #Thus, when used for p,q in RSA, n will have its MSB set |
371 # | 265 # |
372 #Since 30 is lcm(2,3,5), we'll set our test numbers to | 266 #Since 30 is lcm(2,3,5), we'll set our test numbers to |
373 #29 % 30 and keep them there | 267 #29 % 30 and keep them there |
374 low = (2L ** (bits-1)) * 3/2 | 268 low = ((2 ** (bits-1)) * 3) // 2 |
375 high = 2L ** bits - 30 | 269 high = 2 ** bits - 30 |
376 p = getRandomNumber(low, high) | 270 p = getRandomNumber(low, high) |
377 p += 29 - (p % 30) | 271 p += 29 - (p % 30) |
378 while 1: | 272 while 1: |
379 if display: print ".", | 273 if display: print(".", end=' ') |
380 p += 30 | 274 p += 30 |
381 if p >= high: | 275 if p >= high: |
382 p = getRandomNumber(low, high) | 276 p = getRandomNumber(low, high) |
383 p += 29 - (p % 30) | 277 p += 29 - (p % 30) |
384 if isPrime(p, display=display): | 278 if isPrime(p, display=display): |
385 return p | 279 return p |
386 | 280 |
387 #Unused at the moment... | 281 #Unused at the moment... |
388 def getRandomSafePrime(bits, display=False): | 282 def getRandomSafePrime(bits, display=False): |
389 if bits < 10: | 283 if bits < 10: |
390 raise AssertionError() | 284 raise AssertionError() |
391 #The 1.5 ensures the 2 MSBs are set | 285 #The 1.5 ensures the 2 MSBs are set |
392 #Thus, when used for p,q in RSA, n will have its MSB set | 286 #Thus, when used for p,q in RSA, n will have its MSB set |
393 # | 287 # |
394 #Since 30 is lcm(2,3,5), we'll set our test numbers to | 288 #Since 30 is lcm(2,3,5), we'll set our test numbers to |
395 #29 % 30 and keep them there | 289 #29 % 30 and keep them there |
396 low = (2 ** (bits-2)) * 3/2 | 290 low = (2 ** (bits-2)) * 3//2 |
397 high = (2 ** (bits-1)) - 30 | 291 high = (2 ** (bits-1)) - 30 |
398 q = getRandomNumber(low, high) | 292 q = getRandomNumber(low, high) |
399 q += 29 - (q % 30) | 293 q += 29 - (q % 30) |
400 while 1: | 294 while 1: |
401 if display: print ".", | 295 if display: print(".", end=' ') |
402 q += 30 | 296 q += 30 |
403 if (q >= high): | 297 if (q >= high): |
404 q = getRandomNumber(low, high) | 298 q = getRandomNumber(low, high) |
405 q += 29 - (q % 30) | 299 q += 29 - (q % 30) |
406 #Ideas from Tom Wu's SRP code | 300 #Ideas from Tom Wu's SRP code |
407 #Do trial division on p and q before Rabin-Miller | 301 #Do trial division on p and q before Rabin-Miller |
408 if isPrime(q, 0, display=display): | 302 if isPrime(q, 0, display=display): |
409 p = (2 * q) + 1 | 303 p = (2 * q) + 1 |
410 if isPrime(p, display=display): | 304 if isPrime(p, display=display): |
411 if isPrime(q, display=display): | 305 if isPrime(q, display=display): |
412 return p | 306 return p |
OLD | NEW |