上述工作完成后,利用一行代码即可实现网络的训练。
train_history=model.fit(x=x_Train_normalize,
y=y_Train_OneHot, validation_split=0.2,
shuffle=True, epochs=10,batch_size=200,verbose=1)
代码中,fit函数实现网络拟合,该函数的前两个参数分别为训练网络的样本和样本标签。validation_spli设置训练与验证数据比例,0.2表示80%数据用以训练,20%数据作为验证集用以测试。shuffle表示是否在训练过程中随机打乱输入样本的顺序,epochs=10表示执行10个训练周期,batch_size=200表示每一个训练周期网络一次读取200个样本数据进行训练。verbose参数用以控制训练过程的显示状态,verbose=0为不输出日志信息;verbose=1为输出进度条记录;verbose=2为每个epoch输出一行记录,该参数默认为1。代码运行后,会在迭代过程中显示训练状态,结果如图4-22所示。
图4-22 训练过程
上图显示的是网络迭代10次的训练效果,训练过程中显示了两组数字:一个是网络在训练数据上的损失(loss)和精度(acc)。另一个是网络在验证数据上的损失(val_loss)和精度(val_acc),最终在训练集中的准确率达到99%以上,在验证集上的准确率达到97%以上。
为更加直观地显示网络训练过程中损失函数和准确率等各项指标的变化,可采用以下代码绘制训练曲线(图4-23)。
#显示训练过程
import matplotlib. pyplot as plt
def show_train_history(train_history, train, validation):
plt. plot(train_history.history[train])(www.xing528.com)
plt. plot(train_history.history[validation])
plt. title('Train History')
plt. ylabel(train)
plt. xlabel('Epoch')
plt. legend(['train','validation'],loc='upper left')#显示左上角标签
plt. show()
show_train_history(train_history,'acc','val_acc')#画出准确率评估结果
show_train_history(train_history,'loss','val_loss')#画出误差执行结果
图4-23 训练历史曲线
在上述实验中可以看到train loss和validation loss两个性能指标,本例中是利用交叉熵损失函数计算得到。从字面理解二者分别代表训练误差和验证误差,前者指网络在训练数据集上计算得到的误差,后者代表模型在验证数据集上表现的误差。设置验证集可以及时评价网络的泛化能力,具体原因在下一节介绍。
免责声明:以上内容源自网络,版权归原作者所有,如有侵犯您的原创版权请告知,我们将尽快删除相关内容。