Last active
May 6, 2020 23:34
-
-
Save Shikib/ae258c9d6b4b611c0592ada7dc3b3bac to your computer and use it in GitHub Desktop.
FED metrics code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler | |
from torch.utils.data.distributed import DistributedSampler | |
import math | |
from transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule, | |
BertConfig, BertForMaskedLM, BertTokenizer, | |
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, | |
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, | |
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, | |
DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer) | |
tokenizer = GPT2Tokenizer.from_pretrained('dialogpt') | |
model = GPT2LMHeadModel.from_pretrained('gpt2') | |
weights = torch.load("dialogpt/small_fs.pkl") | |
weights = {k.replace("module.", ""): v for k,v in weights.items()} | |
weights["lm_head.weight"] = weights["lm_head.decoder.weight"] | |
weights.pop("lm_head.decoder.weight",None) | |
model.load_state_dict(weights) | |
model.to("cuda") | |
def score(text): | |
if not text.startswith("<|endoftext|> "): | |
text = "<|endoftext|> " + text | |
input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0) # Batch size 1 | |
tokenize_input = tokenizer.tokenize(text) | |
#50256 is the token_id for <|endoftext|> | |
tensor_input = torch.tensor([ tokenizer.convert_tokens_to_ids(tokenize_input)]).cuda() | |
with torch.no_grad(): | |
outputs = model(tensor_input, labels=tensor_input) | |
loss, logits = outputs[:2] | |
#For debugging | |
#lp = 0.0 | |
#for i in range(len(tokenize_input)): | |
# masked_index = i | |
# predicted_score = torch.softmax(logits[0, masked_index], dim=0) | |
# predicted_prob = np.array(predicted_score) | |
# lp += np.log(predicted_prob[tokenizer.convert_tokens_to_ids([tokenize_input[i]])[0]]) | |
# | |
#print("b=", lp) | |
return loss.item() | |
def load_convs(): | |
path = "/home/shikib/dp_site/api/conv_data/" | |
conv_turns = [] | |
conv_names = [] | |
conv_fulls = [] | |
conv_fullnames = [] | |
for fn in os.listdir(path): | |
if not fn.startswith("conv") or len(open(path+fn).readlines()) == 0: | |
continue | |
for i in range(3): | |
if fn.split("_")[0] + "_" + str(i) in conv_names: | |
continue | |
data = json.loads(open(path + fn).readlines()[i]) | |
past_turns = [e.split(": ")[1] for e in data["ctx"].split("\n")] | |
all_turns = past_turns + [ data["rsps"][0].split(": ")[1] ] | |
#all_turns = all_turns[-1:] | |
full_conv = " ".join(["<|endoftext|> " + turn.strip() for turn in all_turns]) | |
conv_turns.append(full_conv) | |
conv_names.append(fn.split("_")[0] + "_" + str(i)) | |
data = json.loads(open(path + fn).readlines()[3]) | |
all_turns = [e.split(": ")[1] for e in data["ctx"].split("\n") if e] | |
full_conv = " ".join(["<|endoftext|> " + turn.strip() for turn in all_turns]) | |
conv_fulls.append(full_conv) | |
conv_fullnames.append(fn.split("_")[0]) | |
messages = { | |
"interesting": { | |
"positive": ["Wow that is really interesting.", "That's really interesting!", "Cool! That sounds super interesting."], | |
"negative": ["That's not very interesting.", "That's really boring.", "That was a really boring response."] | |
}, | |
"engaging": { | |
"positive": ["Wow! That's really cool!", "Tell me more!", "I'm really interested in learning more about this."], | |
"negative": ["Let's change the topic.", "I don't really care. That's pretty boring.", "I want to talk about something else."] | |
}, | |
"specific": { | |
"positive": ["That's good to know. Cool!", "I see, that's interesting.", "That's a good point."], | |
"negative": ["That's a very generic response.", "Not really relevant here.", "That's not really relevant here."] | |
}, | |
"relevant": { | |
"positive": [], | |
"negative": ["That's not even related to what I said.", "Don't change the topic!", "Why are you changing the topic?"] | |
}, | |
"correct": { | |
"positive": [], | |
"negative": ["You're not understanding me!", "I am so confused right now!", "I don't understand what you're saying."] | |
}, | |
"semantically appropriate": { | |
"positive": ["That makes sense!", "You have a good point."], | |
"negative": ["That makes no sense!"] | |
}, | |
"understandable": { | |
"positive": ["That makes sense!", "You have a good point."], | |
"negative": ["I don't understand at all!", "I'm so confused!", "That makes no sense!", "What does that even mean?"] | |
}, | |
"fluent": { | |
"positive": ["That makes sense!", "You have a good point."], | |
"negative": ["Is that real English?", "I'm so confused right now!", "That makes no sense!"] | |
}, | |
} | |
metrics = messages.keys() | |
metrics = [] | |
for metric in metrics: | |
scores = [] | |
utts = messages[metric] | |
print("Processing", metric) | |
pos = utts["positive"] | |
neg = utts["negative"] | |
for conv_turn in conv_turns: | |
orig_score = score(conv_turn + " <|endoftext|>") | |
high_score = 0 | |
for m in pos: | |
hs = score(conv_turn + " <|endoftext|> " + m) | |
high_score += hs #- orig_score | |
high_score = high_score/max(len(pos), 1) | |
low_score = 0 | |
for m in neg: | |
ls = score(conv_turn + " <|endoftext|> " + m) | |
low_score += ls #- orig_score | |
low_score = low_score/max(len(neg), 1) | |
scores.append(low_score - high_score) | |
#scores.append(-score(conv_turn)) | |
score_map = {name:score for name,score in zip(conv_names, scores)} | |
assert len(conv_names) == len(scores) | |
json.dump(score_map, open("{0}_scores.json".format(metric), "w+")) | |
import pdb; pdb.set_trace() | |
messages = { | |
"coherent": { | |
"positive": [], | |
"negative": ["You're making no sense at all.", "You're changing the topic so much!", "You are so confusing."] | |
}, | |
"error recovery": { | |
"positive": [], | |
"negative": ["I am so confused right now.", "You're really confusing.", "I don't understand what you're saying."] | |
}, | |
"consistent": { | |
"positive": [], | |
"negative": ["That's not what you said earlier!", "Stop contradicting yourself!"], | |
}, | |
"diverse": { | |
"positive": [], | |
"negative": ["Stop saying the same thing repeatedly.", "Why are you repeating yourself?", "Stop repeating yourself!"] | |
}, | |
"depth": { | |
"positive": [], | |
"negative": ["Stop changing the topic so much.", "Don't change the topic!"], | |
}, | |
"likeable": { | |
"positive": ["I like you!", "You're super polite and fun to talk to", "Great talking to you."], | |
"negative": ["You're not very nice.", "You're not very fun to talk to.", "I don't like you."] | |
}, | |
"understand": { | |
"positive": [], | |
"negative": ["You're not understanding me!", "What are you trying to say?", "I don't understand what you're saying."] | |
}, | |
"flexible": { | |
"positive": ["You're very easy to talk to!", "Wow you can talk about a lot of things!"], | |
"negative": ["I don't want to talk about that!", "Do you know how to talk about something else?"], | |
}, | |
"informative": { | |
"positive": ["Thanks for all the information!", "Wow that's a lot of information.", "You know a lot of facts!"], | |
"negative": ["You're really boring.", "You don't really know much."], | |
}, | |
"inquisitive": { | |
"positive": ["You ask a lot of questions!", "That's a lot of questions!"], | |
"negative": ["You don't ask many questions.", "You don't seem interested."], | |
}, | |
} | |
import pdb; pdb.set_trace() | |
metrics = messages.keys() | |
metrics = [] | |
for metric in metrics: | |
scores = [] | |
utts = messages[metric] | |
print("Processing", metric) | |
pos = utts["positive"] | |
neg = utts["negative"] | |
for conv_turn in conv_fulls: | |
orig_score = score(conv_turn + " <|endoftext|>") | |
high_score = 0 | |
for m in pos: | |
hs = score(conv_turn + " <|endoftext|> " + m) | |
high_score += hs | |
high_score = high_score/max(len(pos), 1) | |
low_score = 0 | |
for m in neg: | |
ls = score(conv_turn + " <|endoftext|> " + m) | |
low_score += ls | |
low_score = low_score/max(len(neg), 1) | |
scores.append(low_score - high_score) | |
score_map = {name:score for name,score in zip(conv_fullnames, scores)} | |
json.dump(score_map, open("{0}_scores.json".format(metric), "w+")) | |
a=['There is a book on the desk', | |
'There is a plane on the desk', | |
'There is a dog the desk', | |
'There is a book in the desk'] | |
print([score(i) for i in a]) | |
load_convs() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment