Caffe實(shí)踐C++源碼解讀:走入Solver
在理解如何使caffe運(yùn)行之后,我們要理解它是如何運(yùn)行的,即了解Solver類的Solve()函數(shù)做了什么,對(duì)于Solver類中如何初始化網(wǎng)絡(luò)以及其他參數(shù),有興趣的可以深入研究。
源碼中Solver()函數(shù)是有參數(shù)形式的
??//?The?main?entry?of?the?solver?function.?In?default,?iter?will?be?zero.?Pass ??//?in?a?non-zero?iter?number?to?resume?training?for?a?pre-trained?net. ??virtual?void?Solve(const?char*?resume_file?=?NULL); ??inline?void?Solve(const?string?resume_file)?{?Solve(resume_file.c_str());?}
各位一看就明白了吧。再看Solve函數(shù)的定義
templatevoid?Solver::Solve(const?char*?resume_file)?{ ??CHECK(Caffe::root_solver()); ??LOG(INFO)?<<?"Solving?"?<<?net_->name(); ??LOG(INFO)?<<?"Learning?Rate?Policy:?"?<<?param_.lr_policy(); ??//?Initialize?to?false?every?time?we?start?solving. ??requested_early_exit_?=?false; ??if?(resume_file)?{ ????LOG(INFO)?<<?"Restoring?previous?solver?status?from?"?<<?resume_file; ????Restore(resume_file); ??}
傳入?yún)?shù)resume_file是用于繼續(xù)中斷的訓(xùn)練的,既然caffe.cpp中的train()函數(shù)中已經(jīng)進(jìn)行了此操作,此處就不需要再傳入resume_file參數(shù)了。
然后Solver就直接進(jìn)入訓(xùn)練模式了,即Step函數(shù),傳入?yún)?shù)為循環(huán)的次數(shù),此參數(shù)在solver.txt文件中定義的max_iter和resume_file加載的iter_參數(shù)的差。
??//?For?a?network?that?is?trained?by?the?solver,?no?bottom?or?top?vecs ??//?should?be?given,?and?we?will?just?provide?dummy?vecs. ??int?start_iter?=?iter_; ??Step(param_.max_iter()?-?iter_);
在進(jìn)入Step函數(shù)之前,我們繼續(xù)往下看,訓(xùn)練完成后caffe會(huì)保存當(dāng)前模型
??//?If?we?haven't?already,?save?a?snapshot?after?optimization,?unless ??//?overridden?by?setting?snapshot_after_train?:=?false ??if?(param_.snapshot_after_train() ??????&&?(!param_.snapshot()?||?iter_?%?param_.snapshot()?!=?0))?{ ????Snapshot(); ??}
如果solver.txt中提供了test網(wǎng)絡(luò),那么會(huì)在訓(xùn)練完成后進(jìn)行一次測試
??//?After?the?optimization?is?done,?run?an?additional?train?and?test?pass?to ??//?display?the?train?and?test?loss/outputs?if?appropriate?(based?on?the ??//?display?and?test_interval?settings,?respectively).??Unlike?in?the?rest?of ??//?training,?for?the?train?net?we?only?run?a?forward?pass?as?we've?already ??//?updated?the?parameters?"max_iter"?times?--?this?final?pass?is?only?done?to ??//?display?the?loss,?which?is?computed?in?the?forward?pass. ??if?(param_.display()?&&?iter_?%?param_.display()?==?0)?{ ????int?average_loss?=?this->param_.average_loss(); ????Dtype?loss; ????net_->Forward(&loss); ????UpdateSmoothedLoss(loss,?start_iter,?average_loss); ????LOG(INFO)?<<?"Iteration?"?<<?iter_?<<?",?loss?=?"?<<?smoothed_loss_; ??} ??if?(param_.test_interval()?&&?iter_?%?param_.test_interval()?==?0)?{ ????TestAll(); ??}
在Step函數(shù)中通過while循環(huán)迭代訓(xùn)練,并且如果設(shè)置有測試網(wǎng)絡(luò),在設(shè)置條件滿足時(shí),每次循環(huán)會(huì)先對(duì)當(dāng)前網(wǎng)絡(luò)進(jìn)行測試
??while?(iter_?<?stop_iter)?{ ????//?zero-init?the?params ????net_->ClearParamDiffs(); ????if?(param_.test_interval()?&&?iter_?%?param_.test_interval()?==?0 ????????&&?(iter_?>?0?||?param_.test_initialization()))?{ ??????if?(Caffe::root_solver())?{ ????????TestAll(); ??????}
測試完成后,如何沒有終止訓(xùn)練,將繼續(xù)訓(xùn)練,此處的iter_size默認(rèn)值是1,主要作用是SGD中參數(shù)更新頻率,即訓(xùn)練iter_size后更新網(wǎng)絡(luò),此時(shí)訓(xùn)練的總樣本數(shù)為train.txt中定義的batch_size * iter_size。
????for?(int?i?=?0;?i?<?param_.iter_size();?++i)?{ ??????loss?+=?net_->ForwardBackward(); ????}
之后調(diào)用ApplyUpdate();更新權(quán)值和偏置,更新方法后續(xù)再聊。
Step中的測試與caffe.cpp中的test類似,主要是檢測當(dāng)前網(wǎng)絡(luò)訓(xùn)練狀態(tài),可以根據(jù)任務(wù)狀態(tài)提前終止訓(xùn)練,比如測試的損失函數(shù)達(dá)到一定范圍。