Skip to content

Instantly share code, notes, and snippets.

@tejus-gupta
Created October 30, 2018 15:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tejus-gupta/d12ebfe5374cdaa937c30e794d925cf4 to your computer and use it in GitHub Desktop.
Save tejus-gupta/d12ebfe5374cdaa937c30e794d925cf4 to your computer and use it in GitHub Desktop.
import os
import sys
import yaml
import time
import shutil
import torch
import random
import argparse
import datetime
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils import data
from tqdm import tqdm
from ptsemseg.models import get_model
from ptsemseg.loss import get_loss_function
from ptsemseg.loader import get_loader
from ptsemseg.utils import get_logger
from ptsemseg.metrics import runningScore, averageMeter
from ptsemseg.augmentations import get_composed_augmentations
from ptsemseg.schedulers import get_scheduler
from ptsemseg.optimizers import get_optimizer
from tensorboardX import SummaryWriter
import sys
sys.path.append('/home/tejus/lane-seg-experiments/Segmentation/')
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets, models, transforms
from datasets.kitti.config import CONFIG
from datasets.kitti.kitti_loader import kittiLoader
from datasets.tusimple.tusimple_loader import tusimpleLoader
from CAN import CAN
from datasets.tusimple.augmentations import *
import torchvision
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
from metrics import runningScore
from datetime import datetime
import math
# Definitions
TRAIN_BATCH = 3
VAL_BATCH = 4
resume_training = False
checkpoint_dir = '/home/tejus/lane-seg-experiments/Segmentation/CAN_logger/context_and_LFE/best_val_model.pkl'
run_id = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
logdir = os.path.join('runs/' , str(run_id))
writer = SummaryWriter(log_dir=logdir)
print('RUNDIR: {}'.format(logdir))
logger = get_logger(logdir)
logger.info('Let the party begin | Dilated convolutions')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(102)
# Network definition
net = CAN()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.0001)
# optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr = 0.0001, momentum=0.9) # 0.00001
# loss_fn = nn.BCEWithLogitsLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience = 2, verbose = True, min_lr = 0.000001)
loss_fn = nn.CrossEntropyLoss()
net.to(device)
if not resume_training:
checkpoint = torch.load(checkpoint_dir)
net.load_state_dict(checkpoint["model_state"],strict = False)
else:
checkpoint = torch.load(checkpoint_dir)
net.load_state_dict(checkpoint["model_state"],strict = False)
optimizer.load_state_dict(checkpoint["optimizer_state"])
scheduler.load_state_dict(checkpoint["scheduler_state"])
start_iter = checkpoint["epoch"]
logger.info(
"Loaded checkpoint '{}' (epoch {})".format(
checkpoint_dir, checkpoint["epoch"]
)
)
# Set up initialization weights for dilation layer as described in the context-aggregation paper
params = net.state_dict()
dilated_conv_layers = [36, 39, 42, 45, 48, 51, 54]
for layer_idx in dilated_conv_layers:
w = params['features.'+str(layer_idx)+'.weight']
b = params['features.'+str(layer_idx)+'.bias']
w.fill_(0)
for i in range(w.shape[0]):
w[i,i,1,1] = 1
#print(w)
b.fill_(0)
params['features.'+str(layer_idx)+'.weight'] = w
params['features.'+str(layer_idx)+'.weight'] = b
#torch.save(net.state_dict(), 'test_identity.wts')
layer_idx = 56
w = params['features.'+str(layer_idx)+'.weight']
w.fill_(0)
for i in range(w.shape[0]):
w[i,i,0,0] = 1
params['features.'+str(layer_idx)+'.weight'] = w
net.train()
### freeze weights of frontend network
i = 0
for k, v in params.items():
v.requires_grad = False
i += 1
if i == 32:
break
###
augmentations = Compose([RandomRotate(5), RandomHorizontallyFlip()])
train_dataset = tusimpleLoader('/home/tejus/Downloads/train_set/', split="train", augmentations=augmentations)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=TRAIN_BATCH, shuffle=True, num_workers=TRAIN_BATCH, pin_memory=True)
val_dataset = tusimpleLoader('/home/tejus/Downloads/train_set/', split="val", augmentations=None)
valloader = torch.utils.data.DataLoader(val_dataset, batch_size=VAL_BATCH, shuffle=False, num_workers=VAL_BATCH, pin_memory=True)
running_metrics_val = runningScore(2)
best_val_loss = math.inf
val_loss = 0
ctr = 0
best_iou=-100
val_loss_meter = averageMeter()
time_meter = averageMeter()
## compute val loss for pretrained weights
val_loss = 0
net.eval()
with torch.no_grad():
for i, data in enumerate(valloader):
print(i)
if i>10:
break
imgs, labels = data
imgs, labels = imgs.to(device), labels.to(device)
out = net(imgs)
loss = loss_fn(out, labels)
pred = out.data.max(1)[1]
running_metrics_val.update(pred.cpu().numpy(),labels.cpu().numpy())
val_loss_meter.update(loss.item())
val_loss += loss.item()
print("val_loss = ", val_loss)
running_loss = 0
net.train()
for i, data in enumerate(trainloader):
print(i)
if i>10:
break
imgs, labels = data
imgs, labels = imgs.to(device), labels.to(device)
out = net(imgs)
loss = loss_fn(out, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
running_loss += loss.item()
print("running loss = ", running_loss)
val_loss = 0
net.eval()
with torch.no_grad():
for i, data in enumerate(valloader):
print(i)
if i>10:
break
imgs, labels = data
imgs, labels = imgs.to(device), labels.to(device)
out = net(imgs)
loss = loss_fn(out, labels)
pred = out.data.max(1)[1]
running_metrics_val.update(pred.cpu().numpy(),labels.cpu().numpy())
val_loss_meter.update(loss.item())
val_loss += loss.item()
print("val_loss = ", val_loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment