Skip to content

Instantly share code, notes, and snippets.

@Vchekryzhov
Created October 26, 2022 19:52
Show Gist options
  • Save Vchekryzhov/4474d8b0aa1bc252064c01a0319459fa to your computer and use it in GitHub Desktop.
Save Vchekryzhov/4474d8b0aa1bc252064c01a0319459fa to your computer and use it in GitHub Desktop.
Image embeding module
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
import torch
import glob
import pickle
from tqdm import tqdm
from PIL import Image
def pil_loader(path):
# ПРЕДВАРИТЕЛЬНАЯ ОБРАБОТКА ИЗОБРАЕНИЙ. Некоторые изображения из датасета представленны не в RGB формате, необходимо их конверитровать в RGB
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
# Импорт и инициализация предобученой сети реснет
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
preprocess = weights.transforms()
use_precomputed_embeddings = True
emb_filename = 'fashion_images_embs.pickle'
if use_precomputed_embeddings:
with open(emb_filename, 'rb') as fIn:
img_names, img_emb_tensors = pickle.load(fIn)
print("Images:", len(img_names))
else:
img_names = list(glob.glob('images/*.jpg'))
img_emb = []
for image in tqdm(img_names):
# извлечение признаков из изображений в датасете. У меня на CPU заняло около часа
img_emb.append(
model(preprocess(pil_loader(image)).unsqueeze(0)).squeeze(0).detach().numpy()
)
img_emb_tensors = torch.tensor(img_emb)
with open(emb_filename, 'wb') as handle:
# Сохранение массива в файл. (БАЗА ДАННЫХ)
pickle.dump([img_names, img_emb_tensors], handle, protocol=pickle.HIGHEST_PROTOCOL)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment