Training a DQN agent to play Connect Four
During the development of a zero-knowledge application called zk Connect Four, the appropriate context was given for the implementation of a reinforcement learning agent.
Reinforcement learning is a machine learning paradigm in which an agent learns from its environment through trial and error. That is, it is a computational approach to learning from actions.
The goal is to learn what the most appropriate actions are in each game state to maximise a reward. To this end, a DQN agent and a double DQN (DDQN) agent are implemented in this notebook to optimize an approximator in the Connect Four game using PyTorch.
In this article
Setup
Fundamental concepts
The main actors in the reinforcement learning paradigm are the environment and the agent. The environment is the context in which the agent lives and interacts with.
In each iteration, the agent receives an observation of the current state of the game, decides the best action to take, and acts on the environment.
In this way, a change in the state of the environment is generated and a reward is obtained from this transition.
The objective of this exercise is to obtain a policy that is capable of maximising this reward, thus obtaining the best action for each transition and promoting an optimal approach to the Connect Four game.
Below we introduce the main classes that underpin this paradigm and that will allow us to establish the actors and the context of the game.
Environment
The environment implementation consists of a class whose main methods and tasks are the following:
-
get_valid_actions: provides a list with all valid actions given the current state, i.e., the board at that moment in the game. -
step- undertakes a transition with the action received from the agent (an index of a column on the board), checks whether it has been won or tied, and returns a reward depending on the result. -
switch_turn: changes the game turn between both players. -
reset- initializes the state and all game variables and returns an initial state.
DQN and DDQN agents
DQN and DDQN are essentially algorithms that consist of learning an approximator for a Q function action-value.
In our case, we will try to learn the strategies of both players of the Connect Four game, so instead of being an agent that is only trained on one strategy, it would be a multi-agent reinforcement learning (MARL) modality.
Then, an alternating Markov game scenario is presented, or a zero-sum game where one player's gain is proportional to the other's loss.
The implementation of the class used to instantiate the agents consists of the following methods:
In each iteration, the agent is prompted to provide us with an action according to an epsilon-greedy strategy to treat the exploration-exploitation dilemma.
valid_actions = env.get_valid_actions()
action = agent.act(state, valid_actions,
num_steps, enforce_valid_action)
Sometimes our model will predict the action to be taken and other times an action will be sampled.
As can be seen in the graph below, in each iteration the probability of new actions being explored or learnt actions being exploited will decay exponentially based on the rate of decay eps_decay.
# epilson decay graph
# num episodes times avg steps per episode
num_steps = 50000 * 25
eps_start = 0.9
eps_end = 0.05
eps_decay = num_steps / 7.5
steps = np.arange(num_steps)
epsilon = eps_end + (eps_start - eps_end) * np.exp(-1 * steps / eps_decay)
plt.rcParams["figure.figsize"] = (20, 4)
plt.plot(steps, epsilon)
plt.title('Epsilon decay')
plt.xlabel('Step')
plt.ylabel('Epsilon')
plt.show()This strategy is implemented in the method act of the DQN agent below.
If the sample value is greater than the epsilon threshold, the best action given the last observation of the environment will be approximated (exploitation of learnt actions).
Otherwise, an action will be chosen (exploration) from among the valid actions uniformly.
The option is given, through
enforce_valid_action, to ensure that the chosen action is valid, or to urge the model to learn which columns of the board can qualify for a new piece in each turn, i.e., to learn this rule of the game.
The data used to train the agent is generated from scratch during each training. This data are stored in an object of class ExperienceReplay and are randomly sampled with the aim of optimizing a strategy or action plan (policy).
The observations in each iteration are appended as named tuples (itertools' named_tuple) to a sequence list (deque, or double-ended queue) with limited capacity, where observations can be inserted and deleted from both ends, thus creating a gaming experience.
Each observation includes a state (the game board at that precise moment), an action (the index of the column that has been chosen), a next state (the game board after considering the last action, or None if it is a final state, i.e., if a winning action has been taken) and a reward.
Once we have obtained and memorized more observations than the number of batch_size we use, we will proceed to optimize the model.
In our example, it is only optimized once per episode since the values of the transitions will sometimes depend on previous transitions.
For example, when a player wins the game the player gets a positive reward, so for the state transition in this iteration a tuple is generated as follows:
transition = Transition(state, action, None, reward)
The transition would consist of the state in t, the action taken, the state in t+1 (as it is a final state, it is assigned None) and the reward obtained.
What happens is that during the generation of the observation in the previous iteration, we did not yet know about this outcome and therefore the reward was not negative for having lost the game.
For this reason, once each training game ends, this relationship of rewards is resolved to be saved later in memory.
Therefore, the transition in t-1 from the transition above would be:
transition_t-1 = Transition(state, action, next_state, reward + env.params['rewards']['loss'])
The loss reward is added to the assigned reward and not directly assigned since it is possible to configure an extension reward in the game as a hyperparameter.
The optimization process is shown below:
First, a number of observations are sampled in memory in the method optimize, and all state and action values of the observations are concatenated into separate arrays.
Furthermore, both the states in t and t+1 get a new dimension since a board transformation is applied using the get_two_channels function.
The idea is that for each board with value 0, 1 or 2, two matrices with binary values are obtained, one for each player.
These two matrices will be the two input channels per observation that our convolutional neural net will get as input.
A logical array is then generated that will be used as a mask to assign Q values to non-final states, while final states will have a value of zero.
In the case that a DQN agent is used, only the target model will be used to obtain the values of Q, obtaining the values of maximum magnitude among all the actions.
While with a DDQN agent, both our policy model and our "fixed" target model are used to obtain the time difference values of Q that we want to approach in order to solve a possible overestimation of these values.
Subsequently, the long-term profit that can be anticipated is calculated. For this, a discount factor gamma is applied and is subtracted from the reward for that transition (rather than added as in the Bellman equation as suggested in this Kaggle notebook).
In this way we adapt our optimization for the learning of both strategies to the alternating zero-sum game at hand.
It is at this moment when the parameters of our policy are optimized. To do this, the loss is calculated (in our case the criterion torch.nn.HuberLoss is used) to minimize the error of the temporal difference δ, null values are assigned to the gradients to later be computed using backpropagation, a clipping of them is applied if the training has been configured that way and the parameters are updated based on the calculated gradient.
Once the parameters of the policy have been optimized during training, the parameters of our target are considered to be updated. For this, two strategies are implemented: soft update y hard update.
The former will update the parameters gradually in each episode only obtaining a percentage of the parameters of policy based on the TAU value considered as a hyperparameter in training.
For example, if TAU is 0.001, target will use 0.1% of the parameters of policy_ and will keep the rest (99.9%).
On the contrary, through a hard update, all parameters will be copied at once if a number of iterations, to be configured, is reached.
Finally, three methods are implemented to save transitions in memory, save current parameters of policy to configure checkpoints during training, and to load the weights of a saved model at the beginning of a training.
Note that the optimizer is instantiated once the weights of the previously saved model have been loaded.
Modeling
The model chosen to learn the Connect Four game consists of two convolutional layers with ReLU activation and two fully connected layers (the first also with ReLU activation) at the output.
In the same object of class ConnectFourNet two modules are instantiated, one for policy and another for target. The latter is copied from the former once its parameters (weights and bias) have been initialized.
In this way we can save the parameters of both models during training in the same state dictionary.
ConnectFourNet(
(policy): Sequential(
(0): Conv2d(2, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Flatten(start_dim=1, end_dim=-1)
(5): Linear(in_features=672, out_features=128, bias=True)
(6): ReLU()
(7): Linear(in_features=128, out_features=7, bias=True)
)
(target): Sequential(
(0): Conv2d(2, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Flatten(start_dim=1, end_dim=-1)
(5): Linear(in_features=672, out_features=128, bias=True)
(6): ReLU()
(7): Linear(in_features=128, out_features=7, bias=True)
)
)
ConnectFourNet(out_features=7).layer_summary([512, 2, 6, 7])Conv2d output shape: torch.Size([512, 16, 6, 7])
Conv2d output shape: torch.Size([512, 16, 6, 7])
Flatten output shape: torch.Size([512, 672])
Linear output shape: torch.Size([512, 128])
Linear output shape: torch.Size([512, 7])
Training
As previously noted, the training of the game policy combines the strategies of both players. To do this, we experiment with the use of the DQN and DDQN agents.
DQN
First of all, we configure all the parameters not only of the training but also of the evaluation, environment and agent, and with these parameters we instantiate the module with the neural networks, the agent and the environment.
Once all the entities and parameters that are needed have been instantiated, we can start training the agent to optimize the policy.
The function train shows iteration metrics periodically. Some values will correspond to averages obtained from current values and previous iterations in order to recognize trends during training.
In each episode the policy model is evaluated playing against a random agent.
# train policy
dqn_args = train(
agent=dqn_agent,
env=env,
params_train=params_train,
params_eval=params_eval,
checkpoints_dir_path=CHECKPOINTS_DIR_PATH,
device=device
)Training policy in ConnectFourNet.
Episode Step Train rewards (avg) steps (avg) running loss (avg) Eval reward (mean std) win rate(%) Eps LR Time
999 20467 0.0 (-0.04) 19 (19.84) 0.0464 (0.0459) 0.46 0.7269 60.0 0.8018 0.01 0.02s
1999 40734 0.0 (-0.04) 18 (17.0) 0.0502 (0.0453) 0.65 0.5723 70.0 0.7157 0.01 0.02s
2999 60847 -1.0 (-0.08) 32 (23.28) 0.055 (0.048) 0.77 0.4659 79.0 0.64 0.01 0.03s
3999 80435 0.0 (0.0) 23 (21.64) 0.0513 (0.0467) 0.79 0.4312 80.0 0.5746 0.01 0.03s
4999 100228 0.0 (-0.08) 28 (20.48) 0.0469 (0.0487) 0.81 0.4403 83.0 0.5159 0.01 0.02s
5999 119836 0.0 (-0.04) 12 (18.96) 0.0389 (0.0475) 0.93 0.2551 93.0 0.4641 0.01 0.04s
6999 139392 0.0 (-0.04) 7 (18.04) 0.0413 (0.0471) 0.89 0.3434 90.0 0.4183 0.01 0.03s
7999 159565 0.0 (-0.04) 31 (21.44) 0.0488 (0.0489) 0.94 0.2375 94.0 0.3763 0.01 0.05s
8999 179803 0.0 (-0.08) 24 (20.24) 0.0483 (0.0486) 0.93 0.2551 93.0 0.339 0.01 0.04s
9999 199816 0.0 (0.0) 26 (21.68) 0.0533 (0.0458) 0.89 0.3434 90.0 0.3063 0.01 0.05s
10999 220056 0.0 (-0.04) 24 (19.76) 0.0542 (0.0464) 0.87 0.3648 88.0 0.277 0.01 0.06s
11999 240844 0.0 (-0.08) 24 (21.76) 0.0496 (0.0464) 0.89 0.3129 89.0 0.2504 0.01 0.04s
12999 261341 0.0 (0.0) 11 (21.28) 0.0445 (0.0462) 0.92 0.2713 92.0 0.2272 0.01 0.04s
13999 282521 0.0 (-0.04) 22 (20.56) 0.0517 (0.049) 0.9 0.3 90.0 0.206 0.01 0.04s
14999 304168 0.0 (0.0) 28 (23.28) 0.0407 (0.0456) 0.91 0.2862 91.0 0.187 0.01 0.05s
15999 325292 0.0 (-0.08) 7 (22.2) 0.0489 (0.0465) 0.93 0.2551 93.0 0.1707 0.01 0.05s
16999 346287 0.0 (-0.04) 16 (22.32) 0.0503 (0.0474) 0.89 0.3129 89.0 0.1564 0.01 0.05s
17999 367640 0.0 (0.0) 20 (20.0) 0.0466 (0.0471) 0.92 0.2713 92.0 0.1436 0.01 0.04s
18999 388744 0.0 (-0.04) 22 (20.68) 0.0484 (0.0458) 0.92 0.2713 92.0 0.1325 0.01 0.05s
19999 409492 0.0 (0.0) 21 (19.24) 0.0464 (0.0458) 0.93 0.2551 93.0 0.1228 0.01 0.04s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_19_33_58.chkpt at step 409492 with epsilon threshold 0.1228
20999 430888 0.0 (-0.04) 19 (23.44) 0.0566 (0.0453) 0.87 0.3363 87.0 0.1141 0.01 0.05s
21999 452183 0.0 (0.0) 25 (18.68) 0.0498 (0.0447) 0.87 0.3363 87.0 0.1064 0.01 0.04s
22999 473468 0.0 (-0.04) 15 (22.16) 0.0454 (0.0471) 0.95 0.2179 95.0 0.0996 0.01 0.05s
23999 494580 0.0 (-0.04) 34 (20.04) 0.0515 (0.0476) 0.95 0.2179 95.0 0.0937 0.01 0.04s
24999 516624 0.0 (-0.04) 26 (22.6) 0.0508 (0.0452) 0.94 0.2375 94.0 0.0883 0.01 0.05s
25999 538052 0.0 (0.0) 23 (21.6) 0.0535 (0.0453) 0.87 0.3363 87.0 0.0837 0.01 0.04s
26999 560029 0.0 (-0.04) 30 (24.28) 0.0434 (0.0474) 0.92 0.2713 92.0 0.0795 0.01 0.04s
27999 581599 0.0 (0.0) 27 (21.32) 0.0423 (0.0452) 0.84 0.3929 85.0 0.0759 0.01 0.05s
28999 604226 0.0 (-0.04) 26 (25.16) 0.0406 (0.0452) 0.83 0.3756 83.0 0.0726 0.01 0.05s
29999 626930 0.0 (0.0) 21 (21.96) 0.0535 (0.0442) 0.93 0.2551 93.0 0.0698 0.01 0.04s
30999 649048 0.0 (-0.04) 32 (20.72) 0.0378 (0.0451) 0.87 0.3648 88.0 0.0673 0.01 0.06s
31999 670741 0.0 (0.0) 23 (21.4) 0.0382 (0.0449) 0.9 0.3 90.0 0.0652 0.01 0.05s
32999 693199 0.0 (0.0) 13 (20.52) 0.0399 (0.0448) 0.94 0.2375 94.0 0.0633 0.01 0.05s
33999 715469 0.0 (0.0) 18 (21.52) 0.0429 (0.0475) 0.91 0.2862 91.0 0.0616 0.01 0.05s
34999 737025 0.0 (0.0) 26 (21.68) 0.0523 (0.0458) 0.86 0.347 86.0 0.0602 0.01 0.05s
35999 759430 0.0 (-0.04) 21 (24.16) 0.0389 (0.0452) 0.93 0.2551 93.0 0.0589 0.01 0.05s
36999 782127 0.0 (0.0) 22 (18.28) 0.053 (0.0465) 0.92 0.2713 92.0 0.0578 0.01 0.05s
37999 804640 0.0 (0.0) 12 (22.24) 0.0516 (0.0458) 0.92 0.2713 92.0 0.0568 0.01 0.05s
38999 827212 0.0 (-0.08) 18 (22.08) 0.0368 (0.0437) 0.92 0.2713 92.0 0.0559 0.01 0.05s
39999 850050 0.0 (0.0) 21 (20.36) 0.038 (0.0458) 0.88 0.325 88.0 0.0552 0.01 0.05s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_20_09_30.chkpt at step 850050 with epsilon threshold 0.0552
40999 872524 0.0 (0.0) 31 (22.92) 0.0479 (0.045) 0.95 0.2179 95.0 0.0545 0.002 0.05s
41999 894679 0.0 (-0.04) 23 (21.92) 0.0392 (0.045) 0.89 0.3129 89.0 0.054 0.002 0.07s
42999 916012 0.0 (0.0) 19 (23.56) 0.0451 (0.0452) 0.91 0.2862 91.0 0.0535 0.002 0.07s
43999 938584 0.0 (0.0) 27 (24.32) 0.0395 (0.0445) 0.94 0.2375 94.0 0.053 0.002 0.06s
44999 961498 0.0 (0.0) 42 (24.92) 0.0521 (0.0439) 0.86 0.347 86.0 0.0527 0.002 0.06s
45999 984876 0.0 (0.0) 12 (20.08) 0.0415 (0.0442) 0.88 0.325 88.0 0.0523 0.002 0.06s
46999 1007943 0.0 (0.0) 23 (22.36) 0.0515 (0.0459) 0.86 0.347 86.0 0.052 0.002 0.05s
47999 1031514 0.0 (0.0) 21 (20.76) 0.0529 (0.0457) 0.87 0.3363 87.0 0.0517 0.002 0.06s
48999 1055374 0.0 (0.0) 20 (21.8) 0.0342 (0.0456) 0.94 0.2375 94.0 0.0515 0.002 0.06s
49999 1078772 0.0 (0.0) 26 (24.56) 0.0461 (0.0448) 0.89 0.3129 89.0 0.0513 0.002 0.07s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_20_28_55.chkpt at step 1078772 with epsilon threshold 0.0513
The rewards for each episode are usually zero since the winning score is added to the losing score (-1 + 1), however the average that appears in parentheses is sometimes different from zero. This is because in some episodes the model chooses an action that is not valid and obtains a reward as if it had lost the game, but a winning reward is not assigned to the other player since the next transition does not occur.
As we can see, this average tends to decrease, i.e., the model tends to make fewer errors when choosing actions that are legal.
# plot training results
plot(
*dqn_args,
figures_dir_path=FIGURES_DIR_PATH,
)Model training and evaluation figure saved to exports/figures/ConnectFourNet_2023_10_29_T_20_28_55.png
In the previous graphs we can verify what we mentioned previously. Furthermore, it is observed how the number of iterations tends to increase slightly during training and evaluation.
It would be expected that the games were longer in training given that as the model learns it continues to play against itself.
However, with regards to evaluation, and considering that the model plays against random decision making, it would be expected that, although there is a slight reduction in the number of turns it takes to win with the passing of the iterations, the reduction was greater.
In any case, the model tends to win almost all the games, as we can see in Evaluation rates.
An option to try to incentivize the model to try to make the games shorter could be to assign a negative reward to
params_env['rewards']['prolongation'].
Besides, since the agent is learning a strategy or action policy for a non-stationary scenario, which depends on specific actions, the graph of Running loss is only considered for the choice of the training batch size hyperparameter (batch_size) and to verify that the gradients neither increase nor decrease considerably or exponentially (exploding gradients and vanishing gradients).
We can check below how our policy acts with examples of possible legal boards in key game situations, for illustration purposes only.
First we see if the policy model is able to complete some lines of four to win a game.
# set eval mode
dqn_net.policy.eval()
titles = [
"Horizontal P2", "Horizontal P1", "Vertical P2", "Vertical P1", "Diagonal P2",
"Diagonal P1", "Anti diagonal P2", "Anti diagonal P1"
]
finishes_boards = [board for i, (board, _) in enumerate(
finishes_boards_solutions) if i % 2 != 0]
html = get_html(
policy=dqn_net.policy,
observations=finishes_boards,
titles=titles,
device=device
)
display(HTML(html))Horizontal P2
t
t+1
Horizontal P1
t
t+1
Vertical P2
t
t+1
Vertical P1
t
t+1
Diagonal P2
t
t+1
Diagonal P1
t
t+1
Anti-Diagonal P2
t
t+1
Anti-Diagonal P1
t
t+1
We can see how this model is capable of choosing actions to win the game either horizontally, vertically, diagonally or anti-diagonally, although not for both players in all line modalities, for both players in different modalities.
A similar thing happens when our approximator is presented with the opportunity to block lines of four consecutive counters to win a game just in the turn before these winning transitions can happen.
blocks_boards = [board for i, (board, _) in enumerate(
blocks_boards_solutions) if i % 2 != 0]
html = get_html(
policy=dqn_net.policy,
observations=blocks_boards,
titles=titles,
device=device
)
display(HTML(html))Horizontal P2
t
t+1
Horizontal P1
t
t+1
Vertical P2
t
t+1
Vertical P1
t
t+1
Diagonal P2
t
t+1
Diagonal P1
t
t+1
Anti diagonal P2
t
t+1
Anti diagonal P1
t
t+1
DDQN
Finally, a double DQN agent was trained in the same exercise. Similar training hyperparameters have been used.
# net
ddqn_net = ConnectFourNet(out_features=params_env['action_space']).to(device)
# agent
params_agent['double'] = True
ddqn_agent = DQNAgent(
net=ddqn_net,
params=params_agent,
device=device,
load_model_path=load_model_path
)
# train policy
ddqn_args = train(
agent=ddqn_agent,
env=env,
params_train=params_train,
params_eval=params_eval,
checkpoints_dir_path=CHECKPOINTS_DIR_PATH,
device=device
)Training policy in ConnectFourNet.
Episode Step Train rewards (avg) steps (avg) running loss (avg) Eval reward (mean std) win rate(%) Eps LR Time
999 20530 -1.0 (-0.08) 15 (20.2) 0.0469 (0.0457) 0.44 0.7658 61.0 0.8015 0.01 0.02s
1999 40849 0.0 (-0.08) 23 (21.96) 0.0536 (0.044) 0.63 0.627 71.0 0.7152 0.01 0.04s
2999 60557 0.0 (-0.08) 18 (19.4) 0.041 (0.0459) 0.77 0.4659 79.0 0.641 0.01 0.03s
3999 79788 0.0 (0.0) 24 (19.88) 0.0379 (0.047) 0.74 0.5024 77.0 0.5766 0.01 0.03s
4999 99764 0.0 (0.0) 27 (21.72) 0.0397 (0.0461) 0.77 0.4659 79.0 0.5172 0.01 0.04s
5999 118905 0.0 (0.0) 18 (18.0) 0.0575 (0.0478) 0.82 0.3842 82.0 0.4665 0.01 0.04s
6999 138750 0.0 (0.0) 25 (19.12) 0.0433 (0.0478) 0.86 0.347 86.0 0.4197 0.01 0.03s
7999 158652 0.0 (0.0) 7 (19.64) 0.0455 (0.0477) 0.79 0.4312 80.0 0.3781 0.01 0.04s
8999 179199 0.0 (-0.08) 22 (19.48) 0.0327 (0.0484) 0.79 0.4073 79.0 0.34 0.01 0.03s
9999 199498 0.0 (0.0) 18 (18.16) 0.0466 (0.048) 0.87 0.3363 87.0 0.3068 0.01 0.03s
10999 220203 0.0 (0.0) 25 (20.56) 0.0366 (0.0459) 0.85 0.3571 85.0 0.2768 0.01 0.02s
11999 241541 0.0 (0.0) 15 (20.44) 0.0436 (0.0447) 0.89 0.3129 89.0 0.2495 0.01 0.04s
12999 262846 0.0 (-0.04) 16 (20.32) 0.0506 (0.0469) 0.89 0.3129 89.0 0.2256 0.01 0.06s
13999 284491 0.0 (0.0) 36 (22.92) 0.0419 (0.0466) 0.84 0.3666 84.0 0.2042 0.01 0.04s
14999 305991 0.0 (-0.08) 21 (22.44) 0.0566 (0.0452) 0.92 0.2713 92.0 0.1855 0.01 0.05s
15999 327637 0.0 (-0.08) 34 (24.84) 0.0477 (0.046) 0.92 0.3059 93.0 0.169 0.01 0.04s
16999 349786 0.0 (0.0) 20 (23.08) 0.0491 (0.0466) 0.81 0.3923 81.0 0.1542 0.01 0.04s
17999 372238 0.0 (-0.04) 24 (22.0) 0.0486 (0.0472) 0.84 0.3666 84.0 0.1411 0.01 0.05s
18999 394402 0.0 (0.0) 21 (20.84) 0.05 (0.0436) 0.91 0.2862 91.0 0.1297 0.01 0.05s
19999 416747 0.0 (-0.04) 24 (21.16) 0.0405 (0.0454) 0.87 0.3363 87.0 0.1197 0.01 0.04s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_20_59_16.chkpt at step 416747 with epsilon threshold 0.1197
20999 439501 0.0 (-0.08) 17 (23.36) 0.0449 (0.0444) 0.93 0.2551 93.0 0.1108 0.01 0.04s
21999 462133 0.0 (0.0) 11 (21.0) 0.0531 (0.0469) 0.88 0.325 88.0 0.1031 0.01 0.06s
22999 484962 0.0 (0.0) 31 (24.04) 0.0435 (0.0465) 0.88 0.325 88.0 0.0963 0.01 0.04s
23999 507774 0.0 (0.0) 25 (21.08) 0.0389 (0.0459) 0.89 0.3129 89.0 0.0904 0.01 0.06s
24999 531175 0.0 (0.0) 20 (23.6) 0.0453 (0.0427) 0.89 0.3129 89.0 0.0851 0.01 0.05s
25999 554451 0.0 (0.0) 13 (22.84) 0.0498 (0.0445) 0.91 0.2862 91.0 0.0805 0.01 0.06s
26999 577973 0.0 (0.0) 26 (23.0) 0.0417 (0.0443) 0.87 0.3363 87.0 0.0765 0.01 0.05s
27999 601554 0.0 (0.0) 23 (23.6) 0.0493 (0.0436) 0.84 0.3666 84.0 0.073 0.01 0.05s
28999 624682 0.0 (-0.08) 22 (19.16) 0.0485 (0.0456) 0.93 0.2551 93.0 0.07 0.01 0.04s
29999 648382 0.0 (0.0) 23 (22.48) 0.0519 (0.0447) 0.91 0.2862 91.0 0.0674 0.01 0.05s
30999 671673 0.0 (0.0) 21 (22.6) 0.05 (0.0445) 0.85 0.3571 85.0 0.0651 0.01 0.05s
31999 695409 0.0 (-0.04) 20 (24.52) 0.0426 (0.0436) 0.82 0.3842 82.0 0.0631 0.01 0.06s
32999 719257 0.0 (-0.04) 27 (21.44) 0.0576 (0.0437) 0.79 0.4073 79.0 0.0614 0.01 0.06s
33999 742359 0.0 (0.0) 18 (22.04) 0.0499 (0.0421) 0.87 0.3363 87.0 0.0599 0.01 0.05s
34999 765618 -1.0 (-0.04) 32 (21.84) 0.0401 (0.0437) 0.92 0.2713 92.0 0.0586 0.01 0.08s
35999 789207 0.0 (-0.08) 29 (22.24) 0.0469 (0.0438) 0.87 0.3363 87.0 0.0575 0.01 0.06s
36999 812560 0.0 (0.0) 31 (24.28) 0.0492 (0.0434) 0.75 0.433 75.0 0.0565 0.01 0.05s
37999 836424 0.0 (-0.04) 24 (26.56) 0.0431 (0.0427) 0.82 0.3842 82.0 0.0556 0.01 0.09s
38999 859950 0.0 (0.0) 25 (23.56) 0.0395 (0.0432) 0.87 0.3363 87.0 0.0549 0.01 0.05s
39999 883688 0.0 (0.0) 24 (24.08) 0.0397 (0.0441) 0.83 0.3756 83.0 0.0542 0.01 0.06s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_21_36_49.chkpt at step 883688 with epsilon threshold 0.0542
40999 907411 0.0 (0.0) 26 (23.44) 0.0444 (0.0457) 0.91 0.2862 91.0 0.0537 0.002 0.05s
41999 932271 0.0 (0.0) 28 (25.68) 0.0358 (0.0416) 0.89 0.3129 89.0 0.0532 0.002 0.07s
42999 957047 0.0 (0.0) 30 (21.28) 0.0382 (0.0402) 0.83 0.3756 83.0 0.0527 0.002 0.07s
43999 981954 0.0 (0.0) 30 (25.92) 0.0426 (0.0436) 0.85 0.3571 85.0 0.0523 0.002 0.07s
44999 1007640 0.0 (-0.04) 20 (26.88) 0.0403 (0.043) 0.92 0.2713 92.0 0.052 0.002 0.08s
45999 1032640 0.0 (0.0) 9 (22.64) 0.0459 (0.0426) 0.91 0.2862 91.0 0.0517 0.002 0.08s
46999 1057268 0.0 (0.0) 25 (25.92) 0.0411 (0.0443) 0.9 0.3 90.0 0.0515 0.002 0.07s
47999 1082648 0.0 (0.0) 28 (25.84) 0.0459 (0.0419) 0.89 0.3129 89.0 0.0513 0.002 0.09s
48999 1107489 0.0 (0.0) 23 (28.04) 0.0403 (0.0428) 0.87 0.3363 87.0 0.0511 0.002 0.07s
49999 1132553 0.0 (0.0) 24 (25.6) 0.0395 (0.041) 0.84 0.3666 84.0 0.051 0.002 0.07s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_21_57_41.chkpt at step 1132553 with epsilon threshold 0.051
As we can see below, the training results are similar to the DQN agent, however, the results of the model when faced with random decision making are slightly worse.
# plot training results
plot(
*ddqn_args,
figures_dir_path=FIGURES_DIR_PATH,
)Model training and evaluation figure saved to exports/figures/ConnectFourNet_2023_10_29_T_21_57_41.png
Again, just for the sake of illustration, we pose some observations of key scenarios in the game to our model to get an idea of how it performs.
# set eval mode
ddqn_net.policy.eval()
html = get_html(
policy=ddqn_net.policy,
observations=finishes_boards,
titles=titles,
device=device
)
display(HTML(html))Horizontal P2
t
t+1
Horizontal P1
t
t+1
Vertical P2
t
t+1
Vertical P1
t
t+1
Diagonal P2
t
t+1
Diagonal P1
t
t+1
Anti diagonal P2
t
t+1
Anti diagonal P1
t
t+1
As we see, it is still capable of resolving the best action on the majority of the boards both to win the game and to block the opposing player from winning it:
html = get_html(
policy=ddqn_net.policy,
observations=blocks_boards,
titles=titles,
device=device
)
display(HTML(html))Horizontal P2
t
t+1
Horizontal P1
t
t+1
Vertical P2
t
t+1
Vertical P1
t
t+1
Diagonal P2
t
t+1
Diagonal P1
t
t+1
Anti diagonal P2
t
t+1
Anti diagonal P1
t
t+1
It is possible that the models do not perform as well in prolonged games since during training they receive more observations of boards with fewer counters.
There are many training parameters to modify and try to get a better generalization of the model. For example, rewards, epsilon decay rate, observation batch size, memory capacity, learning rate or even implement another neural network architecture or modify and/or include convolutional or fully connected layers .
In addition, more modern reinforcement learning algorithms such as proximal policy optimization (PPO) can be implemented, or make use of a memory of observations with prioritized sampling.
The library TorchRL, which is a reinforcement learning library for PyTorch, implements these and many other options.
Export
As noted in the introduction to this Jupyter notebook, this model is used in a decentralised zero-knowledge web application.
For this reason, the module torch.onnx is used to convert the native PyTorch calculation graph to a ONNX graph and thus, using the ONNX Runtime Web JavaScript library, this machine learning model can be deployed in our web application.
# export onnx version of the policy trained by the DDQN
_, _, _, model_id = dqn_args
export_onnx(
policy=dqn_agent.net.policy,
policies_dir_path=POLICIES_DIR_PATH,
model_id=model_id,
device=device
)Model in ONNX format saved to exports/policies/ConnectFourNet_2023_10_29_T_20_28_55.onnx