Skip to content

Instantly share code, notes, and snippets.

@Wsine
Last active June 24, 2022 06:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Wsine/3d26815bf958f77e435b38542b8dba78 to your computer and use it in GitHub Desktop.
Save Wsine/3d26815bf958f77e435b38542b8dba78 to your computer and use it in GitHub Desktop.
import unittest
import torch
import torchvision.models as models
class TestPytorchLayers(unittest.TestCase):
def setUp(self):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # type: ignore
self.device = torch.device('mps')
elif torch.cuda.is_available():
self.device = torch.device('cuda')
else:
self.device = torch.device('cpu')
print(f'Test method {self._testMethodName} with {self.device}')
def forward_and_backward(self, model):
model = model.to(self.device)
if isinstance(model, models.Inception3):
input = torch.rand((1, 3, 299, 299)).to(self.device)
else:
input = torch.rand((1, 3, 224, 224)).to(self.device)
output = model(input)
if isinstance(model, models.GoogLeNet):
output = output.logits
label = torch.rand(1, 1000).to(self.device)
loss = (output - label).sum()
loss.backward()
def test_resnet18(self):
self.forward_and_backward(models.resnet18())
def test_alexnet(self):
self.forward_and_backward(models.alexnet())
def test_squeezenet(self):
self.forward_and_backward(models.squeezenet1_0())
def test_vgg16(self):
self.forward_and_backward(models.vgg16())
def test_densenet(self):
self.forward_and_backward(models.densenet121())
def test_inception(self):
self.forward_and_backward(models.inception_v3(init_weights=False, aux_logits=False))
def test_googlenet(self):
self.forward_and_backward(models.googlenet(init_weights=False))
def test_shufflenet(self):
self.forward_and_backward(models.shufflenet_v2_x0_5())
def test_mobilenet(self):
self.forward_and_backward(models.mobilenet_v3_small())
def test_resnext(self):
self.forward_and_backward(models.resnext50_32x4d())
def test_wide_resnet(self):
self.forward_and_backward(models.wide_resnet50_2())
def test_mnasnet(self):
self.forward_and_backward(models.mnasnet0_5())
def test_efficientnet(self):
self.forward_and_backward(models.efficientnet_b0())
def test_regnet(self):
self.forward_and_backward(models.regnet_x_400mf())
def test_vit(self):
self.forward_and_backward(models.vit_b_16())
def test_convnext(self):
self.forward_and_backward(models.convnext_tiny())
def rnn_forward_and_backward(self, model):
model = model.to(self.device)
input = torch.randn(5, 3, 10).to(self.device)
h0 = torch.randn(2, 3, 20).to(self.device)
if isinstance(model, torch.nn.LSTM):
c0 = torch.randn(2, 3, 20).to(self.device)
output, _ = model(input, (h0, c0))
else:
output, _ = model(input, h0)
output.sum().backward()
def test_rnn(self):
self.rnn_forward_and_backward(torch.nn.RNN(10, 20, 2))
def test_lstm(self):
self.rnn_forward_and_backward(torch.nn.LSTM(10, 20, 2))
def test_gru(self):
self.rnn_forward_and_backward(torch.nn.GRU(10, 20, 2))
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment