#!/usr/bin/python3

from graphviz import Digraph



def read_log_file(file):
    lines = []
    with open(file, 'r') as f:
        lines = f.readlines()

    return lines


#input format example:
# :: 0x4b5b0060:	rep stosb	byte ptr es:[edi], al
def process_line(line):
    split_line = line.split('\t')
    address = split_line[0].split(':')[2].lstrip()
    mnemonic = split_line[1]
    opcode = split_line[2]

    return address, mnemonic, opcode


def save_edges(edges):
    with open('hexdump_trace_edges.txt', 'w') as f:
        for edge, count in edges.items():
            out_str = "{0}\t{1}\t{2}\n".format(edge[0], edge[1], count)
            f.write(out_str)


def draw():
    g = Digraph('G', filename='hexdump_trace.gv', format='png')

    lines = read_log_file('hexdump.log')

    branch_mnemonics = ['jg','jge','jle', 'jmp', 'js', 'loop' ]
    registers = ['edx']
    edges = {}
    saved_instruction = 0
    for line in lines:
        if line[0:2] != '::':    # skip lines that are not logged intsructions
            continue

        address, mnemonic, operand = process_line(line)
        label = address + '\t' + mnemonic + '\t' + operand

        if mnemonic in branch_mnemonics:
            g.node(address, label=label, style="filled", color="blue", fillcolor="black", fontcolor="white")
        elif mnemonic == "call":
            g.node(address, label=label, style="filled", color="red", fillcolor="purple", fontcolor="white")
        elif mnemonic == "int":
            g.node(address, label=label, color="red", shape="diamond")
        else:
            g.node(address, label=label, shape="box")

        if saved_instruction != 0:
            if (saved_instruction[0], address) not in edges:  # new edge
                edges[(saved_instruction[0], address)] = 1
                g.edge(saved_instruction[0], address)
            else:
                edges[(saved_instruction[0], address)] += 1

        saved_instruction = (address, mnemonic, operand)

    save_edges(edges)

    g.view()


if __name__ == '__main__':
    draw()