Last active
February 3, 2021 09:06
-
-
Save Muhammad4hmed/8730fe5e984d80d9083d0cafc8d69144 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Model(nn.Module): | |
def __init__(self): | |
super().__init__() | |
backbone = timm.create_model(TIMM_MODEL, pretrained=True) | |
n_features = backbone.head.in_features | |
self.backbone = nn.Sequential(*backbone.children())[:-2] | |
self.classifier = nn.Linear(n_features, 5) | |
self.pool = nn.AdaptiveAvgPool2d((1, 1)) | |
def forward_features(self, x): | |
x = self.backbone(x) | |
return x | |
def forward(self, x): | |
feats = self.forward_features(x) | |
x = self.pool(feats).view(x.size(0), -1) | |
x = self.classifier(x) | |
return x, feats | |
def rand_bbox(size, lam): | |
W = size[2] | |
H = size[3] | |
cut_rat = np.sqrt(1. - lam) | |
cut_w = np.int(W * cut_rat) | |
cut_h = np.int(H * cut_rat) | |
# uniform | |
cx = np.random.randint(W) | |
cy = np.random.randint(H) | |
bbx1 = np.clip(cx - cut_w // 2, 0, W) | |
bby1 = np.clip(cy - cut_h // 2, 0, H) | |
bbx2 = np.clip(cx + cut_w // 2, 0, W) | |
bby2 = np.clip(cy + cut_h // 2, 0, H) | |
return bbx1, bby1, bbx2, bby2 | |
def get_spm(input,target,model): | |
imgsize = (512, 512) | |
bs = input.size(0) | |
with torch.no_grad(): | |
output,fms = model(input) | |
clsw = model.model.fc | |
weight = clsw.weight.data | |
bias = clsw.bias.data | |
weight = weight.view(weight.size(0),weight.size(1),1,1) | |
fms = F.relu(fms) | |
poolfea = F.adaptive_avg_pool2d(fms,(1,1)).squeeze() | |
clslogit = F.softmax(clsw.forward(poolfea)) | |
logitlist = [] | |
for i in range(bs): | |
logitlist.append(clslogit[i,target[i]]) | |
clslogit = torch.stack(logitlist) | |
out = F.conv2d(fms, weight, bias=bias) | |
outmaps = [] | |
for i in range(bs): | |
evimap = out[i,target[i]] | |
outmaps.append(evimap) | |
outmaps = torch.stack(outmaps) | |
if imgsize is not None: | |
outmaps = outmaps.view(outmaps.size(0),1,outmaps.size(1),outmaps.size(2)) | |
outmaps = F.interpolate(outmaps,imgsize,mode='bilinear',align_corners=False) | |
outmaps = outmaps.squeeze() | |
for i in range(bs): | |
outmaps[i] -= outmaps[i].min() | |
outmaps[i] /= outmaps[i].sum() | |
return outmaps,clslogit | |
def snapmix(input, target, alpha, model=None): | |
r = np.random.rand(1) | |
lam_a = torch.ones(input.size(0)) | |
lam_b = 1 - lam_a | |
target_b = target.clone() | |
if True: | |
wfmaps,_ = get_spm(input, target, model) | |
bs = input.size(0) | |
lam = np.random.beta(alpha, alpha) | |
lam1 = np.random.beta(alpha, alpha) | |
rand_index = torch.randperm(bs).cuda() | |
wfmaps_b = wfmaps[rand_index,:,:] | |
target_b = target[rand_index] | |
same_label = target == target_b | |
bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam) | |
bbx1_1, bby1_1, bbx2_1, bby2_1 = rand_bbox(input.size(), lam1) | |
area = (bby2-bby1)*(bbx2-bbx1) | |
area1 = (bby2_1-bby1_1)*(bbx2_1-bbx1_1) | |
if area1 > 0 and area>0: | |
ncont = input[rand_index, :, bbx1_1:bbx2_1, bby1_1:bby2_1].clone() | |
ncont = F.interpolate(ncont, size=(bbx2-bbx1,bby2-bby1), mode='bilinear', align_corners=True) | |
input[:, :, bbx1:bbx2, bby1:bby2] = ncont | |
lam_a = 1 - wfmaps[:,bbx1:bbx2,bby1:bby2].sum(2).sum(1)/(wfmaps.sum(2).sum(1)+1e-8) | |
lam_b = wfmaps_b[:,bbx1_1:bbx2_1,bby1_1:bby2_1].sum(2).sum(1)/(wfmaps_b.sum(2).sum(1)+1e-8) | |
tmp = lam_a.clone() | |
lam_a[same_label] += lam_b[same_label] | |
lam_b[same_label] += tmp[same_label] | |
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2])) | |
lam_a[torch.isnan(lam_a)] = lam | |
lam_b[torch.isnan(lam_b)] = 1-lam | |
return input,target,target_b,lam_a.cuda(),lam_b.cuda() | |
class SnapMixLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, criterion, outputs, ya, yb, lam_a, lam_b): | |
loss_a = criterion(outputs, ya) | |
loss_b = criterion(outputs, yb) | |
loss = torch.mean(loss_a * lam_a + loss_b * lam_b) | |
return loss | |
# include this: https://www.kaggle.com/khyeh0719/image-fmix | |
package_path = '../input/image-fmix/FMix-master' | |
import sys; sys.path.append(package_path) | |
from pylab import rcParams | |
rcParams['figure.figsize'] = 20,40 | |
def rand_bbox(size, lam): | |
W = size[2] | |
H = size[3] | |
cut_rat = np.sqrt(1. - lam) | |
cut_w = np.int(W * cut_rat) | |
cut_h = np.int(H * cut_rat) | |
# uniform | |
cx = np.random.randint(W) | |
cy = np.random.randint(H) | |
bbx1 = np.clip(cx - cut_w // 2, 0, W) | |
bby1 = np.clip(cy - cut_h // 2, 0, H) | |
bbx2 = np.clip(cx + cut_w // 2, 0, W) | |
bby2 = np.clip(cy + cut_h // 2, 0, H) | |
return bbx1, bby1, bbx2, bby2 | |
def cutmix(data, target, alpha): | |
indices = torch.randperm(data.size(0)) | |
shuffled_data = data[indices] | |
shuffled_target = target[indices] | |
lam = np.clip(np.random.beta(alpha, alpha),0.3,0.4) | |
bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam) | |
new_data = data.clone() | |
new_data[:, :, bby1:bby2, bbx1:bbx2] = data[indices, :, bby1:bby2, bbx1:bbx2] | |
# adjust lambda to exactly match pixel ratio | |
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2])) | |
targets = (target.to('cuda', dtype=torch.long), shuffled_target.to('cuda', dtype=torch.long), lam) | |
return new_data.to('cuda', dtype=torch.float), targets | |
def fmix(data, targets, alpha, decay_power, shape, max_soft=0.0, reformulate=False): | |
lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate) | |
#mask =torch.tensor(mask, device=device).float() | |
indices = torch.randperm(data.size(0)) | |
shuffled_data = data[indices] | |
shuffled_targets = targets[indices] | |
x1 = torch.from_numpy(mask).to(device)*data | |
x2 = torch.from_numpy(1-mask).to(device)*shuffled_data | |
targets=(targets.to('cuda', dtype=torch.long), shuffled_targets.to('cuda', dtype=torch.long), lam) | |
return (x1+x2).to('cuda', dtype=torch.float), targets | |
def train_model(data_loader, model, optimizer, scheduler, device): | |
""" | |
This function does training for one epoch | |
:param data_loader: this is the pytorch dataloader | |
:param model: pytorch model | |
:param optimizer: optimizer, for e.g. adam, sgd, etc | |
:param device: cuda/cpu | |
""" | |
# put the model in train mode | |
model.train() | |
# go over every batch of data in data loader | |
for data in data_loader: | |
# remember, we have image and targets | |
# in our dataset class | |
inputs = data["image"] | |
targets = data["targets"] | |
# move inputs/targets to cuda/cpu device | |
inputs = inputs.to(device, dtype=torch.float) | |
targets = targets.to(device, dtype=torch.long) | |
# zero grad the optimizer | |
optimizer.zero_grad() | |
#do the forward step of model | |
outputs = model(inputs) | |
# calculate loss | |
snapmix_criterion = SnapMixLoss().to(device) | |
criterion = TaylorCrossEntropyLoss() | |
loss_fn = criterion.to(device) | |
rand = np.random.rand() | |
if rand > 0.5: | |
X, ya, yb, lam_a, lam_b = snapmix(inputs, targets, 5, model) | |
outputs, _ = model(X) | |
loss = snapmix_criterion(criterion, outputs, ya, yb, lam_a, lam_b) | |
else: | |
mix_decision = np.random.rand() | |
if mix_decision < 0.25: | |
inputs, targets = cutmix(inputs, targets, 1.) | |
elif mix_decision >=0.25 and mix_decision < 0.5: | |
inputs, targets = fmix(inputs, targets, alpha=1., decay_power=5., shape=(512,512)) | |
outputs, _ = model(inputs) | |
if mix_decision < 0.50: | |
loss = loss_fn(outputs, targets[0]) * targets[2] + loss_fn(outputs, targets[1]) * (1. - targets[2]) | |
else: | |
loss = loss_fn(outputs, targets) | |
# backward step the loss | |
loss.backward() | |
# step optimizer | |
optimizer.step() | |
# if you have a scheduler, you either need to | |
# step it here or you have to step it after | |
# the epoch. here, we are not using any learning | |
# rate scheduler | |
scheduler.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment