Skip to content

Instantly share code, notes, and snippets.

@dsyme
Created February 15, 2022 15:20
Show Gist options
  • Save dsyme/667db9244eae546887bc22c2a2c523ce to your computer and use it in GitHub Desktop.
Save dsyme/667db9244eae546887bc22c2a2c523ce to your computer and use it in GitHub Desktop.
import argparse
import torch
import numpy as np
import pprint
import time
import os
import glob
import pprint
import sys
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from util import sgd_rev, sgd_fwd, run, runtime, create_path, get_time_stamp, load_text, save_text, sgd_rev_train, sgd_fwd_train, cross_entropy, weight_init_kaiming, weight_init_bias, weight_init_conv2d, bias_init_conv2d, conv2d, maxpool2d, batchnorm2d, dropout, adaptiveavgpool2d, argmax, days_hours_mins_secs_str
parser = argparse.ArgumentParser(description='train', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--mode', choices=['run', 'plot'], nargs='?', default=None, type=str, required=True)
parser.add_argument('--dir', type=str, default=None)
parser.add_argument('--lr', type=float, default=0.0002)
parser.add_argument('--lr_decay', type=float, default=0.0001)
parser.add_argument('--threshold', type=float, default=0.00001)
parser.add_argument('--n', type=int, default=None)
parser.add_argument('--runs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--valid_every', type=int, default=500)
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--model', type=str, choices=['logreg', 'mlp', 'cnn', 'cnn2', 'cnn4', 'cnn4b', 'vgg16', 'resnet18', 'resnet50'], default='logreg')
parser.add_argument('--optimizer', type=str, choices=['sgd', 'sgdn', 'adam'], default='sgd')
parser.add_argument('--momentum', type=float, default=0.2)
parser.add_argument('--skiprev', action='store_true', )
parser.add_argument('--skipfwd', action='store_true', )
opt = parser.parse_args()
print('Arguments:\n{}\n'.format(' '.join(sys.argv[1:])))
print('Config:')
pprint.pprint(vars(opt), depth=2, width=50)
print()
device = torch.device(opt.device)
time_start = time.time()
def logreg_params():
return {'w1': weight_init_kaiming(28*28, 10).to(device),
'b1': weight_init_bias(10).to(device)}
def logreg_eval(params, x):
x = x.view(-1, 28*28)
x = x.matmul(params['w1'])
x = x + params['b1']
# x = x.matmul(params['w2'])
# x = x + params['b2']
return x
def logreg_loss(params, x, target):
y = logreg_eval(params, x)
loss = cross_entropy(y, target)
predicted = argmax(y)
num_correct = (predicted == target).sum().item()
return loss, num_correct
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment