Skip to content

Instantly share code, notes, and snippets.

@ncammarata
Created November 18, 2019 19:17
Show Gist options
  • Save ncammarata/ef76576990a92552086a3df1636a765e to your computer and use it in GitHub Desktop.
Save ncammarata/ef76576990a92552086a3df1636a765e to your computer and use it in GitHub Desktop.
import scipy.ndimage as nd
from scipy.ndimage import gaussian_filter
import lucid.optvis.transform as transform
def mean_alpha():
def inner(T):
input_t = T("input")
return 1.0 - tf.reduce_mean(input_t[..., 3:])
return objectives.Objective(inner)
def pretty_alpha(img, padding = 2):
rgb = img[..., :3]
white = np.ones(rgb.shape)
alpha = img[..., 3:]
percent_bg = 0.5
bg = percent_bg * rgb + (1 - percent_bg) * white
beautiful = (gaussian_filter(bg, sigma=[5, 5, 1])) * (1 - alpha) + rgb * alpha
center = beautiful
cropped = nd.zoom(
center[padding:-padding, padding:-padding, :], order=2, zoom=[1.2, 1.2, 1]
)
return cropped
def render_neuron(model_name, layer, channel):
model = models.InceptionV1()
w = 90
param_f = lambda: param.image(w, alpha = True)
obj = objectives.neuron(layer, channel)
obj -= 1e2 * objectives.blur_alpha_each_step()
obj *= mean_alpha()
transforms = []
transforms.append(transform.jitter(4))
transforms.append(transform.jitter(4))
transforms.append(transform.jitter(8))
transforms.append(transform.jitter(8))
transforms.append(transform.jitter(8))
transforms.append(transform.random_scale([0.995**n for n in range(-5,80)] + [0.998**n for n in 2*list(range(20,40))]))
transforms.append(transform.random_rotate(list(range(-20,20))+list(range(-10,10))+list(range(-5,5))+5*[0]))
transforms.append(transform.jitter(2))
transforms.append(transform.crop_or_pad_to(w, w))
transforms.append(transform.collapse_alpha_random())
result = render.render_vis(model, obj, param_f=param_f, transforms=transforms, thresholds=[1024])[-1]
result = pretty_alpha(result[0])
show(result)
return result
@fredhohman
Copy link

Thanks @ncammarata! Everything works great. The only modification I needed to make to run as is was to load the graph (likely a typo in your gist!): model.load_graphdef()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment