Skip to content

Instantly share code, notes, and snippets.

@qubvel
Created June 6, 2024 16:15
Show Gist options
  • Save qubvel/85cf17f516063d63219ef773a365e83d to your computer and use it in GitHub Desktop.
Save qubvel/85cf17f516063d63219ef773a365e83d to your computer and use it in GitHub Desktop.
Compare image processors across branches
import os
import glob
import torch
from typing import Mapping
# get current file directory
file_dir = os.path.dirname(os.path.realpath(__file__))
branch_1 = "main"
branch_2 = "clean-up-do-reduce-labels"
branch_1_path = os.path.join(file_dir, f"outputs/*/*/{branch_1}/*inputs.pt")
branch_2_path = os.path.join(file_dir, f"outputs/*/*/{branch_2}/*inputs.pt")
branch_1_inputs_paths = sorted(glob.glob(os.path.join(file_dir, f"outputs/*/*/{branch_1}/*inputs.pt")))
branch_2_inputs_paths = sorted(glob.glob(os.path.join(file_dir, f"outputs/*/*/{branch_2}/*inputs.pt")))
assert len(branch_1_inputs_paths) == len(branch_2_inputs_paths)
assert len(branch_1_inputs_paths) > 0
def nested_compare(x1, x2, key=None):
key = key or ""
if isinstance(x1, Mapping):
assert x1.keys() == x2.keys()
for k in x1.keys():
nested_compare(x1[k], x2[k], key=key + f"['{k}']")
elif isinstance(x1, (list, tuple)):
assert len(x1) == len(x2)
for i in range(len(x1)):
nested_compare(x1[i], x2[i], key=key + f"[{i}]")
elif isinstance(x1, torch.Tensor):
assert torch.allclose(x1, x2, atol=1e-4), key
print(".", end="")
elif isinstance(x1, (int, float, str)):
assert x1 == x2, key
else:
raise ValueError((x1, x2, key))
for branch_1_input_path, branch_2_input_path in zip(branch_1_inputs_paths, branch_2_inputs_paths):
print()
print("Testing", branch_1_input_path, branch_2_input_path)
branch_1_inputs = torch.load(branch_1_input_path)
branch_2_inputs = torch.load(branch_2_input_path)
assert branch_1_inputs.keys() == branch_2_inputs.keys()
nested_compare(branch_1_inputs, branch_2_inputs)
import os
import numpy as np
import torch
from transformers import AutoImageProcessor, AutoConfig
from huggingface_hub import HfApi, ModelFilter
api = HfApi()
np.random.seed(42342)
torch.manual_seed(42342)
# get current file directory
file_dir = os.path.dirname(os.path.realpath(__file__))
# get current git branch name
branch_name = os.popen('git rev-parse --abbrev-ref HEAD').read().strip()
# input sample
image = np.random.randint(0, 256, (224, 540, 3), dtype=np.uint8)
segmentation_maps = np.random.randint(0, 2, (224, 540), dtype=np.uint8)
relevant_checkpoints = []
for model_name in [
"maskformer",
"mask2former",
"segformer",
"beit",
"oneformer",
]:
model_filter = ModelFilter(model_name=model_name)
models = api.list_models(filter=model_filter, sort="downloads", limit=50)
for model in models:
print(model.downloads, model.id, model.gated)
if model.downloads > 10 and not model.gated:
relevant_checkpoints.append(model.id)
for checkpoint in relevant_checkpoints:
print(checkpoint)
save_dir = os.path.join(file_dir, "outputs", checkpoint, branch_name)
os.makedirs(save_dir, exist_ok=True)
try:
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
if hasattr(image_processor, "num_text"):
config = AutoConfig.from_pretrained(checkpoint)
image_processor.num_text = config.num_queries - config.text_encoder_n_ctx
except OSError as e:
print(f"Error loading {checkpoint}: {e}")
continue
except ValueError as e:
if "Unrecognized image processor" in str(e):
continue
else:
raise e
# Preprocess the image
inputs = image_processor(images=image, segmentation_maps=segmentation_maps, return_tensors="pt")
torch.save(inputs, os.path.join(save_dir, "inputs.pt"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment