1use hashbrown::{HashMap, HashSet};
21
22use super::{
23 bits::Bits,
24 node::{Node, Unit},
25 utils::{get_sorted_indices, slice_to_hash},
26 Hash, Proof, HASH_LEN, ROOT_KEY,
27};
28use crate::GenericResult;
29
30#[derive(Clone, Debug)]
31pub(crate) struct MemCache {
32 pub(crate) set: HashSet<Hash>,
33 pub(crate) map: HashMap<Hash, Vec<u8>>,
34}
35
36#[allow(dead_code)]
37impl MemCache {
38 pub(crate) fn new() -> Self {
39 Self { set: HashSet::new(), map: HashMap::with_capacity(1 << 12) }
40 }
41
42 pub(crate) fn clear(&mut self) {
43 self.set.clear();
44 self.map.clear();
45 }
46
47 pub(crate) fn contains(&self, key: &[u8]) -> bool {
48 !self.set.contains(key) && self.map.contains_key(key)
49 }
50
51 pub(crate) fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
52 self.map.get(key).cloned()
53 }
54
55 pub(crate) fn put(&mut self, key: &[u8], value: Vec<u8>) {
56 self.map.insert(slice_to_hash(key), value);
57 self.set.remove(key);
58 }
59
60 pub(crate) fn delete(&mut self, key: &[u8]) {
61 self.set.insert(slice_to_hash(key));
62 }
63}
64
65#[derive(Clone, Debug)]
66pub struct MemoryDb {
67 db: HashMap<Hash, Vec<u8>>,
68 batch: MemCache,
69 batch_on: bool,
70}
71
72#[allow(dead_code)]
73impl MemoryDb {
74 fn new() -> Self {
75 Self { db: HashMap::new(), batch: MemCache::new(), batch_on: false }
76 }
77
78 fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
79 if self.batch_on && self.batch.contains(key) {
80 return Ok(self.batch.get(key));
81 }
82
83 match self.db.get(key) {
84 Some(v) => Ok(Some(v.to_owned())),
85 None => Ok(None),
86 }
87 }
88
89 fn put(&mut self, key: &[u8], value: Vec<u8>) -> GenericResult<()> {
90 if self.batch_on {
91 self.batch.put(key, value);
92 } else {
93 self.db.insert(slice_to_hash(key), value);
94 }
95 Ok(())
96 }
97
98 fn delete(&mut self, key: &[u8]) -> GenericResult<()> {
99 if self.batch_on {
100 self.batch.delete(key);
101 } else {
102 self.db.remove(key);
103 }
104 Ok(())
105 }
106
107 fn init_batch(&mut self) -> GenericResult<()> {
108 if !self.batch_on {
109 self.batch.clear();
110 self.batch_on = true;
111 }
112 Ok(())
113 }
114
115 fn finish_batch(&mut self) -> GenericResult<()> {
116 if self.batch_on {
117 for (key, value) in self.batch.map.drain() {
118 self.db.insert(key, value);
119 }
120 for key in self.batch.set.drain() {
121 self.db.remove(&key);
122 }
123 self.batch_on = false;
124 }
125 Ok(())
126 }
127}
128
129#[derive(Clone, Debug)]
131pub struct Monotree {
132 db: MemoryDb,
133}
134
135impl Default for Monotree {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141impl Monotree {
142 pub fn new() -> Self {
143 Self { db: MemoryDb::new() }
144 }
145
146 fn hash_digest(bytes: &[u8]) -> Hash {
147 let mut hasher = blake3::Hasher::new();
148 hasher.update(bytes);
149 let hash = hasher.finalize();
150 slice_to_hash(hash.as_bytes())
151 }
152
153 pub fn get_headroot(&self) -> GenericResult<Option<Hash>> {
155 let headroot = self.db.get(ROOT_KEY)?;
156 match headroot {
157 Some(root) => Ok(Some(slice_to_hash(&root))),
158 None => Ok(None),
159 }
160 }
161
162 pub fn set_headroot(&mut self, headroot: Option<&Hash>) {
164 if let Some(root) = headroot {
165 self.db.put(ROOT_KEY, root.to_vec()).expect("set_headroot(): hash");
166 }
167 }
168
169 pub fn prepare(&mut self) {
170 self.db.init_batch().expect("prepare(): failed to initialize batch");
171 }
172
173 pub fn commit(&mut self) {
174 self.db.finish_batch().expect("commit(): failed to initialize batch");
175 }
176
177 pub fn insert(
179 &mut self,
180 root: Option<&Hash>,
181 key: &Hash,
182 leaf: &Hash,
183 ) -> GenericResult<Option<Hash>> {
184 match root {
185 None => {
186 let (hash, bits) = (leaf, Bits::new(key));
187 self.put_node(Node::new(Some(Unit { hash, bits }), None))
188 }
189 Some(root) => self.put(root, Bits::new(key), leaf),
190 }
191 }
192
193 fn put_node(&mut self, node: Node) -> GenericResult<Option<Hash>> {
194 let bytes = node.to_bytes()?;
195 let hash = Self::hash_digest(&bytes);
196 self.db.put(&hash, bytes)?;
197 Ok(Some(hash))
198 }
199
200 fn put(&mut self, root: &[u8], bits: Bits, leaf: &[u8]) -> GenericResult<Option<Hash>> {
226 let bytes = self.db.get(root)?.expect("bytes");
227 let (lc, rc) = Node::cells_from_bytes(&bytes, bits.first())?;
228 let unit = lc.as_ref().expect("put(): left-unit");
229 let n = Bits::len_common_bits(&unit.bits, &bits);
230
231 match n {
232 0 => self.put_node(Node::new(lc, Some(Unit { hash: leaf, bits }))),
233 n if n == bits.len() => self.put_node(Node::new(Some(Unit { hash: leaf, bits }), rc)),
234 n if n == unit.bits.len() => {
235 let hash = &self.put(unit.hash, bits.shift(n, false), leaf)?.expect("put(): hash");
236
237 let unit = unit.to_owned();
238 self.put_node(Node::new(Some(Unit { hash, ..unit }), rc))
239 }
240 _ => {
241 let bits = bits.shift(n, false);
242 let ru = Unit { hash: leaf, bits };
243
244 let (cloned, unit) = (unit.bits.clone(), unit.to_owned());
245 let (hash, bits) = (unit.hash, unit.bits.shift(n, false));
246 let lu = Unit { hash, bits };
247
248 let (left, right) = if lu.bits < ru.bits { (lu, ru) } else { (ru, lu) };
250
251 let hash =
252 &self.put_node(Node::new(Some(left), Some(right)))?.expect("put(): hash");
253 let bits = cloned.shift(n, true);
254 self.put_node(Node::new(Some(Unit { hash, bits }), rc))
255 }
256 }
257 }
258
259 pub fn get(&mut self, root: Option<&Hash>, key: &Hash) -> GenericResult<Option<Hash>> {
261 match root {
262 None => Ok(None),
263 Some(root) => self.find_key(root, Bits::new(key)),
264 }
265 }
266
267 fn find_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
268 let bytes = self.db.get(root)?.expect("bytes");
269 let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
270 let unit = cell.as_ref().expect("find_key(): left-unit");
271 let n = Bits::len_common_bits(&unit.bits, &bits);
272 match n {
273 n if n == bits.len() => Ok(Some(slice_to_hash(unit.hash))),
274 n if n == unit.bits.len() => self.find_key(unit.hash, bits.shift(n, false)),
275 _ => Ok(None),
276 }
277 }
278
279 pub fn remove(&mut self, root: Option<&Hash>, key: &[u8]) -> GenericResult<Option<Hash>> {
281 match root {
282 None => Ok(None),
283 Some(root) => self.delete_key(root, Bits::new(key)),
284 }
285 }
286
287 fn delete_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
288 let bytes = self.db.get(root)?.expect("bytes");
289 let (lc, rc) = Node::cells_from_bytes(&bytes, bits.first())?;
290 let unit = lc.as_ref().expect("delete_key(): left-unit");
291 let n = Bits::len_common_bits(&unit.bits, &bits);
292
293 match n {
294 n if n == bits.len() => match rc {
295 Some(_) => self.put_node(Node::new(None, rc)),
296 None => Ok(None),
297 },
298 n if n == unit.bits.len() => {
299 let hash = self.delete_key(unit.hash, bits.shift(n, false))?;
300 match (hash, &rc) {
301 (None, None) => Ok(None),
302 (None, Some(_)) => self.put_node(Node::new(None, rc)),
303 (Some(ref hash), _) => {
304 let unit = unit.to_owned();
305 let lc = Some(Unit { hash, ..unit });
306 self.put_node(Node::new(lc, rc))
307 }
308 }
309 }
310 _ => Ok(None),
311 }
312 }
313
314 pub fn inserts(
317 &mut self,
318 root: Option<&Hash>,
319 keys: &[Hash],
320 leaves: &[Hash],
321 ) -> GenericResult<Option<Hash>> {
322 let indices = get_sorted_indices(keys, false);
323 self.prepare();
324
325 let mut root = root.cloned();
326 for i in indices.iter() {
327 root = self.insert(root.as_ref(), &keys[*i], &leaves[*i])?;
328 }
329
330 self.commit();
331 Ok(root)
332 }
333
334 pub fn gets(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Vec<Option<Hash>>> {
336 let mut leaves: Vec<Option<Hash>> = vec![];
337 for key in keys.iter() {
338 leaves.push(self.get(root, key)?);
339 }
340 Ok(leaves)
341 }
342
343 pub fn removes(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Option<Hash>> {
346 let indices = get_sorted_indices(keys, false);
347 let mut root = root.cloned();
348 self.prepare();
349
350 for i in indices.iter() {
351 root = self.remove(root.as_ref(), &keys[*i])?;
352 }
353
354 self.commit();
355 Ok(root)
356 }
357
358 pub fn get_merkle_proof(
360 &mut self,
361 root: Option<&Hash>,
362 key: &[u8],
363 ) -> GenericResult<Option<Proof>> {
364 let mut proof: Proof = vec![];
365 match root {
366 None => Ok(None),
367 Some(root) => self.gen_proof(root, Bits::new(key), &mut proof),
368 }
369 }
370
371 fn gen_proof(
372 &mut self,
373 root: &[u8],
374 bits: Bits,
375 proof: &mut Proof,
376 ) -> GenericResult<Option<Proof>> {
377 let bytes = self.db.get(root)?.expect("bytes");
378 let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
379 let unit = cell.as_ref().expect("gen_proof(): left-unit");
380 let n = Bits::len_common_bits(&unit.bits, &bits);
381
382 match n {
383 n if n == bits.len() => {
384 proof.push(self.encode_proof(&bytes, bits.first())?);
385 Ok(Some(proof.to_owned()))
386 }
387 n if n == unit.bits.len() => {
388 proof.push(self.encode_proof(&bytes, bits.first())?);
389 self.gen_proof(unit.hash, bits.shift(n, false), proof)
390 }
391 _ => Ok(None),
392 }
393 }
394
395 fn encode_proof(&self, bytes: &[u8], right: bool) -> GenericResult<(bool, Vec<u8>)> {
396 match Node::from_bytes(bytes)? {
397 Node::Soft(_) => Ok((false, bytes[HASH_LEN..].to_vec())),
398 Node::Hard(_, _) => {
399 if right {
400 Ok((true, [&bytes[..bytes.len() - HASH_LEN - 1], &[0x01]].concat()))
401 } else {
402 Ok((false, bytes[HASH_LEN..].to_vec()))
403 }
404 }
405 }
406 }
407}
408
409pub fn verify_proof(root: Option<&Hash>, leaf: &Hash, proof: Option<&Proof>) -> bool {
411 match proof {
412 None => false,
413 Some(proof) => {
414 let mut hash = leaf.to_owned();
415 proof.iter().rev().for_each(|(right, cut)| {
416 if *right {
417 let l = cut.len();
418 let o = [&cut[..l - 1], &hash[..], &cut[l - 1..]].concat();
419 hash = Monotree::hash_digest(&o);
420 } else {
421 let o = [&hash[..], &cut[..]].concat();
422 hash = Monotree::hash_digest(&o);
423 }
424 });
425 root.expect("verify_proof(): root") == &hash
426 }
427 }
428}