Skip to content

Instantly share code, notes, and snippets.

@turtlesoupy
Created August 17, 2020 19:39
Show Gist options
  • Save turtlesoupy/6bf5700b465578a43681ca6637ea581b to your computer and use it in GitHub Desktop.
Save turtlesoupy/6bf5700b465578a43681ca6637ea581b to your computer and use it in GitHub Desktop.
import ipywidgets as widgets
slider_to_idx = {}
sliders = []
slider_range = 50
slider_step = 0.1
img_widgets = [
image_widget_from_pil_image(show_image(tensors[i, :, :, :]))
for i in range(tensors.size(0))
]
def on_value_change(change):
with output2:
new_value = change['new']
component_num = slider_to_idx[change['owner']]
V, stdev, _ = estimator.get_components()
one_hot = np.zeros((1, n_components))
for i, slider in enumerate(sliders):
one_hot[0, slider_to_idx[slider]] = slider.value * stdev[i] * slider_step
direction = torch.tensor(np.matmul(one_hot, V), device=device, dtype=torch.float)
direction = direction.repeat(trainer.GAN.G.num_layers, 1)
#direction[:, :] = 0
tensors = generate_from_style_w(base_w_vectors + direction, trainer.GAN.GE, n)
for i, widget in enumerate(img_widgets):
widget.value = utils.pil_to_bytes(show_image(tensors[i, :, :, :]))
output2 = widgets.Output()
for component_num in range(14):
int_range = widgets.IntSlider(min=-slider_range, max=slider_range, description=f"Dir {component_num}")
int_range.observe(on_value_change, names='value')
slider_to_idx[int_range] = component_num
sliders.append(int_range)
sliders_display = widgets.HBox((
widgets.VBox(sliders[:len(sliders) // 2]),
widgets.VBox(sliders[len(sliders) // 2:])
))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment