Skip to content

Instantly share code, notes, and snippets.

@drscotthawley
Created November 20, 2018 06:57
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save drscotthawley/2288e92f23b02e7fd5352708fc6cd125 to your computer and use it in GitHub Desktop.
Save drscotthawley/2288e92f23b02e7fd5352708fc6cd125 to your computer and use it in GitHub Desktop.
FastAICustomModelExample
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
# coding: utf-8
# A mixture of [@eslavich's post](https://forums.fast.ai/t/learner-layer-groups-parameter/30212) and the Lesson 5 lesson5-sgd-mnist.ipynb
# In[ ]:
get_ipython().run_line_magic('reload_ext', 'autoreload')
get_ipython().run_line_magic('autoreload', '2')
get_ipython().run_line_magic('matplotlib', 'inline')
# In[2]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset
from fastai import *
from fastai.vision import *
# In[21]:
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(1, 5)
self.linear2 = nn.Linear(5, 1)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def generate_data(size):
x = np.random.uniform(size=(size, 1))
y = x * 2.0
return x.astype(np.float32), y.astype(np.float32)
x_train, y_train = generate_data(10000)
x_valid, y_valid = generate_data(1000)
x_train,y_train,x_valid,y_valid = map(torch.tensor, (x_train,y_train,x_valid,y_valid))
n,c = x_train.shape
x_train.shape, y_train.min(), y_train.max()
# In[22]:
bs=50
train_ds = TensorDataset(x_train, y_train)
valid_ds = TensorDataset(x_valid, y_valid)
data = DataBunch.create(train_ds, valid_ds, bs=bs)
# In[23]:
x,y = next(iter(data.train_dl))
x.shape,y.shape
# In[24]:
model = SimpleModel().cuda()
# In[25]:
model(x).shape
# In[26]:
loss_func = nn.MSELoss()
learn = Learner(data, SimpleModel(), loss_func=loss_func)
# In[27]:
learn.lr_find()
learn.recorder.plot()
# In[28]:
learn.fit_one_cycle(1, 1e-1)
# In[29]:
learn.recorder.plot_lr(show_moms=True)
# In[30]:
learn.recorder.plot_losses()
# In[ ]:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment