最后可以把测试数据输入网络,测试网络对测试图片的识别准确率。
test_loss, test_acc=model.evaluate(x_Test_normalize, y_Test_OneHot)
print('accuracy',test_acc)输出结果为:
运行结果表明,采用上述方法设计的网络对10000张测试图片的准确率达到97.85%,说明只有215张图片识别错误。具体哪些图片识别错误,可以利用如下代码进行分析。
#定义绘图函数
def plot_image_labels_prediction2(image, lables, prediction, idx, num=10):
fig=plt.gcf()
fig. set_size_inches(12,14)
if num>25:num=25
for i in range(0,num):
ax=plt.subplot(5,5,i+1)
ax. imshow(image[idx[i]],cmap='binary')
title="lable="+str(lables[idx[i]])(www.xing528.com)
if len(prediction)>0:
title+=",predict="+str(prediction[idx[i]])
ax. set_title(title, fontsize=10)
ax. set_xticks([]);ax.set_yticks([])
plt. show()
prediction=model.predict_classes(x_Test)#进行预测
index=np.arange(0,10000)
index_dif=index[prediction!=y_test_label]#查找分类错误的元素索引
plot_image_labels_prediction2(x_test_image, y_test_label, prediction, idx=index_dif, num=10)
图4-24 预测错误的图片
图4-24中列出了10张识别错误的图片,如第一张是把4识别成了9,最后一张是把2识别成了1。可以看出,一些图片依赖人眼也难以正确识别,可见这些错误是情有可原的。相对于传统模式识别方法复杂且识别率不高,本节的神经网络算法不足百行代码就实现了手写数字的识别,由此可见神经网络的强大威力。
免责声明:以上内容源自网络,版权归原作者所有,如有侵犯您的原创版权请告知,我们将尽快删除相关内容。