Skip to content

Instantly share code, notes, and snippets.

@NaxAlpha
Created May 24, 2023 23:29
Show Gist options
  • Save NaxAlpha/3d69432aa81a9ab47dee70c7a16ad8a5 to your computer and use it in GitHub Desktop.
Save NaxAlpha/3d69432aa81a9ab47dee70c7a16ad8a5 to your computer and use it in GitHub Desktop.
Fine-tune Pythia model on Multimodal C4 dataset
# WIP: Fine-tuned a Causal LM with images & text mixed on MMC4 Dataset
import os
import json
import random
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cuda as cuda
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
import timm
import timm.data
from transformers import AutoTokenizer, GPTNeoXForCausalLM
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention
def _attn_wrapper(self, query, key, value, attention_mask=None, head_mask=None):
assert attention_mask is None and head_mask is None, "Not implemented"
with cuda.sdp_kernel(enable_math=False, enable_flash=False):
out = F.scaled_dot_product_attention(
query.half(),
key.half(),
value.half(),
is_causal=True,
).float()
return out, None
# patch attention to save a lot of memory
GPTNeoXAttention._attn = _attn_wrapper
class MultiModalC4(IterableDataset):
def __init__(
self,
dataset_path,
tokenizer_name,
image_model_name,
image_tokens=49,
max_seq_len=1024,
image_token_id=1,
cache_buffer_size=1000,
):
self.path = dataset_path
# jsonl files:
shards = [
os.path.join(self.path, f)
for f in os.listdir(self.path)
if f.endswith(".jsonl")
]
self.shards = sorted(shards)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.image_tokens = image_tokens
self.image_token_id = image_token_id
self.max_seq_len = max_seq_len
self.cache_buffer_size = cache_buffer_size
data_config = timm.data.resolve_model_data_config(image_model_name)
self.transforms = timm.data.create_transform(**data_config, is_training=True)
def _load_image(self, shard_id, image_info):
base_path = os.path.join(self.path, f"{shard_id}")
img_path = os.path.join(base_path, image_info["image_name"])
idx = image_info["matched_text_index"]
try:
img = Image.open(img_path)
except:
img = None
return img, idx
def _merge_images_texts(self, texts, img_map):
txt_list = texts
output = []
for i, txt in enumerate(txt_list):
if i in img_map:
# always put image first
output.append(img_map[i])
output.append(txt)
else:
output.append(txt)
result = [output[0]]
for x in output[1:]:
if isinstance(x, str) and isinstance(result[-1], str):
result[-1] += "\n" + x
else:
result.append(x)
return result
def _flatten_merged(self, merged):
processed_images = []
imgs_starts = []
text_tokens = []
for i, x in enumerate(merged):
if isinstance(x, str):
text_tokens += self.tokenizer.encode(x)
else:
img_token_placeholder = [self.image_token_id] * (self.image_tokens + 2)
img = self.transforms(x.convert("RGB"))
processed_images.append(img)
imgs_starts.append(len(text_tokens) + 1)
text_tokens += img_token_placeholder
if processed_images:
processed_images = torch.stack(processed_images)
imgs_starts = torch.tensor(imgs_starts)
else:
processed_images = torch.zeros(0, 3, 224, 224)
imgs_starts = torch.zeros(0)
return dict(
text_tokens=text_tokens,
processed_images=processed_images,
imgs_starts=imgs_starts,
)
def _stream_shard(self, shard_id):
with ThreadPoolExecutor() as executor:
with open(self.shards[shard_id], "r") as f:
for line in f:
obj = json.loads(line)
imgs = executor.map(
self._load_image,
[shard_id] * len(obj["image_info"]),
obj["image_info"],
)
imgs, idxs = zip(*imgs)
img_map = {idx: img for idx, img in zip(idxs, imgs) if img}
merged = self._merge_images_texts(obj["text_list"], img_map)
yield self._flatten_merged(merged)
def _stream_all(self, rnd):
shard_ids = list(range(len(self.shards)))
rnd.shuffle(shard_ids)
for shard_id in shard_ids:
yield from self._stream_shard(shard_id)
def _buffered_stream(self):
wi = get_worker_info()
if wi is None:
seed = None
else:
seed = wi.seed
rnd = random.Random(seed)
buffer = []
for doc in self._stream_all(rnd):
buffer.append(doc)
if len(buffer) >= self.cache_buffer_size:
idx = rnd.randint(0, len(buffer) - 1)
yield buffer.pop(idx)
yield from buffer
def _find_trainable_seq_range(self, text_tokens):
idx = 0
buffer_found = False
while idx + self.max_seq_len < len(text_tokens):
if (
text_tokens[idx] != self.image_token_id
and text_tokens[idx + self.max_seq_len] != self.image_token_id
):
buffer_found = True
break
idx += 1
return buffer_found, idx
def _find_images_on_crop(self, idx, imgs_starts):
in_range = (imgs_starts >= idx) * (imgs_starts <= idx + self.max_seq_len)
indieces = torch.nonzero(in_range).squeeze(1).tolist()
if not indieces:
return 0, 0
sid, eid = indieces[0], indieces[-1] + 1
return sid, eid
def _joined_docs(self):
text_tokens = []
# image_counts = []
imgs_starts = None
processed_images = None
for doc in self._buffered_stream():
if processed_images is None:
processed_images = doc["processed_images"]
imgs_starts = doc["imgs_starts"]
else:
text_tokens += [self.tokenizer.eos_token_id]
processed_images = torch.cat(
(
processed_images,
doc["processed_images"],
)
)
new_starts = doc["imgs_starts"] + len(text_tokens)
imgs_starts = torch.cat((imgs_starts, new_starts))
# image_counts.append(len(doc["imgs_starts"]))
text_tokens += doc["text_tokens"]
if len(text_tokens) < self.max_seq_len + 1:
continue
buffer_found, idx = self._find_trainable_seq_range(text_tokens)
im_idx1, im_idx2 = self._find_images_on_crop(idx, imgs_starts)
if buffer_found:
# send cropped buffer
yield dict(
text_tokens=text_tokens[idx : idx + self.max_seq_len + 1],
imgs_starts=imgs_starts[im_idx1:im_idx2] - idx,
processed_images=processed_images[im_idx1:im_idx2],
)
idx += self.max_seq_len + 1
# destroy the buffer till idx
text_tokens = text_tokens[idx:]
# image_counts = image_counts[im_idx:]
imgs_starts = imgs_starts[im_idx2:] - idx
processed_images = processed_images[im_idx2:]
def __iter__(self):
for doc in self._joined_docs():
yield doc
def mmc4_collate_fn(batch):
text_tokens = []
processed_images = []
imgs_starts = []
imgs_counts = []
for doc in batch:
text_tokens.append(doc["text_tokens"])
processed_images.append(doc["processed_images"])
imgs_starts.append(doc["imgs_starts"])
imgs_counts.append(len(doc["imgs_starts"]))
text_tokens = torch.tensor(text_tokens)
processed_images = torch.cat(processed_images)
imgs_starts = torch.cat(imgs_starts)
imgs_counts = torch.tensor(imgs_counts)
return dict(
text_tokens=text_tokens,
processed_images=processed_images,
imgs_starts=imgs_starts,
imgs_counts=imgs_counts,
)
class MultiModalPythia(nn.Module):
def __init__(self, transformer_model, image_model, image_token_id=1):
super().__init__()
self.transformer = GPTNeoXForCausalLM.from_pretrained(transformer_model)
self.vision = timm.create_model(
image_model,
pretrained=True,
num_classes=0,
)
vis_emb = self.vision.embed_dim[-1]
lm_emb = self.transformer.config.hidden_size
self.proj = nn.Linear(vis_emb, lm_emb)
self.image_token_id = image_token_id
def forward(self, text_tokens, processed_images, imgs_starts, imgs_counts):
inp_txt = text_tokens[:, :-1]
out_txt = text_tokens[:, 1:].clone()
out_txt[out_txt == self.image_token_id] = -100
txt_emb = self.transformer.gpt_neox.embed_in(inp_txt)
if processed_images.size(0) > 0:
img_emb = self.vision.forward_features(processed_images)
img_emb = img_emb.view(*img_emb.shape[:2], -1).permute(0, 2, 1)
img_emb = self.proj(img_emb)
N = img_emb.shape[1]
imgs_counts = [0] + imgs_counts.tolist()
for i, j in zip(imgs_counts[:-1], imgs_counts[1:]):
imgs = img_emb[i:j]
starts = imgs_starts[i:j]
for s, img in zip(starts, imgs):
txt_emb[s : s + N] = img
logits = self.transformer(inputs_embeds=txt_emb).logits
loss = F.cross_entropy(
logits.view(-1, logits.shape[-1]),
out_txt.reshape(-1),
ignore_index=-100,
)
return loss, logits
if __name__ == "__main__":
from tqdm import tqdm
from torch.optim import Adam
ds = MultiModalC4(
"../mmc4-ff",
"EleutherAI/pythia-1b-deduped",
"focalnet_large_fl4.ms_in22k",
image_tokens=49,
max_seq_len=1024,
image_token_id=1,
)
max_images = 0
# prog = tqdm(ds)
# for i, x in enumerate(prog):
# t = x["text_tokens"]
# q = [-1] * len(t)
# for ims in x["imgs_starts"].long().tolist():
# q[ims - 1 : ims + 50] = [1] * 51
# assert torch.tensor(q == 1).sum() == torch.tensor(t == 1).sum()
# max_images = max(max_images, len(x["imgs_starts"]))
# prog.set_postfix(max_images=max_images)
loader = DataLoader(
dataset=ds,
batch_size=4,
num_workers=4,
collate_fn=mmc4_collate_fn,
)
dev = "cuda"
prog = tqdm(loader)
model = MultiModalPythia(
"EleutherAI/pythia-1b-deduped",
"focalnet_large_fl4.ms_in22k",
image_token_id=1,
).to(dev)
opt = Adam(model.parameters(), lr=1e-5)
for i, x in enumerate(prog):
x = {k: v.to(dev) for k, v in x.items()}
loss, logits = model(**x)
opt.zero_grad()
loss.backward()
opt.step()
prog.set_postfix(loss=loss.item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment