Created
October 22, 2016 19:03
-
-
Save JoostvDoorn/9b8e267305dac4452add0328a2f9c724 to your computer and use it in GitHub Desktop.
Bug #357 in Element-Research/rnn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Simple Sentence Similarity Example with Siamese Encoder | |
-- Author: Joost van Doorn (joostvandoorn.com) | |
require 'nn' | |
require 'rnn' | |
require 'io' | |
local pl = require 'pl' | |
local dl = require 'dataload' | |
Module = nn.Module | |
function Module.flatten(parameters) | |
-- returns true if tensor occupies a contiguous region of memory (no holes) | |
local function isCompact(tensor) | |
local sortedStride, perm = torch.sort( | |
torch.LongTensor(tensor:nDimension()):set(tensor:stride()), 1, true) | |
local sortedSize = torch.LongTensor(tensor:nDimension()):set( | |
tensor:size()):index(1, perm) | |
local nRealDim = torch.clamp(sortedStride, 0, 1):sum() | |
sortedStride = sortedStride:narrow(1, 1, nRealDim):clone() | |
sortedSize = sortedSize:narrow(1, 1, nRealDim):clone() | |
local t = tensor.new():set(tensor:storage(), 1, | |
sortedSize:storage(), | |
sortedStride:storage()) | |
return t:isContiguous() | |
end | |
if not parameters or #parameters == 0 then | |
return torch.Tensor() | |
end | |
local Tensor = parameters[1].new | |
local TmpTensor = Module._flattenTensorBuffer[torch.type(parameters[1])] or Tensor | |
-- 1. construct the set of all unique storages referenced by parameter tensors | |
local storages = {} | |
local nParameters = 0 | |
local parameterMeta = {} | |
for k = 1,#parameters do | |
local param = parameters[k] | |
local storage = parameters[k]:storage() | |
if storage == nil then | |
print("storage is nil in param:") | |
print(k) | |
end | |
local storageKey = torch.pointer(storage) | |
if not storages[storageKey] then | |
storages[storageKey] = {storage, nParameters} | |
nParameters = nParameters + storage:size() | |
end | |
parameterMeta[k] = {storageOffset = param:storageOffset() + | |
storages[storageKey][2], | |
size = param:size(), | |
stride = param:stride()} | |
end | |
-- 2. construct a single tensor that will hold all the parameters | |
local flatParameters = TmpTensor(nParameters):zero() | |
-- 3. determine if there are elements in the storage that none of the | |
-- parameter tensors reference ('holes') | |
local tensorsCompact = true | |
for k = 1,#parameters do | |
local meta = parameterMeta[k] | |
local tmp = TmpTensor():set( | |
flatParameters:storage(), meta.storageOffset, meta.size, meta.stride) | |
tmp:fill(1) | |
tensorsCompact = tensorsCompact and isCompact(tmp) | |
end | |
local maskParameters = flatParameters:byte():clone() | |
local compactOffsets = flatParameters:long():cumsum(1) | |
local nUsedParameters = compactOffsets[-1] | |
-- 4. copy storages into the flattened parameter tensor | |
for _, storageAndOffset in pairs(storages) do | |
local storage, offset = table.unpack(storageAndOffset) | |
flatParameters[{{offset+1,offset+storage:size()}}]:copy(Tensor():set(storage)) | |
end | |
-- 5. allow garbage collection | |
storages = nil | |
for k = 1,#parameters do | |
parameters[k]:set(Tensor()) | |
end | |
-- 6. compact the flattened parameters if there were holes | |
if nUsedParameters ~= nParameters then | |
assert(tensorsCompact, | |
"Cannot gather tensors that are not compact") | |
flatParameters = TmpTensor(nUsedParameters):copy( | |
flatParameters:maskedSelect(maskParameters)) | |
for k = 1,#parameters do | |
parameterMeta[k].storageOffset = | |
compactOffsets[parameterMeta[k].storageOffset] | |
end | |
end | |
if TmpTensor ~= Tensor then | |
flatParameters = Tensor(flatParameters:nElement()):copy(flatParameters) | |
end | |
-- 7. fix up the parameter tensors to point at the flattened parameters | |
for k = 1,#parameters do | |
parameters[k]:set(flatParameters:storage(), | |
parameterMeta[k].storageOffset, | |
parameterMeta[k].size, | |
parameterMeta[k].stride) | |
end | |
return flatParameters | |
end | |
-- Options | |
local opt = {} | |
opt.hiddenSize = 14 | |
opt.iterations = 1 | |
-- Load data | |
dl.downloadfile('.', 'http://alt.qcri.org/semeval2015/task1/data/uploads/tp_trial_data.txt', 'tp_trial_data.txt') | |
local tweets = {} | |
local tokens = {} | |
for line in io.lines('./tp_trial_data.txt') do | |
tweets[#tweets+1] = line | |
for w in line:gmatch("%S+") do tokens[#tokens+1] = w end | |
end | |
local vocab, ivocab, wordfreq = dl.buildVocab(tokens, 1) | |
local x = {} | |
local y = {} | |
for i, tweet in pairs(tweets) do | |
local d = stringx.split(tweet, "|||") | |
if #d >= 4 and i>1 then | |
local i = #x+1 | |
x[i] = {} | |
x[i][1] = dl.text2tensor(stringx.split(d[2]), vocab):view(-1, 1) -- Tweet 1 | |
x[i][2] = dl.text2tensor(stringx.split(d[3]), vocab):view(-1, 1) -- Tweet 2 | |
y[i] = torch.Tensor{stringx.strip(d[4]) == "True" and 1 or 0}:view(1, 1) | |
end | |
end | |
-- Encoder/Embedding | |
local encoder = nn.Sequential():add(nn.LookupTable(#ivocab, opt.hiddenSize)):add(nn.Sequencer(nn.LSTM(opt.hiddenSize, opt.hiddenSize))):add(nn.Select(1, -1)) | |
-- local encoder = torch.load('test.model') | |
encoder = encoder:clone() | |
-- The siamese model | |
local siameseEncoder = nn.ParallelTable() | |
siameseEncoder:add(encoder) | |
siameseEncoder:add(encoder:clone('weight','bias', 'gradWeight','gradBias')) --clone the encoder and share the weight, bias. Must also share the gradWeight and gradBias | |
parameters, gradParameters = siameseEncoder:getParameters() | |
-- The siamese model | |
model = nn.Sequential() | |
model:add(siameseEncoder) | |
model:add(nn.CSubTable(1, 1)) | |
model:add(nn.Linear(opt.hiddenSize, 1)) | |
local criterion = nn.MSECriterion() | |
-- Train | |
local n = #x | |
local lr = 0.01 | |
for iter=1,opt.iterations do | |
local idx = iter % n + 1 | |
model:zeroGradParameters() | |
local outputs = model:forward(x[idx]) | |
local err = criterion:forward(outputs, y[idx]) | |
print(string.format("Iteration %d ; Loss = %f ", iter, err)) | |
local gradOutputs = criterion:backward(outputs, y[idx]) | |
local gradInputs = model:backward(x[idx], gradOutputs) | |
model:updateParameters(lr) | |
end | |
torch.save('test.t7', encoder) | |
local encoder = torch.load('test.t7') | |
parameters, gradParameters = encoder:clone('weight','bias', 'gradWeight','gradBias'):getParameters() | |
-- Test on same data | |
print("Testing on training data") | |
local correct = 0 | |
for idx=1,n do | |
local outputs = model:forward(x[idx]) | |
if outputs:round():squeeze() == y[idx]:squeeze() then | |
correct = correct + 1 | |
end | |
end | |
print(string.format("Accuracy: %f", correct/n)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment