來源 | MyEncyclopedia
上一篇我們從原理層面解析了AlphaGo Zero如何改進MCTS算法,通過不斷自我對弈,最終實現從零棋力開始訓練直至能夠打敗任何高手。在本篇中,我們在已有的N子棋OpenAI Gym 環境中用Pytorch實現一個簡化版的AlphaGo Zero算法。本篇所有代碼在 github.com/MyEncyclopedia/ConnectNGym 中,其中部分參考了SongXiaoJun 的 github.com junxiaosong/AlphaZero_Gomoku。
AlphaGo Zero MCTS 樹節點
上一篇中,我們知道AlphaGo Zero 的MCTS樹搜索是基於傳統MCTS 的UCT (UCB for Tree)的改進版PUCT(Polynomial Upper Confidence Trees)。局面節點的PUCT值由兩部分組成,分別是代表Exploitation的action value Q值,和代表Exploration的U值。
U值計算由這些參數決定:係數,節點先驗機率P(s, a) ,父節點訪問次數,本節點的訪問次數。具體公式如下
因此在實現過程中,對於一個樹節點來說,需要保存其Q值、節點訪問次數 visit_num和先驗機率 prior。其中,prior在節點初始化後不變,Q值和 visit_num 隨著遊戲MCTS模擬進程而改變。此外,節點保存了 parent和 children變量,用於維護父子關係。c_puct為class variable,作為全局參數。
c_puct: ClassVar[int] = 5 # class-wise global param c_puct, exploration weight factor.
_parent: TreeNode_children: Dict[int, TreeNode] # map from action to TreeNode_visit_num: int_Q: float # Q value of the node, which is the mean action value._prior: float
和上面的計算公式相對應,下列代碼根據節點狀態計算PUCT(s, a)。
def get_puct(self) -> float:"""Computes AlphaGo Zero PUCT (polynomial upper confidence trees) of the node.
:return: Node PUCT value."""U = (TreeNode.c_puct * self._prior * np.sqrt(self._parent._visit_num) / (1 + self._visit_num))return self._Q + U
AlphaGo Zero MCTS在playout時遇到已經被展開的節點,會根據selection規則選擇子節點,該規則本質上是在所有子節點中選擇最大的PUCT值的節點。
def select(self) -> Tuple[Pos, TreeNode]:"""Selects an action(Pos) having max UCB value.
:return: Action and corresponding node"""return max(self._children.items, key=lambda act_node: act_node[1].get_puct)
新的葉節點一旦在playout時產生,關聯的 v 值會一路向上更新至根節點,具體新節點的v值將在下一節中解釋。
def propagate_to_root(self, leaf_value: float):"""Updates current node with observed leaf_value and propagates to root node.
:param leaf_value::return:"""if self._parent:self._parent.propagate_to_root(-leaf_value)self._update(leaf_value)
def _update(self, leaf_value: float):"""Updates the node by newly observed leaf_value.
:param leaf_value::return:"""self._visit_num += 1# new Q is updated towards deviation from existing Qself._Q += 0.5 * (leaf_value - self._Q)
AlphaGo Zero MCTS Player 實現
AlphaGo Zero MCTS 在訓練階段分為如下幾個步驟。遊戲初始局面下,整個局面樹的建立由子節點的不斷被探索而豐富起來。AlphaGo Zero對弈一次即產生了一次完整的遊戲開始到結束的動作系列。在對弈過程中的某一遊戲局面,需要採樣海量的playout,又稱MCTS模擬,以此來決定此局面的下一步動作。一次playout可視為在真實遊戲狀態樹的一種特定採樣,playout可能會產生遊戲結局,生成真實的v值;也可能explore 到新的葉子節點,此時v值依賴策略價值網絡的輸出,目的是利用訓練的神經網絡來產生高質量的遊戲對戰局面。每次playout會從當前給定局面遞歸向下,向下的過程中會遇到下面三種節點情況。
海量的playout模擬後,建立了遊戲狀態樹的節點信息。但至此,AI玩家只是收集了信息,還仍未給定局面落子,而落子的決定由Play規則產生。下圖展示了給定局面(Current節點)下,MCST模擬進行的多次playout探索後生成的局面樹,play規則根據這些節點信息,產生Current 節點的動作分布 ,確定下一步落子。
MCTS Playout和Play關係
Play 給定局面
對於當前需要做落子決定的某遊戲局面,根據如下play公式生成落子分布 ,子局面的落子機率正比於其訪問次數的某次方。其中,某次方的倒數稱為溫度參數(Temperature)。
def _next_step_play_act_probs(self, game: ConnectNGame) -> Tuple[List[Pos], ActionProbs]:"""For the given game status, run playouts number of times specified by self._playout_num.Returns the action distribution according to AlphaGo Zero MCTS play formula.
:param game::return: actions and their probability"""
for n in range(self._playout_num):self._playout(copy.deepcopy(game))
act_visits = [(act, node._visit_num) for act, node in self._current_root._children.items]acts, visits = zip(*act_visits)act_probs = softmax(1.0 / MCTSAlphaGoZeroPlayer.temperature * np.log(np.array(visits) + 1e-10))
return acts, act_probs
在訓練模式時,考慮到偏向exploration的目的,在落子分布的基礎上增加了 Dirichlet 分布。
def get_action(self, board: PyGameBoard) -> Pos:"""Method defined in BaseAgent.
:param board::return: next move for the given game board."""return self._get_action(copy.deepcopy(board.connect_n_game))[0]
def _get_action(self, game: ConnectNGame) -> Tuple[MoveWithProb]:epsilon = 0.25avail_pos = game.get_avail_posmove_probs: ActionProbs = np.zeros(game.board_size * game.board_size)assert len(avail_pos) > 0
# the pi defined in AlphaGo Zero paperacts, act_probs = self._next_step_play_act_probs(game)move_probs[list(acts)] = act_probsif self._is_training:# add Dirichlet Noise when training in favour of explorationp_ = (1-epsilon) * act_probs + epsilon * np.random.dirichlet(0.3 * np.ones(len(act_probs)))move = np.random.choice(acts, p=p_)assert move in game.get_avail_poselse:move = np.random.choice(acts, p=act_probs)
self.resetreturn move, move_probs
一次完整的對弈
一次完整的AI對弈就是從初始局面疊代play直至遊戲結束,對弈生成的數據是一系列的 。
如下圖 s0 到 s5 是某次井字棋的對弈。最終結局是先手黑棋玩家贏,即對於黑棋玩家 z = +1。需要注意的是:z = +1 是對於所有黑棋面臨的局面,即s0, s2, s4,而對應的其餘白棋玩家來說 z = -1。
一局完整對弈
以下代碼展示如何在AI對弈時收集數據
def self_play_one_game(self, game: ConnectNGame) -> List[Tuple[NetGameState, ActionProbs, NDArray[(Any), np.float]]]:"""
:param game::return:Sequence of (s, pi, z) of a complete game play. The number of list is the game play length."""
states: List[NetGameState] = []probs: List[ActionProbs] = []current_players: List[np.float] = []
while not game.game_over:move, move_probs = self._get_action(game)states.append(convert_game_state(game))probs.append(move_probs)current_players.append(game.current_player)game.move(move)
current_player_z = np.zeros(len(current_players))current_player_z[np.array(current_players) == game.game_result] = 1.0current_player_z[np.array(current_players) == -game.game_result] = -1.0self.reset
return list(zip(states, probs, current_player_z))
Playout 代碼實現
def _playout(self, game: ConnectNGame):"""From current game status, run a sequence down to a leaf node, either because game ends or unexplored node.Get the leaf value of the leaf node, either the actual reward of game or action value returned by policy net.And propagate upwards to root node.
:param game:"""player_id = game.current_player
node = self._current_rootwhile True:if node.is_leaf:breakact, node = node.selectgame.move(act)
# now game state is a leaf node in the tree, either a terminal node or an unexplored nodeact_and_probs: Iterator[MoveWithProb]act_and_probs, leaf_value = self._policy_value_net.policy_value_fn(game)
if not game.game_over:# case where encountering an unexplored leaf node, update leaf_value estimated by policy net to rootfor act, prob in act_and_probs:game.move(act)child_node = node.expand(act, prob)game.undoelse:# case where game ends, update actual leaf_value to rootif game.game_result == ConnectNGame.RESULT_TIE:leaf_value = ConnectNGame.RESULT_TIEelse:leaf_value = 1 if game.game_result == player_id else -1leaf_value = float(leaf_value)
# Update leaf_value and propagate up to root nodenode.propagate_to_root(-leaf_value)
編碼遊戲局面
為了將信息有效的傳遞給策略神經網絡,必須從當前玩家的角度編碼遊戲局面。局面不僅要反映棋盤上黑白棋子的位置,也需要考慮最後一個落子的位置以及是否為當前玩家棋局。因此,我們將某局面按照當前玩家來編碼,返回類型為4個棋盤大小組成的ndarray,即shape [4, board_size, board_size],其中
第一個數組編碼當前玩家的棋子位置
第二個數組編碼對手玩家棋子位置
第三個表示最後落子位置
第四個全1表示此局面為先手(黑棋)局面,全0表示白棋局面
例如之前遊戲對弈中的前四步:
s1->s2 後局面s2的編碼:當前玩家為黑棋玩家,編碼局面s2 返回如下ndarray,數組[0] 為s2黑子位置,[1]為白子位置,[2]表示最後一個落子(1, 1) ,[3] 全1表示當前是黑棋落子的局面。
編碼黑棋玩家局面 s2
s2->s3 後局面s3的編碼:當前玩家為白棋玩家,編碼返回如下,數組[0] 為s3白子位置,[1]為黑子位置,[2]表示最後一個落子(1, 0) ,[3] 全0表示當前是白棋落子的局面。
編碼白棋玩家局面 s3
具體代碼實現如下。
def convert_game_state(game: ConnectNGame) -> NetGameState:"""Converts game state to type NetGameState as ndarray.
:param game::return:Of shape 4 * board_size * board_size.[0] is current player positions.[1] is opponent positions.[2] is last move location.[3] all 1 meaning move by black player, all 0 meaning move by white."""state_matrix = np.zeros((4, game.board_size, game.board_size))
if game.action_stack:actions = np.array(game.action_stack)move_curr = actions[::2]move_oppo = actions[1::2]for move in move_curr:state_matrix[0][move] = 1.0for move in move_oppo:state_matrix[1][move] = 1.0# indicate the last move locationstate_matrix[2][actions[-1]] = 1.0if len(game.action_stack) % 2 == 0:state_matrix[3][:, :] = 1.0 # indicate the colour to playreturn state_matrix[:, ::-1, :]
策略價值網絡訓練
策略價值網絡是一個共享參數 的雙頭網絡,給定上面的遊戲局面編碼會產生預估的p和v。
結合真實遊戲對弈後產生三元組數據 ,按照論文中的loss 來訓練神經網絡。
下面代碼為Pytorch backward部分。
self.optimizer.zero_gradfor param_group in self.optimizer.param_groups:param_group['lr'] = lr
log_act_probs, value = self.policy_value_net(state_batch)# loss = (z - v)^2 - pi*T * log(p) + c||theta||^2value_loss = F.mse_loss(value.view(-1), value_batch)policy_loss = -torch.mean(torch.sum(probs_batch * log_act_probs, 1))loss = value_loss + policy_lossloss.backwardself.optimizer.stepentropy = -torch.mean(torch.sum(torch.exp(log_act_probs) * log_act_probs, 1))return loss.item, entropy.item
參考資料