Created
January 30, 2021 20:05
-
-
Save bnsh/393f90bccf5cc4b73114ccec2b9cb0b7 to your computer and use it in GitHub Desktop.
Training a model with Sparse parameters.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
FROM pytorch/pytorch:latest | |
MAINTAINER Binesh Bannerjee <binesh_binesh@hotmail.com> | |
COPY test.py /tmp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
all: | |
build: | |
docker build -t binesh/pytorch-sparse-test . | |
run: build | |
docker container run -it --rm binesh/pytorch-sparse-test /bin/bash | |
sparse1: build | |
docker container run -it --rm binesh/pytorch-sparse-test /opt/conda/bin/python3 /tmp/test.py --mode sparse1 | |
sparse2: build | |
docker container run -it --rm binesh/pytorch-sparse-test /opt/conda/bin/python3 /tmp/test.py --mode sparse2 | |
dense: build | |
docker container run -it --rm binesh/pytorch-sparse-test /opt/conda/bin/python3 /tmp/test.py --mode dense |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python3 | |
# vim: expandtab shiftwidth=4 tabstop=4 | |
""" | |
Hi, | |
so you run this like so: | |
python3 ./test.py --mode sparse1 # or | |
python3 ./test.py --mode sparse2 # or | |
python3 ./test.py --mode dense # or | |
(There's a Dockerfile as well, so hopefully this is easily replicable.) | |
What I want for my _real_ task, is to fix most of the elements of a linear layer to be zero, | |
and _only_ train some specified coefficients of the matrix. | |
I was thinking, what I'd do, is make a torch.sparse.FloatTensor, and train that. | |
That first attempt is SparseTest1, it just wraps the torch.sparse.FloatTensor in a torch.nn.Parameter. | |
When I run that, I get: | |
ValueError: Sparse params at indices [0]: SparseAdam requires dense parameter tensors | |
So, then I thought, OK, I'll make a 2 valued vector, make that a nn.Parameter, and then use _that_ | |
inside torch.sparse.FloatTensor as the "values". When I do that though, I get: | |
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time. | |
So, then I thought let me see if this is even feasible, and tried to do a DenseTest, wich plain old dense matrices.. | |
That runs to completion. | |
So, my question is: How can I train through a SparseTensor? Or, if I want to fix certain elements of the weight matrix to zero, how can I do that? | |
Thanks! | |
Binesh | |
""" | |
import argparse | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
class SparseTest1(nn.Module): | |
def __init__(self): | |
super(SparseTest1, self).__init__() | |
self.mat = torch.nn.Parameter(torch.sparse.FloatTensor( | |
torch.LongTensor([(0, 1), (0, 1)]), | |
torch.randn(2), | |
torch.Size([2, 2]) | |
)) | |
print(self.mat) | |
def forward(self, val): | |
return torch.sparse.mm(self.mat, val.t()).t() | |
class SparseTest2(nn.Module): | |
def __init__(self): | |
super(SparseTest2, self).__init__() | |
self.param = torch.nn.Parameter(torch.randn(2)) | |
self.mat = torch.sparse.FloatTensor( | |
torch.LongTensor([(0, 1), (0, 1)]), | |
self.param, | |
torch.Size([2, 2]) | |
) | |
print(self.mat) | |
def forward(self, val): | |
return torch.sparse.mm(self.mat, val.t()).t() | |
class DenseTest(nn.Module): | |
def __init__(self): | |
super(DenseTest, self).__init__() | |
self.mat = torch.nn.Parameter(torch.randn(2, 2)) | |
def forward(self, val): | |
return torch.matmul(val, self.mat) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--mode", type=str, required=True, choices=("sparse1", "sparse2", "dense")) | |
args = parser.parse_args() | |
expect = torch.diag(torch.randn(2)) | |
if args.mode == "sparse1": | |
test = SparseTest1() | |
opt = optim.SparseAdam(test.parameters(), lr=1e-2) | |
elif args.mode == "sparse2": | |
test = SparseTest2() | |
opt = optim.Adam(test.parameters(), lr=1e-2) | |
elif args.mode == "dense": | |
test = DenseTest() | |
opt = optim.Adam(test.parameters(), lr=1e-2) | |
crit = nn.MSELoss() | |
for epoch in range(0, 32): | |
data = torch.randn(1024, 2) | |
out = test(data) | |
targ = torch.matmul(data, expect) | |
opt.zero_grad() | |
loss = crit(out, targ) | |
loss.backward() | |
opt.step() | |
print("%.7f" % (loss,)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment