Skip to content

Instantly share code, notes, and snippets.

@sshh12
Created June 20, 2018 19:18
Show Gist options
  • Save sshh12/d362dba3f72498cd9604de63a659b870 to your computer and use it in GitHub Desktop.
Save sshh12/d362dba3f72498cd9604de63a659b870 to your computer and use it in GitHub Desktop.
from textblob.classifiers import DecisionTreeClassifier as TextClassifier
from datetime import datetime
import matplotlib.pyplot as plt
import csv
num_months = 13
with open('train.csv', 'r') as fp:
clf = TextClassifier(fp, format="csv")
activity = {}
months = []
with open('ChaseActivity.csv', newline='') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
date = datetime.strptime(row['Posting Date'], '%m/%d/%Y')
desc = row['Description']
amt = row['Amount']
month = date.strftime('%b %Y')
label = clf.classify(desc)
if month not in months:
months.append(month)
if not label in activity:
activity[label] = [0] * num_months
if not month in activity[label]:
activity[label][months.index(month)] = 0
activity[label][months.index(month)] += -float(amt)
print(desc)
print(label)
print()
assert len(months) == num_months
labels = clf.labels()
labels.remove('Misc')
labels.remove('Investments')
labels.remove('Transfer')
def flip(list_): # B/c CSV was backwards
return list(reversed(list_))
def _range(n):
return list(range(num_months))
plt.stackplot(_range(num_months), [ flip(activity[label]) for label in labels ])
plt.legend(labels)
plt.ylabel('$')
plt.xlabel('Month')
plt.xticks(_range(num_months), flip(months))
plt.show()
#print(clf.pseudocode())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment