Skip to content

Instantly share code, notes, and snippets.

@FisherKK
Created July 10, 2018 22:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FisherKK/fcd05b0eb3a3d12a680f03c68c5fdb40 to your computer and use it in GitHub Desktop.
Save FisherKK/fcd05b0eb3a3d12a680f03c68c5fdb40 to your computer and use it in GitHub Desktop.
X = np.array([[0.0], [1.0], [2.0], [3.0]])
y = np.array([0.0, 2.0, 4.0, 6.0])
model_parameters = {'b': 0.0, 'w': np.array([5.0])}
y_predicted = [predict(x, model_parameters) for x in X]
plt.figure(figsize=(6, 6))
plt.yticks(np.arange(0, 16, 1))
plt.xticks(np.arange(0, 4, 1))
plt.ylabel("y")
plt.xlabel("x")
plt.ylim(-1, 16)
plt.xlim(-1, 16)
plt.xticks(np.arange(0, 16, 1))
plt.grid("on", linestyle='--', linewidth=1, alpha=0.3)
plt.gca().spines["top"].set_visible(False)
plt.gca().spines["right"].set_visible(False)
plt.gca().spines["bottom"].set_visible(False)
plt.gca().spines["left"].set_visible(False)
plt.title("y = wx, [w=5.0]")
plt.scatter(X, y, edgecolor='black', linewidth=1,
label="expected", s=80, zorder=2)
plt.scatter(X, y_predicted, edgecolor='black', linewidth=1,
label="predicted", s=80, zorder=2)
plt.plot(X, y_predicted, zorder=1, c="orange")
plt.plot([1.0, 1.0], [2.0, 5.0], linewidth=1, linestyle="--", color="red",
label="error area", alpha=0.8, zorder=1)
plt.plot([4.0, 4.0], [2.0, 5.0], linewidth=1, linestyle="--", color="red",
label="error area", alpha=0.8, zorder=1)
plt.plot([1.0, 4.0], [5.0, 5.0], linewidth=1, linestyle="--", color="red",
label="error area", alpha=0.8, zorder=1)
plt.plot([1.0, 4.0], [2.0, 2.0], linewidth=1, linestyle="--", color="red",
label="error area", alpha=0.8, zorder=1)
plt.fill([[1], [1], [4], [4]], [[2],[5],[5],[2]], alpha=0.05, color='red')
plt.plot([2.0, 2.0], [4.0, 10.0], linewidth=1, linestyle="--", color="red",
label="error area", alpha=0.8, zorder=1)
plt.plot([8.0, 8.0], [4.0, 10.0], linewidth=1, linestyle="--", color="red",
label="error area", alpha=0.8, zorder=1)
plt.plot([2.0, 8.0], [10.0, 10.0], linewidth=1, linestyle="--", color="red",
label="error area", alpha=0.8, zorder=1)
plt.plot([2.0, 8.0], [4.0, 4.0], linewidth=1, linestyle="--", color="red",
label="error area", alpha=0.8, zorder=1)
plt.fill([[2], [2], [8], [8]], [[4],[10],[10],[4]], alpha=0.2, color='red')
plt.plot([3.0, 3.0], [6.0, 15.0], linewidth=1, linestyle="--", color="red",
label="error area", alpha=0.8, zorder=1)
plt.plot([12.0, 12.0], [6.0, 15.0], linewidth=1, linestyle="--", color="red",
label="error area", alpha=0.8, zorder=1)
plt.plot([3.0, 12.0], [15.0, 15.0], linewidth=1, linestyle="--", color="red",
label="error area", alpha=0.8, zorder=1)
plt.plot([3.0, 12.0], [6.0, 6.0], linewidth=1, linestyle="--", color="red",
label="error area", alpha=0.8, zorder=1)
plt.fill([[3], [3], [12], [12]], [[6],[15],[15],[6]], alpha=0.4, color='red')
plt.text(0.5, -0.25, "d0", color="red", fontsize=12)
plt.text(4.5, 2.75, "d1", color="red", fontsize=12)
plt.text(8.5, 4.75, "d2", color="red", fontsize=12)
plt.text(12.5, 10.25, "d3", color="red", fontsize=12)
plt.legend();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment