Tensorflow數(shù)據(jù)讀取方法
掃描二維碼
隨時(shí)隨地手機(jī)看文章
轉(zhuǎn)展多處都沒有找到詳細(xì)介紹Tensorflow讀取文件的方法
引言
Tensorflow的數(shù)據(jù)讀取有三種方式:
Preloaded data: 預(yù)加載數(shù)據(jù)Feeding: Python產(chǎn)生數(shù)據(jù),再把數(shù)據(jù)喂給后端。Reading from file: 從文件中直接讀取
這三種有讀取方式有什么區(qū)別呢? 我們首先要知道TensorFlow(TF)是怎么樣工作的。
TF的核心是用C++寫的,這樣的好處是運(yùn)行快,缺點(diǎn)是調(diào)用不靈活。而Python恰好相反,所以結(jié)合兩種語言的優(yōu)勢。涉及計(jì)算的核心算子和運(yùn)行框架是用C++寫的,并提供API給Python。Python調(diào)用這些API,設(shè)計(jì)訓(xùn)練模型(Graph),再將設(shè)計(jì)好的Graph給后端去執(zhí)行。簡而言之,Python的角色是Design,C++是Run。
Preload與Feeding Preload
import?tensorflow?as?tf #?設(shè)計(jì)Graph x1?=?tf.constant([2,?3,?4]) x2?=?tf.constant([4,?0,?1]) y?=?tf.add(x1,?x2) #?打開一個(gè)session?-->?計(jì)算y with?tf.Session()?as?sess: ????print?sess.run(y)
在設(shè)計(jì)Graph的時(shí)候,x1和x2就被定義成了兩個(gè)有值的列表,在計(jì)算y的時(shí)候直接取x1和x2的值。
Feeding
import?tensorflow?as?tf #?設(shè)計(jì)Graph x1?=?tf.placeholder(tf.int16) x2?=?tf.placeholder(tf.int16) y?=?tf.add(x1,?x2) #?用Python產(chǎn)生數(shù)據(jù) li1?=?[2,?3,?4] li2?=?[4,?0,?1] #?打開一個(gè)session?-->?喂數(shù)據(jù)?-->?計(jì)算y with?tf.Session()?as?sess: ????print?sess.run(y,?feed_dict={x1:?li1,?x2:?li2})
在這里x1, x2只是占位符,沒有具體的值,那么運(yùn)行的時(shí)候去哪取值呢?這時(shí)候就要用到sess.run()
中的feed_dict
參數(shù),將Python產(chǎn)生的數(shù)據(jù)喂給后端,并計(jì)算y。
兩種方法的區(qū)別
Preload:
將數(shù)據(jù)直接內(nèi)嵌到Graph中,再把Graph傳入Session中運(yùn)行。當(dāng)數(shù)據(jù)量比較大時(shí),Graph的傳輸會(huì)遇到效率問題。
Feeding:
用占位符替代數(shù)據(jù),待運(yùn)行的時(shí)候填充數(shù)據(jù)。
Reading From File
前兩種方法很方便,但是遇到大型數(shù)據(jù)的時(shí)候就會(huì)很吃力,即使是Feeding,中間環(huán)節(jié)的增加也是不小的開銷,比如數(shù)據(jù)類型轉(zhuǎn)換等等。最優(yōu)的方案就是在Graph定義好文件讀取的方法,讓TF自己去從文件中讀取數(shù)據(jù),并解碼成可使用的樣本集。
在上圖中,首先由一個(gè)單線程把文件名堆入隊(duì)列,兩個(gè)Reader同時(shí)從隊(duì)列中取文件名并讀取數(shù)據(jù),Decoder將讀出的數(shù)據(jù)解碼后堆入樣本隊(duì)列,最后單個(gè)或批量取出樣本(圖中沒有展示樣本出列)。我們這里通過三段代碼逐步實(shí)現(xiàn)上圖的數(shù)據(jù)流,這里我們不使用隨機(jī),讓結(jié)果更清晰。
文件準(zhǔn)備
$?echo?-e?"Alpha1,A1nAlpha2,A2nAlpha3,A3"?>?A.csv $?echo?-e?"Bee1,B1nBee2,B2nBee3,B3"?>?B.csv $?echo?-e?"Sea1,C1nSea2,C2nSea3,C3"?>?C.csv $?cat?A.csv Alpha1,A1 Alpha2,A2 Alpha3,A3
單個(gè)Reader,單個(gè)樣本
import?tensorflow?as?tf #?生成一個(gè)先入先出隊(duì)列和一個(gè)QueueRunner filenames?=?['A.csv',?'B.csv',?'C.csv'] filename_queue?=?tf.train.string_input_producer(filenames,?shuffle=False) #?定義Reader reader?=?tf.TextLineReader() key,?value?=?reader.read(filename_queue) #?定義Decoder example,?label?=?tf.decode_csv(value,?record_defaults=[['null'],?['null']]) #?運(yùn)行Graph with?tf.Session()?as?sess: ????coord?=?tf.train.Coordinator()??#創(chuàng)建一個(gè)協(xié)調(diào)器,管理線程 ????threads?=?tf.train.start_queue_runners(coord=coord)??#啟動(dòng)QueueRunner,?此時(shí)文件名隊(duì)列已經(jīng)進(jìn)隊(duì)。 ????for?i?in?range(10): ????????print?example.eval()???#取樣本的時(shí)候,一個(gè)Reader先從文件名隊(duì)列中取出文件名,讀出數(shù)據(jù),Decoder解析后進(jìn)入樣本隊(duì)列。 ????coord.request_stop() ????coord.join(threads) #?outpt Alpha1 Alpha2 Alpha3 Bee1 Bee2 Bee3 Sea1 Sea2 Sea3 Alpha1
單個(gè)Reader,多個(gè)樣本
import?tensorflow?as?tf filenames?=?['A.csv',?'B.csv',?'C.csv'] filename_queue?=?tf.train.string_input_producer(filenames,?shuffle=False) reader?=?tf.TextLineReader() key,?value?=?reader.read(filename_queue) example,?label?=?tf.decode_csv(value,?record_defaults=[['null'],?['null']]) #?使用tf.train.batch()會(huì)多加了一個(gè)樣本隊(duì)列和一個(gè)QueueRunner。Decoder解后數(shù)據(jù)會(huì)進(jìn)入這個(gè)隊(duì)列,再批量出隊(duì)。 #?雖然這里只有一個(gè)Reader,但可以設(shè)置多線程,相應(yīng)增加線程數(shù)會(huì)提高讀取速度,但并不是線程越多越好。 example_batch,?label_batch?=?tf.train.batch( ??????[example,?label],?batch_size=5) with?tf.Session()?as?sess: ????coord?=?tf.train.Coordinator() ????threads?=?tf.train.start_queue_runners(coord=coord) ????for?i?in?range(10): ????????print?example_batch.eval() ????coord.request_stop() ????coord.join(threads) #?output #?['Alpha1'?'Alpha2'?'Alpha3'?'Bee1'?'Bee2'] #?['Bee3'?'Sea1'?'Sea2'?'Sea3'?'Alpha1'] #?['Alpha2'?'Alpha3'?'Bee1'?'Bee2'?'Bee3'] #?['Sea1'?'Sea2'?'Sea3'?'Alpha1'?'Alpha2'] #?['Alpha3'?'Bee1'?'Bee2'?'Bee3'?'Sea1'] #?['Sea2'?'Sea3'?'Alpha1'?'Alpha2'?'Alpha3'] #?['Bee1'?'Bee2'?'Bee3'?'Sea1'?'Sea2'] #?['Sea3'?'Alpha1'?'Alpha2'?'Alpha3'?'Bee1'] #?['Bee2'?'Bee3'?'Sea1'?'Sea2'?'Sea3'] #?['Alpha1'?'Alpha2'?'Alpha3'?'Bee1'?'Bee2']
多Reader,多個(gè)樣本
import?tensorflow?as?tf filenames?=?['A.csv',?'B.csv',?'C.csv'] filename_queue?=?tf.train.string_input_producer(filenames,?shuffle=False) reader?=?tf.TextLineReader() key,?value?=?reader.read(filename_queue) record_defaults?=?[['null'],?['null']] example_list?=?[tf.decode_csv(value,?record_defaults=record_defaults) ??????????????????for?_?in?range(2)]??#?Reader設(shè)置為2 #?使用tf.train.batch_join(),可以使用多個(gè)reader,并行讀取數(shù)據(jù)。每個(gè)Reader使用一個(gè)線程。 example_batch,?label_batch?=?tf.train.batch_join( ??????example_list,?batch_size=5) with?tf.Session()?as?sess: ????coord?=?tf.train.Coordinator() ????threads?=?tf.train.start_queue_runners(coord=coord) ????for?i?in?range(10): ????????print?example_batch.eval() ????coord.request_stop() ????coord.join(threads) ???? #?output #?['Alpha1'?'Alpha2'?'Alpha3'?'Bee1'?'Bee2'] #?['Bee3'?'Sea1'?'Sea2'?'Sea3'?'Alpha1'] #?['Alpha2'?'Alpha3'?'Bee1'?'Bee2'?'Bee3'] #?['Sea1'?'Sea2'?'Sea3'?'Alpha1'?'Alpha2'] #?['Alpha3'?'Bee1'?'Bee2'?'Bee3'?'Sea1'] #?['Sea2'?'Sea3'?'Alpha1'?'Alpha2'?'Alpha3'] #?['Bee1'?'Bee2'?'Bee3'?'Sea1'?'Sea2'] #?['Sea3'?'Alpha1'?'Alpha2'?'Alpha3'?'Bee1'] #?['Bee2'?'Bee3'?'Sea1'?'Sea2'?'Sea3'] #?['Alpha1'?'Alpha2'?'Alpha3'?'Bee1'?'Bee2']
tf.train.batch
與tf.train.shuffle_batch
函數(shù)是單個(gè)Reader讀取,但是可以多線程。tf.train.batch_join
與tf.train.shuffle_batch_join
可設(shè)置多Reader讀取,每個(gè)Reader使用一個(gè)線程。至于兩種方法的效率,單Reader時(shí),2個(gè)線程就達(dá)到了速度的極限。多Reader時(shí),2個(gè)Reader就達(dá)到了極限。所以并不是線程越多越快,甚至更多的線程反而會(huì)使效率下降。
迭代控制
filenames?=?['A.csv',?'B.csv',?'C.csv'] filename_queue?=?tf.train.string_input_producer(filenames,?shuffle=False,?num_epochs=3)??#?num_epoch:?設(shè)置迭代數(shù) reader?=?tf.TextLineReader() key,?value?=?reader.read(filename_queue) record_defaults?=?[['null'],?['null']] example_list?=?[tf.decode_csv(value,?record_defaults=record_defaults) ??????????????????for?_?in?range(2)] example_batch,?label_batch?=?tf.train.batch_join( ??????example_list,?batch_size=5) init_local_op?=?tf.initialize_local_variables() with?tf.Session()?as?sess: ????sess.run(init_local_op)???#?初始化本地變量? ????coord?=?tf.train.Coordinator() ????threads?=?tf.train.start_queue_runners(coord=coord) ????try: ????????while?not?coord.should_stop(): ????????????print?example_batch.eval() ????except?tf.errors.OutOfRangeError: ????????print('Epochs?Complete!') ????finally: ????????coord.request_stop() ????coord.join(threads) ????coord.request_stop() ????coord.join(threads) #?output #?['Alpha1'?'Alpha2'?'Alpha3'?'Bee1'?'Bee2'] #?['Bee3'?'Sea1'?'Sea2'?'Sea3'?'Alpha1'] #?['Alpha2'?'Alpha3'?'Bee1'?'Bee2'?'Bee3'] #?['Sea1'?'Sea2'?'Sea3'?'Alpha1'?'Alpha2'] #?['Alpha3'?'Bee1'?'Bee2'?'Bee3'?'Sea1'] #?Epochs?Complete!
在迭代控制中,記得添加tf.initialize_local_variables()
,官網(wǎng)教程沒有說明,但是如果不初始化,運(yùn)行就會(huì)報(bào)錯(cuò)。