Skip to content

Instantly share code, notes, and snippets.

View KellerJordan's full-sized avatar

Keller Jordan KellerJordan

  • Berkeley, California
View GitHub Profile
@KellerJordan
KellerJordan / dawnbench_dpage.py
Created December 20, 2023 08:47
standard PyTorch version of the final training script from David Page's How to Train Your ResNet
# dawnbench_dpage.py
# This script aims for exact equivalence to the final training procedure described in David Page's post
# https://myrtle.ai/learn/how-to-train-your-resnet-8-bag-of-tricks/, while also being highly readable and
# adhering to typical PyTorch conventions.
#
# We ran the following test for equivalence. We executed the final (10-epoch) training configuration provided
# by David's code release https://github.com/davidcpage/cifar10-fast/blob/master/bag_of_tricks.ipynb a total
# of n=400 times, and we executed this script n=300 times. We observed that the original notebook code yielded
# a mean accuracy of 94.10%, and our script yielded a mean accuracy of 94.09%. We calculated the statistical
# significance of this difference, finding it to be insigificant (p=0.44).
"""
Results from 5 seeds:
95.28, 95.35, 95.17, 95.26, 95.28
"""
#############################################
# Setup/Hyperparameters #
#############################################
import os
"""
BatchNorm-free variant of airbench94
90.6% mean accuracy in ~6 seconds on an H100
Changes relative to airbench94:
- removed BatchNorms and added conv biases
- reduced batch size 1024 -> 384
- reduced weight decay 0.015 -> 0.001
- reduced lr 11.5 -> 10.0
- increased epochs 9.9 -> 11
"""