Skip to content

Instantly share code, notes, and snippets.

@calpt
Last active May 12, 2024 06:02
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save calpt/8e3555bd11f1916b5169c8125117e5ee to your computer and use it in GitHub Desktop.
Save calpt/8e3555bd11f1916b5169c8125117e5ee to your computer and use it in GitHub Desktop.
Convert multilingual LAION CLIP checkpoints from OpenCLIP to Hugging Face Transformers

OpenCLIP -> HF Transformers Conversion

Setup

pip install open_clip_torch transformers

Convert

# LAION XLM-Roberta Base model
python convert.py xlm-roberta-base-ViT-B-32 laion5b_s13b_b90k

# LAION XLM-Roberta Large model
python convert.py xlm-roberta-large-ViT-H-14 frozen_laion5b_s13b_b90k

Use converted checkpoints

from modeling_clip import OpenCLIPVisionTextDualEncoderModel

model = OpenCLIPVisionTextDualEncoderModel.from_pretrained("...")

Pre-converted checkpoints

import dataclasses
import os
import re
import sys
import open_clip
from open_clip.hf_model import HFTextEncoder
from open_clip.model import CLIPVisionCfg
from transformers import CLIPVisionConfig, VisionTextDualEncoderConfig
from modeling_clip import OpenCLIPVisionTextDualEncoderModel
VISION_CONFIG_MAP = {
"layers": "num_hidden_layers",
"width": "hidden_size",
"patch_size": "patch_size",
"image_size": "image_size",
}
STATE_DICT_PATTERNS = [
# Vision
(r"visual\.class_embedding", "vision_model.vision_model.embeddings.class_embedding"),
(r"visual\.positional_embedding", "vision_model.vision_model.embeddings.position_embedding.weight"),
(r"visual\.conv1\.(\w+)", "vision_model.vision_model.embeddings.patch_embedding.{0}"),
(r"visual\.ln_pre\.(\w+)", "vision_model.vision_model.pre_layrnorm.{0}"),
(r"visual\.ln_post\.(\w+)", "vision_model.vision_model.post_layernorm.{0}"),
(
r"visual\.transformer\.resblocks\.(\w+)\.ln_1\.(\w+)",
"vision_model.vision_model.encoder.layers.{0}.layer_norm1.{1}",
),
(
r"visual\.transformer\.resblocks\.(\w+)\.ln_2\.(\w+)",
"vision_model.vision_model.encoder.layers.{0}.layer_norm2.{1}",
),
(
r"visual\.transformer\.resblocks\.(\w+)\.attn\.out_proj\.(\w+)",
"vision_model.vision_model.encoder.layers.{0}.self_attn.out_proj.{1}",
),
(
r"visual\.transformer\.resblocks\.(\w+)\.mlp\.c_fc\.(\w+)",
"vision_model.vision_model.encoder.layers.{0}.mlp.fc1.{1}",
),
(
r"visual\.transformer\.resblocks\.(\w+)\.mlp\.c_proj\.(\w+)",
"vision_model.vision_model.encoder.layers.{0}.mlp.fc2.{1}",
),
# Text
(r"text\.transformer\.(.+)", "text_model.{0}"),
(r"text\.proj\.(.+)", "text_projection.{0}"),
]
def convert_vision_config(config: CLIPVisionCfg):
config = dataclasses.asdict(config)
new_config = {
"hidden_act": "gelu",
}
for key, value in config.items():
if key in VISION_CONFIG_MAP:
new_config[VISION_CONFIG_MAP[key]] = value
elif key == "head_width":
new_config["num_attention_heads"] = config["width"] // value
elif key == "mlp_ratio":
new_config["intermediate_size"] = int(config["width"] * value)
elif not key.startswith("timm") and value:
print(f"WARNING: Unknown key '{key}' in vision config.")
return CLIPVisionConfig(**new_config)
def convert_state_dict(state_dict):
new_state_dict = {}
for k, v in state_dict.items():
found = False
# special handling of vision attention blocks
if match := re.match(r"visual\.transformer\.resblocks\.(\w+)\.attn\.in_proj_(\w+)", k):
# chunk weights into three
chunks = v.chunk(3, dim=0)
for proj_name, proj_v in zip(["q_proj", "k_proj", "v_proj"], chunks):
new_k = f"vision_model.vision_model.encoder.layers.{match.group(1)}.self_attn.{proj_name}.{match.group(2)}"
print(k, "--->", new_k)
new_state_dict[new_k] = proj_v
found = True
# transpose visual projection
elif k == "visual.proj":
new_k = "visual_projection.weight"
print(k, "--->", new_k)
new_state_dict[new_k] = v.t()
found = True
else:
for pattern, replacement in STATE_DICT_PATTERNS:
if match := re.match(pattern, k):
new_k = replacement.format(*match.groups())
print(k, "--->", new_k)
new_state_dict[new_k] = v
found = True
break
if not found:
new_state_dict[k] = v
return new_state_dict
if __name__ == "__main__":
model_name = sys.argv[1]
pretrained = sys.argv[2]
openclip_config = open_clip.get_model_config(model_name)
openclip_model = open_clip.create_model(model_name, pretrained=pretrained)
if not isinstance(openclip_model.text, HFTextEncoder):
raise ValueError("Only HFTextEncoder is supported.")
if openclip_config["text_cfg"]["pooler_type"] != "mean_pooler":
raise ValueError("Only mean_pooler is supported.")
text_config = openclip_model.text.config
vision_config = convert_vision_config(CLIPVisionCfg(**openclip_config["vision_cfg"]))
config = VisionTextDualEncoderConfig.from_vision_text_configs(
vision_config,
text_config,
projection_dim=openclip_config["embed_dim"],
)
state_dict = convert_state_dict(openclip_model.state_dict())
model, loading_info = OpenCLIPVisionTextDualEncoderModel.from_pretrained(
None, config=config, state_dict=state_dict, output_loading_info=True
)
print(loading_info)
out_path = os.path.join("models", model_name)
model.save_pretrained(out_path)
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import PreTrainedModel, VisionTextDualEncoderConfig, VisionTextDualEncoderModel
from transformers.models.vision_text_dual_encoder.modeling_vision_text_dual_encoder import clip_loss, CLIPOutput
class MeanPooler(nn.Module):
"""Mean pooling"""
def forward(self, x, attention_mask):
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
class OpenCLIPVisionTextDualEncoderModel(VisionTextDualEncoderModel):
def __init__(
self,
config: Optional[VisionTextDualEncoderConfig] = None,
vision_model: Optional[PreTrainedModel] = None,
text_model: Optional[PreTrainedModel] = None,
add_text_model_pooling_layer: bool = False,
):
super().__init__(config, vision_model, text_model)
# Remove text pooling layer
if not add_text_model_pooling_layer:
self.text_model.pooler = None
# Add mean pooling
self.pooler = MeanPooler()
# Overwrite text projection
hidden_size = (self.text_embed_dim + self.projection_dim) // 2
self.text_projection = nn.Sequential(
nn.Linear(self.text_embed_dim, hidden_size, bias=False),
nn.GELU(),
nn.Linear(hidden_size, self.projection_dim, bias=False),
)
def get_text_features(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
token_type_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
token_type_ids=token_type_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = self.pooler(text_outputs, attention_mask)
text_features = self.text_projection(pooled_output)
return text_features
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
token_type_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], CLIPOutput]:
return_dict = return_dict if return_dict is not None else self.config.return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1] # pooler_output
image_embeds = self.visual_projection(image_embeds)
pooled_output = self.pooler(text_outputs, attention_mask)
text_embeds = self.text_projection(pooled_output)
# normalized features
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.T
loss = None
if return_loss:
loss = clip_loss(logits_per_text)
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
return CLIPOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
@versae
Copy link

