Created
July 15, 2022 06:19
-
-
Save suzyahyah/7ba0884141b756134f21d0a225718cf7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
from __future__ import division | |
from collections import Counter | |
import pdb | |
from scipy.stats import multinomial | |
import numpy as np | |
class Corpus(): | |
def __init__(self, docs=[], labels=[], vocab=[]): | |
self.docs = docs | |
self.labels_ix = self._generate_ix(labels) | |
self.vocab_ix = self._generate_ix(vocab) | |
self.word_label_m = self._load_wl_matrix() | |
self.doc_label_m = self._load_dl_matrix() | |
def _generate_ix(self, li): | |
indexes = {} | |
for i in range(len(li)): | |
indexes[li[i]] = i | |
return indexes | |
def _load_dl_matrix(self): | |
mat = np.zeros((len(self.docs), len(self.labels_ix))) | |
for i, doc in enumerate(self.docs): | |
labels = doc.original_labels | |
for l in labels: | |
label_ix = self.labels_ix[l] | |
mat[i][label_ix] = 1 | |
return mat | |
def _load_wl_matrix(self): | |
mat = np.zeros((len(self.vocab_ix), len(self.labels_ix))) | |
for doc in self.docs: | |
for w in doc.word_label.keys(): | |
word_ix = self.vocab_ix[w] | |
label_ix = self.labels_ix[doc.word_label[w]] | |
mat[word_ix, label_ix]+=1 | |
return mat | |
def get_counts_array(self): | |
return self.word_label_matrix.sum(axis=0) | |
#return [self.get_counts(label=label) for label in self.labels] | |
def get_counts(self, word=None, label=None): | |
if word is None: | |
label_ix = self.labels_ix[label] | |
return self.word_label_m[:, label_ix].sum() | |
if word is not None: | |
if word in self.vocab_ix: | |
word_ix = self.vocab_ix[word] | |
return self.word_label_m[word_ix] | |
else: | |
return np.ones(len(self.labels_ix))*0.1 | |
# return sum([doc.get_counts(word=word, label=label) for doc in self.docs]) | |
def set_word_label(self, word, label, val): | |
w = self.vocab_ix[word] | |
l = self.labels_ix[label] | |
self.word_label_m[w, l] += val | |
class Document(): | |
def __init__(self, bow=None, labels=[]): | |
self.bow = bow | |
self.original_labels = labels | |
self.predicted_labels = [] | |
self.word_label = self.init_word_label() | |
def get_counts_array(self, labels_ix): | |
ctr = Counter(self.word_label.values()) | |
arr = np.zeros(len(labels_ix)) | |
for k in ctr: | |
if k == "?": | |
continue | |
if k in labels_ix: | |
ix = labels_ix[k] | |
arr[ix] = ctr[k] | |
else: | |
continue | |
return arr | |
def get_counts(self, word=None, label=None): | |
if word is not None: | |
if label is not None: | |
return sum([1 for w in self.word_label.keys() if self.word_label[w]==label]) | |
else: | |
print("Label must be specified") | |
if word is None: | |
if label is not None: | |
return sum([1 for lbl in self.word_label.items() if lbl[1] == label]) | |
else: | |
return sum([1 for lbl in self.word_label.items() if lbl[1] != "?"]) | |
def set_word_label(self, word, label): | |
self.word_label[word] = label | |
def init_word_label(self): | |
word_label = {} | |
p_label = [1/len(self.original_labels) for l in self.original_labels] | |
for word in self.bow: | |
label_ind = multinomial.rvs(1, p_label).argmax() | |
word_label[word] = self.original_labels[label_ind] | |
return word_label |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment