Skip to content

Instantly share code, notes, and snippets.

@DerekChia
Last active November 14, 2021 23:51
Show Gist options
  • Save DerekChia/24d297cbdfae9a58361244d5d2b75f9a to your computer and use it in GitHub Desktop.
Save DerekChia/24d297cbdfae9a58361244d5d2b75f9a to your computer and use it in GitHub Desktop.
def run():
x_batch, y_batch = generate_dataset()
x, y, y_pred, loss = linear_regression()
optimizer = tf.train.GradientDescentOptimizer(0.1)
train_op = optimizer.minimize(loss)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
feed_dict = {x: x_batch, y: y_batch}
for i in range(30):
session.run(train_op, feed_dict)
print(i, "loss:", loss.eval(feed_dict))
print('Predicting')
y_pred_batch = session.run(y_pred, {x : x_batch})
plt.scatter(x_batch, y_batch)
plt.plot(x_batch, y_pred_batch, color='red')
plt.xlim(0, 2)
plt.ylim(0, 2)
plt.savefig('plot.png')
if __name__ == "__main__":
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment