Training a DQN agent to play Connect Four

Open in Colab
View on GitHub

During the development of a zero-knowledge application called zk Connect Four, the appropriate context was given for the implementation of a reinforcement learning agent.

Reinforcement learning is a machine learning paradigm in which an agent learns from its environment through trial and error. That is, it is a computational approach to learning from actions.

The goal is to learn what the most appropriate actions are in each game state to maximise a reward. To this end, a DQN agent and a double DQN (DDQN) agent are implemented in this notebook to optimize an approximator in the Connect Four game using PyTorch.

In this article

Setup

Fundamental concepts

The main actors in the reinforcement learning paradigm are the environment and the agent. The environment is the context in which the agent lives and interacts with.

In each iteration, the agent receives an observation of the current state of the game, decides the best action to take, and acts on the environment.

In this way, a change in the state of the environment is generated and a reward is obtained from this transition.

The objective of this exercise is to obtain a policy that is capable of maximising this reward, thus obtaining the best action for each transition and promoting an optimal approach to the Connect Four game.

Below we introduce the main classes that underpin this paradigm and that will allow us to establish the actors and the context of the game.

Environment

The environment implementation consists of a class whose main methods and tasks are the following:

  • get_valid_actions: provides a list with all valid actions given the current state, i.e., the board at that moment in the game.

  • step- undertakes a transition with the action received from the agent (an index of a column on the board), checks whether it has been won or tied, and returns a reward depending on the result.

  • switch_turn: changes the game turn between both players.

  • reset- initializes the state and all game variables and returns an initial state.

DQN and DDQN agents

DQN and DDQN are essentially algorithms that consist of learning an approximator for a Q function action-value.

In our case, we will try to learn the strategies of both players of the Connect Four game, so instead of being an agent that is only trained on one strategy, it would be a multi-agent reinforcement learning (MARL) modality.

Then, an alternating Markov game scenario is presented, or a zero-sum game where one player's gain is proportional to the other's loss.

The implementation of the class used to instantiate the agents consists of the following methods:

In each iteration, the agent is prompted to provide us with an action according to an epsilon-greedy strategy to treat the exploration-exploitation dilemma.

valid_actions = env.get_valid_actions()

action = agent.act(state, valid_actions,
                    num_steps, enforce_valid_action)

Sometimes our model will predict the action to be taken and other times an action will be sampled.

As can be seen in the graph below, in each iteration the probability of new actions being explored or learnt actions being exploited will decay exponentially based on the rate of decay eps_decay.

# epilson decay graph
# num episodes times avg steps per episode
num_steps = 50000 * 25
eps_start = 0.9
eps_end = 0.05
eps_decay = num_steps / 7.5
steps = np.arange(num_steps)
epsilon = eps_end + (eps_start - eps_end) * np.exp(-1 * steps / eps_decay)
plt.rcParams["figure.figsize"] = (20, 4)
plt.plot(steps, epsilon)
plt.title('Epsilon decay')
plt.xlabel('Step')
plt.ylabel('Epsilon')
plt.show()

This strategy is implemented in the method act of the DQN agent below.

If the sample value is greater than the epsilon threshold, the best action given the last observation of the environment will be approximated (exploitation of learnt actions).

Otherwise, an action will be chosen (exploration) from among the valid actions uniformly.

The option is given, through enforce_valid_action, to ensure that the chosen action is valid, or to urge the model to learn which columns of the board can qualify for a new piece in each turn, i.e., to learn this rule of the game.

The data used to train the agent is generated from scratch during each training. This data are stored in an object of class ExperienceReplay and are randomly sampled with the aim of optimizing a strategy or action plan (policy).

