This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# a sample up block | |
def make_conv_bn_relu(in_channels, out_channels, kernel_size=3, stride=1, padding=1): | |
return [ | |
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True) | |
] | |
self.up4 = nn.Sequential( | |
*make_conv_bn_relu(128,64, kernel_size=3, stride=1, padding=1 ), | |
*make_conv_bn_relu(64,64, kernel_size=3, stride=1, padding=1 ) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def construct(A,B,C): | |
""" | |
Given Matrices A, B, C construct 3D Tensor | |
A : i, r | |
B: j, r | |
C : k, r | |
""" | |
X_tilde = 0 | |
r = A.shape[1] | |
for i in range(r): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# a sample down block | |
def make_conv_bn_relu(in_channels, out_channels, kernel_size=3, stride=1, padding=1): | |
return [ | |
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True) | |
] | |
self.down1 = nn.Sequential( | |
*make_conv_bn_relu(in_channels, 64, kernel_size=3, stride=1, padding=1 ), | |
*make_conv_bn_relu(64, 64, kernel_size=3, stride=1, padding=1 ), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.vision import * | |
import math | |
__all__ = ['MeshNet', 'VolumetricUnet', 'conv_relu_bn_drop', 'res3dmodel', 'get_total_params', | |
'VolumetricResidualUnet', 'model_dict', 'experiment_model_dict', 'one_by_one_conv', | |
'model_split_dict'] | |
#################### | |
## GET MODELS ## |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.vision.all import * | |
from torch.distributions import Beta | |
from copy import deepcopy | |
__all__ = ["ELR", "ELR_plusA", "ELR_plusB"] | |
class ELR(Callback): | |
''' | |
The selected values are β = 0.7 and λ = 3 for symmetric noise, β = 0.9 and λ = 1 for |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch.distributed as dist | |
from torch.multiprocessing import Process | |
from torchvision import datasets, transforms | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import random |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@call_parse | |
def main( | |
size: Param("Image resolution", int)=224, | |
bs: Param("Batch Size", int)=128, | |
epochs: Param("Number of epochs for training", int)=1, | |
lr: Param("Learning rate for training", float)=5e-5): | |
WANDB = True | |
# start wandb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import wandb | |
from fastai.callback.wandb import WandbCallback | |
from fastai.distributed import * | |
torch.backends.cudnn.benchmark = True | |
from zero_optimizer import ZeroRedundancyOptimizer | |
@patch | |
def after_batch(self: WandbCallback): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.vision.all import * | |
from torch.cuda.amp import autocast, GradScaler | |
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state | |
from sam import SAM | |
class FastaiSched: | |
def __init__(self, optimizer, max_lr): | |
self.optimizer = optimizer | |
self.lr_sched = combine_scheds([0.1,0.9], [SchedLin(1e-8,max_lr), SchedCos(max_lr,1e-8)]) | |
self.update(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.vision.all import * | |
__all__ = ["EMA", "SWA"] | |
class EMA(Callback): | |
"https://fastai.github.io/timmdocs/training_modelEMA" | |
order,run_valid=5,False | |
def __init__(self, decay=0.9999): | |
super().__init__() | |
self.decay = decay |
OlderNewer