Skip to content

Instantly share code, notes, and snippets.

@V0XNIHILI
Created June 14, 2024 15:31
Show Gist options
  • Save V0XNIHILI/aa1c70ec0b6e662c13c74536bd04e9ca to your computer and use it in GitHub Desktop.
Save V0XNIHILI/aa1c70ec0b6e662c13c74536bd04e9ca to your computer and use it in GitHub Desktop.
Argmax tree
module pairwise_max #(
parameter int WIDTH = 8,
parameter int INDEX_WIDTH = 4
)(
input signed [WIDTH-1:0] a,
input signed [WIDTH-1:0] b,
input [INDEX_WIDTH-1:0] index_a,
input [INDEX_WIDTH-1:0] index_b,
output reg signed [WIDTH-1:0] max_value,
output reg [INDEX_WIDTH-1:0] max_index
);
always_comb begin
if (a > b) begin
max_value = a;
max_index = index_a;
end else begin
max_value = b;
max_index = index_b;
end
end
endmodule
module argmax_tree #(
parameter int WIDTH = 8,
parameter int N = 8, // Must be a power of two
parameter int LEVEL = 0,
localparam int IndexWidth = $clog2(N)
)(
input signed [WIDTH-1:0] data [N],
output [IndexWidth-1:0] argmax,
output signed [WIDTH-1:0] max
);
localparam int HalfN = N / 2;
generate
if (N == 2) begin: gen_single_pair
// Base case: single pair comparison
pairwise_max #(
.WIDTH(WIDTH),
.INDEX_WIDTH(1)
) u_pairwise_max (
.a(data[0]),
.b(data[1]),
.index_a(0),
.index_b(1),
.max_value(max),
.max_index(argmax)
);
end else begin: gen_split_merge
wire [IndexWidth-1:0] argmax_left, argmax_right;
wire signed [WIDTH-1:0] max_left, max_right;
// Recursive case: split and merge
wire signed [WIDTH-1:0] left_data [HalfN];
wire signed [WIDTH-1:0] right_data [HalfN];
for (genvar i = 0; i < HalfN; i++) begin
assign left_data[i] = data[i];
assign right_data[i] = data[i + HalfN];
end
// Recursive case: split and merge
argmax_tree #(
.WIDTH(WIDTH),
.N(HalfN),
.LEVEL(2*LEVEL + 1)
) u_argmax_tree_left (
.data(left_data),
.argmax(argmax_left),
.max(max_left)
);
argmax_tree #(
.WIDTH(WIDTH),
.N(HalfN),
.LEVEL(2*LEVEL + 2)
) u_argmax_tree_right (
.data(right_data),
.argmax(argmax_right),
.max(max_right)
);
// Final comparison between the results of the two halves
pairwise_max #(
.WIDTH(WIDTH),
.INDEX_WIDTH(IndexWidth)
) u_pairwise_max (
.a(max_left),
.b(max_right),
.index_a(argmax_left),
/* verilator lint_off WIDTHTRUNC */
.index_b(argmax_right + HalfN),
/* verilator lint_on WIDTHTRUNC */
.max_value(max),
.max_index(argmax)
);
end
endgenerate
endmodule
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment