【例11.3】基于H2O的手写数字识别。
(1)加载H2O包
>install.packages("h2o",repos=(c("http://s3.amazonaws.com/h2o-release/h2o/rel-kahan/5/R",getOption("repos"))))
>library(h2o)
载入需要的R:rjson、statmod和tools。
(2)启动H2O
启动H2O获取连接对象′localH2O′:
>localH2O=h2o.init(ip="localhost",port=54321,startH2O=TRUE,Xmx=′1g′)
为了停止H2O,需执行:
<h2o.shutdown(localH2O)
H2O启动后,就可以使用http://localhost:54321。
(3)数据准备
下载训练集:http://www.pjreddie.com/media/files/mnist_train.csv;
下载测试集:http://www.pjreddie.com/media/files/mnist_test.csv。
(4)建模
训练模型要很长一段时间,最后一行有相应的进度条可查看。
>model<-h2o.deeplearning(x=2:785, #输入变量个数
y=1, #响应变量个数
data=train_h2o,
activation="Tanh",
balance_classes=TRUE,
hidden=c(100,100,100), ##3层隐藏层
epochs=100)
输出模型结果:
<model
IP Address:localhost
Port:54321
Parsed Data Key:mnist_train.hex
Deep Learning Model Key:DeepLearning_9c7831f93efb58b38c3fa08cb17d4e4e(www.xing528.com)
Training classification error:0
Training mean square error:Inf
Validation classification error:0
Validation square error:Inf
Confusion matrix:
Reported onmnist_train.hex
(5)模型评估
>yhat_train<-h2o.predict(model,train_h2o)$predict
>yhat_train<-as.factor(as.matrix(yhat_train))
>yhat_test<-h2o.predict(model,test_h2o)$predict
>yhat_test<-as.factor(as.matrix(yhat_test))
查看前100条预测与实际的数据相比较
<y_test[1:100]
[1]721041495906901597349665407401313472712117423
512446355604195789374
[67]6430702917329776278473613693141769
Levels:0123456789
>yhat_test[1:100]
[1]721041894906901597349665407401313472712117423
512446355604195789374
[67]6430702917329776278473613693141769
Levels:0123456789
查看并保存结果:
>library(caret)
>res[1,1]<-round(h2o.confusionMatrix(yhat_train,y_train)$overall[1],4)
>res[1,2]<-round(h2o.confusionMatrix(yhat_test,y_test)$overall[1],4)
>print(res)
免责声明:以上内容源自网络,版权归原作者所有,如有侵犯您的原创版权请告知,我们将尽快删除相关内容。