Skip to content

Instantly share code, notes, and snippets.

@dragonmux

dragonmux/mul.py Secret

Created May 26, 2022 14:07
Show Gist options
  • Save dragonmux/c2c19abc1aa36d3b4227d1279c7d2ade to your computer and use it in GitHub Desktop.
Save dragonmux/c2c19abc1aa36d3b4227d1279c7d2ade to your computer and use it in GitHub Desktop.
Katatsuba method multiplication
from amaranth import *
class Mul(Elaboratable):
def __init__(self):
self.a = Signal(16)
self.b = Signal(16)
self.result = Signal(16)
self.start = Signal()
self.done = Signal()
self.ack = Signal()
def elaborate(self, platform):
m = Module()
a, b, result = self.a, self.b, self.result
intermediate = Signal.like(result)
partial = Signal.like(result)
# https://en.wikipedia.org/wiki/Karatsuba_algorithm - this is the basic principal of operation
with m.FSM():
with m.State('IDLE'):
with m.If(self.start):
m.next = 'CALC_UPPER'
# First we calculate the upper 16 bits of the otherwise 32-bit result
with m.State('CALC_UPPER'):
m.d.sync += intermediate.eq(a[8:16] * b[8:16])
m.next = 'CHECK_HIGH_OVERFLOW'
with m.State('CHECK_HIGH_OVERFLOW'):
# Check if any of them are set - this is instantly overflow.
with m.If(intermediate):
m.d.sync += [
result.eq(0xFFFF),
self.done.eq(1),
]
m.next = 'DONE'
# If they're clear, then we calculate the first of the middle 16 bits [8:24] of the 32-bit result
with m.Else():
m.d.sync += intermediate.eq(a[0:8] * b[8:16])
m.next = 'CHECK_B_OVERFLOW'
with m.State('CHECK_B_OVERFLOW'):
# If the upper 8 are set to anything, we just overflowed
with m.If(intermediate[8:16]):
m.d.sync += [
result.eq(0xFFFF),
self.done.eq(1),
]
m.next = 'DONE'
# Otherwise we can calculate the other middle 16 bits for the 32-bit result, storing the lower 8 from the previous calc
with m.Else():
m.d.sync += [
partial.eq(intermediate[0:8]),
intermediate.eq(a[8:16] * b[0:8]),
]
m.next = 'CHECK_A_OVERFLOW'
with m.State('CHECK_A_OVERFLOW'):
# If the upper 8 are set to anything, we just overflowed.. again..
with m.If(intermediate[8:16]):
m.d.sync += [
result.eq(0xFFFF),
self.done.eq(1),
]
m.next = 'DONE'
# Otherwise we can calculate the bottom 16 bits for the 32-bit result, adding the lower 8 from the previous calc to the partial
with m.Else():
m.d.sync += [
partial.eq(partial + intermediate[0:8]),
intermediate.eq(a[0:8] * b[0:8]),
]
m.next = 'COMBINE'
with m.State('COMBINE'):
# Finally, we combine the partial result and the most recent intermediate to form the complete 16-bit output
with m.If(partial[8]): # If there's a carry from the partial addition, into the upper 8 bits.. we're done
m.d.sync += result.eq(0xFFFF)
with m.Else(): # add it all together - lower 8 bits of partial added to the upper 8 of immediate
m.d.sync += result.eq(intermediate + Cat(Const(0, 8), partial[0:9]))
m.d.sync += self.done.eq(1)
m.next = 'DONE'
# Synchronisation state.
with m.State('DONE'):
with m.If(self.ack):
m.d.sync += self.done.eq(0)
m.next = 'IDLE'
return m
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment