darkfi_sdk/monotree/
tree.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2025 Dyne.org foundation
4 * Copyright (C) 2021 MONOLOG (Taeho Francis Lim and Jongwhan Lee) MIT License
5 *
6 * This program is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU Affero General Public License as
8 * published by the Free Software Foundation, either version 3 of the
9 * License, or (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU Affero General Public License for more details.
15 *
16 * You should have received a copy of the GNU Affero General Public License
17 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
18 */
19
20use hashbrown::{HashMap, HashSet};
21use sled_overlay::{sled::Tree, SledDbOverlay};
22
23use super::{
24    bits::Bits,
25    node::{Node, Unit},
26    utils::{get_sorted_indices, slice_to_hash},
27    Hash, Proof, HASH_LEN, ROOT_KEY,
28};
29use crate::{ContractError, GenericResult};
30
31#[derive(Clone, Debug)]
32pub(crate) struct MemCache {
33    pub(crate) set: HashSet<Hash>,
34    pub(crate) map: HashMap<Hash, Vec<u8>>,
35}
36
37#[allow(dead_code)]
38impl MemCache {
39    pub(crate) fn new() -> Self {
40        Self { set: HashSet::new(), map: HashMap::with_capacity(1 << 12) }
41    }
42
43    pub(crate) fn clear(&mut self) {
44        self.set.clear();
45        self.map.clear();
46    }
47
48    pub(crate) fn contains(&self, key: &[u8]) -> bool {
49        !self.set.contains(key) && self.map.contains_key(key)
50    }
51
52    pub(crate) fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
53        self.map.get(key).cloned()
54    }
55
56    pub(crate) fn put(&mut self, key: &[u8], value: Vec<u8>) {
57        self.map.insert(slice_to_hash(key), value);
58        self.set.remove(key);
59    }
60
61    pub(crate) fn del(&mut self, key: &[u8]) {
62        self.set.insert(slice_to_hash(key));
63    }
64}
65
66/// Trait for implementing Monotree's storage system
67pub trait MonotreeStorageAdapter {
68    /// Insert a Key/Value pair into the Monotree
69    fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()>;
70    /// Query the Monotree for a Key
71    fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>>;
72    /// Delete an entry in the Monotree
73    fn del(&mut self, key: &Hash) -> GenericResult<()>;
74    /// Initialize a batch
75    fn init_batch(&mut self) -> GenericResult<()>;
76    /// Finalize and write a batch
77    fn finish_batch(&mut self) -> GenericResult<()>;
78}
79
80/// In-memory storage for Monotree
81#[derive(Clone, Debug)]
82pub struct MemoryDb {
83    db: HashMap<Hash, Vec<u8>>,
84    batch: MemCache,
85    batch_on: bool,
86}
87
88#[allow(clippy::new_without_default)]
89impl MemoryDb {
90    pub fn new() -> Self {
91        Self { db: HashMap::new(), batch: MemCache::new(), batch_on: false }
92    }
93}
94
95impl MonotreeStorageAdapter for MemoryDb {
96    fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()> {
97        if self.batch_on {
98            self.batch.put(key, value);
99        } else {
100            self.db.insert(slice_to_hash(key), value);
101        }
102
103        Ok(())
104    }
105
106    fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
107        if self.batch_on && self.batch.contains(key) {
108            return Ok(self.batch.get(key));
109        }
110
111        match self.db.get(key) {
112            Some(v) => Ok(Some(v.to_owned())),
113            None => Ok(None),
114        }
115    }
116
117    fn del(&mut self, key: &Hash) -> GenericResult<()> {
118        if self.batch_on {
119            self.batch.del(key);
120        } else {
121            self.db.remove(key);
122        }
123
124        Ok(())
125    }
126
127    fn init_batch(&mut self) -> GenericResult<()> {
128        if !self.batch_on {
129            self.batch.clear();
130            self.batch_on = true;
131        }
132
133        Ok(())
134    }
135
136    fn finish_batch(&mut self) -> GenericResult<()> {
137        if self.batch_on {
138            for (key, value) in self.batch.map.drain() {
139                self.db.insert(key, value);
140            }
141            for key in self.batch.set.drain() {
142                self.db.remove(&key);
143            }
144            self.batch_on = false;
145        }
146
147        Ok(())
148    }
149}
150
151/// sled-tree based storage for Monotree
152#[derive(Clone)]
153pub struct SledTreeDb {
154    tree: Tree,
155    batch: MemCache,
156    batch_on: bool,
157}
158
159impl SledTreeDb {
160    pub fn new(tree: &Tree) -> Self {
161        Self { tree: tree.clone(), batch: MemCache::new(), batch_on: false }
162    }
163}
164
165impl MonotreeStorageAdapter for SledTreeDb {
166    fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()> {
167        if self.batch_on {
168            self.batch.put(key, value);
169        } else if let Err(e) = self.tree.insert(slice_to_hash(key), value) {
170            return Err(ContractError::IoError(e.to_string()))
171        }
172
173        Ok(())
174    }
175
176    fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
177        if self.batch_on && self.batch.contains(key) {
178            return Ok(self.batch.get(key));
179        }
180
181        match self.tree.get(key) {
182            Ok(Some(v)) => Ok(Some(v.to_vec())),
183            Ok(None) => Ok(None),
184            Err(e) => Err(ContractError::IoError(e.to_string())),
185        }
186    }
187
188    fn del(&mut self, key: &Hash) -> GenericResult<()> {
189        if self.batch_on {
190            self.batch.del(key);
191        } else if let Err(e) = self.tree.remove(key) {
192            return Err(ContractError::IoError(e.to_string()));
193        }
194
195        Ok(())
196    }
197
198    fn init_batch(&mut self) -> GenericResult<()> {
199        if !self.batch_on {
200            self.batch.clear();
201            self.batch_on = true;
202        }
203
204        Ok(())
205    }
206
207    fn finish_batch(&mut self) -> GenericResult<()> {
208        if self.batch_on {
209            for (key, value) in self.batch.map.drain() {
210                if let Err(e) = self.tree.insert(key, value) {
211                    return Err(ContractError::IoError(e.to_string()))
212                }
213            }
214            for key in self.batch.set.drain() {
215                if let Err(e) = self.tree.remove(key) {
216                    return Err(ContractError::IoError(e.to_string()))
217                }
218            }
219            self.batch_on = false;
220        }
221
222        Ok(())
223    }
224}
225
226/// sled-overlay based storage for Monotree
227pub struct SledOverlayDb<'a> {
228    overlay: &'a mut SledDbOverlay,
229    tree: [u8; 32],
230    batch: MemCache,
231    batch_on: bool,
232}
233
234impl<'a> SledOverlayDb<'a> {
235    pub fn new(
236        overlay: &'a mut SledDbOverlay,
237        tree: &[u8; 32],
238    ) -> GenericResult<SledOverlayDb<'a>> {
239        if let Err(e) = overlay.open_tree(tree, false) {
240            return Err(ContractError::IoError(e.to_string()))
241        };
242        Ok(Self { overlay, tree: *tree, batch: MemCache::new(), batch_on: false })
243    }
244}
245
246impl MonotreeStorageAdapter for SledOverlayDb<'_> {
247    fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()> {
248        if self.batch_on {
249            self.batch.put(key, value);
250        } else if let Err(e) = self.overlay.insert(&self.tree, &slice_to_hash(key), &value) {
251            return Err(ContractError::IoError(e.to_string()))
252        }
253
254        Ok(())
255    }
256
257    fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
258        if self.batch_on && self.batch.contains(key) {
259            return Ok(self.batch.get(key));
260        }
261
262        match self.overlay.get(&self.tree, key) {
263            Ok(Some(v)) => Ok(Some(v.to_vec())),
264            Ok(None) => Ok(None),
265            Err(e) => Err(ContractError::IoError(e.to_string())),
266        }
267    }
268
269    fn del(&mut self, key: &Hash) -> GenericResult<()> {
270        if self.batch_on {
271            self.batch.del(key);
272        } else if let Err(e) = self.overlay.remove(&self.tree, key) {
273            return Err(ContractError::IoError(e.to_string()));
274        }
275
276        Ok(())
277    }
278
279    fn init_batch(&mut self) -> GenericResult<()> {
280        if !self.batch_on {
281            self.batch.clear();
282            self.batch_on = true;
283        }
284
285        Ok(())
286    }
287
288    fn finish_batch(&mut self) -> GenericResult<()> {
289        if self.batch_on {
290            for (key, value) in self.batch.map.drain() {
291                if let Err(e) = self.overlay.insert(&self.tree, &key, &value) {
292                    return Err(ContractError::IoError(e.to_string()))
293                }
294            }
295            for key in self.batch.set.drain() {
296                if let Err(e) = self.overlay.remove(&self.tree, &key) {
297                    return Err(ContractError::IoError(e.to_string()))
298                }
299            }
300            self.batch_on = false;
301        }
302
303        Ok(())
304    }
305}
306
307/// A structure for `monotree`
308///
309/// To use this, first create a `MonotreeStorageAdapter` implementor,
310/// and then just manually create this struct with `::new()`.
311#[derive(Clone, Debug)]
312pub struct Monotree<D: MonotreeStorageAdapter> {
313    db: D,
314}
315
316impl<D: MonotreeStorageAdapter> Monotree<D> {
317    pub fn new(db: D) -> Self {
318        Self { db }
319    }
320
321    fn hash_digest(bytes: &[u8]) -> Hash {
322        let mut hasher = blake3::Hasher::new();
323        hasher.update(bytes);
324        let hash = hasher.finalize();
325        slice_to_hash(hash.as_bytes())
326    }
327
328    /// Retrieves the latest state (root) from the database.
329    pub fn get_headroot(&self) -> GenericResult<Option<Hash>> {
330        let headroot = self.db.get(ROOT_KEY)?;
331        match headroot {
332            Some(root) => Ok(Some(slice_to_hash(&root))),
333            None => Ok(None),
334        }
335    }
336
337    /// Sets the latest state (root) to the database.
338    pub fn set_headroot(&mut self, headroot: Option<&Hash>) {
339        if let Some(root) = headroot {
340            self.db.put(ROOT_KEY, root.to_vec()).expect("set_headroot(): hash");
341        }
342    }
343
344    pub fn prepare(&mut self) {
345        self.db.init_batch().expect("prepare(): failed to initialize batch");
346    }
347
348    pub fn commit(&mut self) {
349        self.db.finish_batch().expect("commit(): failed to initialize batch");
350    }
351
352    /// Insert key-leaf entry into the tree. Returns a new root hash.
353    pub fn insert(
354        &mut self,
355        root: Option<&Hash>,
356        key: &Hash,
357        leaf: &Hash,
358    ) -> GenericResult<Option<Hash>> {
359        match root {
360            None => {
361                let (hash, bits) = (leaf, Bits::new(key));
362                self.put_node(Node::new(Some(Unit { hash, bits }), None))
363            }
364            Some(root) => self.put(root, Bits::new(key), leaf),
365        }
366    }
367
368    fn put_node(&mut self, node: Node) -> GenericResult<Option<Hash>> {
369        let bytes = node.to_bytes()?;
370        let hash = Self::hash_digest(&bytes);
371        self.db.put(&hash, bytes)?;
372        Ok(Some(hash))
373    }
374
375    /// Recursively insert a bytes (in forms of Bits) and a leaf into the tree.
376    ///
377    /// Optimisation in `monotree` is mainly to compress the path as much as possible
378    /// while reducing the number of db accesses using the most intuitive model.
379    /// As a result, compared to the standard Sparse Merkle Tree this reduces the
380    /// number of DB accesses from `N` to `log2(N)` in both reads and writes.
381    ///
382    /// Whenever invoked a `put()` call, at least, more than one `put_node()` called,
383    /// which triggers a single hash digest + a single DB write.
384    /// Compressing the path reduces the number of `put()` calls, which yields reducing
385    /// the number of hash function calls as well as the number of DB writes.
386    ///
387    /// There are four modes when putting the entries and each of them is processed in a
388    /// recursive `put()` call.
389    /// The number in parenthesis refers to the minimum of DB access and hash fn calls required.
390    ///
391    /// * set-aside (1)
392    ///   Putting the leaf to the next node in the current depth.
393    /// * replacement (1)
394    ///   Replaces the existing node on the path with the new leaf.
395    /// * consume & pass-over (2+)
396    ///   Consuming the path on the way, then pass the rest of work to their child node.
397    /// * split-node (2)
398    ///   Immediately split node into two with the longest common prefix,
399    ///   then wind the recursive stack from there returning resulting hashes.
400    fn put(&mut self, root: &[u8], bits: Bits, leaf: &[u8]) -> GenericResult<Option<Hash>> {
401        let bytes = self.db.get(root)?.expect("put(): bytes");
402        let (left, right) = Node::cells_from_bytes(&bytes, bits.first())?;
403        let unit = left.as_ref().expect("put(): left-unit");
404        let n = Bits::len_common_bits(&unit.bits, &bits);
405
406        match n {
407            0 => self.put_node(Node::new(left, Some(Unit { hash: leaf, bits }))),
408            n if n == bits.len() => {
409                self.put_node(Node::new(Some(Unit { hash: leaf, bits }), right))
410            }
411            n if n == unit.bits.len() => {
412                let hash =
413                    &self.put(unit.hash, bits.drop(n), leaf)?.expect("put(): consume & pass-over");
414
415                self.put_node(Node::new(Some(Unit { hash, bits: unit.bits.to_owned() }), right))
416            }
417            _ => {
418                let hash = &self
419                    .put_node(Node::new(
420                        Some(Unit { hash: unit.hash, bits: unit.bits.drop(n) }),
421                        Some(Unit { hash: leaf, bits: bits.drop(n) }),
422                    ))?
423                    .expect("put(): split-node");
424
425                self.put_node(Node::new(Some(Unit { hash, bits: unit.bits.take(n) }), right))
426            }
427        }
428    }
429
430    /// Get a leaf hash for the given root and key.
431    pub fn get(&mut self, root: Option<&Hash>, key: &Hash) -> GenericResult<Option<Hash>> {
432        match root {
433            None => Ok(None),
434            Some(root) => self.find_key(root, Bits::new(key)),
435        }
436    }
437
438    fn find_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
439        let bytes = self.db.get(root)?.expect("find_key(): bytes");
440        let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
441        let unit = cell.as_ref().expect("find_key(): left-unit");
442        let n = Bits::len_common_bits(&unit.bits, &bits);
443        match n {
444            n if n == bits.len() => Ok(Some(slice_to_hash(unit.hash))),
445            n if n == unit.bits.len() => self.find_key(unit.hash, bits.drop(n)),
446            _ => Ok(None),
447        }
448    }
449
450    /// Remove the given key and its corresponding leaf from the tree. Returns a new root hash.
451    pub fn remove(&mut self, root: Option<&Hash>, key: &[u8]) -> GenericResult<Option<Hash>> {
452        match root {
453            None => Ok(None),
454            Some(root) => self.delete_key(root, Bits::new(key)),
455        }
456    }
457
458    fn delete_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
459        let bytes = self.db.get(root)?.expect("delete_key(): bytes");
460        let (left, right) = Node::cells_from_bytes(&bytes, bits.first())?;
461        let unit = left.as_ref().expect("delete_key(): left-unit");
462        let n = Bits::len_common_bits(&unit.bits, &bits);
463
464        match n {
465            n if n == bits.len() => match right {
466                Some(_) => self.put_node(Node::new(None, right)),
467                None => Ok(None),
468            },
469            n if n == unit.bits.len() => {
470                let hash = self.delete_key(unit.hash, bits.drop(n))?;
471                match (hash, &right) {
472                    (None, None) => Ok(None),
473                    (None, Some(_)) => self.put_node(Node::new(None, right)),
474                    (Some(ref hash), _) => {
475                        let unit = unit.to_owned();
476                        let left = Some(Unit { hash, ..unit });
477                        self.put_node(Node::new(left, right))
478                    }
479                }
480            }
481            _ => Ok(None),
482        }
483    }
484
485    /// This method is indented to use the `insert()` method in batch mode.
486    /// Note that `inserts()` forces the batch to commit.
487    pub fn inserts(
488        &mut self,
489        root: Option<&Hash>,
490        keys: &[Hash],
491        leaves: &[Hash],
492    ) -> GenericResult<Option<Hash>> {
493        let indices = get_sorted_indices(keys, false);
494        self.prepare();
495
496        let mut root = root.cloned();
497        for i in indices.iter() {
498            root = self.insert(root.as_ref(), &keys[*i], &leaves[*i])?;
499        }
500
501        self.commit();
502        Ok(root)
503    }
504
505    /// This method is intended to use the `get()` method in batch mode.
506    pub fn gets(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Vec<Option<Hash>>> {
507        let mut leaves: Vec<Option<Hash>> = vec![];
508        for key in keys.iter() {
509            leaves.push(self.get(root, key)?);
510        }
511        Ok(leaves)
512    }
513
514    /// This method is intended to use the `remove()` method in batch mode.
515    /// Note that `removes()` forces the batch to commit.
516    pub fn removes(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Option<Hash>> {
517        let indices = get_sorted_indices(keys, false);
518        let mut root = root.cloned();
519        self.prepare();
520
521        for i in indices.iter() {
522            root = self.remove(root.as_ref(), &keys[*i])?;
523        }
524
525        self.commit();
526        Ok(root)
527    }
528
529    /// Generate a Merkle proof for the given root and key.
530    pub fn get_merkle_proof(
531        &mut self,
532        root: Option<&Hash>,
533        key: &[u8],
534    ) -> GenericResult<Option<Proof>> {
535        let mut proof: Proof = vec![];
536        match root {
537            None => Ok(None),
538            Some(root) => self.gen_proof(root, Bits::new(key), &mut proof),
539        }
540    }
541
542    fn gen_proof(
543        &mut self,
544        root: &[u8],
545        bits: Bits,
546        proof: &mut Proof,
547    ) -> GenericResult<Option<Proof>> {
548        let bytes = self.db.get(root)?.expect("gen_proof(): bytes");
549        let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
550        let unit = cell.as_ref().expect("gen_proof(): left-unit");
551        let n = Bits::len_common_bits(&unit.bits, &bits);
552
553        match n {
554            n if n == bits.len() => {
555                proof.push(self.encode_proof(&bytes, bits.first())?);
556                Ok(Some(proof.to_owned()))
557            }
558            n if n == unit.bits.len() => {
559                proof.push(self.encode_proof(&bytes, bits.first())?);
560                self.gen_proof(unit.hash, bits.drop(n), proof)
561            }
562            _ => Ok(None),
563        }
564    }
565
566    fn encode_proof(&self, bytes: &[u8], right: bool) -> GenericResult<(bool, Vec<u8>)> {
567        match Node::from_bytes(bytes)? {
568            Node::Soft(_) => Ok((false, bytes[HASH_LEN..].to_vec())),
569            Node::Hard(_, _) => {
570                if right {
571                    Ok((true, [&bytes[..bytes.len() - HASH_LEN - 1], &[0x01]].concat()))
572                } else {
573                    Ok((false, bytes[HASH_LEN..].to_vec()))
574                }
575            }
576        }
577    }
578}
579
580/// Verify a MerkleProof with the given root and leaf.
581///
582/// NOTE: We use `Monotree::<MemoryDb>` to `hash_digest()` but it doesn't matter.
583pub fn verify_proof(root: Option<&Hash>, leaf: &Hash, proof: Option<&Proof>) -> bool {
584    match proof {
585        None => false,
586        Some(proof) => {
587            let mut hash = leaf.to_owned();
588            proof.iter().rev().for_each(|(right, cut)| {
589                if *right {
590                    let l = cut.len();
591                    let o = [&cut[..l - 1], &hash[..], &cut[l - 1..]].concat();
592                    hash = Monotree::<MemoryDb>::hash_digest(&o);
593                } else {
594                    let o = [&hash[..], &cut[..]].concat();
595                    hash = Monotree::<MemoryDb>::hash_digest(&o);
596                }
597            });
598            root.expect("verify_proof(): root") == &hash
599        }
600    }
601}