Skip to content

Instantly share code, notes, and snippets.

@geektoni
Last active April 19, 2019 13:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save geektoni/9048940fafa151dcb72918472e414c2b to your computer and use it in GitHub Desktop.
Save geektoni/9048940fafa151dcb72918472e414c2b to your computer and use it in GitHub Desktop.
from matplotlib import pyplot as plt
from matplotlib import animation
from shogun import csv_file, features, labels, machine, parameter_observer
f_feats_train = csv_file("classifier_binary_2d_linear_features_train.dat")
f_feats_test = csv_file("classifier_binary_2d_linear_features_test.dat")
f_labels_train = csv_file("classifier_binary_2d_linear_labels_train.dat")
f_labels_test = csv_file("classifier_binary_2d_linear_labels_test.dat")
features_train = features(f_feats_train)
features_test = features(f_feats_test)
labels_train = labels(f_labels_train)
labels_test = labels(f_labels_test)
perceptron = machine("AveragedPerceptron", labels=labels_train, learn_rate=1.0, max_iterations=1000)
observer = parameter_observer("ParameterObserverLogger")
perceptron.subscribe(observer)
perceptron.train(features_train)
labels_predict = perceptron.apply(features_test)
fig = plt.figure()
ax = plt.axes()
line, = ax.plot([], [])
# Plot the background for each frame
def init():
# Plot all the data points every time
for i in range(f_size):
plt.plot(features_train[i][0], features_train[i][1], 'ro')
line.set_data([], [])
return line,
# Print the observation line on screen
def animate(i):
obs = observer.get_observation(i)
vect = obs.get("weights")
x = np.linspace(-10, 10, 1000)
line.set_data(x, vect*x)
return line,
anim = animation.FuncAnimation(fig, animate, init_func=init, frames=f_size, interval=10, blit=True)
anim.save('basic_animation.mp4', fps=30, extra_args=['-vcodec', 'libx264'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment