This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
import requests | |
import transformers | |
import pandas as pd | |
from PIL import Image | |
from transformers import AutoProcessor, CLIPModel, CLIPTextModel, CLIPVisionModel | |
print("Transformers version:", transformers.__version__) | |
N_ITERATIONS = 100 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import glob | |
import torch | |
from typing import Mapping | |
# get current file directory | |
file_dir = os.path.dirname(os.path.realpath(__file__)) | |
branch_1 = "main" | |
branch_2 = "clean-up-do-reduce-labels" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from collections import OrderedDict | |
from typing import List | |
checkpoints_weights_paths: List[str] = ... # sorted in descending order by score | |
model: torch.nn.Module = ... | |
def average_weights(state_dicts: List[dict]): | |
everage_dict = OrderedDict() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TestDataset(Dataset): | |
def __init__(self, base_path, ids, transform): | |
self.base_path = base_path | |
self.ids = ids | |
self.transform = transform | |
def __getitem__(self, i): | |
sample = self.get_sample(i) | |
if self.transform is not None: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
from keras.layers import Conv2D, BatchNormalization, Dense | |
# NOTE! | |
# It is supposed to be used with python 3.6+ as it is rely on ordered keys of dict | |
def get_name(name): | |
"""Parse name""" | |
parts = name.split('/')[:-1] | |
return '/'.join(parts) |