MCTS 튜토리얼 with TicTaeToe Example

Monte Carlo Tree Search 방법을 알아보고 TicTaeToe 문제를 풀어보자.

참조문헌:코드, 한글 설명, 영문설명.

다음과 같은 스크립트를 짜면 5*5 보드판에서 5개를 먼저 연속으로 두는 사람이 이기도록 2개의 player 가 game을 하는 형태로 초기화를 하고 MCTS를 사용하여 게임을 오토 플레이 할 수 있다.

In [1]:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
from mctspy.tree.nodes import TwoPlayersGameMonteCarloTreeSearchNode
from mctspy.tree.search import MonteCarloTreeSearch
from mctspy.games.examples.connect4 import Connect4GameState

# define inital state
state = np.zeros((5, 5))
board_state = Connect4GameState(
    state=state, next_to_move=np.random.choice([-1, 1]), win=5)

# link pieces to icons
pieces = {0: " ", 1: "X", -1: "O"}

# print a single row of the board
def stringify(row):
    return " " + " | ".join(map(lambda x: pieces[int(x)], row)) + " "

# display the whole board
def display(board):
    board = board.copy().T[::-1]
    for row in board[:-1]:
        print(stringify(row))
        print("-"*(len(row)*4-1))
    print(stringify(board[-1]))
    print()

display(board_state.board)
# keep playing until game terminates
while board_state.game_result is None:
    # calculate best move
    root = TwoPlayersGameMonteCarloTreeSearchNode(state=board_state)
    mcts = MonteCarloTreeSearch(root)
    best_node = mcts.best_action(total_simulation_seconds=1)

    # update board
    board_state = best_node.state
    # display board
    display(board_state.board)

# print result
print(pieces[board_state.game_result])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
   |   |   |   |   
-------------------
   |   |   |   |   
-------------------
   |   |   |   |   
-------------------
   |   |   |   |   
-------------------
   |   |   |   |   

   |   |   |   |   
-------------------
   |   |   |   |   
-------------------
   |   |   |   |   
-------------------
   |   |   | X |   
-------------------
   |   |   |   |   

   |   |   |   |   
-------------------
   |   |   |   |   
-------------------
   |   |   |   |   
-------------------
   |   |   | X |   
-------------------
   |   |   |   | O 

   |   |   |   |   
-------------------
   |   |   | X |   
-------------------
   |   |   |   |   
-------------------
   |   |   | X |   
-------------------
   |   |   |   | O 

   |   |   |   |   
-------------------
   |   |   | X |   
-------------------
   |   |   |   |   
-------------------
   |   |   | X |   
-------------------
   |   |   | O | O 

   |   |   |   |   
-------------------
   |   |   | X |   
-------------------
   |   |   |   |   
-------------------
   |   |   | X |   
-------------------
   | X |   | O | O 

   |   |   |   |   
-------------------
   |   |   | X |   
-------------------
   |   |   |   | O 
-------------------
   |   |   | X |   
-------------------
   | X |   | O | O 

   | X |   |   |   
-------------------
   |   |   | X |   
-------------------
   |   |   |   | O 
-------------------
   |   |   | X |   
-------------------
   | X |   | O | O 

   | X |   |   | O 
-------------------
   |   |   | X |   
-------------------
   |   |   |   | O 
-------------------
   |   |   | X |   
-------------------
   | X |   | O | O 

   | X |   |   | O 
-------------------
   |   |   | X |   
-------------------
   |   |   |   | O 
-------------------
 X |   |   | X |   
-------------------
   | X |   | O | O 

   | X |   |   | O 
-------------------
   |   |   | X |   
-------------------
   |   |   |   | O 
-------------------
 X |   |   | X | O 
-------------------
   | X |   | O | O 

   | X |   |   | O 
-------------------
   |   |   | X | X 
-------------------
   |   |   |   | O 
-------------------
 X |   |   | X | O 
-------------------
   | X |   | O | O 

   | X |   |   | O 
-------------------
   |   |   | X | X 
-------------------
   | O |   |   | O 
-------------------
 X |   |   | X | O 
-------------------
   | X |   | O | O 

   | X |   |   | O 
-------------------
   |   |   | X | X 
-------------------
   | O | X |   | O 
-------------------
 X |   |   | X | O 
-------------------
   | X |   | O | O 

   | X |   |   | O 
-------------------
   |   | O | X | X 
-------------------
   | O | X |   | O 
-------------------
 X |   |   | X | O 
-------------------
   | X |   | O | O 

   | X |   |   | O 
-------------------
   |   | O | X | X 
-------------------
   | O | X |   | O 
-------------------
 X |   |   | X | O 
-------------------
 X | X |   | O | O 

 O | X |   |   | O 
-------------------
   |   | O | X | X 
-------------------
   | O | X |   | O 
-------------------
 X |   |   | X | O 
-------------------
 X | X |   | O | O 

 O | X |   | X | O 
-------------------
   |   | O | X | X 
-------------------
   | O | X |   | O 
-------------------
 X |   |   | X | O 
-------------------
 X | X |   | O | O 

 O | X |   | X | O 
-------------------
   |   | O | X | X 
-------------------
   | O | X | O | O 
-------------------
 X |   |   | X | O 
-------------------
 X | X |   | O | O 

 O | X | X | X | O 
-------------------
   |   | O | X | X 
-------------------
   | O | X | O | O 
-------------------
 X |   |   | X | O 
-------------------
 X | X |   | O | O 

 O | X | X | X | O 
-------------------
   |   | O | X | X 
-------------------
   | O | X | O | O 
-------------------
 X |   | O | X | O 
-------------------
 X | X |   | O | O 

 O | X | X | X | O 
-------------------
   |   | O | X | X 
-------------------
   | O | X | O | O 
-------------------
 X |   | O | X | O 
-------------------
 X | X | X | O | O 

 O | X | X | X | O 
-------------------
   | O | O | X | X 
-------------------
   | O | X | O | O 
-------------------
 X |   | O | X | O 
-------------------
 X | X | X | O | O 

 O | X | X | X | O 
-------------------
   | O | O | X | X 
-------------------
   | O | X | O | O 
-------------------
 X | X | O | X | O 
-------------------
 X | X | X | O | O 

 O | X | X | X | O 
-------------------
 O | O | O | X | X 
-------------------
   | O | X | O | O 
-------------------
 X | X | O | X | O 
-------------------
 X | X | X | O | O 

 O | X | X | X | O 
-------------------
 O | O | O | X | X 
-------------------
 X | O | X | O | O 
-------------------
 X | X | O | X | O 
-------------------
 X | X | X | O | O 

 

사실 틱택토 게임은 경우의 수가 작은편이라서 양쪽 플레이어가 모든 시뮬레이션을 알고있다면 무조껀 무승부가 나올 수 밖에 없는 게임이다. 따라서 두 플레이어가 MCTS를 사용하여 경기를 하면 결과는 항상 무승부가 나온다. 어떻게 알고리즘이 동작하는지 보기 위해서 제일 먼저 호출되는 best_action함수로 들어가보자.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    def best_action(self, simulations_number=None, total_simulation_seconds=None):
        """

        Parameters
        ----------
        simulations_number : int
            number of simulations performed to get the best action

        total_simulation_seconds : float
            Amount of time the algorithm has to run. Specified in seconds

        Returns
        -------

        """

        if simulations_number is None :
            assert(total_simulation_seconds is not None)
            end_time = time.time() + total_simulation_seconds
            while time.time() < end_time:
                v = self._tree_policy()
                reward = v.rollout()
                v.backpropagate(reward)
        else :
            for _ in range(0, simulations_number):
                v = self._tree_policy()
                reward = v.rollout()
                v.backpropagate(reward)
        # to select best child go for exploitation only
        return self.root.best_child(c_param=0.)

MCTS 알고리즘은 주어진 시간/횟수동안 Slection -> Expansion -> Simulation -> Backpropagation을 반복한다.

MCTS

_tree_policy 함수에서 Selection -> Expansion 과정, rollout 함수에서 Simulation 과정,

마지막으로 backpropagate 함수에서 Backpropagation을 한다.


Selection & Expansion

Selection -> Expansion 과정이 어떻게 이루어지는 지 알아보기 위해서 _tree_policy 함수로 들어가보자.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def _tree_policy(self):
        """
        selects node to run rollout/playout for

        Returns
        -------

        """
        current_node = self.root
        while not current_node.is_terminal_node():
            if not current_node.is_fully_expanded():
                return current_node.expand()
            else:
                current_node = current_node.best_child()
        return current_node

Selection은 exploration과 exploitation을 고려하면서 Leaf Node로 가는 과정이다.

1
2
    def is_fully_expanded(self):
        return len(self.untried_actions) == 0

Leaf 노드에 대한 조건은 위와같이 시도하지 않은 행동의 존재 여부로 판별한다.

위의 함수를 통해서 현 노드에서 expansion 할지 selection할 지를 판단한다

현 노드에서 아직 모든 행동을 해보지 않았다면 현 노드를 Leaf 노드로 두고 램던하게 행동하여 이동한 state로 expansion한다.

현 노드에서 모든 행동을 해봤다면 주어진 정보로부터 Terminal노드까지 최적 노드를 골라 Selection 한다.

TicTaeToe 환경은 search space가 작아서 모든 행동을 다 해봤을 때만 최적 노드로 Selection하고 큰 그 다음에 노드에 대해서 expansion하지만 큰 경우에는 이 부분의 수정이 필요하다.

1
2
3
4
5
6
7
8
    def expand(self):
        action = self.untried_actions.pop()
        next_state = self.state.move(action)
        child_node = TwoPlayersGameMonteCarloTreeSearchNode(
            next_state, parent=self
        )
        self.children.append(child_node)
        return child_node

expansion은 위와같이 가능한 action중 하나를 가져와서 다음 노드로 이동하도록 정의된다.

1
2
3
4
5
6
7
8
9
10
    def best_child(self, c_param=1.4):
        choices_weights = [
            (c.q / c.n) + c_param * np.sqrt((2 * np.log(self.n) / c.n))
            for c in self.children]
        return self.children[np.argmax(choices_weights)]
            
    def q(self):
        wins = self._results[self.parent.state.next_to_move]
        loses = self._results[-1 * self.parent.state.next_to_move]
        return wins - loses

selection은 tree policy를 사용하여 child 노드를 선택한다.

주로 기준으로 UCT(Upper Confidence Boundary of Tree) 를 사용하는데

best_child를 위와 같은 UCB1(Upper Confidence Bound 1)이라는 함수를 통하여 선택한다.

첫번째 항은 같은 방문횟수에 대해서 많이 이기고 적게 지는 노드를 장려하므로 exploration 을 위한 항이고,

두번째항은 부모노드의 방문횟수를 고정했을 때 자식노드의 방문횟수가 적은 노드를 장려하므로 exploration 위한 항이다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    x = 1
    o = -1
    
    def is_terminal_node(self):
        return self.state.is_game_over()
    
    def is_game_over(self):
        return self.game_result is not None
    
    @property
    def game_result(self):
        # check if game is over
        for i in range(self.board_size - self.win + 1):
            rowsum = np.sum(self.board[i:i+self.win], 0)
            colsum = np.sum(self.board[:,i:i+self.win], 1)
            if rowsum.max() == self.win or colsum.max() == self.win:
                return self.x
            if rowsum.min() == -self.win or colsum.min() == -self.win:
                return self.o
        for i in range(self.board_size - self.win + 1):
            for j in range(self.board_size - self.win + 1):
                sub = self.board[i:i+self.win,j:j+self.win]
                diag_sum_tl = sub.trace()
                diag_sum_tr = sub[::-1].trace()
                if diag_sum_tl == self.win or diag_sum_tr == self.win:
                    return self.x
                if diag_sum_tl == -self.win or diag_sum_tr == -self.win:
                    return self.o

Selction과정 중 Terminal 조건에 걸리면 현 노드를 반환한다.

게임이 종료에 대한 판단은 위와 같이 연속적으로 놓인 패의 개수를 확인하여 판단하며 누가 이겼는지 결과를 반환한다.

Simulation

시뮬레이션 과정을 살펴보기 위해서 rollout 함수로 들어가보자.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
    def rollout(self):
        current_rollout_state = self.state
        while not current_rollout_state.is_game_over():
            possible_moves = current_rollout_state.get_legal_actions()
            action = self.rollout_policy(possible_moves)
            current_rollout_state = current_rollout_state.move(action)
        return current_rollout_state.game_result
    
    def get_legal_actions(self):
        # indices = np.where(np.count_nonzero(self.board,axis=1) != self.board_size)[0]
        # actions = [TicTacToeMove(i, np.count_nonzero(self.board[i,:]), self.next_to_move) for i in indices]
        indices = np.where(self.board == 0)
        actions = [TicTacToeMove(coords[0], coords[1], self.next_to_move) for coords in list(zip(indices[0], indices[1]))]
        return actions
    
    def rollout_policy(self, possible_moves):
        return possible_moves[np.random.randint(len(possible_moves))]

만약 현 노드가 게임종료된 노드이면 시뮬레이션 할 필요없이 종료 결과를 반환한다.

종료된 상태가 아니라면 현 노드에서 가능한 모든 행동중 하나를 임의로 뽑아서 종료시점까지 rollout하며 이동하고 결과를 반환한다.

오픈소스에서는 get_legal_actions을 찾을 때 Seach space를 줄이기 위해서 column에 대한 traving 조건을 항상 오름차순으로 되도록 했다.

그렇게 하면 column방향의 탐색이 좁아져서 성능이 저하되므로 수정했고 환경에서 move 조건을 아래와 같이 수정해야 했다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
    def is_move_legal(self, move):
        # check if correct player moves
        if move.value != self.next_to_move:
            return False

        # check if inside the board on x-axis
        x_in_range = (0 <= move.x_coordinate < self.board_size)
        if not x_in_range:
            return False

        # check if inside the board on y-axis
        y_in_range = (0 <= move.y_coordinate < self.board_size)
        if not y_in_range:
            return False

        # finally check if board field not occupied yet
        not_occ = self.board[move.x_coordinate, move.y_coordinate] == 0
        # not_occ = not_occ or (move.y_coordinate == 0 or self.board[move.x_coordinate, move.y_coordinate-1] != 0)
        return not_occ

위와 같이 수정해야 탐색영역을 다양한 방식으로 탐색하고 5*5 판에서도 항상 무승부가 나도록 된다.

Back Propagation

1
2
3
4
5
6
    def backpropagate(self, result):
        self._number_of_visits += 1.
        self._results[result] += 1.
        if self.parent:
            self.parent.backpropagate(result)

시뮬레이션 결과에 대해서 back propation 을 통해서 얻은 정보들이 업데이트 된다.

UCT를 사용하기 위한 visitation count 와 win count를 시뮬레이션에서 방문했던 모든 노드를 역순으로 돌면서 업데이트 한다.

1
    return self.root.best_child(c_param=0.)

최종적으로 모든 정보가 업데이트 되고나면 c_param=0 으로 설정하여 selection을 위한 exploration 부분을 떼어네고 최적 행동만 선택하여 경로를 반환한다.

고찰

MCTS 는 Tree를 만들어 나가면서 동시에 최선의 수를 찾아가는 방법이다.

환경에 대한 모델이 없이도 방문에 대한 frequecy와 얻어지는 정보를 기반으로 전체 노드를 업데이트 할 수 있는 방법이다.

하지만 시뮬레이션을 여러번 해봐야 정보를 얻을 수 있기 때문에서 실제 문제에 적용이 어려울 수 있다.

MCTS를 문제에 적용하기 위해서는 언제 시뮬레이션을 할 지와 어떻게 시뮬레이션을 할지가 중요하다.

정석이 Leaf 노드에 도달하면 그 노드에 대해서 expansion을 하고 k번 시뮬레이션을 하는 것이지만,

이것을 변경하여 멀티 쓰레드를 사용하면 효과적으로 병렬 시뮬레이션을 적용할 수도 있다.

어디서 쓰레드를 분기하느냐에 따라서 다르게 적용 할 수 있다.

Leaf 노드부터, Root 노드로 부터, 그리고 전제 노드로 부터 3가지 경우가 있다.

다음에는 MCTS를 사용하여 자율주행에 어떻게 적용하는지 알아보자.

Tags:

Categories:

Updated:

Leave a comment