/* * Simple trie implementation for key-value mapping storage * * Copyright (c) 2020-2021 Ákos Uzonyi <uzonyi.akos@gmail.com> * All rights reserved. * * SPDX-License-Identifier: LGPL-2.1-or-later */ #ifdef HAVE_CONFIG_H # include "config.h" #endif #include <stdlib.h> #include <stdio.h> #include "trie.h" #include "macros.h" #include "xmalloc.h" static const uint8_t ptr_sz_lg = (sizeof(void *) == 8 ? 6 : 5); /** * Returns lg2 of node size in bits for the specific level of the trie. */ static uint8_t trie_get_node_size(struct trie *t, uint8_t depth) { /* Last level contains data and we allow it having a different size */ if (depth == t->max_depth) return t->data_block_key_bits + t->item_size_lg; /* Last level of the tree can be smaller */ if (depth == t->max_depth - 1) return (t->key_size - t->data_block_key_bits - 1) % t->node_key_bits + 1 + ptr_sz_lg; return t->node_key_bits + ptr_sz_lg; } /** * Provides starting offset of bits in key corresponding to the node index * at the specific level. */ static uint8_t trie_get_node_bit_offs(struct trie *t, uint8_t depth) { uint8_t offs; if (depth == t->max_depth) return 0; offs = t->data_block_key_bits; if (depth == t->max_depth - 1) return offs; /* data_block_size + remainder */ offs += trie_get_node_size(t, t->max_depth - 1) - ptr_sz_lg; offs += (t->max_depth - depth - 2) * t->node_key_bits; return offs; } struct trie * trie_create(uint8_t key_size, uint8_t item_size_lg, uint8_t node_key_bits, uint8_t data_block_key_bits, uint64_t empty_value) { if (item_size_lg > 6) return NULL; if (key_size > 64) return NULL; if (node_key_bits < 1) return NULL; if (data_block_key_bits < 1 || data_block_key_bits > key_size) return NULL; struct trie *t = malloc(sizeof(*t)); if (!t) return NULL; t->fill_value = t->empty_value = empty_value & MASK64_SAFE(BIT32(item_size_lg)); for (size_t i = 0; i < 6U - item_size_lg; i++) t->fill_value |= t->fill_value << BIT32(item_size_lg + i); t->data = NULL; t->item_size_lg = item_size_lg; t->node_key_bits = node_key_bits; t->data_block_key_bits = data_block_key_bits; t->key_size = key_size; t->max_depth = (key_size - data_block_key_bits + node_key_bits - 1) / t->node_key_bits; return t; } static void * trie_create_data_block(struct trie *t) { uint8_t sz = t->data_block_key_bits + t->item_size_lg; if (sz < 6) sz = 6; size_t count = BIT32(sz - 6); uint64_t *data_block = xcalloc(count, 8); for (size_t i = 0; i < count; i++) data_block[i] = t->fill_value; return data_block; } static uint64_t * trie_get_node(struct trie *t, uint64_t key, bool auto_create) { void **cur_node = &(t->data); if (t->key_size < 64 && key > MASK64(t->key_size)) return NULL; for (uint8_t cur_depth = 0; cur_depth <= t->max_depth; cur_depth++) { uint8_t offs = trie_get_node_bit_offs(t, cur_depth); uint8_t sz = trie_get_node_size(t, cur_depth); if (!*cur_node) { if (!auto_create) return NULL; if (cur_depth == t->max_depth) *cur_node = trie_create_data_block(t); else *cur_node = xcalloc(BIT64(sz), 1); } if (cur_depth == t->max_depth) break; size_t pos = (key >> offs) & MASK64(sz - ptr_sz_lg); cur_node = (((void **) (*cur_node)) + pos); } return (uint64_t *) (*cur_node); } static void trie_data_block_calc_pos(struct trie *t, uint64_t key, uint64_t *pos, uint64_t *mask, uint64_t *offs) { uint64_t key_mask; key_mask = MASK64(t->data_block_key_bits); *pos = (key & key_mask) >> (6 - t->item_size_lg); if (t->item_size_lg == 6) { *offs = 0; *mask = -1; return; } key_mask = MASK64(6 - t->item_size_lg); *offs = (key & key_mask) << t->item_size_lg; *mask = MASK64_SAFE(BIT32(t->item_size_lg)) << *offs; } bool trie_set(struct trie *t, uint64_t key, uint64_t val) { uint64_t *data = trie_get_node(t, key, true); if (!data) return false; uint64_t pos, mask, offs; trie_data_block_calc_pos(t, key, &pos, &mask, &offs); data[pos] &= ~mask; data[pos] |= (val << offs) & mask; return true; } static uint64_t trie_data_block_get(struct trie *t, uint64_t *data, uint64_t key) { if (!data) return t->empty_value; uint64_t pos, mask, offs; trie_data_block_calc_pos(t, key, &pos, &mask, &offs); return (data[pos] & mask) >> offs; } uint64_t trie_get(struct trie *b, uint64_t key) { return trie_data_block_get(b, trie_get_node(b, key, false), key); } static uint64_t trie_iterate_keys_node(struct trie *t, trie_iterate_fn fn, void *fn_data, void *node, uint64_t start, uint64_t end, uint8_t depth) { if (start > end || !node) return 0; if (t->key_size < 64) { uint64_t key_max = MASK64(t->key_size); if (end > key_max) end = key_max; } if (depth == t->max_depth) { for (uint64_t i = start; i <= end; i++) fn(fn_data, i, trie_data_block_get(t, (uint64_t *) node, i)); return end - start + 1; } uint8_t parent_node_bit_off = depth == 0 ? t->key_size : trie_get_node_bit_offs(t, depth - 1); uint64_t first_key_in_node = start & ~MASK64_SAFE(parent_node_bit_off); uint8_t node_bit_off = trie_get_node_bit_offs(t, depth); uint8_t node_key_bits = parent_node_bit_off - node_bit_off; uint64_t mask = MASK64_SAFE(node_key_bits); uint64_t start_index = (start >> node_bit_off) & mask; uint64_t end_index = (end >> node_bit_off) & mask; uint64_t child_key_count = BIT64(node_bit_off); uint64_t count = 0; for (uint64_t i = start_index; i <= end_index; i++) { uint64_t child_start = first_key_in_node + i * child_key_count; uint64_t child_end = first_key_in_node + (i + 1) * child_key_count - 1; if (child_start < start) child_start = start; if (child_end > end) child_end = end; count += trie_iterate_keys_node(t, fn, fn_data, ((void **) node)[i], child_start, child_end, depth + 1); } return count; } uint64_t trie_iterate_keys(struct trie *t, uint64_t start, uint64_t end, trie_iterate_fn fn, void *fn_data) { return trie_iterate_keys_node(t, fn, fn_data, t->data, start, end, 0); } static void trie_free_node(struct trie *t, void *node, uint8_t depth) { if (!node) return; if (depth >= t->max_depth) goto free_node; size_t sz = BIT64(trie_get_node_size(t, depth) - ptr_sz_lg); for (size_t i = 0; i < sz; i++) trie_free_node(t, ((void **) node)[i], depth + 1); free_node: free(node); } void trie_free(struct trie *t) { trie_free_node(t, t->data, 0); free(t); }