Skip to content

Instantly share code, notes, and snippets.

@rmihir96
Last active March 21, 2019 01:38
Show Gist options
  • Save rmihir96/b3e4e3ca8af5d981884a1cdf2fd6fe77 to your computer and use it in GitHub Desktop.
Save rmihir96/b3e4e3ca8af5d981884a1cdf2fd6fe77 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# coding: utf-8
# In[1]:
from fastai.vision import *
from fastai.metrics import *
import glob
from shutil import copyfile, copy, move
import pandas as pd
import numpy as np
get_ipython().run_line_magic('matplotlib', 'inline')
get_ipython().run_line_magic('reload_ext', 'autoreload')
get_ipython().run_line_magic('autoreload', '2')
# In[3]:
path = Path('whales')
train = path/'train'
test = path/'test'
valid = path/'valid'
# In[4]:
path.ls()
# In[5]:
pd.read_csv(path/"train.csv").head()
# In[ ]:
# In[4]:
data = (ImageList.from_csv(path, 'train.csv', folder='train')
#Where to find the data? -> in planet 'train' folder
.split_by_rand_pct()
#How to split in train/valid? -> randomly with the default 20% in valid
.label_from_df(label_delim=' ')
#How to label? -> use the second column of the csv file and split the tags by ' '
.transform(tfms = get_transforms(), size=128)
.add_test_folder()
#Data augmentation? -> use tfms with a size of 128
.databunch())
#Finally -> use the defaults for conversion to databunch
# In[5]:
data
# In[6]:
data.show_batch(rows=3, figsize=(7,6))
# In[9]:
data.c, data.classes
# In[16]:
#Read this would fix but it didnt
# def error_rate(input:Tensor, targs:Tensor)->Rank0Tensor:
# "1 - `accuracy`"
# targs = targs.view(-1).long()
# return 1 - accuracy(input, targs)
# In[7]:
learn = cnn_learner(data, models.resnet34, metrics = error_rate)
learn.model
# In[8]:
learn.fit_one_cycle(2)
# In[ ]:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment