| 1 | # |
|---|
| 2 | # number.py : Number-theoretic functions |
|---|
| 3 | # |
|---|
| 4 | # Part of the Python Cryptography Toolkit |
|---|
| 5 | # |
|---|
| 6 | # Distribute and use freely; there are no restrictions on further |
|---|
| 7 | # dissemination and usage except those imposed by the laws of your |
|---|
| 8 | # country of residence. This software is provided "as is" without |
|---|
| 9 | # warranty of fitness for use or suitability for any purpose, express |
|---|
| 10 | # or implied. Use at your own risk or not at all. |
|---|
| 11 | # |
|---|
| 12 | |
|---|
| 13 | __revision__ = "$Id: number.py,v 1.13 2003/04/04 18:21:07 akuchling Exp $" |
|---|
| 14 | |
|---|
| 15 | bignum = long |
|---|
| 16 | try: |
|---|
| 17 | from Crypto.PublicKey import _fastmath |
|---|
| 18 | except ImportError: |
|---|
| 19 | _fastmath = None |
|---|
| 20 | |
|---|
| 21 | # Commented out and replaced with faster versions below |
|---|
| 22 | ## def long2str(n): |
|---|
| 23 | ## s='' |
|---|
| 24 | ## while n>0: |
|---|
| 25 | ## s=chr(n & 255)+s |
|---|
| 26 | ## n=n>>8 |
|---|
| 27 | ## return s |
|---|
| 28 | |
|---|
| 29 | ## import types |
|---|
| 30 | ## def str2long(s): |
|---|
| 31 | ## if type(s)!=types.StringType: return s # Integers will be left alone |
|---|
| 32 | ## return reduce(lambda x,y : x*256+ord(y), s, 0L) |
|---|
| 33 | |
|---|
| 34 | def size (N): |
|---|
| 35 | """size(N:long) : int |
|---|
| 36 | Returns the size of the number N in bits. |
|---|
| 37 | """ |
|---|
| 38 | bits, power = 0,1L |
|---|
| 39 | while N >= power: |
|---|
| 40 | bits += 1 |
|---|
| 41 | power = power << 1 |
|---|
| 42 | return bits |
|---|
| 43 | |
|---|
| 44 | def getRandomNumber(N, randfunc): |
|---|
| 45 | """getRandomNumber(N:int, randfunc:callable):long |
|---|
| 46 | Return an N-bit random number.""" |
|---|
| 47 | |
|---|
| 48 | S = randfunc(N/8) |
|---|
| 49 | odd_bits = N % 8 |
|---|
| 50 | if odd_bits != 0: |
|---|
| 51 | char = ord(randfunc(1)) >> (8-odd_bits) |
|---|
| 52 | S = chr(char) + S |
|---|
| 53 | value = bytes_to_long(S) |
|---|
| 54 | value |= 2L ** (N-1) # Ensure high bit is set |
|---|
| 55 | assert size(value) >= N |
|---|
| 56 | return value |
|---|
| 57 | |
|---|
| 58 | def GCD(x,y): |
|---|
| 59 | """GCD(x:long, y:long): long |
|---|
| 60 | Return the GCD of x and y. |
|---|
| 61 | """ |
|---|
| 62 | x = abs(x) ; y = abs(y) |
|---|
| 63 | while x > 0: |
|---|
| 64 | x, y = y % x, x |
|---|
| 65 | return y |
|---|
| 66 | |
|---|
| 67 | def inverse(u, v): |
|---|
| 68 | """inverse(u:long, u:long):long |
|---|
| 69 | Return the inverse of u mod v. |
|---|
| 70 | """ |
|---|
| 71 | u3, v3 = long(u), long(v) |
|---|
| 72 | u1, v1 = 1L, 0L |
|---|
| 73 | while v3 > 0: |
|---|
| 74 | q=u3 / v3 |
|---|
| 75 | u1, v1 = v1, u1 - v1*q |
|---|
| 76 | u3, v3 = v3, u3 - v3*q |
|---|
| 77 | while u1<0: |
|---|
| 78 | u1 = u1 + v |
|---|
| 79 | return u1 |
|---|
| 80 | |
|---|
| 81 | # Given a number of bits to generate and a random generation function, |
|---|
| 82 | # find a prime number of the appropriate size. |
|---|
| 83 | |
|---|
| 84 | def getPrime(N, randfunc): |
|---|
| 85 | """getPrime(N:int, randfunc:callable):long |
|---|
| 86 | Return a random N-bit prime number. |
|---|
| 87 | """ |
|---|
| 88 | |
|---|
| 89 | number=getRandomNumber(N, randfunc) | 1 |
|---|
| 90 | while (not isPrime(number)): |
|---|
| 91 | number=number+2 |
|---|
| 92 | return number |
|---|
| 93 | |
|---|
| 94 | def isPrime(N): |
|---|
| 95 | """isPrime(N:long):bool |
|---|
| 96 | Return true if N is prime. |
|---|
| 97 | """ |
|---|
| 98 | if N == 1: |
|---|
| 99 | return 0 |
|---|
| 100 | if N in sieve: |
|---|
| 101 | return 1 |
|---|
| 102 | for i in sieve: |
|---|
| 103 | if (N % i)==0: |
|---|
| 104 | return 0 |
|---|
| 105 | |
|---|
| 106 | # Use the accelerator if available |
|---|
| 107 | if _fastmath is not None: |
|---|
| 108 | return _fastmath.isPrime(N) |
|---|
| 109 | |
|---|
| 110 | # Compute the highest bit that's set in N |
|---|
| 111 | N1 = N - 1L |
|---|
| 112 | n = 1L |
|---|
| 113 | while (n<N): |
|---|
| 114 | n=n<<1L |
|---|
| 115 | n = n >> 1L |
|---|
| 116 | |
|---|
| 117 | # Rabin-Miller test |
|---|
| 118 | for c in sieve[:7]: |
|---|
| 119 | a=long(c) ; d=1L ; t=n |
|---|
| 120 | while (t): # Iterate over the bits in N1 |
|---|
| 121 | x=(d*d) % N |
|---|
| 122 | if x==1L and d!=1L and d!=N1: |
|---|
| 123 | return 0 # Square root of 1 found |
|---|
| 124 | if N1 & t: |
|---|
| 125 | d=(x*a) % N |
|---|
| 126 | else: |
|---|
| 127 | d=x |
|---|
| 128 | t = t >> 1L |
|---|
| 129 | if d!=1L: |
|---|
| 130 | return 0 |
|---|
| 131 | return 1 |
|---|
| 132 | |
|---|
| 133 | # Small primes used for checking primality; these are all the primes |
|---|
| 134 | # less than 256. This should be enough to eliminate most of the odd |
|---|
| 135 | # numbers before needing to do a Rabin-Miller test at all. |
|---|
| 136 | |
|---|
| 137 | sieve=[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, |
|---|
| 138 | 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, |
|---|
| 139 | 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, |
|---|
| 140 | 197, 199, 211, 223, 227, 229, 233, 239, 241, 251] |
|---|
| 141 | |
|---|
| 142 | # Improved conversion functions contributed by Barry Warsaw, after |
|---|
| 143 | # careful benchmarking |
|---|
| 144 | |
|---|
| 145 | import struct |
|---|
| 146 | |
|---|
| 147 | def long_to_bytes(n, blocksize=0): |
|---|
| 148 | """long_to_bytes(n:long, blocksize:int) : string |
|---|
| 149 | Convert a long integer to a byte string. |
|---|
| 150 | |
|---|
| 151 | If optional blocksize is given and greater than zero, pad the front of the |
|---|
| 152 | byte string with binary zeros so that the length is a multiple of |
|---|
| 153 | blocksize. |
|---|
| 154 | """ |
|---|
| 155 | # after much testing, this algorithm was deemed to be the fastest |
|---|
| 156 | s = '' |
|---|
| 157 | n = long(n) |
|---|
| 158 | pack = struct.pack |
|---|
| 159 | while n > 0: |
|---|
| 160 | s = pack('>I', n & 0xffffffffL) + s |
|---|
| 161 | n = n >> 32 |
|---|
| 162 | # strip off leading zeros |
|---|
| 163 | for i in range(len(s)): |
|---|
| 164 | if s[i] != '\000': |
|---|
| 165 | break |
|---|
| 166 | else: |
|---|
| 167 | # only happens when n == 0 |
|---|
| 168 | s = '\000' |
|---|
| 169 | i = 0 |
|---|
| 170 | s = s[i:] |
|---|
| 171 | # add back some pad bytes. this could be done more efficiently w.r.t. the |
|---|
| 172 | # de-padding being done above, but sigh... |
|---|
| 173 | if blocksize > 0 and len(s) % blocksize: |
|---|
| 174 | s = (blocksize - len(s) % blocksize) * '\000' + s |
|---|
| 175 | return s |
|---|
| 176 | |
|---|
| 177 | def bytes_to_long(s): |
|---|
| 178 | """bytes_to_long(string) : long |
|---|
| 179 | Convert a byte string to a long integer. |
|---|
| 180 | |
|---|
| 181 | This is (essentially) the inverse of long_to_bytes(). |
|---|
| 182 | """ |
|---|
| 183 | acc = 0L |
|---|
| 184 | unpack = struct.unpack |
|---|
| 185 | length = len(s) |
|---|
| 186 | if length % 4: |
|---|
| 187 | extra = (4 - length % 4) |
|---|
| 188 | s = '\000' * extra + s |
|---|
| 189 | length = length + extra |
|---|
| 190 | for i in range(0, length, 4): |
|---|
| 191 | acc = (acc << 32) + unpack('>I', s[i:i+4])[0] |
|---|
| 192 | return acc |
|---|
| 193 | |
|---|
| 194 | # For backwards compatibility... |
|---|
| 195 | import warnings |
|---|
| 196 | def long2str(n, blocksize=0): |
|---|
| 197 | warnings.warn("long2str() has been replaced by long_to_bytes()") |
|---|
| 198 | return long_to_bytes(n, blocksize) |
|---|
| 199 | def str2long(s): |
|---|
| 200 | warnings.warn("str2long() has been replaced by bytes_to_long()") |
|---|
| 201 | return bytes_to_long(s) |
|---|