Skip to content

Instantly share code, notes, and snippets.

View thomwolf's full-sized avatar
🚂
training

Thomas Wolf thomwolf

🚂
training
View GitHub Profile
@thomwolf
thomwolf / gradient_accumulation.py
Last active April 19, 2026 07:30
PyTorch gradient accumulation training loop
model.zero_grad() # Reset gradients tensors
for i, (inputs, labels) in enumerate(training_set):
predictions = model(inputs) # Forward pass
loss = loss_function(predictions, labels) # Compute loss function
loss = loss / accumulation_steps # Normalize our loss (if averaged)
loss.backward() # Backward pass
if (i+1) % accumulation_steps == 0: # Wait for several backward steps
optimizer.step() # Now we can do an optimizer step
model.zero_grad() # Reset gradients tensors
if (i+1) % evaluation_steps == 0: # Evaluate the model when we...
@thomwolf
thomwolf / Hard_Sigmoid_LSTM.py
Created October 3, 2017 09:54
A pyTorch LSTM Cell with a hard sigmoid recurrent activation
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
"""
A modified LSTM cell with hard sigmoid activation on the input, forget and output gates.
"""
hx, cx = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = hard_sigmoid(ingate)
@thomwolf
thomwolf / gpt-2-wikitext-103.py
Last active December 3, 2025 07:57
A very small and self-contained gist to train a GPT-2 transformer model on wikitext-103
# Copyright (c) 2019-present, Thomas Wolf.
# All rights reserved. This source code is licensed under the MIT-style license.
""" A very small and self-contained gist to train a GPT-2 transformer model on wikitext-103 """
import os
from collections import namedtuple
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from ignite.engine import Engine, Events
@thomwolf
thomwolf / bayes_by_backprop.py
Created November 30, 2017 13:27 — forked from vvanirudh/bayes_by_backprop.py
Bayes by Backprop in PyTorch (introduced in the paper "Weight uncertainty in Neural Networks", Blundell et. al. 2015)
# Drawn from https://gist.github.com/rocknrollnerd/c5af642cf217971d93f499e8f70fcb72 (in Theano)
# This is implemented in PyTorch
# Author : Anirudh Vemula
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from sklearn.datasets import fetch_mldata
@thomwolf
thomwolf / top-k-top-p.py
Last active October 25, 2025 20:25
Sample the next token from a probability distribution using top-k and/or nucleus (top-p) sampling
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k >0: keep only top k tokens with highest probability (top-k filtering).
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
"""
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
top_k = min(top_k, logits.size(-1)) # Safety check
@thomwolf
thomwolf / persona-chat.py
Last active May 23, 2025 06:32
Download and load persona-chat json dataset
import json
from pytorch_pretrained_bert import cached_path
url = "https://s3.amazonaws.com/datasets.huggingface.co/personachat/personachat_self_original.json"
# Download and load JSON dataset
personachat_file = cached_path(url)
with open(personachat_file, "r", encoding="utf-8") as f:
dataset = json.loads(f.read())
@thomwolf
thomwolf / fast_speech_text_speech.py
Last active January 14, 2025 12:13
speech to text to speech
""" To use: install LLM studio (or Ollama), clone OpenVoice, run this script in the OpenVoice directory
git clone https://github.com/myshell-ai/OpenVoice
cd OpenVoice
git clone https://huggingface.co/myshell-ai/OpenVoice
cp -r OpenVoice/* .
pip install whisper pynput pyaudio
"""
from openai import OpenAI
import time
@thomwolf
thomwolf / loading_wikipedia.py
Last active January 12, 2025 13:34
Load full English Wikipedia dataset in HuggingFace nlp library
import os; import psutil; import timeit
from datasets import load_dataset
mem_before = psutil.Process(os.getpid()).memory_info().rss >> 20
wiki = load_dataset("wikipedia", "20200501.en", split='train')
mem_after = psutil.Process(os.getpid()).memory_info().rss >> 20
print(f"RAM memory used: {(mem_after - mem_before)} MB")
s = """batch_size = 1000
for i in range(0, len(wiki), batch_size):
""" To use: install LLM studio (or Ollama), clone OpenVoice, run this script in the OpenVoice directory
git clone https://github.com/myshell-ai/OpenVoice
cd OpenVoice
git clone https://huggingface.co/myshell-ai/OpenVoice
cp -r OpenVoice/* .
pip install whisper pynput pyaudio
"""
from dataclasses import dataclass
from typing import Optional
import random
@thomwolf
thomwolf / AdamW.py
Created July 3, 2018 21:20
Implements Adam algorithm with weight decay fix in PyTorch (paper: https://arxiv.org/abs/1711.05101)
from torch.optim import Optimizer
class AdamW(Optimizer):
"""
Implements Adam algorithm with weight decay fix in PyTorch
Paper: Fixing Weight Decay Regularization in Adam by Ilya Loshchilov, Frank Hutter
https://arxiv.org/abs/1711.05101
"""
def __init__(self, params, lr, b1=0.9, b2=0.999, e=1e-8, l2=0,
vector_l2=False, max_grad_norm=-1, **kwargs):