Skip to content

Instantly share code, notes, and snippets.

@suzyahyah
Created July 15, 2022 06:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save suzyahyah/7ba0884141b756134f21d0a225718cf7 to your computer and use it in GitHub Desktop.
Save suzyahyah/7ba0884141b756134f21d0a225718cf7 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
from collections import Counter
import pdb
from scipy.stats import multinomial
import numpy as np
class Corpus():
def __init__(self, docs=[], labels=[], vocab=[]):
self.docs = docs
self.labels_ix = self._generate_ix(labels)
self.vocab_ix = self._generate_ix(vocab)
self.word_label_m = self._load_wl_matrix()
self.doc_label_m = self._load_dl_matrix()
def _generate_ix(self, li):
indexes = {}
for i in range(len(li)):
indexes[li[i]] = i
return indexes
def _load_dl_matrix(self):
mat = np.zeros((len(self.docs), len(self.labels_ix)))
for i, doc in enumerate(self.docs):
labels = doc.original_labels
for l in labels:
label_ix = self.labels_ix[l]
mat[i][label_ix] = 1
return mat
def _load_wl_matrix(self):
mat = np.zeros((len(self.vocab_ix), len(self.labels_ix)))
for doc in self.docs:
for w in doc.word_label.keys():
word_ix = self.vocab_ix[w]
label_ix = self.labels_ix[doc.word_label[w]]
mat[word_ix, label_ix]+=1
return mat
def get_counts_array(self):
return self.word_label_matrix.sum(axis=0)
#return [self.get_counts(label=label) for label in self.labels]
def get_counts(self, word=None, label=None):
if word is None:
label_ix = self.labels_ix[label]
return self.word_label_m[:, label_ix].sum()
if word is not None:
if word in self.vocab_ix:
word_ix = self.vocab_ix[word]
return self.word_label_m[word_ix]
else:
return np.ones(len(self.labels_ix))*0.1
# return sum([doc.get_counts(word=word, label=label) for doc in self.docs])
def set_word_label(self, word, label, val):
w = self.vocab_ix[word]
l = self.labels_ix[label]
self.word_label_m[w, l] += val
class Document():
def __init__(self, bow=None, labels=[]):
self.bow = bow
self.original_labels = labels
self.predicted_labels = []
self.word_label = self.init_word_label()
def get_counts_array(self, labels_ix):
ctr = Counter(self.word_label.values())
arr = np.zeros(len(labels_ix))
for k in ctr:
if k == "?":
continue
if k in labels_ix:
ix = labels_ix[k]
arr[ix] = ctr[k]
else:
continue
return arr
def get_counts(self, word=None, label=None):
if word is not None:
if label is not None:
return sum([1 for w in self.word_label.keys() if self.word_label[w]==label])
else:
print("Label must be specified")
if word is None:
if label is not None:
return sum([1 for lbl in self.word_label.items() if lbl[1] == label])
else:
return sum([1 for lbl in self.word_label.items() if lbl[1] != "?"])
def set_word_label(self, word, label):
self.word_label[word] = label
def init_word_label(self):
word_label = {}
p_label = [1/len(self.original_labels) for l in self.original_labels]
for word in self.bow:
label_ind = multinomial.rvs(1, p_label).argmax()
word_label[word] = self.original_labels[label_ind]
return word_label
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment