mcts.c (11604B)
1 #include "hexes.h" 2 3 extern inline mcts_node_relptr_t 4 mcts_node_abs2rel(void *restrict base, struct mcts_node *restrict absptr); 5 6 extern inline struct mcts_node * 7 mcts_node_rel2abs(void *base, mcts_node_relptr_t relptr); 8 9 static void 10 mcts_node_init(struct mcts_node *self, struct mcts_node *parent, 11 enum hex_player player, u8 x, u8 y, u32 children_cap) 12 { 13 assert(self); 14 15 self->parent = mcts_node_abs2rel(self, parent); 16 self->player = player; 17 self->x = x; 18 self->y = y; 19 20 self->wins = self->rave_wins = 0; 21 self->plays = self->rave_plays = 0; 22 23 self->children_cap = children_cap; 24 self->children_len = 0; 25 26 self->children.head = self->children.tail = RELPTR_NULL; 27 self->list_node.prev = self->list_node.next = RELPTR_NULL; 28 } 29 30 static bool 31 mcts_node_expand(struct mcts_node *self, struct arena *pool, u8 x, u8 y) 32 { 33 assert(self); 34 assert(pool); 35 36 struct mcts_node *child = ALLOC_SIZED(pool, struct mcts_node); 37 if (!child) { 38 dbglog(LOG_WARN, "Failed to allocate child node. Consider compacting memory pool\n"); 39 return false; 40 } 41 42 mcts_node_init(child, self, hexopponent(self->player), x, y, self->children_cap - 1); 43 list_push_tail(&self->children, &child->list_node); 44 self->children_len++; 45 46 return true; 47 } 48 49 static struct mcts_node * 50 mcts_node_get_child(struct mcts_node *self, u8 x, u8 y) 51 { 52 assert(self); 53 54 struct list *list = &self->children; 55 list_iter(list) { 56 struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node); 57 assert(child); 58 59 if (child->x == x && child->y == y) 60 return child; 61 } 62 63 return NULL; 64 } 65 66 static f32 67 mcts_node_calc_score(struct mcts_node *self) 68 { 69 assert(self); 70 71 /* MCTS-RAVE formula: 72 * ((1 - beta(n, n')) * (w / n)) + (beta(n, n') * (w' / n')) + (c * sqrt(ln t / n)) 73 * --- 74 * n = number of won playouts for this node 75 * n' = number of won playouts for this node for a given move 76 * w = total number of playouts for this node 77 * w' = total number of playouts for this node for a given move 78 * c = exploration parameter (sqrt(2), or found experimentally) 79 * t = total number of playouts for parent node 80 * beta(n, n') = function close to 1 for small n, and close to 0 for large n 81 */ 82 83 /* if this node has not yet been played, return the default maximum value 84 * so that it is picked during expansion 85 */ 86 if (!self->plays) return INFINITY; 87 88 s64 exploration_rounds = 3000; 89 f32 beta = MAX(0.0, (exploration_rounds - self->plays) / (f32) exploration_rounds); 90 assert(0.0 <= beta && beta <= 1.0); 91 92 dbglog(LOG_DEBUG, "beta: %lf, wins: %d, rave_wins: %d, plays: %u, rave_plays: %u\n", 93 beta, self->wins, self->rave_wins, self->plays, self->rave_plays); 94 95 struct mcts_node *parent = mcts_node_rel2abs(self, self->parent); 96 assert(parent); 97 98 f32 exploration = M_SQRT2 * sqrtf(logf(parent->plays) / (f32) self->plays); 99 100 f32 exploitation = (1 - beta) * ((f32) self->wins / (f32) self->plays); 101 assert(-1.0 <= exploitation && exploitation <= 1.0); 102 103 f32 rave_exploitation = beta * ((f32) self->rave_wins / (f32) self->rave_plays); 104 assert(-1.0 <= rave_exploitation && rave_exploitation <= 1.0); 105 106 dbglog(LOG_DEBUG, "exploration: %f, exploitation: %f, rave_exploitation: %f\n", 107 exploration, exploitation, rave_exploitation); 108 109 return exploration + exploitation + rave_exploitation; 110 } 111 112 static struct mcts_node * 113 mcts_node_best_child(struct mcts_node *self) 114 { 115 assert(self); 116 117 f32 max_score = -INFINITY; 118 struct mcts_node *best_child = NULL; 119 struct list *list = &self->children; 120 list_iter(list) { 121 struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node); 122 assert(child); 123 124 dbglog(LOG_DEBUG, "Node: {parent=%p, children=%" PRIu32 ", x=%" PRIu32 ", y=%" PRIu32 "}\n", 125 mcts_node_rel2abs(child, child->parent), child->children_len, child->x, child->y); 126 127 f32 score = mcts_node_calc_score(child); 128 129 if (score > max_score) { 130 max_score = score; 131 best_child = child; 132 } 133 } 134 135 return best_child; 136 } 137 138 bool 139 agent_mcts_init(struct agent_mcts *self, struct board const *board, struct threadpool *threadpool, 140 u32 mem_limit_mib, enum hex_player player) 141 { 142 assert(self); 143 144 self->board = board; 145 self->threadpool = threadpool; 146 147 if (!board_init(&self->shadow_board, board->size)) return false; 148 149 size_t align = alignof(struct mcts_node); 150 self->pool.cap = ALIGN_PREV((mem_limit_mib * MiB) - HEXES_RESERVED_MEM, align); 151 152 int prot = PROT_READ | PROT_WRITE; 153 int flags = MAP_PRIVATE | MAP_ANONYMOUS; 154 self->pool.ptr = mmap(NULL, self->pool.cap, prot, flags, -1, 0); 155 madvise(self->pool.ptr, self->pool.cap, MADV_HUGEPAGE); 156 157 if (self->pool.ptr == MAP_FAILED) { 158 board_free(&self->shadow_board); 159 return false; 160 } 161 162 u32 moves = board_available_moves(self->board, NULL); 163 self->root = ALLOC_SIZED(&self->pool, struct mcts_node); 164 mcts_node_init(self->root, NULL, hexopponent(player), 0, 0, moves); 165 166 return true; 167 } 168 169 void 170 agent_mcts_free(struct agent_mcts *self) 171 { 172 assert(self); 173 174 munmap(self->pool.ptr, self->pool.cap); 175 } 176 177 void 178 agent_mcts_play(struct agent_mcts *self, enum hex_player player, u32 x, u32 y) 179 { 180 assert(self); 181 182 arena_reset(&self->pool); 183 184 u32 moves = board_available_moves(self->board, NULL); 185 self->root = ALLOC_SIZED(&self->pool, struct mcts_node); 186 mcts_node_init(self->root, NULL, player, x, y, moves); 187 188 // TODO: implement tree reuse, if it improves play 189 // 190 // one possible issue is children containing stale board states, 191 // leading to potentially invalid moves being generated. 192 193 // TODO: implement tree compaction 194 // one possible issue is the fact that walking the tree to 195 // compact it takes a significant amount of time and potentially 196 // outweighs simply resetting the pool and performing a few more 197 // rounds of MCTS. 198 // 199 // another option would be to implement some form of cyclic 200 // memory pool, to allow allocations from memory behind the 201 // current root node, as well as in front of it, but this would 202 // require tracking stale leaf nodes and reclaiming them (hence 203 // transforms into a GC, and thus wastes the benefits of a 204 // simple memory pool) 205 } 206 207 void 208 agent_mcts_swap(struct agent_mcts *self) 209 { 210 assert(self); 211 212 struct mcts_node old_root = *self->root; 213 214 arena_reset(&self->pool); 215 216 u32 moves = board_available_moves(self->board, NULL); 217 self->root = ALLOC_SIZED(&self->pool, struct mcts_node); 218 mcts_node_init(self->root, NULL, hexopponent(old_root.player), old_root.x, old_root.y, moves); 219 } 220 221 static bool 222 mcts_search(struct agent_mcts *self, struct timespec timeout); 223 224 bool 225 agent_mcts_next(struct agent_mcts *self, struct timespec timeout, 226 u32 *restrict out_x, u32 *restrict out_y) 227 { 228 assert(self); 229 assert(out_x); 230 assert(out_y); 231 232 if (!mcts_search(self, timeout)) return false; 233 234 struct mcts_node *root = self->pool.ptr; 235 assert(root->children.head != RELPTR_NULL); 236 237 u32 max_plays = 0; 238 struct mcts_node *best_child = NULL; 239 struct list *list = &root->children; 240 list_iter(list) { 241 struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node); 242 assert(child); 243 244 if (child->plays > max_plays) { 245 max_plays = child->plays; 246 best_child = child; 247 } else if (child->plays == max_plays && random() % 2) { 248 best_child = child; 249 } 250 } 251 252 assert(best_child); 253 254 *out_x = best_child->x; 255 *out_y = best_child->y; 256 257 return true; 258 } 259 260 static bool 261 mcts_round(struct agent_mcts *self, struct move *moves) 262 { 263 assert(self); 264 265 board_copy(self->board, &self->shadow_board); 266 267 dbglog(LOG_DEBUG, "Starting MCTS round\n"); 268 269 /* selection: we walk the mcts tree, picking the child with the highest 270 * mcts-rave score, until we hit a node with unexpanded children 271 */ 272 struct mcts_node *node = self->root; 273 while (node->children_len == node->children_cap) { 274 struct mcts_node *child = mcts_node_best_child(node); 275 if (!child) break; 276 277 if (!board_play(&self->shadow_board, child->player, child->x, child->y)) { 278 dbglog(LOG_WARN, "Failed to play move (%" PRIu32 ", %" PRIu32 ") to shadow board\n", child->x, child->y); 279 return false; 280 } 281 282 node = child; 283 } 284 285 dbglog(LOG_DEBUG, "Selected node {parent=%p, children=%" PRIu8 ", x=%" PRIu32 ", y=%" PRIu32 "} for expansion\n", 286 mcts_node_rel2abs(node, node->parent), node->children_len, node->x, node->y); 287 288 size_t moves_len = board_available_moves(&self->shadow_board, moves); 289 shuffle(moves, sizeof *moves, moves_len); 290 291 /* expansion: we expand the chosen node, creating a new child for a 292 * random move 293 */ 294 enum hex_player winner; 295 if (!board_winner(&self->shadow_board, &winner)) { 296 struct move move = moves[--moves_len]; 297 298 if (!mcts_node_expand(node, &self->pool, move.x, move.y)) { 299 dbglog(LOG_WARN, "Failed to expand selected node\n"); 300 return false; 301 } 302 303 struct mcts_node *child = mcts_node_get_child(node, move.x, move.y); 304 assert(child); 305 306 if (!board_play(&self->shadow_board, child->player, child->x, child->y)) { 307 dbglog(LOG_WARN, "Failed to play move (%" PRIu32 ", %" PRIu32 ") to shadow board\n", child->x, child->y); 308 return false; 309 } 310 } 311 312 dbglog(LOG_DEBUG, "Expanded node {parent=%p, children=%" PRIu8 ", x=%" PRIu32 ", y=%" PRIu32 "}\n", 313 mcts_node_rel2abs(node, node->parent), node->children_len, node->x, node->y); 314 315 /* simulation: we simulate the game using a uniform random walk of the 316 * game state space, until a winner is found 317 */ 318 enum hex_player player = node->player; 319 while (!board_winner(&self->shadow_board, &winner)) { 320 struct move move = moves[--moves_len]; 321 322 if (!board_play(&self->shadow_board, player, move.x, move.y)) { 323 dbglog(LOG_WARN, "Failed to play move (%" PRIu32 ", %" PRIu32 ") to shadow board\n", move.x, move.y); 324 return false; 325 } 326 327 player = hexopponent(player); 328 } 329 330 dbglog(LOG_DEBUG, "Completed playouts for node {parent=%p, children=%" PRIu8 ", x=%" PRIu32 ", y=%" PRIu32 "}\n", 331 mcts_node_rel2abs(node, node->parent), node->children_len, node->x, node->y); 332 333 /* backpropagation: we update the state information in the mcts tree 334 * by walking backwards from the selected node 335 */ 336 do { 337 s32 reward = winner == node->player ? +1 : -1; 338 339 struct list *list = &node->children; 340 list_iter(list) { 341 struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node); 342 assert(child); 343 344 struct segment *segment = &self->shadow_board.segments[child->y * self->shadow_board.size + child->x]; 345 if ((enum cell) child->player == segment->occupant) { 346 child->rave_plays += 1; 347 child->rave_wins += -reward; 348 } 349 } 350 351 node->plays += 1; 352 node->wins += reward; 353 } while ((node = mcts_node_rel2abs(node, node->parent))); 354 355 dbglog(LOG_DEBUG, "Completed backpropagation from selected node\n"); 356 357 dbglog(LOG_DEBUG, "Completed MCTS round\n"); 358 359 return true; 360 } 361 362 static bool 363 mcts_search(struct agent_mcts *self, struct timespec timeout) 364 { 365 assert(self); 366 367 struct move *moves = alloca(self->board->size * self->board->size * sizeof *moves); 368 369 struct timespec time; 370 clock_gettime(CLOCK_MONOTONIC, &time); 371 372 u64 end_nanos = TIMESPEC_TO_NANOS(time.tv_sec, time.tv_nsec) 373 + TIMESPEC_TO_NANOS(timeout.tv_sec, timeout.tv_nsec); 374 375 dbglog(LOG_INFO, "Starting MCTS tree search with %" PRIu32 " second timeout\n", timeout.tv_sec); 376 377 size_t rounds = 0; 378 while (true) { 379 clock_gettime(CLOCK_MONOTONIC, &time); 380 if (end_nanos <= TIMESPEC_TO_NANOS(time.tv_sec, time.tv_nsec)) { 381 dbglog(LOG_DEBUG, "Search timeout elapsed\n"); 382 break; 383 } 384 385 if (!mcts_round(self, moves)) { 386 dbglog(LOG_WARN, "Failed to perform MCTS round %zu\n", rounds + 1); 387 break; 388 } 389 390 rounds++; 391 } 392 393 dbglog(LOG_INFO, "Completed %zu rounds of MCTS\n", rounds); 394 dbglog(LOG_INFO, "MCTS node pool occupancy: %zu/%zu bytes allocated\n", self->pool.len, self->pool.cap); 395 396 return true; 397 }