1use std::collections::HashMap;
20
21use darkfi::{blockchain::HeaderHash, Error, Result};
22use darkfi_sdk::{
23 crypto::{
24 pasta_prelude::PrimeField,
25 smt::{PoseidonFp, SparseMerkleTree, StorageAdapter, SMT_FP_DEPTH},
26 MerkleTree, SecretKey,
27 },
28 error::{ContractError, ContractResult},
29 pasta::pallas,
30};
31use darkfi_serial::{deserialize, serialize};
32use num_bigint::BigUint;
33use sled_overlay::{sled, SledDbOverlay, SledDbOverlayStateDiff};
34use tracing::error;
35
36pub const SLED_SCANNED_BLOCKS_TREE: &[u8] = b"_scanned_blocks";
37pub const SLED_STATE_INVERSE_DIFF_TREE: &[u8] = b"_state_inverse_diff";
38pub const SLED_MERKLE_TREES_TREE: &[u8] = b"_merkle_trees";
39pub const SLED_MONEY_SMT_TREE: &[u8] = b"_money_smt";
40
41#[derive(Clone)]
43pub struct Cache {
44 pub sled_db: sled::Db,
46 pub scanned_blocks: sled::Tree,
50 pub state_inverse_diff: sled::Tree,
54 pub merkle_trees: sled::Tree,
58 pub money_smt: sled::Tree,
63 }
65
66impl Cache {
67 pub fn new(db: &sled::Db) -> Result<Self> {
69 let scanned_blocks = db.open_tree(SLED_SCANNED_BLOCKS_TREE)?;
70 let state_inverse_diff = db.open_tree(SLED_STATE_INVERSE_DIFF_TREE)?;
71 let merkle_trees = db.open_tree(SLED_MERKLE_TREES_TREE)?;
72 let money_smt = db.open_tree(SLED_MONEY_SMT_TREE)?;
73
74 Ok(Self {
75 sled_db: db.clone(),
76 scanned_blocks,
77 state_inverse_diff,
78 merkle_trees,
79 money_smt,
80 })
81 }
82
83 pub fn insert_merkle_trees(&self, trees: &[(&[u8], &MerkleTree)]) -> Result<()> {
87 let mut batch = sled::Batch::default();
88 for (key, tree) in trees {
89 batch.insert(*key, serialize(*tree));
90 }
91 self.merkle_trees.apply_batch(batch)?;
92 Ok(())
93 }
94
95 pub fn insert_state_inverse_diff(
99 &self,
100 height: &u32,
101 diff: &SledDbOverlayStateDiff,
102 ) -> Result<()> {
103 self.state_inverse_diff.insert(height.to_be_bytes(), serialize(diff))?;
104 Ok(())
105 }
106
107 pub fn get_state_inverse_diff(&self, height: &u32) -> Result<SledDbOverlayStateDiff> {
111 match self.state_inverse_diff.get(height.to_be_bytes())? {
112 Some(found) => Ok(deserialize(&found)?),
113 None => Err(Error::BlockStateInverseDiffNotFound(*height)),
114 }
115 }
116}
117
118pub struct CacheOverlay(pub SledDbOverlay);
120
121impl CacheOverlay {
122 pub fn new(cache: &Cache) -> Result<CacheOverlay> {
124 let protected_trees = vec![
126 SLED_SCANNED_BLOCKS_TREE,
127 SLED_STATE_INVERSE_DIFF_TREE,
128 SLED_MERKLE_TREES_TREE,
129 SLED_MONEY_SMT_TREE,
130 ];
131 let mut overlay = SledDbOverlay::new(&cache.sled_db, protected_trees);
132
133 overlay.open_tree(SLED_SCANNED_BLOCKS_TREE, true)?;
135 overlay.open_tree(SLED_STATE_INVERSE_DIFF_TREE, true)?;
136 overlay.open_tree(SLED_MERKLE_TREES_TREE, true)?;
137 overlay.open_tree(SLED_MONEY_SMT_TREE, true)?;
138
139 Ok(Self(overlay))
140 }
141
142 pub fn insert_scanned_block(
147 &mut self,
148 height: &u32,
149 hash: &HeaderHash,
150 signing_key: &Option<SecretKey>,
151 ) -> Result<()> {
152 let block_signing_key = match signing_key {
153 Some(key) => key.to_string(),
154 None => String::from("-"),
155 };
156 self.0.insert(
157 SLED_SCANNED_BLOCKS_TREE,
158 &height.to_be_bytes(),
159 &serialize(&(hash.to_string(), block_signing_key)),
160 )?;
161 Ok(())
162 }
163}
164
165pub type CacheSmt = SparseMerkleTree<
166 'static,
167 SMT_FP_DEPTH,
168 { SMT_FP_DEPTH + 1 },
169 pallas::Base,
170 PoseidonFp,
171 CacheSmtStorage,
172>;
173
174pub struct CacheSmtStorage {
175 pub overlay: CacheOverlay,
176 tree: Vec<u8>,
177}
178
179impl CacheSmtStorage {
180 pub fn new(overlay: CacheOverlay, tree: &[u8]) -> Self {
181 Self { overlay, tree: tree.to_vec() }
182 }
183
184 pub fn snapshot(&self) -> Result<HashMap<BigUint, pallas::Base>> {
185 let mut smt = HashMap::new();
186 for record in self.overlay.0.iter(&self.tree)? {
187 let (key, value) = record?;
188 let mut repr = [0; 32];
189 repr.copy_from_slice(&value);
190 let Some(value) = pallas::Base::from_repr(repr).into() else {
191 return Err(Error::ParseFailed(
192 "[cache::CacheSmtStorage::snapshot] Value conversion failed",
193 ))
194 };
195 smt.insert(BigUint::from_bytes_le(&key), value);
196 }
197 Ok(smt)
198 }
199}
200
201impl StorageAdapter for CacheSmtStorage {
202 type Value = pallas::Base;
203
204 fn put(&mut self, key: BigUint, value: pallas::Base) -> ContractResult {
205 if let Err(e) = self.overlay.0.insert(&self.tree, &key.to_bytes_le(), &value.to_repr()) {
206 error!(target: "cache::StorageAdapter::put", "Inserting key {key:?}, value {value:?} into DB failed: {e}");
207 return Err(ContractError::SmtPutFailed)
208 }
209 Ok(())
210 }
211
212 fn get(&self, key: &BigUint) -> Option<pallas::Base> {
213 let value = match self.overlay.0.get(&self.tree, &key.to_bytes_le()) {
214 Ok(v) => v,
215 Err(e) => {
216 error!(target: "cache::StorageAdapter::get", "Fetching key {key:?} from DB failed: {e}");
217 return None
218 }
219 };
220
221 let value = value?;
222
223 let mut repr = [0; 32];
224 repr.copy_from_slice(&value);
225
226 pallas::Base::from_repr(repr).into()
227 }
228
229 fn del(&mut self, key: &BigUint) -> ContractResult {
230 if let Err(e) = self.overlay.0.remove(&self.tree, &key.to_bytes_le()) {
231 error!(target: "cache::StorageAdapter::del", "Removing key {key:?} from DB failed: {e}");
232 return Err(ContractError::SmtDelFailed)
233 }
234 Ok(())
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use darkfi::{zk::halo2::Field, Result};
241 use darkfi_sdk::{
242 crypto::smt::{gen_empty_nodes, util::FieldHasher, PoseidonFp, SparseMerkleTree},
243 pasta::pallas,
244 };
245 use rand::rngs::OsRng;
246 use sled_overlay::sled;
247
248 use crate::cache::{Cache, CacheOverlay, CacheSmtStorage, SLED_MONEY_SMT_TREE};
249
250 #[test]
251 fn test_cache_smt() -> Result<()> {
252 let sled_db = sled::Config::new().temporary(true).open()?;
254 let cache = Cache::new(&sled_db)?;
255 let overlay = CacheOverlay::new(&cache)?;
256
257 const HEIGHT: usize = 3;
259 let hasher = PoseidonFp::new();
260 let empty_leaf = pallas::Base::ZERO;
261 let empty_nodes = gen_empty_nodes::<{ HEIGHT + 1 }, _, _>(&hasher, empty_leaf);
262 let store = CacheSmtStorage::new(overlay, SLED_MONEY_SMT_TREE);
263 let mut smt = SparseMerkleTree::<HEIGHT, { HEIGHT + 1 }, _, _, _>::new(
264 store,
265 hasher.clone(),
266 &empty_nodes,
267 );
268
269 assert!(cache.money_smt.is_empty());
271
272 let leaves = vec![
273 (pallas::Base::from(1), pallas::Base::random(&mut OsRng)),
274 (pallas::Base::from(2), pallas::Base::random(&mut OsRng)),
275 (pallas::Base::from(3), pallas::Base::random(&mut OsRng)),
276 ];
277 smt.insert_batch(leaves.clone()).unwrap();
278
279 let hash1 = leaves[0].1;
280 let hash2 = leaves[1].1;
281 let hash3 = leaves[2].1;
282
283 let hash = |l, r| hasher.hash([l, r]);
284
285 let hash01 = hash(empty_nodes[3], hash1);
286 let hash23 = hash(hash2, hash3);
287
288 let hash0123 = hash(hash01, hash23);
289 let root = hash(hash0123, empty_nodes[1]);
290 assert_eq!(root, smt.root());
291
292 let pos = leaves[2].0;
294 let path = smt.prove_membership(&pos);
295 assert_eq!(path.path[0], empty_nodes[1]);
296 assert_eq!(path.path[1], hash01);
297 assert_eq!(path.path[2], hash2);
298
299 assert_eq!(hash23, hash(path.path[2], hash3));
300 assert_eq!(hash0123, hash(path.path[1], hash(path.path[2], hash3)));
301 assert_eq!(root, hash(hash(path.path[1], hash(path.path[2], hash3)), path.path[0]));
302
303 assert!(path.verify(&root, &hash3, &pos));
304
305 let diff = smt.store.overlay.0.diff(&[])?;
307
308 smt.store.overlay.0.apply_diff(&diff)?;
310
311 assert!(!cache.money_smt.is_empty());
313
314 smt.store.overlay.0.apply_diff(&diff.inverse())?;
316
317 assert!(cache.money_smt.is_empty());
319
320 Ok(())
321 }
322}