Skip to content

Instantly share code, notes, and snippets.

@ozancaglayan
Created February 21, 2016 15:25
Show Gist options
  • Save ozancaglayan/a0da46bdb375b284a671 to your computer and use it in GitHub Desktop.
Save ozancaglayan/a0da46bdb375b284a671 to your computer and use it in GitHub Desktop.
local argparse = require "argparse"
local moses = require "moses"
local parser = argparse("build_dictionary", "example")
parser:argument("input", "Input text file"):args('+')
parser:option('-o --output', 'Output directory', '.')
parser:option('-m --minfreq', 'Filter out words occuring < m times.', 0)
local args = parser:parse()
args.minfreq = tonumber(args.minfreq)
for i = 1, #args.input do
print("Processing " .. args.input[i])
local word_freqs = {}
local words = {}
local freqs = {}
-- Final dictionary
local dict = {['<eos>'] = 0, ['<unk>'] = 1}
local n_words = 2
-- Read the file
for line in io.lines(args.input[i]) do
-- Split out words
local ws = line:split(' ')
for j = 1, #ws do
if word_freqs[ws[j]] then
word_freqs[ws[j]] = word_freqs[ws[j]] + 1
else
word_freqs[ws[j]] = 1
end
end
end
if word_freqs['.'] == nil or word_freqs[','] == nil then
print('Warning: Check that the input data is tokenized!')
end
local words = moses.keys(word_freqs)
local freqs = moses.values(word_freqs)
freqs = torch.Tensor(freqs)
sorted_freqs, sorted_idxs = freqs:sort(freqs:dim(), true)
for j = 1, #words do
if freqs[sorted_idxs[j]] >= args.minfreq then
dict[words[sorted_idxs[j]]] = j + 1
n_words = n_words + 1
end
end
fname = args.output .. '/vocab-' .. string.gsub(args.input[i], '(.*/)(.*)', '%2') .. '.th7'
print('Dumping vocabulary file (' .. n_words .. ') tokens into ' .. fname)
torch.save(fname, dict)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment