Created
July 29, 2022 15:06
-
-
Save Koziev/4781d2bcfe3ac95fc9494eebd53aa8c5 to your computer and use it in GitHub Desktop.
Эксперимент с визуализацией эмбеддингов токенов в rugpt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Эксперимент с визуализацией эмбеддингов токенов в rugpt. | |
""" | |
import os | |
import io | |
import collections | |
import torch | |
import transformers | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
from sklearn.cluster import KMeans | |
import scipy | |
import numpy as np | |
from sklearn.manifold import TSNE | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
if __name__ == '__main__': | |
proj_dir = os.path.expanduser('~/polygon/chatbot') | |
print('Loading GPT...') | |
device = 'cpu' | |
model_path = os.path.join(proj_dir, '/media/inkoziev/corpora/EmbeddingModels/rugpt3medium_based_on_gpt2') | |
tokenizer = GPT2Tokenizer.from_pretrained(model_path) | |
tokenizer.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>', 'pad_token': '<pad>'}) | |
model = GPT2LMHeadModel.from_pretrained(model_path) | |
model.to(device) | |
model.eval() | |
for name, layer in list(model.named_modules()): | |
if name == 'transformer.wte': | |
vectors = layer.weight.detach().cpu().numpy() | |
print('vectors.shape={}'.format(vectors.shape)) | |
break | |
print('Fitting t-SNE...', end='', flush=True) | |
tsne = TSNE(n_components=2, learning_rate='auto', init='random') | |
vectors2 = tsne.fit_transform(vectors) | |
print(' done') | |
male_animals = ['медведь', 'лев', 'тигр', 'леопард', 'морж', 'тюлень', 'барсук', 'енот', 'опоссум', 'скунс', 'лось', | |
'олень', 'волк', 'поросенок', 'кабан', 'конь', 'бык', 'баран', 'козел', 'козлик', 'вомбат', | |
'бобр', 'вепрь', 'суслик', 'зубр', 'гепард', 'жираф', 'дельфин', 'осьминог', 'мангуст', 'кальмар', | |
'еж', 'ёжик', 'ослик', 'буйвол', 'телёнок', 'шакал', 'мамонт', 'леопард', 'ишак', 'буйвол', 'бык', | |
'заяц', 'кролик', 'сурок', 'крот'] | |
female_animals = ['рысь', 'мышь', 'кошка', 'лиса', 'свинья', 'лошадь', 'корова', 'овца', 'коза', 'антилопа', 'львица', 'тигрица', | |
'бобриха', 'горилла', 'мартышка', 'панда', 'лисичка', 'мышка', 'крыса', 'росомаха', 'буйволица', | |
'пума', 'лошадка', 'кошечка', 'собачка', 'горилла', 'капибара', 'медведица', 'овца', 'коза'] | |
male_names = ['Костя', 'Валера', 'Петя', 'Иосиф', 'Георг', 'Николай', 'Пётр', 'Олег', 'Игорь', 'Влад', 'Ян', 'Гюнтер', 'Павел'] | |
female_names = ['Валя', 'Оля', 'Ольга', 'Настя', 'Марина', 'Инга', 'Анна', 'Аня', 'Света', 'Яна'] | |
substances = ['натрий', 'калий', 'олово', 'железо', 'уран', 'водород', 'кислород', 'гелий', 'марганец', 'никель', | |
'цезий', 'криптон', 'аргон', 'углерод', 'фтор', 'хлор', 'кремний'] | |
movement_verbs = ['бежать', 'бегать', 'ползти', 'перемещаться', 'лететь', 'приползти', 'улететь', 'мчаться', 'идти', | |
'шагать', 'прыгать', 'спускаться', 'преследовать', 'настигать', 'уходить'] | |
data = [] | |
markers = dict() | |
for words, category, marker in [(male_animals + female_animals, 'животное', 'o'), | |
(male_names+female_names, 'имя', 'x'), | |
(substances, 'вещество', '+'), | |
(movement_verbs, 'глаголы движения', 'V')]: | |
markers[category] = marker | |
for word in words: | |
tx = tokenizer.encode(word, add_special_tokens=False, return_tensors=None) | |
word_vectors = [] | |
for token_id in tx: | |
word_vectors.append(vectors2[token_id]) | |
v2 = np.average(word_vectors, axis=0) | |
data.append((v2[0], v2[1], category)) # 2 компоненты этого эмбеддинга будут "x" и "y" | |
df = pd.DataFrame(columns=['x', 'y', 'category'], data=data) | |
plt.figure(figsize=(16, 10)) | |
pict = sns.scatterplot(data=df, x="x", y="y", hue="category", markers=markers) | |
pict.set(title='t-SNE проекция векторов из слоя Token Embedding в rugpt-medium') | |
pict.figure.savefig(os.path.join(proj_dir, 'tmp', 'tsne_vizualisation_of_gpt_token_embeddings.png')) | |
print('All done :)') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment