Skip to content

Instantly share code, notes, and snippets.

View andrewor14's full-sized avatar

andrewor14

View GitHub Profile
@andrewor14
andrewor14 / gist:048b5c1bd01b7fa23c53913856a8ef9f
Created September 5, 2025 20:15
Unsloth QAT full fine-tuning
from unsloth import FastLanguageModel
import torch
from torchao.quantization import Float8DynamicActivationInt4WeightConfig
from transformers import AutoModelForCausalLM, TextStreamer, TorchAoConfig
qat_scheme = "fp8-int4"
save_output_path = "/tmp/unsloth_model"
max_seq_length = 2048
@andrewor14
andrewor14 / gist:b0364ac3cb8aa114e46b39d848fa5c8b
Created August 29, 2025 21:50
Unsloth QAT full finetuning test
# Based on https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb
# but with `full_finetuning=True` and without `get_peft_model`
import os
from unsloth import FastLanguageModel
from transformers import TextStreamer
import torch
max_seq_length = 2048
model, tokenizer = FastLanguageModel.from_pretrained(
@andrewor14
andrewor14 / gist:ab650350b69276cf585c008914aaa146
Last active August 29, 2025 20:30
Repro unsloth full finetuning
# Based on https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb
# but with `full_finetuning=True` and without `get_peft_model`
# Output is at the bottom of the gist
import os
from unsloth import FastLanguageModel
from transformers import TextStreamer
import torch
max_seq_length = 2048
$ GRADIO_SERVER_NAME="0.0.0.0" python test_sayak.py
/home/andrewor/local/ao/torchao/utils.py:408: UserWarning: TORCH_VERSION_AT_LEAST_2_8 is deprecated and will be removed in torchao 0.14.0
warnings.warn(self.msg)
/home/andrewor/local/ao/torchao/utils.py:408: UserWarning: TORCH_VERSION_AT_LEAST_2_7 is deprecated and will be removed in torchao 0.14.0
warnings.warn(self.msg)
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 65.28it/s]
Step 1: Applying QAT observers to the model...
/home/andrewor/local/ao/torchao/quantization/qat/utils.py:84: UserWarning: 'FakeQuantizeConfig' is deprecated and will be removed in a future release. Please use the following API instead:
batch_size: 16
batch_size_val: 8
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.2-3B-Instruct/
checkpoint_files:
- model-00001-of-00002.safetensors
- model-00002-of-00002.safetensors
model_type: LLAMA3_2
output_dir: /home/andrewor/local/logs/tune/Llama3.2-3B_qat
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/andrewor/local/torchtune/recipes/full_finetune_distributed.py", line 982, in <module>
[rank0]: sys.exit(recipe_main())
[rank0]: File "/home/andrewor/local/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank0]: sys.exit(recipe_main(conf))
[rank0]: File "/home/andrewor/local/torchtune/recipes/full_finetune_distributed.py", line 977, in recipe_main
[rank0]: recipe.train()
[rank0]: File "/home/andrewor/local/torchtune/recipes/full_finetune_distributed.py", line 810, in train
[rank0]: logits = self._model(**batch)
[rank0]: File "/home/andrewor/local/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
@andrewor14
andrewor14 / compare_backward_numerics_pr_116092
Created March 1, 2024 16:39
Compare numerics: batch_norm_backward vs native_batch_norm_backward (#116092)
# Debug test failure in https://github.com/pytorch/pytorch/pull/116092 for:
# python test/test_decomp.py -k test_comprehensive_batch_norm_with_update_cuda_bfloat16
# Set up args (these are the exact tensors saved from the decomp test)
# All tensors in args16 are bfloat16
# All tensors in args64 are the same values in args16 upcast to float64
>>> args16
[tensor([[-0.5468750000],
[ 0.7812500000]], device='cuda:0', dtype=torch.bfloat16), tensor([[-1.5234375000],
[-4.1875000000]], device='cuda:0', dtype=torch.bfloat16,
requires_grad=True), tensor([8.8125000000], device='cuda:0', dtype=torch.bfloat16,
// Code generated by ColumnarBatchScan.scala when reading the column buffers
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator inmemorytablescan_input;
/* 008 */ private org.apache.spark.sql.execution.metric.SQLMetric inmemorytablescan_numOutputRows;
// Code generated by GenerateColumnarBatch.scala when building the column buffers
/* 001 */ import org.apache.spark.memory.MemoryMode;
/* 002 */ import org.apache.spark.sql.catalyst.InternalRow;
/* 003 */ import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
/* 004 */
/* 005 */ public GeneratedColumnarBatchIterator generate(Object[] references) {
/* 006 */ return new GeneratedColumnarBatchIterator(references);
/* 007 */ }
/* 008 */

This page tries to prove that the following two are equivalent, as suggested by @davies.

// === (1): The original code in Spark 1.6 before PR 10240

val maxToGrant = math.min(numBytes, math.max(0, maxMemoryPerTask - curMem))
val toGrant = math.min(maxToGrant, memoryFree)

if (curMem < minMemoryPerTask) {
  if (memoryFree >= math.min(maxToGrant, minMemoryPerTask - curMem)) {