在使用TensorFlow進(jìn)行異步計(jì)算時,隊(duì)列是一種強(qiáng)大的機(jī)制。
正如TensorFlow中的其他組件一樣,隊(duì)列就是TensorFlow圖中的節(jié)點(diǎn)。這是一種有狀態(tài)的節(jié)點(diǎn),就像變量一樣:其他節(jié)點(diǎn)可以修改它的內(nèi)容。具體來說,其他節(jié)點(diǎn)可以把新元素插入到隊(duì)列后端(rear),也可以把隊(duì)列前端(front)的元素刪除。
為了感受一下隊(duì)列,讓我們來看一個簡單的例子。我們先創(chuàng)建一個“先入先出”的隊(duì)列(FIFOQueue),并將其內(nèi)部所有元素初始化為零。然后,我們構(gòu)建一個TensorFlow圖,它從隊(duì)列前端取走一個元素,加上1之后,放回隊(duì)列的后端。慢慢地,隊(duì)列的元素的值就會增加。
Enqueue
、 EnqueueMany
和Dequeue
都是特殊的節(jié)點(diǎn)。他們需要獲取隊(duì)列指針,而非普通的值,如此才能修改隊(duì)列內(nèi)容。我們建議您將它們看作隊(duì)列的方法。事實(shí)上,在Python API中,它們就是隊(duì)列對象的方法(例如q.enqueue(...)
)。
現(xiàn)在你已經(jīng)對隊(duì)列有了一定的了解,讓我們深入到細(xì)節(jié)...
隊(duì)列,如FIFOQueue
和RandomShuffleQueue
,在TensorFlow的張量異步計(jì)算時都非常重要。
例如,一個典型的輸入結(jié)構(gòu):是使用一個RandomShuffleQueue
來作為模型訓(xùn)練的輸入:
這種結(jié)構(gòu)具有許多優(yōu)點(diǎn),正如在Reading data how to中強(qiáng)調(diào)的,同時,Reading data how to也概括地描述了如何簡化輸入管道的構(gòu)造過程。
TensorFlow的Session
對象是可以支持多線程的,因此多個線程可以很方便地使用同一個會話(Session)并且并行地執(zhí)行操作。然而,在Python程序?qū)崿F(xiàn)這樣的并行運(yùn)算卻并不容易。所有線程都必須能被同步終止,異常必須能被正確捕獲并報告,回話終止的時候, 隊(duì)列必須能被正確地關(guān)閉。
所幸TensorFlow提供了兩個類來幫助多線程的實(shí)現(xiàn):tf.Coordinator和
tf.QueueRunner。從設(shè)計(jì)上這兩個類必須被一起使用。Coordinator
類可以用來同時停止多個工作線程并且向那個在等待所有工作線程終止的程序報告異常。QueueRunner
類用來協(xié)調(diào)多個工作線程同時將多個張量推入同一個隊(duì)列中。
Coordinator類用來幫助多個線程協(xié)同工作,多個線程同步終止。 其主要方法有:
should_stop()
:如果線程應(yīng)該停止則返回True。request_stop(<exception>)
: 請求該線程停止。join(<list of threads>)
:等待被指定的線程終止。首先創(chuàng)建一個Coordinator
對象,然后建立一些使用Coordinator
對象的線程。這些線程通常一直循環(huán)運(yùn)行,一直到should_stop()
返回True時停止。
任何線程都可以決定計(jì)算什么時候應(yīng)該停止。它只需要調(diào)用request_stop()
,同時其他線程的should_stop()
將會返回True
,然后都停下來。
# 線程體:循環(huán)執(zhí)行,直到`Coordinator`收到了停止請求。
# 如果某些條件為真,請求`Coordinator`去停止其他線程。
def MyLoop(coord):
while not coord.should_stop():
...do something...
if ...some condition...:
coord.request_stop()
# Main code: create a coordinator.
coord = Coordinator()
# Create 10 threads that run 'MyLoop()'
threads = [threading.Thread(target=MyLoop, args=(coord)) for i in xrange(10)]
# Start the threads and wait for all of them to stop.
for t in threads: t.start()
coord.join(threads)
顯然,Coordinator可以管理線程去做不同的事情。上面的代碼只是一個簡單的例子,在設(shè)計(jì)實(shí)現(xiàn)的時候不必完全照搬。Coordinator還支持捕捉和報告異常, 具體可以參考Coordinator class的文檔。
QueueRunner
類會創(chuàng)建一組線程, 這些線程可以重復(fù)的執(zhí)行Enquene操作, 他們使用同一個Coordinator來處理線程同步終止。此外,一個QueueRunner會運(yùn)行一個closer thread,當(dāng)Coordinator收到異常報告時,這個closer thread會自動關(guān)閉隊(duì)列。
您可以使用一個queue runner,來實(shí)現(xiàn)上述結(jié)構(gòu)。 首先建立一個TensorFlow圖表,這個圖表使用隊(duì)列來輸入樣本。增加處理樣本并將樣本推入隊(duì)列中的操作。增加training操作來移除隊(duì)列中的樣本。
example = ...ops to create one example...
# Create a queue, and an op that enqueues examples one at a time in the queue.
queue = tf.RandomShuffleQueue(...)
enqueue_op = queue.enqueue(example)
# Create a training graph that starts by dequeuing a batch of examples.
inputs = queue.dequeue_many(batch_size)
train_op = ...use 'inputs' to build the training part of the graph...
在Python的訓(xùn)練程序中,創(chuàng)建一個QueueRunner
來運(yùn)行幾個線程, 這幾個線程處理樣本,并且將樣本推入隊(duì)列。創(chuàng)建一個Coordinator
,讓queue runner使用Coordinator
來啟動這些線程,創(chuàng)建一個訓(xùn)練的循環(huán), 并且使用Coordinator
來控制QueueRunner
的線程們的終止。
# Create a queue runner that will run 4 threads in parallel to enqueue
# examples.
qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)
# Launch the graph.
sess = tf.Session()
# Create a coordinator, launch the queue runner threads.
coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
# Run the training loop, controlling termination with the coordinator.
for step in xrange(1000000):
if coord.should_stop():
break
sess.run(train_op)
# When done, ask the threads to stop.
coord.request_stop()
# And wait for them to actually do it.
coord.join(threads)
通過queue runners啟動的線程不僅僅只處理推送樣本到隊(duì)列。他們還捕捉和處理由隊(duì)列產(chǎn)生的異常,包括OutOfRangeError
異常,這個異常是用于報告隊(duì)列被關(guān)閉。
使用Coordinator
的訓(xùn)練程序在主循環(huán)中必須同時捕捉和報告異常。
下面是對上面訓(xùn)練循環(huán)的改進(jìn)版本。
try:
for step in xrange(1000000):
if coord.should_stop():
break
sess.run(train_op)
except Exception, e:
# Report exceptions to the coordinator.
coord.request_stop(e)
# Terminate as usual. It is innocuous to request stop twice.
coord.request_stop()
coord.join(threads)
原文地址:Threading and Queues 翻譯:zhangkom 校對:volvet