Skip to content

Instantly share code, notes, and snippets.

@dnicolodi
Last active December 24, 2021 20:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dnicolodi/063d1898f4deb2ca64f7df0ef410333e to your computer and use it in GitHub Desktop.
Save dnicolodi/063d1898f4deb2ca64f7df0ef410333e to your computer and use it in GitHub Desktop.
import click
import decimal
from beancount.core import data
from beancount import loader
from beancount.parser import parser
from beangadgets import printer
from sklearn.feature_extraction.text import TfidfVectorizer, TfidfTransformer, CountVectorizer
from sklearn.svm import LinearSVC
PENDING = '!'
CLEARED = '*'
def target(t):
# gross approximation: the account we want to predict is not a Checking account or Cash
accounts = []
for p in t.postings:
if p.units is None:
continue
parts = p.account.split(':')
if 'Checking' in parts or 'Cash' in parts:
continue
accounts.append((p.units, p.account))
accounts.sort(key=lambda x: x[0], reverse=True)
return accounts[0][1]
def smart(f):
@wraps(f)
def wrap(self, filepath, existing):
entries = f(self, filepath, existing)
return postings(entries, existing)
return wrap
def postings(entries, existing):
if not existing:
return entries
narrations = []
targets = []
for t in data.filter_txns(existing):
narrations.append(t.payee or t.narration)
targets.append(target(t))
vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(narrations)
classifier = LinearSVC().fit(X, targets)
for entry in entries:
if not isinstance(entry, data.Transaction):
yield entry
continue
if len(entry.postings) > 1:
yield entry
continue
Y = vectorizer.transform([entry.narration])
scores = classifier.decision_function(Y)
index = scores.argmax(axis=1)
score = scores[0, index].squeeze()
if score < 0.2:
yield entry
continue
account = classifier.classes_[index][0]
flag = CLEARED if score > 0.5 else PENDING
print(entry.narration, account, score)
entry = entry._replace(flag=flag)
entry.meta['score'] = decimal.Decimal(f'{score:.2f}')
entry.postings.append(data.Posting(account, None, None, None, None, None))
yield entry
@click.command()
@click.argument('ledger', type=click.Path(exists=True))
@click.argument('existing', type=click.Path(exists=True))
def main(ledger, existing):
entries, errors, options = parser.parse_file(ledger)
existing, errors, options = loader.load_file(existing)
entries = list(postings(entries, existing))
printer.print_entries(entries)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment