Created
April 4, 2016 08:42
-
-
Save coodoo/ca28e1369ff5f9c21632f0e065da3cc1 to your computer and use it in GitHub Desktop.
Added support to pass in pre-trained weights for LookupTable so that RNNLM could be trained with large data set
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local THNN = require 'nn.THNN' | |
local LookupTable, parent = torch.class('nn.LookupTable', 'nn.Module') | |
LookupTable.__version = 4 | |
-- added 3rd argument, to pass in pre-trained weights | |
function LookupTable:__init(nIndex, nOutput, odlWeights ) | |
parent.__init(self) | |
self.weight = odlWeights and odlWeights or torch.Tensor(nIndex, nOutput) | |
self.gradWeight = torch.Tensor(nIndex, nOutput):zero() | |
if not odlWeights then | |
self:reset() | |
end | |
end | |
function LookupTable:backCompatibility() | |
self._count = self._count or torch.IntTensor() | |
self._input = self._input or torch.LongTensor() | |
if not self.shouldScaleGradByFreq then | |
self.shouldScaleGradByFreq = false | |
end | |
end | |
function LookupTable:accUpdateOnly() | |
self.gradWeight = nil | |
return self | |
end | |
function LookupTable:scaleGradByFreq() | |
self.shouldScaleGradByFreq = true | |
return self | |
end | |
function LookupTable:reset(stdv) | |
stdv = stdv or 1 | |
self.weight:normal(0, stdv) | |
end | |
function LookupTable:makeInputContiguous(input) | |
-- make sure input is a contiguous torch.LongTensor | |
if (not input:isContiguous()) or torch.type(input) ~= torch.type(self._input) then | |
self.copiedInput = true | |
self._input:resize(input:size()):copy(input) | |
return self._input | |
end | |
self.copiedInput = false | |
return input | |
end | |
function LookupTable:updateOutput(input) | |
self:backCompatibility() | |
input = self:makeInputContiguous(input) | |
if input:dim() == 1 then | |
self.output:index(self.weight, 1, input) | |
elseif input:dim() == 2 then | |
self.output:index(self.weight, 1, input:view(-1)) | |
self.output = self.output:view(input:size(1), input:size(2), self.weight:size(2)) | |
else | |
error("input must be a vector or matrix") | |
end | |
return self.output | |
end | |
function LookupTable:accGradParameters(input, gradOutput, scale) | |
self:backCompatibility() | |
input = self.copiedInput and self._input or input | |
if input:dim() == 2 then | |
input = input:view(-1) | |
elseif input:dim() ~= 1 then | |
error("input must be a vector or matrix") | |
end | |
self.gradWeight.THNN.LookupTable_accGradParameters( | |
input:cdata(), | |
gradOutput:cdata(), | |
self.gradWeight:cdata(), | |
self._count:cdata(), | |
THNN.optionalTensor(self._sorted), | |
THNN.optionalTensor(self._indices), | |
self.shouldScaleGradByFreq or false, | |
scale or 1 | |
) | |
end | |
function LookupTable:type(type, tensorCache) | |
parent.type(self, type, tensorCache) | |
if type == 'torch.CudaTensor' then | |
-- CUDA uses _sorted and _indices temporary tensors | |
self._sorted = self.weight.new() | |
self._indices = self.weight.new() | |
self._count = self.weight.new() | |
self._input = self.weight.new() | |
else | |
-- self._count and self._input should only be converted if using Cuda | |
self._count = torch.IntTensor() | |
self._input = torch.LongTensor() | |
end | |
return self | |
end | |
function LookupTable:clearState() | |
return self | |
end | |
-- we do not need to accumulate parameters when sharing | |
LookupTable.sharedAccUpdateGradParameters = LookupTable.accUpdateGradParameters |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment