Skip to content

Instantly share code, notes, and snippets.

@JoostvDoorn
Created October 22, 2016 19:03
Show Gist options
  • Save JoostvDoorn/9b8e267305dac4452add0328a2f9c724 to your computer and use it in GitHub Desktop.
Save JoostvDoorn/9b8e267305dac4452add0328a2f9c724 to your computer and use it in GitHub Desktop.
Bug #357 in Element-Research/rnn
-- Simple Sentence Similarity Example with Siamese Encoder
-- Author: Joost van Doorn (joostvandoorn.com)
require 'nn'
require 'rnn'
require 'io'
local pl = require 'pl'
local dl = require 'dataload'
Module = nn.Module
function Module.flatten(parameters)
-- returns true if tensor occupies a contiguous region of memory (no holes)
local function isCompact(tensor)
local sortedStride, perm = torch.sort(
torch.LongTensor(tensor:nDimension()):set(tensor:stride()), 1, true)
local sortedSize = torch.LongTensor(tensor:nDimension()):set(
tensor:size()):index(1, perm)
local nRealDim = torch.clamp(sortedStride, 0, 1):sum()
sortedStride = sortedStride:narrow(1, 1, nRealDim):clone()
sortedSize = sortedSize:narrow(1, 1, nRealDim):clone()
local t = tensor.new():set(tensor:storage(), 1,
sortedSize:storage(),
sortedStride:storage())
return t:isContiguous()
end
if not parameters or #parameters == 0 then
return torch.Tensor()
end
local Tensor = parameters[1].new
local TmpTensor = Module._flattenTensorBuffer[torch.type(parameters[1])] or Tensor
-- 1. construct the set of all unique storages referenced by parameter tensors
local storages = {}
local nParameters = 0
local parameterMeta = {}
for k = 1,#parameters do
local param = parameters[k]
local storage = parameters[k]:storage()
if storage == nil then
print("storage is nil in param:")
print(k)
end
local storageKey = torch.pointer(storage)
if not storages[storageKey] then
storages[storageKey] = {storage, nParameters}
nParameters = nParameters + storage:size()
end
parameterMeta[k] = {storageOffset = param:storageOffset() +
storages[storageKey][2],
size = param:size(),
stride = param:stride()}
end
-- 2. construct a single tensor that will hold all the parameters
local flatParameters = TmpTensor(nParameters):zero()
-- 3. determine if there are elements in the storage that none of the
-- parameter tensors reference ('holes')
local tensorsCompact = true
for k = 1,#parameters do
local meta = parameterMeta[k]
local tmp = TmpTensor():set(
flatParameters:storage(), meta.storageOffset, meta.size, meta.stride)
tmp:fill(1)
tensorsCompact = tensorsCompact and isCompact(tmp)
end
local maskParameters = flatParameters:byte():clone()
local compactOffsets = flatParameters:long():cumsum(1)
local nUsedParameters = compactOffsets[-1]
-- 4. copy storages into the flattened parameter tensor
for _, storageAndOffset in pairs(storages) do
local storage, offset = table.unpack(storageAndOffset)
flatParameters[{{offset+1,offset+storage:size()}}]:copy(Tensor():set(storage))
end
-- 5. allow garbage collection
storages = nil
for k = 1,#parameters do
parameters[k]:set(Tensor())
end
-- 6. compact the flattened parameters if there were holes
if nUsedParameters ~= nParameters then
assert(tensorsCompact,
"Cannot gather tensors that are not compact")
flatParameters = TmpTensor(nUsedParameters):copy(
flatParameters:maskedSelect(maskParameters))
for k = 1,#parameters do
parameterMeta[k].storageOffset =
compactOffsets[parameterMeta[k].storageOffset]
end
end
if TmpTensor ~= Tensor then
flatParameters = Tensor(flatParameters:nElement()):copy(flatParameters)
end
-- 7. fix up the parameter tensors to point at the flattened parameters
for k = 1,#parameters do
parameters[k]:set(flatParameters:storage(),
parameterMeta[k].storageOffset,
parameterMeta[k].size,
parameterMeta[k].stride)
end
return flatParameters
end
-- Options
local opt = {}
opt.hiddenSize = 14
opt.iterations = 1
-- Load data
dl.downloadfile('.', 'http://alt.qcri.org/semeval2015/task1/data/uploads/tp_trial_data.txt', 'tp_trial_data.txt')
local tweets = {}
local tokens = {}
for line in io.lines('./tp_trial_data.txt') do
tweets[#tweets+1] = line
for w in line:gmatch("%S+") do tokens[#tokens+1] = w end
end
local vocab, ivocab, wordfreq = dl.buildVocab(tokens, 1)
local x = {}
local y = {}
for i, tweet in pairs(tweets) do
local d = stringx.split(tweet, "|||")
if #d >= 4 and i>1 then
local i = #x+1
x[i] = {}
x[i][1] = dl.text2tensor(stringx.split(d[2]), vocab):view(-1, 1) -- Tweet 1
x[i][2] = dl.text2tensor(stringx.split(d[3]), vocab):view(-1, 1) -- Tweet 2
y[i] = torch.Tensor{stringx.strip(d[4]) == "True" and 1 or 0}:view(1, 1)
end
end
-- Encoder/Embedding
local encoder = nn.Sequential():add(nn.LookupTable(#ivocab, opt.hiddenSize)):add(nn.Sequencer(nn.LSTM(opt.hiddenSize, opt.hiddenSize))):add(nn.Select(1, -1))
-- local encoder = torch.load('test.model')
encoder = encoder:clone()
-- The siamese model
local siameseEncoder = nn.ParallelTable()
siameseEncoder:add(encoder)
siameseEncoder:add(encoder:clone('weight','bias', 'gradWeight','gradBias')) --clone the encoder and share the weight, bias. Must also share the gradWeight and gradBias
parameters, gradParameters = siameseEncoder:getParameters()
-- The siamese model
model = nn.Sequential()
model:add(siameseEncoder)
model:add(nn.CSubTable(1, 1))
model:add(nn.Linear(opt.hiddenSize, 1))
local criterion = nn.MSECriterion()
-- Train
local n = #x
local lr = 0.01
for iter=1,opt.iterations do
local idx = iter % n + 1
model:zeroGradParameters()
local outputs = model:forward(x[idx])
local err = criterion:forward(outputs, y[idx])
print(string.format("Iteration %d ; Loss = %f ", iter, err))
local gradOutputs = criterion:backward(outputs, y[idx])
local gradInputs = model:backward(x[idx], gradOutputs)
model:updateParameters(lr)
end
torch.save('test.t7', encoder)
local encoder = torch.load('test.t7')
parameters, gradParameters = encoder:clone('weight','bias', 'gradWeight','gradBias'):getParameters()
-- Test on same data
print("Testing on training data")
local correct = 0
for idx=1,n do
local outputs = model:forward(x[idx])
if outputs:round():squeeze() == y[idx]:squeeze() then
correct = correct + 1
end
end
print(string.format("Accuracy: %f", correct/n))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment