darkfi_sdk/crypto/smt/
mod.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2025 Dyne.org foundation
4 *
5 * Copyright (C) 2021 Webb Technologies Inc.
6 * Copyright (c) zkMove Authors
7 * SPDX-License-Identifier: Apache-2.0
8 *
9 * This program is free software: you can redistribute it and/or modify
10 * it under the terms of the GNU Affero General Public License as
11 * published by the Free Software Foundation, either version 3 of the
12 * License, or (at your option) any later version.
13 *
14 * This program is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17 * GNU Affero General Public License for more details.
18 *
19 * You should have received a copy of the GNU Affero General Public License
20 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
21 */
22
23//! This file provides a native implementation of the Sparse Merkle tree data
24//! structure.
25//!
26//! A Sparse Merkle tree is a type of Merkle tree, but it is much easier to
27//! prove non-membership in a sparse Merkle tree than in an arbitrary Merkle
28//! tree. For an explanation of sparse Merkle trees, see:
29//! `<https://medium.com/@kelvinfichter/whats-a-sparse-merkle-tree-acda70aeb837>`
30//!
31//! In this file we define the `Path` and `SparseMerkleTree` structs.
32//! These depend on your choice of a prime field F, a field hasher over F
33//! (any hash function that maps F^2 to F will do, e.g. the poseidon hash
34//! function of width 3 where an input of zero is used for padding), and the
35//! height N of the sparse Merkle tree.
36//!
37//! The path corresponding to a given leaf node is stored as an N-tuple of pairs
38//! of field elements. Each pair consists of a node lying on the path from the
39//! leaf node to the root, and that node's sibling.  For example, suppose
40//! ```text
41//!           a
42//!         /   \
43//!        b     c
44//!       / \   / \
45//!      d   e f   g
46//! ```
47//! is our Sparse Merkle tree, and `a` through `g` are field elements stored at
48//! the nodes. Then the merkle proof path `e-b-a` from leaf `e` to root `a` is
49//! stored as `[(d,e), (b,c)]`
50//!
51//! # Terminology
52//!
53//! * **level** - the depth in the tree. Type: `u32`
54//! * **location** - a `(level, position)` tuple
55//! * **position** - the leaf index, or equivalently the binary direction through the tree
56//!   with type `F`.
57//! * **index** - the internal index used in the DB which is `BigUint`. Leaf node indexes are
58//!   calculated as `leaf_idx = final_level_start_idx + position`.
59//! * **node** - either the leaf values or parent nodes `hash(left, right)`.
60
61use num_bigint::BigUint;
62use std::collections::HashMap;
63// Only used for the type aliases below
64use pasta_curves::pallas;
65
66use crate::error::ContractResult;
67use util::{FieldElement, FieldHasher};
68
69mod empty;
70pub use empty::EMPTY_NODES_FP;
71
72#[cfg(test)]
73mod test;
74
75pub mod util;
76pub use util::Poseidon;
77
78#[cfg(feature = "wasm")]
79pub mod wasmdb;
80
81// Bit size for Fp (and Fq)
82pub const SMT_FP_DEPTH: usize = 255;
83pub type PoseidonFp = Poseidon<pallas::Base, 2>;
84pub type MemoryStorageFp = MemoryStorage<pallas::Base>;
85pub type SmtMemoryFp = SparseMerkleTree<
86    'static,
87    SMT_FP_DEPTH,
88    { SMT_FP_DEPTH + 1 },
89    pallas::Base,
90    PoseidonFp,
91    MemoryStorageFp,
92>;
93pub type PathFp = Path<SMT_FP_DEPTH, pallas::Base, PoseidonFp>;
94
95/// Pluggable storage backend for the SMT.
96/// Has a minimal interface to put, get, and delete objects from the store.
97pub trait StorageAdapter {
98    type Value;
99
100    fn put(&mut self, key: BigUint, value: Self::Value) -> ContractResult;
101    fn get(&self, key: &BigUint) -> Option<Self::Value>;
102    fn del(&mut self, key: &BigUint) -> ContractResult;
103}
104
105/// An in-memory storage, useful for unit tests and smaller trees.
106#[derive(Default, Clone)]
107pub struct MemoryStorage<F: FieldElement> {
108    pub tree: HashMap<BigUint, F>,
109}
110
111impl<F: FieldElement> MemoryStorage<F> {
112    pub fn new() -> Self {
113        Self { tree: HashMap::new() }
114    }
115}
116
117impl<F: FieldElement> StorageAdapter for MemoryStorage<F> {
118    type Value = F;
119
120    fn put(&mut self, key: BigUint, value: F) -> ContractResult {
121        self.tree.insert(key, value);
122        Ok(())
123    }
124
125    fn get(&self, key: &BigUint) -> Option<F> {
126        self.tree.get(key).copied()
127    }
128
129    fn del(&mut self, key: &BigUint) -> ContractResult {
130        self.tree.remove(key);
131        Ok(())
132    }
133}
134
135/// The Sparse Merkle Tree struct.
136///
137/// SMT stores a set of leaves represented in a map and a set of empty
138/// hashes that it uses to represent the sparse areas of the tree.
139///
140/// The trait param `N` is the depth of the tree. A tree with a depth of `N`
141/// will have `N + 1` levels.
142#[derive(Debug, Clone)]
143pub struct SparseMerkleTree<
144    'a,
145    const N: usize,
146    // M = N + 1
147    const M: usize,
148    F: FieldElement,
149    H: FieldHasher<F, 2>,
150    S: StorageAdapter<Value = F>,
151> {
152    /// A map from leaf indices to leaf data stored as field elements.
153    store: S,
154    /// The hasher used to build the Merkle tree.
155    hasher: H,
156    /// An array of empty hashes hashed with themselves `N` times.
157    empty_nodes: &'a [F; M],
158}
159
160impl<
161        'a,
162        const N: usize,
163        const M: usize,
164        F: FieldElement,
165        H: FieldHasher<F, 2>,
166        S: StorageAdapter<Value = F>,
167    > SparseMerkleTree<'a, N, M, F, H, S>
168{
169    /// Creates a new SMT
170    pub fn new(store: S, hasher: H, empty_nodes: &'a [F; M]) -> Self {
171        assert_eq!(M, N + 1);
172        Self { store, hasher, empty_nodes }
173    }
174
175    /// Takes a batch of field elements, inserts these hashes into the tree,
176    /// and updates the Merkle root.
177    pub fn insert_batch(&mut self, leaves: Vec<(F, F)>) -> ContractResult {
178        if leaves.is_empty() {
179            return Ok(())
180        }
181
182        // Nodes that need recalculating
183        let mut dirty_idxs = Vec::new();
184        for (pos, leaf) in leaves {
185            let idx = util::leaf_pos_to_index::<N, _>(&pos);
186            self.put_node(idx.clone(), leaf)?;
187
188            // Mark node parent as dirty
189            let parent_idx = util::parent(&idx).unwrap();
190            dirty_idxs.push(parent_idx);
191        }
192
193        self.recompute_tree(&mut dirty_idxs)?;
194
195        Ok(())
196    }
197
198    pub fn remove_leaves(&mut self, leaves: Vec<(F, F)>) -> ContractResult {
199        if leaves.is_empty() {
200            return Ok(())
201        }
202
203        let mut dirty_idxs = Vec::new();
204        for (pos, _leaf) in leaves {
205            let idx = util::leaf_pos_to_index::<N, _>(&pos);
206            self.remove_node(&idx)?;
207
208            // Mark node parent as dirty
209            let parent_idx = util::parent(&idx).unwrap();
210            dirty_idxs.push(parent_idx);
211        }
212
213        self.recompute_tree(&mut dirty_idxs)?;
214
215        Ok(())
216    }
217
218    /// Returns the Merkle tree root.
219    pub fn root(&self) -> F {
220        self.get_node(&BigUint::from(0u32))
221    }
222
223    /// Recomputes the Merkle tree depth first from the bottom of the tree
224    fn recompute_tree(&mut self, dirty_idxs: &mut Vec<BigUint>) -> ContractResult {
225        for _ in 0..N + 1 {
226            let mut new_dirty_idxs = vec![];
227
228            for idx in &mut *dirty_idxs {
229                let left_idx = util::left_child(idx);
230                let right_idx = util::right_child(idx);
231                let left = self.get_node(&left_idx);
232                let right = self.get_node(&right_idx);
233                // Recalclate the node
234                let node = self.hasher.hash([left, right]);
235                self.put_node(idx.clone(), node)?;
236
237                // Add this node's parent to the update list
238                let parent_idx = match util::parent(idx) {
239                    Some(idx) => idx,
240                    // We are at the root node so no parents exist
241                    None => break,
242                };
243
244                new_dirty_idxs.push(parent_idx);
245            }
246
247            *dirty_idxs = new_dirty_idxs;
248        }
249
250        Ok(())
251    }
252
253    /// Give the path leading from the leaf at `index` up to the root. This is
254    /// a "proof" in the sense of "valid path in a Merkle tree", not a ZK argument.
255    pub fn prove_membership(&self, pos: &F) -> Path<N, F, H> {
256        let mut path = [F::ZERO; N];
257        let leaf_idx = util::leaf_pos_to_index::<N, _>(pos);
258
259        let mut current_idx = leaf_idx;
260        // Depth first from the bottom of the tree
261        for lvl in (0..N).rev() {
262            let sibling_idx = util::sibling(&current_idx).unwrap();
263            let sibling_node = self.get_node(&sibling_idx);
264            path[lvl] = sibling_node;
265
266            // Now move to the parent
267            current_idx = util::parent(&current_idx).unwrap();
268        }
269
270        Path { path, hasher: self.hasher.clone() }
271    }
272
273    /// Fast lookup for leaf. The SMT can be used as a generic container for
274    /// objects with very little overhead using this method.
275    pub fn get_leaf(&self, pos: &F) -> F {
276        let leaf_idx = util::leaf_pos_to_index::<N, _>(pos);
277        self.get_node(&leaf_idx)
278    }
279
280    fn get_node(&self, idx: &BigUint) -> F {
281        let lvl = util::log2(idx);
282        let empty_node = self.empty_nodes[lvl as usize];
283        self.store.get(idx).unwrap_or(empty_node)
284    }
285
286    fn put_node(&mut self, key: BigUint, value: F) -> ContractResult {
287        self.store.put(key, value)
288    }
289
290    fn remove_node(&mut self, key: &BigUint) -> ContractResult {
291        self.store.del(key)
292    }
293}
294
295/// The path contains a sequence of sibling nodes that make up a Merkle proof.
296/// Each sibling node is used to identify whether the merkle root construction
297/// is valid at the root.
298pub struct Path<const N: usize, F: FieldElement, H: FieldHasher<F, 2>> {
299    /// Path from leaf to root. It is a list of sibling nodes.
300    /// It does not contain the root node.
301    /// Similar to other conventions here, the list starts higher in the tree
302    /// and goes down. So when iterating we start from the end.
303    pub path: [F; N],
304    hasher: H,
305}
306
307impl<const N: usize, F: FieldElement, H: FieldHasher<F, 2>> Path<N, F, H> {
308    pub fn verify(&self, root: &F, leaf: &F, pos: &F) -> bool {
309        let pos = pos.as_biguint();
310        assert!(pos.bits() as usize <= N);
311
312        let mut current_node = *leaf;
313        for i in (0..N).rev() {
314            let sibling_node = self.path[i];
315
316            let is_right = pos.bit((N - 1 - i) as u64);
317            let (left, right) =
318                if is_right { (sibling_node, current_node) } else { (current_node, sibling_node) };
319            //println!("is_right: {}", is_right);
320            //println!("left: {:?}, right: {:?}", left, right);
321            //println!("current_node: {:?}", current_node);
322
323            current_node = self.hasher.hash([left, right]);
324        }
325
326        current_node == *root
327    }
328}
329
330/// A function to generate empty hashes with a given `default_leaf`.
331///
332/// Given a `FieldHasher`, generate a list of `N` hashes consisting of the
333/// `default_leaf` hashed with itself and repeated `N` times with the
334/// intermediate results. These are used to initialize the sparse portion
335/// of the SMT.
336///
337/// Ordering is depth-wise starting from root going down.
338pub fn gen_empty_nodes<const M: usize, F: FieldElement, H: FieldHasher<F, 2>>(
339    hasher: &H,
340    empty_leaf: F,
341) -> [F; M] {
342    let mut empty_nodes = [F::ZERO; M];
343    let mut empty_node = empty_leaf;
344
345    for item in empty_nodes.iter_mut().rev() {
346        *item = empty_node;
347        empty_node = hasher.hash([empty_node, empty_node]);
348    }
349
350    empty_nodes
351}