Skip to content

Instantly share code, notes, and snippets.

@VoVAllen
Last active December 3, 2018 17:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save VoVAllen/59b344a2e329222c51f68f86b5115a32 to your computer and use it in GitHub Desktop.
Save VoVAllen/59b344a2e329222c51f68f86b5115a32 to your computer and use it in GitHub Desktop.
Draw att
# This part for jupyter notebook setting (if you wants to save, don't use this)
# %matplotlib inline
# %config InlineBackend.figure_format = 'svg'
# import numpy as np
# import matplotlib.pyplot as plt
# plt.rcParams["animation.html"] = "jshtml"
import networkx as nx
from networkx.algorithms import bipartite
import numpy as np
import matplotlib.animation as animation
weight=np.random.randn(10,15)
M=['aasdjkk ']*10
N=[' b']*15
def att(M,N, weight, ax):
in_nodes=len(M)
out_nodes=len(N)
assert weight.shape[0]==in_nodes,weight.shape[1]==out_nodes
g = nx.bipartite.generators.complete_bipartite_graph(in_nodes,out_nodes)
X, Y = bipartite.sets(g)
height_in = 10
height_out = height_in * 0.8
height_in_y = np.linspace(0, height_in, in_nodes)
height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes)
pos = dict()
pos.update((n, (1, i)) for i, n in zip(height_in_y, X)) # put nodes from X at x=1
pos.update((n, (3, i)) for i, n in zip(height_out_y, Y)) # put nodes from Y at x=2
ax.axis('off')
ax.set_xlim(-1,4)
ax.set_title("ATT")
nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes), node_color='r', node_size=50, ax=ax)
nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes, in_nodes + out_nodes), node_color='b', node_size=50, ax=ax)
for edge in g.edges():
nx.draw_networkx_edges(g, pos, edgelist=[edge], width=weight[edge[0], edge[1] - in_nodes] * 1.5, ax=ax)
nx.draw_networkx_labels(g, pos, {i:label for i,label in enumerate(M)},horizontalalignment='right', font_size=8, ax=ax)
nx.draw_networkx_labels(g, pos, {i+in_nodes:label for i,label in enumerate(N)},horizontalalignment='left', font_size=8, ax=ax)
fig = plt.figure(figsize=(2, 3), dpi=150)
fig.clf()
ax=fig.subplots()
weight_list=[np.random.randn(10,15) for _ in range(10)]
def weight_animate(i):
ax.cla()
ax.axis("off")
print(i)
att(M,N,weight_list[i], ax)
ani = animation.FuncAnimation(fig, weight_animate, frames=10, interval=500)
ani
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment