-
-
Save adrianparvino/4d935737d4fc03064f6bf20183ec17b1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import enum | |
from amaranth import * | |
from amaranth.lib import wiring, data | |
from amaranth.hdl.mem import Memory | |
from amaranth.lib.wiring import In, Out, Component | |
class Operator(enum.Enum): | |
ADD = 0 | |
SUB = 1 | |
SIGNED_LT = 2 | |
UNSIGNED_LT = 3 | |
XOR = 4 | |
CONST_B = 5 | |
OR = 6 | |
AND = 7 | |
class BranchType(enum.Enum): | |
EQ = 0 | |
NE = 1 | |
UNCONDITIONAL = 2 | |
NEVER = 3 | |
LT = 4 | |
GE = 5 | |
LTU = 6 | |
GEU = 7 | |
class UOp(data.Struct): | |
""" | |
UOp | |
""" | |
alu_op: Operator | |
branch_type: BranchType | |
wb_rd: 1 | |
a_as_pc: 1 | |
b_as_imm: 1 | |
class DEResult(data.Struct): | |
""" | |
DEResult | |
""" | |
rs1: 5 | |
rs2: 5 | |
rd: 5 | |
imm: 32 | |
uop: UOp | |
class EXMResult(data.Struct): | |
alu_out: 32 | |
branch_taken: 1 | |
stall: 1 | |
class ALU(Component): | |
""" | |
ALU | |
Attributes | |
---------- | |
a: Signal, in | |
The first input. | |
b: Signal, in | |
The second input. | |
op: Operator | |
The operation to perform. | |
out: Signal, out | |
The result of the operation. | |
""" | |
a: In(32) | |
b: In(32) | |
op: In(Operator) | |
out: Out(32) | |
def __init__(self): | |
super().__init__() | |
def elaborate(self, platform): | |
m = Module() | |
with m.Switch(self.op): | |
with m.Case(Operator.ADD): | |
m.d.comb += self.out.eq(self.a + self.b) | |
with m.Case(Operator.SUB): | |
m.d.comb += self.out.eq(self.a - self.b) | |
with m.Case(Operator.SIGNED_LT): | |
m.d.comb += self.out.eq(self.a.as_signed() < self.b.as_signed()) | |
with m.Case(Operator.UNSIGNED_LT): | |
m.d.comb += self.out.eq(self.a < self.b) | |
with m.Case(Operator.XOR): | |
m.d.comb += self.out.eq(self.a ^ self.b) | |
with m.Case(Operator.CONST_B): | |
m.d.comb += self.out.eq(self.b) | |
with m.Case(Operator.OR): | |
m.d.comb += self.out.eq(self.a | self.b) | |
with m.Case(Operator.AND): | |
m.d.comb += self.out.eq(self.a & self.b) | |
return m | |
class Shifter(Component): | |
""" | |
Shifter | |
Attributes | |
---------- | |
x: Signal, in | |
The value to shift. | |
y: Signal, in | |
The shift amount. | |
signed: Signal, in | |
Sign extension | |
out: Signal, out | |
The shifted value. | |
""" | |
x: In(32) | |
y: In(32) | |
signed: In(1) | |
out: Out(32) | |
def __init__(self): | |
super().__init__() | |
def elaborate(self, platform): | |
m = Module() | |
m.d.comb += self.out.eq( | |
Mux(self.signed, | |
self.x.as_signed() >> self.y, | |
self.x >> self.y | |
) | |
) | |
return m | |
class Immdecoder(Component): | |
""" | |
Decodes immediate values. | |
Attributes | |
---------- | |
instruction : Signal, in | |
The instruction to decode. | |
imm : Signal, out | |
The immediate value. | |
""" | |
instruction: In(32) | |
imm: Out(32) | |
def __init__(self): | |
super().__init__() | |
def elaborate(self, platform): | |
m = Module() | |
j = self.instruction[3] | |
s = self.instruction[3:7] == 0b0100 | |
b = self.instruction[6] & self.instruction[2:5] == 0b000 | |
u = self.instruction[4] & self.instruction[2] | |
m.d.comb += self.imm[31].eq(self.instruction[31]) | |
with m.Switch(u): | |
with m.Case(0): | |
m.d.comb += self.imm[20:31].eq(self.instruction[31].as_signed()) | |
with m.Case(1): | |
m.d.comb += self.imm[20:31].eq(self.instruction[20:31]) # U | |
with m.Switch(u | j): | |
with m.Case(0): | |
m.d.comb += self.imm[12:20].eq(self.instruction[31].as_signed()) | |
with m.Case(1): | |
m.d.comb += self.imm[12:20].eq(self.instruction[12:20]) # U | J | |
with m.Switch(Cat(b | j, b | u)): | |
with m.Case(0): | |
m.d.comb += self.imm[11].eq(self.instruction[31]) # S | I | |
with m.Case(1): | |
m.d.comb += self.imm[11].eq(self.instruction[20]) # J | |
with m.Case(2): | |
m.d.comb += self.imm[11].eq(0) # U | |
with m.Case(3): | |
m.d.comb += self.imm[11].eq(self.instruction[7]) # B | |
with m.Switch(u): | |
with m.Case(0): | |
m.d.comb += self.imm[5:11].eq(self.instruction[25:31]) | |
with m.Case(1): | |
m.d.comb += self.imm[5:11].eq(0) # U | |
with m.Switch(Cat(u | j, s | b | u)): | |
with m.Case(0): | |
m.d.comb += self.imm[1:5].eq(self.instruction[21:25]) # I | |
with m.Case(1): | |
m.d.comb += self.imm[1:5].eq(self.instruction[21:25]) # J | |
with m.Case(2): | |
m.d.comb += self.imm[1:5].eq(self.instruction[8:12]) # S | B | |
with m.Case(3): | |
m.d.comb += self.imm[1:5].eq(0) # U | |
with m.Switch(Cat(b | u | j, s)): | |
with m.Case(0): | |
m.d.comb += self.imm[0].eq(self.instruction[20]) # I | |
with m.Case(1): | |
m.d.comb += self.imm[0].eq(0) # B | U | J | |
with m.Case(2): | |
m.d.comb += self.imm[0].eq(self.instruction[7]) # S | |
with m.Case(3): | |
m.d.comb += self.imm[0].eq(0) # UNREACHABLE | |
return m | |
class Furv(Component): | |
""" | |
A RISC-V CPU | |
Attributes | |
---------- | |
instruction : Signal, in | |
The instruction to execute. | |
pc : Signal, out | |
The program counter. | |
(Wishbone Interface) | |
data_in : Signal, in | |
The data read from memory. | |
data_out : Signal, out | |
The data written to memory. | |
addr : Signal, out | |
The word address of the memory access. | |
sel : Signal, out | |
The byte select for the memory access. | |
mem : Signal, out | |
The memory access is valid. | |
mem_write : Signal, out | |
The memory access is a write. | |
ack : Signal, in | |
The memory access is acknowledged. | |
""" | |
instruction: In(32) | |
pc: Out(32) | |
data_in: In(32) | |
data_out: Out(32) | |
addr: Out(32) | |
sel: Out(4) | |
mem: Out(1) | |
mem_write: Out(1) | |
ack: In(1) | |
reset: In(1) | |
def __init__(self, start_pc): | |
## Frontend: Bus Interface Unit | |
# Instruction fetch | |
self.biu_pc = Signal(32, reset = start_pc) | |
self.if_cycles = Signal(1) | |
# Instruction decode | |
# self.de_result = Signal(DEResult, reset=Value.cast(DEResult.const({"uop": Value.cast(UOp.const({"branch_type": Value.cast(BranchType.NEVER)}))}))) | |
self.de_result = Signal(DEResult, reset={"uop": { "branch_type": BranchType.NEVER }}) | |
self.de_valid = Signal() | |
self.de_stall = Signal() | |
## Backend: Execution Unit | |
self.exm_pc = Signal(32, reset = start_pc) | |
# self.exm_jmp = Signal(1) | |
# self.exm_branch_pc = Signal(32) | |
# self.exm_branch_taken = Signal(1) | |
self.exm_result = Signal(EXMResult) | |
self.exm_stall = Signal(1) | |
super().__init__() | |
def elaborate(self, platform): | |
m = Module() | |
m.d.comb += ResetSignal().eq(self.reset) | |
if_valid = self.if_cycles == 1 | |
## Frontend: Bus Interface Unit | |
# Instruction Fetch | |
m.d.comb += self.pc.eq(self.biu_pc & ~0x3) | |
with m.If(~self.de_stall): | |
m.d.sync += self.biu_pc.eq(self.biu_pc + 4) | |
with m.If(~if_valid): | |
m.d.sync += self.if_cycles.eq(self.if_cycles + 1) | |
with m.If(self.exm_result.branch_taken): | |
m.d.sync += [ | |
self.if_cycles.eq(0), | |
self.biu_pc.eq(self.exm_result.alu_out) | |
] | |
# Instruction Decode | |
de_valid = if_valid & ~self.exm_result.branch_taken | |
m.submodules.immdecoder = immdecoder = Immdecoder() | |
m.d.comb += immdecoder.instruction.eq(self.instruction) | |
m.d.sync += self.de_result.imm.eq(immdecoder.imm) | |
with m.If(~self.exm_stall): | |
m.d.sync += [ | |
self.de_result.rs1.eq(self.instruction[15:20]), | |
self.de_result.rs2.eq(self.instruction[20:24]), | |
self.de_result.rd.eq(self.instruction[7:12]), | |
self.de_valid.eq(de_valid) | |
] | |
with m.Switch(self.instruction[2:7]): | |
with m.Case(0b11011, 0b01101, 0b00101): | |
m.d.sync += self.de_result.rs1.eq(0) | |
with m.Switch(self.instruction[2:7]): | |
with m.Case(0b00100, 0b00000, 0b11011, 0b11001, 0b01101, 0b00101): | |
m.d.sync += self.de_result.rs2.eq(0) | |
# m.d.sync += [ | |
# self.de_result.uop.jmp.eq(de_valid & (self.instruction[4:7] == 0b110)) | |
# ] | |
with m.Switch(self.instruction[2:7]): | |
with m.Case(0b01100, 0b00100, 0b00000, 0b01000, 0b01101, 0b00101): | |
m.d.sync += self.de_result.uop.a_as_pc.eq(0) | |
with m.Case(0b11000, 0b11011, 0b11001): | |
m.d.sync += self.de_result.uop.a_as_pc.eq(1) | |
with m.Switch(self.instruction[2:7]): | |
with m.Case(0b01100): | |
m.d.sync += self.de_result.uop.b_as_imm.eq(0) | |
with m.Case(0b00100, 0b00000, 0b01000, 0b11000, 0b11011, 0b11001, 0b01101, 0b00101): | |
m.d.sync += self.de_result.uop.b_as_imm.eq(1) | |
# m.d.sync += self.de_result.uop.alu_op.eq(Operator.ADD) | |
with m.Switch(self.instruction[2:7]): | |
with m.Case(0b01100): | |
with m.Switch(self.instruction[12:15]): | |
with m.Case(0): | |
m.d.sync += self.de_result.uop.alu_op.eq(Mux(self.instruction[25:32] == 0, Operator.ADD, Operator.SUB)) | |
with m.Case(2): | |
m.d.sync += self.de_result.uop.alu_op.eq(Operator.SIGNED_LT) | |
with m.Case(3): | |
m.d.sync += self.de_result.uop.alu_op.eq(Operator.UNSIGNED_LT) | |
with m.Case(4): | |
m.d.sync += self.de_result.uop.alu_op.eq(Operator.XOR) | |
with m.Case(6): | |
m.d.sync += self.de_result.uop.alu_op.eq(Operator.OR) | |
with m.Case(7): | |
m.d.sync += self.de_result.uop.alu_op.eq(Operator.AND) | |
with m.Default(): | |
m.d.sync += self.de_result.uop.alu_op.eq(Operator.ADD) | |
with m.Case(0b00100): | |
with m.Switch(self.instruction[12:15]): | |
with m.Case(0): | |
m.d.sync += self.de_result.uop.alu_op.eq(Operator.ADD) | |
with m.Case(2): | |
m.d.sync += self.de_result.uop.alu_op.eq(Operator.SIGNED_LT) | |
with m.Case(3): | |
m.d.sync += self.de_result.uop.alu_op.eq(Operator.UNSIGNED_LT) | |
with m.Case(4): | |
m.d.sync += self.de_result.uop.alu_op.eq(Operator.XOR) | |
with m.Case(6): | |
m.d.sync += self.de_result.uop.alu_op.eq(Operator.OR) | |
with m.Case(7): | |
m.d.sync += self.de_result.uop.alu_op.eq(Operator.AND) | |
with m.Default(): | |
m.d.sync += self.de_result.uop.alu_op.eq(Operator.ADD) | |
with m.Switch(self.instruction[2:7]): | |
with m.Case(0b11011, 0b11001): | |
m.d.sync += self.de_result.uop.branch_type.eq(BranchType.UNCONDITIONAL) | |
with m.Case(0b11000): | |
m.d.sync += self.de_result.uop.branch_type.eq(self.instruction[12:15]) | |
with m.Default(): | |
m.d.sync += self.de_result.uop.branch_type.eq(BranchType.NEVER) | |
# m.d.sync += self.de_result.uop.branch_type.eq(self.instruction[12:15]) | |
# # Not B-type | |
# with m.If(~self.instruction[6] | self.instruction[2]): | |
# m.d.sync += self.de_result.uop.branch_type.eq(BranchType.UNCONDITIONAL) | |
## Backend: Execution Unit | |
# Register File | |
m.submodules.rf = rf = Memory(width=32, depth=32) | |
rf_rs1 = rf.read_port(domain="comb") | |
rf_rs2 = rf.read_port(domain="comb") | |
rf_wb = rf.write_port() | |
m.d.comb += [ | |
rf_rs1.addr.eq(self.de_result.rs1), | |
rf_rs2.addr.eq(self.de_result.rs2), | |
rf_wb.addr.eq(self.de_result.rd), | |
# rf_wb.data.eq(self.exm_pc), | |
# rf_wb.en.eq(self.de_wb_rd), | |
] | |
input_a = Mux(self.de_result.uop.a_as_pc, self.exm_pc, rf_rs1.data) | |
input_b = Mux(self.de_result.uop.b_as_imm, self.de_result.imm, rf_rs2.data) | |
# ALU | |
m.submodules.alu = alu = ALU() | |
m.d.comb += [ | |
alu.a.eq(input_a), | |
alu.b.eq(input_b), | |
alu.op.eq(self.de_result.uop.alu_op), | |
] | |
m.d.sync += self.exm_result.alu_out.eq(alu.out) | |
# Shifter | |
m.submodules.shifter = shifter = Shifter() | |
m.d.comb += [ | |
shifter.x.eq(input_a), | |
shifter.y.eq(input_b), | |
] | |
# Branch and PC | |
with m.If(self.de_valid): | |
m.d.sync += self.exm_pc.eq(self.exm_pc + 4) | |
with m.Switch(self.de_result.uop.branch_type): | |
with m.Case(BranchType.EQ): | |
m.d.sync += self.exm_result.branch_taken.eq(self.exm_result.alu_out == self.exm_pc) | |
with m.Case(BranchType.NE): | |
m.d.sync += self.exm_result.branch_taken.eq(self.exm_result.alu_out != self.exm_pc) | |
with m.Case(BranchType.UNCONDITIONAL): | |
m.d.sync += self.exm_result.branch_taken.eq(1) | |
with m.Case(BranchType.NEVER): | |
m.d.sync += self.exm_result.branch_taken.eq(0) | |
with m.Case(BranchType.LT): | |
m.d.sync += self.exm_result.branch_taken.eq(self.exm_result.alu_out.as_signed() < self.exm_pc.as_signed()) | |
with m.Case(BranchType.GE): | |
m.d.sync += self.exm_result.branch_taken.eq(self.exm_result.alu_out.as_signed() >= self.exm_pc.as_signed()) | |
with m.Case(BranchType.LTU): | |
m.d.sync += self.exm_result.branch_taken.eq(self.exm_result.alu_out < self.exm_pc) | |
with m.Case(BranchType.GEU): | |
m.d.sync += self.exm_result.branch_taken.eq(self.exm_result.alu_out >= self.exm_pc) | |
with m.Default(): | |
m.d.sync += self.exm_result.branch_taken.eq(0) | |
with m.If(self.exm_result.branch_taken): | |
m.d.sync += [ | |
self.exm_pc.eq(self.exm_result.alu_out), | |
# CORRECTNESS: two taken branches cannot happen consecutively | |
self.exm_result.branch_taken.eq(0) | |
] | |
return m |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment