FP-Growth in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class FPTree(): | |
def __init__(self, min_support=2, min_length=1, max_length=None): | |
self.min_support = min_support | |
self.min_length = min_length | |
self.max_length = max_length | |
# class for a tree node with a name, count, parent and children | |
# taken from : https://adataanalyst.com/machine-learning/fp-growth-algorithm-python-3/ | |
class treeNode: | |
def __init__(self, nameValue, numOccur, parentNode): | |
self.name = nameValue | |
self.count = numOccur | |
self.nodeLink = None | |
self.parent = parentNode #needs to be updated | |
self.children = {} | |
#increments the count variable with a given amount | |
def inc(self, numOccur): | |
self.count += numOccur | |
#display tree in text. Useful for debugging | |
def disp(self, ind=1): | |
print (' '*ind, self.name, ' ', self.count) | |
for child in self.children.values(): | |
child.disp(ind+1) | |
# split a list of text strings into a list of lists of items | |
def split_data(self, data, delimiter=";"): | |
return_list = [] | |
for item in data: | |
return_list.append(item.split(delimiter)) | |
return return_list | |
# get the counts for each item, only counting once for each transaction | |
# the counts should already be sorted in descending frequency order | |
def get_counts(self, data): | |
item_list = [] | |
for transaction in data: | |
transaction_set = set(transaction) | |
for item in transaction_set: | |
item_list.append(item) | |
counts = pd.value_counts(item_list) | |
# filter by min_support if specified | |
if self.min_support > 1: | |
counts = counts[counts > self.min_support] | |
return counts | |
def create_header_table(self, db): | |
counts = self.get_counts(db) | |
headerTable = {} | |
# construct a header table with the item as key, and the support and nodelink as the values | |
for item, support in zip(counts.index, counts): | |
if support > self.min_support: | |
headerTable[item] = [support, None] | |
return headerTable, counts | |
# update a tree by adding children and incrementing counts | |
def updateTree(self, items, inTree, headerTable, count): | |
# check if the first item is in the children of the three | |
if items[0] in inTree.children: | |
# if so increment the count by 1 | |
inTree.children[items[0]].inc(count) | |
# else we add it to the children | |
else: | |
# add the first item to the children | |
inTree.children[items[0]] = self.treeNode(items[0], count, inTree) | |
# if there is no node link back we create one | |
if headerTable[items[0]][1] == None: | |
headerTable[items[0]][1] = inTree.children[items[0]] | |
# else we add the node link to the list of node links | |
else: | |
self.updateHeader(headerTable[items[0]][1], inTree.children[items[0]]) | |
# if there are multiple items in the list we recurse adding each of them | |
if len(items) > 1: | |
self.updateTree(items[1::], inTree.children[items[0]], headerTable, count) | |
# updates header table | |
def updateHeader(self, nodeToTest, targetNode): #this version does not use recursion | |
while (nodeToTest.nodeLink != None): #Do not use recursion to traverse a linked list! | |
nodeToTest = nodeToTest.nodeLink | |
nodeToTest.nodeLink = targetNode | |
def create_tree(self, db): | |
headerTable, counts = self.create_header_table(db) | |
# scan DB again and construct FP tree | |
fpTree = self.treeNode('Null Set', 1, None) | |
for tranSet in db: #go through dataset 2nd time | |
localD = {} | |
for item in tranSet: #put transaction items in order | |
if item in counts.index: | |
localD[item] = headerTable[item][0] | |
if len(localD) > 0: | |
count = 1 | |
orderedItems = [v[0] for v in sorted(localD.items(), key=lambda p: p[1], reverse=True)] | |
self.updateTree(orderedItems, fpTree, headerTable, count)#populate tree with ordered freq itemset | |
self.fpTree = fpTree | |
self.headerTable = headerTable | |
self.counts = counts | |
return fpTree, headerTable | |
# ascends the tree from leafNode to root, collecting items on the path | |
def ascendTree(self, leafNode, prefixPath): | |
if leafNode.parent != None: | |
prefixPath.append(leafNode.name) | |
self.ascendTree(leafNode.parent, prefixPath) | |
# iterates through linked list, calling ascendTree for each item, returns conditional patterns | |
def findPrefixPath(self, basePat, treeNode): | |
condPats = {} | |
while treeNode != None: | |
prefixPath = [] | |
self.ascendTree(treeNode, prefixPath) | |
if len(prefixPath) > 1: | |
condPats[frozenset(prefixPath[1:])] = treeNode.count | |
treeNode = treeNode.nodeLink | |
return condPats | |
# put the frequent patterns into a better format | |
def format_results(self, frequent_patterns): | |
return_dict = {} | |
for key in sorted(frequent_patterns.keys()): | |
key_len = len(key) | |
# exclude patterns less than the specified min length | |
if key_len >= self.min_length: | |
if key_len not in return_dict: | |
return_dict[key_len] = {key: frequent_patterns[key]} | |
else: | |
return_dict[key_len][key] = frequent_patterns[key] | |
return return_dict | |
def mine(self): | |
# construct the conditional database from the reverse ordered counts | |
rev_counts = self.counts.sort_values() | |
cond_dbs = {} | |
# build our conditional db | |
for item in rev_counts.index: | |
cond_dbs[item] = self.findPrefixPath(item, self.headerTable[item][1]) | |
frequent_patterns = {} | |
# add our single item patterns | |
for item, support in zip(self.counts.sort_index().index, self.counts.sort_index()): | |
frequent_patterns[tuple([item])] = support | |
# scan through the conditional db in reverse order | |
for item in rev_counts.index: | |
temp_dict = {} | |
# find frequent single item patterns | |
for pattern in cond_dbs[item]: | |
support = cond_dbs[item][pattern] | |
cdb_list = list(pattern) | |
if self.max_length == None: | |
max_pattern_length = len(cdb_list) + 1 | |
else: | |
max_pattern_length = self.max_length | |
for i in range(1, max_pattern_length): | |
perms = combinations(cdb_list, i) | |
for perm in perms: | |
key = tuple(sorted(perm)) | |
if key in temp_dict: | |
temp_dict[key] += cond_dbs[item][pattern] | |
else: | |
temp_dict[key] = cond_dbs[item][pattern] | |
# create two sets of item with frequent single items in conditional db | |
for term, support in temp_dict.items(): | |
if support > self.min_support: | |
frequent_patterns[tuple(sorted([item] + list(term)))] = support | |
return self.format_results(frequent_patterns) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment