Created
June 20, 2018 19:18
-
-
Save sshh12/d362dba3f72498cd9604de63a659b870 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from textblob.classifiers import DecisionTreeClassifier as TextClassifier | |
from datetime import datetime | |
import matplotlib.pyplot as plt | |
import csv | |
num_months = 13 | |
with open('train.csv', 'r') as fp: | |
clf = TextClassifier(fp, format="csv") | |
activity = {} | |
months = [] | |
with open('ChaseActivity.csv', newline='') as csvfile: | |
reader = csv.DictReader(csvfile) | |
for row in reader: | |
date = datetime.strptime(row['Posting Date'], '%m/%d/%Y') | |
desc = row['Description'] | |
amt = row['Amount'] | |
month = date.strftime('%b %Y') | |
label = clf.classify(desc) | |
if month not in months: | |
months.append(month) | |
if not label in activity: | |
activity[label] = [0] * num_months | |
if not month in activity[label]: | |
activity[label][months.index(month)] = 0 | |
activity[label][months.index(month)] += -float(amt) | |
print(desc) | |
print(label) | |
print() | |
assert len(months) == num_months | |
labels = clf.labels() | |
labels.remove('Misc') | |
labels.remove('Investments') | |
labels.remove('Transfer') | |
def flip(list_): # B/c CSV was backwards | |
return list(reversed(list_)) | |
def _range(n): | |
return list(range(num_months)) | |
plt.stackplot(_range(num_months), [ flip(activity[label]) for label in labels ]) | |
plt.legend(labels) | |
plt.ylabel('$') | |
plt.xlabel('Month') | |
plt.xticks(_range(num_months), flip(months)) | |
plt.show() | |
#print(clf.pseudocode()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment