Skip to content

Instantly share code, notes, and snippets.

@anirudhjayaraman
Last active September 11, 2023 10:18
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save anirudhjayaraman/5ad211982fd68a88fa37 to your computer and use it in GitHub Desktop.
Save anirudhjayaraman/5ad211982fd68a88fa37 to your computer and use it in GitHub Desktop.
Karatsuba Multiplication
def karatsuba(x,y):
"""Function to multiply 2 numbers in a more efficient manner than the grade school algorithm"""
if len(str(x)) == 1 or len(str(y)) == 1:
return x*y
else:
n = max(len(str(x)),len(str(y)))
nby2 = n / 2
a = x / 10**(nby2)
b = x % 10**(nby2)
c = y / 10**(nby2)
d = y % 10**(nby2)
ac = karatsuba(a,c)
bd = karatsuba(b,d)
ad_plus_bc = karatsuba(a+b,c+d) - ac - bd
# this little trick, writing n as 2*nby2 takes care of both even and odd n
prod = ac * 10**(2*nby2) + (ad_plus_bc * 10**nby2) + bd
return prod
@rajashit14
Copy link

#Here is the right code
def karatsuba(x,y):
if len(str(x))==1 or len(str(y))==1:
return xy
else:
n=max(len(str(x)),len(str(y)))//2
a=x//10n
b=x%10
n
c=y//10n
d=y%10
n
ac=karatsuba(a,c)
bd=karatsuba(b,d)
ad_plus_bc=karatsuba(a+b,c+d)-ac-bd
return ac
10**(2n)+ad_plus_bc10**n+bd
x=int(input())
y=int(input())
print(karatsuba(x,y))

@mkhuzaima
Copy link

mkhuzaima commented Sep 22, 2021

I think at lines 9 and 11, there should be floor division (//) instead of simple division (/). They should be like:

  1. a = x // 10**(nby2)
  1. c = y // 10**(nby2)

@rupeshknn
Copy link

Correct me if I'm wrong, but using exponents in a multiplication algorithm feels like a cheat. Multiplication is a basic operation; the exponentiation calls will likely invoke a pre-implemented multiplication algorithm. Ideally, one should use string manipulation here.

@rupeshknn
Copy link

rupeshknn commented Sep 11, 2023

Also, using the Gauss trick mentioned in the video (I think) will only work when the inputs are the same size. Which is why some implementations above fail for longer numbers, as somewhere in the recursive calls, ad and bc have different powers of 10.

here's an example where the Gauss trick fails

x= 1346, y=16
a = 13, b=46
c=1, d=6

ac = 13, bd = 276
(a+b)(c+d) = (59)(7) = 413
(a+b)(c+d) - ac - bd = 124

so the product should be:
term1 =  13*10**(2+1) = 13000
term2 = 276
term3 = 12400
Karatsuba product = 25676
actual product = 21536

This is because if we expand the Gauss trick, you will see ad and bc have different powers of 10 attached to them.
ad + bc = 100*(13*6) + 10*(1*46) = 8260, which, when added to the rest of the terms, gives the correct product.

Here is my implementation btw:

from math import ceil

def abcd(num_s, numlen):
    if numlen < 2:
        return '0', num_s
    else:
        splitidx = ceil(numlen/2)
        return num_s[:splitidx], num_s[splitidx:]

def karatsuba(num1, num2):
    num1_s = str(num1)
    num2_s = str(num2)
    lennum1 = len(num1_s)
    lennum2 = len(num2_s)
    if lennum1==1 and lennum2 == 1:
        return int(num1) * int(num2)
    splitidx1 = lennum1//2
    splitidx2 = lennum2//2
    a,b = abcd(num1_s, lennum1)
    c,d = abcd(num2_s, lennum2)
    
    ac = karatsuba(a,c)
    bd = karatsuba(b,d)
    ad = karatsuba(a,d)
    bc = karatsuba(b,c)
    term1 = int(str(ac) + "0"*(splitidx1 + splitidx2))
    term2 = int(bd)
    term3 = int(str(ad) + "0"*splitidx1)
    term4 = int(str(bc) + "0"*splitidx2)

    final_sol = term1 + term2 + term3 + term4

    return final_sol

I verified it works using the code below:

import numpy as np

rand_ints = np.random.randint(1,int(1e9),(10000,2),dtype='int64')
true_product = rand_ints.prod(axis=1)
karatsuba_product = [karatsuba(val[0],val[1]) for val in rand_ints]
np.alltrue(true_product == karatsuba_product)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment