Last active
July 28, 2022 15:26
-
-
Save fg91/fa3bbafce24b62abb55c22139248e31d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class FilterVisualizer(): | |
def __init__(self, size=56, upscaling_steps=12, upscaling_factor=1.2): | |
self.size, self.upscaling_steps, self.upscaling_factor = size, upscaling_steps, upscaling_factor | |
self.model = vgg16(pre=True).cuda().eval() | |
set_trainable(self.model, False) | |
def visualize(self, layer, filter, lr=0.1, opt_steps=20, blur=None): | |
sz = self.size | |
img = np.uint8(np.random.uniform(150, 180, (sz, sz, 3)))/255 # generate random image | |
activations = SaveFeatures(list(self.model.children())[layer]) # register hook | |
for _ in range(self.upscaling_steps): # scale the image up upscaling_steps times | |
train_tfms, val_tfms = tfms_from_model(vgg16, sz) | |
img_var = V(val_tfms(img)[None], requires_grad=True) # convert image to Variable that requires grad | |
optimizer = torch.optim.Adam([img_var], lr=lr, weight_decay=1e-6) | |
for n in range(opt_steps): # optimize pixel values for opt_steps times | |
optimizer.zero_grad() | |
self.model(img_var) | |
loss = -activations.features[0, filter].mean() | |
loss.backward() | |
optimizer.step() | |
img = val_tfms.denorm(img_var.data.cpu().numpy()[0].transpose(1,2,0)) | |
self.output = img | |
sz = int(self.upscaling_factor * sz) # calculate new image size | |
img = cv2.resize(img, (sz, sz), interpolation = cv2.INTER_CUBIC) # scale image up | |
if blur is not None: img = cv2.blur(img,(blur,blur)) # blur image to reduce high frequency patterns | |
self.save(layer, filter) | |
activations.close() | |
def save(self, layer, filter): | |
plt.imsave("layer_"+str(layer)+"_filter_"+str(filter)+".jpg", np.clip(self.output, 0, 1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hope to have complete code that can be run directly, including models, demo data and other files that generate_image.py depends on...