Skip to content

Instantly share code, notes, and snippets.

@calpt
Last active May 14, 2024 12:21
Show Gist options
  • Save calpt/8e3555bd11f1916b5169c8125117e5ee to your computer and use it in GitHub Desktop.
Save calpt/8e3555bd11f1916b5169c8125117e5ee to your computer and use it in GitHub Desktop.
Convert multilingual LAION CLIP checkpoints from OpenCLIP to Hugging Face Transformers

OpenCLIP -> HF Transformers Conversion

Setup

pip install open_clip_torch transformers

Convert

# LAION XLM-Roberta Base model
python convert.py xlm-roberta-base-ViT-B-32 laion5b_s13b_b90k

# LAION XLM-Roberta Large model
python convert.py xlm-roberta-large-ViT-H-14 frozen_laion5b_s13b_b90k

Use converted checkpoints

from modeling_clip import OpenCLIPVisionTextDualEncoderModel

model = OpenCLIPVisionTextDualEncoderModel.from_pretrained("...")

Pre-converted checkpoints

import dataclasses
import os
import re
import sys
import open_clip
from open_clip.hf_model import HFTextEncoder
from open_clip.model import CLIPVisionCfg
from transformers import CLIPVisionConfig, VisionTextDualEncoderConfig
from modeling_clip import OpenCLIPVisionTextDualEncoderModel
VISION_CONFIG_MAP = {
"layers": "num_hidden_layers",
"width": "hidden_size",
"patch_size": "patch_size",
"image_size": "image_size",
}
STATE_DICT_PATTERNS = [
# Vision
(r"visual\.class_embedding", "vision_model.vision_model.embeddings.class_embedding"),
(r"visual\.positional_embedding", "vision_model.vision_model.embeddings.position_embedding.weight"),
(r"visual\.conv1\.(\w+)", "vision_model.vision_model.embeddings.patch_embedding.{0}"),
(r"visual\.ln_pre\.(\w+)", "vision_model.vision_model.pre_layrnorm.{0}"),
(r"visual\.ln_post\.(\w+)", "vision_model.vision_model.post_layernorm.{0}"),
(
r"visual\.transformer\.resblocks\.(\w+)\.ln_1\.(\w+)",
"vision_model.vision_model.encoder.layers.{0}.layer_norm1.{1}",
),
(
r"visual\.transformer\.resblocks\.(\w+)\.ln_2\.(\w+)",
"vision_model.vision_model.encoder.layers.{0}.layer_norm2.{1}",
),
(
r"visual\.transformer\.resblocks\.(\w+)\.attn\.out_proj\.(\w+)",
"vision_model.vision_model.encoder.layers.{0}.self_attn.out_proj.{1}",
),
(
r"visual\.transformer\.resblocks\.(\w+)\.mlp\.c_fc\.(\w+)",
"vision_model.vision_model.encoder.layers.{0}.mlp.fc1.{1}",
),
(
r"visual\.transformer\.resblocks\.(\w+)\.mlp\.c_proj\.(\w+)",
"vision_model.vision_model.encoder.layers.{0}.mlp.fc2.{1}",
),
# Text
(r"text\.transformer\.(.+)", "text_model.{0}"),
(r"text\.proj\.(.+)", "text_projection.{0}"),
]
def convert_vision_config(config: CLIPVisionCfg):
config = dataclasses.asdict(config)
new_config = {
"hidden_act": "gelu",
}
for key, value in config.items():
if key in VISION_CONFIG_MAP:
new_config[VISION_CONFIG_MAP[key]] = value
elif key == "head_width":
new_config["num_attention_heads"] = config["width"] // value
elif key == "mlp_ratio":
new_config["intermediate_size"] = int(config["width"] * value)
elif not key.startswith("timm") and value:
print(f"WARNING: Unknown key '{key}' in vision config.")
return CLIPVisionConfig(**new_config)
def convert_state_dict(state_dict):
new_state_dict = {}
for k, v in state_dict.items():
found = False
# special handling of vision attention blocks
if match := re.match(r"visual\.transformer\.resblocks\.(\w+)\.attn\.in_proj_(\w+)", k):
# chunk weights into three
chunks = v.chunk(3, dim=0)
for proj_name, proj_v in zip(["q_proj", "k_proj", "v_proj"], chunks):
new_k = f"vision_model.vision_model.encoder.layers.{match.group(1)}.self_attn.{proj_name}.{match.group(2)}"
print(k, "--->", new_k)
new_state_dict[new_k] = proj_v
found = True
# transpose visual projection
elif k == "visual.proj":
new_k = "visual_projection.weight"
print(k, "--->", new_k)
new_state_dict[new_k] = v.t()
found = True
else:
for pattern, replacement in STATE_DICT_PATTERNS:
if match := re.match(pattern, k):
new_k = replacement.format(*match.groups())
print(k, "--->", new_k)
new_state_dict[new_k] = v
found = True
break
if not found:
new_state_dict[k] = v
return new_state_dict
if __name__ == "__main__":
model_name = sys.argv[1]
pretrained = sys.argv[2]
openclip_config = open_clip.get_model_config(model_name)
openclip_model = open_clip.create_model(model_name, pretrained=pretrained)
if not isinstance(openclip_model.text, HFTextEncoder):
raise ValueError("Only HFTextEncoder is supported.")
if openclip_config["text_cfg"]["pooler_type"] != "mean_pooler":
raise ValueError("Only mean_pooler is supported.")
text_config = openclip_model.text.config
vision_config = convert_vision_config(CLIPVisionCfg(**openclip_config["vision_cfg"]))
config = VisionTextDualEncoderConfig.from_vision_text_configs(
vision_config,
text_config,
projection_dim=openclip_config["embed_dim"],
)
state_dict = convert_state_dict(openclip_model.state_dict())
model, loading_info = OpenCLIPVisionTextDualEncoderModel.from_pretrained(
None, config=config, state_dict=state_dict, output_loading_info=True
)
print(loading_info)
out_path = os.path.join("models", model_name)
model.save_pretrained(out_path)
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import PreTrainedModel, VisionTextDualEncoderConfig, VisionTextDualEncoderModel
from transformers.models.vision_text_dual_encoder.modeling_vision_text_dual_encoder import clip_loss, CLIPOutput
class MeanPooler(nn.Module):
"""Mean pooling"""
def forward(self, x, attention_mask):
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
class OpenCLIPVisionTextDualEncoderModel(VisionTextDualEncoderModel):
def __init__(
self,
config: Optional[VisionTextDualEncoderConfig] = None,
vision_model: Optional[PreTrainedModel] = None,
text_model: Optional[PreTrainedModel] = None,
add_text_model_pooling_layer: bool = False,
):
super().__init__(config, vision_model, text_model)
# Remove text pooling layer
if not add_text_model_pooling_layer:
self.text_model.pooler = None
# Add mean pooling
self.pooler = MeanPooler()
# Overwrite text projection
hidden_size = (self.text_embed_dim + self.projection_dim) // 2
self.text_projection = nn.Sequential(
nn.Linear(self.text_embed_dim, hidden_size, bias=False),
nn.GELU(),
nn.Linear(hidden_size, self.projection_dim, bias=False),
)
def get_text_features(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
token_type_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
token_type_ids=token_type_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = self.pooler(text_outputs, attention_mask)
text_features = self.text_projection(pooled_output)
return text_features
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
token_type_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], CLIPOutput]:
return_dict = return_dict if return_dict is not None else self.config.return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1] # pooler_output
image_embeds = self.visual_projection(image_embeds)
pooled_output = self.pooler(text_outputs, attention_mask)
text_embeds = self.text_projection(pooled_output)
# normalized features
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.T
loss = None
if return_loss:
loss = clip_loss(logits_per_text)
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
return CLIPOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
@aleablu
Copy link

aleablu commented Mar 11, 2024

hey there, amazing work thank you so much! I managed to convert the model and it works perfectly, but now I am facing some issues when trying to leverage huggingface's optimum to export the model in ONNX, because this vision-text-dual-encoder is not supported yet. I also checked with the older version but no luck. I'm following this guide

Anybody has any ideas on how to make it happen? Thanks in advance!

@damian0815
Copy link

so i tried this but it doesn't seem to work with the ViT-H model architecture - basically both of these ValueErrors are raised

    if not isinstance(openclip_model.text, HFTextEncoder):
        raise ValueError("Only HFTextEncoder is supported.")
    if openclip_config["text_cfg"]["pooler_type"] != "mean_pooler":
        raise ValueError("Only mean_pooler is supported.")

any hints?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment