勾配降下党青年局

万国のグラーディエントよ、降下せよ!

半年前に自分で作った機械学習によるオセロAIの復習をする3

 前回はネットワークを紹介しました。このネットワークは指し手(方策)と評価値(価値)を返します。このネットワークを利用して、モンテカルロ木探索という手法により、実際の指し手を計算していきます。
 モンテカルロ木探索というと、なんだかランダム性がありそうな予感がしてきますが、なんと決定的なアルゴリズムです。元々のモンテカルロ木探索は、でたらめに試行しまくってたまたま一番勝率が高かった手を選ぶという手法です。ランダムな試行でもゲームが終わる囲碁で流行した手法のようです。機械学習を利用する文脈ではネットワークに基づいて試行していくため、ランダム性はありません。じゃあ名前変えろよって思っちゃいません?方策価値ネットワーク木探索とかどうですか。
 ここからの説明は元のモンテカルロ木探索に関する知識がないとイメージしづらいかと思います。聖書を読んでください。
 オセロとか将棋とかは局面をノード、指し手をエッジとすると、木になります。現在の盤面から指し手を探索していって、木をどんどん大きくしていきます。木をおっきくすれば良い手が見つけやすくなりそうですね。
 モンテカルロ木探索の一回分の探索単位をプレイアウトと呼びます。1プレイアウトでは、根から初めてひたすら子ノードをもぐっていき、葉っぱにたどりつくまで探索します。深さ優先探索みたいな感じで、隣のノードを探索したりしません。探索経路はいっぱいありますが、3つの基準で決めます。

  1. 方策ネットワークに基づく指し手の確率(大きいほど、選ばれやすい)
  2. 価値ネットワークに基づく葉ノードの平均評価値(大きいほど、選ばれやすい)
  3. 訪問回数(大きいほど、選ばれにくい)

1は分かりやすいですね。ネットワークが良い手を予測してくれると仮定すれば、その手を探索すればいいわけです。2ではそのノードからたどり着く葉ノードの平均評価値(価値ネットワークの出力)に基づいています。読み進めていくと意外といい手だった、なんて場合に良い結果になります。1,2について、ネットワークは常に同じ値を示すので、それだけでは何回探索しても同じ経路をたどるだけです。そこで3の出番です。訪問回数が大きいほど、十分探索できたでしょ、という考えで探索経路に選ばれづらくします。こうすることで、良い手ほどいっぱい探索されて、そうでもない手はあまり探索されません。
具体的にはあるノードにたどり着いたとき、以下が一番高い子ノードが次の探索先になります。
 平均評価値+指し手の確率×\frac{\sqrt{ノードの訪問回数}}{子ノードの訪問回数+1}
葉にたどり着いたら、葉の訪問回数が規定値より大きければ、合法手分ノードを展開して、木を大きくします。
というわけでプレイアウトを複数回やることで、木を成長させて最終的には葉ノードの平均評価値が一番高いノード(根の子ノード)が指し手として選ばれます。指し手を決める部分では、価値ネットワークに基づく評価値のみを参照するんですね。指し手を予測する方策ネットワークの出力は探索のみに使われます。この辺直感に反しますね。
これが全体的な流れですが、プレイアウト中にノードが展開され、新しいノードにたどり着いたとき、その盤面をネットワークで推論することになります。1つ1つのノードに推論していくのは、GPUの並列処理ぱわーを生かせていません。そのため実際には1つ1つノードを推論するのではなく何回か探索して新しいノードがたまったら並列で推論する、という方法をとります。ただし何回探索しても同じノードにたどり着いてしまったら意味がありません。そこで推論前のノードはとりあえず評価値を0(Virtual Loss)としておきます。評価値が0なので探索対象になりづらくなり、別の経路が選ばれるようになります。推論した後は、評価値を正しい値に直します。

以下実装を貼り付けるだけ。Pythonでやってますが、強いAIを作るためにはコンパイラ言語を利用したり、並列処理が必要です。
ノードクラスには、ネットワークの出力や、訪問回数、子ノードの情報や展開処理などが含まれます。

class Node():
  def __init__(self,iniwin=0.5,node_test=False):
    self.move_count = 0
    self.sum_value = 0.0
    self.child_move = None
    self.child_move_count = None
    self.child_sum_value = None
    self.child_node = None
    self.policy = None
    self.value = None
    self.iniwin = 0.5
    self.node_test = node_test

  def creat_child_node(self, index):
    self.child_node[index] = Node()

  def expand_node(self,board):
    child_num = len(self.child_move)
    self.child_move_count = np.zeros(child_num,dtype=np.int32)
    self.child_sum_value = np.zeros(child_num,dtype=np.float32)
    self.child_node = [None]*child_num

  def update(self,policy,value):
    self.policy = [policy[move%10 + move//10*8 -9].item() for move in self.child_move]
    self.value = value[0]

  def leaf(self,board):
    self.child_move = board.legal_moves
    if len(self.child_move)==0:
      self.value = sigmoid(board.turn * board.board[BOARD].sum()/10)
      return True
    return False

  def eval(self,board,model):
    self.child_move = board.legal_moves
    if len(self.child_move)==0:
      self.value = sigmoid(board.turn * board.board[BOARD].sum()/10)
    else:
      feature = np.array([board.feature()],dtype=np.float32)
      policy, value = self.infer(feature,model)
      self.value = value[0,0]
      self.policy = [policy[0][move%10 + move//10*8 -9].item() for move in self.child_move]

  def choice(self,c_puct):
    if self.node_test:
        win_rate = self.child_sum_value / (self.child_move_count+1)
    else:
        win_rate = np.divide(self.child_sum_value,self.child_move_count,out=np.zeros(len(self.child_move), np.float32)+self.iniwin,where=self.child_move_count != 0)
    index = np.argmax(win_rate + (self.policy * (np.sqrt(np.float32(self.move_count)) / (self.child_move_count + 1))) * c_puct)
    if self.child_node[index] == None:
      self.creat_child_node(index)
    return  index

MCTSクラスでは、プレイアウト、バーチャルロスとバーチャルロスを元に戻すバックアップなどの処理があります。またオセロの場合、最終局面は全探索可能なので全探索して最善手を出します。
推論関数はonnxを利用しているのでちょっと特殊です。まあ指し手にソフトマックスを通して、評価にシグモイドを通しているだけです。

import time
def infer(feature,session,batch=1):
    if session == "random":
      return np.random.rand(batch,64) , np.random.rand(batch,1)
    io_binding = session.io_binding()
    io_binding.bind_cpu_input('input', feature)
    io_binding.bind_output('policy')
    io_binding.bind_output('value')
    session.run_with_iobinding(io_binding)
    policy,value =  io_binding.copy_outputs_to_cpu()
    return softmax(policy),sigmoid(value)

class MCTS():
  def __init__(self,board,session,halt=100,c_puct=1,threshold=3,batch_size=8,iniwin = 0.5,node_test = False,perfect=0,temp=0.1):
    self.original_node = Node(iniwin,node_test)
    self.original_board = board.copy()
    self.current_node = self.original_node
    self.current_board = board.copy()
    self.route = []
    self.indexes = []
    self.turns = [self.current_board.turn]
    self.moves = []
    self.threshold = threshold
    self.playout_number = 0
    self.session = session
    self.halt = halt
    self.c_puct = c_puct
    self.batch_size = batch_size
    self.current_batch = 0
    self.batch_route = []
    self.batch_feature = []
    self.times = 0
    self.iniwin = iniwin
    self.node_test = node_test
    self.perfect = perfect
    self.temp = temp

  def search(self,prob=False):
    self.times = 0
    while self.playout_number <= self.halt:
      self.route.append(self.current_node)
      if self.current_node.move_count == 0:
        if self.current_node.leaf(self.current_board):
          self.playout()
          continue
        else:
          self.virtual_loss()
          continue
      elif self.current_node.value == None:
        self.virtual_loss(False)
        continue
      elif self.current_node.move_count < self.threshold:
        self.playout()
        continue
      else:
        if self.current_node.child_node == None:
          self.current_node.expand_node(self.current_board)
        if len(self.current_node.child_move) == 0:
          self.playout()
          continue
        index = self.current_node.choice(self.c_puct)
        self.moves.append(self.current_node.child_move[index])
        start_time = time.perf_counter()
        self.current_board.move(self.current_node.child_move[index])
        end_time = time.perf_counter()
        self.times += end_time-start_time
        self.current_node = self.current_node.child_node[index]
        self.indexes.append(index)
        self.turns.append(self.current_board.turn)

    sorted_index = np.argsort(self.original_node.child_move_count)[::-1]
    winrate = self.original_node.child_sum_value[sorted_index[0]] / self.original_node.child_node[sorted_index[0]].move_count
    if prob:
        move = np.random.choice(self.original_node.child_move,p=softmax([self.original_node.child_move_count/self.halt/self.temp])[0])
        return move,winrate
    else:
        return np.array(self.original_node.child_move)[sorted_index] ,winrate

  def perfect_search(self,input_board):
      v = 0
      max = -1000
      turn = input_board.turn
      if len(input_board.legal_moves) == 0:
        return None,sigmoid(input_board.turn*input_board.board[BOARD].sum()/10)
      for move in input_board.legal_moves:
        board = input_board.copy()
        board.move(move)
        _,v = self.perfect_search(board)
        if turn != board.turn:
            v = 1 - v
        if v >= max:
            best_move = move
            max = v
      return best_move,max
  
  def move(self,num=0,re_eval=False,prob=False):
    if self.perfect >= 100 - np.count_nonzero(self.current_board.board):
        move,eval = self.perfect_search(self.current_board)
    else:
        if prob:
            move,eval = self.search(True)
        else:
            moves,eval = self.search()
            move = moves[min(num,len(moves)-1)]
        self.playout_number = 0
        index = self.original_node.child_move.index(move)
        if self.original_node.child_node[index] == None:
            self.original_node = Node(self.iniwin,self.node_test)
        else:
            self.original_node = self.original_node.child_node[index]
        self.current_node = self.original_node
    
    self.original_board.move(move)
    self.current_board = self.original_board.copy()
    if re_eval:
      return move,eval
    else:
      return move

  def move_enemy(self,move):
    self.original_board.move(move)
    self.current_board = self.original_board.copy()
    self.playout_number = 0
    if self.perfect >= 100 - np.count_nonzero(self.current_board.board):
      return;

    if self.original_node.child_node == None:
      self.original_node = Node(self.iniwin,self.node_test)
    else:
      index = self.original_node.child_move.index(move)
      if self.original_node.child_node[index] == None:
        self.original_node = Node(self.iniwin,self.node_test)
      else:
        self.original_node = self.original_node.child_node[index]
    self.current_node = self.original_node

    

  def playout(self):
    self.playout_number += 1
    results = None
    while len(self.route) > 0:
      node = self.route.pop()
      if results == None:
        results = node.value
        node.sum_value += results
      else:
        index = self.indexes.pop()
        node.sum_value += results
        node.child_move_count[index] += 1
        node.child_sum_value[index] += results
      turn = self.turns.pop()
      if len(self.turns) >= 1 and turn != self.turns[-1]:
        results = 1 - results
      node.move_count += 1
    self.current_node = self.original_node
    self.current_board = self.original_board.copy()
    self.route = []
    self.indexes = []
    self.moves = []
    self.turns = [self.current_board.turn]

  def virtual_loss(self,init=True):
    self.current_batch += 1
    self.batch_route += [(self.route.copy(),self.indexes.copy(),self.turns.copy())]
    if init:
      self.batch_feature += [self.current_board.feature()]
    results = None
    while len(self.route) > 0:
      node = self.route.pop()
      if results == None:
        results = 0
      else:
        index = self.indexes.pop()
        node.child_move_count[index] += 1
      turn = self.turns.pop()
      node.move_count += 1
    self.current_node = self.original_node
    self.current_board = self.original_board.copy()
    self.route = []
    self.indexes = []
    self.moves = []
    self.turns = [self.current_board.turn]
    if self.current_batch == self.batch_size:
      self.backup()
      self.current_batch = 0

  def backup(self):
    self.playout_number += self.current_batch
    policies, values = infer(np.array(self.batch_feature,dtype=np.float32),self.session,batch=len(self.batch_feature))
    i = 0
    for route,indexes,turns in self.batch_route:
      results = None
      while len(route) > 0:
        node = route.pop()
        if results == None:
          if node.value==None:
            results = values[i,0]
            node.update(policies[i],values[i])
            i+=1
          else:
            results = node.value
          node.sum_value += results
          
        else:
          index = indexes.pop()
          node.sum_value += results
          node.child_sum_value[index] += results
        turn = turns.pop()
        if len(turns) >= 1 and turn != turns[-1]:
          results = 1 - results
    self.batch_route = []
    self.batch_feature = []