Skip to content

Instantly share code, notes, and snippets.

@e96031413
Created July 16, 2021 12:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save e96031413/5e00f174fc6fffded710b920b23e4380 to your computer and use it in GitHub Desktop.
Save e96031413/5e00f174fc6fffded710b920b23e4380 to your computer and use it in GitHub Desktop.
for i, (img, target,_) in enumerate(tqdm(dataloader)):
feat_list = []
def hook(module, input, output):
# 由於MobileNetv2不像ResNet18有宣告self.avgpool(),因此我的作法是將模型卷積層的最後一層的輸出手動進行adaptive_avg_pool2d
# 接著將它加到feature_list中
feat_list.append(nn.functional.adaptive_avg_pool2d(output.clone(), (1, 1)).reshape(output.clone().shape[0], -1).detach())
images = img.to(device)
target = target.squeeze().tolist()
for element in target:
labels.append(element)
with torch.no_grad():
handle=model.features[-1].register_forward_hook(hook) # 模型卷積層最後一層的輸出(model.features[-1])
output = model.forward(images)
feat = torch.flatten(feat_list[0], 1)
handle.remove()
current_features = feat.cpu().numpy()
if features is not None:
features = np.concatenate((features, current_features))
else:
features = current_features
return features, labels
@melody871126
Copy link

謝謝您的解說!我理解了!原本以為是最後一個步驟才執行if else,是我忽略了他會不斷讀取dataloader裡的資料去做更新。
謝謝您。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment