obal_variables_initializer(),
tf.local_variables_initializer()
))
# 初始化训练数据的迭代器。
sess.run(iterator.initializer)
# 循环进行训练,知道数据集完成输入,抛出OutOfRangeError错误
while True:
try:
sess.run(train_step)
except tf.errors.OutOfRangeError:
break
# 初始化测试数据的迭代器
sess.run(test_iterator.initializer)
# 获取预测结果
test_results = []
test_labels = []
while True:
try:
pred, label = sess.run([predictions, test_label_batch])
test_results.extend(pred)
test_labels.extend(label)
except tf.errors.OutOfRangeError:
break
# 计算准确率
correct = [float(y == y_) for (y, y_) in zip(test_results, test_labels)]
accuracy = sum(correct) / len(correct)
print("Test accuracy is: ", accuracy)
|