Created
November 18, 2019 19:17
-
-
Save ncammarata/ef76576990a92552086a3df1636a765e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scipy.ndimage as nd | |
from scipy.ndimage import gaussian_filter | |
import lucid.optvis.transform as transform | |
def mean_alpha(): | |
def inner(T): | |
input_t = T("input") | |
return 1.0 - tf.reduce_mean(input_t[..., 3:]) | |
return objectives.Objective(inner) | |
def pretty_alpha(img, padding = 2): | |
rgb = img[..., :3] | |
white = np.ones(rgb.shape) | |
alpha = img[..., 3:] | |
percent_bg = 0.5 | |
bg = percent_bg * rgb + (1 - percent_bg) * white | |
beautiful = (gaussian_filter(bg, sigma=[5, 5, 1])) * (1 - alpha) + rgb * alpha | |
center = beautiful | |
cropped = nd.zoom( | |
center[padding:-padding, padding:-padding, :], order=2, zoom=[1.2, 1.2, 1] | |
) | |
return cropped | |
def render_neuron(model_name, layer, channel): | |
model = models.InceptionV1() | |
w = 90 | |
param_f = lambda: param.image(w, alpha = True) | |
obj = objectives.neuron(layer, channel) | |
obj -= 1e2 * objectives.blur_alpha_each_step() | |
obj *= mean_alpha() | |
transforms = [] | |
transforms.append(transform.jitter(4)) | |
transforms.append(transform.jitter(4)) | |
transforms.append(transform.jitter(8)) | |
transforms.append(transform.jitter(8)) | |
transforms.append(transform.jitter(8)) | |
transforms.append(transform.random_scale([0.995**n for n in range(-5,80)] + [0.998**n for n in 2*list(range(20,40))])) | |
transforms.append(transform.random_rotate(list(range(-20,20))+list(range(-10,10))+list(range(-5,5))+5*[0])) | |
transforms.append(transform.jitter(2)) | |
transforms.append(transform.crop_or_pad_to(w, w)) | |
transforms.append(transform.collapse_alpha_random()) | |
result = render.render_vis(model, obj, param_f=param_f, transforms=transforms, thresholds=[1024])[-1] | |
result = pretty_alpha(result[0]) | |
show(result) | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks @ncammarata! Everything works great. The only modification I needed to make to run as is was to load the graph (likely a typo in your gist!):
model.load_graphdef()