Created
July 18, 2021 01:18
-
-
Save jph00/bc1c623943d1cc93165f8c8825bb9b6c to your computer and use it in GitHub Desktop.
This is the code that copilot nearly entirely auto-generates for a function called `finetune` and a 1-line docstring
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# I added all the imports by hand | |
import torch | |
from torchvision.datasets import ImageFolder | |
from torchvision import transforms,models | |
from torch import nn,optim | |
# For all functions including this one, I wrote the name and docstring, and sometimes also the param names | |
def finetune(folder, model): | |
"""fine tune pytorch model using images from folder and report results on validation set""" | |
if not os.path.exists(folder): raise ValueError(f"{folder} does not exist") | |
if not os.path.exists(model): raise ValueError(f"{model} does not exist") | |
if not os.path.exists(f'{folder}/val'): raise ValueError(f"{folder}/val does not exist") | |
if not os.path.exists(f'{folder}/train'): raise ValueError(f"{folder}/train does not exist") | |
# build model | |
model = getattr(models,model)() | |
model.load_state_dict(torch.load(model)) | |
model = model.to(device) | |
# load data | |
train_set = ImageFolder(folder,'train',transform=data_transforms['train']) | |
val_set = ImageFolder(folder,'val',transform=data_transforms['val']) | |
train_loader = torch.utils.data.DataLoader(train_set,batch_size=4,shuffle=True,num_workers=2) | |
val_loader = torch.utils.data.DataLoader(val_set,batch_size=4,shuffle=True,num_workers=2) | |
# fine-tune | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9) | |
for epoch in range(1,10): | |
train(model,train_loader,criterion,optimizer,log_interval=100) | |
validate(model,val_loader,criterion) | |
# save model | |
torch.save(model.state_dict(),model) | |
# `data_transforms` and `device` are both used above but not defined, so I told copilot to define them | |
# I had to type "data_transforms =" to get copilot to fill out the definition | |
data_transforms = { | |
'train': transforms.Compose([ | |
transforms.RandomResizedCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]), | |
'val': transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
} | |
# I had to type "device =" to get copilot to fill out the definition | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# I had to type "def train(" to get copilot to fill out the definition | |
# NB: this function doesn't quite work, since it uses `epoch`, but that didn't get passed in | |
def train(model, train_loader, criterion, optimizer, log_interval=100): | |
"train model" | |
model.train() | |
for batch_idx, (data, target) in enumerate(train_loader): | |
data, target = data.to(device), target.to(device) | |
optimizer.zero_grad() | |
output = model(data) | |
loss = criterion(output, target) | |
loss.backward() | |
optimizer.step() | |
if batch_idx % log_interval == 0: | |
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | |
epoch, batch_idx * len(data), len(train_loader.dataset), | |
100. * batch_idx / len(train_loader), loss.item())) | |
# I had to type "def validate(" to get copilot to fill out the definition | |
def validate(model, val_loader, criterion): | |
"validate model" | |
model.eval() | |
val_loss = 0 | |
correct = 0 | |
with torch.no_grad(): | |
for data, target in val_loader: | |
data, target = data.to(device), target.to(device) | |
output = model(data) | |
val_loss += criterion(output, target).item() | |
pred = output.max(1, keepdim=True)[1] | |
correct += pred.eq(target.view_as(pred)).sum().item() | |
val_loss /= len(val_loader.dataset) | |
print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | |
val_loss, correct, len(val_loader.dataset), | |
100. * correct / len(val_loader.dataset))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment