Created
December 29, 2016 04:50
-
-
Save vgoklani/e33973d3202639e1f021bd745b8971c2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
from __future__ import print_function | |
import os, sys, traceback, random | |
random.seed(42) | |
input_directory = "data/cornell_movie_dialogs_corpus" | |
output_directory = "data/cornell_movie_dialogs_corpus/processed" | |
if os.path.exists(output_directory) is False: | |
os.makedirs(output_directory) | |
def process_movie_lines(filename='movie_lines.txt'): | |
with open(os.path.join(input_directory, filename)) as f: | |
lines = [_.strip() for _ in f.readlines()] | |
line_id_text_mapping = {} | |
for line in lines: | |
tokens = line.split(' +++$+++ ') | |
if len(tokens) == 5: | |
line_id, character_id, movie_id, character_name, text = tokens | |
else: | |
line_id, character_id, movie_id, character_name = tokens | |
text = "" | |
line_id_text_mapping[line_id] = text | |
return line_id_text_mapping | |
def process_conversations(filename='movie_conversations.txt'): | |
with open(os.path.join(input_directory, filename)) as f: | |
lines = [_.strip() for _ in f.readlines()] | |
conversations = [] | |
for line in lines: | |
character_id_first, character_id_second, movie_id, utterances_str = line.split(' +++$+++ ') | |
utterances = [_.strip() for _ in utterances_str[1:-1].replace("\'", "").strip().split(",")] | |
conversations.append(utterances) | |
return conversations | |
def generate_output(line_id_text_mapping, conversations, split_ratio=0.8): | |
data = [] | |
for conversation in conversations: | |
for index in range(len(conversation)): | |
if index < len(conversation) - 1: | |
try: | |
question = line_id_text_mapping[conversation[index]].encode('utf-8') | |
answer = line_id_text_mapping[conversation[index+1]].encode('utf-8') | |
data.append( {"question": question, "answer": answer}) | |
except Exception as e: | |
sys.stderr.write("\nerror -> %s" % e.message) | |
sys.stderr.write('\n\t' + str(traceback.print_exc())) | |
random.shuffle(data) | |
cutoff = int(len(data)*split_ratio) | |
encoder_train = [_['question'] for _ in data[:cutoff]] | |
decoder_train = [_['answer'] for _ in data[:cutoff]] | |
encoder_test = [_['question'] for _ in data[cutoff:]] | |
decoder_test = [_['answer'] for _ in data[cutoff:]] | |
with open(os.path.join(output_directory, "train_encoder.txt"), 'wb') as f: | |
for line in encoder_train: | |
f.write( line + "\n") | |
with open(os.path.join(output_directory, "train_decoder.txt"), 'wb') as f: | |
for line in decoder_train: | |
f.write( line + "\n") | |
with open(os.path.join(output_directory, "test_encoder.txt"), 'wb') as f: | |
for line in encoder_test: | |
f.write( line + "\n") | |
with open(os.path.join(output_directory, "test_decoder.txt"), 'wb') as f: | |
for line in decoder_test: | |
f.write( line + "\n") | |
def main(): | |
line_id_text_mapping = process_movie_lines() | |
conversations = process_conversations() | |
generate_output(line_id_text_mapping, conversations) | |
if __name__== '__main__': | |
try: | |
__IPYTHON__ | |
print ('\nrunning via ipython -> not running continously') | |
except NameError: | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment