Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save fromLittleAcorns/12472c0cd9eedb28362dda028aa0b329 to your computer and use it in GitHub Desktop.
Save fromLittleAcorns/12472c0cd9eedb28362dda028aa0b329 to your computer and use it in GitHub Desktop.
Code to demo possible fastai but in data block
import os,sys,inspect
from pathlib import Path
from fastai import *
from fastai.basic_train import *
from fastai.data_block import *
from fastai.basic_data import *
from fastai.train import *
from fastai.torch_core import *
from fastai.callbacks import *
import numpy as np
import pandas as pd
current_dir = Path(os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))))
# Create random number input dataframe to show issue
inputs = 11
total_np = np.random.rand(100, inputs).astype(np.float32)
total_df = pd.DataFrame(data=total_np)
total_df.loc[0:60, inputs] = False
total_df.loc[60:, inputs] = True
total_df = total_df.rename(columns={0:"my_label"})
total_df.loc[:,'my_label'] = 'dummy_label'
bs = 20
input_cols = list(total_df.columns[1:-1])
# Define column that represent the target - ie the tissue label
target_col = total_df.columns[0]
# Define validation column
valid_col = total_df.columns[-1]
# Create dataset for AE
dataAE = ItemList.from_df(df=total_df, cols=input_cols)
dataAE = dataAE.split_from_df(col=valid_col)
dataAE = dataAE.label_from_df(cols=input_cols, label_cls=FloatList)
dataB_AE = dataAE.databunch(bs=bs)
print(dataB_AE.valid_ds[0])
# Create (very) simple autoencoder
class Ae(nn.Module):
def __init__(self, n_outer):
super().__init__()
self.hl1 = nn.Linear(n_outer, 5)
self.hl2 = nn.Linear(5, n_outer)
def forward(self, x):
x = F.relu(self.hl1(x))
x = self.hl2(x)
return x
n_outer = len(dataB_AE.valid_ds[0][0])
ae = Ae(n_outer=n_outer)
### Create learner
learner_ae = Learner(model = ae, data=dataB_AE, loss_func=nn.MSELoss(), model_dir=current_dir)
learner_ae.fit_one_cycle(cyc_len=1, max_lr=0.05)
pred, y, prob = learner_ae.predict(dataB_AE.valid_ds[0][0])
learner_ae.validate(dataB_AE.valid_dl)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment