Skip to content

Instantly share code, notes, and snippets.

@e96031413
Created November 18, 2021 03:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save e96031413/affbde87f08b60b22d13e61137a8174f to your computer and use it in GitHub Desktop.
Save e96031413/affbde87f08b60b22d13e61137a8174f to your computer and use it in GitHub Desktop.
def get_features_trained_weight(model, transform_dataset):
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
if isinstance(model,torch.nn.DataParallel):
model = model.module
model.eval()
model.to(device)
dataset = CustomDataset(df_tsne, transform = transform_dataset)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1024, collate_fn=collate_skip_empty, shuffle=False, num_workers=4)
# we'll store the features as NumPy array of size num_images x feature_size
features = None
imgs = None
# we'll also store the image labels and paths to visualize them later
labels = []
image_paths = []
print("Start extracting Feature")
for i, (img, target, path, _) in enumerate(tqdm(dataloader)):
feat_list = []
def hook(module, input, output):
feat_list.append(output.clone().detach())
images = img.to(device)
target = target.squeeze().tolist()
for element in target:
labels.append(element)
for element in path:
image_paths.append(element)
with torch.no_grad():
handle=model.avgpool.register_forward_hook(hook)
output = model.forward(images)
feat = torch.flatten(feat_list[0], 1)
handle.remove()
current_imgs = images.cpu().numpy()
if imgs is not None:
imgs = np.concatenate((imgs, current_imgs))
else:
imgs = current_imgs
current_features = feat.cpu().numpy()
if features is not None:
features = np.concatenate((features, current_features))
else:
features = current_features
return features, imgs, labels, image_paths
model = model
features, imgs, labels, image_path = get_features_trained_weight(model, transform_dataset)
writer.add_embedding(features, metadata=labels, label_img=imgs)
writer.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment