darkfi_sdk/crypto/
merkle_node.rs1use core::{fmt, str::FromStr};
20use std::{io, iter};
21
22use bridgetree::{BridgeTree, Hashable, Level};
23use darkfi_serial::{SerialDecodable, SerialEncodable};
24use halo2_gadgets::sinsemilla::primitives::HashDomain;
25use lazy_static::lazy_static;
26use pasta_curves::{
27 group::ff::{PrimeField, PrimeFieldBits},
28 pallas,
29};
30use subtle::{Choice, ConditionallySelectable};
31
32#[cfg(feature = "async")]
33use darkfi_serial::async_trait;
34
35use crate::crypto::{
36 constants::{
37 sinsemilla::{i2lebsp_k, L_ORCHARD_MERKLE, MERKLE_CRH_PERSONALIZATION},
38 MERKLE_DEPTH,
39 },
40 util::FieldElemAsStr,
41};
42
43pub type MerkleTree = BridgeTree<MerkleNode, usize, { MERKLE_DEPTH }>;
44
45lazy_static! {
46 static ref UNCOMMITTED_ORCHARD: pallas::Base = pallas::Base::from(2);
47 static ref EMPTY_ROOTS: Vec<MerkleNode> = {
48 iter::empty()
49 .chain(Some(MerkleNode::empty_leaf()))
50 .chain((0..MERKLE_DEPTH).scan(MerkleNode::empty_leaf(), |state, l| {
51 *state = MerkleNode::combine(l.into(), state, state);
52 Some(*state)
53 }))
54 .collect()
55 };
56}
57
58#[repr(C)]
60#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, SerialEncodable, SerialDecodable)]
61pub struct MerkleNode(pallas::Base);
62
63impl MerkleNode {
64 pub fn new(v: pallas::Base) -> Self {
65 Self(v)
66 }
67
68 pub fn inner(&self) -> pallas::Base {
70 self.0
71 }
72
73 pub fn from_bytes(bytes: [u8; 32]) -> Option<Self> {
76 let n = pallas::Base::from_repr(bytes);
77 match bool::from(n.is_some()) {
78 true => Some(Self(n.unwrap())),
79 false => None,
80 }
81 }
82
83 pub fn to_bytes(&self) -> [u8; 32] {
85 self.0.to_repr()
86 }
87}
88
89impl From<pallas::Base> for MerkleNode {
90 fn from(x: pallas::Base) -> Self {
91 Self(x)
92 }
93}
94
95impl fmt::Display for MerkleNode {
96 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97 write!(f, "{}", self.0.to_string())
98 }
99}
100
101impl FromStr for MerkleNode {
102 type Err = io::Error;
103
104 fn from_str(s: &str) -> Result<Self, Self::Err> {
107 let bytes = match bs58::decode(s).into_vec() {
108 Ok(v) => v,
109 Err(e) => return Err(io::Error::other(e)),
110 };
111
112 if bytes.len() != 32 {
113 return Err(io::Error::other("Length of decoded bytes is not 32"))
114 }
115
116 if let Some(merkle_node) = Self::from_bytes(bytes.try_into().unwrap()) {
117 return Ok(merkle_node)
118 }
119
120 Err(io::Error::other("Invalid bytes for MerkleNode"))
121 }
122}
123
124impl ConditionallySelectable for MerkleNode {
125 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
126 Self(pallas::Base::conditional_select(&a.0, &b.0, choice))
127 }
128}
129
130impl Hashable for MerkleNode {
131 fn empty_leaf() -> Self {
132 Self(*UNCOMMITTED_ORCHARD)
133 }
134
135 fn combine(altitude: Level, left: &Self, right: &Self) -> Self {
147 let domain = HashDomain::new(MERKLE_CRH_PERSONALIZATION);
149
150 Self(
151 domain
152 .hash(
153 iter::empty()
154 .chain(i2lebsp_k(altitude.into()).iter().copied())
155 .chain(left.inner().to_le_bits().iter().by_vals().take(L_ORCHARD_MERKLE))
156 .chain(right.inner().to_le_bits().iter().by_vals().take(L_ORCHARD_MERKLE)),
157 )
158 .unwrap_or(pallas::Base::zero()),
159 )
160 }
161
162 fn empty_root(altitude: Level) -> Self {
163 EMPTY_ROOTS[<usize>::from(altitude)]
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 use halo2_proofs::arithmetic::Field;
172 use rand::rngs::OsRng;
173
174 #[test]
175 fn bridgetree_checkpoints() {
176 const MAX_CHECKPOINTS: usize = 100;
177 let mut tree = MerkleTree::new(MAX_CHECKPOINTS);
178 let mut roots = vec![];
179
180 for id in 0..MAX_CHECKPOINTS {
181 let leaf = MerkleNode::from(pallas::Base::random(&mut OsRng));
182 tree.append(leaf);
183 roots.push(tree.root(0).unwrap());
184 tree.checkpoint(id);
185 }
186
187 for root in roots.iter().rev() {
188 tree.rewind();
189 assert!(root == &tree.root(0).unwrap());
190 }
191 }
192}