Skip to content

Instantly share code, notes, and snippets.

@e96031413
Created July 16, 2021 12:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save e96031413/5e00f174fc6fffded710b920b23e4380 to your computer and use it in GitHub Desktop.
Save e96031413/5e00f174fc6fffded710b920b23e4380 to your computer and use it in GitHub Desktop.
for i, (img, target,_) in enumerate(tqdm(dataloader)):
feat_list = []
def hook(module, input, output):
# 由於MobileNetv2不像ResNet18有宣告self.avgpool(),因此我的作法是將模型卷積層的最後一層的輸出手動進行adaptive_avg_pool2d
# 接著將它加到feature_list中
feat_list.append(nn.functional.adaptive_avg_pool2d(output.clone(), (1, 1)).reshape(output.clone().shape[0], -1).detach())
images = img.to(device)
target = target.squeeze().tolist()
for element in target:
labels.append(element)
with torch.no_grad():
handle=model.features[-1].register_forward_hook(hook) # 模型卷積層最後一層的輸出(model.features[-1])
output = model.forward(images)
feat = torch.flatten(feat_list[0], 1)
handle.remove()
current_features = feat.cpu().numpy()
if features is not None:
features = np.concatenate((features, current_features))
else:
features = current_features
return features, labels
@melody871126
Copy link

您好:我想請教一下,以下這段程式碼我不太理解。一開始features為none,為何會有features is not None的情況?又為何features is not None時,要合併features, current_features?謝謝您。
if features is not None:
features = np.concatenate((features, current_features))
else:
features = current_features

@e96031413
Copy link
Author

其實這段程式碼也可以再簡化一些

程式碼會不斷地讀取dataloader的圖片,所以當讀取第二張圖片後,featurees就不是None了

合併features, current_features的原因是要在function結束後return features給tsne的套件進行處理

@melody871126
Copy link

melody871126 commented Jan 17, 2022

好的謝謝您,想再請問,像 for element in target: labels.append(element),這邊有append到labels,但在code中沒看到有features更新的地方,唯一出現的是features=none,最後就直接執行if features is not None了,想請問過程中features是如何更新的?
如果程式碼會不斷地讀取dataloader的圖片,更新featurees,那又為何會有features is not None的情況?
感謝答覆!!

@e96031413
Copy link
Author

您好,以下的程式碼應該會比較清楚一點:

def get_features(model, transform_dataset):
    

    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'
    if isinstance(model,torch.nn.DataParallel):
        model = model.module
    
    model.eval()
    model.to(device)
    
    
    dataset = CustomDataset(df_tsne, transform = transform_dataset)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1024, collate_fn=collate_skip_empty, shuffle=False, num_workers=4)

    features = None
    
    labels = []
    image_paths = []
    print("Start extracting Feature")
    for i, (img, target) in enumerate(tqdm(dataloader)):
        
        feat_list = []
        def hook(module, input, output): 
            feat_list.append(output.clone().detach())
        
        images = img.to(device)
        target = target.squeeze().tolist()
        
        for element in target:
            labels.append(element)
        
        with torch.no_grad():
            handle=model.avgpool.register_forward_hook(hook)
            output = model.forward(images)
            feat = torch.flatten(feat_list[0], 1)
            handle.remove()
        
        current_features = feat.cpu().numpy()
        if features is not None:
            features = np.concatenate((features, current_features))
        else:
            features = current_features

    return features, labels


features, labels = get_features(model, transform_dataset)
tsne = TSNE(n_components=2).fit_transform(features)

features is not None是為了避免程式出現例外情形導致中斷

@melody871126
Copy link

謝謝您的答覆,這個code我看過了,但仍無法看出來features = None之後,哪個步驟有更新到"features"?

@e96031413
Copy link
Author

features = None

current_features = feat.cpu().numpy()
if features is not None:
    features = np.concatenate((features, current_features))
else:
    features = current_features

您好,這段程式碼當中,我把其他程式碼先暫時去掉,只看features相關的程式碼

  1. 初始狀態為features = None
  2. 第一次更新features的時候,由於features = None,所以會執行else部分的程式碼,這時候features被更新後,就不是None了
  3. 因此下一次更新features的時候,會從if features is not None開始執行,這邊會透過numpy把features和current_features進行concat

@melody871126
Copy link

謝謝您的解說!我理解了!原本以為是最後一個步驟才執行if else,是我忽略了他會不斷讀取dataloader裡的資料去做更新。
謝謝您。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment