了解了迁移学习的功能之后,下面就利用VGG16这个耳熟能详的卷积神经网络实现对CIFAR-10的图像分类。VGG16是针对ImageNet数据集中1000类图像的分类模型,当时训练成本和训练时间都非常之高,现在,Keras框架中提供已经训练好的VGG16模型,这里可以直接采用“拿来主义”,利用VGG16完成对类似的数据集CIFAR-10的分类任务。
首先通过VGG16获取一个卷积模块,代码如下:
from keras. applications.vgg16 import VGG16
model_vgg=VGG16(include_top=False, weights='imagenet',input_shape=(ishape, ishape,3))
#include_top=False表示将vgg16顶层去掉,只保留网络结构
for layers in model_vgg. layers:
layers. trainable=False
#layers. trainable=False将不需要重新训练的权重“冷冻”起来
这里将卷积模块的每一层都给锁住,原因在于不希望仅仅识别几何特征的开始几个层被再次训练。
接下来利用Flatten()函数展开VGG最后的卷积层,并和全连接网络搭建的分类器组成一个完整的网络。
model=Flatten()(model_vgg.output)
model=Dense(1024,activation='relu',name='fc1')(model)
model=Dense(1024,activation='relu',name='fc2')(model)
model=Dropout(0.5)(model)
model=Dense(10,activation='softmax',name='prediction')(model)
model_vgg_cifar10_pretrain=Model(inputs=model_vgg.input, outputs=model, name='vgg16_pretrain')
model_vgg_cifar10_pretrain. summary()
sgd=SGD(lr=0.05,decay=1e-5)
model_vgg_cifar10_pretrain. compile(optimizer=sgd, loss='categorical_cros-sentropy',metrics=['accuracy'])
由于篇幅原因,这里只列出网络结构的最后部分,可以看到有大量参数不再参与训练,如图5-25所示。
图5-25 网络参数图
当然,同MNIST数据一样,还少不了数据载入和预处理部分。
(X_train, y_train),(X_test, y_test)=cifar10.load_data()
X_train=[cv2.resize(i,(ishape, ishape))for i in X_train]
X_test=[cv2.resize(i,(ishape, ishape))for i in X_test]
X_train=np.concatenate([arr[np.newaxis]for arr in X_train]).astype('float32')
X_test=np.concatenate([arr[np.newaxis]for arr in X_test]).astype('float32')
#预处理
print(X_train[0]. shape)
print(y_train[0])
X_train=X_train/255
X_test=X_test/255
np. where(X_train[0]!=0)
def train_y(y):(www.xing528.com)
y_one=np.zeros(10)
y_one[y]=1
return y_one
y_train_one=np.array([train_y(y_train[i])for i in range(len(y_train))])
y_test_one=np.array([train_y(y_test[i])for i in range(len(y_test))])
最后是模型训练。
model_vgg_cifar10_pretrain. fit(X_train, y_train_one, validation_data=(X_test, y_test_one),epochs=50,batch_size=128)
model_vgg_cifar10_pretrain. save('cifar10.h5')
训练完毕,读者可以自行准备一张彩色图像,利用以下代码进行测试。
model=load_model('cifar10.h5')
class MainPredictImg(object):
def__init__(self):
pass
def pre(self, filename):
pred_image=processimage.imread(filename)
pred_iamge=np.array(pred_iamge)
pred_iamge=scipy.misc.imresize(pred_iamge, size=(64,64))
pred_iamge=pred_iamge.reshape(-1,64,64,3)
prediction=model.predict(pred_iamge)#predict
labels=['airplane','automobile','bird','cat','deer','dog','frog','horse','
ship','truck']
Final_prediction=[result.argmax()for result in prediction][0]
Final_prediction=labels[Final_prediction]
a=0
for i in prediction[0]:
print labels[a]
print'Percent:{:. 30%}'.format(i)
a=a+1
return Final_prediction
Predict=MainPredictImg()
res=Predict.pre('airplant.jpg')
print'your picture is:——>',res
免责声明:以上内容源自网络,版权归原作者所有,如有侵犯您的原创版权请告知,我们将尽快删除相关内容。