darkfi_sdk/crypto/
merkle_node.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2025 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use core::{fmt, str::FromStr};
20use std::{io, iter};
21
22use bridgetree::{BridgeTree, Hashable, Level};
23use darkfi_serial::{SerialDecodable, SerialEncodable};
24use halo2_gadgets::sinsemilla::primitives::HashDomain;
25use lazy_static::lazy_static;
26use pasta_curves::{
27    group::ff::{PrimeField, PrimeFieldBits},
28    pallas,
29};
30use subtle::{Choice, ConditionallySelectable};
31
32#[cfg(feature = "async")]
33use darkfi_serial::async_trait;
34
35use crate::crypto::{
36    constants::{
37        sinsemilla::{i2lebsp_k, L_ORCHARD_MERKLE, MERKLE_CRH_PERSONALIZATION},
38        MERKLE_DEPTH,
39    },
40    util::FieldElemAsStr,
41};
42
43pub type MerkleTree = BridgeTree<MerkleNode, usize, { MERKLE_DEPTH }>;
44
45lazy_static! {
46    static ref UNCOMMITTED_ORCHARD: pallas::Base = pallas::Base::from(2);
47    static ref EMPTY_ROOTS: Vec<MerkleNode> = {
48        iter::empty()
49            .chain(Some(MerkleNode::empty_leaf()))
50            .chain((0..MERKLE_DEPTH).scan(MerkleNode::empty_leaf(), |state, l| {
51                *state = MerkleNode::combine(l.into(), state, state);
52                Some(*state)
53            }))
54            .collect()
55    };
56}
57
58/// The `MerkleNode` is represented as a base field element.
59#[repr(C)]
60#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, SerialEncodable, SerialDecodable)]
61pub struct MerkleNode(pallas::Base);
62
63impl MerkleNode {
64    pub fn new(v: pallas::Base) -> Self {
65        Self(v)
66    }
67
68    /// Reference the raw inner base field element
69    pub fn inner(&self) -> pallas::Base {
70        self.0
71    }
72
73    /// Try to create a `MerkleNode` type from the given 32 bytes.
74    /// Returns `Some` if the bytes fit in the base field, and `None` if not.
75    pub fn from_bytes(bytes: [u8; 32]) -> Option<Self> {
76        let n = pallas::Base::from_repr(bytes);
77        match bool::from(n.is_some()) {
78            true => Some(Self(n.unwrap())),
79            false => None,
80        }
81    }
82
83    /// Convert the `MerkleNode` type into 32 raw bytes
84    pub fn to_bytes(&self) -> [u8; 32] {
85        self.0.to_repr()
86    }
87}
88
89impl From<pallas::Base> for MerkleNode {
90    fn from(x: pallas::Base) -> Self {
91        Self(x)
92    }
93}
94
95impl fmt::Display for MerkleNode {
96    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97        write!(f, "{}", self.0.to_string())
98    }
99}
100
101impl FromStr for MerkleNode {
102    type Err = io::Error;
103
104    /// Tries to decode a base58 string into a `MerkleNode` type.
105    /// This string is the same string received by calling `MerkleNode::to_string()`.
106    fn from_str(s: &str) -> Result<Self, Self::Err> {
107        let bytes = match bs58::decode(s).into_vec() {
108            Ok(v) => v,
109            Err(e) => return Err(io::Error::other(e)),
110        };
111
112        if bytes.len() != 32 {
113            return Err(io::Error::other("Length of decoded bytes is not 32"))
114        }
115
116        if let Some(merkle_node) = Self::from_bytes(bytes.try_into().unwrap()) {
117            return Ok(merkle_node)
118        }
119
120        Err(io::Error::other("Invalid bytes for MerkleNode"))
121    }
122}
123
124impl ConditionallySelectable for MerkleNode {
125    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
126        Self(pallas::Base::conditional_select(&a.0, &b.0, choice))
127    }
128}
129
130impl Hashable for MerkleNode {
131    fn empty_leaf() -> Self {
132        Self(*UNCOMMITTED_ORCHARD)
133    }
134
135    /// Implements `MerkleCRH^Orchard` as defined in
136    /// <https://zips.z.cash/protocol/protocol.pdf#orchardmerklecrh>
137    ///
138    /// The layer with 2^n nodes is called "layer n":
139    ///     - leaves are at layer MERKLE_DEPTH_ORCHARD = 32;
140    ///     - the root is at layer 0.
141    /// `l` is MERKLE_DEPTH_ORCHARD - layer - 1.
142    ///     - when hashing two leaves, we produce a node on the layer
143    ///       above the leaves, i.e. layer = 31, l = 0
144    ///     - when hashing to the final root, we produce the anchor
145    ///       with layer = 0, l = 31.
146    fn combine(altitude: Level, left: &Self, right: &Self) -> Self {
147        // MerkleCRH Sinsemilla hash domain.
148        let domain = HashDomain::new(MERKLE_CRH_PERSONALIZATION);
149
150        Self(
151            domain
152                .hash(
153                    iter::empty()
154                        .chain(i2lebsp_k(altitude.into()).iter().copied())
155                        .chain(left.inner().to_le_bits().iter().by_vals().take(L_ORCHARD_MERKLE))
156                        .chain(right.inner().to_le_bits().iter().by_vals().take(L_ORCHARD_MERKLE)),
157                )
158                .unwrap_or(pallas::Base::zero()),
159        )
160    }
161
162    fn empty_root(altitude: Level) -> Self {
163        EMPTY_ROOTS[<usize>::from(altitude)]
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    use halo2_proofs::arithmetic::Field;
172    use rand::rngs::OsRng;
173
174    #[test]
175    fn bridgetree_checkpoints() {
176        const MAX_CHECKPOINTS: usize = 100;
177        let mut tree = MerkleTree::new(MAX_CHECKPOINTS);
178        let mut roots = vec![];
179
180        for id in 0..MAX_CHECKPOINTS {
181            let leaf = MerkleNode::from(pallas::Base::random(&mut OsRng));
182            tree.append(leaf);
183            roots.push(tree.root(0).unwrap());
184            tree.checkpoint(id);
185        }
186
187        for root in roots.iter().rev() {
188            tree.rewind();
189            assert!(root == &tree.root(0).unwrap());
190        }
191    }
192}