Skip to content

Instantly share code, notes, and snippets.

@adrianparvino
Created May 25, 2024 21:22
Show Gist options
  • Save adrianparvino/4d935737d4fc03064f6bf20183ec17b1 to your computer and use it in GitHub Desktop.
Save adrianparvino/4d935737d4fc03064f6bf20183ec17b1 to your computer and use it in GitHub Desktop.
import enum
from amaranth import *
from amaranth.lib import wiring, data
from amaranth.hdl.mem import Memory
from amaranth.lib.wiring import In, Out, Component
class Operator(enum.Enum):
ADD = 0
SUB = 1
SIGNED_LT = 2
UNSIGNED_LT = 3
XOR = 4
CONST_B = 5
OR = 6
AND = 7
class BranchType(enum.Enum):
EQ = 0
NE = 1
UNCONDITIONAL = 2
NEVER = 3
LT = 4
GE = 5
LTU = 6
GEU = 7
class UOp(data.Struct):
"""
UOp
"""
alu_op: Operator
branch_type: BranchType
wb_rd: 1
a_as_pc: 1
b_as_imm: 1
class DEResult(data.Struct):
"""
DEResult
"""
rs1: 5
rs2: 5
rd: 5
imm: 32
uop: UOp
class EXMResult(data.Struct):
alu_out: 32
branch_taken: 1
stall: 1
class ALU(Component):
"""
ALU
Attributes
----------
a: Signal, in
The first input.
b: Signal, in
The second input.
op: Operator
The operation to perform.
out: Signal, out
The result of the operation.
"""
a: In(32)
b: In(32)
op: In(Operator)
out: Out(32)
def __init__(self):
super().__init__()
def elaborate(self, platform):
m = Module()
with m.Switch(self.op):
with m.Case(Operator.ADD):
m.d.comb += self.out.eq(self.a + self.b)
with m.Case(Operator.SUB):
m.d.comb += self.out.eq(self.a - self.b)
with m.Case(Operator.SIGNED_LT):
m.d.comb += self.out.eq(self.a.as_signed() < self.b.as_signed())
with m.Case(Operator.UNSIGNED_LT):
m.d.comb += self.out.eq(self.a < self.b)
with m.Case(Operator.XOR):
m.d.comb += self.out.eq(self.a ^ self.b)
with m.Case(Operator.CONST_B):
m.d.comb += self.out.eq(self.b)
with m.Case(Operator.OR):
m.d.comb += self.out.eq(self.a | self.b)
with m.Case(Operator.AND):
m.d.comb += self.out.eq(self.a & self.b)
return m
class Shifter(Component):
"""
Shifter
Attributes
----------
x: Signal, in
The value to shift.
y: Signal, in
The shift amount.
signed: Signal, in
Sign extension
out: Signal, out
The shifted value.
"""
x: In(32)
y: In(32)
signed: In(1)
out: Out(32)
def __init__(self):
super().__init__()
def elaborate(self, platform):
m = Module()
m.d.comb += self.out.eq(
Mux(self.signed,
self.x.as_signed() >> self.y,
self.x >> self.y
)
)
return m
class Immdecoder(Component):
"""
Decodes immediate values.
Attributes
----------
instruction : Signal, in
The instruction to decode.
imm : Signal, out
The immediate value.
"""
instruction: In(32)
imm: Out(32)
def __init__(self):
super().__init__()
def elaborate(self, platform):
m = Module()
j = self.instruction[3]
s = self.instruction[3:7] == 0b0100
b = self.instruction[6] & self.instruction[2:5] == 0b000
u = self.instruction[4] & self.instruction[2]
m.d.comb += self.imm[31].eq(self.instruction[31])
with m.Switch(u):
with m.Case(0):
m.d.comb += self.imm[20:31].eq(self.instruction[31].as_signed())
with m.Case(1):
m.d.comb += self.imm[20:31].eq(self.instruction[20:31]) # U
with m.Switch(u | j):
with m.Case(0):
m.d.comb += self.imm[12:20].eq(self.instruction[31].as_signed())
with m.Case(1):
m.d.comb += self.imm[12:20].eq(self.instruction[12:20]) # U | J
with m.Switch(Cat(b | j, b | u)):
with m.Case(0):
m.d.comb += self.imm[11].eq(self.instruction[31]) # S | I
with m.Case(1):
m.d.comb += self.imm[11].eq(self.instruction[20]) # J
with m.Case(2):
m.d.comb += self.imm[11].eq(0) # U
with m.Case(3):
m.d.comb += self.imm[11].eq(self.instruction[7]) # B
with m.Switch(u):
with m.Case(0):
m.d.comb += self.imm[5:11].eq(self.instruction[25:31])
with m.Case(1):
m.d.comb += self.imm[5:11].eq(0) # U
with m.Switch(Cat(u | j, s | b | u)):
with m.Case(0):
m.d.comb += self.imm[1:5].eq(self.instruction[21:25]) # I
with m.Case(1):
m.d.comb += self.imm[1:5].eq(self.instruction[21:25]) # J
with m.Case(2):
m.d.comb += self.imm[1:5].eq(self.instruction[8:12]) # S | B
with m.Case(3):
m.d.comb += self.imm[1:5].eq(0) # U
with m.Switch(Cat(b | u | j, s)):
with m.Case(0):
m.d.comb += self.imm[0].eq(self.instruction[20]) # I
with m.Case(1):
m.d.comb += self.imm[0].eq(0) # B | U | J
with m.Case(2):
m.d.comb += self.imm[0].eq(self.instruction[7]) # S
with m.Case(3):
m.d.comb += self.imm[0].eq(0) # UNREACHABLE
return m
class Furv(Component):
"""
A RISC-V CPU
Attributes
----------
instruction : Signal, in
The instruction to execute.
pc : Signal, out
The program counter.
(Wishbone Interface)
data_in : Signal, in
The data read from memory.
data_out : Signal, out
The data written to memory.
addr : Signal, out
The word address of the memory access.
sel : Signal, out
The byte select for the memory access.
mem : Signal, out
The memory access is valid.
mem_write : Signal, out
The memory access is a write.
ack : Signal, in
The memory access is acknowledged.
"""
instruction: In(32)
pc: Out(32)
data_in: In(32)
data_out: Out(32)
addr: Out(32)
sel: Out(4)
mem: Out(1)
mem_write: Out(1)
ack: In(1)
reset: In(1)
def __init__(self, start_pc):
## Frontend: Bus Interface Unit
# Instruction fetch
self.biu_pc = Signal(32, reset = start_pc)
self.if_cycles = Signal(1)
# Instruction decode
# self.de_result = Signal(DEResult, reset=Value.cast(DEResult.const({"uop": Value.cast(UOp.const({"branch_type": Value.cast(BranchType.NEVER)}))})))
self.de_result = Signal(DEResult, reset={"uop": { "branch_type": BranchType.NEVER }})
self.de_valid = Signal()
self.de_stall = Signal()
## Backend: Execution Unit
self.exm_pc = Signal(32, reset = start_pc)
# self.exm_jmp = Signal(1)
# self.exm_branch_pc = Signal(32)
# self.exm_branch_taken = Signal(1)
self.exm_result = Signal(EXMResult)
self.exm_stall = Signal(1)
super().__init__()
def elaborate(self, platform):
m = Module()
m.d.comb += ResetSignal().eq(self.reset)
if_valid = self.if_cycles == 1
## Frontend: Bus Interface Unit
# Instruction Fetch
m.d.comb += self.pc.eq(self.biu_pc & ~0x3)
with m.If(~self.de_stall):
m.d.sync += self.biu_pc.eq(self.biu_pc + 4)
with m.If(~if_valid):
m.d.sync += self.if_cycles.eq(self.if_cycles + 1)
with m.If(self.exm_result.branch_taken):
m.d.sync += [
self.if_cycles.eq(0),
self.biu_pc.eq(self.exm_result.alu_out)
]
# Instruction Decode
de_valid = if_valid & ~self.exm_result.branch_taken
m.submodules.immdecoder = immdecoder = Immdecoder()
m.d.comb += immdecoder.instruction.eq(self.instruction)
m.d.sync += self.de_result.imm.eq(immdecoder.imm)
with m.If(~self.exm_stall):
m.d.sync += [
self.de_result.rs1.eq(self.instruction[15:20]),
self.de_result.rs2.eq(self.instruction[20:24]),
self.de_result.rd.eq(self.instruction[7:12]),
self.de_valid.eq(de_valid)
]
with m.Switch(self.instruction[2:7]):
with m.Case(0b11011, 0b01101, 0b00101):
m.d.sync += self.de_result.rs1.eq(0)
with m.Switch(self.instruction[2:7]):
with m.Case(0b00100, 0b00000, 0b11011, 0b11001, 0b01101, 0b00101):
m.d.sync += self.de_result.rs2.eq(0)
# m.d.sync += [
# self.de_result.uop.jmp.eq(de_valid & (self.instruction[4:7] == 0b110))
# ]
with m.Switch(self.instruction[2:7]):
with m.Case(0b01100, 0b00100, 0b00000, 0b01000, 0b01101, 0b00101):
m.d.sync += self.de_result.uop.a_as_pc.eq(0)
with m.Case(0b11000, 0b11011, 0b11001):
m.d.sync += self.de_result.uop.a_as_pc.eq(1)
with m.Switch(self.instruction[2:7]):
with m.Case(0b01100):
m.d.sync += self.de_result.uop.b_as_imm.eq(0)
with m.Case(0b00100, 0b00000, 0b01000, 0b11000, 0b11011, 0b11001, 0b01101, 0b00101):
m.d.sync += self.de_result.uop.b_as_imm.eq(1)
# m.d.sync += self.de_result.uop.alu_op.eq(Operator.ADD)
with m.Switch(self.instruction[2:7]):
with m.Case(0b01100):
with m.Switch(self.instruction[12:15]):
with m.Case(0):
m.d.sync += self.de_result.uop.alu_op.eq(Mux(self.instruction[25:32] == 0, Operator.ADD, Operator.SUB))
with m.Case(2):
m.d.sync += self.de_result.uop.alu_op.eq(Operator.SIGNED_LT)
with m.Case(3):
m.d.sync += self.de_result.uop.alu_op.eq(Operator.UNSIGNED_LT)
with m.Case(4):
m.d.sync += self.de_result.uop.alu_op.eq(Operator.XOR)
with m.Case(6):
m.d.sync += self.de_result.uop.alu_op.eq(Operator.OR)
with m.Case(7):
m.d.sync += self.de_result.uop.alu_op.eq(Operator.AND)
with m.Default():
m.d.sync += self.de_result.uop.alu_op.eq(Operator.ADD)
with m.Case(0b00100):
with m.Switch(self.instruction[12:15]):
with m.Case(0):
m.d.sync += self.de_result.uop.alu_op.eq(Operator.ADD)
with m.Case(2):
m.d.sync += self.de_result.uop.alu_op.eq(Operator.SIGNED_LT)
with m.Case(3):
m.d.sync += self.de_result.uop.alu_op.eq(Operator.UNSIGNED_LT)
with m.Case(4):
m.d.sync += self.de_result.uop.alu_op.eq(Operator.XOR)
with m.Case(6):
m.d.sync += self.de_result.uop.alu_op.eq(Operator.OR)
with m.Case(7):
m.d.sync += self.de_result.uop.alu_op.eq(Operator.AND)
with m.Default():
m.d.sync += self.de_result.uop.alu_op.eq(Operator.ADD)
with m.Switch(self.instruction[2:7]):
with m.Case(0b11011, 0b11001):
m.d.sync += self.de_result.uop.branch_type.eq(BranchType.UNCONDITIONAL)
with m.Case(0b11000):
m.d.sync += self.de_result.uop.branch_type.eq(self.instruction[12:15])
with m.Default():
m.d.sync += self.de_result.uop.branch_type.eq(BranchType.NEVER)
# m.d.sync += self.de_result.uop.branch_type.eq(self.instruction[12:15])
# # Not B-type
# with m.If(~self.instruction[6] | self.instruction[2]):
# m.d.sync += self.de_result.uop.branch_type.eq(BranchType.UNCONDITIONAL)
## Backend: Execution Unit
# Register File
m.submodules.rf = rf = Memory(width=32, depth=32)
rf_rs1 = rf.read_port(domain="comb")
rf_rs2 = rf.read_port(domain="comb")
rf_wb = rf.write_port()
m.d.comb += [
rf_rs1.addr.eq(self.de_result.rs1),
rf_rs2.addr.eq(self.de_result.rs2),
rf_wb.addr.eq(self.de_result.rd),
# rf_wb.data.eq(self.exm_pc),
# rf_wb.en.eq(self.de_wb_rd),
]
input_a = Mux(self.de_result.uop.a_as_pc, self.exm_pc, rf_rs1.data)
input_b = Mux(self.de_result.uop.b_as_imm, self.de_result.imm, rf_rs2.data)
# ALU
m.submodules.alu = alu = ALU()
m.d.comb += [
alu.a.eq(input_a),
alu.b.eq(input_b),
alu.op.eq(self.de_result.uop.alu_op),
]
m.d.sync += self.exm_result.alu_out.eq(alu.out)
# Shifter
m.submodules.shifter = shifter = Shifter()
m.d.comb += [
shifter.x.eq(input_a),
shifter.y.eq(input_b),
]
# Branch and PC
with m.If(self.de_valid):
m.d.sync += self.exm_pc.eq(self.exm_pc + 4)
with m.Switch(self.de_result.uop.branch_type):
with m.Case(BranchType.EQ):
m.d.sync += self.exm_result.branch_taken.eq(self.exm_result.alu_out == self.exm_pc)
with m.Case(BranchType.NE):
m.d.sync += self.exm_result.branch_taken.eq(self.exm_result.alu_out != self.exm_pc)
with m.Case(BranchType.UNCONDITIONAL):
m.d.sync += self.exm_result.branch_taken.eq(1)
with m.Case(BranchType.NEVER):
m.d.sync += self.exm_result.branch_taken.eq(0)
with m.Case(BranchType.LT):
m.d.sync += self.exm_result.branch_taken.eq(self.exm_result.alu_out.as_signed() < self.exm_pc.as_signed())
with m.Case(BranchType.GE):
m.d.sync += self.exm_result.branch_taken.eq(self.exm_result.alu_out.as_signed() >= self.exm_pc.as_signed())
with m.Case(BranchType.LTU):
m.d.sync += self.exm_result.branch_taken.eq(self.exm_result.alu_out < self.exm_pc)
with m.Case(BranchType.GEU):
m.d.sync += self.exm_result.branch_taken.eq(self.exm_result.alu_out >= self.exm_pc)
with m.Default():
m.d.sync += self.exm_result.branch_taken.eq(0)
with m.If(self.exm_result.branch_taken):
m.d.sync += [
self.exm_pc.eq(self.exm_result.alu_out),
# CORRECTNESS: two taken branches cannot happen consecutively
self.exm_result.branch_taken.eq(0)
]
return m
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment