-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
208 lines (172 loc) · 7.83 KB
/
app.py
File metadata and controls
208 lines (172 loc) · 7.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import streamlit as st
import os
import numpy as np
import torch
import time
from Games.ConnectFour.ConnectFour import ConnectFour
from Games.ConnectFour.ConnectFourNN import ResNet as ConnectFourResNet
from Games.TicTacToe.TicTacToe import TicTacToe
from Games.TicTacToe.TicTacToeNN import ResNet as TicTacToeResNet
from Alpha_MCTS import Alpha_MCTS
st.set_page_config(page_title="AlphaZero UI", layout="centered", page_icon="🎮")
st.markdown("""
<style>
div.stButton > button {
height: 80px;
font-size: 30px;
font-weight: bold;
border-radius: 12px;
transition: all 0.3s ease;
}
div.stButton > button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(0,0,0,0.15);
}
.main-header {
text-align: center;
margin-bottom: 2rem;
font-size: 3rem;
font-weight: 800;
background: -webkit-linear-gradient(45deg, #FF4B4B, #FF9090);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_model(game_name):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if game_name == "ConnectFour":
game = ConnectFour()
model = ConnectFourResNet(game, 9, 128, device)
else:
game = TicTacToe()
model = TicTacToeResNet(game, 9, 128, device)
model.eval()
model_path = os.path.join(os.getcwd(), "Games", game_name, "models_n_optimizers", "model.pt")
if os.path.exists(model_path):
try:
model.load_state_dict(torch.load(model_path, map_location=device))
except Exception as e:
st.warning(f"Failed to load model from {model_path}: {e}")
else:
st.warning(f"Model path does not exist: {model_path}")
return game, model
def init_state(game_name):
st.session_state.game_name = game_name
st.session_state.board_state = None
st.session_state.player = 1
st.session_state.game_over = False
st.session_state.winner = None
st.sidebar.title("AlphaZero Play")
game_selection = st.sidebar.selectbox("Select Game", ["ConnectFour", "TicTacToe"])
st.sidebar.markdown("### Hyperparameters")
no_of_searches = st.sidebar.slider("Number of MCTS Searches", min_value=10, max_value=20000, value=600, step=10, help="More searches = stronger but slower AI.")
exploration_constant = st.sidebar.slider("Exploration Constant (C)", min_value=0.1, max_value=5.0, value=1.0, step=0.1, help="Higher values favor exploration.")
temperature = st.sidebar.slider("Temperature", min_value=0.1, max_value=2.0, value=1.0, step=0.1, help="Controls exploration during policy evaluation.")
adversarial = st.sidebar.checkbox("Adversarial (Zero-Sum)", value=True)
root_randomness = st.sidebar.checkbox("Root Randomness (Dirichlet Noise)", value=False)
mcts_args = {
"ADVERSARIAL": adversarial,
"ROOT_RANDOMNESS": root_randomness,
"TEMPERATURE": temperature,
"NO_OF_SEARCHES": no_of_searches,
"EXPLORATION_CONSTANT": exploration_constant,
}
if root_randomness:
mcts_args["DIRICHLET_EPSILON"] = st.sidebar.slider("Dirichlet Epsilon", 0.0, 1.0, 0.25)
mcts_args["DIRICHLET_ALPHA"] = st.sidebar.slider("Dirichlet Alpha", 0.01, 1.0, 0.3)
if "game_name" not in st.session_state or st.session_state.game_name != game_selection:
init_state(game_selection)
game, model = load_model(game_selection)
mcts = Alpha_MCTS(game, mcts_args, model)
if st.session_state.board_state is None:
st.session_state.board_state = game.initialise_state()
st.sidebar.markdown("---")
if st.sidebar.button("Reset Game"):
init_state(game_selection)
st.session_state.board_state = game.initialise_state()
st.rerun()
st.sidebar.markdown("### Rules")
st.sidebar.info(
"You are Player 1 (playing first).\n"
"AlphaZero is Player -1.\n\n"
"For **Tic Tac Toe**: Click on a cell to place your X.\n"
"\n"
"For **Connect Four**: Click the ⬇️ button above a column to drop your piece."
)
st.markdown(f"<h1 class='main-header'>{game_selection} vs AlphaZero</h1>", unsafe_allow_html=True)
if st.session_state.game_over:
if st.session_state.winner == 1:
st.success("You Won! Amazing job playing against AlphaZero!")
elif st.session_state.winner == -1:
st.error("AlphaZero Won! Better luck next time!")
else:
st.info("It's a Draw! Well played.")
def trigger_rerun():
time.sleep(0.1)
st.rerun()
def check_ai_move():
if not st.session_state.game_over and st.session_state.player == -1:
with st.spinner(f"AlphaZero is thinking ({no_of_searches} searches)..."):
neutral_state = game.change_perspective(st.session_state.board_state, st.session_state.player)
mcts_probs = mcts.search(neutral_state)
action = np.argmax(mcts_probs)
st.session_state.board_state = game.make_move(
st.session_state.board_state.copy(), action, st.session_state.player
)
# Check terminal
is_terminal, value = game.know_terminal_value(st.session_state.board_state, action)
if is_terminal:
st.session_state.game_over = True
st.session_state.winner = st.session_state.player if value == 1 else 0
else:
st.session_state.player = game.get_opponent(st.session_state.player)
trigger_rerun()
def make_move(action):
if not st.session_state.game_over and st.session_state.player == 1:
valid_moves = game.get_valid_moves(st.session_state.board_state)
if isinstance(valid_moves, np.ndarray) and valid_moves.ndim > 1:
valid_moves = valid_moves.reshape(-1)
if valid_moves[action] == 1:
st.session_state.board_state = game.make_move(
st.session_state.board_state.copy(), action, st.session_state.player
)
is_terminal, value = game.know_terminal_value(st.session_state.board_state, action)
if is_terminal:
st.session_state.game_over = True
st.session_state.winner = st.session_state.player if value == 1 else 0
else:
st.session_state.player = game.get_opponent(st.session_state.player)
trigger_rerun()
container = st.container()
with container:
if game_selection == "TicTacToe":
state = st.session_state.board_state
for row in range(3):
cols = st.columns([1, 1, 1, 1, 1])
for col in range(3):
val = state[row, col]
display_str = "❌" if val == 1 else "⭕" if val == -1 else " "
action = row * 3 + col
with cols[col + 1]:
if st.button(display_str, key=f"btn_{action}", disabled=st.session_state.game_over or val != 0 or st.session_state.player != 1, use_container_width=True):
make_move(action)
elif game_selection == "ConnectFour":
state = st.session_state.board_state
cols = st.columns(7)
valid_moves = game.get_valid_moves(state)
for col in range(7):
with cols[col]:
if st.button("⬇️", key=f"drop_{col}", disabled=st.session_state.game_over or valid_moves[col] == 0 or st.session_state.player != 1, use_container_width=True):
make_move(col)
st.markdown("---")
colors = {1: "🔴", -1: "🟡", 0: "⚫"}
for row in range(6):
cols = st.columns(7)
for col in range(7):
val = state[row, col]
with cols[col]:
st.markdown(f"<div style='text-align:center; font-size:40px;'>{colors[val]}</div>", unsafe_allow_html=True)
if not st.session_state.game_over and st.session_state.player == -1:
check_ai_move()