hex

hex.git
git clone git://git.lenczewski.org/hex.git
Log | Files | Refs

commit 072d7517572b56dcf30e1e0e5ed7269671307193
parent 1018b76098c392d846b775ac0c9e6a751cba1378
Author: MikoĊ‚aj Lenczewski <mblenczewski@gmail.com>
Date:   Fri,  3 Jan 2025 15:56:28 +0000

Switch to intrusive list for mcts_node children list

Diffstat:
Magents/hexes/include/hexes.h | 184++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------------------
Magents/hexes/src/agent.c | 3++-
Magents/hexes/src/agent/mcts.c | 89+++++++++++++++++++++++++++++++++++++++++++------------------------------------
Magents/hexes/src/agent/random.c | 3++-
Magents/hexes/src/board.c | 2+-
Magents/hexes/src/utils.c | 24++++++++++++++++++------
6 files changed, 202 insertions(+), 103 deletions(-)

diff --git a/agents/hexes/include/hexes.h b/agents/hexes/include/hexes.h @@ -31,12 +31,15 @@ #include <arpa/inet.h> #include <math.h> -#include <netdb.h> -#include <sys/socket.h> -#include <sys/types.h> #include <time.h> #include <unistd.h> +#include <sys/types.h> +#include <sys/socket.h> +#include <netdb.h> + +#include <sys/mman.h> + typedef int32_t b32; typedef unsigned char c8; @@ -59,13 +62,19 @@ typedef double f64; #define MIN(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) < (b) ? (b) : (a)) +/* NOTE: we define a relative pointer helper to allow fully relocatable data + * structures, for example to allow for trees that survive a memcpy() + */ + #define RELPTR_NULL (0) #define _RELPTR_MASK(ty_relptr) ((ty_relptr)1 << ((sizeof(ty_relptr) * 8) - 1)) + #define _RELPTR_ENC(ty_relptr, ptroff) \ - ((ty_relptr)((ptroff) ^ _RELPTR_MASK(ty_relptr))) + ((ty_relptr) ((ptroff) ^ _RELPTR_MASK(ty_relptr))) + #define _RELPTR_DEC(ty_relptr, relptr) \ - ((ty_relptr)((relptr) ^ _RELPTR_MASK(ty_relptr))) + ((ty_relptr) ((relptr) ^ _RELPTR_MASK(ty_relptr))) #define RELPTR_ABS2REL(ty_relptr, base, absptr) \ ((absptr) \ @@ -73,8 +82,8 @@ typedef double f64; : RELPTR_NULL) #define RELPTR_REL2ABS(ty_absptr, ty_relptr, base, relptr) \ - ((relptr) \ - ? ((ty_absptr)((u8 *) base + _RELPTR_DEC(ty_relptr, relptr))) \ + (((relptr) != RELPTR_NULL) \ + ? ((ty_absptr) ((u8 *) base + _RELPTR_DEC(ty_relptr, relptr))) \ : NULL) #define NANOSECS (1000000000ULL) @@ -168,62 +177,130 @@ difftimespec(struct timespec *restrict lhs, struct timespec *restrict rhs, struc #define IS_POW2(v) (((v) & ((v) - 1)) == 0) #define IS_ALIGNED(v, align) (((v) & ((align) - 1)) == 0) -#define ALIGN_PREV(v, align) ((v) & ~((v) - 1)) +#define ALIGN_PREV(v, align) ((v) & ~((align) - 1)) #define ALIGN_NEXT(v, align) ALIGN_PREV((v) + ((align) - 1), (align)) -struct mem_pool { +struct arena { void *ptr; - size_t cap, len; + uint64_t cap, len; }; -inline bool -mem_pool_init(struct mem_pool *self, size_t align, size_t capacity) +inline void +arena_reset(struct arena *arena) { - assert(self); + arena->len = 0; +} + +inline void * +arena_alloc(struct arena *arena, size_t size, size_t align) +{ + assert(arena); + assert(size); assert(align); assert(IS_POW2(align)); - assert(capacity % align == 0); - self->ptr = aligned_alloc(align, capacity); - if (!self->ptr) return false; + size_t aligned_len = ALIGN_NEXT(arena->len, align); + if (aligned_len + size > arena->cap) + return NULL; - self->cap = capacity; - self->len = 0; + void *ptr = (void *) ((intptr_t) arena->ptr + aligned_len); + arena->len = aligned_len + size; - return true; + return ptr; } -inline void -mem_pool_free(struct mem_pool *self) +#define ALLOC_ARRAY(arena, T, n) \ + arena_alloc((arena), sizeof(T) * (n), alignof(T)) + +#define ALLOC_SIZED(arena, T) ALLOC_ARRAY((arena), T, 1) + +typedef s64 list_node_relptr_t; + +struct list_node; + +#define FROM_NODE(node, T, member) \ + ((T *) ((uintptr_t) node - offsetof(T, member))) + +struct list_node { + list_node_relptr_t prev, next; +}; + +inline list_node_relptr_t +list_node_abs2rel(void const *restrict base, struct list_node *restrict absptr) { - assert(self); + return RELPTR_ABS2REL(list_node_relptr_t, base, absptr); +} - free(self->ptr); +inline struct list_node * +list_node_rel2abs(void const *base, list_node_relptr_t relptr) +{ + return RELPTR_REL2ABS(struct list_node *, list_node_relptr_t, base, relptr); } +#define list_node_iter(node) \ + for (struct list_node *it = node; it; it = list_node_rel2abs(it, it->next)) + +#define list_node_riter(node) \ + for (struct list_node *it = node; it ; it = list_node_rel2abs(it, it->prev)) + +struct list { + list_node_relptr_t head, tail; +}; + +#define list_iter(list) \ + list_node_iter(list_node_rel2abs(list, list->head)) + +#define list_riter(list) \ + list_node_riter(list_node_rel2abs(list, list->tail)) + inline void -mem_pool_reset(struct mem_pool *self) +list_push_head(struct list *restrict list, struct list_node *restrict node) { - assert(self); + if (list->tail == RELPTR_NULL) + list->tail = list_node_abs2rel(list, node); - self->len = 0; + struct list_node *head = list_node_rel2abs(list, list->head); + if (head) + head->prev = list_node_abs2rel(head, node); + + node->next = list_node_abs2rel(node, head); + list->head = list_node_abs2rel(list, node); } -inline void * -mem_pool_alloc(struct mem_pool *self, size_t align, size_t size) +inline void +list_push_tail(struct list *restrict list, struct list_node *restrict node) { - assert(self); - assert(align); - assert(IS_POW2(align)); + if (list->head == RELPTR_NULL) + list->head = list_node_abs2rel(list, node); + + struct list_node *tail = list_node_rel2abs(list, list->tail); + if (tail) + tail->next = list_node_abs2rel(tail, node); + + node->prev = list_node_abs2rel(node, tail); + list->tail = list_node_abs2rel(list, node); +} - size_t aligned_len = ALIGN_NEXT(self->len, align); - if (aligned_len + size > self->cap) +inline struct list_node * +list_pop_head(struct list *list) +{ + if (list->head == RELPTR_NULL) return NULL; - void *ptr = (void *) ((intptr_t) self->ptr + aligned_len); - self->len = aligned_len + size; + struct list_node *node = list_node_rel2abs(list, list->head); + list->head = list_node_abs2rel(list, list_node_rel2abs(node, node->next)); + return node; +} - return ptr; +inline struct list_node * +list_pop_tail(struct list *list) +{ + if (list->tail == RELPTR_NULL) + return NULL; + + struct list_node *node = list_node_rel2abs(list, list->tail); + list->tail = list_node_abs2rel(list, list_node_rel2abs(node, node->prev)); + return node; } /* network definitions @@ -266,7 +343,7 @@ struct segment { }; inline segment_relptr_t -segment_abs2rel(void const *base, struct segment *absptr) +segment_abs2rel(void const *restrict base, struct segment *restrict absptr) { return RELPTR_ABS2REL(segment_relptr_t, base, absptr); } @@ -334,7 +411,7 @@ void board_swap(struct board *self); size_t -board_available_moves(struct board const *self, struct move *buf); +board_available_moves(struct board const *restrict self, struct move *restrict buf); bool board_winner(struct board *self, enum hex_player *out); @@ -371,13 +448,14 @@ void agent_random_swap(struct agent_random *self); bool -agent_random_next(struct agent_random *self, struct timespec timeout, u32 *out_x, u32 *out_y); - -#define MCTS_RESERVED_MEM (MiB) +agent_random_next(struct agent_random *self, struct timespec timeout, + u32 *restrict out_x, u32 *restrict out_y); typedef s64 mcts_node_relptr_t; struct mcts_node { + alignas(64) + mcts_node_relptr_t parent; enum hex_player player; u8 x, y; @@ -385,18 +463,14 @@ struct mcts_node { s32 wins, rave_wins; u32 plays, rave_plays; - u16 children_cap, children_len; - mcts_node_relptr_t children[]; -}; + u32 children_cap, children_len; -inline size_t -mcts_node_sizeof(size_t children) -{ - return sizeof(struct mcts_node) + children * sizeof(mcts_node_relptr_t); -} + struct list children; + struct list_node list_node; +}; inline mcts_node_relptr_t -mcts_node_abs2rel(void *base, struct mcts_node *absptr) +mcts_node_abs2rel(void *restrict base, struct mcts_node *restrict absptr) { return RELPTR_ABS2REL(mcts_node_relptr_t, base, absptr); } @@ -413,7 +487,7 @@ struct agent_mcts { struct board shadow_board; - struct mem_pool pool; + struct arena pool; struct mcts_node *root; }; @@ -431,13 +505,16 @@ void agent_mcts_swap(struct agent_mcts *self); bool -agent_mcts_next(struct agent_mcts *self, struct timespec timeout, u32 *out_x, u32 *out_y); +agent_mcts_next(struct agent_mcts *self, struct timespec timeout, + u32 *restrict out_x, u32 *restrict out_y); enum agent_type { AGENT_RANDOM, AGENT_MCTS, }; +#define HEXES_RESERVED_MEM (MiB) + struct agent { enum agent_type type; union { @@ -460,6 +537,7 @@ void agent_swap(struct agent *self); bool -agent_next(struct agent *self, struct timespec timeout, u32 *out_x, u32 *out_y); +agent_next(struct agent *self, struct timespec timeout, u32 *restrict out_x, + u32 *restrict out_y); #endif /* HEXES_H */ diff --git a/agents/hexes/src/agent.c b/agents/hexes/src/agent.c @@ -55,7 +55,8 @@ agent_swap(struct agent *self) } bool -agent_next(struct agent *self, struct timespec timeout, u32 *out_x, u32 *out_y) +agent_next(struct agent *self, struct timespec timeout, u32 *restrict out_x, + u32 *restrict out_y) { assert(self); assert(out_x); diff --git a/agents/hexes/src/agent/mcts.c b/agents/hexes/src/agent/mcts.c @@ -1,17 +1,14 @@ #include "hexes.h" -extern inline size_t -mcts_node_sizeof(size_t children); - extern inline mcts_node_relptr_t -mcts_node_abs2rel(void *base, struct mcts_node *absptr); +mcts_node_abs2rel(void *restrict base, struct mcts_node *restrict absptr); extern inline struct mcts_node * mcts_node_rel2abs(void *base, mcts_node_relptr_t relptr); static void mcts_node_init(struct mcts_node *self, struct mcts_node *parent, - enum hex_player player, u8 x, u8 y, size_t children) + enum hex_player player, u8 x, u8 y, u32 children_cap) { assert(self); @@ -23,27 +20,28 @@ mcts_node_init(struct mcts_node *self, struct mcts_node *parent, self->wins = self->rave_wins = 0; self->plays = self->rave_plays = 0; - self->children_cap = children; + self->children_cap = children_cap; self->children_len = 0; + + self->children.head = self->children.tail = RELPTR_NULL; + self->list_node.prev = self->list_node.next = RELPTR_NULL; } static bool -mcts_node_expand(struct mcts_node *self, struct mem_pool *pool, u8 x, u8 y) +mcts_node_expand(struct mcts_node *self, struct arena *pool, u8 x, u8 y) { assert(self); assert(pool); - struct mcts_node *child = mem_pool_alloc(pool, alignof(struct mcts_node), - mcts_node_sizeof(self->children_cap - 1)); - + struct mcts_node *child = ALLOC_SIZED(pool, struct mcts_node); if (!child) { dbglog(LOG_WARN, "Failed to allocate child node. Consider compacting memory pool\n"); return false; } mcts_node_init(child, self, hexopponent(self->player), x, y, self->children_cap - 1); - - self->children[self->children_len++] = mcts_node_abs2rel(self, child); + list_push_tail(&self->children, &child->list_node); + self->children_len++; return true; } @@ -53,13 +51,13 @@ mcts_node_get_child(struct mcts_node *self, u8 x, u8 y) { assert(self); - if (!self->children_len) return NULL; - - for (size_t i = 0; i < self->children_cap; i++) { - struct mcts_node *child = mcts_node_rel2abs(self, self->children[i]); - if (!child) continue; + struct list *list = &self->children; + list_iter(list) { + struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node); + assert(child); - if (child->x == x && child->y == y) return child; + if (child->x == x && child->y == y) + return child; } return NULL; @@ -118,11 +116,12 @@ mcts_node_best_child(struct mcts_node *self) f32 max_score = -INFINITY; struct mcts_node *best_child = NULL; - for (size_t i = 0; i < self->children_cap; i++) { - struct mcts_node *child = mcts_node_rel2abs(self, self->children[i]); - if (!child) continue; + struct list *list = &self->children; + list_iter(list) { + struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node); + assert(child); - dbglog(LOG_DEBUG, "Node: {parent=%p, children=%" PRIu8 ", x=%" PRIu32 ", y=%" PRIu32 "}\n", + dbglog(LOG_DEBUG, "Node: {parent=%p, children=%" PRIu32 ", x=%" PRIu32 ", y=%" PRIu32 "}\n", mcts_node_rel2abs(child, child->parent), child->children_len, child->x, child->y); f32 score = mcts_node_calc_score(child); @@ -148,15 +147,20 @@ agent_mcts_init(struct agent_mcts *self, struct board const *board, struct threa if (!board_init(&self->shadow_board, board->size)) return false; size_t align = alignof(struct mcts_node); - size_t cap = ((mem_limit_mib * MiB) - MCTS_RESERVED_MEM) & ~(align - 1); + self->pool.cap = ALIGN_PREV((mem_limit_mib * MiB) - HEXES_RESERVED_MEM, align); - if (!mem_pool_init(&self->pool, align, cap)) { + int prot = PROT_READ | PROT_WRITE; + int flags = MAP_PRIVATE | MAP_ANONYMOUS; + self->pool.ptr = mmap(NULL, self->pool.cap, prot, flags, -1, 0); + madvise(self->pool.ptr, self->pool.cap, MADV_HUGEPAGE); + + if (self->pool.ptr == MAP_FAILED) { board_free(&self->shadow_board); return false; } - size_t moves = board_available_moves(board, NULL); - self->root = mem_pool_alloc(&self->pool, alignof(struct mcts_node), mcts_node_sizeof(moves)); + u32 moves = board_available_moves(self->board, NULL); + self->root = ALLOC_SIZED(&self->pool, struct mcts_node); mcts_node_init(self->root, NULL, hexopponent(player), 0, 0, moves); return true; @@ -167,7 +171,7 @@ agent_mcts_free(struct agent_mcts *self) { assert(self); - mem_pool_free(&self->pool); + munmap(self->pool.ptr, self->pool.cap); } void @@ -175,10 +179,10 @@ agent_mcts_play(struct agent_mcts *self, enum hex_player player, u32 x, u32 y) { assert(self); - mem_pool_reset(&self->pool); + arena_reset(&self->pool); - size_t moves = board_available_moves(self->board, NULL); - self->root = mem_pool_alloc(&self->pool, alignof(struct mcts_node), mcts_node_sizeof(moves)); + u32 moves = board_available_moves(self->board, NULL); + self->root = ALLOC_SIZED(&self->pool, struct mcts_node); mcts_node_init(self->root, NULL, player, x, y, moves); // TODO: implement tree reuse, if it improves play @@ -207,10 +211,10 @@ agent_mcts_swap(struct agent_mcts *self) struct mcts_node old_root = *self->root; - mem_pool_reset(&self->pool); + arena_reset(&self->pool); - size_t moves = board_available_moves(self->board, NULL); - self->root = mem_pool_alloc(&self->pool, alignof(struct mcts_node), mcts_node_sizeof(moves)); + u32 moves = board_available_moves(self->board, NULL); + self->root = ALLOC_SIZED(&self->pool, struct mcts_node); mcts_node_init(self->root, NULL, hexopponent(old_root.player), old_root.x, old_root.y, moves); } @@ -218,7 +222,8 @@ static bool mcts_search(struct agent_mcts *self, struct timespec timeout); bool -agent_mcts_next(struct agent_mcts *self, struct timespec timeout, u32 *out_x, u32 *out_y) +agent_mcts_next(struct agent_mcts *self, struct timespec timeout, + u32 *restrict out_x, u32 *restrict out_y) { assert(self); assert(out_x); @@ -227,13 +232,14 @@ agent_mcts_next(struct agent_mcts *self, struct timespec timeout, u32 *out_x, u3 if (!mcts_search(self, timeout)) return false; struct mcts_node *root = self->pool.ptr; - assert(root->children_len); + assert(root->children.head != RELPTR_NULL); u32 max_plays = 0; struct mcts_node *best_child = NULL; - for (size_t i = 0; i < root->children_cap; i++) { - struct mcts_node *child = mcts_node_rel2abs(root, root->children[i]); - if (!child) continue; + struct list *list = &root->children; + list_iter(list) { + struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node); + assert(child); if (child->plays > max_plays) { max_plays = child->plays; @@ -330,9 +336,10 @@ mcts_round(struct agent_mcts *self, struct move *moves) do { s32 reward = winner == node->player ? +1 : -1; - for (size_t i = 0; i < node->children_len; i++) { - struct mcts_node *child = mcts_node_rel2abs(node, node->children[i]); - if (!child) continue; + struct list *list = &node->children; + list_iter(list) { + struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node); + assert(child); struct segment *segment = &self->shadow_board.segments[child->y * self->shadow_board.size + child->x]; if ((enum cell) child->player == segment->occupant) { diff --git a/agents/hexes/src/agent/random.c b/agents/hexes/src/agent/random.c @@ -58,7 +58,8 @@ agent_random_swap(struct agent_random *self) } bool -agent_random_next(struct agent_random *self, struct timespec timeout, u32 *out_x, u32 *out_y) +agent_random_next(struct agent_random *self, struct timespec timeout, + u32 *restrict out_x, u32 *restrict out_y) { assert(self); assert(out_x); diff --git a/agents/hexes/src/board.c b/agents/hexes/src/board.c @@ -3,7 +3,7 @@ #define NEIGHBOUR_COUNT 6 extern inline segment_relptr_t -segment_abs2rel(void const *base, struct segment *absptr); +segment_abs2rel(void const *restrict base, struct segment *restrict absptr); extern inline struct segment * segment_rel2abs(void const *base, segment_relptr_t relptr); diff --git a/agents/hexes/src/utils.c b/agents/hexes/src/utils.c @@ -12,17 +12,29 @@ shuffle(void *arr, size_t size, size_t len); extern inline void difftimespec(struct timespec *restrict lhs, struct timespec *restrict rhs, struct timespec *restrict out); -extern inline bool -mem_pool_init(struct mem_pool *self, size_t align, size_t capacity); +extern inline void +arena_reset(struct arena *arena); + +extern inline void * +arena_alloc(struct arena *arena, size_t size, size_t align); + +extern inline list_node_relptr_t +list_node_abs2rel(void const *restrict base, struct list_node *restrict absptr); + +extern inline struct list_node * +list_node_rel2abs(void const *base, list_node_relptr_t relptr); extern inline void -mem_pool_free(struct mem_pool *self); +list_push_head(struct list *restrict list, struct list_node *restrict node); extern inline void -mem_pool_reset(struct mem_pool *self); +list_push_tail(struct list *restrict list, struct list_node *restrict node); -extern inline void * -mem_pool_alloc(struct mem_pool *self, size_t align, size_t size); +extern inline struct list_node * +list_pop_head(struct list *list); + +extern inline struct list_node * +list_pop_tail(struct list *list); extern inline enum hex_player hexopponent(enum hex_player player);