hex

hex.git
git clone git://git.lenczewski.org/hex.git
Log | Files | Refs

mcts.c (11604B)


      1 #include "hexes.h"
      2 
      3 extern inline mcts_node_relptr_t
      4 mcts_node_abs2rel(void *restrict base, struct mcts_node *restrict absptr);
      5 
      6 extern inline struct mcts_node *
      7 mcts_node_rel2abs(void *base, mcts_node_relptr_t relptr);
      8 
      9 static void
     10 mcts_node_init(struct mcts_node *self, struct mcts_node *parent,
     11 	       enum hex_player player, u8 x, u8 y, u32 children_cap)
     12 {
     13 	assert(self);
     14 
     15 	self->parent = mcts_node_abs2rel(self, parent);
     16 	self->player = player;
     17 	self->x = x;
     18 	self->y = y;
     19 
     20 	self->wins = self->rave_wins = 0;
     21 	self->plays = self->rave_plays = 0;
     22 
     23 	self->children_cap = children_cap;
     24 	self->children_len = 0;
     25 
     26 	self->children.head = self->children.tail = RELPTR_NULL;
     27 	self->list_node.prev = self->list_node.next = RELPTR_NULL;
     28 }
     29 
     30 static bool
     31 mcts_node_expand(struct mcts_node *self, struct arena *pool, u8 x, u8 y)
     32 {
     33 	assert(self);
     34 	assert(pool);
     35 
     36 	struct mcts_node *child = ALLOC_SIZED(pool, struct mcts_node);
     37 	if (!child) {
     38 		dbglog(LOG_WARN, "Failed to allocate child node. Consider compacting memory pool\n");
     39 		return false;
     40 	}
     41 
     42 	mcts_node_init(child, self, hexopponent(self->player), x, y, self->children_cap - 1);
     43 	list_push_tail(&self->children, &child->list_node);
     44 	self->children_len++;
     45 
     46 	return true;
     47 }
     48 
     49 static struct mcts_node *
     50 mcts_node_get_child(struct mcts_node *self, u8 x, u8 y)
     51 {
     52 	assert(self);
     53 
     54 	struct list *list = &self->children;
     55 	list_iter(list) {
     56 		struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node);
     57 		assert(child);
     58 
     59 		if (child->x == x && child->y == y)
     60 			return child;
     61 	}
     62 
     63 	return NULL;
     64 }
     65 
     66 static f32
     67 mcts_node_calc_score(struct mcts_node *self)
     68 {
     69 	assert(self);
     70 
     71 	/* MCTS-RAVE formula:
     72 	 * ((1 - beta(n, n')) * (w / n)) + (beta(n, n') * (w' / n')) + (c * sqrt(ln t / n))
     73 	 * ---
     74 	 *  n = number of won playouts for this node
     75 	 *  n' = number of won playouts for this node for a given move
     76 	 *  w = total number of playouts for this node
     77 	 *  w' = total number of playouts for this node for a given move
     78 	 *  c = exploration parameter (sqrt(2), or found experimentally)
     79 	 *  t = total number of playouts for parent node
     80 	 *  beta(n, n') = function close to 1 for small n, and close to 0 for large n
     81 	 */
     82 
     83 	/* if this node has not yet been played, return the default maximum value
     84 	 * so that it is picked during expansion
     85 	 */
     86 	if (!self->plays) return INFINITY;
     87 
     88 	s64 exploration_rounds = 3000;
     89 	f32 beta = MAX(0.0, (exploration_rounds - self->plays) / (f32) exploration_rounds);
     90 	assert(0.0 <= beta && beta <= 1.0);
     91 
     92 	dbglog(LOG_DEBUG, "beta: %lf, wins: %d, rave_wins: %d, plays: %u, rave_plays: %u\n",
     93 			  beta, self->wins, self->rave_wins, self->plays, self->rave_plays);
     94 
     95 	struct mcts_node *parent = mcts_node_rel2abs(self, self->parent);
     96 	assert(parent);
     97 
     98 	f32 exploration = M_SQRT2 * sqrtf(logf(parent->plays) / (f32) self->plays);
     99 
    100 	f32 exploitation = (1 - beta) * ((f32) self->wins / (f32) self->plays);
    101 	assert(-1.0 <= exploitation && exploitation <= 1.0);
    102 
    103 	f32 rave_exploitation = beta * ((f32) self->rave_wins / (f32) self->rave_plays);
    104 	assert(-1.0 <= rave_exploitation && rave_exploitation <= 1.0);
    105 
    106 	dbglog(LOG_DEBUG, "exploration: %f, exploitation: %f, rave_exploitation: %f\n",
    107 			exploration, exploitation, rave_exploitation);
    108 
    109 	return exploration + exploitation + rave_exploitation;
    110 }
    111 
    112 static struct mcts_node *
    113 mcts_node_best_child(struct mcts_node *self)
    114 {
    115 	assert(self);
    116 
    117 	f32 max_score = -INFINITY;
    118 	struct mcts_node *best_child = NULL;
    119 	struct list *list = &self->children;
    120 	list_iter(list) {
    121 		struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node);
    122 		assert(child);
    123 
    124 		dbglog(LOG_DEBUG, "Node: {parent=%p, children=%" PRIu32 ", x=%" PRIu32 ", y=%" PRIu32 "}\n",
    125 				mcts_node_rel2abs(child, child->parent), child->children_len, child->x, child->y);
    126 
    127 		f32 score = mcts_node_calc_score(child);
    128 
    129 		if (score > max_score) {
    130 			max_score = score;
    131 			best_child = child;
    132 		}
    133 	}
    134 
    135 	return best_child;
    136 }
    137 
    138 bool
    139 agent_mcts_init(struct agent_mcts *self, struct board const *board, struct threadpool *threadpool,
    140 		u32 mem_limit_mib, enum hex_player player)
    141 {
    142 	assert(self);
    143 
    144 	self->board = board;
    145 	self->threadpool = threadpool;
    146 
    147 	if (!board_init(&self->shadow_board, board->size)) return false;
    148 
    149 	size_t align = alignof(struct mcts_node);
    150 	self->pool.cap = ALIGN_PREV((mem_limit_mib * MiB) - HEXES_RESERVED_MEM, align);
    151 
    152 	int prot = PROT_READ | PROT_WRITE;
    153 	int flags = MAP_PRIVATE | MAP_ANONYMOUS;
    154 	self->pool.ptr = mmap(NULL, self->pool.cap, prot, flags, -1, 0);
    155 	madvise(self->pool.ptr, self->pool.cap, MADV_HUGEPAGE);
    156 
    157 	if (self->pool.ptr == MAP_FAILED) {
    158 		board_free(&self->shadow_board);
    159 		return false;
    160 	}
    161 
    162 	u32 moves = board_available_moves(self->board, NULL);
    163 	self->root = ALLOC_SIZED(&self->pool, struct mcts_node);
    164 	mcts_node_init(self->root, NULL, hexopponent(player), 0, 0, moves);
    165 
    166 	return true;
    167 }
    168 
    169 void
    170 agent_mcts_free(struct agent_mcts *self)
    171 {
    172 	assert(self);
    173 
    174 	munmap(self->pool.ptr, self->pool.cap);
    175 }
    176 
    177 void
    178 agent_mcts_play(struct agent_mcts *self, enum hex_player player, u32 x, u32 y)
    179 {
    180 	assert(self);
    181 
    182 	arena_reset(&self->pool);
    183 
    184 	u32 moves = board_available_moves(self->board, NULL);
    185 	self->root = ALLOC_SIZED(&self->pool, struct mcts_node);
    186 	mcts_node_init(self->root, NULL, player, x, y, moves);
    187 
    188 	// TODO: implement tree reuse, if it improves play
    189 	//
    190 	//       one possible issue is children containing stale board states,
    191 	//       leading to potentially invalid moves being generated.
    192 
    193 	// TODO: implement tree compaction
    194 	//       one possible issue is the fact that walking the tree to
    195 	//       compact it takes a significant amount of time and potentially
    196 	//       outweighs simply resetting the pool and performing a few more
    197 	//       rounds of MCTS.
    198 	//
    199 	//       another option would be to implement some form of cyclic
    200 	//       memory pool, to allow allocations from memory behind the
    201 	//       current root node, as well as in front of it, but this would
    202 	//       require tracking stale leaf nodes and reclaiming them (hence
    203 	//       transforms into a GC, and thus wastes the benefits of a
    204 	//       simple memory pool)
    205 }
    206 
    207 void
    208 agent_mcts_swap(struct agent_mcts *self)
    209 {
    210 	assert(self);
    211 
    212 	struct mcts_node old_root = *self->root;
    213 
    214 	arena_reset(&self->pool);
    215 
    216 	u32 moves = board_available_moves(self->board, NULL);
    217 	self->root = ALLOC_SIZED(&self->pool, struct mcts_node);
    218 	mcts_node_init(self->root, NULL, hexopponent(old_root.player), old_root.x, old_root.y, moves);
    219 }
    220 
    221 static bool
    222 mcts_search(struct agent_mcts *self, struct timespec timeout);
    223 
    224 bool
    225 agent_mcts_next(struct agent_mcts *self, struct timespec timeout,
    226 		u32 *restrict out_x, u32 *restrict out_y)
    227 {
    228 	assert(self);
    229 	assert(out_x);
    230 	assert(out_y);
    231 
    232 	if (!mcts_search(self, timeout)) return false;
    233 
    234 	struct mcts_node *root = self->pool.ptr;
    235 	assert(root->children.head != RELPTR_NULL);
    236 
    237 	u32 max_plays = 0;
    238 	struct mcts_node *best_child = NULL;
    239 	struct list *list = &root->children;
    240 	list_iter(list) {
    241 		struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node);
    242 		assert(child);
    243 
    244 		if (child->plays > max_plays) {
    245 			max_plays = child->plays;
    246 			best_child = child;
    247 		} else if (child->plays == max_plays && random() % 2) {
    248 			best_child = child;
    249 		}
    250 	}
    251 
    252 	assert(best_child);
    253 
    254 	*out_x = best_child->x;
    255 	*out_y = best_child->y;
    256 
    257 	return true;
    258 }
    259 
    260 static bool
    261 mcts_round(struct agent_mcts *self, struct move *moves)
    262 {
    263 	assert(self);
    264 
    265 	board_copy(self->board, &self->shadow_board);
    266 
    267 	dbglog(LOG_DEBUG, "Starting MCTS round\n");
    268 
    269 	/* selection: we walk the mcts tree, picking the child with the highest
    270 	 * mcts-rave score, until we hit a node with unexpanded children
    271 	 */
    272 	struct mcts_node *node = self->root;
    273 	while (node->children_len == node->children_cap) {
    274 		struct mcts_node *child = mcts_node_best_child(node);
    275 		if (!child) break;
    276 
    277 		if (!board_play(&self->shadow_board, child->player, child->x, child->y)) {
    278 			dbglog(LOG_WARN, "Failed to play move (%" PRIu32 ", %" PRIu32 ") to shadow board\n", child->x, child->y);
    279 			return false;
    280 		}
    281 
    282 		node = child;
    283 	}
    284 
    285 	dbglog(LOG_DEBUG, "Selected node {parent=%p, children=%" PRIu8 ", x=%" PRIu32 ", y=%" PRIu32 "} for expansion\n",
    286 			  mcts_node_rel2abs(node, node->parent), node->children_len, node->x, node->y);
    287 
    288 	size_t moves_len = board_available_moves(&self->shadow_board, moves);
    289 	shuffle(moves, sizeof *moves, moves_len);
    290 
    291 	/* expansion: we expand the chosen node, creating a new child for a
    292 	 * random move
    293 	 */
    294 	enum hex_player winner;
    295 	if (!board_winner(&self->shadow_board, &winner)) {
    296 		struct move move = moves[--moves_len];
    297 
    298 		if (!mcts_node_expand(node, &self->pool, move.x, move.y)) {
    299 			dbglog(LOG_WARN, "Failed to expand selected node\n");
    300 			return false;
    301 		}
    302 
    303 		struct mcts_node *child = mcts_node_get_child(node, move.x, move.y);
    304 		assert(child);
    305 
    306 		if (!board_play(&self->shadow_board, child->player, child->x, child->y)) {
    307 			dbglog(LOG_WARN, "Failed to play move (%" PRIu32 ", %" PRIu32 ") to shadow board\n", child->x, child->y);
    308 			return false;
    309 		}
    310 	}
    311 
    312 	dbglog(LOG_DEBUG, "Expanded node {parent=%p, children=%" PRIu8 ", x=%" PRIu32 ", y=%" PRIu32 "}\n",
    313 			  mcts_node_rel2abs(node, node->parent), node->children_len, node->x, node->y);
    314 
    315 	/* simulation: we simulate the game using a uniform random walk of the
    316 	 * game state space, until a winner is found
    317 	 */
    318 	enum hex_player player = node->player;
    319 	while (!board_winner(&self->shadow_board, &winner)) {
    320 		struct move move = moves[--moves_len];
    321 
    322 		if (!board_play(&self->shadow_board, player, move.x, move.y)) {
    323 			dbglog(LOG_WARN, "Failed to play move (%" PRIu32 ", %" PRIu32 ") to shadow board\n", move.x, move.y);
    324 			return false;
    325 		}
    326 
    327 		player = hexopponent(player);
    328 	}
    329 
    330 	dbglog(LOG_DEBUG, "Completed playouts for node {parent=%p, children=%" PRIu8 ", x=%" PRIu32 ", y=%" PRIu32 "}\n",
    331 			  mcts_node_rel2abs(node, node->parent), node->children_len, node->x, node->y);
    332 
    333 	/* backpropagation: we update the state information in the mcts tree
    334 	 * by walking backwards from the selected node
    335 	 */
    336 	do {
    337 		s32 reward = winner == node->player ? +1 : -1;
    338 
    339 		struct list *list = &node->children;
    340 		list_iter(list) {
    341 			struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node);
    342 			assert(child);
    343 
    344 			struct segment *segment = &self->shadow_board.segments[child->y * self->shadow_board.size + child->x];
    345 			if ((enum cell) child->player == segment->occupant) {
    346 				child->rave_plays += 1;
    347 				child->rave_wins += -reward;
    348 			}
    349 		}
    350 
    351 		node->plays += 1;
    352 		node->wins += reward;
    353 	} while ((node = mcts_node_rel2abs(node, node->parent)));
    354 
    355 	dbglog(LOG_DEBUG, "Completed backpropagation from selected node\n");
    356 
    357 	dbglog(LOG_DEBUG, "Completed MCTS round\n");
    358 
    359 	return true;
    360 }
    361 
    362 static bool
    363 mcts_search(struct agent_mcts *self, struct timespec timeout)
    364 {
    365 	assert(self);
    366 
    367 	struct move *moves = alloca(self->board->size * self->board->size * sizeof *moves);
    368 
    369 	struct timespec time;
    370 	clock_gettime(CLOCK_MONOTONIC, &time);
    371 
    372 	u64 end_nanos = TIMESPEC_TO_NANOS(time.tv_sec, time.tv_nsec)
    373 		      + TIMESPEC_TO_NANOS(timeout.tv_sec, timeout.tv_nsec);
    374 
    375 	dbglog(LOG_INFO, "Starting MCTS tree search with %" PRIu32 " second timeout\n", timeout.tv_sec);
    376 
    377 	size_t rounds = 0;
    378 	while (true) {
    379 		clock_gettime(CLOCK_MONOTONIC, &time);
    380 		if (end_nanos <= TIMESPEC_TO_NANOS(time.tv_sec, time.tv_nsec)) {
    381 			dbglog(LOG_DEBUG, "Search timeout elapsed\n");
    382 			break;
    383 		}
    384 
    385 		if (!mcts_round(self, moves)) {
    386 			dbglog(LOG_WARN, "Failed to perform MCTS round %zu\n", rounds + 1);
    387 			break;
    388 		}
    389 
    390 		rounds++;
    391 	}
    392 
    393 	dbglog(LOG_INFO, "Completed %zu rounds of MCTS\n", rounds);
    394 	dbglog(LOG_INFO, "MCTS node pool occupancy: %zu/%zu bytes allocated\n", self->pool.len, self->pool.cap);
    395 
    396 	return true;
    397 }