Formación de un agente DQN en el juego de Conecta Cuatro
Durante el desarrollo de una aplicación de conocimiento cero llamada zk Connect Four se dio el contexto adecuado para la implementación de un agente de aprendizaje por refuerzo.
El aprendizaje por refuerzo es un paradigma de aprendizaje automático en el que un agente aprende de su entorno a través de prueba y error. Es decir, se trata de una aproximación computacional al aprendizaje a partir de acciones.
El objetivo es aprender cuáles son las acciones más adecuadas en cada estado del juego para maximizar una recompensa. Para ello, en este cuaderno se implementan un agente DQN y un agente DQN doble (DDQN) para optimizar un aproximador en el juego Conecta Cuatro utilizando PyTorch.
En este artículo
Configuración
Conceptos fundamentales
Los actores principales en el paradigma de aprendizaje por refuerzo son el entorno y el agente. El entorno es el contexto en el que el agente vive y con el que interactúa.
En cada iteración, el agente recibe una observación del estado actual del juego, decide la mejor acción a tomar y actúa sobre el entorno.
De esta manera, se genera un cambio en el estado del entorno y se obtiene una recompensa a partir de esta transición.
El objetivo de este ejercicio es obtener una estrategia, plan o política de acción (policy) que sea capaz de maximizar esta recompensa, obteniendo así la mejor acción para cada transición y propiciando una aproximación óptima al juego de Conecta Cuatro.
A continuación se introducen las principales clases que fundamentan este paradigma y que nos permitirán dar conformidad a los actores y al contexto del juego.
Entorno
La implementación del entorno consiste en una clase cuyos métodos y acometidos principales son los siguientes:
-
get_valid_actions: provee de una lista con todas las acciones válidas dado el estado actual, es decir, el tablero en ese momento del juego. -
step: acomete una transición con la acción recibida por parte del agente (un índice de una columna del tablero), comprueba si se ha ganado o empatado y devuelve una recompensa dependiendo del resultado. -
switch_turn: cambia el turno de juego entre ambos jugadores. -
reset: inicializa el estado y todas las variables del juego y devuelve un estado inicial.
Agentes DQN y DDQN
DQN y DDQN son en esencia unos algoritmos que consisten en aprender un aproximador para una función Q de acción-valor.
En nuestro caso, se intentará aprender las estrategias de ambos jugadores del juego Conecta Cuatro, por lo que en vez de tratarse de un agente que solamente es entrenado en una estrategia, se trataría de una modalidad de multiagente de aprendizaje por refuerzo (MARL).
Entonces, se presenta un escenario de juego alterno de Markov, o un juego de suma cero donde la ganancia de un jugador es proporcional a la pérdida del otro.
La implementación de la clase que se utiliza para instanciar los agentes consiste de los siguientes métodos:
En cada iteración, se insta al agente a que nos proporcione una acción de acuerdo a una estrategia epsilon-greedy para tratar el dilema exploración-explotación.
valid_actions = env.get_valid_actions()
action = agent.act(state, valid_actions,
num_steps, enforce_valid_action)
En ocasiones será nuestro modelo el que prediga la acción a tomar y en otras se muestreará una acción.
Como se puede ver en el gráfico de abajo, en cada iteración la probabilidad de que se exploren nuevas acciones o de que se exploten las acciones aprendidas decaerá exponencialmente en base al ritmo de decadencia eps_decay.
# epilson decay graph
# num episodes times avg steps per episode
num_steps = 50000 * 25
eps_start = 0.9
eps_end = 0.05
eps_decay = num_steps / 7.5
steps = np.arange(num_steps)
epsilon = eps_end + (eps_start - eps_end) * np.exp(-1 * steps / eps_decay)
plt.rcParams["figure.figsize"] = (20, 4)
plt.plot(steps, epsilon)
plt.title('Epsilon decay')
plt.xlabel('Step')
plt.ylabel('Epsilon')
plt.show()En el método act del agente DQN, a continuación, se implementa esta estrategia.
Si el valor de la muestra es superior al umbral de epsilon, se aproximará la mejor acción dada la última observación del entorno (explotación de las acciones aprendidas).
En caso contrario, se elegirá una acción (exploración) de entre las acciones válidas de manera uniforme.
Se da la opción, mediante
enforce_valid_action, de asegurar que la acción elegida sea válida, o de que se inste al modelo a que aprenda qué columnas del tablero pueden optar en cada turno a una nueva ficha, es decir, a que aprenda esta regla del juego.
Los datos que se utilizan para entrenar el agente se generan desde cero durante cada entrenamiento. Estos datos se almacenan en un objeto de clase ExperienceReplay y se muestrean aleatoriamente con el objetivo de optimizar una estrategia o plan de acción (policy).
Las observaciones en cada iteración se adjuntan como tuplas nombradas (itertools' named_tuple) a una lista secuencia (deque, o cola de doble extremo) con capacidad limitada, donde las observaciones se pueden insertar y eliminar desde ambos extremos, creando así una experiencia de juego.
Cada observación incluye un estado (el tablero del juego en ese preciso instante), una acción (el índice de la columna que ha sido elegida), un siguiente estado (el tablero del juego tras considerar la última acción, o None si es un estado final, es decir si se ha tomado una acción ganadora) y una recompensa.
Una vez se hayan obtenido y memorizado más observaciones que el número de batch_size que utilicemos, procederemos a optimizar el modelo.
En nuestro ejemplo, solamente se optimiza una vez por episodio dado que los valores de las transiciones dependerán en algunas ocasiones de transiciones anteriores.
Por ejemplo, cuando un jugador gana la partida obtiene una recompensa positiva, por lo que para la transición del estado en esta iteración se genera una tupla de la siguiente manera:
transition = Transition(state, action, None, reward)
La transición consistiría en el estado en t, la acción tomada, el estado en t+1 (al ser un estado final se le asigna None) y la recompensa obtenida.
Lo que sucede es que durante la generación de la observación en la iteración anterior, no sabíamos todavía de este desenlace y por lo tanto la recompensa no fue negativa por haber perdido el juego.
Por esta razón, una vez termina cada partida del entrenamiento se resuelve esta relación de recompensas para guardarlas posteriormente en memoria.
Por lo tanto, el transición en t-1 de la transición de arriba sería:
transition_t-1 = Transition(state, action, next_state, reward + env.params['rewards']['loss'])
Se suma a la recompensa asignada y no se asigna directamente la recompensa de pérdida dado que es posible configurar una recompensa por prolongación en el juego como hiperparámetro.
El proceso de optimización se muestra a continuación:
En primer lugar, se muestrean en el método optimize un número de observaciones en memoria, y se concatenan todos los valores de estado y acción de las observaciones en matrices independientes.
Además, tanto los estados en t como en t+1 obtienen una nueva dimensión dado que se aplica una transformación del tablero mediante la función get_two_channels.
La idea es que por cada tablero con valores 0, 1 o 2 se obtengan dos matrices con valores binarios, una para cada jugador.
Estas dos matrices serán los dos canales de entrada por observación que obtendrá nuestra red neuronal convolucional.
Después se genera una matriz lógica que se utilizará como máscara para asignar los valores Q a los estados que no sean finales, mientras que los finales tendrán un valor de cero.
En el caso de que se utilice un agente DQN, se utiliza solamente el modelo de target para obtener los valores de Q, obteniendo simplemente los valores de máxima magnitud de entre todas las acciones.
Mientras que con un agente DDQN, para intentar solventar una posible sobreestimación de estos valores, se utilizan tanto nuestro modelo policy como nuestro modelo "fijo" target para obtener los valores de diferencia temporal de Q a los que nos queremos aproximar.
Posteriormente, se calcula la ganancia a largo plazo que se puede anticipar. Para ello, se aplica un factor de descuento gamma y se resta de la recompensa para esa transición (en lugar de sumarse como en la ecuación de Bellman tal y como se sugiere en este cuaderno de Kaggle).
De esta manera adaptamos nuestra optimización para el aprendizaje de ambas estrategias al juego alterno de suma cero que nos ocupa.
Es en este momento cuando se optimizan los parámetros de nuestra policy. Para ello se calcula la pérdida (en nuestro caso se utiliza el criterio torch.nn.HuberLoss) para minimizar el error de la diferencia temporal δ, se asignan valores nulos a los gradientes para posteriormente ser computados mediante backpropagation, se aplica un clipeado de los mismos en caso de que así se haya configurado el entrenamiento y se actualizan los parámetros en base al gradiente calculado.
Una vez optimizados los parámetros de la policy durante el entrenamiento, se plantea si actualizar los parámetros de nuestro target. Para ello se implementan dos estrategias: soft update y hard update.
La primera actualizará los parámetros de forma gradual en cada episodio solamente obteniendo una parte porcentual de los parámetros de policy en base al valor TAU considerado como hiperparámetro en el entrenamiento.
Por ejemplo, si TAU es 0.001, target pasará a utilizar el 0.1% de los parámetros de policy y mantendrá el 99,9%.
Por el contrario, mediante hard update se copiarán todos los parámetros de golpe si se alcanza un número de iteraciones a configurar.
Por último, se implementan tres métodos para guardar transiciones en memoria, guardar parámetros actuales de policy para configurar checkpoints durante el entrenamiento y para cargar los pesos de un modelo guardado al principio de un entrenamiento.
Nótese que el optimizador se instancia una vez se hayan podido cargar los pesos del modelo de uno previamente guardado.
Modelado
El modelo elegido para aprender el juego de Conecta Cuatro consiste de dos capas convolucionales con activación ReLU y dos capas completamente conectadas (la primera también con activación ReLU) en la salida.
En el mismo objeto de clase ConnectFourNet se instancian dos módulos, uno para policy y otro para target. Este último se copia del primero una vez se hayan inicializado sus parámetros (pesos y sesgo).
De esta manera podemos guardar los parámetros de ambos modelos durante el entrenamiento en el mismo diccionario de estados.
ConnectFourNet(
(policy): Sequential(
(0): Conv2d(2, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Flatten(start_dim=1, end_dim=-1)
(5): Linear(in_features=672, out_features=128, bias=True)
(6): ReLU()
(7): Linear(in_features=128, out_features=7, bias=True)
)
(target): Sequential(
(0): Conv2d(2, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Flatten(start_dim=1, end_dim=-1)
(5): Linear(in_features=672, out_features=128, bias=True)
(6): ReLU()
(7): Linear(in_features=128, out_features=7, bias=True)
)
)
ConnectFourNet(out_features=7).layer_summary([512, 2, 6, 7])Conv2d output shape: torch.Size([512, 16, 6, 7])
Conv2d output shape: torch.Size([512, 16, 6, 7])
Flatten output shape: torch.Size([512, 672])
Linear output shape: torch.Size([512, 128])
Linear output shape: torch.Size([512, 7])
Entrenamiento
Como se ha apuntado previamente, el entrenamiento de la política de juego aúna las estrategias de ambos jugadores. Para ello se experimenta con la utilización de los agentes DQN y DDQN.
DQN
En primer lugar configuramos todos los parámetros no solo del entrenamiento sino también de la evaluación, entorno y agente, y con estos parámetros se instancian el módulo con las redes neuronales, el agente y el entorno.
Una vez instanciadas todas las entidades y parámetros que se necesitan podemos iniciar el entrenamiento del agente para optimizar la policy.
La función train muestra métricas de iteraciones de manera periódica. Algunos valores corresponderán a medias obtenidas a partir de valores actuales y de iteraciones previas para poder reconocer las tendencias durante el entrenamiento.
En cada episodio se evalúa el modelo de policy jugando contra un agente random.
# train policy
dqn_args = train(
agent=dqn_agent,
env=env,
params_train=params_train,
params_eval=params_eval,
checkpoints_dir_path=CHECKPOINTS_DIR_PATH,
device=device
)Training policy in ConnectFourNet.
Episode Step Train rewards (avg) steps (avg) running loss (avg) Eval reward (mean std) win rate(%) Eps LR Time
999 20467 0.0 (-0.04) 19 (19.84) 0.0464 (0.0459) 0.46 0.7269 60.0 0.8018 0.01 0.02s
1999 40734 0.0 (-0.04) 18 (17.0) 0.0502 (0.0453) 0.65 0.5723 70.0 0.7157 0.01 0.02s
2999 60847 -1.0 (-0.08) 32 (23.28) 0.055 (0.048) 0.77 0.4659 79.0 0.64 0.01 0.03s
3999 80435 0.0 (0.0) 23 (21.64) 0.0513 (0.0467) 0.79 0.4312 80.0 0.5746 0.01 0.03s
4999 100228 0.0 (-0.08) 28 (20.48) 0.0469 (0.0487) 0.81 0.4403 83.0 0.5159 0.01 0.02s
5999 119836 0.0 (-0.04) 12 (18.96) 0.0389 (0.0475) 0.93 0.2551 93.0 0.4641 0.01 0.04s
6999 139392 0.0 (-0.04) 7 (18.04) 0.0413 (0.0471) 0.89 0.3434 90.0 0.4183 0.01 0.03s
7999 159565 0.0 (-0.04) 31 (21.44) 0.0488 (0.0489) 0.94 0.2375 94.0 0.3763 0.01 0.05s
8999 179803 0.0 (-0.08) 24 (20.24) 0.0483 (0.0486) 0.93 0.2551 93.0 0.339 0.01 0.04s
9999 199816 0.0 (0.0) 26 (21.68) 0.0533 (0.0458) 0.89 0.3434 90.0 0.3063 0.01 0.05s
10999 220056 0.0 (-0.04) 24 (19.76) 0.0542 (0.0464) 0.87 0.3648 88.0 0.277 0.01 0.06s
11999 240844 0.0 (-0.08) 24 (21.76) 0.0496 (0.0464) 0.89 0.3129 89.0 0.2504 0.01 0.04s
12999 261341 0.0 (0.0) 11 (21.28) 0.0445 (0.0462) 0.92 0.2713 92.0 0.2272 0.01 0.04s
13999 282521 0.0 (-0.04) 22 (20.56) 0.0517 (0.049) 0.9 0.3 90.0 0.206 0.01 0.04s
14999 304168 0.0 (0.0) 28 (23.28) 0.0407 (0.0456) 0.91 0.2862 91.0 0.187 0.01 0.05s
15999 325292 0.0 (-0.08) 7 (22.2) 0.0489 (0.0465) 0.93 0.2551 93.0 0.1707 0.01 0.05s
16999 346287 0.0 (-0.04) 16 (22.32) 0.0503 (0.0474) 0.89 0.3129 89.0 0.1564 0.01 0.05s
17999 367640 0.0 (0.0) 20 (20.0) 0.0466 (0.0471) 0.92 0.2713 92.0 0.1436 0.01 0.04s
18999 388744 0.0 (-0.04) 22 (20.68) 0.0484 (0.0458) 0.92 0.2713 92.0 0.1325 0.01 0.05s
19999 409492 0.0 (0.0) 21 (19.24) 0.0464 (0.0458) 0.93 0.2551 93.0 0.1228 0.01 0.04s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_19_33_58.chkpt at step 409492 with epsilon threshold 0.1228
20999 430888 0.0 (-0.04) 19 (23.44) 0.0566 (0.0453) 0.87 0.3363 87.0 0.1141 0.01 0.05s
21999 452183 0.0 (0.0) 25 (18.68) 0.0498 (0.0447) 0.87 0.3363 87.0 0.1064 0.01 0.04s
22999 473468 0.0 (-0.04) 15 (22.16) 0.0454 (0.0471) 0.95 0.2179 95.0 0.0996 0.01 0.05s
23999 494580 0.0 (-0.04) 34 (20.04) 0.0515 (0.0476) 0.95 0.2179 95.0 0.0937 0.01 0.04s
24999 516624 0.0 (-0.04) 26 (22.6) 0.0508 (0.0452) 0.94 0.2375 94.0 0.0883 0.01 0.05s
25999 538052 0.0 (0.0) 23 (21.6) 0.0535 (0.0453) 0.87 0.3363 87.0 0.0837 0.01 0.04s
26999 560029 0.0 (-0.04) 30 (24.28) 0.0434 (0.0474) 0.92 0.2713 92.0 0.0795 0.01 0.04s
27999 581599 0.0 (0.0) 27 (21.32) 0.0423 (0.0452) 0.84 0.3929 85.0 0.0759 0.01 0.05s
28999 604226 0.0 (-0.04) 26 (25.16) 0.0406 (0.0452) 0.83 0.3756 83.0 0.0726 0.01 0.05s
29999 626930 0.0 (0.0) 21 (21.96) 0.0535 (0.0442) 0.93 0.2551 93.0 0.0698 0.01 0.04s
30999 649048 0.0 (-0.04) 32 (20.72) 0.0378 (0.0451) 0.87 0.3648 88.0 0.0673 0.01 0.06s
31999 670741 0.0 (0.0) 23 (21.4) 0.0382 (0.0449) 0.9 0.3 90.0 0.0652 0.01 0.05s
32999 693199 0.0 (0.0) 13 (20.52) 0.0399 (0.0448) 0.94 0.2375 94.0 0.0633 0.01 0.05s
33999 715469 0.0 (0.0) 18 (21.52) 0.0429 (0.0475) 0.91 0.2862 91.0 0.0616 0.01 0.05s
34999 737025 0.0 (0.0) 26 (21.68) 0.0523 (0.0458) 0.86 0.347 86.0 0.0602 0.01 0.05s
35999 759430 0.0 (-0.04) 21 (24.16) 0.0389 (0.0452) 0.93 0.2551 93.0 0.0589 0.01 0.05s
36999 782127 0.0 (0.0) 22 (18.28) 0.053 (0.0465) 0.92 0.2713 92.0 0.0578 0.01 0.05s
37999 804640 0.0 (0.0) 12 (22.24) 0.0516 (0.0458) 0.92 0.2713 92.0 0.0568 0.01 0.05s
38999 827212 0.0 (-0.08) 18 (22.08) 0.0368 (0.0437) 0.92 0.2713 92.0 0.0559 0.01 0.05s
39999 850050 0.0 (0.0) 21 (20.36) 0.038 (0.0458) 0.88 0.325 88.0 0.0552 0.01 0.05s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_20_09_30.chkpt at step 850050 with epsilon threshold 0.0552
40999 872524 0.0 (0.0) 31 (22.92) 0.0479 (0.045) 0.95 0.2179 95.0 0.0545 0.002 0.05s
41999 894679 0.0 (-0.04) 23 (21.92) 0.0392 (0.045) 0.89 0.3129 89.0 0.054 0.002 0.07s
42999 916012 0.0 (0.0) 19 (23.56) 0.0451 (0.0452) 0.91 0.2862 91.0 0.0535 0.002 0.07s
43999 938584 0.0 (0.0) 27 (24.32) 0.0395 (0.0445) 0.94 0.2375 94.0 0.053 0.002 0.06s
44999 961498 0.0 (0.0) 42 (24.92) 0.0521 (0.0439) 0.86 0.347 86.0 0.0527 0.002 0.06s
45999 984876 0.0 (0.0) 12 (20.08) 0.0415 (0.0442) 0.88 0.325 88.0 0.0523 0.002 0.06s
46999 1007943 0.0 (0.0) 23 (22.36) 0.0515 (0.0459) 0.86 0.347 86.0 0.052 0.002 0.05s
47999 1031514 0.0 (0.0) 21 (20.76) 0.0529 (0.0457) 0.87 0.3363 87.0 0.0517 0.002 0.06s
48999 1055374 0.0 (0.0) 20 (21.8) 0.0342 (0.0456) 0.94 0.2375 94.0 0.0515 0.002 0.06s
49999 1078772 0.0 (0.0) 26 (24.56) 0.0461 (0.0448) 0.89 0.3129 89.0 0.0513 0.002 0.07s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_20_28_55.chkpt at step 1078772 with epsilon threshold 0.0513
Las recompensas para cada episodio suelen ser cero dado que se suma la puntuación ganadora a la perdedora (-1 + 1), en cambio la media que aparece entre paréntesis en ocasiones es diferente a cero. Esto es debido a que en algunos episodios el modelo elige una acción que no es válida y obtiene una recompensa como si hubiese perdido el juego, pero no se le asigna una recompensa ganadora al otro jugador puesto que no se da la transición siguiente.
Como podemos ver, esta media tiende a disminuir, es decir, el modelo tiende a cometer menos errores a la hora de elegir acciones que sean legales.
# plot training results
plot(
*dqn_args,
figures_dir_path=FIGURES_DIR_PATH,
)Model training and evaluation figure saved to exports/figures/ConnectFourNet_2023_10_29_T_20_28_55.png
En los gráficos anteriores podemos comprobar lo que comentamos previamente. Además, se observa cómo la cantidad de iteraciones tiende a aumentar ligeramente durante el entrenamiento y la evaluación.
Cabía esperar que en el entrenamiento se alargasen las partidas dado que a medida que el modelo aprende sigue jugando contra sí mismo.
En cambio, en lo que respecta a la evaluación, y considerando que el modelo juega contra una toma de decisiones aleatorias, cabría esperar que, pese a que hay una ligera reducción en la cantidad de turnos que necesita para ganar con el paso de las iteraciones, la reducción fuese mayor.
De cualquier manera, el modelo tiende a ganar casi la totalidad de las partidas, como podemos ver en Evaluation rates.
Una opción a probar para incentivar al modelo a intentar que las partidas sean más cortas podría ser asignar una recompensa negativa a
params_env['rewards']['prolongation'].
Por otra parte, al estar aprendiendo una estrategia o política de acción para un escenario no estacionario, el cual depende de acciones puntuales, solamente se considera el gráfico de Running loss para la elección del hiperparámetro del tamaño del lote de entrenamiento (batch_size) y para comprobar que los gradientes ni aumentan ni disminuyen considerable o exponencialmente (exploding gradients y vanishing gradients).
A continuación podemos comprobar cómo actúa nuestra policy con ejemplos de posibles tableros legales en situaciones clave del juego, solamente por motivos de ilustración.
En primer lugar vemos si es capaz de completar algunas líneas de cuatro para ganar un juego.
# set eval mode
dqn_net.policy.eval()
titles = [
"Horizontal P2", "Horizontal P1", "Vertical P2", "Vertical P1", "Diagonal P2",
"Diagonal P1", "Anti diagonal P2", "Anti diagonal P1"
]
finishes_boards = [board for i, (board, _) in enumerate(
finishes_boards_solutions) if i % 2 != 0]
html = get_html(
policy=dqn_net.policy,
observations=finishes_boards,
titles=titles,
device=device
)
display(HTML(html))Horizontal P2
t
t+1
Horizontal P1
t
t+1
Vertical P2
t
t+1
Vertical P1
t
t+1
Diagonal P2
t
t+1
Diagonal P1
t
t+1
Anti-Diagonal P2
t
t+1
Anti-Diagonal P1
t
t+1
Podemos ver como este modelo es capaz de elegir acciones para ganar el juego ya sea horizontal, vertical, diagonal o anti diagonalmente, aunque no para ambos jugadores en todas las modalidades de líneas, sí para ambos jugadores en distintas modalidades.
De manera similar sucede cuando a nuestro aproximador se le presenta la ocasión de bloquear líneas de cuatro fichas consecutivas para ganar un juego justo en el turno anterior de que puedan suceder estas transiciones ganadoras.
blocks_boards = [board for i, (board, _) in enumerate(
blocks_boards_solutions) if i % 2 != 0]
html = get_html(
policy=dqn_net.policy,
observations=blocks_boards,
titles=titles,
device=device
)
display(HTML(html))Horizontal P2
t
t+1
Horizontal P1
t
t+1
Vertical P2
t
t+1
Vertical P1
t
t+1
Diagonal P2
t
t+1
Diagonal P1
t
t+1
Anti diagonal P2
t
t+1
Anti diagonal P1
t
t+1
DDQN
Por último se ha procedido al entrenamiento de un agente DQN doble en el mismo ejercicio. Se han utilizado hiperparámetros de entrenamiento similares.
# net
ddqn_net = ConnectFourNet(out_features=params_env['action_space']).to(device)
# agent
params_agent['double'] = True
ddqn_agent = DQNAgent(
net=ddqn_net,
params=params_agent,
device=device,
load_model_path=load_model_path
)
# train policy
ddqn_args = train(
agent=ddqn_agent,
env=env,
params_train=params_train,
params_eval=params_eval,
checkpoints_dir_path=CHECKPOINTS_DIR_PATH,
device=device
)Training policy in ConnectFourNet.
Episode Step Train rewards (avg) steps (avg) running loss (avg) Eval reward (mean std) win rate(%) Eps LR Time
999 20530 -1.0 (-0.08) 15 (20.2) 0.0469 (0.0457) 0.44 0.7658 61.0 0.8015 0.01 0.02s
1999 40849 0.0 (-0.08) 23 (21.96) 0.0536 (0.044) 0.63 0.627 71.0 0.7152 0.01 0.04s
2999 60557 0.0 (-0.08) 18 (19.4) 0.041 (0.0459) 0.77 0.4659 79.0 0.641 0.01 0.03s
3999 79788 0.0 (0.0) 24 (19.88) 0.0379 (0.047) 0.74 0.5024 77.0 0.5766 0.01 0.03s
4999 99764 0.0 (0.0) 27 (21.72) 0.0397 (0.0461) 0.77 0.4659 79.0 0.5172 0.01 0.04s
5999 118905 0.0 (0.0) 18 (18.0) 0.0575 (0.0478) 0.82 0.3842 82.0 0.4665 0.01 0.04s
6999 138750 0.0 (0.0) 25 (19.12) 0.0433 (0.0478) 0.86 0.347 86.0 0.4197 0.01 0.03s
7999 158652 0.0 (0.0) 7 (19.64) 0.0455 (0.0477) 0.79 0.4312 80.0 0.3781 0.01 0.04s
8999 179199 0.0 (-0.08) 22 (19.48) 0.0327 (0.0484) 0.79 0.4073 79.0 0.34 0.01 0.03s
9999 199498 0.0 (0.0) 18 (18.16) 0.0466 (0.048) 0.87 0.3363 87.0 0.3068 0.01 0.03s
10999 220203 0.0 (0.0) 25 (20.56) 0.0366 (0.0459) 0.85 0.3571 85.0 0.2768 0.01 0.02s
11999 241541 0.0 (0.0) 15 (20.44) 0.0436 (0.0447) 0.89 0.3129 89.0 0.2495 0.01 0.04s
12999 262846 0.0 (-0.04) 16 (20.32) 0.0506 (0.0469) 0.89 0.3129 89.0 0.2256 0.01 0.06s
13999 284491 0.0 (0.0) 36 (22.92) 0.0419 (0.0466) 0.84 0.3666 84.0 0.2042 0.01 0.04s
14999 305991 0.0 (-0.08) 21 (22.44) 0.0566 (0.0452) 0.92 0.2713 92.0 0.1855 0.01 0.05s
15999 327637 0.0 (-0.08) 34 (24.84) 0.0477 (0.046) 0.92 0.3059 93.0 0.169 0.01 0.04s
16999 349786 0.0 (0.0) 20 (23.08) 0.0491 (0.0466) 0.81 0.3923 81.0 0.1542 0.01 0.04s
17999 372238 0.0 (-0.04) 24 (22.0) 0.0486 (0.0472) 0.84 0.3666 84.0 0.1411 0.01 0.05s
18999 394402 0.0 (0.0) 21 (20.84) 0.05 (0.0436) 0.91 0.2862 91.0 0.1297 0.01 0.05s
19999 416747 0.0 (-0.04) 24 (21.16) 0.0405 (0.0454) 0.87 0.3363 87.0 0.1197 0.01 0.04s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_20_59_16.chkpt at step 416747 with epsilon threshold 0.1197
20999 439501 0.0 (-0.08) 17 (23.36) 0.0449 (0.0444) 0.93 0.2551 93.0 0.1108 0.01 0.04s
21999 462133 0.0 (0.0) 11 (21.0) 0.0531 (0.0469) 0.88 0.325 88.0 0.1031 0.01 0.06s
22999 484962 0.0 (0.0) 31 (24.04) 0.0435 (0.0465) 0.88 0.325 88.0 0.0963 0.01 0.04s
23999 507774 0.0 (0.0) 25 (21.08) 0.0389 (0.0459) 0.89 0.3129 89.0 0.0904 0.01 0.06s
24999 531175 0.0 (0.0) 20 (23.6) 0.0453 (0.0427) 0.89 0.3129 89.0 0.0851 0.01 0.05s
25999 554451 0.0 (0.0) 13 (22.84) 0.0498 (0.0445) 0.91 0.2862 91.0 0.0805 0.01 0.06s
26999 577973 0.0 (0.0) 26 (23.0) 0.0417 (0.0443) 0.87 0.3363 87.0 0.0765 0.01 0.05s
27999 601554 0.0 (0.0) 23 (23.6) 0.0493 (0.0436) 0.84 0.3666 84.0 0.073 0.01 0.05s
28999 624682 0.0 (-0.08) 22 (19.16) 0.0485 (0.0456) 0.93 0.2551 93.0 0.07 0.01 0.04s
29999 648382 0.0 (0.0) 23 (22.48) 0.0519 (0.0447) 0.91 0.2862 91.0 0.0674 0.01 0.05s
30999 671673 0.0 (0.0) 21 (22.6) 0.05 (0.0445) 0.85 0.3571 85.0 0.0651 0.01 0.05s
31999 695409 0.0 (-0.04) 20 (24.52) 0.0426 (0.0436) 0.82 0.3842 82.0 0.0631 0.01 0.06s
32999 719257 0.0 (-0.04) 27 (21.44) 0.0576 (0.0437) 0.79 0.4073 79.0 0.0614 0.01 0.06s
33999 742359 0.0 (0.0) 18 (22.04) 0.0499 (0.0421) 0.87 0.3363 87.0 0.0599 0.01 0.05s
34999 765618 -1.0 (-0.04) 32 (21.84) 0.0401 (0.0437) 0.92 0.2713 92.0 0.0586 0.01 0.08s
35999 789207 0.0 (-0.08) 29 (22.24) 0.0469 (0.0438) 0.87 0.3363 87.0 0.0575 0.01 0.06s
36999 812560 0.0 (0.0) 31 (24.28) 0.0492 (0.0434) 0.75 0.433 75.0 0.0565 0.01 0.05s
37999 836424 0.0 (-0.04) 24 (26.56) 0.0431 (0.0427) 0.82 0.3842 82.0 0.0556 0.01 0.09s
38999 859950 0.0 (0.0) 25 (23.56) 0.0395 (0.0432) 0.87 0.3363 87.0 0.0549 0.01 0.05s
39999 883688 0.0 (0.0) 24 (24.08) 0.0397 (0.0441) 0.83 0.3756 83.0 0.0542 0.01 0.06s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_21_36_49.chkpt at step 883688 with epsilon threshold 0.0542
40999 907411 0.0 (0.0) 26 (23.44) 0.0444 (0.0457) 0.91 0.2862 91.0 0.0537 0.002 0.05s
41999 932271 0.0 (0.0) 28 (25.68) 0.0358 (0.0416) 0.89 0.3129 89.0 0.0532 0.002 0.07s
42999 957047 0.0 (0.0) 30 (21.28) 0.0382 (0.0402) 0.83 0.3756 83.0 0.0527 0.002 0.07s
43999 981954 0.0 (0.0) 30 (25.92) 0.0426 (0.0436) 0.85 0.3571 85.0 0.0523 0.002 0.07s
44999 1007640 0.0 (-0.04) 20 (26.88) 0.0403 (0.043) 0.92 0.2713 92.0 0.052 0.002 0.08s
45999 1032640 0.0 (0.0) 9 (22.64) 0.0459 (0.0426) 0.91 0.2862 91.0 0.0517 0.002 0.08s
46999 1057268 0.0 (0.0) 25 (25.92) 0.0411 (0.0443) 0.9 0.3 90.0 0.0515 0.002 0.07s
47999 1082648 0.0 (0.0) 28 (25.84) 0.0459 (0.0419) 0.89 0.3129 89.0 0.0513 0.002 0.09s
48999 1107489 0.0 (0.0) 23 (28.04) 0.0403 (0.0428) 0.87 0.3363 87.0 0.0511 0.002 0.07s
49999 1132553 0.0 (0.0) 24 (25.6) 0.0395 (0.041) 0.84 0.3666 84.0 0.051 0.002 0.07s
Model saved to exports/checkpoints/ConnectFourNet_2023_10_29_T_21_57_41.chkpt at step 1132553 with epsilon threshold 0.051
Como podemos ver a continuación, los resultados del entrenamiento son similares al agente DQN, en cambio, los resultados del modelo frente a una toma de decisiones aleatoria son ligeramente peores.
# plot training results
plot(
*ddqn_args,
figures_dir_path=FIGURES_DIR_PATH,
)Model training and evaluation figure saved to exports/figures/ConnectFourNet_2023_10_29_T_21_57_41.png
De nuevo, solo por motivos de ilustración, le planteamos a nuestro modelo unas observaciones de escenarios clave en el juego para tener una idea de cómo se desenvuelve.
# set eval mode
ddqn_net.policy.eval()
html = get_html(
policy=ddqn_net.policy,
observations=finishes_boards,
titles=titles,
device=device
)
display(HTML(html))Horizontal P2
t
t+1
Horizontal P1
t
t+1
Vertical P2
t
t+1
Vertical P1
t
t+1
Diagonal P2
t
t+1
Diagonal P1
t
t+1
Anti diagonal P2
t
t+1
Anti diagonal P1
t
t+1
Como vemos, sigue siendo capaz de resolver la mejor acción en la mayoría de tableros tanto para ganar la partida como para bloquear al jugador contrario a que la gane:
html = get_html(
policy=ddqn_net.policy,
observations=blocks_boards,
titles=titles,
device=device
)
display(HTML(html))Horizontal P2
t
t+1
Horizontal P1
t
t+1
Vertical P2
t
t+1
Vertical P1
t
t+1
Diagonal P2
t
t+1
Diagonal P1
t
t+1
Anti diagonal P2
t
t+1
Anti diagonal P1
t
t+1
Cabe la posibilidad que los modelos no rindan tan bien en partidas prolongadas del juego dado que durante el entrenamiento reciben más observaciones de tableros con menos fichas.
Hay muchos parámetros del entrenamiento a modificar e intentar una mejor generalización del modelo. Por ejemplo, las recompensas, la tasa de decadencia de epsilon, el tamaño del lote de observaciones, la capacidad de la memoria, la tasa de aprendizaje o incluso implementar otra arquitectura de red neuronal o modificar y/o incluir capas convolucionales o las completamente conectadas.
Además, se pueden implementar algoritmos de aprendizaje por refuerzo más modernos como la optimización de políticas proximales (PPO), o hacer uso de una memoria de observaciones con muestreado priorizado.
La librería TorchRL, que es una librería de aprendizaje por refuerzo para PyTorch, implementa estas y otras muchas opciones.
Exportación
Como se apunta en la introducción de este cuaderno de Jupyter, este modelo se utiliza en una aplicación web descentralizada de conocimiento cero.
Por esta razón, se hace uso del módulo torch.onnx para convertir el gráfico de cálculo nativo de PyTorch a un gráfico ONNX y así mediante la librería de JavaScript ONNX Runtime Web se pueda desplegar este modelo de aprendizaje automático en nuestra aplicación web.
# export onnx version of the policy trained by the DDQN
_, _, _, model_id = dqn_args
export_onnx(
policy=dqn_agent.net.policy,
policies_dir_path=POLICIES_DIR_PATH,
model_id=model_id,
device=device
)Model in ONNX format saved to exports/policies/ConnectFourNet_2023_10_29_T_20_28_55.onnx