Skip to content

Instantly share code, notes, and snippets.

@Shikib
Last active May 6, 2020 23:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Shikib/ae258c9d6b4b611c0592ada7dc3b3bac to your computer and use it in GitHub Desktop.
Save Shikib/ae258c9d6b4b611c0592ada7dc3b3bac to your computer and use it in GitHub Desktop.
FED metrics code
import os
import json
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
import math
from transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
BertConfig, BertForMaskedLM, BertTokenizer,
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
tokenizer = GPT2Tokenizer.from_pretrained('dialogpt')
model = GPT2LMHeadModel.from_pretrained('gpt2')
weights = torch.load("dialogpt/small_fs.pkl")
weights = {k.replace("module.", ""): v for k,v in weights.items()}
weights["lm_head.weight"] = weights["lm_head.decoder.weight"]
weights.pop("lm_head.decoder.weight",None)
model.load_state_dict(weights)
model.to("cuda")
def score(text):
if not text.startswith("<|endoftext|> "):
text = "<|endoftext|> " + text
input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0) # Batch size 1
tokenize_input = tokenizer.tokenize(text)
#50256 is the token_id for <|endoftext|>
tensor_input = torch.tensor([ tokenizer.convert_tokens_to_ids(tokenize_input)]).cuda()
with torch.no_grad():
outputs = model(tensor_input, labels=tensor_input)
loss, logits = outputs[:2]
#For debugging
#lp = 0.0
#for i in range(len(tokenize_input)):
# masked_index = i
# predicted_score = torch.softmax(logits[0, masked_index], dim=0)
# predicted_prob = np.array(predicted_score)
# lp += np.log(predicted_prob[tokenizer.convert_tokens_to_ids([tokenize_input[i]])[0]])
#
#print("b=", lp)
return loss.item()
def load_convs():
path = "/home/shikib/dp_site/api/conv_data/"
conv_turns = []
conv_names = []
conv_fulls = []
conv_fullnames = []
for fn in os.listdir(path):
if not fn.startswith("conv") or len(open(path+fn).readlines()) == 0:
continue
for i in range(3):
if fn.split("_")[0] + "_" + str(i) in conv_names:
continue
data = json.loads(open(path + fn).readlines()[i])
past_turns = [e.split(": ")[1] for e in data["ctx"].split("\n")]
all_turns = past_turns + [ data["rsps"][0].split(": ")[1] ]
#all_turns = all_turns[-1:]
full_conv = " ".join(["<|endoftext|> " + turn.strip() for turn in all_turns])
conv_turns.append(full_conv)
conv_names.append(fn.split("_")[0] + "_" + str(i))
data = json.loads(open(path + fn).readlines()[3])
all_turns = [e.split(": ")[1] for e in data["ctx"].split("\n") if e]
full_conv = " ".join(["<|endoftext|> " + turn.strip() for turn in all_turns])
conv_fulls.append(full_conv)
conv_fullnames.append(fn.split("_")[0])
messages = {
"interesting": {
"positive": ["Wow that is really interesting.", "That's really interesting!", "Cool! That sounds super interesting."],
"negative": ["That's not very interesting.", "That's really boring.", "That was a really boring response."]
},
"engaging": {
"positive": ["Wow! That's really cool!", "Tell me more!", "I'm really interested in learning more about this."],
"negative": ["Let's change the topic.", "I don't really care. That's pretty boring.", "I want to talk about something else."]
},
"specific": {
"positive": ["That's good to know. Cool!", "I see, that's interesting.", "That's a good point."],
"negative": ["That's a very generic response.", "Not really relevant here.", "That's not really relevant here."]
},
"relevant": {
"positive": [],
"negative": ["That's not even related to what I said.", "Don't change the topic!", "Why are you changing the topic?"]
},
"correct": {
"positive": [],
"negative": ["You're not understanding me!", "I am so confused right now!", "I don't understand what you're saying."]
},
"semantically appropriate": {
"positive": ["That makes sense!", "You have a good point."],
"negative": ["That makes no sense!"]
},
"understandable": {
"positive": ["That makes sense!", "You have a good point."],
"negative": ["I don't understand at all!", "I'm so confused!", "That makes no sense!", "What does that even mean?"]
},
"fluent": {
"positive": ["That makes sense!", "You have a good point."],
"negative": ["Is that real English?", "I'm so confused right now!", "That makes no sense!"]
},
}
metrics = messages.keys()
metrics = []
for metric in metrics:
scores = []
utts = messages[metric]
print("Processing", metric)
pos = utts["positive"]
neg = utts["negative"]
for conv_turn in conv_turns:
orig_score = score(conv_turn + " <|endoftext|>")
high_score = 0
for m in pos:
hs = score(conv_turn + " <|endoftext|> " + m)
high_score += hs #- orig_score
high_score = high_score/max(len(pos), 1)
low_score = 0
for m in neg:
ls = score(conv_turn + " <|endoftext|> " + m)
low_score += ls #- orig_score
low_score = low_score/max(len(neg), 1)
scores.append(low_score - high_score)
#scores.append(-score(conv_turn))
score_map = {name:score for name,score in zip(conv_names, scores)}
assert len(conv_names) == len(scores)
json.dump(score_map, open("{0}_scores.json".format(metric), "w+"))
import pdb; pdb.set_trace()
messages = {
"coherent": {
"positive": [],
"negative": ["You're making no sense at all.", "You're changing the topic so much!", "You are so confusing."]
},
"error recovery": {
"positive": [],
"negative": ["I am so confused right now.", "You're really confusing.", "I don't understand what you're saying."]
},
"consistent": {
"positive": [],
"negative": ["That's not what you said earlier!", "Stop contradicting yourself!"],
},
"diverse": {
"positive": [],
"negative": ["Stop saying the same thing repeatedly.", "Why are you repeating yourself?", "Stop repeating yourself!"]
},
"depth": {
"positive": [],
"negative": ["Stop changing the topic so much.", "Don't change the topic!"],
},
"likeable": {
"positive": ["I like you!", "You're super polite and fun to talk to", "Great talking to you."],
"negative": ["You're not very nice.", "You're not very fun to talk to.", "I don't like you."]
},
"understand": {
"positive": [],
"negative": ["You're not understanding me!", "What are you trying to say?", "I don't understand what you're saying."]
},
"flexible": {
"positive": ["You're very easy to talk to!", "Wow you can talk about a lot of things!"],
"negative": ["I don't want to talk about that!", "Do you know how to talk about something else?"],
},
"informative": {
"positive": ["Thanks for all the information!", "Wow that's a lot of information.", "You know a lot of facts!"],
"negative": ["You're really boring.", "You don't really know much."],
},
"inquisitive": {
"positive": ["You ask a lot of questions!", "That's a lot of questions!"],
"negative": ["You don't ask many questions.", "You don't seem interested."],
},
}
import pdb; pdb.set_trace()
metrics = messages.keys()
metrics = []
for metric in metrics:
scores = []
utts = messages[metric]
print("Processing", metric)
pos = utts["positive"]
neg = utts["negative"]
for conv_turn in conv_fulls:
orig_score = score(conv_turn + " <|endoftext|>")
high_score = 0
for m in pos:
hs = score(conv_turn + " <|endoftext|> " + m)
high_score += hs
high_score = high_score/max(len(pos), 1)
low_score = 0
for m in neg:
ls = score(conv_turn + " <|endoftext|> " + m)
low_score += ls
low_score = low_score/max(len(neg), 1)
scores.append(low_score - high_score)
score_map = {name:score for name,score in zip(conv_fullnames, scores)}
json.dump(score_map, open("{0}_scores.json".format(metric), "w+"))
a=['There is a book on the desk',
'There is a plane on the desk',
'There is a dog the desk',
'There is a book in the desk']
print([score(i) for i in a])
load_convs()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment