合理主義的グルメブログ

学生起業家の日常をツラツラと書いています。主に食事情報です。

tensorflowで並列処理をする

A3CやPPOなどの深層強化学習アルゴリズムを実装していると,
並列処理が必ず話題になります.

実際に調べると,以下のサイトのようなthreadingモジュールとtf.train.Coordinatorを使ったものが,結構出てきます.

これの罠にハマったのでメモっておきます.

まず,以下のプログラムを見てください.

# スレッドを作成する
with tf.device('/cpu:0'):
    brain = Brain() # NNのクラス
    threads = []    # 並列して走るスレッド

    # 学習するスレッドを準備
    print(N_WORKERS)
    for i in range(N_WORKERS):
        threads.append(Worker_thread(
            thread_name='train_thread{}'.format(i+1),
            thread_type='train',
            brain=brain
        ))

    # 学習後にテストで走るスレッドを準備
    threads.append(Worker_thread(
        thread_name='test_thread',
        thread_type='test',
        brain=brain
    ))

coord = tf.train.Coordinator()  # tensorflowでマルチスレッドにするための準備
SESS.run(tf.global_variables_initializer()) # 変数を初期化

running_threads = []
print('Start Trainning...')
for i in range(len(threads)):
    job = lambda: threads[i].run()
    t = threading.Thread(target=job, daemon=True)
    t.start()
    running_threads.append(t)

for worker in threads:
    job = lambda: worker.run()
    t = threading.Thread(target=job, daemon=True)
    t.start()
    running_threads.append(t)

一見良さげに見えますが,これだとすべてのスレッドがtest_thread扱いになります.
これを直すには,以下のように書くといいです.

# スレッドを作成する
with tf.device('/cpu:0'):
    brain = Brain() # NNのクラス
    threads = []    # 並列して走るスレッド

    # 学習するスレッドを準備
    print(N_WORKERS)
    for i in range(N_WORKERS):
        threads.append(Worker_thread(
            thread_name='train_thread{}'.format(i+1),
            thread_type='train',
            brain=brain
        ))

    # 学習後にテストで走るスレッドを準備
    threads.append(Worker_thread(
        thread_name='test_thread',
        thread_type='test',
        brain=brain
    ))

# Tensorflowでマルチスレッドを実行
coord = tf.train.Coordinator()  # tensorflowでマルチスレッドにするための準備
SESS.run(tf.global_variables_initializer()) # 変数を初期化

running_threads = []
print('Start Trainning...')

def work(worker):
    worker.run()

for worker in threads:
    t = threading.Thread(target=work, args=(worker,), daemon=True)
    t.start()
    running_threads.append(t)

worker変数に入れてから,lambda内で使うと,うまくいきません.
これはおそらく,Pythonの参照関係の問題からです.

詳しい理由は調べてないので,分かる人がいたら,コメントなどで教えてください.