Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
krsnewwave / create_kedro_project.sh
Last active October 25, 2022 16:49
start kedro
# (1) virtual environment
conda activate kedro-env
pip install kedro kedro-mlflow optuna kedro-viz
# (2) new project with starter
# put anything, for me, I wrote 'tutorial'
kedro new --starter=pandas-iris
cd tutorial
# (3) fire up git. The starter already has a gitignore file and
@krsnewwave
krsnewwave / catalog.yaml
Created March 11, 2022 15:04
kedro catalog yaml with mlflow
# in <root>/conf/base/catalog.yaml
insurance:
type: pandas.CSVDataSet
filepath: data/01_raw/train.csv
layer: raw
model_input_table:
type: pandas.ParquetDataSet
filepath: data/03_primary/model_input_table.pq
@krsnewwave
krsnewwave / nodes_with_optuna.py
Created March 11, 2022 15:08
random forest hyperparam optimization using optuna, kedro and mlflow
# in <root>/src/<project>/pipelines/data_science/nodes.py
def rr_objective(X_train: pd.DataFrame, y_train: pd.Series,
X_test: pd.DataFrame, y_test: pd.Series,
trial: optuna.trial):
max_depth = trial.suggest_int("max_depth", 8, 64, log=True)
min_samples_split = trial.suggest_int("min_samples_split", 50, 1000, )
ccp_alpha = trial.suggest_float("ccp_alpha", 0.001, 0.03, log=True)
rr_clf = RandomForestClassifier(max_depth=max_depth,
min_samples_split=min_samples_split,
@krsnewwave
krsnewwave / pipeline.py
Created March 11, 2022 15:09
kedro pipeline
# in <root>/src/<project>/pipelines/data_science/pipeline.py
from kedro.pipeline import node, pipeline
from .nodes import split_data, fit_xgboost
def create_plot_roc_node():
return node(
func=plot_roc,
inputs=["clf", "X_test", "y_test"],
outputs="roc_graph",
@krsnewwave
krsnewwave / pipeline_registry.py
Created March 11, 2022 15:10
kedro pipeline registry
# in <root>/src/<project>/pipeline_registry.py
def register_pipelines() -> Dict[str, Pipeline]:
data_engineering_pipeline = de.create_pipeline()
xgb_pipe = ds.create_xgb_pipeline()
rr_pipe = ds.create_rr_pipeline()
logres_pipe = ds.create_logres_pipeline()
rr_ho_pipe = ds.create_rr_ho_pipeline()
return {
@krsnewwave
krsnewwave / pytorch_multitask_paintings.py
Created April 9, 2022 16:42
PyTorch model for multitask learning
# following https://towardsdatascience.com/multilabel-classification-with-pytorch-in-5-minutes-a4fa8993cbc7
class LightningResNetMultiLabel(pl.LightningModule):
def __init__(self, net, n_period, n_artists, criterion = F.cross_entropy, optimizer = None, scheduler = None, dropout_p = 0., lr=0.001, freeze_net=False):
super().__init__()
self.net = net
self.feature_extractor = nn.Sequential(*(list(self.net.children())[:-1]))
if freeze_net:
for param in self.net.parameters():
class LightningResNet(pl.LightningModule):
def __init__(self, net_pretrained, device='cpu', criterion = F.cross_entropy,
num_classes = 4, optimizer = None, scheduler = None):
super().__init__()
self.net = net_pretrained
# set top to number of classes
num_ftrs = self.net.fc.in_features
self.net.fc = nn.Linear(num_ftrs, num_classes)
class RecoSparseTrainDataset(Dataset):
def __init__(self, sparse_mat):
self.sparse_mat = sparse_mat
def __len__(self):
return self.sparse_mat.shape[0]
def __getitem__(self, idx):
batch_matrix = self.sparse_mat[idx].toarray().squeeze()
return batch_matrix, idx
class CDAE(pl.LightningModule):
def __init__(self, model_conf : Dict, novelty_per_item, num_users, num_items, remove_observed = False, ):
super().__init__()
self.hidden_dim = model_conf["hidden_dim"]
# ... other self. initializations
self.user_embedding = nn.Embedding(self.num_users, self.hidden_dim)
self.encoder = nn.Linear(self.num_items, self.hidden_dim)
self.decoder = nn.Linear(self.hidden_dim, self.num_items)
# pip install ray tune, and comet ml
from ray.tune.integration.comet import CometLoggerCallback
from functools import partial
from ray.tune.integration.pytorch_lightning import TuneReportCallback
def train_function(model_conf, novelty_per_item, epochs, patience,
train_loader, val_loader, checkpoint_dir=None):
model = CDAE(model_conf, novelty_per_item, num_users, num_items)
# fill up your metrics here