1use hashbrown::{HashMap, HashSet};
21use sled_overlay::{sled::Tree, SledDbOverlay};
22
23use super::{
24 bits::Bits,
25 node::{Node, Unit},
26 utils::{get_sorted_indices, slice_to_hash},
27 Hash, Proof, HASH_LEN, ROOT_KEY,
28};
29use crate::{ContractError, GenericResult};
30
31#[derive(Clone, Debug)]
32pub(crate) struct MemCache {
33 pub(crate) set: HashSet<Hash>,
34 pub(crate) map: HashMap<Hash, Vec<u8>>,
35}
36
37#[allow(dead_code)]
38impl MemCache {
39 pub(crate) fn new() -> Self {
40 Self { set: HashSet::new(), map: HashMap::with_capacity(1 << 12) }
41 }
42
43 pub(crate) fn clear(&mut self) {
44 self.set.clear();
45 self.map.clear();
46 }
47
48 pub(crate) fn contains(&self, key: &[u8]) -> bool {
49 !self.set.contains(key) && self.map.contains_key(key)
50 }
51
52 pub(crate) fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
53 self.map.get(key).cloned()
54 }
55
56 pub(crate) fn put(&mut self, key: &[u8], value: Vec<u8>) {
57 self.map.insert(slice_to_hash(key), value);
58 self.set.remove(key);
59 }
60
61 pub(crate) fn del(&mut self, key: &[u8]) {
62 self.set.insert(slice_to_hash(key));
63 }
64}
65
66pub trait MonotreeStorageAdapter {
68 fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()>;
70 fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>>;
72 fn del(&mut self, key: &Hash) -> GenericResult<()>;
74 fn init_batch(&mut self) -> GenericResult<()>;
76 fn finish_batch(&mut self) -> GenericResult<()>;
78}
79
80#[derive(Clone, Debug)]
82pub struct MemoryDb {
83 db: HashMap<Hash, Vec<u8>>,
84 batch: MemCache,
85 batch_on: bool,
86}
87
88#[allow(clippy::new_without_default)]
89impl MemoryDb {
90 pub fn new() -> Self {
91 Self { db: HashMap::new(), batch: MemCache::new(), batch_on: false }
92 }
93}
94
95impl MonotreeStorageAdapter for MemoryDb {
96 fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()> {
97 if self.batch_on {
98 self.batch.put(key, value);
99 } else {
100 self.db.insert(slice_to_hash(key), value);
101 }
102
103 Ok(())
104 }
105
106 fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
107 if self.batch_on && self.batch.contains(key) {
108 return Ok(self.batch.get(key));
109 }
110
111 match self.db.get(key) {
112 Some(v) => Ok(Some(v.to_owned())),
113 None => Ok(None),
114 }
115 }
116
117 fn del(&mut self, key: &Hash) -> GenericResult<()> {
118 if self.batch_on {
119 self.batch.del(key);
120 } else {
121 self.db.remove(key);
122 }
123
124 Ok(())
125 }
126
127 fn init_batch(&mut self) -> GenericResult<()> {
128 if !self.batch_on {
129 self.batch.clear();
130 self.batch_on = true;
131 }
132
133 Ok(())
134 }
135
136 fn finish_batch(&mut self) -> GenericResult<()> {
137 if self.batch_on {
138 for (key, value) in self.batch.map.drain() {
139 self.db.insert(key, value);
140 }
141 for key in self.batch.set.drain() {
142 self.db.remove(&key);
143 }
144 self.batch_on = false;
145 }
146
147 Ok(())
148 }
149}
150
151#[derive(Clone)]
153pub struct SledTreeDb {
154 tree: Tree,
155 batch: MemCache,
156 batch_on: bool,
157}
158
159impl SledTreeDb {
160 pub fn new(tree: &Tree) -> Self {
161 Self { tree: tree.clone(), batch: MemCache::new(), batch_on: false }
162 }
163}
164
165impl MonotreeStorageAdapter for SledTreeDb {
166 fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()> {
167 if self.batch_on {
168 self.batch.put(key, value);
169 } else if let Err(e) = self.tree.insert(slice_to_hash(key), value) {
170 return Err(ContractError::IoError(e.to_string()))
171 }
172
173 Ok(())
174 }
175
176 fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
177 if self.batch_on && self.batch.contains(key) {
178 return Ok(self.batch.get(key));
179 }
180
181 match self.tree.get(key) {
182 Ok(Some(v)) => Ok(Some(v.to_vec())),
183 Ok(None) => Ok(None),
184 Err(e) => Err(ContractError::IoError(e.to_string())),
185 }
186 }
187
188 fn del(&mut self, key: &Hash) -> GenericResult<()> {
189 if self.batch_on {
190 self.batch.del(key);
191 } else if let Err(e) = self.tree.remove(key) {
192 return Err(ContractError::IoError(e.to_string()));
193 }
194
195 Ok(())
196 }
197
198 fn init_batch(&mut self) -> GenericResult<()> {
199 if !self.batch_on {
200 self.batch.clear();
201 self.batch_on = true;
202 }
203
204 Ok(())
205 }
206
207 fn finish_batch(&mut self) -> GenericResult<()> {
208 if self.batch_on {
209 for (key, value) in self.batch.map.drain() {
210 if let Err(e) = self.tree.insert(key, value) {
211 return Err(ContractError::IoError(e.to_string()))
212 }
213 }
214 for key in self.batch.set.drain() {
215 if let Err(e) = self.tree.remove(key) {
216 return Err(ContractError::IoError(e.to_string()))
217 }
218 }
219 self.batch_on = false;
220 }
221
222 Ok(())
223 }
224}
225
226pub struct SledOverlayDb<'a> {
228 overlay: &'a mut SledDbOverlay,
229 tree: [u8; 32],
230 batch: MemCache,
231 batch_on: bool,
232}
233
234impl<'a> SledOverlayDb<'a> {
235 pub fn new(
236 overlay: &'a mut SledDbOverlay,
237 tree: &[u8; 32],
238 ) -> GenericResult<SledOverlayDb<'a>> {
239 if let Err(e) = overlay.open_tree(tree, false) {
240 return Err(ContractError::IoError(e.to_string()))
241 };
242 Ok(Self { overlay, tree: *tree, batch: MemCache::new(), batch_on: false })
243 }
244}
245
246impl MonotreeStorageAdapter for SledOverlayDb<'_> {
247 fn put(&mut self, key: &Hash, value: Vec<u8>) -> GenericResult<()> {
248 if self.batch_on {
249 self.batch.put(key, value);
250 } else if let Err(e) = self.overlay.insert(&self.tree, &slice_to_hash(key), &value) {
251 return Err(ContractError::IoError(e.to_string()))
252 }
253
254 Ok(())
255 }
256
257 fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
258 if self.batch_on && self.batch.contains(key) {
259 return Ok(self.batch.get(key));
260 }
261
262 match self.overlay.get(&self.tree, key) {
263 Ok(Some(v)) => Ok(Some(v.to_vec())),
264 Ok(None) => Ok(None),
265 Err(e) => Err(ContractError::IoError(e.to_string())),
266 }
267 }
268
269 fn del(&mut self, key: &Hash) -> GenericResult<()> {
270 if self.batch_on {
271 self.batch.del(key);
272 } else if let Err(e) = self.overlay.remove(&self.tree, key) {
273 return Err(ContractError::IoError(e.to_string()));
274 }
275
276 Ok(())
277 }
278
279 fn init_batch(&mut self) -> GenericResult<()> {
280 if !self.batch_on {
281 self.batch.clear();
282 self.batch_on = true;
283 }
284
285 Ok(())
286 }
287
288 fn finish_batch(&mut self) -> GenericResult<()> {
289 if self.batch_on {
290 for (key, value) in self.batch.map.drain() {
291 if let Err(e) = self.overlay.insert(&self.tree, &key, &value) {
292 return Err(ContractError::IoError(e.to_string()))
293 }
294 }
295 for key in self.batch.set.drain() {
296 if let Err(e) = self.overlay.remove(&self.tree, &key) {
297 return Err(ContractError::IoError(e.to_string()))
298 }
299 }
300 self.batch_on = false;
301 }
302
303 Ok(())
304 }
305}
306
307#[derive(Clone, Debug)]
312pub struct Monotree<D: MonotreeStorageAdapter> {
313 db: D,
314}
315
316impl<D: MonotreeStorageAdapter> Monotree<D> {
317 pub fn new(db: D) -> Self {
318 Self { db }
319 }
320
321 fn hash_digest(bytes: &[u8]) -> Hash {
322 let mut hasher = blake3::Hasher::new();
323 hasher.update(bytes);
324 let hash = hasher.finalize();
325 slice_to_hash(hash.as_bytes())
326 }
327
328 pub fn get_headroot(&self) -> GenericResult<Option<Hash>> {
330 let headroot = self.db.get(ROOT_KEY)?;
331 match headroot {
332 Some(root) => Ok(Some(slice_to_hash(&root))),
333 None => Ok(None),
334 }
335 }
336
337 pub fn set_headroot(&mut self, headroot: Option<&Hash>) {
339 if let Some(root) = headroot {
340 self.db.put(ROOT_KEY, root.to_vec()).expect("set_headroot(): hash");
341 }
342 }
343
344 pub fn prepare(&mut self) {
345 self.db.init_batch().expect("prepare(): failed to initialize batch");
346 }
347
348 pub fn commit(&mut self) {
349 self.db.finish_batch().expect("commit(): failed to initialize batch");
350 }
351
352 pub fn insert(
354 &mut self,
355 root: Option<&Hash>,
356 key: &Hash,
357 leaf: &Hash,
358 ) -> GenericResult<Option<Hash>> {
359 match root {
360 None => {
361 let (hash, bits) = (leaf, Bits::new(key));
362 self.put_node(Node::new(Some(Unit { hash, bits }), None))
363 }
364 Some(root) => self.put(root, Bits::new(key), leaf),
365 }
366 }
367
368 fn put_node(&mut self, node: Node) -> GenericResult<Option<Hash>> {
369 let bytes = node.to_bytes()?;
370 let hash = Self::hash_digest(&bytes);
371 self.db.put(&hash, bytes)?;
372 Ok(Some(hash))
373 }
374
375 fn put(&mut self, root: &[u8], bits: Bits, leaf: &[u8]) -> GenericResult<Option<Hash>> {
401 let bytes = self.db.get(root)?.expect("put(): bytes");
402 let (left, right) = Node::cells_from_bytes(&bytes, bits.first())?;
403 let unit = left.as_ref().expect("put(): left-unit");
404 let n = Bits::len_common_bits(&unit.bits, &bits);
405
406 match n {
407 0 => self.put_node(Node::new(left, Some(Unit { hash: leaf, bits }))),
408 n if n == bits.len() => {
409 self.put_node(Node::new(Some(Unit { hash: leaf, bits }), right))
410 }
411 n if n == unit.bits.len() => {
412 let hash =
413 &self.put(unit.hash, bits.drop(n), leaf)?.expect("put(): consume & pass-over");
414
415 self.put_node(Node::new(Some(Unit { hash, bits: unit.bits.to_owned() }), right))
416 }
417 _ => {
418 let hash = &self
419 .put_node(Node::new(
420 Some(Unit { hash: unit.hash, bits: unit.bits.drop(n) }),
421 Some(Unit { hash: leaf, bits: bits.drop(n) }),
422 ))?
423 .expect("put(): split-node");
424
425 self.put_node(Node::new(Some(Unit { hash, bits: unit.bits.take(n) }), right))
426 }
427 }
428 }
429
430 pub fn get(&mut self, root: Option<&Hash>, key: &Hash) -> GenericResult<Option<Hash>> {
432 match root {
433 None => Ok(None),
434 Some(root) => self.find_key(root, Bits::new(key)),
435 }
436 }
437
438 fn find_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
439 let bytes = self.db.get(root)?.expect("find_key(): bytes");
440 let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
441 let unit = cell.as_ref().expect("find_key(): left-unit");
442 let n = Bits::len_common_bits(&unit.bits, &bits);
443 match n {
444 n if n == bits.len() => Ok(Some(slice_to_hash(unit.hash))),
445 n if n == unit.bits.len() => self.find_key(unit.hash, bits.drop(n)),
446 _ => Ok(None),
447 }
448 }
449
450 pub fn remove(&mut self, root: Option<&Hash>, key: &[u8]) -> GenericResult<Option<Hash>> {
452 match root {
453 None => Ok(None),
454 Some(root) => self.delete_key(root, Bits::new(key)),
455 }
456 }
457
458 fn delete_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
459 let bytes = self.db.get(root)?.expect("delete_key(): bytes");
460 let (left, right) = Node::cells_from_bytes(&bytes, bits.first())?;
461 let unit = left.as_ref().expect("delete_key(): left-unit");
462 let n = Bits::len_common_bits(&unit.bits, &bits);
463
464 match n {
465 n if n == bits.len() => match right {
466 Some(_) => self.put_node(Node::new(None, right)),
467 None => Ok(None),
468 },
469 n if n == unit.bits.len() => {
470 let hash = self.delete_key(unit.hash, bits.drop(n))?;
471 match (hash, &right) {
472 (None, None) => Ok(None),
473 (None, Some(_)) => self.put_node(Node::new(None, right)),
474 (Some(ref hash), _) => {
475 let unit = unit.to_owned();
476 let left = Some(Unit { hash, ..unit });
477 self.put_node(Node::new(left, right))
478 }
479 }
480 }
481 _ => Ok(None),
482 }
483 }
484
485 pub fn inserts(
488 &mut self,
489 root: Option<&Hash>,
490 keys: &[Hash],
491 leaves: &[Hash],
492 ) -> GenericResult<Option<Hash>> {
493 let indices = get_sorted_indices(keys, false);
494 self.prepare();
495
496 let mut root = root.cloned();
497 for i in indices.iter() {
498 root = self.insert(root.as_ref(), &keys[*i], &leaves[*i])?;
499 }
500
501 self.commit();
502 Ok(root)
503 }
504
505 pub fn gets(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Vec<Option<Hash>>> {
507 let mut leaves: Vec<Option<Hash>> = vec![];
508 for key in keys.iter() {
509 leaves.push(self.get(root, key)?);
510 }
511 Ok(leaves)
512 }
513
514 pub fn removes(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Option<Hash>> {
517 let indices = get_sorted_indices(keys, false);
518 let mut root = root.cloned();
519 self.prepare();
520
521 for i in indices.iter() {
522 root = self.remove(root.as_ref(), &keys[*i])?;
523 }
524
525 self.commit();
526 Ok(root)
527 }
528
529 pub fn get_merkle_proof(
531 &mut self,
532 root: Option<&Hash>,
533 key: &[u8],
534 ) -> GenericResult<Option<Proof>> {
535 let mut proof: Proof = vec![];
536 match root {
537 None => Ok(None),
538 Some(root) => self.gen_proof(root, Bits::new(key), &mut proof),
539 }
540 }
541
542 fn gen_proof(
543 &mut self,
544 root: &[u8],
545 bits: Bits,
546 proof: &mut Proof,
547 ) -> GenericResult<Option<Proof>> {
548 let bytes = self.db.get(root)?.expect("gen_proof(): bytes");
549 let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
550 let unit = cell.as_ref().expect("gen_proof(): left-unit");
551 let n = Bits::len_common_bits(&unit.bits, &bits);
552
553 match n {
554 n if n == bits.len() => {
555 proof.push(self.encode_proof(&bytes, bits.first())?);
556 Ok(Some(proof.to_owned()))
557 }
558 n if n == unit.bits.len() => {
559 proof.push(self.encode_proof(&bytes, bits.first())?);
560 self.gen_proof(unit.hash, bits.drop(n), proof)
561 }
562 _ => Ok(None),
563 }
564 }
565
566 fn encode_proof(&self, bytes: &[u8], right: bool) -> GenericResult<(bool, Vec<u8>)> {
567 match Node::from_bytes(bytes)? {
568 Node::Soft(_) => Ok((false, bytes[HASH_LEN..].to_vec())),
569 Node::Hard(_, _) => {
570 if right {
571 Ok((true, [&bytes[..bytes.len() - HASH_LEN - 1], &[0x01]].concat()))
572 } else {
573 Ok((false, bytes[HASH_LEN..].to_vec()))
574 }
575 }
576 }
577 }
578}
579
580pub fn verify_proof(root: Option<&Hash>, leaf: &Hash, proof: Option<&Proof>) -> bool {
584 match proof {
585 None => false,
586 Some(proof) => {
587 let mut hash = leaf.to_owned();
588 proof.iter().rev().for_each(|(right, cut)| {
589 if *right {
590 let l = cut.len();
591 let o = [&cut[..l - 1], &hash[..], &cut[l - 1..]].concat();
592 hash = Monotree::<MemoryDb>::hash_digest(&o);
593 } else {
594 let o = [&hash[..], &cut[..]].concat();
595 hash = Monotree::<MemoryDb>::hash_digest(&o);
596 }
597 });
598 root.expect("verify_proof(): root") == &hash
599 }
600 }
601}