This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Demonstrative changes to this file: https://github.com/Lightning-AI/lit-gpt/blob/main/finetune/lora.py | |
from lightning.pytorch.loggers import WandbLogger | |
def setup( | |
data_dir: Path = Path("data/alpaca"), | |
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), | |
out_dir: Path = Path("out/lora/alpaca"), | |
precision: Optional[str] = None, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import wandb | |
from wandb.keras import ( | |
# WandBMetricsLogger, | |
WandbModelCheckpoint, | |
# WandbGradientLogger, | |
# ModelLogger, | |
# FLOPsLogger, | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import wandb | |
from wandb.keras import WandBMetricsLogger | |
with wandb.init(project="mnist", job_type="dev-wandb-metrics-logger"): | |
fashion_mnist = tf.keras.datasets.fashion_mnist | |
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import wandb | |
import tensorflow as tf | |
from keras_cv import bounding_box | |
class WandbTablesBuilder: | |
""" | |
Utility class that contains useful methods to create W&B Tables, | |
and log it to W&B. | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Imports | |
import tensorflow as tf | |
from tensorflow.keras.layers import * | |
from tensorflow.keras.models import * | |
import numpy as np | |
import wandb | |
from wandb.keras import WandbCallback |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = nn.Conv2d(3, 32, 3, 1) | |
torch.nn.init.kaiming_uniform_(self.conv1.weight, mode='fan_in', nonlinearity='relu') | |
self.conv2 = nn.Conv2d(32, 32, 3, 1) | |
torch.nn.init.kaiming_uniform_(self.conv2.weight, mode='fan_in', nonlinearity='relu') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NetforExplode(nn.Module): | |
def __init__(self): | |
super(NetforExplode, self).__init__() | |
self.conv1 = nn.Conv2d(1, 32, 3, 1) | |
self.conv1.weight.data.fill_(100) | |
self.conv1.bias.data.fill_(-100) | |
self.conv2 = nn.Conv2d(32, 64, 3, 1) | |
self.conv2.weight.data.fill_(100) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
lr_finder = LRFinder(net, optimizer, device) | |
lr_finder.range_test(trainloader, end_lr=10, num_iter=100, logwandb=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if logwandb: | |
wandb.log({'lr': lr_schedule.get_lr()[0], 'loss': loss}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
net = Net().to(device) | |
optimizer = optim.Adam(net.parameters()) | |
wandb.init(project='pytorchw_b') | |
wandb.watch(net, log='all') | |
for epoch in range(10): | |
train(net, device, trainloader, optimizer, epoch) | |
test(net, device, testloader, classes) | |
NewerOlder