Last active
September 11, 2023 10:18
-
-
Save anirudhjayaraman/5ad211982fd68a88fa37 to your computer and use it in GitHub Desktop.
Karatsuba Multiplication
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def karatsuba(x,y): | |
"""Function to multiply 2 numbers in a more efficient manner than the grade school algorithm""" | |
if len(str(x)) == 1 or len(str(y)) == 1: | |
return x*y | |
else: | |
n = max(len(str(x)),len(str(y))) | |
nby2 = n / 2 | |
a = x / 10**(nby2) | |
b = x % 10**(nby2) | |
c = y / 10**(nby2) | |
d = y % 10**(nby2) | |
ac = karatsuba(a,c) | |
bd = karatsuba(b,d) | |
ad_plus_bc = karatsuba(a+b,c+d) - ac - bd | |
# this little trick, writing n as 2*nby2 takes care of both even and odd n | |
prod = ac * 10**(2*nby2) + (ad_plus_bc * 10**nby2) + bd | |
return prod |
Also, using the Gauss trick mentioned in the video (I think) will only work when the inputs are the same size. Which is why some implementations above fail for longer numbers, as somewhere in the recursive calls, ad
and bc
have different powers of 10.
here's an example where the Gauss trick fails
x= 1346, y=16
a = 13, b=46
c=1, d=6
ac = 13, bd = 276
(a+b)(c+d) = (59)(7) = 413
(a+b)(c+d) - ac - bd = 124
so the product should be:
term1 = 13*10**(2+1) = 13000
term2 = 276
term3 = 12400
Karatsuba product = 25676
actual product = 21536
This is because if we expand the Gauss trick, you will see ad
and bc
have different powers of 10 attached to them.
ad + bc = 100*(13*6) + 10*(1*46) = 8260
, which, when added to the rest of the terms, gives the correct product.
Here is my implementation btw:
from math import ceil
def abcd(num_s, numlen):
if numlen < 2:
return '0', num_s
else:
splitidx = ceil(numlen/2)
return num_s[:splitidx], num_s[splitidx:]
def karatsuba(num1, num2):
num1_s = str(num1)
num2_s = str(num2)
lennum1 = len(num1_s)
lennum2 = len(num2_s)
if lennum1==1 and lennum2 == 1:
return int(num1) * int(num2)
splitidx1 = lennum1//2
splitidx2 = lennum2//2
a,b = abcd(num1_s, lennum1)
c,d = abcd(num2_s, lennum2)
ac = karatsuba(a,c)
bd = karatsuba(b,d)
ad = karatsuba(a,d)
bc = karatsuba(b,c)
term1 = int(str(ac) + "0"*(splitidx1 + splitidx2))
term2 = int(bd)
term3 = int(str(ad) + "0"*splitidx1)
term4 = int(str(bc) + "0"*splitidx2)
final_sol = term1 + term2 + term3 + term4
return final_sol
I verified it works using the code below:
import numpy as np
rand_ints = np.random.randint(1,int(1e9),(10000,2),dtype='int64')
true_product = rand_ints.prod(axis=1)
karatsuba_product = [karatsuba(val[0],val[1]) for val in rand_ints]
np.alltrue(true_product == karatsuba_product)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Correct me if I'm wrong, but using exponents in a multiplication algorithm feels like a cheat. Multiplication is a basic operation; the exponentiation calls will likely invoke a pre-implemented multiplication algorithm. Ideally, one should use string manipulation here.