Skip to content

Commit

Permalink
update adaptive prefix trie
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Jan 31, 2024
1 parent dcbf45a commit 5184000
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 34 deletions.
3 changes: 2 additions & 1 deletion text-utils-prefix/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ edition = "2021"
[dependencies]
rayon = "1.8"
itertools = "0.12"
serde = { version = "1.0", features = ["derive"] }
serde-big-array = "0.5"

[dev-dependencies]
criterion = "0.5"
Expand All @@ -14,7 +16,6 @@ patricia_tree = "0.8.0"
rand = "0.8"
rand_distr = "0.4"
rand_chacha = "0.3"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

[profile.release]
Expand Down
66 changes: 39 additions & 27 deletions text-utils-prefix/src/adaptive_radix_trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ use std::{

use crate::{ContinuationSearch, PrefixSearch};

type Index<const N: usize> = [u8; N];
type Children<V, const N: usize> = [Option<Box<Node<V>>>; N];
type Index<const N: usize> = Box<[u8; N]>;
type Children<V, const N: usize> = Box<[Option<Box<Node<V>>>; N]>;

#[derive(Default, Debug)]
enum NodeType<V> {
#[default]
Empty,
Leaf(V),
N4(Index<4>, Children<V, 4>, usize),
N16(Index<16>, Children<V, 16>, usize),
N48(Box<Index<256>>, Children<V, 48>, usize),
N256(Children<V, 256>, usize),
N4(Index<4>, Children<V, 4>, u8),
N16(Index<16>, Children<V, 16>, u8),
N48(Index<256>, Children<V, 48>, u8),
N256(Children<V, 256>, u16),
}

#[derive(Debug)]
Expand Down Expand Up @@ -117,7 +117,11 @@ impl<V> Node<V> {
fn new_inner(prefix: Vec<u8>) -> Self {
Self {
prefix: prefix.into_boxed_slice(),
inner: NodeType::N4(std::array::from_fn(|_| 0), std::array::from_fn(|_| None), 0),
inner: NodeType::N4(
Box::new(std::array::from_fn(|_| 0)),
Box::new(std::array::from_fn(|_| None)),
0,
),
}
}

Expand Down Expand Up @@ -184,17 +188,17 @@ impl<V> Node<V> {
match &self.inner {
NodeType::Empty | NodeType::Leaf(_) => Box::new(empty()),
NodeType::N4(_, children, num_children) => Box::new(
children[..*num_children]
children[..*num_children as usize]
.iter()
.filter_map(|child| child.as_deref()),
),
NodeType::N16(_, children, num_children) => Box::new(
children[..*num_children]
children[..*num_children as usize]
.iter()
.filter_map(|child| child.as_deref()),
),
NodeType::N48(_, children, num_children) => Box::new(
children[..*num_children]
children[..*num_children as usize]
.iter()
.filter_map(|child| child.as_deref()),
),
Expand All @@ -214,8 +218,9 @@ impl<V> Node<V> {
NodeType::Empty | NodeType::Leaf(_) => unreachable!("should not happen"),
NodeType::N4(keys, children, num_children) => {
// also keep sorted order for n4 for easier upgrade
let idx = keys[..*num_children].binary_search(&key).unwrap_err();
if idx < *num_children {
let n = *num_children as usize;
let idx = keys[..n].binary_search(&key).unwrap_err();
if idx < n {
keys[idx..].rotate_right(1);
children[idx..].rotate_right(1);
}
Expand All @@ -224,8 +229,9 @@ impl<V> Node<V> {
*num_children += 1;
}
NodeType::N16(keys, children, num_children) => {
let idx = keys[..*num_children].binary_search(&key).unwrap_err();
if idx < *num_children {
let n = *num_children as usize;
let idx = keys[..n].binary_search(&key).unwrap_err();
if idx < n {
keys[idx..].rotate_right(1);
children[idx..].rotate_right(1);
}
Expand All @@ -234,8 +240,8 @@ impl<V> Node<V> {
*num_children += 1;
}
NodeType::N48(index, children, num_children) => {
index[key as usize] = *num_children as u8;
children[*num_children] = Some(Box::new(child));
index[key as usize] = *num_children;
children[*num_children as usize] = Some(Box::new(child));
*num_children += 1;
}
NodeType::N256(children, num_children) => {
Expand Down Expand Up @@ -291,14 +297,15 @@ impl<V> Node<V> {
NodeType::Empty | NodeType::Leaf(_) => None,
NodeType::N4(keys, children, num_children) => {
for i in 0..*num_children {
let i = i as usize;
if keys[i] == key {
return children[i].as_deref();
}
}
None
}
NodeType::N16(keys, children, num_children) => {
let idx = keys[..*num_children].binary_search(&key).ok()?;
let idx = keys[..*num_children as usize].binary_search(&key).ok()?;
children[idx].as_deref()
}
NodeType::N48(keys, children, _) => {
Expand All @@ -314,14 +321,15 @@ impl<V> Node<V> {
NodeType::Empty | NodeType::Leaf(_) => None,
NodeType::N4(keys, children, num_children) => {
for i in 0..*num_children {
let i = i as usize;
if keys[i] == key {
return children[i].as_deref_mut();
}
}
None
}
NodeType::N16(keys, children, num_children) => {
let idx = keys[..*num_children].binary_search(&key).ok()?;
let idx = keys[..*num_children as usize].binary_search(&key).ok()?;
children[idx].as_deref_mut()
}
NodeType::N48(keys, children, _) => children
Expand All @@ -336,9 +344,8 @@ impl<V> Node<V> {
NodeType::Empty | NodeType::Leaf(_) => {
unreachable!("should not happen")
}
NodeType::N256(_, num_children) => {
NodeType::N256(..) => {
// upgrade should only be called on non empty n256 nodes
assert!(*num_children < 256);
return;
}
NodeType::N4(keys, children, num_children) => {
Expand All @@ -349,15 +356,15 @@ impl<V> Node<V> {
assert_eq!(*num_children, 4);
// just move over because n4 is also sorted
NodeType::N16(
std::array::from_fn(|i| if i < 4 { keys[i] } else { 0 }),
std::array::from_fn(|i| {
Box::new(std::array::from_fn(|i| if i < 4 { keys[i] } else { 0 })),
Box::new(std::array::from_fn(|i| {
if i < 4 {
assert!(children[i].is_some());
std::mem::take(&mut children[i])
} else {
None
}
}),
})),
4,
)
}
Expand All @@ -373,14 +380,14 @@ impl<V> Node<V> {
}
NodeType::N48(
Box::new(index),
std::array::from_fn(|i| {
Box::new(std::array::from_fn(|i| {
if i < 16 {
assert!(children[i].is_some());
std::mem::take(&mut children[i])
} else {
None
}
}),
})),
16,
)
}
Expand All @@ -391,15 +398,15 @@ impl<V> Node<V> {
}
assert_eq!(*num_children, 48);
NodeType::N256(
std::array::from_fn(|i| {
Box::new(std::array::from_fn(|i| {
let idx = index[i];
if idx < 48 {
assert!(children[idx as usize].is_some());
std::mem::take(&mut children[idx as usize])
} else {
None
}
}),
})),
48,
)
}
Expand Down Expand Up @@ -673,12 +680,17 @@ impl<V> ContinuationSearch for AdaptiveRadixTrie<V> {

#[cfg(test)]
mod test {
use crate::adaptive_radix_trie::Node;
use crate::{adaptive_radix_trie::AdaptiveRadixTrie, PrefixSearch};
use std::fs;
use std::path::PathBuf;

#[test]
fn test_trie() {
println!(
"size of adaptive radix trie node: {}",
std::mem::size_of::<Node<i32>>()
);
let mut trie = AdaptiveRadixTrie::default();
assert_eq!(trie.get(b"hello"), None);
assert_eq!(trie.get(b""), None);
Expand Down
15 changes: 15 additions & 0 deletions text-utils-prefix/src/bin/test_art.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use std::{fs, path::PathBuf};

use text_utils_prefix::adaptive_radix_trie::AdaptiveRadixTrie;

fn main() {
let dir = env!("CARGO_MANIFEST_DIR");
let index = fs::read_to_string(PathBuf::from(dir).join("resources/test/index.txt"))
.expect("failed to read file");
let n = 10_000_000;
let words: Vec<_> = index.lines().map(|s| s.as_bytes()).take(n).collect();

let trie: AdaptiveRadixTrie<_> = words.iter().enumerate().map(|(i, w)| (w, i)).collect();
let stats = trie.stats();
println!("{stats:#?}");
}
15 changes: 15 additions & 0 deletions text-utils-prefix/src/bin/test_patricia.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use std::{fs, path::PathBuf};

use text_utils_prefix::patricia_trie::PatriciaTrie;

fn main() {
let dir = env!("CARGO_MANIFEST_DIR");
let index = fs::read_to_string(PathBuf::from(dir).join("resources/test/index.txt"))
.expect("failed to read file");
let n = 1_000_000;
let words: Vec<_> = index.lines().map(|s| s.as_bytes()).take(n).collect();

let trie: PatriciaTrie<_> = words.iter().enumerate().map(|(i, w)| (w, i)).collect();
let stats = trie.stats();
println!("{stats:#?}");
}
15 changes: 15 additions & 0 deletions text-utils-prefix/src/bin/test_trie.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use std::{fs, path::PathBuf};

use text_utils_prefix::trie::Trie;

fn main() {
let dir = env!("CARGO_MANIFEST_DIR");
let index = fs::read_to_string(PathBuf::from(dir).join("resources/test/index.txt"))
.expect("failed to read file");
let n = 100_000;
let words: Vec<_> = index.lines().map(|s| s.as_bytes()).take(n).collect();

let trie: Trie<_> = words.iter().enumerate().map(|(i, w)| (w, i)).collect();
let stats = trie.stats();
println!("{stats:#?}");
}
12 changes: 10 additions & 2 deletions text-utils-prefix/src/patricia_trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ enum NodeType<V> {
#[default]
Empty,
Leaf(V),
Inner([Option<Box<Node<V>>>; 256]),
Inner(Box<[Option<Box<Node<V>>>; 256]>),
}

#[derive(Debug)]
Expand Down Expand Up @@ -105,7 +105,7 @@ impl<V> Node<V> {
fn new_inner(prefix: Vec<u8>) -> Self {
Self {
prefix: prefix.into_boxed_slice(),
inner: NodeType::Inner(std::array::from_fn(|_| None)),
inner: NodeType::Inner(Box::new(std::array::from_fn(|_| None))),
}
}

Expand Down Expand Up @@ -494,12 +494,20 @@ impl<V> ContinuationSearch for PatriciaTrie<V> {

#[cfg(test)]
mod test {
use crate::patricia_trie::Node;
use crate::{patricia_trie::PatriciaTrie, PrefixSearch};
use std::fs;
use std::path::PathBuf;

#[test]
fn test_trie() {
println!(
"size of patricia trie node: {}, box array: {}, box slice: {}, vec: {}",
std::mem::size_of::<Node<i32>>(),
std::mem::size_of::<Box<[usize; 256]>>(),
std::mem::size_of::<Box<[usize]>>(),
std::mem::size_of::<Vec<usize>>()
);
let mut trie = PatriciaTrie::default();
assert_eq!(trie.get(b"hello"), None);
assert_eq!(trie.get(b""), None);
Expand Down
8 changes: 4 additions & 4 deletions text-utils-prefix/src/trie.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
use std::collections::HashMap;

use crate::PrefixSearch;

#[derive(Debug)]
struct Node<V> {
value: Option<V>,
children: [Option<Box<Node<V>>>; 256],
children: Box<[Option<Box<Node<V>>>; 256]>,
}

impl<V> Default for Node<V> {
fn default() -> Self {
Self {
value: None,
children: std::array::from_fn(|_| None),
children: Box::new(std::array::from_fn(|_| None)),
}
}
}
Expand Down Expand Up @@ -177,12 +175,14 @@ impl<V> PrefixSearch for Trie<V> {

#[cfg(test)]
mod test {
use crate::trie::Node;
use crate::{trie::Trie, PrefixSearch};
use std::fs;
use std::path::PathBuf;

#[test]
fn test_trie() {
println!("size of trie node: {}", std::mem::size_of::<Node<i32>>());
let mut trie = Trie::default();
assert_eq!(trie.get(b"hello"), None);
assert_eq!(trie.get(b""), None);
Expand Down

0 comments on commit 5184000

Please sign in to comment.