Skip to content

Instantly share code, notes, and snippets.

@Phil1108
Last active July 22, 2020 19:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Phil1108/7fc35be4a986d751ab3ce2f6dbd6efd8 to your computer and use it in GitHub Desktop.
Save Phil1108/7fc35be4a986d751ab3ce2f6dbd6efd8 to your computer and use it in GitHub Desktop.
Play with German Electra (Word similarity comparison)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Jul 18 21:37:41 2020
@author: philipp
"""
"""
CONVERT RAW TENSORFLOW OUTPUT FIRST: (Download this file from transformers repo)
python convert_electra_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path=/media/data/48_BERT/12_Training_Output/01_Electra/BASE_100k_CC_1/Electra_models_Electra_german_CC_model.ckpt-95000 --config_file=/media/data/48_BERT/12_Training_Output/01_Electra/BASE_100k_CC_1/config.json --pytorch_dump_path=/media/data/48_BERT/12_Training_Output/01_Electra/BASE_100k_CC_1/pytorch_model.bin --discriminator_or_generator='discriminator'
"""
import pdb
from transformers import ElectraTokenizer, TFElectraModel, ElectraModel, ElectraForPreTraining,ElectraForMaskedLM,BertTokenizer,BertModel
import tensorflow as tf
import torch
def get_similarity(model, tokenizer, word, synonym, model_name):
inputs = tokenizer(word, return_tensors="pt")["input_ids"]
inputs2 = tokenizer(synonym, return_tensors="pt")["input_ids"]
print(synonym, "<->" , word, end="\t")
cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
outputs = model(inputs)
outputs2 = model(inputs2)
if model_name == 'ELECTRA':
#pdb.set_trace()
a = torch.mean(torch.stack([outputs[1][i][0][1] for i in range(8,12)]), dim=0) #Select embedding of first word (because of cls)
a2 = torch.mean(torch.stack([outputs2[1][i][0][1] for i in range(8,12)]), dim=0)
if model_name == 'BERT':
a = torch.mean(torch.stack([outputs[2][i][0][1] for i in range(8,12)]), dim=0) #Select embedding of first word (because of cls)
a2 = torch.mean(torch.stack([outputs2[2][i][0][1] for i in range(8,12)]), dim=0)
similarity = cos(a, a2)
print(float(similarity))
print(len(inputs[0]))
if __name__ == '__main__':
tokenizer_electra = ElectraTokenizer.from_pretrained('/media/data/48_BERT/12_Training_Output/01_Electra/BASE_100k_CC_1/')
model_electra = ElectraForMaskedLM.from_pretrained('/media/data/48_BERT/12_Training_Output/01_Electra/BASE_100k_CC_1/')
#model = ElectraForPreTraining.from_pretrained('/media/data/48_BERT/12_Training_Output/01_Electra/BASE_100k_CC_1/')#, output_hidden_states=True)
#model = ElectraModel.from_pretrained('/media/data/48_BERT/12_Training_Output/01_Electra/BASE_100k_CC_1/')
model_electra.config.output_hidden_states=True
tokenizer_bert = BertTokenizer.from_pretrained('bert-base-german-cased')
model_bert = BertModel.from_pretrained('bert-base-german-cased')
model_bert.config.output_hidden_states=True
#synonyms = ["Ruhrgebiet", "Dortmund"]
#synonyms = ["Rhein", "Köln", "Dom"]
synonyms = ["PKW", "Fahrzeug", "Kraftfahrzeug", "Autobahn", "selbst", "eigen"]
#synonyms = ["Bergwerk", "Grube", "Mine", "Stollen", "Pütt", "Honorarnote", "Kostennote"]
#word = "Zeche"
word = "Auto"
for synonym in synonyms:
print('BERT', end="\t")
get_similarity(model_bert, tokenizer_bert, word, synonym, "BERT")
print('ELECTRA', end="\t")
get_similarity(model_electra, tokenizer_electra, word, synonym, "ELECTRA")
@Phil1108
Copy link
Author

Only works with words which are in the vocab file.
So if len(inputs[0]) is anything other than 3 the word is combined from multiple tokens and this script will not be useful

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment