Skip to content

Instantly share code, notes, and snippets.

@smrmkt
Created September 13, 2017 05:25
Show Gist options
  • Save smrmkt/eabb61031ab757bff766e998cd9fce3c to your computer and use it in GitHub Desktop.
Save smrmkt/eabb61031ab757bff766e998cd9fce3c to your computer and use it in GitHub Desktop.
seq2seq calc source data
import math
import csv
import random
random.seed(10)
n_numbers = 3 # how many numbers do we want to operate on
largest = 10 #largest INT
# Character set
character_set = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '*', ' ']
data_dim = len(character_set) # size of our vocab; also one hot encoding size
input_seq_length = 8 # max_digit_len + n_operators = n_numbers*2 + 2
output_seq_length = n_numbers + 1 #MAX Output Length
def get_sum_pairs(n_examples):
inputs, labels = list(), list()
char_to_int = dict((c, i) for i, c in enumerate(character_set)) # lookup table
for i in range(n_examples):
lhs = [random.randint(1, largest) for _ in range(n_numbers)]
op = random.choice(['+', '*'])
if op == '+':
rhs = sum(lhs)
elif op == '*':
rhs = 1
for l in lhs:
rhs *= l
lhs = [str(l) for l in lhs]
lhs_str = " * ".join(lhs) if op == "*" else " + ".join(lhs) # + or * only
inputs.append(lhs_str)
labels.append(str(rhs))
print len(inputs), len(labels)
return inputs, labels
def write_to_file(lines, path):
with open(path, 'ar') as f:
for line in lines:
f.write("".join(line) + '\n')
trainX, trainY = get_sum_pairs(10000)
devX, devY = get_sum_pairs(3000)
write_to_file(trainX, 'train.source')
write_to_file(trainY, 'train.target')
write_to_file(devX, 'dev.source')
write_to_file(devY, 'dev.target')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment