Skip to content

Instantly share code, notes, and snippets.

@xunge
Last active July 3, 2019 01:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xunge/e62250203fc6fe10f63dfd5ccb95d8f5 to your computer and use it in GitHub Desktop.
Save xunge/e62250203fc6fe10f63dfd5ccb95d8f5 to your computer and use it in GitHub Desktop.
定义单隐藏层前馈网络模型的训练样本X和Y、定义输入x,输出y,隐藏层参数分别定义为w1和b1,隐藏层的激活函数选取ReLU;输出层参数为w2,b2,输出层激活函数选取sigmoid。定义深度前馈网络模型输出out,定义损失函数为均方差损失函数loss,并且使用Adam算法的Optimizer
import tensorflow as tf
# 输入训练数据,这里是python的list, 也可以定义为numpy的ndarray
x_data = [[1., 0.], [0., 1.], [0., 0.], [1., 1.]]
x = tf.placeholder(tf.float32, shape=[None, 2]) # 定义占位符,占位符在运行图的时候必须feed数据
y_data = [[1], [1], [0], [0]] # 训练数据的标签,注意维度
y = tf.placeholder(tf.float32, shape=[None, 1])
# 定义variables,在运行图的过程中会被按照优化目标改变和保存
weights = {'w1': tf.Variable(tf.random_normal([2, 16])),
'w2': tf.Variable(tf.random_normal([16, 1]))}
bias = {'b1': tf.Variable(tf.zeros([1])),
'b2': tf.Variable(tf.zeros([1]))} # b1,b2初始为0,正态化初始也可
# 定义神经网络计算图
def nn(x, weights, bias):
d1 = tf.matmul(x, weights['w1']) + bias['b1']
d1 = tf.nn.relu(d1)
d2 = tf.matmul(d1, weights['w2']) + bias['b2']
d2 = tf.nn.sigmoid(d2)
return d2
pred = nn(x, weights, bias) # 预测值
cost = tf.reduce_mean(tf.square(y - pred)) # 损失函数
learning_rate = 0.01 # 学习率取0.01
# 定义tf.train用来训练
# train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) ## max_step: 20000, loss: 0.002638
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cost) ## max_step: 2000, loss: 0.000014
init = tf.global_variables_initializer() # 初始化参数,图运行的一开始必须初始化所有变量
# 运行图
with tf.Session() as sess:
sess.run(init)
max_step = 500
for i in range(max_step + 1):
sess.run(train_step, feed_dict={x: x_data, y: y_data})
loss = sess.run(cost, feed_dict={x: x_data, y: y_data})
if i % 100 == 0:
print('step: ' + str(i) + ' loss:' + "{:.6f}".format(loss)) # + ' accuracy:' + "{:.6f}".format(acc))
print(sess.run(pred, feed_dict={x: x_data}))
print('end')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment