Skip to content

Instantly share code, notes, and snippets.

Last active Feb 4, 2019
What would you like to do?
FP-Growth in Python
class FPTree():
def __init__(self, min_support=2, min_length=1, max_length=None):
self.min_support = min_support
self.min_length = min_length
self.max_length = max_length
# class for a tree node with a name, count, parent and children
# taken from :
class treeNode:
def __init__(self, nameValue, numOccur, parentNode): = nameValue
self.count = numOccur
self.nodeLink = None
self.parent = parentNode #needs to be updated
self.children = {}
#increments the count variable with a given amount
def inc(self, numOccur):
self.count += numOccur
#display tree in text. Useful for debugging
def disp(self, ind=1):
print (' '*ind,, ' ', self.count)
for child in self.children.values():
# split a list of text strings into a list of lists of items
def split_data(self, data, delimiter=";"):
return_list = []
for item in data:
return return_list
# get the counts for each item, only counting once for each transaction
# the counts should already be sorted in descending frequency order
def get_counts(self, data):
item_list = []
for transaction in data:
transaction_set = set(transaction)
for item in transaction_set:
counts = pd.value_counts(item_list)
# filter by min_support if specified
if self.min_support > 1:
counts = counts[counts > self.min_support]
return counts
def create_header_table(self, db):
counts = self.get_counts(db)
headerTable = {}
# construct a header table with the item as key, and the support and nodelink as the values
for item, support in zip(counts.index, counts):
if support > self.min_support:
headerTable[item] = [support, None]
return headerTable, counts
# update a tree by adding children and incrementing counts
def updateTree(self, items, inTree, headerTable, count):
# check if the first item is in the children of the three
if items[0] in inTree.children:
# if so increment the count by 1
# else we add it to the children
# add the first item to the children
inTree.children[items[0]] = self.treeNode(items[0], count, inTree)
# if there is no node link back we create one
if headerTable[items[0]][1] == None:
headerTable[items[0]][1] = inTree.children[items[0]]
# else we add the node link to the list of node links
self.updateHeader(headerTable[items[0]][1], inTree.children[items[0]])
# if there are multiple items in the list we recurse adding each of them
if len(items) > 1:
self.updateTree(items[1::], inTree.children[items[0]], headerTable, count)
# updates header table
def updateHeader(self, nodeToTest, targetNode): #this version does not use recursion
while (nodeToTest.nodeLink != None): #Do not use recursion to traverse a linked list!
nodeToTest = nodeToTest.nodeLink
nodeToTest.nodeLink = targetNode
def create_tree(self, db):
headerTable, counts = self.create_header_table(db)
# scan DB again and construct FP tree
fpTree = self.treeNode('Null Set', 1, None)
for tranSet in db: #go through dataset 2nd time
localD = {}
for item in tranSet: #put transaction items in order
if item in counts.index:
localD[item] = headerTable[item][0]
if len(localD) > 0:
count = 1
orderedItems = [v[0] for v in sorted(localD.items(), key=lambda p: p[1], reverse=True)]
self.updateTree(orderedItems, fpTree, headerTable, count)#populate tree with ordered freq itemset
self.fpTree = fpTree
self.headerTable = headerTable
self.counts = counts
return fpTree, headerTable
# ascends the tree from leafNode to root, collecting items on the path
def ascendTree(self, leafNode, prefixPath):
if leafNode.parent != None:
self.ascendTree(leafNode.parent, prefixPath)
# iterates through linked list, calling ascendTree for each item, returns conditional patterns
def findPrefixPath(self, basePat, treeNode):
condPats = {}
while treeNode != None:
prefixPath = []
self.ascendTree(treeNode, prefixPath)
if len(prefixPath) > 1:
condPats[frozenset(prefixPath[1:])] = treeNode.count
treeNode = treeNode.nodeLink
return condPats
# put the frequent patterns into a better format
def format_results(self, frequent_patterns):
return_dict = {}
for key in sorted(frequent_patterns.keys()):
key_len = len(key)
# exclude patterns less than the specified min length
if key_len >= self.min_length:
if key_len not in return_dict:
return_dict[key_len] = {key: frequent_patterns[key]}
return_dict[key_len][key] = frequent_patterns[key]
return return_dict
def mine(self):
# construct the conditional database from the reverse ordered counts
rev_counts = self.counts.sort_values()
cond_dbs = {}
# build our conditional db
for item in rev_counts.index:
cond_dbs[item] = self.findPrefixPath(item, self.headerTable[item][1])
frequent_patterns = {}
# add our single item patterns
for item, support in zip(self.counts.sort_index().index, self.counts.sort_index()):
frequent_patterns[tuple([item])] = support
# scan through the conditional db in reverse order
for item in rev_counts.index:
temp_dict = {}
# find frequent single item patterns
for pattern in cond_dbs[item]:
support = cond_dbs[item][pattern]
cdb_list = list(pattern)
if self.max_length == None:
max_pattern_length = len(cdb_list) + 1
max_pattern_length = self.max_length
for i in range(1, max_pattern_length):
perms = combinations(cdb_list, i)
for perm in perms:
key = tuple(sorted(perm))
if key in temp_dict:
temp_dict[key] += cond_dbs[item][pattern]
temp_dict[key] = cond_dbs[item][pattern]
# create two sets of item with frequent single items in conditional db
for term, support in temp_dict.items():
if support > self.min_support:
frequent_patterns[tuple(sorted([item] + list(term)))] = support
return self.format_results(frequent_patterns)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment