commit 072d7517572b56dcf30e1e0e5ed7269671307193
parent 1018b76098c392d846b775ac0c9e6a751cba1378
Author: MikoĊaj Lenczewski <mblenczewski@gmail.com>
Date: Fri, 3 Jan 2025 15:56:28 +0000
Switch to intrusive list for mcts_node children list
Diffstat:
6 files changed, 202 insertions(+), 103 deletions(-)
diff --git a/agents/hexes/include/hexes.h b/agents/hexes/include/hexes.h
@@ -31,12 +31,15 @@
#include <arpa/inet.h>
#include <math.h>
-#include <netdb.h>
-#include <sys/socket.h>
-#include <sys/types.h>
#include <time.h>
#include <unistd.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netdb.h>
+
+#include <sys/mman.h>
+
typedef int32_t b32;
typedef unsigned char c8;
@@ -59,13 +62,19 @@ typedef double f64;
#define MIN(a, b) ((a) > (b) ? (a) : (b))
#define MAX(a, b) ((a) < (b) ? (b) : (a))
+/* NOTE: we define a relative pointer helper to allow fully relocatable data
+ * structures, for example to allow for trees that survive a memcpy()
+ */
+
#define RELPTR_NULL (0)
#define _RELPTR_MASK(ty_relptr) ((ty_relptr)1 << ((sizeof(ty_relptr) * 8) - 1))
+
#define _RELPTR_ENC(ty_relptr, ptroff) \
- ((ty_relptr)((ptroff) ^ _RELPTR_MASK(ty_relptr)))
+ ((ty_relptr) ((ptroff) ^ _RELPTR_MASK(ty_relptr)))
+
#define _RELPTR_DEC(ty_relptr, relptr) \
- ((ty_relptr)((relptr) ^ _RELPTR_MASK(ty_relptr)))
+ ((ty_relptr) ((relptr) ^ _RELPTR_MASK(ty_relptr)))
#define RELPTR_ABS2REL(ty_relptr, base, absptr) \
((absptr) \
@@ -73,8 +82,8 @@ typedef double f64;
: RELPTR_NULL)
#define RELPTR_REL2ABS(ty_absptr, ty_relptr, base, relptr) \
- ((relptr) \
- ? ((ty_absptr)((u8 *) base + _RELPTR_DEC(ty_relptr, relptr))) \
+ (((relptr) != RELPTR_NULL) \
+ ? ((ty_absptr) ((u8 *) base + _RELPTR_DEC(ty_relptr, relptr))) \
: NULL)
#define NANOSECS (1000000000ULL)
@@ -168,62 +177,130 @@ difftimespec(struct timespec *restrict lhs, struct timespec *restrict rhs, struc
#define IS_POW2(v) (((v) & ((v) - 1)) == 0)
#define IS_ALIGNED(v, align) (((v) & ((align) - 1)) == 0)
-#define ALIGN_PREV(v, align) ((v) & ~((v) - 1))
+#define ALIGN_PREV(v, align) ((v) & ~((align) - 1))
#define ALIGN_NEXT(v, align) ALIGN_PREV((v) + ((align) - 1), (align))
-struct mem_pool {
+struct arena {
void *ptr;
- size_t cap, len;
+ uint64_t cap, len;
};
-inline bool
-mem_pool_init(struct mem_pool *self, size_t align, size_t capacity)
+inline void
+arena_reset(struct arena *arena)
{
- assert(self);
+ arena->len = 0;
+}
+
+inline void *
+arena_alloc(struct arena *arena, size_t size, size_t align)
+{
+ assert(arena);
+ assert(size);
assert(align);
assert(IS_POW2(align));
- assert(capacity % align == 0);
- self->ptr = aligned_alloc(align, capacity);
- if (!self->ptr) return false;
+ size_t aligned_len = ALIGN_NEXT(arena->len, align);
+ if (aligned_len + size > arena->cap)
+ return NULL;
- self->cap = capacity;
- self->len = 0;
+ void *ptr = (void *) ((intptr_t) arena->ptr + aligned_len);
+ arena->len = aligned_len + size;
- return true;
+ return ptr;
}
-inline void
-mem_pool_free(struct mem_pool *self)
+#define ALLOC_ARRAY(arena, T, n) \
+ arena_alloc((arena), sizeof(T) * (n), alignof(T))
+
+#define ALLOC_SIZED(arena, T) ALLOC_ARRAY((arena), T, 1)
+
+typedef s64 list_node_relptr_t;
+
+struct list_node;
+
+#define FROM_NODE(node, T, member) \
+ ((T *) ((uintptr_t) node - offsetof(T, member)))
+
+struct list_node {
+ list_node_relptr_t prev, next;
+};
+
+inline list_node_relptr_t
+list_node_abs2rel(void const *restrict base, struct list_node *restrict absptr)
{
- assert(self);
+ return RELPTR_ABS2REL(list_node_relptr_t, base, absptr);
+}
- free(self->ptr);
+inline struct list_node *
+list_node_rel2abs(void const *base, list_node_relptr_t relptr)
+{
+ return RELPTR_REL2ABS(struct list_node *, list_node_relptr_t, base, relptr);
}
+#define list_node_iter(node) \
+ for (struct list_node *it = node; it; it = list_node_rel2abs(it, it->next))
+
+#define list_node_riter(node) \
+ for (struct list_node *it = node; it ; it = list_node_rel2abs(it, it->prev))
+
+struct list {
+ list_node_relptr_t head, tail;
+};
+
+#define list_iter(list) \
+ list_node_iter(list_node_rel2abs(list, list->head))
+
+#define list_riter(list) \
+ list_node_riter(list_node_rel2abs(list, list->tail))
+
inline void
-mem_pool_reset(struct mem_pool *self)
+list_push_head(struct list *restrict list, struct list_node *restrict node)
{
- assert(self);
+ if (list->tail == RELPTR_NULL)
+ list->tail = list_node_abs2rel(list, node);
- self->len = 0;
+ struct list_node *head = list_node_rel2abs(list, list->head);
+ if (head)
+ head->prev = list_node_abs2rel(head, node);
+
+ node->next = list_node_abs2rel(node, head);
+ list->head = list_node_abs2rel(list, node);
}
-inline void *
-mem_pool_alloc(struct mem_pool *self, size_t align, size_t size)
+inline void
+list_push_tail(struct list *restrict list, struct list_node *restrict node)
{
- assert(self);
- assert(align);
- assert(IS_POW2(align));
+ if (list->head == RELPTR_NULL)
+ list->head = list_node_abs2rel(list, node);
+
+ struct list_node *tail = list_node_rel2abs(list, list->tail);
+ if (tail)
+ tail->next = list_node_abs2rel(tail, node);
+
+ node->prev = list_node_abs2rel(node, tail);
+ list->tail = list_node_abs2rel(list, node);
+}
- size_t aligned_len = ALIGN_NEXT(self->len, align);
- if (aligned_len + size > self->cap)
+inline struct list_node *
+list_pop_head(struct list *list)
+{
+ if (list->head == RELPTR_NULL)
return NULL;
- void *ptr = (void *) ((intptr_t) self->ptr + aligned_len);
- self->len = aligned_len + size;
+ struct list_node *node = list_node_rel2abs(list, list->head);
+ list->head = list_node_abs2rel(list, list_node_rel2abs(node, node->next));
+ return node;
+}
- return ptr;
+inline struct list_node *
+list_pop_tail(struct list *list)
+{
+ if (list->tail == RELPTR_NULL)
+ return NULL;
+
+ struct list_node *node = list_node_rel2abs(list, list->tail);
+ list->tail = list_node_abs2rel(list, list_node_rel2abs(node, node->prev));
+ return node;
}
/* network definitions
@@ -266,7 +343,7 @@ struct segment {
};
inline segment_relptr_t
-segment_abs2rel(void const *base, struct segment *absptr)
+segment_abs2rel(void const *restrict base, struct segment *restrict absptr)
{
return RELPTR_ABS2REL(segment_relptr_t, base, absptr);
}
@@ -334,7 +411,7 @@ void
board_swap(struct board *self);
size_t
-board_available_moves(struct board const *self, struct move *buf);
+board_available_moves(struct board const *restrict self, struct move *restrict buf);
bool
board_winner(struct board *self, enum hex_player *out);
@@ -371,13 +448,14 @@ void
agent_random_swap(struct agent_random *self);
bool
-agent_random_next(struct agent_random *self, struct timespec timeout, u32 *out_x, u32 *out_y);
-
-#define MCTS_RESERVED_MEM (MiB)
+agent_random_next(struct agent_random *self, struct timespec timeout,
+ u32 *restrict out_x, u32 *restrict out_y);
typedef s64 mcts_node_relptr_t;
struct mcts_node {
+ alignas(64)
+
mcts_node_relptr_t parent;
enum hex_player player;
u8 x, y;
@@ -385,18 +463,14 @@ struct mcts_node {
s32 wins, rave_wins;
u32 plays, rave_plays;
- u16 children_cap, children_len;
- mcts_node_relptr_t children[];
-};
+ u32 children_cap, children_len;
-inline size_t
-mcts_node_sizeof(size_t children)
-{
- return sizeof(struct mcts_node) + children * sizeof(mcts_node_relptr_t);
-}
+ struct list children;
+ struct list_node list_node;
+};
inline mcts_node_relptr_t
-mcts_node_abs2rel(void *base, struct mcts_node *absptr)
+mcts_node_abs2rel(void *restrict base, struct mcts_node *restrict absptr)
{
return RELPTR_ABS2REL(mcts_node_relptr_t, base, absptr);
}
@@ -413,7 +487,7 @@ struct agent_mcts {
struct board shadow_board;
- struct mem_pool pool;
+ struct arena pool;
struct mcts_node *root;
};
@@ -431,13 +505,16 @@ void
agent_mcts_swap(struct agent_mcts *self);
bool
-agent_mcts_next(struct agent_mcts *self, struct timespec timeout, u32 *out_x, u32 *out_y);
+agent_mcts_next(struct agent_mcts *self, struct timespec timeout,
+ u32 *restrict out_x, u32 *restrict out_y);
enum agent_type {
AGENT_RANDOM,
AGENT_MCTS,
};
+#define HEXES_RESERVED_MEM (MiB)
+
struct agent {
enum agent_type type;
union {
@@ -460,6 +537,7 @@ void
agent_swap(struct agent *self);
bool
-agent_next(struct agent *self, struct timespec timeout, u32 *out_x, u32 *out_y);
+agent_next(struct agent *self, struct timespec timeout, u32 *restrict out_x,
+ u32 *restrict out_y);
#endif /* HEXES_H */
diff --git a/agents/hexes/src/agent.c b/agents/hexes/src/agent.c
@@ -55,7 +55,8 @@ agent_swap(struct agent *self)
}
bool
-agent_next(struct agent *self, struct timespec timeout, u32 *out_x, u32 *out_y)
+agent_next(struct agent *self, struct timespec timeout, u32 *restrict out_x,
+ u32 *restrict out_y)
{
assert(self);
assert(out_x);
diff --git a/agents/hexes/src/agent/mcts.c b/agents/hexes/src/agent/mcts.c
@@ -1,17 +1,14 @@
#include "hexes.h"
-extern inline size_t
-mcts_node_sizeof(size_t children);
-
extern inline mcts_node_relptr_t
-mcts_node_abs2rel(void *base, struct mcts_node *absptr);
+mcts_node_abs2rel(void *restrict base, struct mcts_node *restrict absptr);
extern inline struct mcts_node *
mcts_node_rel2abs(void *base, mcts_node_relptr_t relptr);
static void
mcts_node_init(struct mcts_node *self, struct mcts_node *parent,
- enum hex_player player, u8 x, u8 y, size_t children)
+ enum hex_player player, u8 x, u8 y, u32 children_cap)
{
assert(self);
@@ -23,27 +20,28 @@ mcts_node_init(struct mcts_node *self, struct mcts_node *parent,
self->wins = self->rave_wins = 0;
self->plays = self->rave_plays = 0;
- self->children_cap = children;
+ self->children_cap = children_cap;
self->children_len = 0;
+
+ self->children.head = self->children.tail = RELPTR_NULL;
+ self->list_node.prev = self->list_node.next = RELPTR_NULL;
}
static bool
-mcts_node_expand(struct mcts_node *self, struct mem_pool *pool, u8 x, u8 y)
+mcts_node_expand(struct mcts_node *self, struct arena *pool, u8 x, u8 y)
{
assert(self);
assert(pool);
- struct mcts_node *child = mem_pool_alloc(pool, alignof(struct mcts_node),
- mcts_node_sizeof(self->children_cap - 1));
-
+ struct mcts_node *child = ALLOC_SIZED(pool, struct mcts_node);
if (!child) {
dbglog(LOG_WARN, "Failed to allocate child node. Consider compacting memory pool\n");
return false;
}
mcts_node_init(child, self, hexopponent(self->player), x, y, self->children_cap - 1);
-
- self->children[self->children_len++] = mcts_node_abs2rel(self, child);
+ list_push_tail(&self->children, &child->list_node);
+ self->children_len++;
return true;
}
@@ -53,13 +51,13 @@ mcts_node_get_child(struct mcts_node *self, u8 x, u8 y)
{
assert(self);
- if (!self->children_len) return NULL;
-
- for (size_t i = 0; i < self->children_cap; i++) {
- struct mcts_node *child = mcts_node_rel2abs(self, self->children[i]);
- if (!child) continue;
+ struct list *list = &self->children;
+ list_iter(list) {
+ struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node);
+ assert(child);
- if (child->x == x && child->y == y) return child;
+ if (child->x == x && child->y == y)
+ return child;
}
return NULL;
@@ -118,11 +116,12 @@ mcts_node_best_child(struct mcts_node *self)
f32 max_score = -INFINITY;
struct mcts_node *best_child = NULL;
- for (size_t i = 0; i < self->children_cap; i++) {
- struct mcts_node *child = mcts_node_rel2abs(self, self->children[i]);
- if (!child) continue;
+ struct list *list = &self->children;
+ list_iter(list) {
+ struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node);
+ assert(child);
- dbglog(LOG_DEBUG, "Node: {parent=%p, children=%" PRIu8 ", x=%" PRIu32 ", y=%" PRIu32 "}\n",
+ dbglog(LOG_DEBUG, "Node: {parent=%p, children=%" PRIu32 ", x=%" PRIu32 ", y=%" PRIu32 "}\n",
mcts_node_rel2abs(child, child->parent), child->children_len, child->x, child->y);
f32 score = mcts_node_calc_score(child);
@@ -148,15 +147,20 @@ agent_mcts_init(struct agent_mcts *self, struct board const *board, struct threa
if (!board_init(&self->shadow_board, board->size)) return false;
size_t align = alignof(struct mcts_node);
- size_t cap = ((mem_limit_mib * MiB) - MCTS_RESERVED_MEM) & ~(align - 1);
+ self->pool.cap = ALIGN_PREV((mem_limit_mib * MiB) - HEXES_RESERVED_MEM, align);
- if (!mem_pool_init(&self->pool, align, cap)) {
+ int prot = PROT_READ | PROT_WRITE;
+ int flags = MAP_PRIVATE | MAP_ANONYMOUS;
+ self->pool.ptr = mmap(NULL, self->pool.cap, prot, flags, -1, 0);
+ madvise(self->pool.ptr, self->pool.cap, MADV_HUGEPAGE);
+
+ if (self->pool.ptr == MAP_FAILED) {
board_free(&self->shadow_board);
return false;
}
- size_t moves = board_available_moves(board, NULL);
- self->root = mem_pool_alloc(&self->pool, alignof(struct mcts_node), mcts_node_sizeof(moves));
+ u32 moves = board_available_moves(self->board, NULL);
+ self->root = ALLOC_SIZED(&self->pool, struct mcts_node);
mcts_node_init(self->root, NULL, hexopponent(player), 0, 0, moves);
return true;
@@ -167,7 +171,7 @@ agent_mcts_free(struct agent_mcts *self)
{
assert(self);
- mem_pool_free(&self->pool);
+ munmap(self->pool.ptr, self->pool.cap);
}
void
@@ -175,10 +179,10 @@ agent_mcts_play(struct agent_mcts *self, enum hex_player player, u32 x, u32 y)
{
assert(self);
- mem_pool_reset(&self->pool);
+ arena_reset(&self->pool);
- size_t moves = board_available_moves(self->board, NULL);
- self->root = mem_pool_alloc(&self->pool, alignof(struct mcts_node), mcts_node_sizeof(moves));
+ u32 moves = board_available_moves(self->board, NULL);
+ self->root = ALLOC_SIZED(&self->pool, struct mcts_node);
mcts_node_init(self->root, NULL, player, x, y, moves);
// TODO: implement tree reuse, if it improves play
@@ -207,10 +211,10 @@ agent_mcts_swap(struct agent_mcts *self)
struct mcts_node old_root = *self->root;
- mem_pool_reset(&self->pool);
+ arena_reset(&self->pool);
- size_t moves = board_available_moves(self->board, NULL);
- self->root = mem_pool_alloc(&self->pool, alignof(struct mcts_node), mcts_node_sizeof(moves));
+ u32 moves = board_available_moves(self->board, NULL);
+ self->root = ALLOC_SIZED(&self->pool, struct mcts_node);
mcts_node_init(self->root, NULL, hexopponent(old_root.player), old_root.x, old_root.y, moves);
}
@@ -218,7 +222,8 @@ static bool
mcts_search(struct agent_mcts *self, struct timespec timeout);
bool
-agent_mcts_next(struct agent_mcts *self, struct timespec timeout, u32 *out_x, u32 *out_y)
+agent_mcts_next(struct agent_mcts *self, struct timespec timeout,
+ u32 *restrict out_x, u32 *restrict out_y)
{
assert(self);
assert(out_x);
@@ -227,13 +232,14 @@ agent_mcts_next(struct agent_mcts *self, struct timespec timeout, u32 *out_x, u3
if (!mcts_search(self, timeout)) return false;
struct mcts_node *root = self->pool.ptr;
- assert(root->children_len);
+ assert(root->children.head != RELPTR_NULL);
u32 max_plays = 0;
struct mcts_node *best_child = NULL;
- for (size_t i = 0; i < root->children_cap; i++) {
- struct mcts_node *child = mcts_node_rel2abs(root, root->children[i]);
- if (!child) continue;
+ struct list *list = &root->children;
+ list_iter(list) {
+ struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node);
+ assert(child);
if (child->plays > max_plays) {
max_plays = child->plays;
@@ -330,9 +336,10 @@ mcts_round(struct agent_mcts *self, struct move *moves)
do {
s32 reward = winner == node->player ? +1 : -1;
- for (size_t i = 0; i < node->children_len; i++) {
- struct mcts_node *child = mcts_node_rel2abs(node, node->children[i]);
- if (!child) continue;
+ struct list *list = &node->children;
+ list_iter(list) {
+ struct mcts_node *child = FROM_NODE(it, struct mcts_node, list_node);
+ assert(child);
struct segment *segment = &self->shadow_board.segments[child->y * self->shadow_board.size + child->x];
if ((enum cell) child->player == segment->occupant) {
diff --git a/agents/hexes/src/agent/random.c b/agents/hexes/src/agent/random.c
@@ -58,7 +58,8 @@ agent_random_swap(struct agent_random *self)
}
bool
-agent_random_next(struct agent_random *self, struct timespec timeout, u32 *out_x, u32 *out_y)
+agent_random_next(struct agent_random *self, struct timespec timeout,
+ u32 *restrict out_x, u32 *restrict out_y)
{
assert(self);
assert(out_x);
diff --git a/agents/hexes/src/board.c b/agents/hexes/src/board.c
@@ -3,7 +3,7 @@
#define NEIGHBOUR_COUNT 6
extern inline segment_relptr_t
-segment_abs2rel(void const *base, struct segment *absptr);
+segment_abs2rel(void const *restrict base, struct segment *restrict absptr);
extern inline struct segment *
segment_rel2abs(void const *base, segment_relptr_t relptr);
diff --git a/agents/hexes/src/utils.c b/agents/hexes/src/utils.c
@@ -12,17 +12,29 @@ shuffle(void *arr, size_t size, size_t len);
extern inline void
difftimespec(struct timespec *restrict lhs, struct timespec *restrict rhs, struct timespec *restrict out);
-extern inline bool
-mem_pool_init(struct mem_pool *self, size_t align, size_t capacity);
+extern inline void
+arena_reset(struct arena *arena);
+
+extern inline void *
+arena_alloc(struct arena *arena, size_t size, size_t align);
+
+extern inline list_node_relptr_t
+list_node_abs2rel(void const *restrict base, struct list_node *restrict absptr);
+
+extern inline struct list_node *
+list_node_rel2abs(void const *base, list_node_relptr_t relptr);
extern inline void
-mem_pool_free(struct mem_pool *self);
+list_push_head(struct list *restrict list, struct list_node *restrict node);
extern inline void
-mem_pool_reset(struct mem_pool *self);
+list_push_tail(struct list *restrict list, struct list_node *restrict node);
-extern inline void *
-mem_pool_alloc(struct mem_pool *self, size_t align, size_t size);
+extern inline struct list_node *
+list_pop_head(struct list *list);
+
+extern inline struct list_node *
+list_pop_tail(struct list *list);
extern inline enum hex_player
hexopponent(enum hex_player player);