versae commented Sep 18, 2023

Hey! Thank you so much for this code. I've been trying to re-create a standalone OpenCLIP that can be used solely via HF. But after some tinkering I think there's something not completely correct with the conversion script. For example:

import torch
import open_clip
from modeling_clip import OpenCLIPVisionTextDualEncoderModel

pixel_values = torch.randn(1, 3, 224, 224)

openclip_model, _, _ = open_clip.create_model_and_transforms('xlm-roberta-base-ViT-B-32', pretrained='laion5b_s13b_b90k')

model_id = "calpt/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k"
model = OpenCLIPVisionTextDualEncoderModel.from_pretrained(model_id)

v1 = model.get_image_features(pixel_values)
v2 = openclip_model.encode_image(pixel_values)
torch.allclose(v1, v2, atol=1e-4)
# False

Normalizing doesn't help either. And while vector values seem somewhat similar, the differences are large enough to make me think that something might be wrong.

Any help would be much appreciated.

@calpt
Copy link
Author

calpt commented Sep 24, 2023

Hey @versae, thanks for bringing this up! When checking the models again, I realized there's actually a small mistake in the vision config of the converted HF model. The activation function should be set to "gelu" to match the OpenCLIP model, which is currently not the case for the converted models.

You might fix this by loading the model as follow:

config = VisionTextDualEncoderConfig.from_pretrained(model_id)
config.vision_config.hidden_act = "gelu"
model = OpenCLIPVisionTextDualEncoderModel.from_pretrained(model_id, config=config)

Hope this solves the issue! I'll update the script and the uploaded checkpoints.

@versae
Copy link

versae commented Sep 25, 2023

Amazing! Just tested and it works perfectly. Thanks, @calpt!

@calpt
Copy link
Author

calpt commented Sep 25, 2023

Glad to hear!

As a general side note: The updated checkpoints now don't require manually copy-pasting the modeling code in the file above anymore when passing trust_remote_code=True to the regular AutoModel class, i.e.:

from transformers import AutoModel, AutoFeatureExtractor, AutoTokenizer

model = AutoModel.from_pretrained("calpt/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k", trust_remote_code=True)

@versae
Copy link

versae commented Sep 26, 2023

Great! That is exactly what I was trying to achieve. Thanks!

@aleablu
Copy link

aleablu commented Mar 11, 2024

hey there, amazing work thank you so much! I managed to convert the model and it works perfectly, but now I am facing some issues when trying to leverage huggingface's optimum to export the model in ONNX, because this vision-text-dual-encoder is not supported yet. I also checked with the older version but no luck. I'm following this guide

Anybody has any ideas on how to make it happen? Thanks in advance!

@damian0815
Copy link

so i tried this but it doesn't seem to work with the ViT-H model architecture - basically both of these ValueErrors are raised

    if not isinstance(openclip_model.text, HFTextEncoder):
        raise ValueError("Only HFTextEncoder is supported.")
    if openclip_config["text_cfg"]["pooler_type"] != "mean_pooler":
        raise ValueError("Only mean_pooler is supported.")

any hints?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment