1use hashbrown::{HashMap, HashSet};
21
22use super::{
23 bits::Bits,
24 node::{Node, Unit},
25 utils::{get_sorted_indices, slice_to_hash},
26 Hash, Proof, HASH_LEN, ROOT_KEY,
27};
28use crate::GenericResult;
29
30#[derive(Clone, Debug)]
31pub(crate) struct MemCache {
32 pub(crate) set: HashSet<Hash>,
33 pub(crate) map: HashMap<Hash, Vec<u8>>,
34}
35
36#[allow(dead_code)]
37impl MemCache {
38 pub(crate) fn new() -> Self {
39 Self { set: HashSet::new(), map: HashMap::with_capacity(1 << 12) }
40 }
41
42 pub(crate) fn clear(&mut self) {
43 self.set.clear();
44 self.map.clear();
45 }
46
47 pub(crate) fn contains(&self, key: &[u8]) -> bool {
48 !self.set.contains(key) && self.map.contains_key(key)
49 }
50
51 pub(crate) fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
52 self.map.get(key).cloned()
53 }
54
55 pub(crate) fn put(&mut self, key: &[u8], value: Vec<u8>) {
56 self.map.insert(slice_to_hash(key), value);
57 self.set.remove(key);
58 }
59
60 pub(crate) fn delete(&mut self, key: &[u8]) {
61 self.set.insert(slice_to_hash(key));
62 }
63}
64
65#[derive(Clone, Debug)]
66pub struct MemoryDb {
67 db: HashMap<Hash, Vec<u8>>,
68 batch: MemCache,
69 batch_on: bool,
70}
71
72#[allow(dead_code)]
73impl MemoryDb {
74 fn new() -> Self {
75 Self { db: HashMap::new(), batch: MemCache::new(), batch_on: false }
76 }
77
78 fn get(&self, key: &[u8]) -> GenericResult<Option<Vec<u8>>> {
79 if self.batch_on && self.batch.contains(key) {
80 return Ok(self.batch.get(key));
81 }
82
83 match self.db.get(key) {
84 Some(v) => Ok(Some(v.to_owned())),
85 None => Ok(None),
86 }
87 }
88
89 fn put(&mut self, key: &[u8], value: Vec<u8>) -> GenericResult<()> {
90 if self.batch_on {
91 self.batch.put(key, value);
92 } else {
93 self.db.insert(slice_to_hash(key), value);
94 }
95 Ok(())
96 }
97
98 fn delete(&mut self, key: &[u8]) -> GenericResult<()> {
99 if self.batch_on {
100 self.batch.delete(key);
101 } else {
102 self.db.remove(key);
103 }
104 Ok(())
105 }
106
107 fn init_batch(&mut self) -> GenericResult<()> {
108 if !self.batch_on {
109 self.batch.clear();
110 self.batch_on = true;
111 }
112 Ok(())
113 }
114
115 fn finish_batch(&mut self) -> GenericResult<()> {
116 if self.batch_on {
117 for (key, value) in self.batch.map.drain() {
118 self.db.insert(key, value);
119 }
120 for key in self.batch.set.drain() {
121 self.db.remove(&key);
122 }
123 self.batch_on = false;
124 }
125 Ok(())
126 }
127}
128
129#[derive(Clone, Debug)]
131pub struct Monotree {
132 db: MemoryDb,
133}
134
135impl Default for Monotree {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141impl Monotree {
142 pub fn new() -> Self {
143 Self { db: MemoryDb::new() }
144 }
145
146 fn hash_digest(bytes: &[u8]) -> Hash {
147 let mut hasher = blake3::Hasher::new();
148 hasher.update(bytes);
149 let hash = hasher.finalize();
150 slice_to_hash(hash.as_bytes())
151 }
152
153 pub fn get_headroot(&self) -> GenericResult<Option<Hash>> {
155 let headroot = self.db.get(ROOT_KEY)?;
156 match headroot {
157 Some(root) => Ok(Some(slice_to_hash(&root))),
158 None => Ok(None),
159 }
160 }
161
162 pub fn set_headroot(&mut self, headroot: Option<&Hash>) {
164 if let Some(root) = headroot {
165 self.db.put(ROOT_KEY, root.to_vec()).expect("set_headroot(): hash");
166 }
167 }
168
169 pub fn prepare(&mut self) {
170 self.db.init_batch().expect("prepare(): failed to initialize batch");
171 }
172
173 pub fn commit(&mut self) {
174 self.db.finish_batch().expect("commit(): failed to initialize batch");
175 }
176
177 pub fn insert(
179 &mut self,
180 root: Option<&Hash>,
181 key: &Hash,
182 leaf: &Hash,
183 ) -> GenericResult<Option<Hash>> {
184 match root {
185 None => {
186 let (hash, bits) = (leaf, Bits::new(key));
187 self.put_node(Node::new(Some(Unit { hash, bits }), None))
188 }
189 Some(root) => self.put(root, Bits::new(key), leaf),
190 }
191 }
192
193 fn put_node(&mut self, node: Node) -> GenericResult<Option<Hash>> {
194 let bytes = node.to_bytes()?;
195 let hash = Self::hash_digest(&bytes);
196 self.db.put(&hash, bytes)?;
197 Ok(Some(hash))
198 }
199
200 fn put(&mut self, root: &[u8], bits: Bits, leaf: &[u8]) -> GenericResult<Option<Hash>> {
226 let bytes = self.db.get(root)?.expect("put(): bytes");
227 let (left, right) = Node::cells_from_bytes(&bytes, bits.first())?;
228 let unit = left.as_ref().expect("put(): left-unit");
229 let n = Bits::len_common_bits(&unit.bits, &bits);
230
231 match n {
232 0 => self.put_node(Node::new(left, Some(Unit { hash: leaf, bits }))),
233 n if n == bits.len() => {
234 self.put_node(Node::new(Some(Unit { hash: leaf, bits }), right))
235 }
236 n if n == unit.bits.len() => {
237 let hash =
238 &self.put(unit.hash, bits.drop(n), leaf)?.expect("put(): consume & pass-over");
239
240 self.put_node(Node::new(Some(Unit { hash, bits: unit.bits.to_owned() }), right))
241 }
242 _ => {
243 let hash = &self
244 .put_node(Node::new(
245 Some(Unit { hash: unit.hash, bits: unit.bits.drop(n) }),
246 Some(Unit { hash: leaf, bits: bits.drop(n) }),
247 ))?
248 .expect("put(): split-node");
249
250 self.put_node(Node::new(Some(Unit { hash, bits: unit.bits.take(n) }), right))
251 }
252 }
253 }
254
255 pub fn get(&mut self, root: Option<&Hash>, key: &Hash) -> GenericResult<Option<Hash>> {
257 match root {
258 None => Ok(None),
259 Some(root) => self.find_key(root, Bits::new(key)),
260 }
261 }
262
263 fn find_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
264 let bytes = self.db.get(root)?.expect("find_key(): bytes");
265 let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
266 let unit = cell.as_ref().expect("find_key(): left-unit");
267 let n = Bits::len_common_bits(&unit.bits, &bits);
268 match n {
269 n if n == bits.len() => Ok(Some(slice_to_hash(unit.hash))),
270 n if n == unit.bits.len() => self.find_key(unit.hash, bits.drop(n)),
271 _ => Ok(None),
272 }
273 }
274
275 pub fn remove(&mut self, root: Option<&Hash>, key: &[u8]) -> GenericResult<Option<Hash>> {
277 match root {
278 None => Ok(None),
279 Some(root) => self.delete_key(root, Bits::new(key)),
280 }
281 }
282
283 fn delete_key(&mut self, root: &[u8], bits: Bits) -> GenericResult<Option<Hash>> {
284 let bytes = self.db.get(root)?.expect("delete_key(): bytes");
285 let (left, right) = Node::cells_from_bytes(&bytes, bits.first())?;
286 let unit = left.as_ref().expect("delete_key(): left-unit");
287 let n = Bits::len_common_bits(&unit.bits, &bits);
288
289 match n {
290 n if n == bits.len() => match right {
291 Some(_) => self.put_node(Node::new(None, right)),
292 None => Ok(None),
293 },
294 n if n == unit.bits.len() => {
295 let hash = self.delete_key(unit.hash, bits.drop(n))?;
296 match (hash, &right) {
297 (None, None) => Ok(None),
298 (None, Some(_)) => self.put_node(Node::new(None, right)),
299 (Some(ref hash), _) => {
300 let unit = unit.to_owned();
301 let left = Some(Unit { hash, ..unit });
302 self.put_node(Node::new(left, right))
303 }
304 }
305 }
306 _ => Ok(None),
307 }
308 }
309
310 pub fn inserts(
313 &mut self,
314 root: Option<&Hash>,
315 keys: &[Hash],
316 leaves: &[Hash],
317 ) -> GenericResult<Option<Hash>> {
318 let indices = get_sorted_indices(keys, false);
319 self.prepare();
320
321 let mut root = root.cloned();
322 for i in indices.iter() {
323 root = self.insert(root.as_ref(), &keys[*i], &leaves[*i])?;
324 }
325
326 self.commit();
327 Ok(root)
328 }
329
330 pub fn gets(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Vec<Option<Hash>>> {
332 let mut leaves: Vec<Option<Hash>> = vec![];
333 for key in keys.iter() {
334 leaves.push(self.get(root, key)?);
335 }
336 Ok(leaves)
337 }
338
339 pub fn removes(&mut self, root: Option<&Hash>, keys: &[Hash]) -> GenericResult<Option<Hash>> {
342 let indices = get_sorted_indices(keys, false);
343 let mut root = root.cloned();
344 self.prepare();
345
346 for i in indices.iter() {
347 root = self.remove(root.as_ref(), &keys[*i])?;
348 }
349
350 self.commit();
351 Ok(root)
352 }
353
354 pub fn get_merkle_proof(
356 &mut self,
357 root: Option<&Hash>,
358 key: &[u8],
359 ) -> GenericResult<Option<Proof>> {
360 let mut proof: Proof = vec![];
361 match root {
362 None => Ok(None),
363 Some(root) => self.gen_proof(root, Bits::new(key), &mut proof),
364 }
365 }
366
367 fn gen_proof(
368 &mut self,
369 root: &[u8],
370 bits: Bits,
371 proof: &mut Proof,
372 ) -> GenericResult<Option<Proof>> {
373 let bytes = self.db.get(root)?.expect("gen_proof(): bytes");
374 let (cell, _) = Node::cells_from_bytes(&bytes, bits.first())?;
375 let unit = cell.as_ref().expect("gen_proof(): left-unit");
376 let n = Bits::len_common_bits(&unit.bits, &bits);
377
378 match n {
379 n if n == bits.len() => {
380 proof.push(self.encode_proof(&bytes, bits.first())?);
381 Ok(Some(proof.to_owned()))
382 }
383 n if n == unit.bits.len() => {
384 proof.push(self.encode_proof(&bytes, bits.first())?);
385 self.gen_proof(unit.hash, bits.drop(n), proof)
386 }
387 _ => Ok(None),
388 }
389 }
390
391 fn encode_proof(&self, bytes: &[u8], right: bool) -> GenericResult<(bool, Vec<u8>)> {
392 match Node::from_bytes(bytes)? {
393 Node::Soft(_) => Ok((false, bytes[HASH_LEN..].to_vec())),
394 Node::Hard(_, _) => {
395 if right {
396 Ok((true, [&bytes[..bytes.len() - HASH_LEN - 1], &[0x01]].concat()))
397 } else {
398 Ok((false, bytes[HASH_LEN..].to_vec()))
399 }
400 }
401 }
402 }
403}
404
405pub fn verify_proof(root: Option<&Hash>, leaf: &Hash, proof: Option<&Proof>) -> bool {
407 match proof {
408 None => false,
409 Some(proof) => {
410 let mut hash = leaf.to_owned();
411 proof.iter().rev().for_each(|(right, cut)| {
412 if *right {
413 let l = cut.len();
414 let o = [&cut[..l - 1], &hash[..], &cut[l - 1..]].concat();
415 hash = Monotree::hash_digest(&o);
416 } else {
417 let o = [&hash[..], &cut[..]].concat();
418 hash = Monotree::hash_digest(&o);
419 }
420 });
421 root.expect("verify_proof(): root") == &hash
422 }
423 }
424}