Skip to content

Instantly share code, notes, and snippets.

@vgoklani
Created December 29, 2016 04:50
Show Gist options
  • Save vgoklani/e33973d3202639e1f021bd745b8971c2 to your computer and use it in GitHub Desktop.
Save vgoklani/e33973d3202639e1f021bd745b8971c2 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
from __future__ import print_function
import os, sys, traceback, random
random.seed(42)
input_directory = "data/cornell_movie_dialogs_corpus"
output_directory = "data/cornell_movie_dialogs_corpus/processed"
if os.path.exists(output_directory) is False:
os.makedirs(output_directory)
def process_movie_lines(filename='movie_lines.txt'):
with open(os.path.join(input_directory, filename)) as f:
lines = [_.strip() for _ in f.readlines()]
line_id_text_mapping = {}
for line in lines:
tokens = line.split(' +++$+++ ')
if len(tokens) == 5:
line_id, character_id, movie_id, character_name, text = tokens
else:
line_id, character_id, movie_id, character_name = tokens
text = ""
line_id_text_mapping[line_id] = text
return line_id_text_mapping
def process_conversations(filename='movie_conversations.txt'):
with open(os.path.join(input_directory, filename)) as f:
lines = [_.strip() for _ in f.readlines()]
conversations = []
for line in lines:
character_id_first, character_id_second, movie_id, utterances_str = line.split(' +++$+++ ')
utterances = [_.strip() for _ in utterances_str[1:-1].replace("\'", "").strip().split(",")]
conversations.append(utterances)
return conversations
def generate_output(line_id_text_mapping, conversations, split_ratio=0.8):
data = []
for conversation in conversations:
for index in range(len(conversation)):
if index < len(conversation) - 1:
try:
question = line_id_text_mapping[conversation[index]].encode('utf-8')
answer = line_id_text_mapping[conversation[index+1]].encode('utf-8')
data.append( {"question": question, "answer": answer})
except Exception as e:
sys.stderr.write("\nerror -> %s" % e.message)
sys.stderr.write('\n\t' + str(traceback.print_exc()))
random.shuffle(data)
cutoff = int(len(data)*split_ratio)
encoder_train = [_['question'] for _ in data[:cutoff]]
decoder_train = [_['answer'] for _ in data[:cutoff]]
encoder_test = [_['question'] for _ in data[cutoff:]]
decoder_test = [_['answer'] for _ in data[cutoff:]]
with open(os.path.join(output_directory, "train_encoder.txt"), 'wb') as f:
for line in encoder_train:
f.write( line + "\n")
with open(os.path.join(output_directory, "train_decoder.txt"), 'wb') as f:
for line in decoder_train:
f.write( line + "\n")
with open(os.path.join(output_directory, "test_encoder.txt"), 'wb') as f:
for line in encoder_test:
f.write( line + "\n")
with open(os.path.join(output_directory, "test_decoder.txt"), 'wb') as f:
for line in decoder_test:
f.write( line + "\n")
def main():
line_id_text_mapping = process_movie_lines()
conversations = process_conversations()
generate_output(line_id_text_mapping, conversations)
if __name__== '__main__':
try:
__IPYTHON__
print ('\nrunning via ipython -> not running continously')
except NameError:
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment