commit d7c4a85236076592c53b4b623e36b49475e775d5
parent 80a86b247739696fac607fb85a09f3a0ac50c4b4
Author: MikoĊaj Lenczewski <mblenczewski@gmail.com>
Date:   Sat, 12 Apr 2025 23:16:16 +0000
Add trie implementation.
Diffstat:
| M | list.h |  |  | 2 | +- | 
| A | trie.c |  |  | 65 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | 
| A | trie.h |  |  | 271 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | 
3 files changed, 337 insertions(+), 1 deletion(-)
diff --git a/list.h b/list.h
@@ -31,7 +31,7 @@ struct list_node {
  * point to a valid instance in our next and prev fields). this makes checking
  * whether a list is empty slightly convoluted unfortunately.
  */
-#define LIST_INIT(list) { .prev = &(list), .next = &(list), }
+#define LIST_INIT(list) ((struct list_node) { .prev = &(list), .next = &(list), })
 #define LIST_EMTPY(list) (LIST_HEAD(list) == (list) && LIST_TAIL(list) == (list))
 
 
diff --git a/trie.c b/trie.c
@@ -0,0 +1,65 @@
+#define HEADER_IMPL
+#include "trie.h"
+
+int
+main(void)
+{
+	char buf[8192];
+	struct arena arena = { .ptr = buf, .cap = sizeof buf, .len = 0, };
+
+	printf("trie_init()\n");
+	struct trie_node *root = trie_init(&arena);
+	assert(root);
+	trie_display(root, 0);
+
+	printf("trie_insert(\"foo\")\n");
+	struct trie_node *foo = trie_insert(&arena, root, "foo", 3);
+	trie_display(root, 0);
+
+	printf("trie_insert(\"for\")\n");
+	struct trie_node *_for = trie_insert(&arena, root, "for", 3);
+	trie_display(root, 0);
+
+	printf("trie_insert(\"foobar\")\n");
+	struct trie_node *foobar = trie_insert(&arena, root, "foobar", 6);
+	trie_display(root, 0);
+
+	printf("trie_insert(\"bar\")\n");
+	struct trie_node *bar = trie_insert(&arena, root, "bar", 3);
+	trie_display(root, 0);
+
+	printf("trie_insert(\"forsen\")\n");
+	struct trie_node *forsen = trie_insert(&arena, root, "forsen", 6);
+	trie_display(root, 0);
+
+	printf("trie_insert(\"forseen\")\n");
+	struct trie_node *forseen = trie_insert(&arena, root, "forseen", 7);
+	trie_display(root, 0);
+
+	printf("trie_insert(\"forsee\")\n");
+	struct trie_node *forsee = trie_insert(&arena, root, "forsee", 6);
+	trie_display(root, 0);
+
+	struct trie_node const *search_result;
+
+	printf("trie_search(\"foo\")\n");
+	search_result = trie_search(root, "foo", 3);
+	assert(search_result == foo);
+
+	printf("trie_search(\"bar\")\n");
+	search_result = trie_search(root, "bar", 3);
+	assert(search_result == bar);
+
+	printf("trie_search(\"baz\")\n");
+	search_result = trie_search(root, "baz", 3);
+	assert(!search_result);
+
+	printf("trie_search(\"forsee\")\n");
+	search_result = trie_search(root, "forsee", 6);
+	assert(search_result == forsee);
+
+	printf("arena stats: %zu/%zu bytes used, %zu nodes allocated\n",
+			arena.len, arena.cap, arena.len / sizeof *root);
+
+	return 0;
+}
diff --git a/trie.h b/trie.h
@@ -0,0 +1,271 @@
+#ifndef TRIE_H
+#define TRIE_H
+
+#include <assert.h>
+#include <stddef.h>
+
+#include "arena.h"
+#include "list.h"
+
+/* a trie is a data structure that holds a trie of string prefixes. we can walk
+ * this trie to search for a particular string. each node optionally contains a
+ * string prefix, and can have zero or more child nodes.
+ */
+struct trie_node {
+	struct trie_node *parent;
+	struct list_node children;
+
+	struct {
+		char const *begin, *end;
+	} prefix;
+
+	struct list_node list_node;
+};
+
+/* tries can have a bounded number of children (i.e. fixed radix tries), or can
+ * be adaptive to the number of characters in the input alphabet (i.e. adaptive
+ * radix tries).
+ */
+#ifndef TRIE_ADAPTIVE_RADIX
+
+/* the maximum number of children each trie node can have is known as its
+ * "radix", and affects the maximum depth of the trie (therefore the worst case
+ * search time). a value of 2 reults in essentially a binary search tree of
+ * prefixes. higher values result in shallower and wider trees.
+ */
+#ifndef TRIE_MAX_RADIX
+# define TRIE_MAX_RADIX 2
+#endif /* TRIE_MAX_RADIX */
+
+#endif /* TRIE_ADAPTIVE_RADIX */
+
+inline struct trie_node *
+trie_init(struct arena *arena)
+{
+	struct trie_node *res = ARENA_ALLOC_SIZED(arena, struct trie_node);
+	if (!res)
+		return NULL;
+
+	res->parent = NULL;
+	res->children = LIST_INIT(res->children);
+	res->prefix.begin = res->prefix.end = NULL;
+	res->list_node.prev = res->list_node.next = NULL;
+
+	return res;
+}
+
+inline char const *
+_trie_split_prefix(struct trie_node const *trie, char const *key, size_t len)
+{
+	assert(trie->prefix.begin && trie->prefix.end);
+
+	char const *cur = trie->prefix.begin, *key_end = key + len;
+	while (cur < trie->prefix.end && key < key_end) {
+		if (*cur != *key)
+			break;
+
+		cur++;
+		key++;
+	}
+
+	return cur;
+}
+
+/* if prefix == NULL:
+ *  is group node, check all children
+ *  if failed to add to any children, add new child
+ * else compare key against prefix:
+ *  if key == prefix:
+ *   return node
+ *  else if key < prefix:
+ *   prepend parent node with prefix: key
+ *   set current node prefix to: (prefix - key)
+ *  else if key > prefix:
+ *   append child node with prefix: (key - prefix)
+ *  else if key != prefix:
+ *   prepend parent node with prefix: NULL
+ *   append child node to parent node with prefix: key
+ */
+inline struct trie_node *
+trie_insert(struct arena *arena, struct trie_node *trie, char const *key, size_t len)
+{
+	assert(key);
+	assert(len);
+
+	struct trie_node *res = NULL;
+
+	if (!trie->prefix.begin && !trie->prefix.end) {
+		struct trie_node *it;
+		LIST_ENTRY_ITER(&trie->children, it, list_node) {
+			if ((res = trie_insert(arena, it, key, len)))
+				return res;
+		}
+
+		if (!(res = trie_init(arena)))
+			return NULL;
+
+		res->parent = trie;
+		res->prefix.begin = key;
+		res->prefix.end = key + len;
+		list_push_tail(&trie->children, &res->list_node);
+
+		return res;
+	}
+
+	char const *split = _trie_split_prefix(trie, key, len);
+	size_t split_len = split - trie->prefix.begin;
+
+#if 0
+	printf("node prefix: \"%.*s\", prefix len: %zu\n",
+			(int) (trie->prefix.end - trie->prefix.begin),
+			trie->prefix.begin,
+			trie->prefix.end - trie->prefix.begin);
+
+	printf("before split: \"%.*s\", after split: \"%.*s\", split len: %zu\n",
+			(int) (split - trie->prefix.begin), trie->prefix.begin,
+			(int) (trie->prefix.end - split), split,
+			split_len);
+
+	printf("key: \"%.*s\", key len: %zu\n", (int) len, key, len);
+#endif
+
+	if (split == trie->prefix.end && split_len == len) { /* key == prefix */
+		return trie;
+	}
+
+	if (split == trie->prefix.end && split_len < len) { /* key > prefix */
+		struct trie_node *it;
+		LIST_ENTRY_ITER(&trie->children, it, list_node) {
+			if ((res = trie_insert(arena, it,
+					       key + split_len,
+					       len - split_len)))
+				return res;
+		}
+
+		if (!(res = trie_init(arena)))
+			return NULL;
+
+		res->parent = trie;
+		res->prefix.begin = key + split_len;
+		res->prefix.end = key + len;
+		list_push_tail(&trie->children, &res->list_node);
+
+		return res;
+	}
+
+	if (split == trie->prefix.begin) { /* key != prefix */
+		return NULL;
+	}
+
+	/* key < prefix */
+
+	list_node_unlink(&trie->list_node);
+
+	struct trie_node *new_parent;
+	if (!(new_parent = trie_init(arena)))
+		return NULL;
+
+	new_parent->parent = trie->parent;
+	new_parent->prefix.begin = key;
+	new_parent->prefix.end = key + split_len;
+	list_push_tail(&trie->parent->children, &new_parent->list_node);
+
+	/* trim old parent */
+	trie->parent = new_parent;
+	trie->prefix.begin = trie->prefix.begin + split_len;
+	trie->list_node.prev = trie->list_node.next = NULL;
+	list_push_tail(&new_parent->children, &trie->list_node);
+
+	if (split_len == len) /* no remainder */
+		return new_parent;
+
+	/* handle remainder of key */
+	if (!(res = trie_init(arena)))
+		return NULL;
+
+	res->parent = new_parent;
+	res->prefix.begin = key + split_len;
+	res->prefix.end = key + len;
+	list_push_tail(&new_parent->children, &res->list_node);
+
+	return res;
+}
+
+inline struct trie_node const *
+trie_search(struct trie_node const *trie, char const *key, size_t len)
+{
+	struct trie_node const *res = NULL;
+
+	if (!trie->prefix.begin && !trie->prefix.end) {
+		struct trie_node *it;
+		LIST_ENTRY_ITER(&trie->children, it, list_node) {
+			if ((res = trie_search(it, key, len)))
+				return res;
+		}
+
+		return NULL;
+	}
+
+	char const *split = _trie_split_prefix(trie, key, len);
+	size_t split_len = split - trie->prefix.begin;
+
+	if (split == trie->prefix.end && split_len == len)
+		return trie;
+
+	if (split == trie->prefix.end && split_len < len) {
+		struct trie_node *it;
+		LIST_ENTRY_ITER(&trie->children, it, list_node) {
+			if ((res = trie_search(it, key + split_len, len - split_len)))
+				return res;
+		}
+
+		return NULL;
+	}
+
+	if (split == trie->prefix.begin)
+		return NULL;
+
+	struct trie_node *it;
+	LIST_ENTRY_ITER(&trie->children, it, list_node) {
+		if ((res = trie_search(it, key + split_len, len - split_len)))
+			return res;
+	}
+
+	return NULL;
+}
+
+inline void
+trie_display(struct trie_node const *trie, size_t indent)
+{
+	for (size_t i = 0; i < indent; i++)
+		printf("  ");
+
+	printf("{ %p, prefix: \"%.*s\", }\n",
+			trie, (int) (trie->prefix.end - trie->prefix.begin), trie->prefix.begin);
+
+	struct trie_node *it;
+	LIST_ENTRY_ITER(&trie->children, it, list_node) {
+		trie_display(it, indent + 1);
+	}
+}
+
+#endif /* TRIE_H */
+
+#ifdef HEADER_IMPL
+
+extern inline struct trie_node *
+trie_init(struct arena *arena);
+
+extern inline char const *
+_trie_split_prefix(struct trie_node const *trie, char const *key, size_t len);
+
+extern inline struct trie_node *
+trie_insert(struct arena *arena, struct trie_node *trie, char const *key, size_t len);
+
+extern inline struct trie_node const *
+trie_search(struct trie_node const *trie, char const *key, size_t len);
+
+extern inline void
+trie_display(struct trie_node const *trie, size_t indent);
+
+#endif /* HEADER_IMPL */