darkfi_sdk/monotree/
tree.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2025 Dyne.org foundation
4 * Copyright (C) 2021 MONOLOG (Taeho Francis Lim and Jongwhan Lee) MIT License
5 *
6 * This program is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU Affero General Public License as
8 * published by the Free Software Foundation, either version 3 of the
9 * License, or (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU Affero General Public License for more details.
15 *
16 * You should have received a copy of the GNU Affero General Public License
17 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
18 */
19
20use hashbrown::{HashMap, HashSet};
21
22use super::{
23    bits::Bits,
24    node::{Node, Unit},
25    utils::{get_sorted_indices, slice_to_hash},
26    Hash, Proof, HASH_LEN, ROOT_KEY,
27};
28use crate::GenericResult;
29
30#[derive(Clone, Debug)]
31pub(crate) struct MemCache {
32    pub(crate) set: HashSet<Hash>,
33    pub(crate) map: HashMap<Hash, Vec<u8>>,
34}
35
36#[allow(dead_code)]
37impl MemCache {
38    pub(crate) fn new() -> Self {
39        Self { set: HashSet::new(), map: HashMap::with_capacity(1 << 12) }
40    }
41
42    pub(crate) fn clear(&mut self) {
43        self.set.clear();
44        self.map.clear();
45    }
46
47    pub(crate) fn contains(&self, key: &[u8]) -> bool {
48        !self.set.contains(key) && self.map.contains_key(key)
49    }
50
51    pub(crate) fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
52        self.map.get(key).cloned()
53    }
54
55    pub(crate) fn put(&mut self, key: &[u8], value: Vec<u8>) {
56        self.map.insert(slice_to_hash(key), value);
57        self.set.remove(key);
58    }
59
60    pub(crate) fn delete(&mut self, key: &[u8]) {
61        self.set.insert(slice_to_hash(key));
62    }
63}
64
65#[derive(Clone, Debug)]
66pub struct MemoryDb {
67    db: HashMap<Hash, Vec<u8>>,
68    batch: MemCache,
69    batch_on: bool,
70}
71
72#[allow(dead_code)]
73impl MemoryDb {
74    fn new() -> Self {
75        Self { db: HashMap::new(), batch: MemCache::new(), batch_on: false }
76    }
77
78    fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
79        if self.batch_on && self.batch.contains(key) {
80            return Ok(self.batch.get(key));
81        }
82
83        match self.db.get(key) {
84            Some(v) => Ok(Some(v.to_owned())),
85            None => Ok(None),
86        }
87    }
88
89    fn put(&mut self, key: &[u8], value: Vec<u8>) -> GenericResult<()> {
90        if self.batch_on {
91            self.batch.put(key, value);
92        } else {
93            self.db.insert(slice_to_hash(key), value);
94        }
95        Ok(())
96    }
97
98    fn delete(&mut self, key: &[u8]) -> GenericResult<()> {
99        if self.batch_on {
100            self.batch.delete(key);
101        } else {
102            self.db.remove(key);
103        }
104        Ok(())
105    }
106
107    fn init_batch(&mut self) -> GenericResult<()> {
108        if !self.batch_on {
109            self.batch.clear();
110            self.batch_on = true;
111        }
112        Ok(())
113    }
114
115    fn finish_batch(&mut self) -> GenericResult<()> {
116        if self.batch_on {
117            for (key, value) in self.batch.map.drain() {
118                self.db.insert(key, value);
119            }
120            for key in self.batch.set.drain() {
121                self.db.remove(&key);
122            }
123            self.batch_on = false;
124        }
125        Ok(())
126    }
127}
128
129/// A structure for `monotree`
130#[derive(Clone, Debug)]
131pub struct Monotree {
132    db: MemoryDb,
133}
134
135impl Default for Monotree {
136    fn default() -> Self {
137        Self::new()
138    }
139}
140
141impl Monotree {
142    pub fn new() -> Self {
143        Self { db: MemoryDb::new() }
144    }
145
146    fn hash_digest(bytes: &[u8]) -> Hash {
147        let mut hasher = blake3::Hasher::new();
148        hasher.update(bytes);
149        let hash = hasher.finalize();
150        slice_to_hash(hash.as_bytes())
151    }
152
153    /// Retrieves the latest state (root) from the database.
154    pub fn get_headroot(&self) -> GenericResult<Option<Hash>> {
155        let headroot = self.db.get(ROOT_KEY)?;
156        match headroot {
157            Some(root) => Ok(Some(slice_to_hash(&root))),
158            None => Ok(None),
159        }
160    }
161
162    /// Sets the latest state (root) to the database.
163    pub fn set_headroot(&mut self, headroot: Option<&Hash>) {
164        if let Some(root) = headroot {
165            self.db.put(ROOT_KEY, root.to_vec()).expect("set_headroot(): hash");
166        }
167    }
168
169    pub fn prepare(&mut self) {
170        self.db.init_batch().expect("prepare(): failed to initialize batch");
171    }
172
173    pub fn commit(&mut self) {
174        self.db.finish_batch().expect("commit(): failed to initialize batch");
175    }
176
177    /// Insert key-leaf entry into the tree. Returns a new root hash.
178    pub fn insert(
179        &mut self,
180        root: Option<&Hash>,
181        key: &Hash,
182        leaf: &Hash,
183    ) -> GenericResult<Option<Hash>> {
184        match root {
185            None => {
186                let (hash, bits) = (leaf, Bits::new(key));
187                self.put_node(Node::new(Some(Unit { hash, bits }), None))
188            }
189            Some(root) => self.put(root, Bits::new(key), leaf),
190        }
191    }
192
193    fn put_node(&mut self, node: Node) -> GenericResult<Option<Hash>> {
194        let bytes = node.to_bytes()?;
195        let hash = Self::hash_digest(&bytes);
196        self.db.put(&hash, bytes)?;
197        Ok(Some(hash))
198    }
199
200    /// Recursively insert a bytes (in forms of Bits) and a leaf into the tree.
201    ///
202    /// Optimisation in `monotree` is mainly to compress the path as much as possible
203    /// while reducing the number of db accesses using the most intuitive model.
204    /// As a result, compared to the standard Sparse Merkle Tree this reduces the
205    /// number of DB accesses from `N` to `log2(N)` in both reads and writes.
206    ///
207    /// Whenever invoked a `put()` call, at least, more than one `put_node()` called,
208    /// which triggers a single hash digest + a single DB write.
209    /// Compressing the path reduces the number of `put()` calls, which yields reducing
210    /// the number of hash function calls as well as the number of DB writes.
211    ///
212    /// There are four modes when putting the entries and each of them is processed in a
213    /// recursive `put()` call.
214    /// The number in parenthesis refers to the minimum of DB access and hash fn calls required.
215    ///
216    /// * set-aside (1)
217    ///   Putting the leaf to the next node in the current depth.
218    /// * replacement (1)
219    ///   Replaces the existing node on the path with the new leaf.
220    /// * consume & pass-over (2+)
221    ///   Consuming the path on the way, then pass the rest of work to their child node.
222    /// * split-node (2)
223    ///   Immediately split node into two with the longest common prefix,
224    ///   then wind the recursive stack from there returning resulting hashes.
225    fn put(&mut self, root: &[u8], bits: Bits, leaf: &[u8]) -> GenericResult<Option<Hash>> {
226        let bytes = self.db.get(root)?.expect("put(): bytes");
227        let (left, right) = Node::cells_from_bytes(&bytes, bits.first())?;
228        let unit = left.as_ref().expect("put(): left-unit");
229        let n = Bits::len_common_bits(&unit.bits, &bits);
230
231        match n {
232            0 => self.put_node(Node::new(left, Some(Unit { hash: leaf, bits }))),
233            n if n == bits.len() => {
234                self.put_node(Node::new(Some(Unit { hash: leaf, bits }), right))
235            }
236            n if n == unit.bits.len() => {
237                let hash =
238                    &self.put(unit.hash, bits.drop(n), leaf)?.expect("put(): consume & pass-over");
239
240                self.put_node(Node::new(Some(Unit { hash, bits: unit.bits.to_owned() }), right))
241            }
242            _ => {
243                let hash = &self
244                    .put_node(Node::new(
245                        Some(Unit { hash: unit.hash, bits: unit.bits.drop(n) }),
246                        Some(Unit { hash: leaf, bits: bits.drop(n) }),
247                    ))?
248                    .expect("put(): split-node");
249
250                self.put_node(Node::new(Some(Unit { hash, bits: unit.bits.take(n) }), right))
251            }
252        }
253    }
254
255    /// Get a leaf hash for the given root and key.
256    pub fn get(&mut self, root: Option<&Hash>, key: &Hash) -> GenericResult<Option<Hash>> {
257        match root {
258            None => Ok(None),
259            Some(root) => self.find_key(root, Bits::new(key)),
260        }
261    }
262
263    fn find_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
264        let bytes = self.db.get(root)?.expect("find_key(): bytes");
265        let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
266        let unit = cell.as_ref().expect("find_key(): left-unit");
267        let n = Bits::len_common_bits(&unit.bits, &bits);
268        match n {
269            n if n == bits.len() => Ok(Some(slice_to_hash(unit.hash))),
270            n if n == unit.bits.len() => self.find_key(unit.hash, bits.drop(n)),
271            _ => Ok(None),
272        }
273    }
274
275    /// Remove the given key and its corresponding leaf from the tree. Returns a new root hash.
276    pub fn remove(&mut self, root: Option<&Hash>, key: &[u8]) -> GenericResult<Option<Hash>> {
277        match root {
278            None => Ok(None),
279            Some(root) => self.delete_key(root, Bits::new(key)),
280        }
281    }
282
283    fn delete_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
284        let bytes = self.db.get(root)?.expect("delete_key(): bytes");
285        let (left, right) = Node::cells_from_bytes(&bytes, bits.first())?;
286        let unit = left.as_ref().expect("delete_key(): left-unit");
287        let n = Bits::len_common_bits(&unit.bits, &bits);
288
289        match n {
290            n if n == bits.len() => match right {
291                Some(_) => self.put_node(Node::new(None, right)),
292                None => Ok(None),
293            },
294            n if n == unit.bits.len() => {
295                let hash = self.delete_key(unit.hash, bits.drop(n))?;
296                match (hash, &right) {
297                    (None, None) => Ok(None),
298                    (None, Some(_)) => self.put_node(Node::new(None, right)),
299                    (Some(ref hash), _) => {
300                        let unit = unit.to_owned();
301                        let left = Some(Unit { hash, ..unit });
302                        self.put_node(Node::new(left, right))
303                    }
304                }
305            }
306            _ => Ok(None),
307        }
308    }
309
310    /// This method is indented to use the `insert()` method in batch mode.
311    /// Note that `inserts()` forces the batch to commit.
312    pub fn inserts(
313        &mut self,
314        root: Option<&Hash>,
315        keys: &[Hash],
316        leaves: &[Hash],
317    ) -> GenericResult<Option<Hash>> {
318        let indices = get_sorted_indices(keys, false);
319        self.prepare();
320
321        let mut root = root.cloned();
322        for i in indices.iter() {
323            root = self.insert(root.as_ref(), &keys[*i], &leaves[*i])?;
324        }
325
326        self.commit();
327        Ok(root)
328    }
329
330    /// This method is intended to use the `get()` method in batch mode.
331    pub fn gets(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Vec<Option<Hash>>> {
332        let mut leaves: Vec<Option<Hash>> = vec![];
333        for key in keys.iter() {
334            leaves.push(self.get(root, key)?);
335        }
336        Ok(leaves)
337    }
338
339    /// This method is intended to use the `remove()` method in batch mode.
340    /// Note that `removes()` forces the batch to commit.
341    pub fn removes(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Option<Hash>> {
342        let indices = get_sorted_indices(keys, false);
343        let mut root = root.cloned();
344        self.prepare();
345
346        for i in indices.iter() {
347            root = self.remove(root.as_ref(), &keys[*i])?;
348        }
349
350        self.commit();
351        Ok(root)
352    }
353
354    /// Generate a Merkle proof for the given root and key.
355    pub fn get_merkle_proof(
356        &mut self,
357        root: Option<&Hash>,
358        key: &[u8],
359    ) -> GenericResult<Option<Proof>> {
360        let mut proof: Proof = vec![];
361        match root {
362            None => Ok(None),
363            Some(root) => self.gen_proof(root, Bits::new(key), &mut proof),
364        }
365    }
366
367    fn gen_proof(
368        &mut self,
369        root: &[u8],
370        bits: Bits,
371        proof: &mut Proof,
372    ) -> GenericResult<Option<Proof>> {
373        let bytes = self.db.get(root)?.expect("gen_proof(): bytes");
374        let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
375        let unit = cell.as_ref().expect("gen_proof(): left-unit");
376        let n = Bits::len_common_bits(&unit.bits, &bits);
377
378        match n {
379            n if n == bits.len() => {
380                proof.push(self.encode_proof(&bytes, bits.first())?);
381                Ok(Some(proof.to_owned()))
382            }
383            n if n == unit.bits.len() => {
384                proof.push(self.encode_proof(&bytes, bits.first())?);
385                self.gen_proof(unit.hash, bits.drop(n), proof)
386            }
387            _ => Ok(None),
388        }
389    }
390
391    fn encode_proof(&self, bytes: &[u8], right: bool) -> GenericResult<(bool, Vec<u8>)> {
392        match Node::from_bytes(bytes)? {
393            Node::Soft(_) => Ok((false, bytes[HASH_LEN..].to_vec())),
394            Node::Hard(_, _) => {
395                if right {
396                    Ok((true, [&bytes[..bytes.len() - HASH_LEN - 1], &[0x01]].concat()))
397                } else {
398                    Ok((false, bytes[HASH_LEN..].to_vec()))
399                }
400            }
401        }
402    }
403}
404
405/// Verify a MerkleProof with the given root and leaf.
406pub fn verify_proof(root: Option<&Hash>, leaf: &Hash, proof: Option<&Proof>) -> bool {
407    match proof {
408        None => false,
409        Some(proof) => {
410            let mut hash = leaf.to_owned();
411            proof.iter().rev().for_each(|(right, cut)| {
412                if *right {
413                    let l = cut.len();
414                    let o = [&cut[..l - 1], &hash[..], &cut[l - 1..]].concat();
415                    hash = Monotree::hash_digest(&o);
416                } else {
417                    let o = [&hash[..], &cut[..]].concat();
418                    hash = Monotree::hash_digest(&o);
419                }
420            });
421            root.expect("verify_proof(): root") == &hash
422        }
423    }
424}