Last active
June 24, 2022 06:59
-
-
Save Wsine/3d26815bf958f77e435b38542b8dba78 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import torch | |
import torchvision.models as models | |
class TestPytorchLayers(unittest.TestCase): | |
def setUp(self): | |
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # type: ignore | |
self.device = torch.device('mps') | |
elif torch.cuda.is_available(): | |
self.device = torch.device('cuda') | |
else: | |
self.device = torch.device('cpu') | |
print(f'Test method {self._testMethodName} with {self.device}') | |
def forward_and_backward(self, model): | |
model = model.to(self.device) | |
if isinstance(model, models.Inception3): | |
input = torch.rand((1, 3, 299, 299)).to(self.device) | |
else: | |
input = torch.rand((1, 3, 224, 224)).to(self.device) | |
output = model(input) | |
if isinstance(model, models.GoogLeNet): | |
output = output.logits | |
label = torch.rand(1, 1000).to(self.device) | |
loss = (output - label).sum() | |
loss.backward() | |
def test_resnet18(self): | |
self.forward_and_backward(models.resnet18()) | |
def test_alexnet(self): | |
self.forward_and_backward(models.alexnet()) | |
def test_squeezenet(self): | |
self.forward_and_backward(models.squeezenet1_0()) | |
def test_vgg16(self): | |
self.forward_and_backward(models.vgg16()) | |
def test_densenet(self): | |
self.forward_and_backward(models.densenet121()) | |
def test_inception(self): | |
self.forward_and_backward(models.inception_v3(init_weights=False, aux_logits=False)) | |
def test_googlenet(self): | |
self.forward_and_backward(models.googlenet(init_weights=False)) | |
def test_shufflenet(self): | |
self.forward_and_backward(models.shufflenet_v2_x0_5()) | |
def test_mobilenet(self): | |
self.forward_and_backward(models.mobilenet_v3_small()) | |
def test_resnext(self): | |
self.forward_and_backward(models.resnext50_32x4d()) | |
def test_wide_resnet(self): | |
self.forward_and_backward(models.wide_resnet50_2()) | |
def test_mnasnet(self): | |
self.forward_and_backward(models.mnasnet0_5()) | |
def test_efficientnet(self): | |
self.forward_and_backward(models.efficientnet_b0()) | |
def test_regnet(self): | |
self.forward_and_backward(models.regnet_x_400mf()) | |
def test_vit(self): | |
self.forward_and_backward(models.vit_b_16()) | |
def test_convnext(self): | |
self.forward_and_backward(models.convnext_tiny()) | |
def rnn_forward_and_backward(self, model): | |
model = model.to(self.device) | |
input = torch.randn(5, 3, 10).to(self.device) | |
h0 = torch.randn(2, 3, 20).to(self.device) | |
if isinstance(model, torch.nn.LSTM): | |
c0 = torch.randn(2, 3, 20).to(self.device) | |
output, _ = model(input, (h0, c0)) | |
else: | |
output, _ = model(input, h0) | |
output.sum().backward() | |
def test_rnn(self): | |
self.rnn_forward_and_backward(torch.nn.RNN(10, 20, 2)) | |
def test_lstm(self): | |
self.rnn_forward_and_backward(torch.nn.LSTM(10, 20, 2)) | |
def test_gru(self): | |
self.rnn_forward_and_backward(torch.nn.GRU(10, 20, 2)) | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment