Skip to content

Instantly share code, notes, and snippets.

@paxbun
Last active July 7, 2021 18:40
Show Gist options
  • Save paxbun/f29e3793ae5e7fd88baa1827fc3294dd to your computer and use it in GitHub Desktop.
Save paxbun/f29e3793ae5e7fd88baa1827fc3294dd to your computer and use it in GitHub Desktop.
module float32_add_4
(
input [31:0] args [3:0],
output [31:0] res
);
wire [31:0] add1_res, add2_res;
float32_add adder1 (.lhs(args[0]), .rhs(args[1]), .res(add1_res));
float32_add adder2 (.lhs(args[2]), .rhs(args[3]), .res(add2_res));
float32_add adder3 (.lhs(add1_res), .rhs(add2_res), .res(res));
endmodule
module float32_dot_4
(
input [31:0] lhs [3:0],
input [31:0] rhs [3:0],
output [31:0] res
);
wire [31:0] mul_res [3:0];
genvar i;
generate
for (i = 0; i < 4; i = i + 1) begin
float32_mul multiplier (
.lhs(lhs[i]),
.rhs(rhs[i]),
.res(mul_res[i])
);
end
endgenerate
float32_add_4 sum (
.args(mul_res),
.res(res)
);
endmodule
module float32_mat4x4_mul
(
input [31:0] lhs [3:0][3:0],
input [31:0] rhs [3:0][3:0],
output [31:0] res [3:0][3:0]
);
genvar i, j, k;
generate
for (i = 0; i < 4; i = i + 1) begin
for (j = 0; j < 4; j = j + 1) begin
wire [31:0] lhs_tmp [3:0];
wire [31:0] rhs_tmp [3:0];
for (k = 0; k < 4; k = k + 1) begin
assign lhs_tmp[k] = lhs[i][k];
assign rhs_tmp[k] = rhs[k][j];
end
float32_dot_4 product (
.lhs(lhs_tmp),
.rhs(rhs_tmp),
.res(res[i][j])
);
end
end
endgenerate
endmodule
module float32_4_add
(
input [31:0] arg0, arg1, arg2, arg3,
output [31:0] res
);
wire [31:0] add1_res, add2_res;
float32_add adder1 (.lhs(arg0), .rhs(arg1), .res(add1_res));
float32_add adder2 (.lhs(arg2), .rhs(arg3), .res(add2_res));
float32_add adder3 (.lhs(add1_res), .rhs(add2_res), .res(res));
endmodule
module float32_4_dot
(
input [31:0] lhs_0, lhs_1, lhs_2, lhs_3,
input [31:0] rhs_0, rhs_1, rhs_2, rhs_3,
output [31:0] res
);
wire [31:0] mul_res [3:0];
float32_mul multiplier0 (.lhs(lhs_0), .rhs(rhs_0), .res(mul_res[0]));
float32_mul multiplier1 (.lhs(lhs_1), .rhs(rhs_1), .res(mul_res[1]));
float32_mul multiplier2 (.lhs(lhs_2), .rhs(rhs_2), .res(mul_res[2]));
float32_mul multiplier3 (.lhs(lhs_3), .rhs(rhs_3), .res(mul_res[3]));
float32_4_add sum
(
.arg0(mul_res[0]), .arg1(mul_res[1]), .arg2(mul_res[2]), .arg3(mul_res[3]),
.res(res)
);
endmodule
module float32_mat4x4_mul
(
input [511:0] lhs, rhs,
output [511:0] res
);
genvar i, j;
generate
for (i = 0; i < 4; i = i + 1) begin
for (j = 0; j < 4; j = j + 1) begin
float32_4_dot dot
(
.lhs_0( lhs[ ((i * 4 + 0) * 32 + 31) : ((i * 4 + 0) * 32) ] ),
.lhs_1( lhs[ ((i * 4 + 1) * 32 + 31) : ((i * 4 + 1) * 32) ] ),
.lhs_2( lhs[ ((i * 4 + 2) * 32 + 31) : ((i * 4 + 2) * 32) ] ),
.lhs_3( lhs[ ((i * 4 + 3) * 32 + 31) : ((i * 4 + 3) * 32) ] ),
.rhs_0( rhs[ ((0 * 4 + j) * 32 + 31) : ((0 * 4 + j) * 32) ] ),
.rhs_1( rhs[ ((1 * 4 + j) * 32 + 31) : ((1 * 4 + j) * 32) ] ),
.rhs_2( rhs[ ((2 * 4 + j) * 32 + 31) : ((2 * 4 + j) * 32) ] ),
.rhs_3( rhs[ ((3 * 4 + j) * 32 + 31) : ((3 * 4 + j) * 32) ] ),
.res ( res[ ((i * 4 + j) * 32 + 31) : ((i * 4 + j) * 32) ] )
);
end
end
endgenerate
endmodule
module tb_float32_mat4x4;
real mat1_v [3:0][3:0];
real mat2_v [3:0][3:0];
real matr_v [3:0][3:0];
reg [31:0] mat1 [3:0][3:0];
reg [31:0] mat2 [3:0][3:0];
reg [31:0] matr [3:0][3:0];
real matr_av [3:0][3:0];
float32_mat4x4_mul multiplier
(
.lhs(mat1),
.rhs(mat2),
.res(matr)
);
always begin
for (int i = 0; i < 4; ++i) begin
for (int j = 0; j < 4; ++j) begin
mat1_v[i][j] = $urandom_range(0, 10000000) / 100000.0;
mat2_v[i][j] = $urandom_range(0, 1000) / 1000.0;
mat1[i][j] = $shortrealtobits(mat1_v[i][j]);
mat2[i][j] = $shortrealtobits(mat2_v[i][j]);
end
end
for (int i = 0; i < 4; ++i) begin
for (int j = 0; j < 4; ++j) begin
matr_v[i][j] = 0.0;
for (int k = 0; k < 4; ++k) begin
matr_v[i][j] += mat1_v[i][k] * mat2_v[k][j];
end
end
end
#5;
$display("desired:");
for (int i = 0; i < 4; ++i) begin
$display("[ %.6f %.6f %.6f %.6f ]", matr_v[i][0], matr_v[i][1], matr_v[i][2], matr_v[i][3]);
end
$display("");
$display("actual:");
for (int i = 0; i < 4; ++i) begin
for (int j = 0; j < 4; ++j) begin
matr_av[i][j] = $bitstoshortreal(matr[i][j]);
end
$display("[ %.6f %.6f %.6f %.6f ]", matr_av[i][0], matr_av[i][1], matr_av[i][2], matr_av[i][3]);
end
$display("");
#5;
end
endmodule
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment