darkfi_sdk/monotree/
tree.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2025 Dyne.org foundation
4 * Copyright (C) 2021 MONOLOG (Taeho Francis Lim and Jongwhan Lee) MIT License
5 *
6 * This program is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU Affero General Public License as
8 * published by the Free Software Foundation, either version 3 of the
9 * License, or (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU Affero General Public License for more details.
15 *
16 * You should have received a copy of the GNU Affero General Public License
17 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
18 */
19
20use hashbrown::{HashMap, HashSet};
21
22use super::{
23    bits::Bits,
24    node::{Node, Unit},
25    utils::{get_sorted_indices, slice_to_hash},
26    Hash, Proof, HASH_LEN, ROOT_KEY,
27};
28use crate::GenericResult;
29
30#[derive(Clone, Debug)]
31pub(crate) struct MemCache {
32    pub(crate) set: HashSet<Hash>,
33    pub(crate) map: HashMap<Hash, Vec<u8>>,
34}
35
36#[allow(dead_code)]
37impl MemCache {
38    pub(crate) fn new() -> Self {
39        Self { set: HashSet::new(), map: HashMap::with_capacity(1 << 12) }
40    }
41
42    pub(crate) fn clear(&mut self) {
43        self.set.clear();
44        self.map.clear();
45    }
46
47    pub(crate) fn contains(&self, key: &[u8]) -> bool {
48        !self.set.contains(key) && self.map.contains_key(key)
49    }
50
51    pub(crate) fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
52        self.map.get(key).cloned()
53    }
54
55    pub(crate) fn put(&mut self, key: &[u8], value: Vec<u8>) {
56        self.map.insert(slice_to_hash(key), value);
57        self.set.remove(key);
58    }
59
60    pub(crate) fn delete(&mut self, key: &[u8]) {
61        self.set.insert(slice_to_hash(key));
62    }
63}
64
65#[derive(Clone, Debug)]
66pub struct MemoryDb {
67    db: HashMap<Hash, Vec<u8>>,
68    batch: MemCache,
69    batch_on: bool,
70}
71
72#[allow(dead_code)]
73impl MemoryDb {
74    fn new() -> Self {
75        Self { db: HashMap::new(), batch: MemCache::new(), batch_on: false }
76    }
77
78    fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
79        if self.batch_on && self.batch.contains(key) {
80            return Ok(self.batch.get(key));
81        }
82
83        match self.db.get(key) {
84            Some(v) => Ok(Some(v.to_owned())),
85            None => Ok(None),
86        }
87    }
88
89    fn put(&mut self, key: &[u8], value: Vec<u8>) -> GenericResult<()> {
90        if self.batch_on {
91            self.batch.put(key, value);
92        } else {
93            self.db.insert(slice_to_hash(key), value);
94        }
95        Ok(())
96    }
97
98    fn delete(&mut self, key: &[u8]) -> GenericResult<()> {
99        if self.batch_on {
100            self.batch.delete(key);
101        } else {
102            self.db.remove(key);
103        }
104        Ok(())
105    }
106
107    fn init_batch(&mut self) -> GenericResult<()> {
108        if !self.batch_on {
109            self.batch.clear();
110            self.batch_on = true;
111        }
112        Ok(())
113    }
114
115    fn finish_batch(&mut self) -> GenericResult<()> {
116        if self.batch_on {
117            for (key, value) in self.batch.map.drain() {
118                self.db.insert(key, value);
119            }
120            for key in self.batch.set.drain() {
121                self.db.remove(&key);
122            }
123            self.batch_on = false;
124        }
125        Ok(())
126    }
127}
128
129/// A structure for `monotree`
130#[derive(Clone, Debug)]
131pub struct Monotree {
132    db: MemoryDb,
133}
134
135impl Default for Monotree {
136    fn default() -> Self {
137        Self::new()
138    }
139}
140
141impl Monotree {
142    pub fn new() -> Self {
143        Self { db: MemoryDb::new() }
144    }
145
146    fn hash_digest(bytes: &[u8]) -> Hash {
147        let mut hasher = blake3::Hasher::new();
148        hasher.update(bytes);
149        let hash = hasher.finalize();
150        slice_to_hash(hash.as_bytes())
151    }
152
153    /// Retrieves the latest state (root) from the database.
154    pub fn get_headroot(&self) -> GenericResult<Option<Hash>> {
155        let headroot = self.db.get(ROOT_KEY)?;
156        match headroot {
157            Some(root) => Ok(Some(slice_to_hash(&root))),
158            None => Ok(None),
159        }
160    }
161
162    /// Sets the latest state (root) to the database.
163    pub fn set_headroot(&mut self, headroot: Option<&Hash>) {
164        if let Some(root) = headroot {
165            self.db.put(ROOT_KEY, root.to_vec()).expect("set_headroot(): hash");
166        }
167    }
168
169    pub fn prepare(&mut self) {
170        self.db.init_batch().expect("prepare(): failed to initialize batch");
171    }
172
173    pub fn commit(&mut self) {
174        self.db.finish_batch().expect("commit(): failed to initialize batch");
175    }
176
177    /// Insert key-leaf entry into the tree. Returns a new root hash.
178    pub fn insert(
179        &mut self,
180        root: Option<&Hash>,
181        key: &Hash,
182        leaf: &Hash,
183    ) -> GenericResult<Option<Hash>> {
184        match root {
185            None => {
186                let (hash, bits) = (leaf, Bits::new(key));
187                self.put_node(Node::new(Some(Unit { hash, bits }), None))
188            }
189            Some(root) => self.put(root, Bits::new(key), leaf),
190        }
191    }
192
193    fn put_node(&mut self, node: Node) -> GenericResult<Option<Hash>> {
194        let bytes = node.to_bytes()?;
195        let hash = Self::hash_digest(&bytes);
196        self.db.put(&hash, bytes)?;
197        Ok(Some(hash))
198    }
199
200    /// Recursively insert a bytes (in forms of Bits) and a leaf into the tree.
201    ///
202    /// Optimisation in `monotree` is mainly to compress the path as much as possible
203    /// while reducing the number of db accesses using the most intuitive model.
204    /// As a result, compared to the standard Sparse Merkle Tree this reduces the
205    /// number of DB accesses from `N` to `log2(N)` in both reads and writes.
206    ///
207    /// Whenever invoked a `put()` call, at least, more than one `put_node()` called,
208    /// which triggers a single hash digest + a single DB write.
209    /// Compressing the path reduces the number of `put()` calls, which yields reducing
210    /// the number of hash function calls as well as the number of DB writes.
211    ///
212    /// There are four modes when putting the entries and each of them is processed in a
213    /// recursive `put()` call.
214    /// The number in parenthesis refers to the minimum of DB access and hash fn calls required.
215    ///
216    /// * set-aside (1)
217    ///   Putting the leaf to the next node in the current depth.
218    /// * replacement (1)
219    ///   Replaces the existing node on the path with the new leaf.
220    /// * consume & pass-over (2+)
221    ///   Consuming the path on the way, then pass the rest of work to their child node.
222    /// * split-node (2)
223    ///   Immediately split node into two with the longest common prefix,
224    ///   then wind the recursive stack from there returning resulting hashes.
225    fn put(&mut self, root: &[u8], bits: Bits, leaf: &[u8]) -> GenericResult<Option<Hash>> {
226        let bytes = self.db.get(root)?.expect("bytes");
227        let (lc, rc) = Node::cells_from_bytes(&bytes, bits.first())?;
228        let unit = lc.as_ref().expect("put(): left-unit");
229        let n = Bits::len_common_bits(&unit.bits, &bits);
230
231        match n {
232            0 => self.put_node(Node::new(lc, Some(Unit { hash: leaf, bits }))),
233            n if n == bits.len() => self.put_node(Node::new(Some(Unit { hash: leaf, bits }), rc)),
234            n if n == unit.bits.len() => {
235                let hash = &self.put(unit.hash, bits.shift(n, false), leaf)?.expect("put(): hash");
236
237                let unit = unit.to_owned();
238                self.put_node(Node::new(Some(Unit { hash, ..unit }), rc))
239            }
240            _ => {
241                let bits = bits.shift(n, false);
242                let ru = Unit { hash: leaf, bits };
243
244                let (cloned, unit) = (unit.bits.clone(), unit.to_owned());
245                let (hash, bits) = (unit.hash, unit.bits.shift(n, false));
246                let lu = Unit { hash, bits };
247
248                // ENFORCE DETERMINISTIC ORDERING
249                let (left, right) = if lu.bits < ru.bits { (lu, ru) } else { (ru, lu) };
250
251                let hash =
252                    &self.put_node(Node::new(Some(left), Some(right)))?.expect("put(): hash");
253                let bits = cloned.shift(n, true);
254                self.put_node(Node::new(Some(Unit { hash, bits }), rc))
255            }
256        }
257    }
258
259    /// Get a leaf hash for the given root and key.
260    pub fn get(&mut self, root: Option<&Hash>, key: &Hash) -> GenericResult<Option<Hash>> {
261        match root {
262            None => Ok(None),
263            Some(root) => self.find_key(root, Bits::new(key)),
264        }
265    }
266
267    fn find_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
268        let bytes = self.db.get(root)?.expect("bytes");
269        let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
270        let unit = cell.as_ref().expect("find_key(): left-unit");
271        let n = Bits::len_common_bits(&unit.bits, &bits);
272        match n {
273            n if n == bits.len() => Ok(Some(slice_to_hash(unit.hash))),
274            n if n == unit.bits.len() => self.find_key(unit.hash, bits.shift(n, false)),
275            _ => Ok(None),
276        }
277    }
278
279    /// Remove the given key and its corresponding leaf from the tree. Returns a new root hash.
280    pub fn remove(&mut self, root: Option<&Hash>, key: &[u8]) -> GenericResult<Option<Hash>> {
281        match root {
282            None => Ok(None),
283            Some(root) => self.delete_key(root, Bits::new(key)),
284        }
285    }
286
287    fn delete_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
288        let bytes = self.db.get(root)?.expect("bytes");
289        let (lc, rc) = Node::cells_from_bytes(&bytes, bits.first())?;
290        let unit = lc.as_ref().expect("delete_key(): left-unit");
291        let n = Bits::len_common_bits(&unit.bits, &bits);
292
293        match n {
294            n if n == bits.len() => match rc {
295                Some(_) => self.put_node(Node::new(None, rc)),
296                None => Ok(None),
297            },
298            n if n == unit.bits.len() => {
299                let hash = self.delete_key(unit.hash, bits.shift(n, false))?;
300                match (hash, &rc) {
301                    (None, None) => Ok(None),
302                    (None, Some(_)) => self.put_node(Node::new(None, rc)),
303                    (Some(ref hash), _) => {
304                        let unit = unit.to_owned();
305                        let lc = Some(Unit { hash, ..unit });
306                        self.put_node(Node::new(lc, rc))
307                    }
308                }
309            }
310            _ => Ok(None),
311        }
312    }
313
314    /// This method is indented to use the `insert()` method in batch mode.
315    /// Note that `inserts()` forces the batch to commit.
316    pub fn inserts(
317        &mut self,
318        root: Option<&Hash>,
319        keys: &[Hash],
320        leaves: &[Hash],
321    ) -> GenericResult<Option<Hash>> {
322        let indices = get_sorted_indices(keys, false);
323        self.prepare();
324
325        let mut root = root.cloned();
326        for i in indices.iter() {
327            root = self.insert(root.as_ref(), &keys[*i], &leaves[*i])?;
328        }
329
330        self.commit();
331        Ok(root)
332    }
333
334    /// This method is intended to use the `get()` method in batch mode.
335    pub fn gets(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Vec<Option<Hash>>> {
336        let mut leaves: Vec<Option<Hash>> = vec![];
337        for key in keys.iter() {
338            leaves.push(self.get(root, key)?);
339        }
340        Ok(leaves)
341    }
342
343    /// This method is intended to use the `remove()` method in batch mode.
344    /// Note that `removes()` forces the batch to commit.
345    pub fn removes(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Option<Hash>> {
346        let indices = get_sorted_indices(keys, false);
347        let mut root = root.cloned();
348        self.prepare();
349
350        for i in indices.iter() {
351            root = self.remove(root.as_ref(), &keys[*i])?;
352        }
353
354        self.commit();
355        Ok(root)
356    }
357
358    /// Generate a Merkle proof for the given root and key.
359    pub fn get_merkle_proof(
360        &mut self,
361        root: Option<&Hash>,
362        key: &[u8],
363    ) -> GenericResult<Option<Proof>> {
364        let mut proof: Proof = vec![];
365        match root {
366            None => Ok(None),
367            Some(root) => self.gen_proof(root, Bits::new(key), &mut proof),
368        }
369    }
370
371    fn gen_proof(
372        &mut self,
373        root: &[u8],
374        bits: Bits,
375        proof: &mut Proof,
376    ) -> GenericResult<Option<Proof>> {
377        let bytes = self.db.get(root)?.expect("bytes");
378        let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
379        let unit = cell.as_ref().expect("gen_proof(): left-unit");
380        let n = Bits::len_common_bits(&unit.bits, &bits);
381
382        match n {
383            n if n == bits.len() => {
384                proof.push(self.encode_proof(&bytes, bits.first())?);
385                Ok(Some(proof.to_owned()))
386            }
387            n if n == unit.bits.len() => {
388                proof.push(self.encode_proof(&bytes, bits.first())?);
389                self.gen_proof(unit.hash, bits.shift(n, false), proof)
390            }
391            _ => Ok(None),
392        }
393    }
394
395    fn encode_proof(&self, bytes: &[u8], right: bool) -> GenericResult<(bool, Vec<u8>)> {
396        match Node::from_bytes(bytes)? {
397            Node::Soft(_) => Ok((false, bytes[HASH_LEN..].to_vec())),
398            Node::Hard(_, _) => {
399                if right {
400                    Ok((true, [&bytes[..bytes.len() - HASH_LEN - 1], &[0x01]].concat()))
401                } else {
402                    Ok((false, bytes[HASH_LEN..].to_vec()))
403                }
404            }
405        }
406    }
407}
408
409/// Verify a MerkleProof with the given root and leaf.
410pub fn verify_proof(root: Option<&Hash>, leaf: &Hash, proof: Option<&Proof>) -> bool {
411    match proof {
412        None => false,
413        Some(proof) => {
414            let mut hash = leaf.to_owned();
415            proof.iter().rev().for_each(|(right, cut)| {
416                if *right {
417                    let l = cut.len();
418                    let o = [&cut[..l - 1], &hash[..], &cut[l - 1..]].concat();
419                    hash = Monotree::hash_digest(&o);
420                } else {
421                    let o = [&hash[..], &cut[..]].concat();
422                    hash = Monotree::hash_digest(&o);
423                }
424            });
425            root.expect("verify_proof(): root") == &hash
426        }
427    }
428}