Created
January 12, 2022 13:08
-
-
Save e96031413/043e9185e27fc89dfd91b7d90fd0dd16 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import pickle | |
from torch.utils.data import Dataset, DataLoader | |
# 提取特徵 | |
def extract(loader): | |
featureExtractor.eval() | |
features = None | |
with torch.no_grad(): | |
for batch_idx, img in enumerate(tqdm(loader)): | |
feat_list = [] | |
def hook(module, input, output): | |
feat_list.append(output.clone().detach()) | |
handle=featureExtractor.avgpool.register_forward_hook(hook) | |
img = img.cuda() | |
feat = featureExtractor(img) | |
feat_tsne = torch.flatten(feat_list[0], 1) | |
handle.remove() | |
current_features = feat_tsne.cpu().numpy() # tsne visualize feature | |
if features is not None: | |
features = np.concatenate((features, current_features)) | |
else: | |
features = current_features | |
return features | |
# 提取Training set圖片 | |
train_dataloader = loadTrainSet() | |
features = eval(loader=train_dataloader) | |
with open('./extractedFeature.pickle', 'wb') as handle: | |
pickle.dump(features, handle, protocol=pickle.HIGHEST_PROTOCOL) # protocol=4 避免超過4GB檔案無法儲存的問題 | |
# 讀取Pickle檔案回Dataloader | |
class FeatDataset(Dataset): | |
def __init__(self, data): | |
self.data = data | |
self.current_set_len = data.shape[0] | |
def __len__(self): | |
return self.current_set_len | |
def __getitem__(self, idx): | |
curdata = self.data[idx] | |
return curdata | |
with open('./extractedFeature.pickle', "rb") as fn: | |
wholeData = pickle.load(fn) | |
wholeData = torch.from_numpy(wholeData) | |
wholeData.unsqueeze_(-1).unsqueeze_(-1) | |
trainset_closeset = FeatDataset(data=wholeData) | |
dataloader = DataLoader(trainset_closeset, batch_size=batch_size, shuffle=True, num_workers=1) | |
data_sampler = iter(dataloader) | |
feaList = next(data_sampler) | |
print(feaList.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment