Skip to content

Instantly share code, notes, and snippets.

@e96031413
Created July 16, 2021 12:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save e96031413/5e00f174fc6fffded710b920b23e4380 to your computer and use it in GitHub Desktop.
Save e96031413/5e00f174fc6fffded710b920b23e4380 to your computer and use it in GitHub Desktop.
for i, (img, target,_) in enumerate(tqdm(dataloader)):
feat_list = []
def hook(module, input, output):
# 由於MobileNetv2不像ResNet18有宣告self.avgpool(),因此我的作法是將模型卷積層的最後一層的輸出手動進行adaptive_avg_pool2d
# 接著將它加到feature_list中
feat_list.append(nn.functional.adaptive_avg_pool2d(output.clone(), (1, 1)).reshape(output.clone().shape[0], -1).detach())
images = img.to(device)
target = target.squeeze().tolist()
for element in target:
labels.append(element)
with torch.no_grad():
handle=model.features[-1].register_forward_hook(hook) # 模型卷積層最後一層的輸出(model.features[-1])
output = model.forward(images)
feat = torch.flatten(feat_list[0], 1)
handle.remove()
current_features = feat.cpu().numpy()
if features is not None:
features = np.concatenate((features, current_features))
else:
features = current_features
return features, labels
@melody871126
Copy link

謝謝您的答覆,這個code我看過了,但仍無法看出來features = None之後,哪個步驟有更新到"features"?

@e96031413
Copy link
Author

features = None

current_features = feat.cpu().numpy()
if features is not None:
    features = np.concatenate((features, current_features))
else:
    features = current_features

您好,這段程式碼當中,我把其他程式碼先暫時去掉,只看features相關的程式碼

  1. 初始狀態為features = None
  2. 第一次更新features的時候,由於features = None,所以會執行else部分的程式碼,這時候features被更新後,就不是None了
  3. 因此下一次更新features的時候,會從if features is not None開始執行,這邊會透過numpy把features和current_features進行concat

@melody871126
Copy link

謝謝您的解說!我理解了!原本以為是最後一個步驟才執行if else,是我忽略了他會不斷讀取dataloader裡的資料去做更新。
謝謝您。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment