Last active
July 3, 2019 01:45
-
-
Save xunge/e62250203fc6fe10f63dfd5ccb95d8f5 to your computer and use it in GitHub Desktop.
定义单隐藏层前馈网络模型的训练样本X和Y、定义输入x,输出y,隐藏层参数分别定义为w1和b1,隐藏层的激活函数选取ReLU;输出层参数为w2,b2,输出层激活函数选取sigmoid。定义深度前馈网络模型输出out,定义损失函数为均方差损失函数loss,并且使用Adam算法的Optimizer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
# 输入训练数据,这里是python的list, 也可以定义为numpy的ndarray | |
x_data = [[1., 0.], [0., 1.], [0., 0.], [1., 1.]] | |
x = tf.placeholder(tf.float32, shape=[None, 2]) # 定义占位符,占位符在运行图的时候必须feed数据 | |
y_data = [[1], [1], [0], [0]] # 训练数据的标签,注意维度 | |
y = tf.placeholder(tf.float32, shape=[None, 1]) | |
# 定义variables,在运行图的过程中会被按照优化目标改变和保存 | |
weights = {'w1': tf.Variable(tf.random_normal([2, 16])), | |
'w2': tf.Variable(tf.random_normal([16, 1]))} | |
bias = {'b1': tf.Variable(tf.zeros([1])), | |
'b2': tf.Variable(tf.zeros([1]))} # b1,b2初始为0,正态化初始也可 | |
# 定义神经网络计算图 | |
def nn(x, weights, bias): | |
d1 = tf.matmul(x, weights['w1']) + bias['b1'] | |
d1 = tf.nn.relu(d1) | |
d2 = tf.matmul(d1, weights['w2']) + bias['b2'] | |
d2 = tf.nn.sigmoid(d2) | |
return d2 | |
pred = nn(x, weights, bias) # 预测值 | |
cost = tf.reduce_mean(tf.square(y - pred)) # 损失函数 | |
learning_rate = 0.01 # 学习率取0.01 | |
# 定义tf.train用来训练 | |
# train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) ## max_step: 20000, loss: 0.002638 | |
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cost) ## max_step: 2000, loss: 0.000014 | |
init = tf.global_variables_initializer() # 初始化参数,图运行的一开始必须初始化所有变量 | |
# 运行图 | |
with tf.Session() as sess: | |
sess.run(init) | |
max_step = 500 | |
for i in range(max_step + 1): | |
sess.run(train_step, feed_dict={x: x_data, y: y_data}) | |
loss = sess.run(cost, feed_dict={x: x_data, y: y_data}) | |
if i % 100 == 0: | |
print('step: ' + str(i) + ' loss:' + "{:.6f}".format(loss)) # + ' accuracy:' + "{:.6f}".format(acc)) | |
print(sess.run(pred, feed_dict={x: x_data})) | |
print('end') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment