This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#A resnet block consisting of N number of resnet modules, first layer has is pooled. | |
def resnet_block(x,filters,num_layers,pool_first_layer=True): | |
for i in range(num_layers): | |
pool = False | |
if i == 0 and pool_first_layer: pool = True | |
x = resnet_module(x,filters=filters,pool=pool) | |
return x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A single resnet module consisting of 1 x 1 conv - 3 x 3 conv and 1 x 1 conv | |
def resnet_identity_module(x, filters, pool=False): | |
res = x | |
stride = 1 | |
if pool: | |
stride = 2 | |
res = Conv2D(filters, kernel_size=1, strides=2, padding="same")(res) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A single resnet module consisting of 1 x 1 conv - 3 x 3 conv and 1 x 1 conv | |
def resnet_module(x, filters, pool=False): | |
res = x | |
stride = 1 | |
if pool: | |
stride = 2 | |
res = Conv2D(filters, kernel_size=1, strides=2, padding="same")(res) | |
res = BatchNormalization()(res) | |
x = Conv2D(int(filters / 4), kernel_size=1, strides=stride, padding="same")(x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import keras | |
from keras.layers import * | |
from keras.models import Model | |
# A single resnet module consisting of 1 x 1 conv - 3 x 3 conv and 1 x 1 conv | |
def resnet_module(x, filters, pool=False): | |
res = x | |
stride = 1 | |
if pool: | |
stride = 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import keras | |
from keras.layers import * | |
from keras.models import Model | |
# A single resnet module consisting of 1 x 1 conv - 3 x 3 conv and 1 x 1 conv | |
def resnet_identity_module(x, filters, pool=False): | |
res = x | |
stride = 1 | |
if pool: | |
stride = 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Unit(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(Unit, self).__init__() | |
self.conv = nn.Conv2d(in_channels=in_channels, kernel_size=3, out_channels=out_channels, stride=1, padding=1) | |
self.bn = nn.BatchNorm2d(num_features=out_channels) | |
self.relu = nn.ReLU() | |
def forward(self, input): | |
output = self.conv(input) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define transformations for the test set | |
test_transformations = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
# Load the test set, note that train is set to False | |
test_set = CIFAR10(root="./data", train=False, transform=test_transformations, download=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.optim import Adam | |
# Check if gpu support is available | |
cuda_avail = torch.cuda.is_available() | |
# Create model, optimizer and loss function | |
model = SimpleNet(num_classes=10) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create a learning rate adjustment function that divides the learning rate by 10 every 30 epochs | |
def adjust_learning_rate(epoch): | |
lr = 0.001 | |
if epoch > 180: | |
lr = lr / 1000000 | |
elif epoch > 150: | |
lr = lr / 100000 | |
elif epoch > 120: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def save_models(epoch): | |
torch.save(model.state_dict(), "cifar10model_{}.model".format(epoch)) | |
print("Chekcpoint saved") |
OlderNewer