Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
import numpy as np
import torch
import pickle
from torch.utils.data import Dataset, DataLoader
# 提取特徵
def extract(loader):
featureExtractor.eval()
features = None
with torch.no_grad():
for batch_idx, img in enumerate(tqdm(loader)):
feat_list = []
def hook(module, input, output):
feat_list.append(output.clone().detach())
handle=featureExtractor.avgpool.register_forward_hook(hook)
img = img.cuda()
feat = featureExtractor(img)
feat_tsne = torch.flatten(feat_list[0], 1)
handle.remove()
current_features = feat_tsne.cpu().numpy() # tsne visualize feature
if features is not None:
features = np.concatenate((features, current_features))
else:
features = current_features
return features
# 提取Training set圖片
train_dataloader = loadTrainSet()
features = eval(loader=train_dataloader)
with open('./extractedFeature.pickle', 'wb') as handle:
pickle.dump(features, handle, protocol=pickle.HIGHEST_PROTOCOL) # protocol=4 避免超過4GB檔案無法儲存的問題
# 讀取Pickle檔案回Dataloader
class FeatDataset(Dataset):
def __init__(self, data):
self.data = data
self.current_set_len = data.shape[0]
def __len__(self):
return self.current_set_len
def __getitem__(self, idx):
curdata = self.data[idx]
return curdata
with open('./extractedFeature.pickle', "rb") as fn:
wholeData = pickle.load(fn)
wholeData = torch.from_numpy(wholeData)
wholeData.unsqueeze_(-1).unsqueeze_(-1)
trainset_closeset = FeatDataset(data=wholeData)
dataloader = DataLoader(trainset_closeset, batch_size=batch_size, shuffle=True, num_workers=1)
data_sampler = iter(dataloader)
feaList = next(data_sampler)
print(feaList.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment