1use std::collections::HashMap;
20
21use darkfi::{blockchain::HeaderHash, Error, Result};
22use darkfi_sdk::{
23 crypto::{
24 pasta_prelude::PrimeField,
25 smt::{PoseidonFp, SparseMerkleTree, StorageAdapter, SMT_FP_DEPTH},
26 MerkleTree,
27 },
28 error::{ContractError, ContractResult},
29 pasta::pallas,
30};
31use darkfi_serial::{deserialize, serialize};
32use log::error;
33use num_bigint::BigUint;
34use sled_overlay::{sled, SledDbOverlay, SledDbOverlayStateDiff};
35
36pub const SLED_SCANNED_BLOCKS_TREE: &[u8] = b"_scanned_blocks";
37pub const SLED_STATE_INVERSE_DIFF_TREE: &[u8] = b"_state_inverse_diff";
38pub const SLED_MERKLE_TREES_TREE: &[u8] = b"_merkle_trees";
39pub const SLED_MONEY_SMT_TREE: &[u8] = b"_money_smt";
40
41#[derive(Clone)]
43pub struct Cache {
44 pub sled_db: sled::Db,
46 pub scanned_blocks: sled::Tree,
50 pub state_inverse_diff: sled::Tree,
54 pub merkle_trees: sled::Tree,
58 pub money_smt: sled::Tree,
63 }
65
66impl Cache {
67 pub fn new(db: &sled::Db) -> Result<Self> {
69 let scanned_blocks = db.open_tree(SLED_SCANNED_BLOCKS_TREE)?;
70 let state_inverse_diff = db.open_tree(SLED_STATE_INVERSE_DIFF_TREE)?;
71 let merkle_trees = db.open_tree(SLED_MERKLE_TREES_TREE)?;
72 let money_smt = db.open_tree(SLED_MONEY_SMT_TREE)?;
73
74 Ok(Self {
75 sled_db: db.clone(),
76 scanned_blocks,
77 state_inverse_diff,
78 merkle_trees,
79 money_smt,
80 })
81 }
82
83 pub fn insert_merkle_trees(&self, trees: &[(&[u8], &MerkleTree)]) -> Result<()> {
87 let mut batch = sled::Batch::default();
88 for (key, tree) in trees {
89 batch.insert(*key, serialize(*tree));
90 }
91 self.merkle_trees.apply_batch(batch)?;
92 Ok(())
93 }
94
95 pub fn insert_state_inverse_diff(
99 &self,
100 height: &u32,
101 diff: &SledDbOverlayStateDiff,
102 ) -> Result<()> {
103 self.state_inverse_diff.insert(height.to_be_bytes(), serialize(diff))?;
104 Ok(())
105 }
106
107 pub fn get_state_inverse_diff(&self, height: &u32) -> Result<SledDbOverlayStateDiff> {
111 match self.state_inverse_diff.get(height.to_be_bytes())? {
112 Some(found) => Ok(deserialize(&found)?),
113 None => Err(Error::BlockStateInverseDiffNotFound(*height)),
114 }
115 }
116}
117
118pub struct CacheOverlay(pub SledDbOverlay);
120
121impl CacheOverlay {
122 pub fn new(cache: &Cache) -> Result<CacheOverlay> {
124 let protected_trees = vec![
126 SLED_SCANNED_BLOCKS_TREE,
127 SLED_STATE_INVERSE_DIFF_TREE,
128 SLED_MERKLE_TREES_TREE,
129 SLED_MONEY_SMT_TREE,
130 ];
131 let mut overlay = SledDbOverlay::new(&cache.sled_db, protected_trees);
132
133 overlay.open_tree(SLED_SCANNED_BLOCKS_TREE, true)?;
135 overlay.open_tree(SLED_STATE_INVERSE_DIFF_TREE, true)?;
136 overlay.open_tree(SLED_MERKLE_TREES_TREE, true)?;
137 overlay.open_tree(SLED_MONEY_SMT_TREE, true)?;
138
139 Ok(Self(overlay))
140 }
141
142 pub fn insert_scanned_block(&mut self, height: &u32, hash: &HeaderHash) -> Result<()> {
146 self.0.insert(
147 SLED_SCANNED_BLOCKS_TREE,
148 &height.to_be_bytes(),
149 &serialize(&hash.to_string()),
150 )?;
151 Ok(())
152 }
153}
154
155pub type CacheSmt = SparseMerkleTree<
156 'static,
157 SMT_FP_DEPTH,
158 { SMT_FP_DEPTH + 1 },
159 pallas::Base,
160 PoseidonFp,
161 CacheSmtStorage,
162>;
163
164pub struct CacheSmtStorage {
165 pub overlay: CacheOverlay,
166 tree: Vec<u8>,
167}
168
169impl CacheSmtStorage {
170 pub fn new(overlay: CacheOverlay, tree: &[u8]) -> Self {
171 Self { overlay, tree: tree.to_vec() }
172 }
173
174 pub fn snapshot(&self) -> Result<HashMap<BigUint, pallas::Base>> {
175 let mut smt = HashMap::new();
176 for record in self.overlay.0.iter(&self.tree)? {
177 let (key, value) = record?;
178 let mut repr = [0; 32];
179 repr.copy_from_slice(&value);
180 let Some(value) = pallas::Base::from_repr(repr).into() else {
181 return Err(Error::ParseFailed(
182 "[cache::CacheSmtStorage::snapshot] Value conversion failed",
183 ))
184 };
185 smt.insert(BigUint::from_bytes_le(&key), value);
186 }
187 Ok(smt)
188 }
189}
190
191impl StorageAdapter for CacheSmtStorage {
192 type Value = pallas::Base;
193
194 fn put(&mut self, key: BigUint, value: pallas::Base) -> ContractResult {
195 if let Err(e) = self.overlay.0.insert(&self.tree, &key.to_bytes_le(), &value.to_repr()) {
196 error!(target: "cache::StorageAdapter::put", "Inserting key {key:?}, value {value:?} into DB failed: {e}");
197 return Err(ContractError::SmtPutFailed)
198 }
199 Ok(())
200 }
201
202 fn get(&self, key: &BigUint) -> Option<pallas::Base> {
203 let value = match self.overlay.0.get(&self.tree, &key.to_bytes_le()) {
204 Ok(v) => v,
205 Err(e) => {
206 error!(target: "cache::StorageAdapter::get", "Fetching key {key:?} from DB failed: {e}");
207 return None
208 }
209 };
210
211 let value = value?;
212
213 let mut repr = [0; 32];
214 repr.copy_from_slice(&value);
215
216 pallas::Base::from_repr(repr).into()
217 }
218
219 fn del(&mut self, key: &BigUint) -> ContractResult {
220 if let Err(e) = self.overlay.0.remove(&self.tree, &key.to_bytes_le()) {
221 error!(target: "cache::StorageAdapter::del", "Removing key {key:?} from DB failed: {e}");
222 return Err(ContractError::SmtDelFailed)
223 }
224 Ok(())
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use darkfi::{zk::halo2::Field, Result};
231 use darkfi_sdk::{
232 crypto::smt::{gen_empty_nodes, util::FieldHasher, PoseidonFp, SparseMerkleTree},
233 pasta::pallas,
234 };
235 use rand::rngs::OsRng;
236 use sled_overlay::sled;
237
238 use crate::cache::{Cache, CacheOverlay, CacheSmtStorage, SLED_MONEY_SMT_TREE};
239
240 #[test]
241 fn test_cache_smt() -> Result<()> {
242 let sled_db = sled::Config::new().temporary(true).open()?;
244 let cache = Cache::new(&sled_db)?;
245 let overlay = CacheOverlay::new(&cache)?;
246
247 const HEIGHT: usize = 3;
249 let hasher = PoseidonFp::new();
250 let empty_leaf = pallas::Base::ZERO;
251 let empty_nodes = gen_empty_nodes::<{ HEIGHT + 1 }, _, _>(&hasher, empty_leaf);
252 let store = CacheSmtStorage::new(overlay, SLED_MONEY_SMT_TREE);
253 let mut smt = SparseMerkleTree::<HEIGHT, { HEIGHT + 1 }, _, _, _>::new(
254 store,
255 hasher.clone(),
256 &empty_nodes,
257 );
258
259 assert!(cache.money_smt.is_empty());
261
262 let leaves = vec![
263 (pallas::Base::from(1), pallas::Base::random(&mut OsRng)),
264 (pallas::Base::from(2), pallas::Base::random(&mut OsRng)),
265 (pallas::Base::from(3), pallas::Base::random(&mut OsRng)),
266 ];
267 smt.insert_batch(leaves.clone()).unwrap();
268
269 let hash1 = leaves[0].1;
270 let hash2 = leaves[1].1;
271 let hash3 = leaves[2].1;
272
273 let hash = |l, r| hasher.hash([l, r]);
274
275 let hash01 = hash(empty_nodes[3], hash1);
276 let hash23 = hash(hash2, hash3);
277
278 let hash0123 = hash(hash01, hash23);
279 let root = hash(hash0123, empty_nodes[1]);
280 assert_eq!(root, smt.root());
281
282 let pos = leaves[2].0;
284 let path = smt.prove_membership(&pos);
285 assert_eq!(path.path[0], empty_nodes[1]);
286 assert_eq!(path.path[1], hash01);
287 assert_eq!(path.path[2], hash2);
288
289 assert_eq!(hash23, hash(path.path[2], hash3));
290 assert_eq!(hash0123, hash(path.path[1], hash(path.path[2], hash3)));
291 assert_eq!(root, hash(hash(path.path[1], hash(path.path[2], hash3)), path.path[0]));
292
293 assert!(path.verify(&root, &hash3, &pos));
294
295 let diff = smt.store.overlay.0.diff(&[])?;
297
298 smt.store.overlay.0.apply_diff(&diff)?;
300
301 assert!(!cache.money_smt.is_empty());
303
304 smt.store.overlay.0.apply_diff(&diff.inverse())?;
306
307 assert!(cache.money_smt.is_empty());
309
310 Ok(())
311 }
312}