Skip to content

Instantly share code, notes, and snippets.

@Deniallugo
Created May 6, 2020 11:29
Show Gist options
  • Save Deniallugo/3d4b4ea683a812fc1360b88605e4d204 to your computer and use it in GitHub Desktop.
Save Deniallugo/3d4b4ea683a812fc1360b88605e4d204 to your computer and use it in GitHub Desktop.
import enum
class Phase(enum.Enum):
train = 1
val = 2
def prepare_data_loader(phase: Phase, model):
if phase == Phase.train:
dataloader = train_loader
model.train() # Set model to training mode
else:
dataloader = test_loader
model.eval() # Set model to evaluate mode
def set_numerical_instability(phase: Phase, model):
pred_original, confidence = model(images)
pred_original = F.softmax(pred_original, dim=-1)
confidence = torch.sigmoid(confidence)
# Make sure we don't have any numerical instability
if phase == phase.train:
eps = 1e-12
pred_original = torch.clamp(pred_original, 0.0 + eps, 1.0 - eps)
confidence = torch.clamp(confidence, 0.0 + eps, 1.0 - eps)
if baseline:
# Randomly set half of the confidences to 1 (i.e. no hints)
b = torch.bernoulli(torch.Tensor(confidence.size()).uniform_(0, 1)).to(
device
)
conf = confidence * b + (1 - b)
pred_new = pred_original * conf.expand_as(pred_original) + labels_onehot * (
1 - conf.expand_as(labels_onehot)
)
pred_original = torch.log(pred_new)
else:
pred_original = torch.log(pred_original)
return pred_original, confidence
# FIXME find better name
def backward(xentropy_loss, lmbda, baseline, confidence_loss):
if baseline:
total_loss = xentropy_loss
else:
total_loss = xentropy_loss + (lmbda * confidence_loss)
if budget > confidence_loss.item():
lmbda = lmbda / 1.01
elif budget <= confidence_loss.item():
lmbda = lmbda / 0.99
total_loss.backward()
optimizer.step()
def find_conf(running_confidence):
conf_min = np.min(np.array(running_confidence))
conf_max = np.max(running_confidence)
conf_avg = np.mean(running_confidence)
return conf_min, conf_avg, conf_max
def run_phase(model, loss, optimizer, scheduler, epoch, phase: Phase):
dataloader = prepare_data_loader(phase, model)
running_loss = 0.0
running_acc = 0.0
running_confidence = []
running_conf_loss = 0.0
# Iterate over data.
for images, labels in tqdm(dataloader):
images = images.to(device)
labels = labels.to(device)
labels_onehot = encode_onehot(labels, num_classes)
optimizer.zero_grad()
# forward and backward
with torch.set_grad_enabled(phase == Phase.train):
pred_original, confidence = set_numerical_instability(phase, model)
xentropy_loss = loss(pred_original, labels)
confidence_loss = torch.mean(-torch.log(confidence))
if phase == Phase.train:
backward(xentropy_loss, lmbda, baseline, confidence_loss)
pred_idx = pred_original.argmax(dim=1)
# TODO move to another function
running_loss += xentropy_loss.item()
running_acc += (pred_idx == labels.data).float().mean()
running_conf_loss += confidence_loss.item()
if phase == "val":
running_confidence.extend(confidence.cpu().numpy())
epoch_loss = running_loss / len(dataloader)
epoch_acc = running_acc / len(dataloader)
epoch_conf_loss = running_conf_loss / len(dataloader)
print(
f"\n {phase} Loss: {epoch_loss:.4f} Confidence Loss: {epoch_conf_loss:.4f} Acc: {epoch_acc:.4f}",
flush=True,
)
if phase == Phase.val:
conf_min, conf_avg, conf_max = find_conf(running_confidence)
print(
f"conf_min: {conf_min:.3f}, conf_max: {conf_max:.3f}, conf_avg: {conf_avg:.3f}"
)
write_epoch(epoch, epoch_loss, epoch_conf_loss, epoch_acc)
save_torch(accuracy, log_loss, conf_loss)
if phase == Phase.train:
scheduler.step(epoch)
return epoch_acc.cpu(), epoch_loss, epoch_conf_loss
def write_epoch(epoch, loss, conf_loss, acc):
writer.add_scalar(f"Loss/{phase}", loss, epoch)
writer.add_scalar(f"ConfLoss/{phase}", conf_loss, epoch)
writer.add_scalar(f"Accuracy/{phase}", acc, epoch)
def save_torch(accuracy, log_loss, conf_loss):
data = {
"accuracy": accuracy,
"loss": log_loss,
"confidence_loss": conf_loss,
}
if not os.path.isdir("accs_losses"):
os.mkdir("accs_losses")
torch.save(data, f"./accs_losses/{phase}_accs_losses_{budget}.pth")
def train_model(model, loss, optimizer, scheduler, num_epochs):
lmbda = 0.1
accuracy = np.array([])
log_loss = np.array([])
conf_loss = np.array([])
for epoch in range(start_epoch, start_epoch + num_epochs):
print("Epoch {}/{}:".format(epoch, num_epochs - 1), flush=True)
# Each epoch has a training and validation phase
acc, loss, conf_loss = run_phase(
model, loss, optimizer, scheduler, epoch, Phase.train
)
accuracy = np.append(accuracy, acc)
log_loss = np.append(log_loss, loss)
conf_loss = np.append(conf_loss, conf_loss)
acc, loss, conf_loss = run_phase(
model, loss, optimizer, scheduler, epoch, Phase.train
)
accuracy = np.append(accuracy, acc)
log_loss = np.append(log_loss, loss)
conf_loss = np.append(conf_loss, conf_loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment