Tensorflow數(shù)據(jù)讀取方法
轉(zhuǎn)展多處都沒有找到詳細介紹Tensorflow讀取文件的方法
引言
Tensorflow的數(shù)據(jù)讀取有三種方式:
Preloaded data: 預(yù)加載數(shù)據(jù)Feeding: Python產(chǎn)生數(shù)據(jù),再把數(shù)據(jù)喂給后端。Reading from file: 從文件中直接讀取
這三種有讀取方式有什么區(qū)別呢? 我們首先要知道TensorFlow(TF)是怎么樣工作的。
TF的核心是用C++寫的,這樣的好處是運行快,缺點是調(diào)用不靈活。而Python恰好相反,所以結(jié)合兩種語言的優(yōu)勢。涉及計算的核心算子和運行框架是用C++寫的,并提供API給Python。Python調(diào)用這些API,設(shè)計訓(xùn)練模型(Graph),再將設(shè)計好的Graph給后端去執(zhí)行。簡而言之,Python的角色是Design,C++是Run。
Preload與Feeding Preload
import?tensorflow?as?tf #?設(shè)計Graph x1?=?tf.constant([2,?3,?4]) x2?=?tf.constant([4,?0,?1]) y?=?tf.add(x1,?x2) #?打開一個session?-->?計算y with?tf.Session()?as?sess: ????print?sess.run(y)
在設(shè)計Graph的時候,x1和x2就被定義成了兩個有值的列表,在計算y的時候直接取x1和x2的值。
Feeding
import?tensorflow?as?tf #?設(shè)計Graph x1?=?tf.placeholder(tf.int16) x2?=?tf.placeholder(tf.int16) y?=?tf.add(x1,?x2) #?用Python產(chǎn)生數(shù)據(jù) li1?=?[2,?3,?4] li2?=?[4,?0,?1] #?打開一個session?-->?喂數(shù)據(jù)?-->?計算y with?tf.Session()?as?sess: ????print?sess.run(y,?feed_dict={x1:?li1,?x2:?li2})
在這里x1, x2只是占位符,沒有具體的值,那么運行的時候去哪取值呢?這時候就要用到sess.run()
中的feed_dict
參數(shù),將Python產(chǎn)生的數(shù)據(jù)喂給后端,并計算y。
兩種方法的區(qū)別
Preload:
將數(shù)據(jù)直接內(nèi)嵌到Graph中,再把Graph傳入Session中運行。當(dāng)數(shù)據(jù)量比較大時,Graph的傳輸會遇到效率問題。
Feeding:
用占位符替代數(shù)據(jù),待運行的時候填充數(shù)據(jù)。
Reading From File
前兩種方法很方便,但是遇到大型數(shù)據(jù)的時候就會很吃力,即使是Feeding,中間環(huán)節(jié)的增加也是不小的開銷,比如數(shù)據(jù)類型轉(zhuǎn)換等等。最優(yōu)的方案就是在Graph定義好文件讀取的方法,讓TF自己去從文件中讀取數(shù)據(jù),并解碼成可使用的樣本集。
在上圖中,首先由一個單線程把文件名堆入隊列,兩個Reader同時從隊列中取文件名并讀取數(shù)據(jù),Decoder將讀出的數(shù)據(jù)解碼后堆入樣本隊列,最后單個或批量取出樣本(圖中沒有展示樣本出列)。我們這里通過三段代碼逐步實現(xiàn)上圖的數(shù)據(jù)流,這里我們不使用隨機,讓結(jié)果更清晰。
文件準備
$?echo?-e?"Alpha1,A1nAlpha2,A2nAlpha3,A3"?>?A.csv $?echo?-e?"Bee1,B1nBee2,B2nBee3,B3"?>?B.csv $?echo?-e?"Sea1,C1nSea2,C2nSea3,C3"?>?C.csv $?cat?A.csv Alpha1,A1 Alpha2,A2 Alpha3,A3
單個Reader,單個樣本
import?tensorflow?as?tf #?生成一個先入先出隊列和一個QueueRunner filenames?=?['A.csv',?'B.csv',?'C.csv'] filename_queue?=?tf.train.string_input_producer(filenames,?shuffle=False) #?定義Reader reader?=?tf.TextLineReader() key,?value?=?reader.read(filename_queue) #?定義Decoder example,?label?=?tf.decode_csv(value,?record_defaults=[['null'],?['null']]) #?運行Graph with?tf.Session()?as?sess: ????coord?=?tf.train.Coordinator()??#創(chuàng)建一個協(xié)調(diào)器,管理線程 ????threads?=?tf.train.start_queue_runners(coord=coord)??#啟動QueueRunner,?此時文件名隊列已經(jīng)進隊。 ????for?i?in?range(10): ????????print?example.eval()???#取樣本的時候,一個Reader先從文件名隊列中取出文件名,讀出數(shù)據(jù),Decoder解析后進入樣本隊列。 ????coord.request_stop() ????coord.join(threads) #?outpt Alpha1 Alpha2 Alpha3 Bee1 Bee2 Bee3 Sea1 Sea2 Sea3 Alpha1
單個Reader,多個樣本
import?tensorflow?as?tf filenames?=?['A.csv',?'B.csv',?'C.csv'] filename_queue?=?tf.train.string_input_producer(filenames,?shuffle=False) reader?=?tf.TextLineReader() key,?value?=?reader.read(filename_queue) example,?label?=?tf.decode_csv(value,?record_defaults=[['null'],?['null']]) #?使用tf.train.batch()會多加了一個樣本隊列和一個QueueRunner。Decoder解后數(shù)據(jù)會進入這個隊列,再批量出隊。 #?雖然這里只有一個Reader,但可以設(shè)置多線程,相應(yīng)增加線程數(shù)會提高讀取速度,但并不是線程越多越好。 example_batch,?label_batch?=?tf.train.batch( ??????[example,?label],?batch_size=5) with?tf.Session()?as?sess: ????coord?=?tf.train.Coordinator() ????threads?=?tf.train.start_queue_runners(coord=coord) ????for?i?in?range(10): ????????print?example_batch.eval() ????coord.request_stop() ????coord.join(threads) #?output #?['Alpha1'?'Alpha2'?'Alpha3'?'Bee1'?'Bee2'] #?['Bee3'?'Sea1'?'Sea2'?'Sea3'?'Alpha1'] #?['Alpha2'?'Alpha3'?'Bee1'?'Bee2'?'Bee3'] #?['Sea1'?'Sea2'?'Sea3'?'Alpha1'?'Alpha2'] #?['Alpha3'?'Bee1'?'Bee2'?'Bee3'?'Sea1'] #?['Sea2'?'Sea3'?'Alpha1'?'Alpha2'?'Alpha3'] #?['Bee1'?'Bee2'?'Bee3'?'Sea1'?'Sea2'] #?['Sea3'?'Alpha1'?'Alpha2'?'Alpha3'?'Bee1'] #?['Bee2'?'Bee3'?'Sea1'?'Sea2'?'Sea3'] #?['Alpha1'?'Alpha2'?'Alpha3'?'Bee1'?'Bee2']
多Reader,多個樣本
import?tensorflow?as?tf filenames?=?['A.csv',?'B.csv',?'C.csv'] filename_queue?=?tf.train.string_input_producer(filenames,?shuffle=False) reader?=?tf.TextLineReader() key,?value?=?reader.read(filename_queue) record_defaults?=?[['null'],?['null']] example_list?=?[tf.decode_csv(value,?record_defaults=record_defaults) ??????????????????for?_?in?range(2)]??#?Reader設(shè)置為2 #?使用tf.train.batch_join(),可以使用多個reader,并行讀取數(shù)據(jù)。每個Reader使用一個線程。 example_batch,?label_batch?=?tf.train.batch_join( ??????example_list,?batch_size=5) with?tf.Session()?as?sess: ????coord?=?tf.train.Coordinator() ????threads?=?tf.train.start_queue_runners(coord=coord) ????for?i?in?range(10): ????????print?example_batch.eval() ????coord.request_stop() ????coord.join(threads) ???? #?output #?['Alpha1'?'Alpha2'?'Alpha3'?'Bee1'?'Bee2'] #?['Bee3'?'Sea1'?'Sea2'?'Sea3'?'Alpha1'] #?['Alpha2'?'Alpha3'?'Bee1'?'Bee2'?'Bee3'] #?['Sea1'?'Sea2'?'Sea3'?'Alpha1'?'Alpha2'] #?['Alpha3'?'Bee1'?'Bee2'?'Bee3'?'Sea1'] #?['Sea2'?'Sea3'?'Alpha1'?'Alpha2'?'Alpha3'] #?['Bee1'?'Bee2'?'Bee3'?'Sea1'?'Sea2'] #?['Sea3'?'Alpha1'?'Alpha2'?'Alpha3'?'Bee1'] #?['Bee2'?'Bee3'?'Sea1'?'Sea2'?'Sea3'] #?['Alpha1'?'Alpha2'?'Alpha3'?'Bee1'?'Bee2']
tf.train.batch
與tf.train.shuffle_batch
函數(shù)是單個Reader讀取,但是可以多線程。tf.train.batch_join
與tf.train.shuffle_batch_join
可設(shè)置多Reader讀取,每個Reader使用一個線程。至于兩種方法的效率,單Reader時,2個線程就達到了速度的極限。多Reader時,2個Reader就達到了極限。所以并不是線程越多越快,甚至更多的線程反而會使效率下降。
迭代控制
filenames?=?['A.csv',?'B.csv',?'C.csv'] filename_queue?=?tf.train.string_input_producer(filenames,?shuffle=False,?num_epochs=3)??#?num_epoch:?設(shè)置迭代數(shù) reader?=?tf.TextLineReader() key,?value?=?reader.read(filename_queue) record_defaults?=?[['null'],?['null']] example_list?=?[tf.decode_csv(value,?record_defaults=record_defaults) ??????????????????for?_?in?range(2)] example_batch,?label_batch?=?tf.train.batch_join( ??????example_list,?batch_size=5) init_local_op?=?tf.initialize_local_variables() with?tf.Session()?as?sess: ????sess.run(init_local_op)???#?初始化本地變量? ????coord?=?tf.train.Coordinator() ????threads?=?tf.train.start_queue_runners(coord=coord) ????try: ????????while?not?coord.should_stop(): ????????????print?example_batch.eval() ????except?tf.errors.OutOfRangeError: ????????print('Epochs?Complete!') ????finally: ????????coord.request_stop() ????coord.join(threads) ????coord.request_stop() ????coord.join(threads) #?output #?['Alpha1'?'Alpha2'?'Alpha3'?'Bee1'?'Bee2'] #?['Bee3'?'Sea1'?'Sea2'?'Sea3'?'Alpha1'] #?['Alpha2'?'Alpha3'?'Bee1'?'Bee2'?'Bee3'] #?['Sea1'?'Sea2'?'Sea3'?'Alpha1'?'Alpha2'] #?['Alpha3'?'Bee1'?'Bee2'?'Bee3'?'Sea1'] #?Epochs?Complete!
在迭代控制中,記得添加tf.initialize_local_variables()
,官網(wǎng)教程沒有說明,但是如果不初始化,運行就會報錯。