Last active
July 22, 2020 19:50
-
-
Save Phil1108/7fc35be4a986d751ab3ce2f6dbd6efd8 to your computer and use it in GitHub Desktop.
Play with German Electra (Word similarity comparison)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Sat Jul 18 21:37:41 2020 | |
@author: philipp | |
""" | |
""" | |
CONVERT RAW TENSORFLOW OUTPUT FIRST: (Download this file from transformers repo) | |
python convert_electra_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path=/media/data/48_BERT/12_Training_Output/01_Electra/BASE_100k_CC_1/Electra_models_Electra_german_CC_model.ckpt-95000 --config_file=/media/data/48_BERT/12_Training_Output/01_Electra/BASE_100k_CC_1/config.json --pytorch_dump_path=/media/data/48_BERT/12_Training_Output/01_Electra/BASE_100k_CC_1/pytorch_model.bin --discriminator_or_generator='discriminator' | |
""" | |
import pdb | |
from transformers import ElectraTokenizer, TFElectraModel, ElectraModel, ElectraForPreTraining,ElectraForMaskedLM,BertTokenizer,BertModel | |
import tensorflow as tf | |
import torch | |
def get_similarity(model, tokenizer, word, synonym, model_name): | |
inputs = tokenizer(word, return_tensors="pt")["input_ids"] | |
inputs2 = tokenizer(synonym, return_tensors="pt")["input_ids"] | |
print(synonym, "<->" , word, end="\t") | |
cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6) | |
outputs = model(inputs) | |
outputs2 = model(inputs2) | |
if model_name == 'ELECTRA': | |
#pdb.set_trace() | |
a = torch.mean(torch.stack([outputs[1][i][0][1] for i in range(8,12)]), dim=0) #Select embedding of first word (because of cls) | |
a2 = torch.mean(torch.stack([outputs2[1][i][0][1] for i in range(8,12)]), dim=0) | |
if model_name == 'BERT': | |
a = torch.mean(torch.stack([outputs[2][i][0][1] for i in range(8,12)]), dim=0) #Select embedding of first word (because of cls) | |
a2 = torch.mean(torch.stack([outputs2[2][i][0][1] for i in range(8,12)]), dim=0) | |
similarity = cos(a, a2) | |
print(float(similarity)) | |
print(len(inputs[0])) | |
if __name__ == '__main__': | |
tokenizer_electra = ElectraTokenizer.from_pretrained('/media/data/48_BERT/12_Training_Output/01_Electra/BASE_100k_CC_1/') | |
model_electra = ElectraForMaskedLM.from_pretrained('/media/data/48_BERT/12_Training_Output/01_Electra/BASE_100k_CC_1/') | |
#model = ElectraForPreTraining.from_pretrained('/media/data/48_BERT/12_Training_Output/01_Electra/BASE_100k_CC_1/')#, output_hidden_states=True) | |
#model = ElectraModel.from_pretrained('/media/data/48_BERT/12_Training_Output/01_Electra/BASE_100k_CC_1/') | |
model_electra.config.output_hidden_states=True | |
tokenizer_bert = BertTokenizer.from_pretrained('bert-base-german-cased') | |
model_bert = BertModel.from_pretrained('bert-base-german-cased') | |
model_bert.config.output_hidden_states=True | |
#synonyms = ["Ruhrgebiet", "Dortmund"] | |
#synonyms = ["Rhein", "Köln", "Dom"] | |
synonyms = ["PKW", "Fahrzeug", "Kraftfahrzeug", "Autobahn", "selbst", "eigen"] | |
#synonyms = ["Bergwerk", "Grube", "Mine", "Stollen", "Pütt", "Honorarnote", "Kostennote"] | |
#word = "Zeche" | |
word = "Auto" | |
for synonym in synonyms: | |
print('BERT', end="\t") | |
get_similarity(model_bert, tokenizer_bert, word, synonym, "BERT") | |
print('ELECTRA', end="\t") | |
get_similarity(model_electra, tokenizer_electra, word, synonym, "ELECTRA") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Only works with words which are in the vocab file.
So if
len(inputs[0])
is anything other than 3 the word is combined from multiple tokens and this script will not be useful