hex

hex.git
git clone git://git.lenczewski.org/hex.git
Log | Files | Refs

commit aca3eae1effef4c0a09cbb347365168538bcafd3
parent 259e55f5690be7df58b9d7ff3f830488b82b53a6
Author: MikoĊ‚aj Lenczewski <mblenczewski@gmail.com>
Date:   Sun, 13 Aug 2023 23:28:33 +0000

Ported hexes agent to new server protocol

Diffstat:
Aagents/hexes/.gitignore | 3+++
Aagents/hexes/Makefile | 48++++++++++++++++++++++++++++++++++++++++++++++++
Aagents/hexes/include/hexes.h | 48++++++++++++++++++++++++++++++++++++++++++++++++
Aagents/hexes/include/hexes/agent.h | 41+++++++++++++++++++++++++++++++++++++++++
Aagents/hexes/include/hexes/agent/mcts.h | 71+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aagents/hexes/include/hexes/agent/random.h | 29+++++++++++++++++++++++++++++
Aagents/hexes/include/hexes/board.h | 94+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aagents/hexes/include/hexes/log.h | 33+++++++++++++++++++++++++++++++++
Aagents/hexes/include/hexes/network.h | 22++++++++++++++++++++++
Aagents/hexes/include/hexes/threadpool.h | 16++++++++++++++++
Aagents/hexes/include/hexes/utils.h | 103+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aagents/hexes/run.sh | 3+++
Aagents/hexes/src/agent.c | 70++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aagents/hexes/src/agent/mcts.c | 373+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aagents/hexes/src/agent/random.c | 76++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aagents/hexes/src/board.c | 221+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aagents/hexes/src/hexes.c | 248+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aagents/hexes/src/log.c | 4++++
Aagents/hexes/src/network.c | 100+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aagents/hexes/src/threadpool.c | 24++++++++++++++++++++++++
Aagents/hexes/src/utils.c | 28++++++++++++++++++++++++++++
Mschedule.txt | 2++
22 files changed, 1657 insertions(+), 0 deletions(-)

diff --git a/agents/hexes/.gitignore b/agents/hexes/.gitignore @@ -0,0 +1,3 @@ +hexes +!include/hexes/ +obj/ diff --git a/agents/hexes/Makefile b/agents/hexes/Makefile @@ -0,0 +1,48 @@ +.PHONY: all build clean + +CC ?= cc +TAR ?= tar + +SRC := src +INC := include +DEPINC := ../../include +OBJ := obj + +WARN := -Wall -Wextra -Wpedantic -Werror + +CFLAGS := -std=c17 $(WARN) -Og -g -flto +CPPFLAGS := -I$(INC) -I$(DEPINC) +LDFLAGS := -lm -flto + +TARGET := hexes +SOURCES := $(SRC)/hexes.c \ + $(SRC)/agent.c \ + $(SRC)/agent/mcts.c \ + $(SRC)/agent/random.c \ + $(SRC)/board.c \ + $(SRC)/log.c \ + $(SRC)/network.c \ + $(SRC)/threadpool.c \ + $(SRC)/utils.c + +OBJECTS := $(SOURCES:$(SRC)/%.c=$(OBJ)/%.o) +OBJDEPS := $(OBJECTS:%.o=%.d) + +all: build + +build: $(TARGET) + +clean: + rm -rf $(TARGET) $(OBJ) + +$(TARGET): $(OBJECTS) + $(CC) -o $@ $^ $(LDFLAGS) + +$(OBJ)/%.o: $(SRC)/%.c | $(OBJ) + @mkdir -p $(dir $@) + $(CC) -MMD -o $@ -c $< $(CFLAGS) $(CPPFLAGS) + +-include $(OBJDEPS) + +$(OBJ): + mkdir -p $@ diff --git a/agents/hexes/include/hexes.h b/agents/hexes/include/hexes.h @@ -0,0 +1,48 @@ +#ifndef HEXES_H +#define HEXES_H + +#ifdef _XOPEN_SOURCE +#undef _XOPEN_SOURCE +#endif + +#define _XOPEN_SOURCE 700 + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE 1 +#endif + +#ifndef _DEFAULT_SOURCE +#define _DEFAULT_SOURCE 1 +#endif + +#include "hex/types.h" +#include "hex/proto.h" + +#include <assert.h> +#include <errno.h> +#include <stdalign.h> +#include <stdarg.h> +#include <stdbool.h> +#include <stddef.h> +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#include <arpa/inet.h> +#include <math.h> +#include <netdb.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <time.h> +#include <unistd.h> + +struct opts { + u32 log_level, agent_type; +}; + +extern struct opts opts; + +#include "hexes/log.h" + +#endif /* HEXES_H */ diff --git a/agents/hexes/include/hexes/agent.h b/agents/hexes/include/hexes/agent.h @@ -0,0 +1,41 @@ +#ifndef HEXES_AGENT_H +#define HEXES_AGENT_H + +#include "hexes.h" + +#include "hexes/board.h" +#include "hexes/threadpool.h" + +#include "hexes/agent/random.h" +#include "hexes/agent/mcts.h" + +enum agent_type { + AGENT_RANDOM, + AGENT_MCTS, +}; + +struct agent { + enum agent_type type; + union { + struct agent_random random; + struct agent_mcts mcts; + } backend; +}; + +bool +agent_init(struct agent *self, enum agent_type type, struct board const *board, + struct threadpool *threadpool, u32 mem_limit_mib, enum hex_player player); + +void +agent_free(struct agent *self); + +void +agent_play(struct agent *self, enum hex_player player, u32 x, u32 y); + +void +agent_swap(struct agent *self); + +bool +agent_next(struct agent *self, struct timespec timeout, u32 *out_x, u32 *out_y); + +#endif /* HEXES_AGENT_H */ diff --git a/agents/hexes/include/hexes/agent/mcts.h b/agents/hexes/include/hexes/agent/mcts.h @@ -0,0 +1,71 @@ +#ifndef HEXES_AGENT_MCTS_H +#define HEXES_AGENT_MCTS_H + +#include "hexes.h" + +#include "hexes/board.h" +#include "hexes/threadpool.h" +#include "hexes/utils.h" + +#define RESERVED_MEM (MiB) + +typedef s64 mcts_node_relptr_t; + +struct mcts_node { + mcts_node_relptr_t parent; + enum hex_player player; + u8 x, y; + + s32 wins, rave_wins; + u32 plays, rave_plays; + + u16 children_cap, children_len; + mcts_node_relptr_t children[]; +}; + +inline size_t +mcts_node_sizeof(size_t children) +{ + return sizeof(struct mcts_node) + children * sizeof(mcts_node_relptr_t); +} + +inline mcts_node_relptr_t +mcts_node_abs2rel(void *base, struct mcts_node *absptr) +{ + return RELPTR_ABS2REL(mcts_node_relptr_t, base, absptr); +} + +inline struct mcts_node * +mcts_node_rel2abs(void *base, mcts_node_relptr_t relptr) +{ + return RELPTR_REL2ABS(struct mcts_node *, mcts_node_relptr_t, base, relptr); +} + +struct agent_mcts { + struct board const *board; + struct threadpool *threadpool; + + struct board shadow_board; + + struct mem_pool pool; + struct mcts_node *root; +}; + +bool +agent_mcts_init(struct agent_mcts *self, struct board const *board, struct threadpool *threadpool, + u32 mem_limit_mib, enum hex_player player); + +void +agent_mcts_free(struct agent_mcts *self); + +void +agent_mcts_play(struct agent_mcts *self, enum hex_player player, u32 x, u32 y); + +void +agent_mcts_swap(struct agent_mcts *self); + +bool +agent_mcts_next(struct agent_mcts *self, struct timespec timeout, u32 *out_x, u32 *out_y); + + +#endif /* HEXES_AGENT_MCTS_H */ diff --git a/agents/hexes/include/hexes/agent/random.h b/agents/hexes/include/hexes/agent/random.h @@ -0,0 +1,29 @@ +#ifndef HEXES_AGENT_RANDOM_H +#define HEXES_AGENT_RANDOM_H + +#include "hexes.h" + +#include "hexes/board.h" +#include "hexes/utils.h" + +struct agent_random { + size_t len; + struct move *moves; +}; + +bool +agent_random_init(struct agent_random *self, struct board const *board); + +void +agent_random_free(struct agent_random *self); + +void +agent_random_play(struct agent_random *self, enum hex_player player, u32 x, u32 y); + +void +agent_random_swap(struct agent_random *self); + +bool +agent_random_next(struct agent_random *self, struct timespec timeout, u32 *out_x, u32 *out_y); + +#endif /* HEXES_AGENT_RANDOM_H */ diff --git a/agents/hexes/include/hexes/board.h b/agents/hexes/include/hexes/board.h @@ -0,0 +1,94 @@ +#ifndef HEXES_BOARD_H +#define HEXES_BOARD_H + +#include "hexes.h" + +enum cell { + CELL_BLACK = HEX_PLAYER_BLACK, + CELL_WHITE = HEX_PLAYER_WHITE, + CELL_EMPTY, +}; + +typedef s16 segment_relptr_t; + +struct segment { + enum cell occupant; + u32 rank; + segment_relptr_t parent; +}; + +inline segment_relptr_t +segment_abs2rel(void const *base, struct segment *absptr) +{ + return RELPTR_ABS2REL(segment_relptr_t, base, absptr); +} + +inline struct segment * +segment_rel2abs(void const *base, segment_relptr_t relptr) +{ + return RELPTR_REL2ABS(struct segment *, segment_relptr_t, base, relptr); +} + +enum board_edges { + BLACK_SOURCE, + BLACK_SINK, + WHITE_SOURCE, + WHITE_SINK, + _BOARD_EDGE_COUNT, +}; + +struct board { + u32 size; + struct segment *segments; +}; + +inline struct segment * +board_black_source(struct board *self) +{ + return &self->segments[self->size * self->size + BLACK_SOURCE]; +} + +inline struct segment * +board_black_sink(struct board *self) +{ + return &self->segments[self->size * self->size + BLACK_SINK]; +} + +inline struct segment * +board_white_source(struct board *self) +{ + return &self->segments[self->size * self->size + WHITE_SOURCE]; +} + +inline struct segment * +board_white_sink(struct board *self) +{ + return &self->segments[self->size * self->size + WHITE_SINK]; +} + +struct move { + u8 x, y; +}; + +bool +board_init(struct board *self, u32 size); + +void +board_free(struct board *self); + +void +board_copy(struct board const *restrict self, struct board *restrict other); + +bool +board_play(struct board *self, enum hex_player player, u32 x, u32 y); + +void +board_swap(struct board *self); + +size_t +board_available_moves(struct board const *self, struct move *buf); + +bool +board_winner(struct board *self, enum hex_player *out); + +#endif /* HEXES_BOARD_H */ diff --git a/agents/hexes/include/hexes/log.h b/agents/hexes/include/hexes/log.h @@ -0,0 +1,33 @@ +#ifndef HEXES_LOG_H +#define HEXES_LOG_H + +#include "hexes.h" + +enum log_level { + LOG_ERROR, + LOG_WARN, + LOG_INFO, + LOG_DEBUG, +}; + +inline void +dbglog(enum log_level log_level, char const *fmt, ...) +{ + if (opts.log_level < log_level) return; + + switch (log_level) { + case LOG_ERROR: fputs("[ERROR]", stderr); break; + case LOG_WARN: fputs("[WARN] ", stderr); break; + case LOG_INFO: fputs("[INFO] ", stderr); break; + case LOG_DEBUG: fputs("[DEBUG]", stderr); break; + } + + fputc(' ', stderr); + + va_list ap; + va_start(ap, fmt); + vfprintf(stderr, fmt, ap); + va_end(ap); +} + +#endif /* HEXES_LOG_H */ diff --git a/agents/hexes/include/hexes/network.h b/agents/hexes/include/hexes/network.h @@ -0,0 +1,22 @@ +#ifndef HEXES_NETWORK_H +#define HEXES_NETWORK_H + +#include "hexes.h" + +struct network { + int sockfd; +}; + +bool +network_init(struct network *self, char const *host, char const *port); + +void +network_free(struct network *self); + +bool +network_send(struct network *self, struct hex_msg const *msg); + +bool +network_recv(struct network *self, struct hex_msg *out, enum hex_msg_type *expected, size_t len); + +#endif /* HEXES_NETWORK_H */ diff --git a/agents/hexes/include/hexes/threadpool.h b/agents/hexes/include/hexes/threadpool.h @@ -0,0 +1,16 @@ +#ifndef HEXES_THREADPOOL_H +#define HEXES_THREADPOOL_H + +#include "hexes.h" + +struct threadpool { + u32 threads; +}; + +bool +threadpool_init(struct threadpool *self, u32 threads); + +void +threadpool_free(struct threadpool *self); + +#endif /* HEXES_THREADPOOL_H */ diff --git a/agents/hexes/include/hexes/utils.h b/agents/hexes/include/hexes/utils.h @@ -0,0 +1,103 @@ +#ifndef HEXES_UTILS_H +#define HEXES_UTILS_H + +#include "hexes.h" + +inline void +swap(void *restrict lhs, void *restrict rhs, size_t size) +{ + assert(lhs); + assert(rhs); + assert(size); + + u8 tmp[size]; + + memcpy(tmp, rhs, size); + memcpy(rhs, lhs, size); + memcpy(lhs, tmp, size); +} + +inline void +shuffle(void *arr, size_t size, size_t len) +{ + assert(arr); + assert(size); + + for (size_t i = 0; i < len - 2; i++) { + size_t j = (i + random()) % len; + + swap((u8 *) arr + (i * size), + (u8 *) arr + (j * size), + size); + } +} + +inline void +difftimespec(struct timespec *restrict lhs, struct timespec *restrict rhs, struct timespec *restrict out) +{ + if (lhs->tv_sec <= rhs->tv_sec && lhs->tv_nsec < rhs->tv_nsec) { + out->tv_sec = 0; + out->tv_nsec = 0; + } else { + out->tv_sec = lhs->tv_sec - rhs->tv_sec - (lhs->tv_nsec < rhs->tv_nsec); + out->tv_nsec = lhs->tv_nsec - rhs->tv_nsec + (lhs->tv_nsec < rhs->tv_nsec) * NANOSECS; + } +} + +struct mem_pool { + void *ptr; + size_t cap, len; +}; + +inline bool +mem_pool_init(struct mem_pool *self, size_t align, size_t capacity) +{ + assert(self); + assert(align); + assert(align % 2 == 0); + assert(capacity % align == 0); + + self->ptr = aligned_alloc(align, capacity); + if (!self->ptr) return false; + + self->cap = capacity; + self->len = 0; + + return true; +} + +inline void +mem_pool_free(struct mem_pool *self) +{ + assert(self); + + free(self->ptr); +} + +inline void +mem_pool_reset(struct mem_pool *self) +{ + assert(self); + + self->len = 0; +} + +inline void * +mem_pool_alloc(struct mem_pool *self, size_t align, size_t size) +{ + assert(self); + assert(align); + assert(align % 2 == 0); + + size_t align_off = align - 1, align_mask = ~align_off; + size_t aligned_len = (self->len + align_off) & align_mask; + + if (aligned_len + size >= self->cap) return NULL; + + void *ptr = (u8 *) self->ptr + aligned_len; + self->len = aligned_len + size; + + return ptr; +} + +#endif /* HEXES_UTILS_H */ diff --git a/agents/hexes/run.sh b/agents/hexes/run.sh @@ -0,0 +1,3 @@ +#!/bin/sh + +exec $(dirname $0)/hexes $@ diff --git a/agents/hexes/src/agent.c b/agents/hexes/src/agent.c @@ -0,0 +1,70 @@ +#include "hexes/agent.h" + +bool +agent_init(struct agent *self, enum agent_type type, struct board const *board, + struct threadpool *threadpool, u32 mem_limit_mib, enum hex_player player) +{ + assert(self); + + self->type = type; + + (void) player; + + switch (type) { + case AGENT_RANDOM: + return agent_random_init(&self->backend.random, board); + + case AGENT_MCTS: + return agent_mcts_init(&self->backend.mcts, board, threadpool, mem_limit_mib, player); + } + + return false; +} + +void +agent_free(struct agent *self) +{ + assert(self); + + switch (self->type) { + case AGENT_RANDOM: agent_random_free(&self->backend.random); break; + case AGENT_MCTS: agent_mcts_free(&self->backend.mcts); break; + } +} + +void +agent_play(struct agent *self, enum hex_player player, u32 x, u32 y) +{ + assert(self); + + switch (self->type) { + case AGENT_RANDOM: agent_random_play(&self->backend.random, player, x, y); break; + case AGENT_MCTS: agent_mcts_play(&self->backend.mcts, player, x, y); break; + } +} + +void +agent_swap(struct agent *self) +{ + assert(self); + + switch (self->type) { + case AGENT_RANDOM: agent_random_swap(&self->backend.random); break; + case AGENT_MCTS: agent_mcts_swap(&self->backend.mcts); break; + } +} + +bool +agent_next(struct agent *self, struct timespec timeout, u32 *out_x, u32 *out_y) +{ + assert(self); + assert(out_x); + assert(out_y); + + switch (self->type) { + case AGENT_RANDOM: return agent_random_next(&self->backend.random, timeout, out_x, out_y); + case AGENT_MCTS: return agent_mcts_next(&self->backend.mcts, timeout, out_x, out_y); + } + + return false; +} diff --git a/agents/hexes/src/agent/mcts.c b/agents/hexes/src/agent/mcts.c @@ -0,0 +1,373 @@ +#include "hexes/agent/mcts.h" + +extern inline size_t +mcts_node_sizeof(size_t children); + +extern inline mcts_node_relptr_t +mcts_node_abs2rel(void *base, struct mcts_node *absptr); + +extern inline struct mcts_node * +mcts_node_rel2abs(void *base, mcts_node_relptr_t relptr); + +static void +mcts_node_init(struct mcts_node *self, struct mcts_node *parent, + enum hex_player player, u8 x, u8 y, size_t children) +{ + assert(self); + + self->parent = mcts_node_abs2rel(self, parent); + self->player = player; + self->x = x; + self->y = y; + + self->wins = self->rave_wins = 0; + self->plays = self->rave_plays = 0; + + self->children_cap = children; + self->children_len = 0; +} + +static bool +mcts_node_expand(struct mcts_node *self, struct mem_pool *pool, u8 x, u8 y) +{ + assert(self); + assert(pool); + + struct mcts_node *child = mem_pool_alloc(pool, alignof(struct mcts_node), + mcts_node_sizeof(self->children_cap - 1)); + + if (!child) { + dbglog(LOG_WARN, "Failed to allocate child node. Consider compacting memory pool\n"); + return false; + } + + mcts_node_init(child, self, hexopponent(self->player), x, y, self->children_cap - 1); + + self->children[self->children_len++] = mcts_node_abs2rel(self, child); + + return true; +} + +static struct mcts_node * +mcts_node_get_child(struct mcts_node *self, u8 x, u8 y) +{ + assert(self); + + if (!self->children_len) return NULL; + + for (size_t i = 0; i < self->children_cap; i++) { + struct mcts_node *child = mcts_node_rel2abs(self, self->children[i]); + if (!child) continue; + + if (child->x == x && child->y == y) return child; + } + + return NULL; +} + +static f32 +mcts_node_calc_score(struct mcts_node *self) +{ + assert(self); + + /* MCTS-RAVE formula: + * ((1 - beta(n, n')) * (w / n)) + (beta(n, n') * (w' / n')) + (c * sqrt(ln t / n)) + * --- + * n = number of won playouts for this node + * n' = number of won playouts for this node for a given move + * w = total number of playouts for this node + * w' = total number of playouts for this node for a given move + * c = exploration parameter (sqrt(2), or found experimentally) + * t = total number of playouts for parent node + * beta(n, n') = function close to 1 for small n, and close to 0 for large n + */ + + /* if this node has not yet been played, return the default maximum value + * so that it is picked during expansion + */ + if (!self->plays) return INFINITY; + + s64 exploration_rounds = 300; + f32 beta = MAX(0.0, (exploration_rounds - self->plays) / (f32) exploration_rounds); + assert(0.0 <= beta && beta <= 1.0); + + dbglog(LOG_DEBUG, "beta: %lf, wins: %d, rave_wins: %d, plays: %u, rave_plays: %u\n", + beta, self->wins, self->rave_wins, self->plays, self->rave_plays); + + f32 exploration = 0; // TODO: implement exploration parameter + + f32 exploitation = (1 - beta) * ((f32) self->wins / (f32) self->plays); + assert(-1.0 <= exploitation && exploitation <= 1.0); + + f32 rave_exploitation = beta * ((f32) self->rave_wins / (f32) self->rave_plays); + assert(-1.0 <= rave_exploitation && rave_exploitation <= 1.0); + + return exploration + exploitation + rave_exploitation; +} + +static struct mcts_node * +mcts_node_best_child(struct mcts_node *self) +{ + assert(self); + + f32 max_score = -INFINITY; + struct mcts_node *best_child = NULL; + for (size_t i = 0; i < self->children_cap; i++) { + struct mcts_node *child = mcts_node_rel2abs(self, self->children[i]); + if (!child) continue; + + dbglog(LOG_DEBUG, "Node: {parent=%p, children=%" PRIu8 ", x=%" PRIu32 ", y=%" PRIu32 "}\n", + mcts_node_rel2abs(child, child->parent), child->children_len, child->x, child->y); + + f32 score = mcts_node_calc_score(child); + + if (score > max_score) { + max_score = score; + best_child = child; + } + } + + return best_child; +} + +bool +agent_mcts_init(struct agent_mcts *self, struct board const *board, struct threadpool *threadpool, + u32 mem_limit_mib, enum hex_player player) +{ + assert(self); + + self->board = board; + self->threadpool = threadpool; + + if (!board_init(&self->shadow_board, board->size)) return false; + + size_t align = alignof(struct mcts_node); + size_t cap = ((mem_limit_mib * MiB) - RESERVED_MEM) & ~(align - 1); + + if (!mem_pool_init(&self->pool, align, cap)) { + board_free(&self->shadow_board); + return false; + } + + size_t moves = board_available_moves(board, NULL); + self->root = mem_pool_alloc(&self->pool, alignof(struct mcts_node), mcts_node_sizeof(moves)); + mcts_node_init(self->root, NULL, hexopponent(player), 0, 0, moves); + + return true; +} + +void +agent_mcts_free(struct agent_mcts *self) +{ + assert(self); + + mem_pool_free(&self->pool); +} + +void +agent_mcts_play(struct agent_mcts *self, enum hex_player player, u32 x, u32 y) +{ + assert(self); + + mem_pool_reset(&self->pool); + + size_t moves = board_available_moves(self->board, NULL); + self->root = mem_pool_alloc(&self->pool, alignof(struct mcts_node), mcts_node_sizeof(moves)); + mcts_node_init(self->root, NULL, player, x, y, moves); + + // TODO: implement tree compaction and tree reuse, if it improves play + // one possible issue is children containing stale board states, + // leading to potentially invalid moves being generated. another + // is the simple fact that walking the tree to compact it takes + // a significant amount of time and potentially outweighs simply + // resetting the pool and performing a few more rounds of MCTS +} + +void +agent_mcts_swap(struct agent_mcts *self) +{ + assert(self); + + struct mcts_node old_root = *self->root; + + mem_pool_reset(&self->pool); + + size_t moves = board_available_moves(self->board, NULL); + self->root = mem_pool_alloc(&self->pool, alignof(struct mcts_node), mcts_node_sizeof(moves)); + mcts_node_init(self->root, NULL, hexopponent(old_root.player), old_root.x, old_root.y, moves); +} + +static bool +mcts_search(struct agent_mcts *self, struct timespec timeout); + +bool +agent_mcts_next(struct agent_mcts *self, struct timespec timeout, u32 *out_x, u32 *out_y) +{ + assert(self); + assert(out_x); + assert(out_y); + + if (!mcts_search(self, timeout)) return false; + + struct mcts_node *root = self->pool.ptr; + assert(root->children_len); + + u32 max_plays = 0; + struct mcts_node *best_child = NULL; + for (size_t i = 0; i < root->children_cap; i++) { + struct mcts_node *child = mcts_node_rel2abs(root, root->children[i]); + if (!child) continue; + + if (child->plays > max_plays) { + max_plays = child->plays; + best_child = child; + } else if (child->plays == max_plays && random() % 2) { + best_child = child; + } + } + + assert(best_child); + + *out_x = best_child->x; + *out_y = best_child->y; + + return true; +} + +static bool +mcts_round(struct agent_mcts *self, struct move *moves) +{ + assert(self); + + board_copy(self->board, &self->shadow_board); + + dbglog(LOG_DEBUG, "Starting MCTS round\n"); + + /* selection: we walk the mcts tree, picking the child with the highest + * mcts-rave score, until we hit a node with unexpanded children + */ + struct mcts_node *node = self->root; + while (node->children_len == node->children_cap) { + struct mcts_node *child = mcts_node_best_child(node); + if (!child) break; + + if (!board_play(&self->shadow_board, child->player, child->x, child->y)) { + dbglog(LOG_WARN, "Failed to play move (%" PRIu32 ", %" PRIu32 ") to shadow board\n", child->x, child->y); + return false; + } + + node = child; + } + + dbglog(LOG_DEBUG, "Selected node {parent=%p, children=%" PRIu8 ", x=%" PRIu32 ", y=%" PRIu32 "} for expansion\n", + mcts_node_rel2abs(node, node->parent), node->children_len, node->x, node->y); + + size_t moves_len = board_available_moves(&self->shadow_board, moves); + shuffle(moves, sizeof *moves, moves_len); + + /* expansion: we expand the chosen node, creating a new child for a + * random move + */ + enum hex_player winner; + if (!board_winner(&self->shadow_board, &winner)) { + struct move move = moves[--moves_len]; + + if (!mcts_node_expand(node, &self->pool, move.x, move.y)) { + dbglog(LOG_WARN, "Failed to expand selected node\n"); + return false; + } + + struct mcts_node *child = mcts_node_get_child(node, move.x, move.y); + assert(child); + + if (!board_play(&self->shadow_board, child->player, child->x, child->y)) { + dbglog(LOG_WARN, "Failed to play move (%" PRIu32 ", %" PRIu32 ") to shadow board\n", child->x, child->y); + return false; + } + } + + dbglog(LOG_DEBUG, "Expanded node {parent=%p, children=%" PRIu8 ", x=%" PRIu32 ", y=%" PRIu32 "}\n", + mcts_node_rel2abs(node, node->parent), node->children_len, node->x, node->y); + + /* simulation: we simulate the game using a uniform random walk of the + * game state space, until a winner is found + */ + enum hex_player player = node->player; + while (!board_winner(&self->shadow_board, &winner)) { + struct move move = moves[--moves_len]; + + if (!board_play(&self->shadow_board, player, move.x, move.y)) { + dbglog(LOG_WARN, "Failed to play move (%" PRIu32 ", %" PRIu32 ") to shadow board\n", move.x, move.y); + return false; + } + + player = hexopponent(player); + } + + dbglog(LOG_DEBUG, "Completed playouts for node {parent=%p, children=%" PRIu8 ", x=%" PRIu32 ", y=%" PRIu32 "}\n", + mcts_node_rel2abs(node, node->parent), node->children_len, node->x, node->y); + + /* backpropagation: we update the state information in the mcts tree + * by walking backwards from the selected node + */ + do { + s32 reward = winner == node->player ? +1 : -1; + + for (size_t i = 0; i < node->children_len; i++) { + struct mcts_node *child = mcts_node_rel2abs(node, node->children[i]); + if (!child) continue; + + struct segment *segment = &self->shadow_board.segments[child->y * self->shadow_board.size + child->x]; + if ((enum cell) child->player == segment->occupant) { + child->rave_plays += 1; + child->rave_wins += -reward; + } + } + + node->plays += 1; + node->wins += reward; + } while ((node = mcts_node_rel2abs(node, node->parent))); + + dbglog(LOG_DEBUG, "Completed backpropagation from selected node\n"); + + dbglog(LOG_DEBUG, "Completed MCTS round\n"); + + return true; +} + +static bool +mcts_search(struct agent_mcts *self, struct timespec timeout) +{ + assert(self); + + struct move *moves = alloca(self->board->size * self->board->size * sizeof *moves); + + struct timespec time; + clock_gettime(CLOCK_MONOTONIC, &time); + + u64 end_nanos = TIMESPEC_TO_NANOS(time.tv_sec, time.tv_nsec) + + TIMESPEC_TO_NANOS(timeout.tv_sec, timeout.tv_nsec); + + dbglog(LOG_INFO, "Starting MCTS tree search with %" PRIu32 " second timeout\n", timeout.tv_sec); + + size_t rounds = 0; + while (true) { + clock_gettime(CLOCK_MONOTONIC, &time); + if (end_nanos <= TIMESPEC_TO_NANOS(time.tv_sec, time.tv_nsec)) { + dbglog(LOG_DEBUG, "Search timeout elapsed\n"); + break; + } + + if (!mcts_round(self, moves)) { + dbglog(LOG_WARN, "Failed to perform MCTS round %zu\n", rounds + 1); + break; + } + + rounds++; + } + + dbglog(LOG_INFO, "Completed %zu rounds of MCTS\n", rounds); + dbglog(LOG_INFO, "MCTS node pool occupancy: %zu/%zu bytes allocated\n", self->pool.len, self->pool.cap); + + return true; +} diff --git a/agents/hexes/src/agent/random.c b/agents/hexes/src/agent/random.c @@ -0,0 +1,76 @@ +#include "hexes/agent/random.h" + +bool +agent_random_init(struct agent_random *self, struct board const *board) +{ + assert(self); + assert(board); + + self->len = board->size * board->size; + + if (!(self->moves = malloc(self->len * sizeof *self->moves))) return false; + + for (u32 j = 0; j < board->size; j++) { + for (u32 i = 0; i < board->size; i++) { + size_t idx = j * board->size + i; + + self->moves[idx].x = i; + self->moves[idx].y = j; + } + } + + shuffle(self->moves, sizeof *self->moves, self->len); + + return true; +} + +void +agent_random_free(struct agent_random *self) +{ + assert(self); + + free(self->moves); +} + +void +agent_random_play(struct agent_random *self, enum hex_player player, u32 x, u32 y) +{ + assert(self); + + (void) player; + + for (size_t i = 0; i < self->len; i++) { + if (self->moves[i].x == x && self->moves[i].y == y) { + swap(&self->moves[i], &self->moves[--self->len], sizeof *self->moves); + break; + } + } +} + +void +agent_random_swap(struct agent_random *self) +{ + assert(self); + + /* NOTE: this would affect nothing, as we remove moves made by both players + */ + (void) self; +} + +bool +agent_random_next(struct agent_random *self, struct timespec timeout, u32 *out_x, u32 *out_y) +{ + assert(self); + assert(out_x); + assert(out_y); + + (void) timeout; + + if (!self->len) return false; + + size_t idx = --self->len; + *out_x = self->moves[idx].x; + *out_y = self->moves[idx].y; + + return true; +} diff --git a/agents/hexes/src/board.c b/agents/hexes/src/board.c @@ -0,0 +1,221 @@ +#include "hexes/board.h" + +#define NEIGHBOUR_COUNT 6 + +extern inline segment_relptr_t +segment_abs2rel(void const *base, struct segment *absptr); + +extern inline struct segment * +segment_rel2abs(void const *base, segment_relptr_t relptr); + +static s8 neighbour_dx[NEIGHBOUR_COUNT] = { -1, -1, 0, 0, +1, +1, }; +static s8 neighbour_dy[NEIGHBOUR_COUNT] = { 0, +1, -1, +1, -1, 0, }; + +struct segment * +segment_root(struct segment *self) +{ + assert(self); + + struct segment *parent, *grandparent; + while ((parent = segment_rel2abs(self, self->parent))) { + if (!(grandparent = segment_rel2abs(parent, parent->parent))) + return parent; + + self->parent = segment_abs2rel(self, grandparent); + + self = grandparent; + } + + return self; +} + +bool +segment_merge(struct segment *self, struct segment *elem) +{ + assert(self); + assert(elem); + + struct segment *self_root = segment_root(self); + struct segment *elem_root = segment_root(elem); + + if (self_root == elem_root) return false; + + if (self_root->rank <= elem_root->rank) { + self_root->parent = segment_abs2rel(self_root, elem_root); + } else if (self_root->rank > elem_root->rank) { + elem_root->parent = segment_abs2rel(elem_root, self_root); + } + + /* disambiguate between self and element ranks */ + if (self_root->rank == elem_root->rank) elem_root->rank++; + + return true; +} + +extern inline struct segment * +board_black_source(struct board *self); + +extern inline struct segment * +board_black_sink(struct board *self); + +extern inline struct segment * +board_white_source(struct board *self); + +extern inline struct segment * +board_white_sink(struct board *self); + +bool +board_init(struct board *self, u32 size) +{ + assert(self); + + self->size = size; + + size_t segments = (size * size) + _BOARD_EDGE_COUNT; + if (!(self->segments = malloc(segments * sizeof *self->segments))) + return false; + + for (size_t i = 0; i < segments; i++) { + struct segment *segment = &self->segments[i]; + + segment->occupant = CELL_EMPTY; + segment->rank = 0; + segment->parent = RELPTR_NULL; + } + + struct segment *black_source = board_black_source(self); + struct segment *black_sink = board_black_sink(self); + struct segment *white_source = board_white_source(self); + struct segment *white_sink = board_white_sink(self); + + black_source->occupant = black_sink->occupant = CELL_BLACK; + white_source->occupant = white_sink->occupant = CELL_WHITE; + + return true; +} + +void +board_free(struct board *self) +{ + assert(self); + + free(self->segments); +} + +void +board_copy(struct board const *restrict self, struct board *restrict other) +{ + assert(self); + assert(other); + + assert(self->size == other->size); + + size_t segments = (self->size * self->size) + _BOARD_EDGE_COUNT; + memcpy(other->segments, self->segments, segments * sizeof *self->segments); +} + +bool +board_play(struct board *self, enum hex_player player, u32 x, u32 y) +{ + assert(self); + + struct segment *segment = &self->segments[y * self->size + x]; + + if (segment->occupant != CELL_EMPTY) return false; + + segment->occupant = (enum cell) player; + + /* handle connection to source/sink for given player at edge of board + */ + if (player == HEX_PLAYER_BLACK) { + if (x == 0) + segment_merge(board_black_source(self), segment); + else if (x == self->size - 1) + segment_merge(board_black_sink(self), segment); + } else if (player == HEX_PLAYER_WHITE) { + if (y == 0) + segment_merge(board_white_source(self), segment); + else if (y == self->size - 1) + segment_merge(board_white_sink(self), segment); + } + + /* handle connecting to neighbouring segments with same occupant + */ + for (size_t i = 0; i < NEIGHBOUR_COUNT; i++) { + s64 px = x + neighbour_dx[i]; + s64 py = y + neighbour_dy[i]; + + if (0 <= px && px < self->size && 0 <= py && py < self->size) { + struct segment *neighbour = &self->segments[py * self->size + px]; + + if (segment->occupant == neighbour->occupant) + segment_merge(segment, neighbour); + } + } + + return true; +} + +void +board_swap(struct board *self) +{ + assert(self); + + for (u32 j = 0; j < self->size; j++) { + for (u32 i = 0; i < self->size; i++) { + struct segment *segment = &self->segments[j * self->size + i]; + + switch (segment->occupant) { + case CELL_BLACK: + segment->occupant = CELL_EMPTY; + board_play(self, HEX_PLAYER_WHITE, i, j); + break; + + case CELL_WHITE: + segment->occupant = CELL_EMPTY; + board_play(self, HEX_PLAYER_BLACK, i, j); + break; + + default: break; + } + } + } +} + +size_t +board_available_moves(struct board const *self, struct move *buf) +{ + assert(self); + + size_t idx = 0; + for (u32 j = 0; j < self->size; j++) { + for (u32 i = 0; i < self->size; i++) { + if (self->segments[j * self->size + i].occupant == CELL_EMPTY) { + if (buf) { + buf[idx].x = i; + buf[idx].y = j; + } + + idx++; + } + } + } + + return idx; +} + +bool +board_winner(struct board *self, enum hex_player *out) +{ + assert(self); + + if (segment_root(board_black_source(self)) == segment_root(board_black_sink(self))) { + *out = HEX_PLAYER_BLACK; + return true; + } else if (segment_root(board_white_source(self)) == segment_root(board_white_sink(self))) { + *out = HEX_PLAYER_WHITE; + return true; + } + + return false; +} diff --git a/agents/hexes/src/hexes.c b/agents/hexes/src/hexes.c @@ -0,0 +1,248 @@ +#include "hexes.h" +#include "hexes/agent.h" +#include "hexes/board.h" +#include "hexes/log.h" +#include "hexes/network.h" +#include "hexes/threadpool.h" + +struct opts opts = { + .log_level = LOG_INFO, + .agent_type = AGENT_MCTS, +}; + +enum game_state { + GAME_START, + GAME_RECV, + GAME_SEND, + GAME_END, +}; + +struct game { + struct network network; + struct threadpool threadpool; + struct board board; + struct agent agent; + + size_t round, thread_limit, mem_limit_mib; + struct timespec timer; + enum hex_player player, opponent; + + enum game_state state; + bool game_over; +}; + +static struct game game = { + .state = GAME_START, +}; + +static void +start_handler(struct game *game); + +static void +recv_handler(struct game *game); + +static void +send_handler(struct game *game); + +static void +end_handler(struct game *game); + +int +main(int argc, char **argv) +{ + srandom(getpid()); + + if (argc < 3) { + dbglog(LOG_ERROR, "Usage: %s <host> <port>\n", argv[0]); + exit(EXIT_FAILURE); + } + + char *host = argv[1], *port = argv[2]; + if (!network_init(&game.network, host, port)) { + dbglog(LOG_ERROR, "Failed to initialise network (connecting to %s:%s)\n", host, port); + exit(EXIT_FAILURE); + } + + while (!game.game_over) { + dbglog(LOG_INFO, "==============================\n"); + + switch (game.state) { + case GAME_START: start_handler(&game); break; + case GAME_RECV: recv_handler(&game); break; + case GAME_SEND: send_handler(&game); break; + case GAME_END: end_handler(&game); break; + } + + game.round++; + } + + agent_free(&game.agent); + board_free(&game.board); + threadpool_free(&game.threadpool); + + network_free(&game.network); + + exit(EXIT_SUCCESS); +} + +static void +start_handler(struct game *game) +{ + assert(game); + + enum hex_msg_type expected[] = { HEX_MSG_START, }; + + struct hex_msg msg; + if (!network_recv(&game->network, &msg, expected, ARRLEN(expected))) { + dbglog(LOG_ERROR, "Failed to receive message from server\n"); + goto error; + } + + game->player = msg.data.start.player; + game->opponent = hexopponent(game->player); + game->timer.tv_sec = msg.data.start.game_secs; + game->thread_limit = msg.data.start.thread_limit; + game->mem_limit_mib = msg.data.start.mem_limit_mib; + + dbglog(LOG_INFO, "Received game parameters: player: %s, board size: %" PRIu32 ", game secs: %" PRIu32 ", thread limit: %" PRIu32 ", mem limit (MiB): %" PRIu32 "\n", + hexplayerstr(game->player), msg.data.start.board_size, game->timer.tv_sec, game->thread_limit, game->mem_limit_mib); + + if (!threadpool_init(&game->threadpool, msg.data.start.thread_limit - 1)) { + dbglog(LOG_ERROR, "Failed to initialise threadpool\n"); + goto error; + } + + if (!board_init(&game->board, msg.data.start.board_size)) { + dbglog(LOG_ERROR, "Failed to initialise board\n"); + goto error; + } + + if (!agent_init(&game->agent, (enum agent_type) opts.agent_type, &game->board, + &game->threadpool, game->mem_limit_mib, game->player)) { + dbglog(LOG_ERROR, "Failed to initialise agent\n"); + goto error; + } + + switch (game->player) { + case HEX_PLAYER_BLACK: game->state = GAME_SEND; break; + case HEX_PLAYER_WHITE: game->state = GAME_RECV; break; + } + + return; + +error: + game->state = GAME_END; +} + +static void +recv_handler(struct game *game) +{ + assert(game); + + enum hex_msg_type expected[] = { HEX_MSG_MOVE, HEX_MSG_SWAP, HEX_MSG_END, }; + + struct hex_msg msg; + if (!network_recv(&game->network, &msg, expected, ARRLEN(expected))) { + dbglog(LOG_ERROR, "Failed to receive message from server\n"); + goto error; + } + + switch (msg.type) { + case HEX_MSG_MOVE: { + dbglog(LOG_INFO, "Received move {x=%" PRIu32 ", y=%" PRIu32 "} from opponent\n", + msg.data.move.board_x, msg.data.move.board_y); + + if (!board_play(&game->board, game->opponent, msg.data.move.board_x, + msg.data.move.board_y)) { + dbglog(LOG_ERROR, "Failed to play received move on board\n"); + goto error; + } + + agent_play(&game->agent, game->opponent, msg.data.move.board_x, msg.data.move.board_y); + + if (game->round == 1 && /* TODO: calculate when to attempt to swap board */ false) { + game->state = GAME_RECV; + } else { + game->state = GAME_SEND; + } + } break; + + case HEX_MSG_SWAP: { + dbglog(LOG_INFO, "Received swap msg from opponent\n"); + + board_swap(&game->board); + agent_swap(&game->agent); + + game->state = GAME_SEND; + } break; + + case HEX_MSG_END: { + dbglog(LOG_INFO, "Player %s has won the game\n", hexplayerstr(msg.data.end.winner)); + + game->state = GAME_END; + } break; + } + + return; + +error: + game->state = GAME_END; +} + +static void +send_handler(struct game *game) +{ + assert(game); + + struct hex_msg msg = { + .type = HEX_MSG_MOVE, + }; + + size_t total_rounds = (game->board.size * game->board.size) / 2; + + struct timespec timeout = { + .tv_sec = game->timer.tv_sec / (total_rounds - game->round), + }, start, end, diff, new_timer; + + clock_gettime(CLOCK_MONOTONIC, &start); + if (!agent_next(&game->agent, timeout, &msg.data.move.board_x, &msg.data.move.board_y)) { + dbglog(LOG_ERROR, "Failed to generate next move\n"); + goto error; + } + clock_gettime(CLOCK_MONOTONIC, &end); + + difftimespec(&end, &start, &diff); + difftimespec(&game->timer, &diff, &new_timer); + game->timer = new_timer; + + dbglog(LOG_INFO, "Generated move: {x=%" PRIu32 ", y=%" PRIu32 "}\n", msg.data.move.board_x, msg.data.move.board_y); + + if (!board_play(&game->board, game->player, msg.data.move.board_x, msg.data.move.board_y)) { + dbglog(LOG_ERROR, "Failed to play generated move on board\n"); + goto error; + } + + agent_play(&game->agent, game->player, msg.data.move.board_x, msg.data.move.board_y); + + if (!network_send(&game->network, &msg)) { + dbglog(LOG_ERROR, "Failed to send message to server\n"); + goto error; + } + + game->state = GAME_RECV; + + return; + +error: + game->state = GAME_END; +} + +static void +end_handler(struct game *game) +{ + assert(game); + + dbglog(LOG_INFO, "Game over. Goodbye, World!\n"); + + game->game_over = true; +} diff --git a/agents/hexes/src/log.c b/agents/hexes/src/log.c @@ -0,0 +1,4 @@ +#include "hexes/log.h" + +extern inline void +dbglog(enum log_level, char const *, ...); diff --git a/agents/hexes/src/network.c b/agents/hexes/src/network.c @@ -0,0 +1,100 @@ +#include "hexes/network.h" + +bool +network_init(struct network *self, char const *host, char const *port) +{ + assert(self); + assert(host); + assert(port); + + struct addrinfo hints = { + .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM, + }, *addrinfo, *ptr; + + int res; + if ((res = getaddrinfo(host, port, &hints, &addrinfo))) { + dbglog(LOG_ERROR, "Failed to get address info: %s\n", gai_strerror(res)); + return false; + } + + for (ptr = addrinfo; ptr; ptr = ptr->ai_next) { + self->sockfd = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol); + if (self->sockfd == -1) continue; + if (connect(self->sockfd, ptr->ai_addr, ptr->ai_addrlen) != -1) break; + close(self->sockfd); + } + + freeaddrinfo(addrinfo); + + if (!ptr) { + dbglog(LOG_ERROR, "Failed to connect to %s:%s\n", host, port); + return false; + } + + return true; +} + +void +network_free(struct network *self) +{ + assert(self); + + close(self->sockfd); +} + +bool +network_send(struct network *self, struct hex_msg const *msg) +{ + assert(self); + assert(msg); + + u8 buf[HEX_MSG_SZ]; + + if (!hex_msg_try_serialise(msg, buf)) return false; + + size_t count = 0; + do { + ssize_t curr = send(self->sockfd, buf + count, ARRLEN(buf) - count, 0); + if (curr <= 0) return false; /* error or socket shutdown */ + count += curr; + } while (count < ARRLEN(buf)); + + return true; +} + +bool +network_recv(struct network *self, struct hex_msg *out, enum hex_msg_type *expected, size_t len) +{ + assert(self); + assert(out); + assert(expected); + assert(len); + + u8 buf[HEX_MSG_SZ]; + + size_t count = 0; + do { + ssize_t curr = recv(self->sockfd, buf + count, ARRLEN(buf) - count, 0); + if (curr <= 0) return false; /* error or socket shutdown */ + count += curr; + } while (count < ARRLEN(buf)); + + struct hex_msg msg; + if (!hex_msg_try_deserialise(buf, &msg)) return false; + + for (size_t i = 0; i < len; i++) { + if (msg.type == expected[i]) { + *out = msg; + return true; + } + } + + return false; +} + +extern inline b32 +hex_msg_try_serialise(struct hex_msg const *msg, u8 out[static HEX_MSG_SZ]); + +extern inline b32 +hex_msg_try_deserialise(u8 buf[static HEX_MSG_SZ], struct hex_msg *out); diff --git a/agents/hexes/src/threadpool.c b/agents/hexes/src/threadpool.c @@ -0,0 +1,24 @@ +#include "hexes/threadpool.h" + +bool +threadpool_init(struct threadpool *self, u32 threads) +{ + assert(self); + + (void) self; + (void) threads; + + // TODO: implement me + + return true; +} + +void +threadpool_free(struct threadpool *self) +{ + assert(self); + + (void) self; + + // TODO: implement me +} diff --git a/agents/hexes/src/utils.c b/agents/hexes/src/utils.c @@ -0,0 +1,28 @@ +#include "hexes/utils.h" + +extern inline void +swap(void *restrict lhs, void *restrict rhs, size_t size); + +extern inline void +shuffle(void *arr, size_t size, size_t len); + +extern inline void +difftimespec(struct timespec *restrict lhs, struct timespec *restrict rhs, struct timespec *restrict out); + +extern inline bool +mem_pool_init(struct mem_pool *self, size_t align, size_t capacity); + +extern inline void +mem_pool_free(struct mem_pool *self); + +extern inline void +mem_pool_reset(struct mem_pool *self); + +extern inline void * +mem_pool_alloc(struct mem_pool *self, size_t align, size_t size); + +extern inline enum hex_player +hexopponent(enum hex_player player); + +extern inline char const * +hexplayerstr(enum hex_player val); diff --git a/schedule.txt b/schedule.txt @@ -9,3 +9,5 @@ agents/example_cpp_agent/run.sh,agents/example_python3_agent/run.sh # java agent requires at least 16 threads, even with the minimal JVM options, # which means that --threads 16 must be passed to tournament-host.py #agents/example_java_agent/run.sh,agents/example_python3_agent/run.sh + +agents/example_python3_agent/run.sh,agents/hexes/run.sh