minimax

minimax.git
git clone git://git.lenczewski.org/minimax.git
Log | Files | Refs

minimax.c (15024B)


      1 #ifdef _XOPEN_SOURCE
      2 # undef _XOPEN_SOURCE
      3 #endif
      4 
      5 #define _XOPEN_SOURCE 700
      6 
      7 #ifndef _GNU_SOURCE
      8 # define _GNU_SOURCE 1
      9 #endif
     10 
     11 #ifndef _BSD_SOURCE
     12 # define _BSD_SOURCE 1
     13 #endif
     14 
     15 #ifndef _DEFAULT_SOURCE
     16 # define _DEFAULT_SOURCE 1
     17 #endif
     18 
     19 #include <assert.h>
     20 #include <stdalign.h>
     21 #include <stddef.h>
     22 #include <stdint.h>
     23 #include <stdio.h>
     24 #include <stdlib.h>
     25 #include <string.h>
     26 
     27 #include <math.h>
     28 
     29 #if defined(__linux__)
     30 # include <alloca.h>
     31 # include <sys/mman.h>
     32 # define alloca_impl alloca
     33 #elif defined(_WIN32)
     34 # include <malloc.h>
     35 # define alloca_impl _alloca
     36 #endif
     37 
     38 /* configuration
     39  * ===========================================================================
     40  */
     41 
     42 #define KiB 1024ULL
     43 #define MiB (1024 * KiB)
     44 #define GiB (1024 * MiB)
     45 
     46 #define ARENA_SIZE 1 * GiB
     47 
     48 // length of one side of the game board
     49 #define BOARD_SIZE 3
     50 
     51 // maximum depth at which to search using minimax
     52 // NOTE: this is reduced if the memory arena would not allow for this depth
     53 #define MAX_SEARCH_DEPTH 11
     54 
     55 // test minimax algorithm on first move of game
     56 #define MINIMAX_DEBUG 0
     57 
     58 /* utils
     59  * ===========================================================================
     60  */
     61 
     62 #define true 1
     63 #define false 0
     64 
     65 #define MAX(a, b) (((a) > (b)) ? (a) : (b))
     66 #define MIN(a, b) (((a) < (b)) ? (a) : (b))
     67 
     68 struct arena {
     69 	void *ptr;
     70 	uint64_t cap, len;
     71 };
     72 
     73 static inline void
     74 arena_reset(struct arena *arena)
     75 {
     76 	arena->len = 0;
     77 }
     78 
     79 #define ALIGN_PREV(v, align) ((v) & ~((align) - 1))
     80 #define ALIGN_NEXT(v, align) ALIGN_PREV((v) + ((align) - 1), (align))
     81 
     82 static inline void *
     83 arena_alloc(struct arena *arena, size_t size, size_t align)
     84 {
     85 	uint64_t aligned_len = ALIGN_NEXT(arena->len, align);
     86 	if (aligned_len + size > arena->cap)
     87 		return NULL;
     88 
     89 	void *ptr = (void *) ((uintptr_t) arena->ptr + aligned_len);
     90 	arena->len = aligned_len + size;
     91 
     92 	return ptr;
     93 }
     94 
     95 #define ALLOC_ARRAY(arena, T, n) arena_alloc((arena), sizeof(T) * (n), alignof(T))
     96 #define ALLOC_SIZED(arena, T) ALLOC_ARRAY((arena), T, 1)
     97 
     98 struct list_node {
     99 	struct list_node *next;
    100 };
    101 
    102 #define FROM_NODE(node, T, member) \
    103 	((T *) ((uintptr_t) node - offsetof(T, member)))
    104 
    105 #define list_node_iter(node) \
    106 	for (struct list_node *it = node; it; it = it->next)
    107 
    108 struct list {
    109 	struct list_node *head;
    110 };
    111 
    112 #define list_iter(list) \
    113 	list_node_iter(list->head)
    114 
    115 static inline void
    116 list_push(struct list *restrict list, struct list_node *restrict node)
    117 {
    118 	node->next = list->head;
    119 	list->head = node;
    120 }
    121 
    122 /* tic-tac-toe definitions
    123  * ===========================================================================
    124  */
    125 
    126 enum player {
    127 	EMPTY = 0,
    128 	BLACK = +1,
    129 	WHITE = -1,
    130 };
    131 
    132 static inline enum player
    133 get_opponent(enum player player)
    134 {
    135 	switch (player) {
    136 	case EMPTY: return EMPTY;
    137 	case BLACK: return WHITE;
    138 	case WHITE: return BLACK;
    139 	}
    140 }
    141 
    142 static inline char
    143 get_player_char(enum player player)
    144 {
    145 	switch (player) {
    146 	case EMPTY: return '.';
    147 	case BLACK: return 'B';
    148 	case WHITE: return 'W';
    149 	}
    150 }
    151 
    152 static inline char const *
    153 get_player_str(enum player player)
    154 {
    155 	switch (player) {
    156 	case EMPTY: return "EMPTY";
    157 	case BLACK: return "BLACK";
    158 	case WHITE: return "WHITE";
    159 	}
    160 }
    161 
    162 struct cell {
    163 	int8_t state;
    164 };
    165 
    166 struct board {
    167 	struct cell cells[BOARD_SIZE * BOARD_SIZE];
    168 };
    169 
    170 static inline struct cell *
    171 get_cell_ptr(struct board *board, size_t i, size_t j)
    172 {
    173 	return &board->cells[j * BOARD_SIZE + i];
    174 }
    175 
    176 struct move {
    177 	uint8_t i, j;
    178 };
    179 
    180 static inline int
    181 try_play_move(struct board *board, struct move move, enum player player)
    182 {
    183 	if (BOARD_SIZE <= move.i || BOARD_SIZE <= move.j)
    184 		return false;
    185 
    186 	struct cell *cell = get_cell_ptr(board, move.i, move.j);
    187 	if (cell->state != EMPTY)
    188 		return false;
    189 
    190 	cell->state = player;
    191 
    192 	return true;
    193 }
    194 
    195 static inline void
    196 undo_move(struct board *board, struct move move)
    197 {
    198 	struct cell *cell = get_cell_ptr(board, move.i, move.j);
    199 	cell->state = EMPTY;
    200 }
    201 
    202 static size_t
    203 get_available_moves(struct board *restrict board, struct move *restrict buf)
    204 {
    205 	size_t moves = 0;
    206 	for (size_t j = 0; j < BOARD_SIZE; j++) {
    207 		for (size_t i = 0; i < BOARD_SIZE; i++) {
    208 			if (get_cell_ptr(board, i, j)->state == EMPTY) {
    209 				if (buf)
    210 					*buf++ = (struct move) {i, j};
    211 
    212 				moves++;
    213 			}
    214 		}
    215 	}
    216 
    217 	return moves;
    218 }
    219 
    220 static enum player
    221 get_winner(struct board *board)
    222 {
    223 	for (size_t j = 0; j < BOARD_SIZE; j++) { /* check rows */
    224 		enum player winner = get_cell_ptr(board, 0, j)->state;
    225 
    226 		int contiguous = true;
    227 		for (size_t i = 0; i < BOARD_SIZE; i++) {
    228 			if (get_cell_ptr(board, i, j)->state != winner)
    229 				contiguous = false;
    230 		}
    231 
    232 		if (contiguous)
    233 			return winner;
    234 	}
    235 
    236 	for (size_t i = 0; i < BOARD_SIZE; i++) { /* check cols */
    237 		enum player winner = get_cell_ptr(board, i, 0)->state;
    238 
    239 		int contiguous = true;
    240 		for (size_t j = 0; j < BOARD_SIZE; j++) {
    241 			if (get_cell_ptr(board, i, j)->state != winner)
    242 				contiguous = false;
    243 		}
    244 
    245 		if (contiguous)
    246 			return winner;
    247 	}
    248 
    249 	{
    250 		enum player winner = get_cell_ptr(board, 0, 0)->state;
    251 
    252 		int contiguous = true;
    253 		for (size_t k = 0; k < BOARD_SIZE; k++) { /* check (k,k) diagonal */
    254 			if (get_cell_ptr(board, k, k)->state != winner)
    255 				contiguous = false;
    256 		}
    257 
    258 		if (contiguous)
    259 			return winner;
    260 	}
    261 
    262 	{
    263 		enum player winner = get_cell_ptr(board, 0, BOARD_SIZE - 1)->state;
    264 
    265 		int contiguous = true;
    266 		for (size_t k = 0; k < BOARD_SIZE; k++) { /* check (k,k) diagonal */
    267 			if (get_cell_ptr(board, k, BOARD_SIZE - 1 - k)->state != winner)
    268 				contiguous = false;
    269 		}
    270 
    271 		if (contiguous)
    272 			return winner;
    273 	}
    274 
    275 	return EMPTY;
    276 }
    277 
    278 static int
    279 game_over(struct board *board, enum player *winner)
    280 {
    281 	size_t moves = get_available_moves(board, NULL);
    282 
    283 	if ((*winner = get_winner(board)) || !moves) /* winner, or a draw */
    284 		return true;
    285 
    286 	return false;
    287 }
    288 
    289 static inline void
    290 draw_board_row_divider(void)
    291 {
    292 	printf("-");
    293 	for (size_t i = 0; i < BOARD_SIZE; i++)
    294 		printf("----");
    295 	printf("\n");
    296 }
    297 
    298 static void
    299 print_board(struct board *board)
    300 {
    301 	draw_board_row_divider();
    302 
    303 	for (size_t j = 0; j < BOARD_SIZE; j++) {
    304 		printf("|");
    305 		for (size_t i = 0; i < BOARD_SIZE; i++) {
    306 			enum player cell = get_cell_ptr(board, i, j)->state;
    307 			printf(" %c |", get_player_char(cell));
    308 		}
    309 		printf("\n");
    310 
    311 		draw_board_row_divider();
    312 	}
    313 }
    314 
    315 /* minimax definitions
    316  * ===========================================================================
    317  */
    318 
    319 struct minimax {
    320 	struct list possible_moves;
    321 };
    322 
    323 struct minimax_node {
    324 	// NOTE: we can use this to avoid false-aliasing of cachelines
    325 	alignas(64)
    326 
    327 	struct move move;
    328 	int8_t player;
    329 
    330 	struct list children;
    331 	struct list_node list_node;
    332 };
    333 
    334 static inline void
    335 minimax_node_init(struct minimax_node *node, struct move move, enum player player)
    336 {
    337 	node->move = move;
    338 	node->player = player;
    339 	node->children.head = NULL;
    340 	node->list_node.next = NULL;
    341 }
    342 
    343 #define IS_TERMINAL_NODE(node) ((node)->children.head == NULL)
    344 
    345 static inline size_t
    346 nodecount(size_t available_moves, size_t max_depth)
    347 {
    348 	size_t res = available_moves;
    349 
    350 	while (max_depth-- && --available_moves)
    351 		res = res * available_moves;
    352 
    353 	return res;
    354 }
    355 
    356 static size_t
    357 max_depth_for_node_cap(size_t available_moves, size_t cap)
    358 {
    359 	size_t depth, nodes;
    360 	for (depth = 0; depth < available_moves; depth++)
    361 		if ((nodes = nodecount(available_moves, depth)) > cap)
    362 			return depth;
    363 
    364 	return depth;
    365 }
    366 
    367 #if 0
    368 
    369 static void
    370 dump_tree(FILE *fp, struct minimax_node const *node, size_t depth, size_t max_depth)
    371 {
    372 #define LEADER(fp, depth) \
    373 	for (size_t i = 0; i < depth; i++) fprintf(fp, "  ");
    374 
    375 	LEADER(fp, depth);
    376 	fprintf(fp, "node: {move: {i: %u, j: %u}, player: %s}\n",
    377 			node->move.i, node->move.j, get_player_str(node->player));
    378 
    379 	if (depth + 1 == max_depth) {
    380 		LEADER(fp, depth + 1);
    381 		fprintf(fp, "...\n");
    382 		return;
    383 	}
    384 
    385 	struct list const *list = &node->children;
    386 	list_iter(list) {
    387 		struct minimax_node *child = FROM_NODE(it, struct minimax_node, list_node);
    388 		dump_tree(fp, child, depth + 1, max_depth);
    389 	}
    390 #undef LEADER
    391 }
    392 
    393 static void
    394 minimax_dump_trees(struct minimax *minimax, FILE *fp, size_t max_depth)
    395 {
    396 	struct list *list = &minimax->possible_moves;
    397 	list_iter(list) {
    398 		struct minimax_node *tree = FROM_NODE(it, struct minimax_node, list_node);
    399 		dump_tree(fp, tree, 0, max_depth);
    400 	}
    401 }
    402 
    403 #endif
    404 
    405 static struct minimax_node *
    406 create_tree(struct board *board, struct move move, enum player player,
    407 	    struct arena *arena, size_t depth, size_t max_depth)
    408 {
    409 	struct minimax_node *node = ALLOC_SIZED(arena, struct minimax_node);
    410 	assert(node);
    411 
    412 	minimax_node_init(node, move, player);
    413 
    414 	int res = try_play_move(board, node->move, node->player);
    415 	assert(res);
    416 
    417 	if (depth == max_depth)
    418 		goto end;
    419 
    420 	size_t available_moves = get_available_moves(board, NULL);
    421 	if (!available_moves)
    422 		goto end;
    423 
    424 	struct move *moves = alloca_impl(available_moves * sizeof *moves);
    425 	get_available_moves(board, moves);
    426 
    427 	for (size_t i = 0; i < available_moves; i++) {
    428 		struct minimax_node *child = create_tree(board, moves[i],
    429 							 get_opponent(player),
    430 							 arena, depth + 1, max_depth);
    431 
    432 		list_push(&node->children, &child->list_node);
    433 	}
    434 
    435 end:
    436 	undo_move(board, node->move);
    437 
    438 	return node;
    439 }
    440 
    441 static float
    442 heuristic(struct minimax_node *node, struct board *board, enum player maximising_player)
    443 {
    444 	enum player winner;
    445 	if ((winner = get_winner(board))) { /* this is a winning move */
    446 		return (winner == maximising_player) ? +INFINITY : -INFINITY;
    447 	}
    448 
    449 	size_t moves = get_available_moves(board, NULL);
    450 	if (!moves) /* no winner and no moves left to make, draw */
    451 		return 0;
    452 
    453 	// TODO: is there a better heuristic than assuming a loss?
    454 	return (node->player == maximising_player) ? -INFINITY : +INFINITY;
    455 }
    456 
    457 static float
    458 minimax(struct minimax_node *node, struct board *board, enum player maximising_player,
    459 	size_t max_depth)
    460 {
    461 	float result;
    462 
    463 	int res = try_play_move(board, node->move, node->player);
    464 	assert(res);
    465 
    466 	if (!max_depth || IS_TERMINAL_NODE(node)) {
    467 		result = heuristic(node, board, maximising_player);
    468 		goto end;
    469 	}
    470 
    471 	if (node->player == maximising_player) {
    472 		float max_value = -INFINITY;
    473 
    474 		struct list *list = &node->children;
    475 		list_iter(list) {
    476 			struct minimax_node *child = FROM_NODE(it, struct minimax_node, list_node);
    477 			float value = minimax(child, board, maximising_player, max_depth - 1);
    478 			max_value = MAX(max_value, value);
    479 		}
    480 
    481 		result = max_value;
    482 	} else /* node->player == minimising_player */ {
    483 		float min_value = +INFINITY;
    484 
    485 		struct list *list = &node->children;
    486 		list_iter(list) {
    487 			struct minimax_node *child = FROM_NODE(it, struct minimax_node, list_node);
    488 			float value = minimax(child, board, maximising_player, max_depth - 1);
    489 			min_value = MIN(min_value, value);
    490 		}
    491 
    492 		result = min_value;
    493 	}
    494 
    495 end:
    496 	undo_move(board, node->move);
    497 
    498 	return result;
    499 }
    500 
    501 static struct move
    502 minimax_get_best_move(struct minimax *self, struct board *board, enum player player,
    503 		      float *out, struct arena *arena)
    504 {
    505 	size_t available_moves = get_available_moves(board, NULL);
    506 	assert(available_moves);
    507 
    508 	size_t max_nodes = arena->cap / sizeof(struct minimax_node);
    509 	size_t max_search_depth = MIN(MAX_SEARCH_DEPTH,
    510 				      max_depth_for_node_cap(available_moves, max_nodes));
    511 
    512 #if MINIMAX_DEBUG
    513 	fprintf(stderr, "Maximum search depth: %zu, using %zu/%zu nodes\n",
    514 			max_search_depth, nodecount(available_moves, max_search_depth), max_nodes);
    515 #endif
    516 
    517 	struct move *moves = alloca_impl(available_moves * sizeof *moves);
    518 	get_available_moves(board, moves);
    519 
    520 	self->possible_moves.head = NULL;
    521 
    522 	float best_value = -INFINITY;
    523 	struct move best_move;
    524 
    525 	for (size_t i = 0; i < available_moves; i++) {
    526 		arena_reset(arena);
    527 
    528 #if MINIMAX_DEBUG
    529 		fprintf(stderr, "starting minimax tree for root node: {i: %u, j: %u} with depth: %zu\n",
    530 				moves[i].i, moves[i].j, max_search_depth);
    531 #endif
    532 
    533 		// NOTE: we limit the depth of the minimax tree to avoid
    534 		//       running out of memory, but if we could employ
    535 		//       memoization of similar game states we could greatly
    536 		//       reduce the need for this limit
    537 		struct minimax_node *tree = create_tree(board, moves[i], player,
    538 							arena, 0, max_search_depth);
    539 		assert(tree);
    540 
    541 #if MINIMAX_DEBUG
    542 		fprintf(stderr, "created tree with %zu/%zu nodes\n",
    543 				arena->len / sizeof(struct minimax_node),
    544 				arena->cap / sizeof(struct minimax_node));
    545 #endif
    546 
    547 
    548 		list_push(&self->possible_moves, &tree->list_node);
    549 
    550 		float value = minimax(tree, board, player, max_search_depth);
    551 		if (value >= best_value) {
    552 			best_value = value;
    553 			best_move = tree->move;
    554 		}
    555 	}
    556 
    557 	*out = best_value;
    558 
    559 #if MINIMAX_DEBUG
    560 		fprintf(stderr, "Best move for %s: {i: %u, j: %u}, value: %f\n",
    561 				get_player_str(player), best_move.i, best_move.j, best_value);
    562 #endif
    563 
    564 	return best_move;
    565 }
    566 
    567 /* main definition
    568  * ===========================================================================
    569  */
    570 
    571 static size_t
    572 get_input_coord(void)
    573 {
    574 	char buf[64];
    575 	while (true) {
    576 		char *line = fgets(buf, sizeof buf, stdin);
    577 
    578 		size_t coord;
    579 		if (sscanf(line, "%zu", &coord)) {
    580 			return coord;
    581 		} else {
    582 			printf("invalid coordinate. try again.\n");
    583 		}
    584 	}
    585 }
    586 
    587 int
    588 main(void)
    589 {
    590 	// NOTE: we use an arena to avoid expensive system allocators, and so
    591 	//       that we can easily free all nodes without iterating the tree
    592 	struct arena arena;
    593 
    594 #if defined(__linux__)
    595 	int prot = PROT_READ|PROT_WRITE;
    596 	int flags = MAP_PRIVATE|MAP_ANONYMOUS;
    597 	arena.ptr = mmap(NULL, ARENA_SIZE, prot, flags, -1, 0);
    598 	arena.cap = ARENA_SIZE;
    599 	arena.len = 0;
    600 
    601 	assert(arena.ptr != MAP_FAILED);
    602 
    603 	madvise(arena.ptr, arena.cap, MADV_HUGEPAGE);
    604 #elif defined(_WIN32)
    605 	arena.ptr = malloc(ARENA_SIZE);
    606 	arena.cap = ARENA_SIZE;
    607 	arean.len = 0;
    608 
    609 	assert(arena.ptr);
    610 #endif
    611 
    612 	size_t max_nodes = arena.cap / sizeof(struct minimax_node);
    613 	fprintf(stderr, "Arena cap: %zu bytes, node size: %zu bytes, max node count: %zu\n",
    614 			arena.cap, sizeof(struct minimax_node), max_nodes);
    615 
    616 	struct minimax minimax;
    617 	memset(&minimax, 0, sizeof minimax);
    618 
    619 	struct board board;
    620 	memset(&board, 0, sizeof board);
    621 
    622 #if MINIMAX_DEBUG
    623 	float value;
    624 	struct move move = minimax_get_best_move(&minimax, &board, BLACK, &value, &arena);
    625 
    626 	int res = try_play_move(&board, move, BLACK);
    627 	assert(res);
    628 
    629 	return 0;
    630 
    631 #endif
    632 
    633 	enum player winner;
    634 	while (true) {
    635 		print_board(&board);
    636 
    637 		/* player move */
    638 		struct move player_move;
    639 		while (true) {
    640 			printf("user x coord (between 0 and %d):\n", BOARD_SIZE - 1);
    641 			size_t x = get_input_coord();
    642 
    643 			printf("user y coord (between 0 and %d):\n", BOARD_SIZE - 1);
    644 			size_t y = get_input_coord();
    645 
    646 			player_move.i = x;
    647 			player_move.j = y;
    648 
    649 			if (!try_play_move(&board, player_move, BLACK)) {
    650 				printf("move already made! try again.\n");
    651 				continue;
    652 			}
    653 
    654 			break;
    655 		}
    656 
    657 		if (game_over(&board, &winner))
    658 			break;
    659 
    660 		/* ai move */
    661 		float value;
    662 		struct move best_move = minimax_get_best_move(&minimax, &board,
    663 							      WHITE, &value,
    664 							      &arena);
    665 
    666 		int res = try_play_move(&board, best_move, WHITE);
    667 		assert(res);
    668 
    669 		if (game_over(&board, &winner))
    670 			break;
    671 	}
    672 
    673 	print_board(&board);
    674 
    675 	if (winner == BLACK) {
    676 		printf("Player won the game.\n");
    677 	} else if (winner == WHITE) {
    678 		printf("Player lost the game.\n");
    679 	} else {
    680 		printf("Player drew the game.\n");
    681 	}
    682 
    683 	return 0;
    684 }