Created
March 5, 2024 00:40
-
-
Save mikesparr/98749f7ecaf09904304e1653f34bba00 to your computer and use it in GitHub Desktop.
Experiment with Word2Vec embedding of words in Python for study of GenAI and NLP solutions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import nltk | |
import string | |
import matplotlib.pyplot as plt | |
from nltk.corpus import stopwords | |
from nltk import word_tokenize | |
from gensim.models import Word2Vec as w2v | |
from sklearn.decomposition import PCA | |
# constants | |
PATH = 'data/shakespeare.txt' | |
sw = stopwords.words('english') | |
plt.style.use('ggplot') | |
#nltk.download('punkt') | |
#nltk.download('stopwords') | |
# import data | |
lines = [] | |
with open(PATH, 'r') as f: | |
for l in f: | |
lines.append(l) | |
# remove newlines chars | |
lines = [line.rstrip('\n') for line in lines] | |
# make all characters lower | |
lines = [line.lower() for line in lines] | |
# remove punctuations from each line | |
lines = [line.translate(str.maketrans('', '', string.punctuation)) for line in lines] | |
# tokenize | |
lines = [word_tokenize(line) for line in lines] | |
def remove_stopwords(lines, sw = sw): | |
''' | |
The purpose of this function is to remove stopwords from a given array of | |
lines. | |
params: | |
lines (Array / List) : The list of lines you want to remove the stopwords from | |
sw (Set) : The set of stopwords you want to remove | |
example: | |
lines = remove_stopwords(lines = lines, sw = sw) | |
''' | |
res = [] | |
for line in lines: | |
original = line | |
line = [w for w in line if w not in sw] | |
if len(line) < 1: | |
line = original | |
res.append(line) | |
return res | |
filtered_lines = remove_stopwords(lines = lines, sw = sw) | |
w = w2v( | |
filtered_lines, | |
min_count=3, | |
sg = 1, | |
window=7 | |
) | |
print(w.wv.most_similar('thou')) | |
emb_df = ( | |
pd.DataFrame( | |
[w.wv.get_vector(str(n)) for n in w.wv.key_to_index], | |
index = w.wv.key_to_index | |
) | |
) | |
print(emb_df.shape) | |
emb_df.head() | |
# pca = PCA(n_components=2, random_state=7) | |
# pca_mdl = pca.fit_transform(emb_df) | |
# emb_df_PCA = ( | |
# pd.DataFrame( | |
# pca_mdl, | |
# columns=['x','y'], | |
# index = emb_df.index | |
# ) | |
# ) | |
# plt.clf() | |
# fig = plt.figure(figsize=(6,4)) | |
# plt.scatter( | |
# x = emb_df_PCA['x'], | |
# y = emb_df_PCA['y'], | |
# s = 0.4, | |
# color = 'maroon', | |
# alpha = 0.5 | |
# ) | |
# plt.xlabel('PCA-1') | |
# plt.ylabel('PCA-2') | |
# plt.title('PCA Visualization') | |
# plt.plot() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment