Skip to content

Instantly share code, notes, and snippets.

@jbush001
Last active May 13, 2019 14:25
Show Gist options
  • Save jbush001/59a82882a0b3b60dc3dcdf8f1088a138 to your computer and use it in GitHub Desktop.
Save jbush001/59a82882a0b3b60dc3dcdf8f1088a138 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
#
# Copyright 2019 Jeff Bush
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
We are adding together all of the partial products. Each one shifts
left by 2 bits instead of one because they are booth encoded.
xxxxxxxx
xxxxxxxx
xxxxxxxx
+ xxxxxxxx
-----------------
This works by maintaining an array of wire names for each bit position
(corresponding to the columns of the diagram above). Each layer combines
the bits in a bit position with the use of full adders (whenever there are
three or more wires) and half adders (if there are only two wires). When a
full adder is used, there will be a carry signal put into the next higher
bit position array. This continues until there is only one wire in each
column, at which point you have the result.
"""
import math
NUM_OUTPUT_DIGITS = 64
INPUT_DIGITS = int(NUM_OUTPUT_DIGITS / 2)
NUM_PARTIAL_PRODUCTS = int(INPUT_DIGITS / 2)
BITS_PER_PP = INPUT_DIGITS + 1
NUM_LAYERS = int(math.log(NUM_OUTPUT_DIGITS, 2))
all_wires = []
rtl_source_code = ''
def idx2chr(i):
return chr(i + ord('a'))
def new_layer():
global next_full_adder_id, next_half_adder_id
next_full_adder_id = [0 for x in range(NUM_OUTPUT_DIGITS)]
next_half_adder_id = [0 for x in range(NUM_OUTPUT_DIGITS)]
def emit_full_adder(layer, place, sum, carry_out, a, b, c):
"""Print a full adder with the passed signals.
This adds together 3 binary digits and returns a sum and carry. The
array next_full_adder is used to create a unique instance name: it keeps
track of how many adders are already in this layer and place.
Args:
layer: int
Index of layer number (0 being first). This is used to create a unique,
descriptive name for the adder.
place: int
Which place in the value this is adding values for (with 0 being
least significant)
sum: string
Name of the wire that should be assigned the sum.
carry_out: string
Name of the wire that should be assigned the carry out vaue.
a, b, c: string
Name of wires for values to be added.
Returns:
Nothing
"""
global next_full_adder_id, rtl_source_code
rtl_source_code += ' full_adder fa{}_{}{}({}, {}, {}, {}, {});\n'.format(layer, place,
idx2chr(next_full_adder_id[place]), sum, carry_out, a, b, c)
next_full_adder_id[place] += 1
def emit_half_adder(layer, place, sum, carry_out, a, b):
"""Print a half adder with the passed signals.
The half adder is similar to the full adder, but it has two inputs
instead of 3.
layer: int
Index of layer number (0 being first). This is used to create a unique,
descriptive name for the adder.
place: int
Which place in the value this is adding values for (with 0 being
least significant)
sum: string
Name of the wire that should be assigned the sum.
carry_out: string
Name of the wire that should be assigned the carry out vaue.
a, b: string
Name of wires for values to be added.
Returns:
Nothing
"""
global next_half_adder_id, rtl_source_code
rtl_source_code += ' half_adder ha{}_{}{}({}, {}, {}, {});\n'.format(layer, place,
idx2chr(next_half_adder_id[place]), sum, carry_out, a, b)
next_half_adder_id[place] += 1
def layer_reduction(layer_num, in_digits):
"""Emit adders and half adders to combine signals in each place.
Args:
layer_num: int
Index of layer (0 being first)
in_digits: list[list[string]]
A list of lists. The outer list corresponds to the places (digits).
Each inner list is the list of signals in that place to be added.
Returns:
list[list[string]] The reduced form of in_digits. This is a new list that
has names of the output signals from the adders (and may, in some cases,
pass the previous signal through)
"""
wire_num = 0
out_digits = [[] for x in range(NUM_OUTPUT_DIGITS)]
next_sum_id = [0 for x in range(NUM_OUTPUT_DIGITS)]
next_carry_id = [0 for x in range(NUM_OUTPUT_DIGITS)]
for place, in_wires in enumerate(in_digits):
while in_wires:
if len(in_wires) >= 2:
sum = 's{}_{}{}'.format(layer_num, place, idx2chr(next_sum_id[place]))
all_wires.append(sum)
next_sum_id[place] += 1
carry = 'c{}_{}{}'.format(layer_num, place, idx2chr(next_carry_id[place]))
next_carry_id[place] += 1
all_wires.append(carry)
if len(in_wires) >= 3:
emit_full_adder(layer, place, sum, carry, *in_wires[:3])
del in_wires[:3]
else:
emit_half_adder(layer, place, sum, carry, *in_wires[:2])
del in_wires[:2]
out_digits[place].append(sum)
if place + 1 < NUM_OUTPUT_DIGITS:
out_digits[place + 1].append(carry)
else:
# Pass wire through to next layer
out_digits[place].append(in_wires[0])
del in_wires[0]
return out_digits
def build_layer_inputs():
"""Create the initial set of inputs to the tree.
This includes all fo the partial products, as well as sign extension
values.
"""
digits = [[] for x in range(NUM_OUTPUT_DIGITS)]
for pp_idx in range(NUM_PARTIAL_PRODUCTS):
# Partial product bits
for digit_pos in range(BITS_PER_PP):
shifted_index = digit_pos + pp_idx * 2
digits[shifted_index].append('pp[{}][{}]'
.format(pp_idx, digit_pos))
# When a value is negative, we invert the values and add one.
digits[pp_idx * 2].append('neg[{}]'.format(pp_idx))
# Sign extension. Rather than fully sign extending, we can do a trick:
# prepend ~S S S to the first partial product
# prepend 1 ~S to intermediate partial products
# the second to last partia product has ~S prepended
digits[BITS_PER_PP + 1].append('neg[{}]'.format(pp_idx))
digits[BITS_PER_PP + 2].append('neg[{}]'.format(pp_idx))
digits[BITS_PER_PP + 2].append('!neg[{}]'.format(pp_idx))
for pp_idx in range(1, NUM_PARTIAL_PRODUCTS - 2):
digits[BITS_PER_PP + pp_idx + 1].append('!neg[{}]'.format(pp_idx))
digits[BITS_PER_PP + pp_idx + 2].append('1\'b1')
digits[BITS_PER_PP - 1].append('neg[{}]'.format(NUM_PARTIAL_PRODUCTS - 2))
return digits
def insert_layer_flops(layer_num, digits):
"""Inserts flip flops for each signal coming out of a layer.
Args:
layer_num: int
Index of layer, with 0 being first. This is used to create unique
net names.
digits: list[list[string]]
Each element in the outer list corresponds to the place (digit). Each
list is all signals in that place to be added.
Returns:
list[list[string]] A new list that will have the same structure as digit,
but with the original values replaced with flip flops.
"""
global rtl_source_code
new_layer = digits[:]
rtl_source_code += '\n always @(posedge clk)\n'
rtl_source_code += ' begin\n'
for index, wires in enumerate(new_layer):
for id, wire in enumerate(wires):
flop_output = 'd{}_{}{}'.format(layer, index, idx2chr(id))
rtl_source_code += ' {} <= {};\n'.format(flop_output, wire)
wires[id] = flop_output
all_wires.append(flop_output)
rtl_source_code += ' end\n'
return new_layer
print('''// This file autogenerated by make_wallace_tree.py
module wallace_tree(
input clk,
input logic[{}:0][{}:0] pp,
output[{}:0] result);
'''.format(NUM_PARTIAL_PRODUCTS - 1, BITS_PER_PP - 1, NUM_OUTPUT_DIGITS))
wires = build_layer_inputs()
for layer in range(NUM_LAYERS):
new_layer()
rtl_source_code += '\n // layer {}\n'.format(layer)
wires = layer_reduction(layer, wires)
if layer % 2 == 1:
wires = insert_layer_flops(layer, wires)
# Declare wires
DECLS_PER_LINE = 15
for index in range(0, len(all_wires), DECLS_PER_LINE):
if index + DECLS_PER_LINE > len(all_wires):
slice = all_wires[index:]
else:
slice = all_wires[index:index + DECLS_PER_LINE]
print(' logic ' + ', '.join(slice) + ';')
print(rtl_source_code)
print(' assign addend1 = {' + (', '.join([x[0] for x in reversed(wires) ])) + '};')
print(' assign addend2 = {' + (', '.join([x[1] if len(x) > 1 else '1\'b0' for x in reversed(wires) ])) + '};')
print(' assign result = addend1 + addend2;')
print('endmodule')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment