snoop2head commited on
Commit
0de5849
1 Parent(s): 482e39a

initial commit

Browse files
Files changed (3) hide show
  1. ai.py +39 -0
  2. app.py +399 -0
  3. requirements.txt +6 -0
ai.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from transformers import GPT2LMHeadModel
4
+
5
+
6
+ def load_model(model_name: str = "snoop2head/Gomoku-GPT2") -> GPT2LMHeadModel:
7
+ gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
8
+ return gpt2
9
+
10
+
11
+ BOS_TOKEN_ID = 401
12
+ PAD_TOKEN_ID = 402
13
+ EOS_TOKEN_ID = 403
14
+
15
+
16
+ def generate_gpt2(model: GPT2LMHeadModel, input_ids: torch.LongTensor) -> list:
17
+ """
18
+ input_ids: [batch_size, seq_len] torch.LongTensor
19
+ output_ids: [seq_len] list
20
+ """
21
+ output_ids = model.generate(
22
+ input_ids,
23
+ max_length=128,
24
+ num_beams=5,
25
+ temperature=0.7,
26
+ pad_token_id=PAD_TOKEN_ID,
27
+ eos_token_id=EOS_TOKEN_ID,
28
+ )
29
+ return output_ids.squeeze().tolist()
30
+
31
+
32
+ def change_to_1d_coordinate(board: np.ndarray, x: int, y: int) -> int:
33
+ """change 2d coordinate to 1d coordinate"""
34
+ return x * board.shape[1] + y
35
+
36
+
37
+ def change_to_2d_coordinate(board: np.ndarray, coordinate: int) -> tuple:
38
+ """change 1d coordinate to 2d coordinate"""
39
+ return (coordinate // board.shape[1], coordinate % board.shape[1])
app.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ - This is a simple gomoku game built with Streamlit by TeddyHuang-00 (huang_nan_2019@pku.edu.cn).
3
+ - For Gomoku-GPT2, please refer to Young-Jin Ahn (young_ahn@yonsei.ac.kr).
4
+
5
+ Shared under MIT license
6
+ """
7
+
8
+ import time
9
+ from copy import deepcopy
10
+ from uuid import uuid4
11
+
12
+ import torch
13
+ import numpy as np
14
+ import streamlit as st
15
+ from scipy.signal import convolve
16
+ from streamlit import session_state
17
+ from streamlit_server_state import server_state, server_state_lock
18
+
19
+ from ai import (
20
+ BOS_TOKEN_ID,
21
+ generate_gpt2,
22
+ load_model,
23
+ )
24
+
25
+
26
+ # Utils
27
+ class Room:
28
+ def __init__(self, room_id) -> None:
29
+ self.ROOM_ID = room_id
30
+ self.BOARD = np.zeros(shape=(20, 20), dtype=int)
31
+ self.PLAYER = _BLACK
32
+ self.TURN = self.PLAYER
33
+ self.HISTORY = (0, 0)
34
+ self.WINNER = _BLANK
35
+ self.TIME = time.time()
36
+ self.COORDINATE_1D = [BOS_TOKEN_ID]
37
+
38
+
39
+ gpt2 = load_model()
40
+
41
+
42
+ _BLANK = 0
43
+ _BLACK = 1
44
+ _WHITE = -1
45
+ _PLAYER_SYMBOL = {
46
+ _WHITE: "⚪",
47
+ _BLANK: "➕",
48
+ _BLACK: "⚫",
49
+ }
50
+ _PLAYER_COLOR = {
51
+ _WHITE: "Gomoku-GPT",
52
+ _BLANK: "Blank",
53
+ _BLACK: "YOU HUMAN",
54
+ }
55
+ _HORIZONTAL = np.array(
56
+ [
57
+ [0, 0, 0, 0, 0],
58
+ [0, 0, 0, 0, 0],
59
+ [1, 1, 1, 1, 1],
60
+ [0, 0, 0, 0, 0],
61
+ [0, 0, 0, 0, 0],
62
+ ]
63
+ )
64
+ _VERTICAL = np.array(
65
+ [
66
+ [0, 0, 1, 0, 0],
67
+ [0, 0, 1, 0, 0],
68
+ [0, 0, 1, 0, 0],
69
+ [0, 0, 1, 0, 0],
70
+ [0, 0, 1, 0, 0],
71
+ ]
72
+ )
73
+ _DIAGONAL_UP_LEFT = np.array(
74
+ [
75
+ [1, 0, 0, 0, 0],
76
+ [0, 1, 0, 0, 0],
77
+ [0, 0, 1, 0, 0],
78
+ [0, 0, 0, 1, 0],
79
+ [0, 0, 0, 0, 1],
80
+ ]
81
+ )
82
+ _DIAGONAL_UP_RIGHT = np.array(
83
+ [
84
+ [0, 0, 0, 0, 1],
85
+ [0, 0, 0, 1, 0],
86
+ [0, 0, 1, 0, 0],
87
+ [0, 1, 0, 0, 0],
88
+ [1, 0, 0, 0, 0],
89
+ ]
90
+ )
91
+
92
+ _ROOM_COLOR = {
93
+ True: _BLACK,
94
+ False: _WHITE,
95
+ }
96
+
97
+ # Initialize the game
98
+ if "ROOM" not in session_state:
99
+ session_state.ROOM = Room("local")
100
+ if "OWNER" not in session_state:
101
+ session_state.OWNER = False
102
+
103
+ # Check server health
104
+ if "ROOMS" not in server_state:
105
+ with server_state_lock["ROOMS"]:
106
+ server_state.ROOMS = {}
107
+
108
+ # # Layout
109
+ # Main
110
+ TITLE = st.empty()
111
+ ROUND_INFO = st.empty()
112
+ BOARD_PLATE = [
113
+ [cell.empty() for cell in st.columns([1 for _ in range(20)])] for _ in range(20)
114
+ ]
115
+ WAIT_FOR_OPPONENT = st.empty()
116
+
117
+ # Sidebar
118
+ SCORE_TAG = st.sidebar.empty()
119
+ SCORE_PLATE = st.sidebar.columns(2)
120
+ PLAY_MODE_INFO = st.sidebar.container()
121
+ MULTIPLAYER_TAG = st.sidebar.empty()
122
+ with st.sidebar.container():
123
+ ANOTHER_ROUND = st.empty()
124
+ RESTART = st.empty()
125
+ EXIT = st.empty()
126
+ GAME_INFO = st.sidebar.container()
127
+
128
+
129
+ # Draw the board
130
+ def gomoku():
131
+ """
132
+ Draw the board.
133
+
134
+ Handle the main logic.
135
+ """
136
+
137
+ # Restart the game
138
+ def restart() -> None:
139
+ """
140
+ Restart the game.
141
+ """
142
+ session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
143
+
144
+ # Continue new round
145
+ def another_round() -> None:
146
+ """
147
+ Continue new round.
148
+ """
149
+ session_state.ROOM = deepcopy(session_state.ROOM)
150
+ session_state.ROOM.BOARD = np.zeros(shape=(20, 20), dtype=int)
151
+ session_state.ROOM.PLAYER = -session_state.ROOM.PLAYER
152
+ session_state.ROOM.TURN = session_state.ROOM.PLAYER
153
+ session_state.ROOM.WINNER = _BLANK
154
+ session_state.ROOM.COORDINATE_1D = [BOS_TOKEN_ID]
155
+
156
+ # Room status sync
157
+ def sync_room() -> bool:
158
+ room_id = session_state.ROOM.ROOM_ID
159
+ if room_id not in server_state.ROOMS.keys():
160
+ session_state.ROOM = Room("local")
161
+ return False
162
+ elif server_state.ROOMS[room_id].TIME == session_state.ROOM.TIME:
163
+ return False
164
+ elif server_state.ROOMS[room_id].TIME < session_state.ROOM.TIME:
165
+ # Only acquire the lock when writing to the server state
166
+ with server_state_lock["ROOMS"]:
167
+ server_rooms = server_state.ROOMS
168
+ server_rooms[room_id] = session_state.ROOM
169
+ server_state.ROOMS = server_rooms
170
+ return True
171
+ else:
172
+ session_state.ROOM = server_state.ROOMS[room_id]
173
+ return True
174
+
175
+ # Check if winner emerge from move
176
+ def check_win() -> int:
177
+ """
178
+ Use convolution to check if any player wins.
179
+ """
180
+ vertical = convolve(
181
+ session_state.ROOM.BOARD,
182
+ _VERTICAL,
183
+ mode="same",
184
+ )
185
+ horizontal = convolve(
186
+ session_state.ROOM.BOARD,
187
+ _HORIZONTAL,
188
+ mode="same",
189
+ )
190
+ diagonal_up_left = convolve(
191
+ session_state.ROOM.BOARD,
192
+ _DIAGONAL_UP_LEFT,
193
+ mode="same",
194
+ )
195
+ diagonal_up_right = convolve(
196
+ session_state.ROOM.BOARD,
197
+ _DIAGONAL_UP_RIGHT,
198
+ mode="same",
199
+ )
200
+ if (
201
+ np.max(
202
+ [
203
+ np.max(vertical),
204
+ np.max(horizontal),
205
+ np.max(diagonal_up_left),
206
+ np.max(diagonal_up_right),
207
+ ]
208
+ )
209
+ == 5 * _BLACK
210
+ ):
211
+ winner = _BLACK
212
+ elif (
213
+ np.min(
214
+ [
215
+ np.min(vertical),
216
+ np.min(horizontal),
217
+ np.min(diagonal_up_left),
218
+ np.min(diagonal_up_right),
219
+ ]
220
+ )
221
+ == 5 * _WHITE
222
+ ):
223
+ winner = _WHITE
224
+ else:
225
+ winner = _BLANK
226
+ return winner
227
+
228
+ # Triggers the board response on click
229
+ def handle_click(x, y):
230
+ """
231
+ Controls whether to pass on / continue current board / may start new round
232
+ """
233
+ if session_state.ROOM.BOARD[x][y] != _BLANK:
234
+ pass
235
+ elif (
236
+ session_state.ROOM.ROOM_ID in server_state.ROOMS.keys()
237
+ and _ROOM_COLOR[session_state.OWNER]
238
+ != server_state.ROOMS[session_state.ROOM.ROOM_ID].TURN
239
+ ):
240
+ sync_room()
241
+
242
+ # normal play situation
243
+ elif session_state.ROOM.WINNER == _BLANK:
244
+ session_state.ROOM = deepcopy(session_state.ROOM)
245
+
246
+ session_state.ROOM.BOARD[x][y] = session_state.ROOM.TURN
247
+ session_state.ROOM.COORDINATE_1D.append(x * 20 + y)
248
+
249
+ session_state.ROOM.TURN = -session_state.ROOM.TURN
250
+ session_state.ROOM.WINNER = check_win()
251
+ session_state.ROOM.HISTORY = (
252
+ session_state.ROOM.HISTORY[0]
253
+ + int(session_state.ROOM.WINNER == _WHITE),
254
+ session_state.ROOM.HISTORY[1]
255
+ + int(session_state.ROOM.WINNER == _BLACK),
256
+ )
257
+ session_state.ROOM.TIME = time.time()
258
+
259
+ # Draw board
260
+ def draw_board(response: bool):
261
+ """construct each buttons for all cells of the board"""
262
+
263
+ if response and session_state.ROOM.TURN == 1: # human turn
264
+ print("Your turn")
265
+ # construction of clickable buttons
266
+ for i, row in enumerate(session_state.ROOM.BOARD):
267
+ for j, cell in enumerate(row):
268
+ BOARD_PLATE[i][j].button(
269
+ _PLAYER_SYMBOL[cell],
270
+ key=f"{i}:{j}",
271
+ on_click=handle_click,
272
+ args=(i, j),
273
+ )
274
+
275
+ elif response and session_state.ROOM.TURN == -1: # AI turn
276
+ print("AI's turn")
277
+ gpt_predictions = generate_gpt2(
278
+ gpt2,
279
+ torch.tensor(session_state.ROOM.COORDINATE_1D).unsqueeze(0),
280
+ )
281
+ print(gpt_predictions)
282
+ gpt_response = gpt_predictions[len(session_state.ROOM.COORDINATE_1D)]
283
+ gpt_i, gpt_j = gpt_response // 20, gpt_response % 20
284
+ print(gpt_i, gpt_j)
285
+ session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
286
+ session_state.ROOM.COORDINATE_1D.append(gpt_i * 20 + gpt_j)
287
+
288
+ # construction of clickable buttons
289
+ for i, row in enumerate(session_state.ROOM.BOARD):
290
+ for j, cell in enumerate(row):
291
+ if (
292
+ i * 20 + j
293
+ in gpt_predictions[: len(session_state.ROOM.COORDINATE_1D)]
294
+ ):
295
+ # disable click for GPT choices
296
+ BOARD_PLATE[i][j].button(
297
+ _PLAYER_SYMBOL[cell],
298
+ key=f"{i}:{j}",
299
+ on_click=False,
300
+ args=(i, j),
301
+ )
302
+ else:
303
+ # enable click for other cells available for human choices
304
+ BOARD_PLATE[i][j].button(
305
+ _PLAYER_SYMBOL[cell],
306
+ key=f"{i}:{j}",
307
+ on_click=handle_click,
308
+ args=(i, j),
309
+ )
310
+
311
+ # change turn
312
+ session_state.ROOM.TURN = -session_state.ROOM.TURN
313
+ session_state.ROOM.WINNER = check_win()
314
+ session_state.ROOM.HISTORY = (
315
+ session_state.ROOM.HISTORY[0]
316
+ + int(session_state.ROOM.WINNER == _WHITE),
317
+ session_state.ROOM.HISTORY[1]
318
+ + int(session_state.ROOM.WINNER == _BLACK),
319
+ )
320
+ session_state.ROOM.TIME = time.time()
321
+
322
+ if not response or session_state.ROOM.WINNER != _BLANK:
323
+ print("Game over")
324
+ for i, row in enumerate(session_state.ROOM.BOARD):
325
+ for j, cell in enumerate(row):
326
+ BOARD_PLATE[i][j].write(
327
+ _PLAYER_SYMBOL[cell],
328
+ key=f"{i}:{j}",
329
+ )
330
+
331
+ # Game process control
332
+ def game_control():
333
+ if session_state.ROOM.WINNER != _BLANK:
334
+ draw_board(False)
335
+ else:
336
+ draw_board(True)
337
+ if session_state.ROOM.WINNER != _BLANK or 0 not in session_state.ROOM.BOARD:
338
+ ANOTHER_ROUND.button(
339
+ "Play Next round!",
340
+ on_click=another_round,
341
+ help="Clear board and swap first player",
342
+ )
343
+ if session_state.ROOM.ROOM_ID == "local" or session_state.OWNER:
344
+ RESTART.button(
345
+ "Reset",
346
+ on_click=restart,
347
+ help="Clear the board as well as the scores",
348
+ )
349
+
350
+ # Infos
351
+ def draw_info() -> None:
352
+ # Text information
353
+ TITLE.subheader("**🤖 Do you wanna have a bad time?**")
354
+ PLAY_MODE_INFO.write("---\n\n**You are Black, AI is White.**")
355
+ GAME_INFO.markdown(
356
+ """
357
+ ---
358
+
359
+ ## Freestyle Gomoku game.
360
+
361
+
362
+ <a href="https://en.wikipedia.org/wiki/Gomoku#Freestyle_Gomoku" style="color:#FFFFFF">Freestyle Gomoku</a>
363
+
364
+ - no restrictions
365
+ - no regrets
366
+ - swap players after one round is over
367
+
368
+ ##### Design by <a href="https://github.com/TeddyHuang-00" style="color:#FFFFFF">TeddyHuang-00</a> • <a href="https://github.com/TeddyHuang-00/streamlit-gomoku" style="color:#FFFFFF">Github repo</a>
369
+ ##### Gomoku-GPT by <a href="https://github.com/snoop2head" style="color:#FFFFFF">snoop2head</a> • <a href="https://github.com/snoop2head/" style="color:#FFFFFF">Github repo</a>
370
+
371
+ """,
372
+ unsafe_allow_html=True,
373
+ )
374
+ # History scores
375
+ SCORE_TAG.subheader("Scores")
376
+ SCORE_PLATE[0].metric("Gomoku-GPT", session_state.ROOM.HISTORY[0])
377
+ SCORE_PLATE[1].metric("Black", session_state.ROOM.HISTORY[1])
378
+
379
+ # Additional information
380
+ if session_state.ROOM.WINNER != _BLANK:
381
+ st.balloons()
382
+ ROUND_INFO.write(
383
+ f"#### **{_PLAYER_COLOR[session_state.ROOM.WINNER]} wins!**\n**Click buttons on the left for more plays.**"
384
+ )
385
+
386
+ elif 0 not in session_state.ROOM.BOARD:
387
+ ROUND_INFO.write("#### **Tie**")
388
+ else:
389
+ ROUND_INFO.write(
390
+ f"#### **{_PLAYER_SYMBOL[session_state.ROOM.TURN]} {_PLAYER_COLOR[session_state.ROOM.TURN]}'s turn...**"
391
+ )
392
+
393
+ # The main game loop
394
+ game_control()
395
+ draw_info()
396
+
397
+
398
+ if __name__ == "__main__":
399
+ gomoku()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy==1.23.5
2
+ scipy==1.10.1
3
+ torch==2.0.0
4
+ streamlit==0.89.0
5
+ streamlit_server_state==0.6.1
6
+ transformers==4.27.4