The observations in each iteration are appended as named tuples (itertools' named_tuple) to a sequence list (deque, or double-ended queue) with limited capacity, where observations can be inserted and deleted from both ends, thus creating a gaming experience.

Each observation includes a state (the game board at that precise moment), an action (the index of the column that has been chosen), a next state (the game board after considering the last action, or None if it is a final state, i.e., if a winning action has been taken) and a reward.

Once we have obtained and memorized more observations than the number of batch_size we use, we will proceed to optimize the model.

In our example, it is only optimized once per episode since the values of the transitions will sometimes depend on previous transitions.

For example, when a player wins the game the player gets a positive reward, so for the state transition in this iteration a tuple is generated as follows:

transition = Transition(state, action, None, reward)

The transition would consist of the state in t, the action taken, the state in t+1 (as it is a final state, it is assigned None) and the reward obtained.

What happens is that during the generation of the observation in the previous iteration, we did not yet know about this outcome and therefore the reward was not negative for having lost the game.

For this reason, once each training game ends, this relationship of rewards is resolved to be saved later in memory.

Therefore, the transition in t-1 from the transition above would be:

transition_t-1 = Transition(state, action, next_state, reward + env.params['rewards']['loss'])

The loss reward is added to the assigned reward and not directly assigned since it is possible to configure an extension reward in the game as a hyperparameter.

The optimization process is shown below:

First, a number of observations are sampled in memory in the method optimize, and all state and action values of the observations are concatenated into separate arrays.

Furthermore, both the states in t and t+1 get a new dimension since a board transformation is applied using the get_two_channels function.

The idea is that for each board with value 0, 1 or 2, two matrices with binary values are obtained, one for each player.

These two matrices will be the two input channels per observation that our convolutional neural net will get as input.

A logical array is then generated that will be used as a mask to assign Q values to non-final states, while final states will have a value of zero.

In the case that a DQN agent is used, only the target model will be used to obtain the values of Q, obtaining the values of maximum magnitude among all the actions.

While with a DDQN agent, both our policy model and our "fixed" target model are used to obtain the time difference values of Q that we want to approach in order to solve a possible overestimation of these values.

Subsequently, the long-term profit that can be anticipated is calculated. For this, a discount factor gamma is applied and is subtracted from the reward for that transition (rather than added as in the Bellman equation as suggested in this Kaggle notebook).

In this way we adapt our optimization for the learning of both strategies to the alternating zero-sum game at hand.

It is at this moment when the parameters of our policy are optimized. To do this, the loss is calculated (in our case the criterion torch.nn.HuberLoss is used) to minimize the error of the temporal difference δ, null values are assigned to the gradients to later be computed using backpropagation, a clipping of them is applied if the training has been configured that way and the parameters are updated based on the calculated gradient.

Once the parameters of the policy have been optimized during training, the parameters of our target are considered to be updated. For this, two strategies are implemented: soft update y hard update.

The former will update the parameters gradually in each episode only obtaining a percentage of the parameters of policy based on the TAU value considered as a hyperparameter in training.

For example, if TAU is 0.001, target will use 0.1% of the parameters of policy_ and will keep the rest (99.9%).

On the contrary, through a hard update, all parameters will be copied at once if a number of iterations, to be configured, is reached.

Finally, three methods are implemented to save transitions in memory, save current parameters of policy to configure checkpoints during training, and to load the weights of a saved model at the beginning of a training.

Note that the optimizer is instantiated once the weights of the previously saved model have been loaded.

Modeling

The model chosen to learn the Connect Four game consists of two convolutional layers with ReLU activation and two fully connected layers (the first also with ReLU activation) at the output.

In the same object of class ConnectFourNet two modules are instantiated, one for policy and another for target. The latter is copied from the former once its parameters (weights and bias) have been initialized.

In this way we can save the parameters of both models during training in the same state dictionary.

ConnectFourNet(
  (policy): Sequential(
    (0): Conv2d(2, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Flatten(start_dim=1, end_dim=-1)
    (5): Linear(in_features=672, out_features=128, bias=True)
    (6): ReLU()
    (7): Linear(in_features=128, out_features=7, bias=True)
  )
  (target): Sequential(
    (0): Conv2d(2, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Flatten(start_dim=1, end_dim=-1)
    (5): Linear(in_features=672, out_features=128, bias=True)
    (6): ReLU()
    (7): Linear(in_features=128, out_features=7, bias=True)
  )
)
ConnectFourNet(out_features=7).layer_summary([512, 2, 6, 7])
Conv2d output shape:          torch.Size([512, 16, 6, 7])
Conv2d output shape:          torch.Size([512, 16, 6, 7])
Flatten output shape:         torch.Size([512, 672])
Linear output shape:          torch.Size([512, 128])
Linear output shape:          torch.Size([512, 7])

Training

As previously noted, the training of the game policy combines the strategies of both players. To do this, we experiment with the use of the DQN and DDQN agents.

DQN

First of all, we configure all the parameters not only of the training but also of the evaluation, environment and agent, and with these parameters we instantiate the module with the neural networks, the agent and the environment.

Once all the entities and parameters that are needed have been instantiated, we can start training the agent to optimize the policy.

The function train shows iteration metrics periodically. Some values will correspond to averages obtained from current values and previous iterations in order to recognize trends during training.

In each episode the policy model is evaluated playing against a random agent.

# train policy
dqn_args = train(
    agent=dqn_agent,
    env=env,
    params_train=params_train,
    params_eval=params_eval,
    checkpoints_dir_path=CHECKPOINTS_DIR_PATH,
    device=device
)
Training policy in ConnectFourNet.
  Episode     Step   Train rewards (avg)  steps (avg)   running loss (avg)  Eval reward (mean std)  win rate(%)   Eps      LR     Time
    999      20467       0.0 (-0.04)       19 (19.84)    0.0464 (0.0459)         0.46  0.7269          60.0     0.8018   0.01    0.02s
    1999     40734       0.0 (-0.04)       18 (17.0)     0.0502 (0.0453)         0.65  0.5723          70.0     0.7157   0.01    0.02s
    2999     60847       -1.0 (-0.08)      32 (23.28)     0.055 (0.048)          0.77  0.4659          79.0      0.64    0.01    0.03s
    3999     80435        0.0 (0.0)        23 (21.64)    0.0513 (0.0467)         0.79  0.4312          80.0     0.5746   0.01    0.03s
    4999     100228      0.0 (-0.08)       28 (20.48)    0.0469 (0.0487)         0.81  0.4403          83.0     0.5159   0.01    0.02s
    5999     119836      0.0 (-0.04)       12 (18.96)    0.0389 (0.0475)         0.93  0.2551          93.0     0.4641   0.01    0.04s
    6999     139392      0.0 (-0.04)       7 (18.04)     0.0413 (0.0471)         0.89  0.3434          90.0     0.4183   0.01    0.03s
    7999     159565      0.0 (-0.04)       31 (21.44)    0.0488 (0.0489)         0.94  0.2375          94.0     0.3763   0.01    0.05s
    8999     179803      0.0 (-0.08)       24 (20.24)    0.0483 (0.0486)         0.93  0.2551          93.0     0.339    0.01    0.04s
    9999     199816       0.0 (0.0)        26 (21.68)    0.0533 (0.0458)         0.89  0.3434          90.0     0.3063   0.01    0.05s
  10999     220056      0.0 (-0.04)       24 (19.76)    0.0542 (0.0464)         0.87  0.3648          88.0     0.277    0.01    0.06s
  11999     240844      0.0 (-0.08)       24 (21.76)    0.0496 (0.0464)         0.89  0.3129          89.0     0.2504   0.01    0.04s
  12999     261341       0.0 (0.0)        11 (21.28)    0.0445 (0.0462)         0.92  0.2713          92.0     0.2272   0.01    0.04s
  13999     282521      0.0 (-0.04)       22 (20.56)     0.0517 (0.049)           0.9  0.3            90.0     0.206    0.01    0.04s
  14999     304168       0.0 (0.0)        28 (23.28)    0.0407 (0.0456)         0.91  0.2862          91.0     0.187    0.01    0.05s
  15999     325292      0.0 (-0.08)        7 (22.2)     0.0489 (0.0465)         0.93  0.2551          93.0     0.1707   0.01    0.05s
  16999     346287      0.0 (-0.04)       16 (22.32)    0.0503 (0.0474)         0.89  0.3129          89.0     0.1564   0.01    0.05s
  17999     367640       0.0 (0.0)        20 (20.0)     0.0466 (0.0471)         0.92  0.2713          92.0     0.1436   0.01    0.04s
  18999     388744      0.0 (-0.04)       22 (20.68)    0.0484 (0.0458)         0.92  0.2713          92.0     0.1325   0.01    0.05s
  19999     409492       0.0 (0.0)        21 (19.24)    0.0464 (0.0458)         0.93  0.2551          93.0     0.1228   0.01    0.04s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_19_33_58.chkpt at step 409492 with epsilon threshold 0.1228
  20999     430888      0.0 (-0.04)       19 (23.44)    0.0566 (0.0453)         0.87  0.3363          87.0     0.1141   0.01    0.05s
  21999     452183       0.0 (0.0)        25 (18.68)    0.0498 (0.0447)         0.87  0.3363          87.0     0.1064   0.01    0.04s
  22999     473468      0.0 (-0.04)       15 (22.16)    0.0454 (0.0471)         0.95  0.2179          95.0     0.0996   0.01    0.05s
  23999     494580      0.0 (-0.04)       34 (20.04)    0.0515 (0.0476)         0.95  0.2179          95.0     0.0937   0.01    0.04s
  24999     516624      0.0 (-0.04)       26 (22.6)     0.0508 (0.0452)         0.94  0.2375          94.0     0.0883   0.01    0.05s
  25999     538052       0.0 (0.0)        23 (21.6)     0.0535 (0.0453)         0.87  0.3363          87.0     0.0837   0.01    0.04s
  26999     560029      0.0 (-0.04)       30 (24.28)    0.0434 (0.0474)         0.92  0.2713          92.0     0.0795   0.01    0.04s
  27999     581599       0.0 (0.0)        27 (21.32)    0.0423 (0.0452)         0.84  0.3929          85.0     0.0759   0.01    0.05s
  28999     604226      0.0 (-0.04)       26 (25.16)    0.0406 (0.0452)         0.83  0.3756          83.0     0.0726   0.01    0.05s
  29999     626930       0.0 (0.0)        21 (21.96)    0.0535 (0.0442)         0.93  0.2551          93.0     0.0698   0.01    0.04s
  30999     649048      0.0 (-0.04)       32 (20.72)    0.0378 (0.0451)         0.87  0.3648          88.0     0.0673   0.01    0.06s
  31999     670741       0.0 (0.0)        23 (21.4)     0.0382 (0.0449)           0.9  0.3            90.0     0.0652   0.01    0.05s
  32999     693199       0.0 (0.0)        13 (20.52)    0.0399 (0.0448)         0.94  0.2375          94.0     0.0633   0.01    0.05s
  33999     715469       0.0 (0.0)        18 (21.52)    0.0429 (0.0475)         0.91  0.2862          91.0     0.0616   0.01    0.05s
  34999     737025       0.0 (0.0)        26 (21.68)    0.0523 (0.0458)          0.86  0.347          86.0     0.0602   0.01    0.05s
  35999     759430      0.0 (-0.04)       21 (24.16)    0.0389 (0.0452)         0.93  0.2551          93.0     0.0589   0.01    0.05s
  36999     782127       0.0 (0.0)        22 (18.28)     0.053 (0.0465)         0.92  0.2713          92.0     0.0578   0.01    0.05s
  37999     804640       0.0 (0.0)        12 (22.24)    0.0516 (0.0458)         0.92  0.2713          92.0     0.0568   0.01    0.05s
  38999     827212      0.0 (-0.08)       18 (22.08)    0.0368 (0.0437)         0.92  0.2713          92.0     0.0559   0.01    0.05s
  39999     850050       0.0 (0.0)        21 (20.36)     0.038 (0.0458)          0.88  0.325          88.0     0.0552   0.01    0.05s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_20_09_30.chkpt at step 850050 with epsilon threshold 0.0552
  40999     872524       0.0 (0.0)        31 (22.92)     0.0479 (0.045)         0.95  0.2179          95.0     0.0545  0.002    0.05s
  41999     894679      0.0 (-0.04)       23 (21.92)     0.0392 (0.045)         0.89  0.3129          89.0     0.054   0.002    0.07s
  42999     916012       0.0 (0.0)        19 (23.56)    0.0451 (0.0452)         0.91  0.2862          91.0     0.0535  0.002    0.07s
  43999     938584       0.0 (0.0)        27 (24.32)    0.0395 (0.0445)         0.94  0.2375          94.0     0.053   0.002    0.06s
  44999     961498       0.0 (0.0)        42 (24.92)    0.0521 (0.0439)          0.86  0.347          86.0     0.0527  0.002    0.06s
  45999     984876       0.0 (0.0)        12 (20.08)    0.0415 (0.0442)          0.88  0.325          88.0     0.0523  0.002    0.06s
  46999    1007943       0.0 (0.0)        23 (22.36)    0.0515 (0.0459)          0.86  0.347          86.0     0.052   0.002    0.05s
  47999    1031514       0.0 (0.0)        21 (20.76)    0.0529 (0.0457)         0.87  0.3363          87.0     0.0517  0.002    0.06s
  48999    1055374       0.0 (0.0)        20 (21.8)     0.0342 (0.0456)         0.94  0.2375          94.0     0.0515  0.002    0.06s
  49999    1078772       0.0 (0.0)        26 (24.56)    0.0461 (0.0448)         0.89  0.3129          89.0     0.0513  0.002    0.07s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_20_28_55.chkpt at step 1078772 with epsilon threshold 0.0513

The rewards for each episode are usually zero since the winning score is added to the losing score (-1 + 1), however the average that appears in parentheses is sometimes different from zero. This is because in some episodes the model chooses an action that is not valid and obtains a reward as if it had lost the game, but a winning reward is not assigned to the other player since the next transition does not occur.

As we can see, this average tends to decrease, i.e., the model tends to make fewer errors when choosing actions that are legal.

# plot training results
plot(
    *dqn_args,
    figures_dir_path=FIGURES_DIR_PATH,
)
Model training and evaluation figure saved to exports/figures/ConnectFourNet_2023_10_29_T_20_28_55.png

In the previous graphs we can verify what we mentioned previously. Furthermore, it is observed how the number of iterations tends to increase slightly during training and evaluation.

It would be expected that the games were longer in training given that as the model learns it continues to play against itself.

However, with regards to evaluation, and considering that the model plays against random decision making, it would be expected that, although there is a slight reduction in the number of turns it takes to win with the passing of the iterations, the reduction was greater.

In any case, the model tends to win almost all the games, as we can see in Evaluation rates.

An option to try to incentivize the model to try to make the games shorter could be to assign a negative reward to params_env['rewards']['prolongation'].

Besides, since the agent is learning a strategy or action policy for a non-stationary scenario, which depends on specific actions, the graph of Running loss is only considered for the choice of the training batch size hyperparameter (batch_size) and to verify that the gradients neither increase nor decrease considerably or exponentially (exploding gradients and vanishing gradients).

We can check below how our policy acts with examples of possible legal boards in key game situations, for illustration purposes only.

First we see if the policy model is able to complete some lines of four to win a game.

# set eval mode
dqn_net.policy.eval()
titles = [
    "Horizontal P2", "Horizontal P1", "Vertical P2", "Vertical P1", "Diagonal P2",
    "Diagonal P1", "Anti diagonal P2", "Anti diagonal P1"
]
 
finishes_boards = [board for i, (board, _) in enumerate(
    finishes_boards_solutions) if i % 2 != 0]
html = get_html(
    policy=dqn_net.policy,
    observations=finishes_boards,
    titles=titles,
    device=device
)
display(HTML(html))

Horizontal P2

t
t+1

Horizontal P1

t
t+1

Vertical P2

t
t+1

Vertical P1

t
t+1

Diagonal P2

t
t+1

Diagonal P1

t
t+1

Anti-Diagonal P2

t
t+1

Anti-Diagonal P1

t
t+1

We can see how this model is capable of choosing actions to win the game either horizontally, vertically, diagonally or anti-diagonally, although not for both players in all line modalities, for both players in different modalities.

A similar thing happens when our approximator is presented with the opportunity to block lines of four consecutive counters to win a game just in the turn before these winning transitions can happen.

blocks_boards = [board for i, (board, _) in enumerate(
    blocks_boards_solutions) if i % 2 != 0]
html = get_html(
    policy=dqn_net.policy,
    observations=blocks_boards,
    titles=titles,
    device=device
)
display(HTML(html))

Horizontal P2

t
t+1

Horizontal P1

t
t+1

Vertical P2

t
t+1

Vertical P1

t
t+1

Diagonal P2

t
t+1

Diagonal P1

t
t+1

Anti diagonal P2

t
t+1

Anti diagonal P1

t
t+1

DDQN

Finally, a double DQN agent was trained in the same exercise. Similar training hyperparameters have been used.

# net
ddqn_net = ConnectFourNet(out_features=params_env['action_space']).to(device)
 
# agent
params_agent['double'] = True
ddqn_agent = DQNAgent(
    net=ddqn_net,
    params=params_agent,
    device=device,
    load_model_path=load_model_path
)
 
# train policy
ddqn_args = train(
    agent=ddqn_agent,
    env=env,
    params_train=params_train,
    params_eval=params_eval,
    checkpoints_dir_path=CHECKPOINTS_DIR_PATH,
    device=device
)
Training policy in ConnectFourNet.
  Episode     Step   Train rewards (avg)  steps (avg)   running loss (avg)  Eval reward (mean std)  win rate(%)   Eps      LR     Time
    999      20530       -1.0 (-0.08)      15 (20.2)     0.0469 (0.0457)         0.44  0.7658          61.0     0.8015   0.01    0.02s
    1999     40849       0.0 (-0.08)       23 (21.96)     0.0536 (0.044)          0.63  0.627          71.0     0.7152   0.01    0.04s
    2999     60557       0.0 (-0.08)       18 (19.4)      0.041 (0.0459)         0.77  0.4659          79.0     0.641    0.01    0.03s
    3999     79788        0.0 (0.0)        24 (19.88)     0.0379 (0.047)         0.74  0.5024          77.0     0.5766   0.01    0.03s
    4999     99764        0.0 (0.0)        27 (21.72)    0.0397 (0.0461)         0.77  0.4659          79.0     0.5172   0.01    0.04s
    5999     118905       0.0 (0.0)        18 (18.0)     0.0575 (0.0478)         0.82  0.3842          82.0     0.4665   0.01    0.04s
    6999     138750       0.0 (0.0)        25 (19.12)    0.0433 (0.0478)          0.86  0.347          86.0     0.4197   0.01    0.03s
    7999     158652       0.0 (0.0)        7 (19.64)     0.0455 (0.0477)         0.79  0.4312          80.0     0.3781   0.01    0.04s
    8999     179199      0.0 (-0.08)       22 (19.48)    0.0327 (0.0484)         0.79  0.4073          79.0      0.34    0.01    0.03s
    9999     199498       0.0 (0.0)        18 (18.16)     0.0466 (0.048)         0.87  0.3363          87.0     0.3068   0.01    0.03s
  10999     220203       0.0 (0.0)        25 (20.56)    0.0366 (0.0459)         0.85  0.3571          85.0     0.2768   0.01    0.02s
  11999     241541       0.0 (0.0)        15 (20.44)    0.0436 (0.0447)         0.89  0.3129          89.0     0.2495   0.01    0.04s
  12999     262846      0.0 (-0.04)       16 (20.32)    0.0506 (0.0469)         0.89  0.3129          89.0     0.2256   0.01    0.06s
  13999     284491       0.0 (0.0)        36 (22.92)    0.0419 (0.0466)         0.84  0.3666          84.0     0.2042   0.01    0.04s
  14999     305991      0.0 (-0.08)       21 (22.44)    0.0566 (0.0452)         0.92  0.2713          92.0     0.1855   0.01    0.05s
  15999     327637      0.0 (-0.08)       34 (24.84)     0.0477 (0.046)         0.92  0.3059          93.0     0.169    0.01    0.04s
  16999     349786       0.0 (0.0)        20 (23.08)    0.0491 (0.0466)         0.81  0.3923          81.0     0.1542   0.01    0.04s
  17999     372238      0.0 (-0.04)       24 (22.0)     0.0486 (0.0472)         0.84  0.3666          84.0     0.1411   0.01    0.05s
  18999     394402       0.0 (0.0)        21 (20.84)     0.05 (0.0436)          0.91  0.2862          91.0     0.1297   0.01    0.05s
  19999     416747      0.0 (-0.04)       24 (21.16)    0.0405 (0.0454)         0.87  0.3363          87.0     0.1197   0.01    0.04s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_20_59_16.chkpt at step 416747 with epsilon threshold 0.1197
  20999     439501      0.0 (-0.08)       17 (23.36)    0.0449 (0.0444)         0.93  0.2551          93.0     0.1108   0.01    0.04s
  21999     462133       0.0 (0.0)        11 (21.0)     0.0531 (0.0469)          0.88  0.325          88.0     0.1031   0.01    0.06s
  22999     484962       0.0 (0.0)        31 (24.04)    0.0435 (0.0465)          0.88  0.325          88.0     0.0963   0.01    0.04s
  23999     507774       0.0 (0.0)        25 (21.08)    0.0389 (0.0459)         0.89  0.3129          89.0     0.0904   0.01    0.06s
  24999     531175       0.0 (0.0)        20 (23.6)     0.0453 (0.0427)         0.89  0.3129          89.0     0.0851   0.01    0.05s
  25999     554451       0.0 (0.0)        13 (22.84)    0.0498 (0.0445)         0.91  0.2862          91.0     0.0805   0.01    0.06s
  26999     577973       0.0 (0.0)        26 (23.0)     0.0417 (0.0443)         0.87  0.3363          87.0     0.0765   0.01    0.05s
  27999     601554       0.0 (0.0)        23 (23.6)     0.0493 (0.0436)         0.84  0.3666          84.0     0.073    0.01    0.05s
  28999     624682      0.0 (-0.08)       22 (19.16)    0.0485 (0.0456)         0.93  0.2551          93.0      0.07    0.01    0.04s
  29999     648382       0.0 (0.0)        23 (22.48)    0.0519 (0.0447)         0.91  0.2862          91.0     0.0674   0.01    0.05s
  30999     671673       0.0 (0.0)        21 (22.6)      0.05 (0.0445)          0.85  0.3571          85.0     0.0651   0.01    0.05s
  31999     695409      0.0 (-0.04)       20 (24.52)    0.0426 (0.0436)         0.82  0.3842          82.0     0.0631   0.01    0.06s
  32999     719257      0.0 (-0.04)       27 (21.44)    0.0576 (0.0437)         0.79  0.4073          79.0     0.0614   0.01    0.06s
  33999     742359       0.0 (0.0)        18 (22.04)    0.0499 (0.0421)         0.87  0.3363          87.0     0.0599   0.01    0.05s
  34999     765618      -1.0 (-0.04)      32 (21.84)    0.0401 (0.0437)         0.92  0.2713          92.0     0.0586   0.01    0.08s
  35999     789207      0.0 (-0.08)       29 (22.24)    0.0469 (0.0438)         0.87  0.3363          87.0     0.0575   0.01    0.06s
  36999     812560       0.0 (0.0)        31 (24.28)    0.0492 (0.0434)          0.75  0.433          75.0     0.0565   0.01    0.05s
  37999     836424      0.0 (-0.04)       24 (26.56)    0.0431 (0.0427)         0.82  0.3842          82.0     0.0556   0.01    0.09s
  38999     859950       0.0 (0.0)        25 (23.56)    0.0395 (0.0432)         0.87  0.3363          87.0     0.0549   0.01    0.05s
  39999     883688       0.0 (0.0)        24 (24.08)    0.0397 (0.0441)         0.83  0.3756          83.0     0.0542   0.01    0.06s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_21_36_49.chkpt at step 883688 with epsilon threshold 0.0542
  40999     907411       0.0 (0.0)        26 (23.44)    0.0444 (0.0457)         0.91  0.2862          91.0     0.0537  0.002    0.05s
  41999     932271       0.0 (0.0)        28 (25.68)    0.0358 (0.0416)         0.89  0.3129          89.0     0.0532  0.002    0.07s
  42999     957047       0.0 (0.0)        30 (21.28)    0.0382 (0.0402)         0.83  0.3756          83.0     0.0527  0.002    0.07s
  43999     981954       0.0 (0.0)        30 (25.92)    0.0426 (0.0436)         0.85  0.3571          85.0     0.0523  0.002    0.07s
  44999    1007640      0.0 (-0.04)       20 (26.88)     0.0403 (0.043)         0.92  0.2713          92.0     0.052   0.002    0.08s
  45999    1032640       0.0 (0.0)        9 (22.64)     0.0459 (0.0426)         0.91  0.2862          91.0     0.0517  0.002    0.08s
  46999    1057268       0.0 (0.0)        25 (25.92)    0.0411 (0.0443)           0.9  0.3            90.0     0.0515  0.002    0.07s
  47999    1082648       0.0 (0.0)        28 (25.84)    0.0459 (0.0419)         0.89  0.3129          89.0     0.0513  0.002    0.09s
  48999    1107489       0.0 (0.0)        23 (28.04)    0.0403 (0.0428)         0.87  0.3363          87.0     0.0511  0.002    0.07s
  49999    1132553       0.0 (0.0)        24 (25.6)      0.0395 (0.041)         0.84  0.3666          84.0     0.051   0.002    0.07s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_21_57_41.chkpt at step 1132553 with epsilon threshold 0.051

As we can see below, the training results are similar to the DQN agent, however, the results of the model when faced with random decision making are slightly worse.

# plot training results
plot(
    *ddqn_args,
    figures_dir_path=FIGURES_DIR_PATH,
)
Model training and evaluation figure saved to exports/figures/ConnectFourNet_2023_10_29_T_21_57_41.png

Again, just for the sake of illustration, we pose some observations of key scenarios in the game to our model to get an idea of how it performs.

# set eval mode
ddqn_net.policy.eval()
html = get_html(
    policy=ddqn_net.policy,
    observations=finishes_boards,
    titles=titles,
    device=device
)
display(HTML(html))

Horizontal P2

t
t+1

Horizontal P1

t
t+1

Vertical P2

t
t+1

Vertical P1

t
t+1

Diagonal P2

t
t+1

Diagonal P1

t
t+1

Anti diagonal P2

t
t+1

Anti diagonal P1

t
t+1

As we see, it is still capable of resolving the best action on the majority of the boards both to win the game and to block the opposing player from winning it:

html = get_html(
    policy=ddqn_net.policy,
    observations=blocks_boards,
    titles=titles,
    device=device
)
display(HTML(html))

Horizontal P2

t
t+1

Horizontal P1

t
t+1

Vertical P2

t
t+1

Vertical P1

t
t+1

Diagonal P2

t
t+1

Diagonal P1

t
t+1

Anti diagonal P2

t
t+1

Anti diagonal P1

t
t+1

It is possible that the models do not perform as well in prolonged games since during training they receive more observations of boards with fewer counters.

There are many training parameters to modify and try to get a better generalization of the model. For example, rewards, epsilon decay rate, observation batch size, memory capacity, learning rate or even implement another neural network architecture or modify and/or include convolutional or fully connected layers .

In addition, more modern reinforcement learning algorithms such as proximal policy optimization (PPO) can be implemented, or make use of a memory of observations with prioritized sampling.

The library TorchRL, which is a reinforcement learning library for PyTorch, implements these and many other options.

Export

As noted in the introduction to this Jupyter notebook, this model is used in a decentralised zero-knowledge web application.

For this reason, the module torch.onnx is used to convert the native PyTorch calculation graph to a ONNX graph and thus, using the ONNX Runtime Web JavaScript library, this machine learning model can be deployed in our web application.

# export onnx version of the policy trained by the DDQN
_, _, _, model_id = dqn_args
export_onnx(
    policy=dqn_agent.net.policy,
    policies_dir_path=POLICIES_DIR_PATH,
    model_id=model_id,
    device=device
)
Model in ONNX format saved to exports/policies/ConnectFourNet_2023_10_29_T_20_28_55.onnx

Related posts

Regression

Predicting Ames housing prices: pipeline workflow

Price regression on the Ames housing prices dataset

Classification

Classifying news from 20NewsGroup

Classifying text in the 20Newsgroup dataset

Microservices

Approach to a microservices-based architecture bank application

A microservices-based architecture bank application that includes back-end and front-end applications, as well as a...

Data visualization

Predicting Ames housing prices: exploratory data analysis

Exploratory data analysis on the Ames housing prices dataset


Ready to #buidl?

Are you interested in Web3 or the synergies between blockchain technology, artificial intelligence and zero knowledge?. Then, do not hesitate to contact me by e-mail or on my LinkedIn profile. You can also find me on GitHub.