Created
June 6, 2024 16:15
-
-
Save qubvel/85cf17f516063d63219ef773a365e83d to your computer and use it in GitHub Desktop.
Compare image processors across branches
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import glob | |
import torch | |
from typing import Mapping | |
# get current file directory | |
file_dir = os.path.dirname(os.path.realpath(__file__)) | |
branch_1 = "main" | |
branch_2 = "clean-up-do-reduce-labels" | |
branch_1_path = os.path.join(file_dir, f"outputs/*/*/{branch_1}/*inputs.pt") | |
branch_2_path = os.path.join(file_dir, f"outputs/*/*/{branch_2}/*inputs.pt") | |
branch_1_inputs_paths = sorted(glob.glob(os.path.join(file_dir, f"outputs/*/*/{branch_1}/*inputs.pt"))) | |
branch_2_inputs_paths = sorted(glob.glob(os.path.join(file_dir, f"outputs/*/*/{branch_2}/*inputs.pt"))) | |
assert len(branch_1_inputs_paths) == len(branch_2_inputs_paths) | |
assert len(branch_1_inputs_paths) > 0 | |
def nested_compare(x1, x2, key=None): | |
key = key or "" | |
if isinstance(x1, Mapping): | |
assert x1.keys() == x2.keys() | |
for k in x1.keys(): | |
nested_compare(x1[k], x2[k], key=key + f"['{k}']") | |
elif isinstance(x1, (list, tuple)): | |
assert len(x1) == len(x2) | |
for i in range(len(x1)): | |
nested_compare(x1[i], x2[i], key=key + f"[{i}]") | |
elif isinstance(x1, torch.Tensor): | |
assert torch.allclose(x1, x2, atol=1e-4), key | |
print(".", end="") | |
elif isinstance(x1, (int, float, str)): | |
assert x1 == x2, key | |
else: | |
raise ValueError((x1, x2, key)) | |
for branch_1_input_path, branch_2_input_path in zip(branch_1_inputs_paths, branch_2_inputs_paths): | |
print() | |
print("Testing", branch_1_input_path, branch_2_input_path) | |
branch_1_inputs = torch.load(branch_1_input_path) | |
branch_2_inputs = torch.load(branch_2_input_path) | |
assert branch_1_inputs.keys() == branch_2_inputs.keys() | |
nested_compare(branch_1_inputs, branch_2_inputs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
import torch | |
from transformers import AutoImageProcessor, AutoConfig | |
from huggingface_hub import HfApi, ModelFilter | |
api = HfApi() | |
np.random.seed(42342) | |
torch.manual_seed(42342) | |
# get current file directory | |
file_dir = os.path.dirname(os.path.realpath(__file__)) | |
# get current git branch name | |
branch_name = os.popen('git rev-parse --abbrev-ref HEAD').read().strip() | |
# input sample | |
image = np.random.randint(0, 256, (224, 540, 3), dtype=np.uint8) | |
segmentation_maps = np.random.randint(0, 2, (224, 540), dtype=np.uint8) | |
relevant_checkpoints = [] | |
for model_name in [ | |
"maskformer", | |
"mask2former", | |
"segformer", | |
"beit", | |
"oneformer", | |
]: | |
model_filter = ModelFilter(model_name=model_name) | |
models = api.list_models(filter=model_filter, sort="downloads", limit=50) | |
for model in models: | |
print(model.downloads, model.id, model.gated) | |
if model.downloads > 10 and not model.gated: | |
relevant_checkpoints.append(model.id) | |
for checkpoint in relevant_checkpoints: | |
print(checkpoint) | |
save_dir = os.path.join(file_dir, "outputs", checkpoint, branch_name) | |
os.makedirs(save_dir, exist_ok=True) | |
try: | |
image_processor = AutoImageProcessor.from_pretrained(checkpoint) | |
if hasattr(image_processor, "num_text"): | |
config = AutoConfig.from_pretrained(checkpoint) | |
image_processor.num_text = config.num_queries - config.text_encoder_n_ctx | |
except OSError as e: | |
print(f"Error loading {checkpoint}: {e}") | |
continue | |
except ValueError as e: | |
if "Unrecognized image processor" in str(e): | |
continue | |
else: | |
raise e | |
# Preprocess the image | |
inputs = image_processor(images=image, segmentation_maps=segmentation_maps, return_tensors="pt") | |
torch.save(inputs, os.path.join(save_dir, "inputs.pt")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment