Skip to content

Instantly share code, notes, and snippets.

@afiaka87
Last active September 29, 2023 09:35
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save afiaka87/012846e2c907173a300346c749a2d0b2 to your computer and use it in GitHub Desktop.
Save afiaka87/012846e2c907173a300346c749a2d0b2 to your computer and use it in GitHub Desktop.
Finetune CLIP on a 'webdataset' formatted dataset
from torchvision.transforms.transforms import GaussianBlur
import webdataset as wds
import io
from PIL import Image
from clip.loader import TextImageDataset
from clip.clip import load, tokenize
from torchvision import transforms as T
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn.functional as F
import wandb # Quit early if user doesn't have wandb installed.
import argparse
import time
import torch
from glob import glob
import clip
# argument parsing
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str,
help='name of CLIP model')
parser.add_argument('--image_text_path', type=str, required=True,
help='path to your path of images and text for learning the CLIP')
parser.add_argument('--clip_output_file_name', type=str, default="clip",
help='output_file_name')
parser.add_argument('--wandb_name', default='clip_finetuning',
help='Name W&B will use when saving results.\ne.g. `--wandb_name "coco2017-full-sparse"`')
train_group = parser.add_argument_group('Training settings')
train_group.add_argument('--epochs', default=40,
type=int, help='Number of epochs')
train_group.add_argument('--text_seq_len', default=77,
type=int, help='Text sequence length')
train_group.add_argument('--save_every_n_steps', default=1000,
type=int, help='Save a checkpoint every n steps')
train_group.add_argument('--batch_size', default=32,
type=int, help='Batch size')
train_group.add_argument('--ga_steps', default=1, type=int,
help='Number of steps to accumulate gradients across per each iteration')
train_group.add_argument('--learning_rate', default=1e-7,
type=float, help='Learning rate')
train_group.add_argument('--clip_grad_norm', default=0.5,
type=float, help='Clip gradient norm')
train_group.add_argument('--warmup_steps', default=10000, type=int)
args = parser.parse_args()
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def get_trainable_params(model):
return [params for params in model.parameters() if params.requires_grad]
def create_clip_img_transform(image_width):
clip_mean = [0.48145466, 0.4578275, 0.40821073]
clip_std = [0.26862954, 0.26130258, 0.27577711]
transform = T.Compose([
# T.ToPILImage(),
# T.CenterCrop((image_width, image_width)),
T.Resize(336, interpolation=T.InterpolationMode.LANCZOS),
T.Resize(image_width, interpolation=T.InterpolationMode.LANCZOS),
# T.RandomResizedCrop(size=(image_width, image_width), scale=(1.0, 1.0), ratio=(1.0, 1.0), interpolation=T.InterpolationMode.BILINEAR),
T.ToTensor(),
T.Normalize(mean=clip_mean, std=clip_std)
])
return transform
def create_webdataset(
urls,
image_transform,
enable_text=True,
enable_image=True,
image_key='jpg',
caption_key='txt',
enable_metadata=False,
cache_path=None,):
dataset = wds.WebDataset(urls, cache_dir=cache_path, cache_size=10**10, handler=wds.handlers.warn_and_continue)
tokenizer = lambda text: clip.tokenize([text], truncate=True)[0]
def filter_dataset(item):
if enable_text and caption_key not in item:
return False
if enable_image and image_key not in item:
return False
if enable_metadata and "json" not in item:
return False
return True
filtered_dataset = dataset.select(filter_dataset)
def preprocess_dataset(item):
output = {}
if enable_image:
image_data = item[image_key]
image = Image.open(io.BytesIO(image_data))
image_tensor = image_transform(image)
output["image_filename"] = item["__key__"]
output["image_tensor"] = image_tensor
if enable_text:
text = item[caption_key]
caption = text.decode("utf-8")
tokenized_text = tokenizer(caption)
output["text_tokens"] = tokenized_text
output["text"] = caption
if enable_metadata:
metadata_file = item["json"]
metadata = metadata_file.decode("utf-8")
output["metadata"] = metadata
return output
transformed_dataset = filtered_dataset.map(preprocess_dataset, handler=wds.handlers.warn_and_continue)
return transformed_dataset
CLIP_OUTPUT_FILE_NAME = args.clip_output_file_name + ".pt"
CLIP_FINAL_OUTPUT_FILE_NAME = args.clip_output_file_name + "-final.pt"
WARMUP_STEPS = int(args.warmup_steps) # enables learning rate warmup.
EPOCHS = args.epochs
BATCH_SIZE = args.batch_size
TEXT_SEQ_LEN = args.text_seq_len
LEARNING_RATE = args.learning_rate if WARMUP_STEPS == 0 else 1e-12
print(f"Staring with learning rate: {LEARNING_RATE}")
GRAD_CLIP_NORM = args.clip_grad_norm
ACCUM_STEPS = args.ga_steps
SAVE_EVERY_N_STEPS = args.save_every_n_steps
MODEL_NAME = args.model_name
truncate_captions = True
input_resolution = 224
IMAGE_SIZE = 224
# load the dataset and transform
# create dataset and dataloader
is_shuffle = True #@not distributed_utils.using_backend(distributed_utils.HorovodBackend)
DATASET_SIZE = int(1e9) # You need to set a nominal length for the Dataset in order to avoid warnings from DataLoader
WEBDATASET_PATH = glob(args.image_text_path)
dataset = create_webdataset(WEBDATASET_PATH, create_clip_img_transform(IMAGE_SIZE), True, True, image_key='jpg', caption_key='txt', enable_metadata=False, cache_path=None)
dl = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=24, pin_memory=True, prefetch_factor=2, drop_last=True)
# Load CLIP
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if ".pt" in MODEL_NAME:
assert exists(MODEL_NAME), "checkpoint does not exist"
print(f"Resuming training from {MODEL_NAME}.")
clip_model, _ = load(MODEL_NAME, device=device)
clip_model.train()
# clip_model.eval() # TODO experimenting with https://github.com/openai/CLIP/issues/150
input_res = clip_model.visual.input_resolution # 224
clip_transform = create_clip_img_transform(input_res)
# optimizer
opt = Adam(get_trainable_params(clip_model), lr=LEARNING_RATE,
betas=(0.9, 0.98), eps=1e-06, weight_decay=0.)
model_config = dict(
batch_size=BATCH_SIZE,
learning_rate=LEARNING_RATE,
clip_grad_norm=GRAD_CLIP_NORM,
ga_steps=ACCUM_STEPS,
model_name=MODEL_NAME,
save_every_n_steps=SAVE_EVERY_N_STEPS,
clip_output_file_name=CLIP_OUTPUT_FILE_NAME,
clip_final_output_file_name=CLIP_FINAL_OUTPUT_FILE_NAME,
wandb_name=args.wandb_name,
text_seq_len=TEXT_SEQ_LEN,
image_width=input_res,
truncate_captions=truncate_captions,
device=device,
)
run = wandb.init(
project=args.wandb_name, # 'clip_finetuning' by default
config=model_config,
)
def save_model(path):
save_obj = clip_model.state_dict()
torch.save(save_obj, path)
if WARMUP_STEPS > 0:
print("Warmup steps:", WARMUP_STEPS)
save_model(f'./{CLIP_OUTPUT_FILE_NAME}')
# training
print(f"Training started ...loaded. ")
steps = 0
t = time.time() # Get initial time.
for epoch in range(0, EPOCHS):
print(f"Epoch 0 ")
try:
# for i, (texts, images) in enumerate(dl):
for i, item in enumerate(dl):
texts = item["text_tokens"]
images = item["image_tensor"]
if i % 10 == 0:
t = time.time()
texts, images = map(lambda t: t.cuda(), (texts, images))
logits_per_image, logits_per_text = clip_model(images, texts)
labels = torch.arange(BATCH_SIZE, device=device)
text_loss = F.cross_entropy(logits_per_image, labels)
image_loss = F.cross_entropy(logits_per_text, labels) / 2
loss = text_loss + image_loss
loss.backward()
opt.step()
opt.zero_grad()
log = {}
lr = opt.param_groups[0]['lr']
# Warm up learning rate
if lr < 1e-6:
lr = 1e-6 * (steps / WARMUP_STEPS)
for param_group in opt.param_groups:
param_group['lr'] = lr
print(f"Warmup step {steps}/{WARMUP_STEPS}")
print(f"Learning rate: {lr}")
if i % 10 == 0:
print(f'epoch - {epoch},', f'step - {i},', f'loss - {loss.item()}', f'text_loss - {text_loss.item()}',
f'image_loss - {image_loss.item()}')
log = {
**log,
'epoch': epoch,
'iter': i,
'loss': loss.item(),
'text_loss': text_loss.item(),
'image_loss': image_loss.item(),
'lr': lr
}
if i % 10 == 9:
sample_per_sec = BATCH_SIZE * 10 / (time.time() - t)
log["sample_per_sec"] = sample_per_sec
print(epoch, i, f'sample_per_sec - {sample_per_sec}')
if i % SAVE_EVERY_N_STEPS == 0:
save_model(f'./{CLIP_OUTPUT_FILE_NAME}')
steps += 1
wandb.log(log)
# save trained model to wandb as an artifact every epoch's end
model_artifact = wandb.Artifact(
'finetuned-clip', type='model', metadata=dict(model_config))
run.log_artifact(model_artifact)
except KeyError as e:
print(e)
break
save_model(f'./{CLIP_FINAL_OUTPUT_FILE_NAME}')
model_artifact = wandb.Artifact(
'finetuned-clip', type='model', metadata=dict(model_config))
run.log_artifact(model_artifact)
wandb.finish()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment