Created
June 14, 2024 15:31
-
-
Save V0XNIHILI/aa1c70ec0b6e662c13c74536bd04e9ca to your computer and use it in GitHub Desktop.
Argmax tree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module pairwise_max #( | |
parameter int WIDTH = 8, | |
parameter int INDEX_WIDTH = 4 | |
)( | |
input signed [WIDTH-1:0] a, | |
input signed [WIDTH-1:0] b, | |
input [INDEX_WIDTH-1:0] index_a, | |
input [INDEX_WIDTH-1:0] index_b, | |
output reg signed [WIDTH-1:0] max_value, | |
output reg [INDEX_WIDTH-1:0] max_index | |
); | |
always_comb begin | |
if (a > b) begin | |
max_value = a; | |
max_index = index_a; | |
end else begin | |
max_value = b; | |
max_index = index_b; | |
end | |
end | |
endmodule | |
module argmax_tree #( | |
parameter int WIDTH = 8, | |
parameter int N = 8, // Must be a power of two | |
parameter int LEVEL = 0, | |
localparam int IndexWidth = $clog2(N) | |
)( | |
input signed [WIDTH-1:0] data [N], | |
output [IndexWidth-1:0] argmax, | |
output signed [WIDTH-1:0] max | |
); | |
localparam int HalfN = N / 2; | |
generate | |
if (N == 2) begin: gen_single_pair | |
// Base case: single pair comparison | |
pairwise_max #( | |
.WIDTH(WIDTH), | |
.INDEX_WIDTH(1) | |
) u_pairwise_max ( | |
.a(data[0]), | |
.b(data[1]), | |
.index_a(0), | |
.index_b(1), | |
.max_value(max), | |
.max_index(argmax) | |
); | |
end else begin: gen_split_merge | |
wire [IndexWidth-1:0] argmax_left, argmax_right; | |
wire signed [WIDTH-1:0] max_left, max_right; | |
// Recursive case: split and merge | |
wire signed [WIDTH-1:0] left_data [HalfN]; | |
wire signed [WIDTH-1:0] right_data [HalfN]; | |
for (genvar i = 0; i < HalfN; i++) begin | |
assign left_data[i] = data[i]; | |
assign right_data[i] = data[i + HalfN]; | |
end | |
// Recursive case: split and merge | |
argmax_tree #( | |
.WIDTH(WIDTH), | |
.N(HalfN), | |
.LEVEL(2*LEVEL + 1) | |
) u_argmax_tree_left ( | |
.data(left_data), | |
.argmax(argmax_left), | |
.max(max_left) | |
); | |
argmax_tree #( | |
.WIDTH(WIDTH), | |
.N(HalfN), | |
.LEVEL(2*LEVEL + 2) | |
) u_argmax_tree_right ( | |
.data(right_data), | |
.argmax(argmax_right), | |
.max(max_right) | |
); | |
// Final comparison between the results of the two halves | |
pairwise_max #( | |
.WIDTH(WIDTH), | |
.INDEX_WIDTH(IndexWidth) | |
) u_pairwise_max ( | |
.a(max_left), | |
.b(max_right), | |
.index_a(argmax_left), | |
/* verilator lint_off WIDTHTRUNC */ | |
.index_b(argmax_right + HalfN), | |
/* verilator lint_on WIDTHTRUNC */ | |
.max_value(max), | |
.max_index(argmax) | |
); | |
end | |
endgenerate | |
endmodule |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment