Index: third_party/tlslite/tlslite/utils/cryptomath.py |
diff --git a/third_party/tlslite/tlslite/utils/cryptomath.py b/third_party/tlslite/tlslite/utils/cryptomath.py |
index 86da25e52498d44358a720d0f17c1f170d2bfbe0..30354b2c56cebe688c9e80fb056979268dee3c92 100644 |
--- a/third_party/tlslite/tlslite/utils/cryptomath.py |
+++ b/third_party/tlslite/tlslite/utils/cryptomath.py |
@@ -1,25 +1,19 @@ |
+# Authors: |
+# Trevor Perrin |
+# Martin von Loewis - python 3 port |
+# |
+# See the LICENSE file for legal information regarding use of this file. |
+ |
"""cryptomath module |
This module has basic math/crypto code.""" |
- |
+from __future__ import print_function |
import os |
import math |
import base64 |
import binascii |
-# The sha module is deprecated in Python 2.6 |
-try: |
- import sha |
-except ImportError: |
- from hashlib import sha1 as sha |
- |
-# The md5 module is deprecated in Python 2.6 |
-try: |
- import md5 |
-except ImportError: |
- from hashlib import md5 |
- |
-from compat import * |
+from .compat import * |
# ************************************************************************** |
@@ -34,23 +28,6 @@ try: |
except ImportError: |
m2cryptoLoaded = False |
- |
-# Try to load cryptlib |
-try: |
- import cryptlib_py |
- try: |
- cryptlib_py.cryptInit() |
- except cryptlib_py.CryptException, e: |
- #If tlslite and cryptoIDlib are both present, |
- #they might each try to re-initialize this, |
- #so we're tolerant of that. |
- if e[0] != cryptlib_py.CRYPT_ERROR_INITED: |
- raise |
- cryptlibpyLoaded = True |
- |
-except ImportError: |
- cryptlibpyLoaded = False |
- |
#Try to load GMPY |
try: |
import gmpy |
@@ -70,151 +47,116 @@ except ImportError: |
# PRNG Functions |
# ************************************************************************** |
-# Get os.urandom PRNG |
-try: |
- os.urandom(1) |
- def getRandomBytes(howMany): |
- return stringToBytes(os.urandom(howMany)) |
- prngName = "os.urandom" |
- |
-except: |
- # Else get cryptlib PRNG |
- if cryptlibpyLoaded: |
- def getRandomBytes(howMany): |
- randomKey = cryptlib_py.cryptCreateContext(cryptlib_py.CRYPT_UNUSED, |
- cryptlib_py.CRYPT_ALGO_AES) |
- cryptlib_py.cryptSetAttribute(randomKey, |
- cryptlib_py.CRYPT_CTXINFO_MODE, |
- cryptlib_py.CRYPT_MODE_OFB) |
- cryptlib_py.cryptGenerateKey(randomKey) |
- bytes = createByteArrayZeros(howMany) |
- cryptlib_py.cryptEncrypt(randomKey, bytes) |
- return bytes |
- prngName = "cryptlib" |
- |
- else: |
- #Else get UNIX /dev/urandom PRNG |
- try: |
- devRandomFile = open("/dev/urandom", "rb") |
- def getRandomBytes(howMany): |
- return stringToBytes(devRandomFile.read(howMany)) |
- prngName = "/dev/urandom" |
- except IOError: |
- #Else get Win32 CryptoAPI PRNG |
- try: |
- import win32prng |
- def getRandomBytes(howMany): |
- s = win32prng.getRandomBytes(howMany) |
- if len(s) != howMany: |
- raise AssertionError() |
- return stringToBytes(s) |
- prngName ="CryptoAPI" |
- except ImportError: |
- #Else no PRNG :-( |
- def getRandomBytes(howMany): |
- raise NotImplementedError("No Random Number Generator "\ |
- "available.") |
- prngName = "None" |
+# Check that os.urandom works |
+import zlib |
+length = len(zlib.compress(os.urandom(1000))) |
+assert(length > 900) |
+ |
+def getRandomBytes(howMany): |
+ b = bytearray(os.urandom(howMany)) |
+ assert(len(b) == howMany) |
+ return b |
+ |
+prngName = "os.urandom" |
# ************************************************************************** |
-# Converter Functions |
+# Simple hash functions |
# ************************************************************************** |
-def bytesToNumber(bytes): |
- total = 0L |
- multiplier = 1L |
- for count in range(len(bytes)-1, -1, -1): |
- byte = bytes[count] |
- total += multiplier * byte |
- multiplier *= 256 |
- return total |
+import hmac |
+import hashlib |
-def numberToBytes(n, howManyBytes=None): |
- if howManyBytes == None: |
- howManyBytes = numBytes(n) |
- bytes = createByteArrayZeros(howManyBytes) |
- for count in range(howManyBytes-1, -1, -1): |
- bytes[count] = int(n % 256) |
- n >>= 8 |
- return bytes |
+def MD5(b): |
+ return bytearray(hashlib.md5(compat26Str(b)).digest()) |
-def bytesToBase64(bytes): |
- s = bytesToString(bytes) |
- return stringToBase64(s) |
+def SHA1(b): |
+ return bytearray(hashlib.sha1(compat26Str(b)).digest()) |
-def base64ToBytes(s): |
- s = base64ToString(s) |
- return stringToBytes(s) |
+def HMAC_MD5(k, b): |
+ k = compatHMAC(k) |
+ b = compatHMAC(b) |
+ return bytearray(hmac.new(k, b, hashlib.md5).digest()) |
-def numberToBase64(n): |
- bytes = numberToBytes(n) |
- return bytesToBase64(bytes) |
+def HMAC_SHA1(k, b): |
+ k = compatHMAC(k) |
+ b = compatHMAC(b) |
+ return bytearray(hmac.new(k, b, hashlib.sha1).digest()) |
-def base64ToNumber(s): |
- bytes = base64ToBytes(s) |
- return bytesToNumber(bytes) |
-def stringToNumber(s): |
- bytes = stringToBytes(s) |
- return bytesToNumber(bytes) |
+# ************************************************************************** |
+# Converter Functions |
+# ************************************************************************** |
-def numberToString(s): |
- bytes = numberToBytes(s) |
- return bytesToString(bytes) |
+def bytesToNumber(b): |
+ total = 0 |
+ multiplier = 1 |
+ for count in range(len(b)-1, -1, -1): |
+ byte = b[count] |
+ total += multiplier * byte |
+ multiplier *= 256 |
+ # Force-cast to long to appease PyCrypto. |
+ # https://github.com/trevp/tlslite/issues/15 |
+ return long(total) |
-def base64ToString(s): |
- try: |
- return base64.decodestring(s) |
- except binascii.Error, e: |
- raise SyntaxError(e) |
- except binascii.Incomplete, e: |
- raise SyntaxError(e) |
+def numberToByteArray(n, howManyBytes=None): |
+ """Convert an integer into a bytearray, zero-pad to howManyBytes. |
-def stringToBase64(s): |
- return base64.encodestring(s).replace("\n", "") |
+ The returned bytearray may be smaller than howManyBytes, but will |
+ not be larger. The returned bytearray will contain a big-endian |
+ encoding of the input integer (n). |
+ """ |
+ if howManyBytes == None: |
+ howManyBytes = numBytes(n) |
+ b = bytearray(howManyBytes) |
+ for count in range(howManyBytes-1, -1, -1): |
+ b[count] = int(n % 256) |
+ n >>= 8 |
+ return b |
def mpiToNumber(mpi): #mpi is an openssl-format bignum string |
if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number |
raise AssertionError() |
- bytes = stringToBytes(mpi[4:]) |
- return bytesToNumber(bytes) |
+ b = bytearray(mpi[4:]) |
+ return bytesToNumber(b) |
def numberToMPI(n): |
- bytes = numberToBytes(n) |
+ b = numberToByteArray(n) |
ext = 0 |
#If the high-order bit is going to be set, |
#add an extra byte of zeros |
if (numBits(n) & 0x7)==0: |
ext = 1 |
length = numBytes(n) + ext |
- bytes = concatArrays(createByteArrayZeros(4+ext), bytes) |
- bytes[0] = (length >> 24) & 0xFF |
- bytes[1] = (length >> 16) & 0xFF |
- bytes[2] = (length >> 8) & 0xFF |
- bytes[3] = length & 0xFF |
- return bytesToString(bytes) |
- |
+ b = bytearray(4+ext) + b |
+ b[0] = (length >> 24) & 0xFF |
+ b[1] = (length >> 16) & 0xFF |
+ b[2] = (length >> 8) & 0xFF |
+ b[3] = length & 0xFF |
+ return bytes(b) |
# ************************************************************************** |
# Misc. Utility Functions |
# ************************************************************************** |
+def numBits(n): |
+ if n==0: |
+ return 0 |
+ s = "%x" % n |
+ return ((len(s)-1)*4) + \ |
+ {'0':0, '1':1, '2':2, '3':2, |
+ '4':3, '5':3, '6':3, '7':3, |
+ '8':4, '9':4, 'a':4, 'b':4, |
+ 'c':4, 'd':4, 'e':4, 'f':4, |
+ }[s[0]] |
+ return int(math.floor(math.log(n, 2))+1) |
+ |
def numBytes(n): |
if n==0: |
return 0 |
bits = numBits(n) |
return int(math.ceil(bits / 8.0)) |
-def hashAndBase64(s): |
- return stringToBase64(sha.sha(s).digest()) |
- |
-def getBase64Nonce(numChars=22): #defaults to an 132 bit nonce |
- bytes = getRandomBytes(numChars) |
- bytesStr = "".join([chr(b) for b in bytes]) |
- return stringToBase64(bytesStr)[:numChars] |
- |
- |
# ************************************************************************** |
# Big Number Math |
# ************************************************************************** |
@@ -240,9 +182,7 @@ def gcd(a,b): |
return a |
def lcm(a, b): |
- #This will break when python division changes, but we can't use // cause |
- #of Jython |
- return (a * b) / gcd(a, b) |
+ return (a * b) // gcd(a, b) |
#Returns inverse of a mod b, zero if none |
#Uses Extended Euclidean Algorithm |
@@ -250,9 +190,7 @@ def invMod(a, b): |
c, d = a, b |
uc, ud = 1, 0 |
while c != 0: |
- #This will break when python division changes, but we can't use // |
- #cause of Jython |
- q = d / c |
+ q = d // c |
c, d = d-(q*c), c |
uc, ud = ud - (q * uc), uc |
if d == 1: |
@@ -269,61 +207,17 @@ if gmpyLoaded: |
return long(result) |
else: |
- #Copied from Bryan G. Olson's post to comp.lang.python |
- #Does left-to-right instead of pow()'s right-to-left, |
- #thus about 30% faster than the python built-in with small bases |
def powMod(base, power, modulus): |
- nBitScan = 5 |
- |
- """ Return base**power mod modulus, using multi bit scanning |
- with nBitScan bits at a time.""" |
- |
- #TREV - Added support for negative exponents |
- negativeResult = False |
- if (power < 0): |
- power *= -1 |
- negativeResult = True |
- |
- exp2 = 2**nBitScan |
- mask = exp2 - 1 |
- |
- # Break power into a list of digits of nBitScan bits. |
- # The list is recursive so easy to read in reverse direction. |
- nibbles = None |
- while power: |
- nibbles = int(power & mask), nibbles |
- power = power >> nBitScan |
- |
- # Make a table of powers of base up to 2**nBitScan - 1 |
- lowPowers = [1] |
- for i in xrange(1, exp2): |
- lowPowers.append((lowPowers[i-1] * base) % modulus) |
- |
- # To exponentiate by the first nibble, look it up in the table |
- nib, nibbles = nibbles |
- prod = lowPowers[nib] |
- |
- # For the rest, square nBitScan times, then multiply by |
- # base^nibble |
- while nibbles: |
- nib, nibbles = nibbles |
- for i in xrange(nBitScan): |
- prod = (prod * prod) % modulus |
- if nib: prod = (prod * lowPowers[nib]) % modulus |
- |
- #TREV - Added support for negative exponents |
- if negativeResult: |
- prodInv = invMod(prod, modulus) |
- #Check to make sure the inverse is correct |
- if (prod * prodInv) % modulus != 1: |
- raise AssertionError() |
- return prodInv |
- return prod |
- |
+ if power < 0: |
+ result = pow(base, power*-1, modulus) |
+ result = invMod(result, modulus) |
+ return result |
+ else: |
+ return pow(base, power, modulus) |
#Pre-calculate a sieve of the ~100 primes < 1000: |
def makeSieve(n): |
- sieve = range(n) |
+ sieve = list(range(n)) |
for count in range(2, int(math.sqrt(n))): |
if sieve[count] == 0: |
continue |
@@ -344,10 +238,10 @@ def isPrime(n, iterations=5, display=False): |
#Passed trial division, proceed to Rabin-Miller |
#Rabin-Miller implemented per Ferguson & Schneier |
#Compute s, t for Rabin-Miller |
- if display: print "*", |
+ if display: print("*", end=' ') |
s, t = n-1, 0 |
while s % 2 == 0: |
- s, t = s/2, t+1 |
+ s, t = s//2, t+1 |
#Repeat Rabin-Miller x times |
a = 2 #Use 2 as a base for first iteration speedup, per HAC |
for count in range(iterations): |
@@ -371,12 +265,12 @@ def getRandomPrime(bits, display=False): |
# |
#Since 30 is lcm(2,3,5), we'll set our test numbers to |
#29 % 30 and keep them there |
- low = (2L ** (bits-1)) * 3/2 |
- high = 2L ** bits - 30 |
+ low = ((2 ** (bits-1)) * 3) // 2 |
+ high = 2 ** bits - 30 |
p = getRandomNumber(low, high) |
p += 29 - (p % 30) |
while 1: |
- if display: print ".", |
+ if display: print(".", end=' ') |
p += 30 |
if p >= high: |
p = getRandomNumber(low, high) |
@@ -393,12 +287,12 @@ def getRandomSafePrime(bits, display=False): |
# |
#Since 30 is lcm(2,3,5), we'll set our test numbers to |
#29 % 30 and keep them there |
- low = (2 ** (bits-2)) * 3/2 |
+ low = (2 ** (bits-2)) * 3//2 |
high = (2 ** (bits-1)) - 30 |
q = getRandomNumber(low, high) |
q += 29 - (q % 30) |
while 1: |
- if display: print ".", |
+ if display: print(".", end=' ') |
q += 30 |
if (q >= high): |
q = getRandomNumber(low, high) |