drk/
walletdb.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2025 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use std::{
20    path::PathBuf,
21    sync::{Arc, Mutex},
22};
23
24use darkfi_sdk::{
25    crypto::{
26        pasta_prelude::PrimeField,
27        smt::{PoseidonFp, SparseMerkleTree, StorageAdapter, SMT_FP_DEPTH},
28    },
29    error::{ContractError, ContractResult},
30    pasta::pallas,
31};
32use log::{debug, error};
33use num_bigint::BigUint;
34use rusqlite::{
35    types::{ToSql, Value},
36    Connection,
37};
38
39use crate::error::{WalletDbError, WalletDbResult};
40
41pub type WalletPtr = Arc<WalletDb>;
42
43/// Structure representing base wallet database operations.
44pub struct WalletDb {
45    /// Connection to the SQLite database.
46    pub conn: Mutex<Connection>,
47    /// Inverse queries cache, in case we want to rollback
48    /// executed queries, stored as raw SQL strings.
49    inverse_cache: Mutex<Vec<String>>,
50}
51
52impl WalletDb {
53    /// Create a new wallet database handler. If `path` is `None`, create it in memory.
54    pub fn new(path: Option<PathBuf>, password: Option<&str>) -> WalletDbResult<WalletPtr> {
55        let Ok(conn) = (match path.clone() {
56            Some(p) => Connection::open(p),
57            None => Connection::open_in_memory(),
58        }) else {
59            return Err(WalletDbError::ConnectionFailed);
60        };
61
62        if let Some(password) = password {
63            if let Err(e) = conn.pragma_update(None, "key", password) {
64                error!(target: "walletdb::new", "[WalletDb] Pragma update failed: {e}");
65                return Err(WalletDbError::PragmaUpdateError);
66            };
67        }
68        if let Err(e) = conn.pragma_update(None, "foreign_keys", "ON") {
69            error!(target: "walletdb::new", "[WalletDb] Pragma update failed: {e}");
70            return Err(WalletDbError::PragmaUpdateError);
71        };
72
73        debug!(target: "walletdb::new", "[WalletDb] Opened Sqlite connection at \"{path:?}\"");
74        Ok(Arc::new(Self { conn: Mutex::new(conn), inverse_cache: Mutex::new(vec![]) }))
75    }
76
77    /// This function executes a given SQL query that contains multiple SQL statements,
78    /// that don't contain any parameters.
79    pub fn exec_batch_sql(&self, query: &str) -> WalletDbResult<()> {
80        debug!(target: "walletdb::exec_batch_sql", "[WalletDb] Executing batch SQL query:\n{query}");
81        let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
82        if let Err(e) = conn.execute_batch(query) {
83            error!(target: "walletdb::exec_batch_sql", "[WalletDb] Query failed: {e}");
84            return Err(WalletDbError::QueryExecutionFailed)
85        };
86
87        Ok(())
88    }
89
90    /// This function executes a given SQL query, but isn't able to return anything.
91    /// Therefore it's best to use it for initializing a table or similar things.
92    pub fn exec_sql(&self, query: &str, params: &[&dyn ToSql]) -> WalletDbResult<()> {
93        debug!(target: "walletdb::exec_sql", "[WalletDb] Executing SQL query:\n{query}");
94        let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
95
96        // If no params are provided, execute directly
97        if params.is_empty() {
98            if let Err(e) = conn.execute(query, ()) {
99                error!(target: "walletdb::exec_sql", "[WalletDb] Query failed: {e}");
100                return Err(WalletDbError::QueryExecutionFailed)
101            };
102            return Ok(())
103        }
104
105        // First we prepare the query
106        let Ok(mut stmt) = conn.prepare(query) else {
107            return Err(WalletDbError::QueryPreparationFailed)
108        };
109
110        // Execute the query using provided params
111        if let Err(e) = stmt.execute(params) {
112            error!(target: "walletdb::exec_sql", "[WalletDb] Query failed: {e}");
113            return Err(WalletDbError::QueryExecutionFailed)
114        };
115
116        // Finalize query and drop connection lock
117        if let Err(e) = stmt.finalize() {
118            error!(target: "walletdb::exec_sql", "[WalletDb] Query finalization failed: {e}");
119            return Err(WalletDbError::QueryFinalizationFailed)
120        };
121        drop(conn);
122
123        Ok(())
124    }
125
126    /// Generate a new statement for provided query and bind the provided params,
127    /// returning the raw SQL query as a string.
128    pub fn create_prepared_statement(
129        &self,
130        query: &str,
131        params: &[&dyn ToSql],
132    ) -> WalletDbResult<String> {
133        debug!(target: "walletdb::create_prepared_statement", "[WalletDb] Preparing statement for SQL query:\n{query}");
134        let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
135
136        // First we prepare the query
137        let Ok(mut stmt) = conn.prepare(query) else {
138            return Err(WalletDbError::QueryPreparationFailed)
139        };
140
141        // Bind all provided params
142        for (index, param) in params.iter().enumerate() {
143            if stmt.raw_bind_parameter(index + 1, param).is_err() {
144                return Err(WalletDbError::QueryPreparationFailed)
145            };
146        }
147
148        // Grab the raw SQL
149        let query = stmt.expanded_sql().unwrap();
150
151        // Drop statement and the connection lock
152        drop(stmt);
153        drop(conn);
154
155        Ok(query)
156    }
157
158    /// Generate a `SELECT` query for provided table from selected column names and
159    /// provided `WHERE` clauses. Named parameters are supported in the `WHERE` clauses,
160    /// assuming they follow the normal formatting ":{column_name}".
161    fn generate_select_query(
162        &self,
163        table: &str,
164        col_names: &[&str],
165        params: &[(&str, &dyn ToSql)],
166    ) -> String {
167        let mut query = if col_names.is_empty() {
168            format!("SELECT * FROM {}", table)
169        } else {
170            format!("SELECT {} FROM {}", col_names.join(", "), table)
171        };
172        if params.is_empty() {
173            return query
174        }
175
176        let mut where_str = Vec::with_capacity(params.len());
177        for (k, _) in params {
178            let col = &k[1..];
179            where_str.push(format!("{col} = {k}"));
180        }
181        query.push_str(&format!(" WHERE {}", where_str.join(" AND ")));
182
183        query
184    }
185
186    /// Query provided table from selected column names and provided `WHERE` clauses,
187    /// for a single row.
188    pub fn query_single(
189        &self,
190        table: &str,
191        col_names: &[&str],
192        params: &[(&str, &dyn ToSql)],
193    ) -> WalletDbResult<Vec<Value>> {
194        // Generate `SELECT` query
195        let query = self.generate_select_query(table, col_names, params);
196        debug!(target: "walletdb::query_single", "[WalletDb] Executing SQL query:\n{query}");
197
198        // First we prepare the query
199        let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
200
201        let Ok(mut stmt) = conn.prepare(&query) else {
202            return Err(WalletDbError::QueryPreparationFailed)
203        };
204
205        // Execute the query using provided params
206        let Ok(mut rows) = stmt.query(params) else {
207            return Err(WalletDbError::QueryExecutionFailed)
208        };
209
210        // Check if row exists
211        let Ok(next) = rows.next() else { return Err(WalletDbError::QueryExecutionFailed) };
212        let row = match next {
213            Some(row_result) => row_result,
214            None => return Err(WalletDbError::RowNotFound),
215        };
216
217        // Grab returned values
218        let mut result = vec![];
219        if col_names.is_empty() {
220            let mut idx = 0;
221            loop {
222                let Ok(value) = row.get(idx) else { break };
223                result.push(value);
224                idx += 1;
225            }
226        } else {
227            for col in col_names {
228                let Ok(value) = row.get(*col) else {
229                    return Err(WalletDbError::ParseColumnValueError)
230                };
231                result.push(value);
232            }
233        }
234
235        Ok(result)
236    }
237
238    /// Query provided table from selected column names and provided `WHERE` clauses,
239    /// for multiple rows.
240    pub fn query_multiple(
241        &self,
242        table: &str,
243        col_names: &[&str],
244        params: &[(&str, &dyn ToSql)],
245    ) -> WalletDbResult<Vec<Vec<Value>>> {
246        // Generate `SELECT` query
247        let query = self.generate_select_query(table, col_names, params);
248        debug!(target: "walletdb::query_multiple", "[WalletDb] Executing SQL query:\n{query}");
249
250        // First we prepare the query
251        let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
252        let Ok(mut stmt) = conn.prepare(&query) else {
253            return Err(WalletDbError::QueryPreparationFailed)
254        };
255
256        // Execute the query using provided converted params
257        let Ok(mut rows) = stmt.query(params) else {
258            return Err(WalletDbError::QueryExecutionFailed)
259        };
260
261        // Loop over returned rows and parse them
262        let mut result = vec![];
263        loop {
264            // Check if an error occured
265            let row = match rows.next() {
266                Ok(r) => r,
267                Err(_) => return Err(WalletDbError::QueryExecutionFailed),
268            };
269
270            // Check if no row was returned
271            let row = match row {
272                Some(r) => r,
273                None => break,
274            };
275
276            // Grab row returned values
277            let mut row_values = vec![];
278            if col_names.is_empty() {
279                let mut idx = 0;
280                loop {
281                    let Ok(value) = row.get(idx) else { break };
282                    row_values.push(value);
283                    idx += 1;
284                }
285            } else {
286                for col in col_names {
287                    let Ok(value) = row.get(*col) else {
288                        return Err(WalletDbError::ParseColumnValueError)
289                    };
290                    row_values.push(value);
291                }
292            }
293            result.push(row_values);
294        }
295
296        Ok(result)
297    }
298
299    /// Query provided table using provided query for multiple rows.
300    pub fn query_custom(
301        &self,
302        query: &str,
303        params: &[&dyn ToSql],
304    ) -> WalletDbResult<Vec<Vec<Value>>> {
305        debug!(target: "walletdb::query_custom", "[WalletDb] Executing SQL query:\n{query}");
306
307        // First we prepare the query
308        let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
309        let Ok(mut stmt) = conn.prepare(query) else {
310            return Err(WalletDbError::QueryPreparationFailed)
311        };
312
313        // Execute the query using provided converted params
314        let Ok(mut rows) = stmt.query(params) else {
315            return Err(WalletDbError::QueryExecutionFailed)
316        };
317
318        // Loop over returned rows and parse them
319        let mut result = vec![];
320        loop {
321            // Check if an error occured
322            let row = match rows.next() {
323                Ok(r) => r,
324                Err(_) => return Err(WalletDbError::QueryExecutionFailed),
325            };
326
327            // Check if no row was returned
328            let row = match row {
329                Some(r) => r,
330                None => break,
331            };
332
333            // Grab row returned values
334            let mut row_values = vec![];
335            let mut idx = 0;
336            loop {
337                let Ok(value) = row.get(idx) else { break };
338                row_values.push(value);
339                idx += 1;
340            }
341            result.push(row_values);
342        }
343
344        Ok(result)
345    }
346
347    /// Auxiliary function to store provided inverse query into our cache.
348    pub fn cache_inverse(&self, query: String) -> WalletDbResult<()> {
349        debug!(target: "walletdb::cache_inverse", "[WalletDb] Storing query:\n{query}");
350        let Ok(mut cache) = self.inverse_cache.lock() else {
351            return Err(WalletDbError::FailedToAquireLock)
352        };
353
354        // Push the query into the cache
355        cache.push(query);
356
357        // Drop cache lock
358        drop(cache);
359
360        Ok(())
361    }
362
363    /// Auxiliary function to retrieve cached inverse queries into a single SQL execution block.
364    /// The final query will contain the queries in reverse order, and cache is cleared afterwards.
365    pub fn grab_inverse_cache_block(&self) -> WalletDbResult<String> {
366        // Grab cache lock
367        debug!(target: "walletdb::grab_inverse_block", "[WalletDb] Grabbing cached inverse queries");
368        let Ok(cache) = self.inverse_cache.lock() else {
369            return Err(WalletDbError::FailedToAquireLock)
370        };
371
372        // Build the full SQL block query
373        let mut inverse_batch = String::from("BEGIN;");
374        for query in cache.iter().rev() {
375            inverse_batch += query;
376        }
377        inverse_batch += "END;";
378
379        // Drop the lock
380        drop(cache);
381
382        Ok(inverse_batch)
383    }
384
385    /// Auxiliary function to clear inverse queries cache.
386    pub fn clear_inverse_cache(&self) -> WalletDbResult<()> {
387        // Grab cache lock
388        let Ok(mut cache) = self.inverse_cache.lock() else {
389            return Err(WalletDbError::FailedToAquireLock)
390        };
391
392        // Clear cache
393        *cache = vec![];
394
395        // Drop the lock
396        drop(cache);
397
398        Ok(())
399    }
400}
401
402/// Custom implementation of rusqlite::named_params! to use `expr` instead of `literal` as `$param_name`,
403/// and append the ":" named parameters prefix.
404#[macro_export]
405macro_rules! convert_named_params {
406    () => {
407        &[] as &[(&str, &dyn rusqlite::types::ToSql)]
408    };
409    ($(($param_name:expr, $param_val:expr)),+ $(,)?) => {
410        &[$((format!(":{}", $param_name).as_str(), &$param_val as &dyn rusqlite::types::ToSql)),+] as &[(&str, &dyn rusqlite::types::ToSql)]
411    };
412}
413
414/// Wallet SMT definition
415pub type WalletSmt<'a> = SparseMerkleTree<
416    'static,
417    SMT_FP_DEPTH,
418    { SMT_FP_DEPTH + 1 },
419    pallas::Base,
420    PoseidonFp,
421    WalletStorage<'a>,
422>;
423
424/// An SMT adapter for wallet SQLite database storage.
425pub struct WalletStorage<'a> {
426    wallet: &'a WalletPtr,
427    table: &'a str,
428    key_col: &'a str,
429    value_col: &'a str,
430}
431
432impl<'a> WalletStorage<'a> {
433    pub fn new(
434        wallet: &'a WalletPtr,
435        table: &'a str,
436        key_col: &'a str,
437        value_col: &'a str,
438    ) -> Self {
439        Self { wallet, table, key_col, value_col }
440    }
441}
442
443impl StorageAdapter for WalletStorage<'_> {
444    type Value = pallas::Base;
445
446    fn put(&mut self, key: BigUint, value: pallas::Base) -> ContractResult {
447        // Check if record already exists to create the corresponding query,
448        // its param and its inverse.
449        let (query, params, inverse) = match self.get(&key) {
450            Some(v) => {
451                // Create an SQL `UPDATE` query
452                let q = format!(
453                    "UPDATE {} SET {} = ?1 WHERE {} = ?2;",
454                    self.table, self.value_col, self.key_col
455                );
456
457                // Create its inverse query
458                let i = match self.wallet.create_prepared_statement(
459                    &format!(
460                        "UPDATE {} SET {} = ?1 WHERE {} = ?2;",
461                        self.table, self.value_col, self.key_col
462                    ),
463                    rusqlite::params![v.to_repr(), key.to_bytes_le()],
464                ) {
465                    Ok(i) => i,
466                    Err(e) => {
467                        error!(target: "walletdb::StorageAdapter::put", "Creating inverse query for key {key:?} failed: {e:?}");
468                        return Err(ContractError::SmtPutFailed)
469                    }
470                };
471
472                (q, rusqlite::params![value.to_repr(), key.to_bytes_le()], i)
473            }
474            None => {
475                // Create an SQL `INSERT` query
476                let q = format!(
477                    "INSERT INTO {} ({}, {}) VALUES (?1, ?2);",
478                    self.table, self.key_col, self.value_col
479                );
480
481                // Create its inverse query
482                let i = match self.wallet.create_prepared_statement(
483                    &format!("DELETE FROM {} WHERE {} = ?1;", self.table, self.key_col),
484                    rusqlite::params![key.to_bytes_le()],
485                ) {
486                    Ok(i) => i,
487                    Err(e) => {
488                        error!(target: "walletdb::StorageAdapter::put", "Creating inverse query for key {key:?} failed: {e:?}");
489                        return Err(ContractError::SmtPutFailed)
490                    }
491                };
492
493                (q, rusqlite::params![key.to_bytes_le(), value.to_repr()], i)
494            }
495        };
496
497        // Execute the query
498        if let Err(e) = self.wallet.exec_sql(&query, params) {
499            error!(target: "walletdb::StorageAdapter::put", "Inserting key {key:?}, value {value:?} into DB failed: {e:?}");
500            return Err(ContractError::SmtPutFailed)
501        }
502
503        // Store its inverse
504        if let Err(e) = self.wallet.cache_inverse(inverse) {
505            error!(target: "walletdb::StorageAdapter::put", "Inserting inverse query into cache failed: {e:?}");
506            return Err(ContractError::SmtPutFailed)
507        }
508
509        Ok(())
510    }
511
512    fn get(&self, key: &BigUint) -> Option<pallas::Base> {
513        let row = match self.wallet.query_single(
514            self.table,
515            &[self.value_col],
516            convert_named_params! {(self.key_col, key.to_bytes_le())},
517        ) {
518            Ok(r) => r,
519            Err(WalletDbError::RowNotFound) => return None,
520            Err(e) => {
521                error!(target: "walletdb::StorageAdapter::get", "Fetching key {key:?} from DB failed: {e:?}");
522                return None
523            }
524        };
525
526        let Value::Blob(ref value_bytes) = row[0] else {
527            error!(target: "walletdb::StorageAdapter::get", "Parsing key {key:?} value bytes");
528            return None
529        };
530
531        let mut repr = [0; 32];
532        repr.copy_from_slice(value_bytes);
533
534        pallas::Base::from_repr(repr).into()
535    }
536
537    fn del(&mut self, key: &BigUint) -> ContractResult {
538        // Check if record already exists to create the corresponding query,
539        // its param and its inverse.
540        let (query, params, inverse) = match self.get(key) {
541            Some(value) => {
542                // Create an SQL `DELETE` query
543                let q = format!("DELETE FROM {} WHERE {} = ?1;", self.table, self.key_col);
544
545                // Create its inverse query
546                let i = match self.wallet.create_prepared_statement(
547                    &format!(
548                        "INSERT INTO {} ({}, {}) VALUES (?1, ?2);",
549                        self.table, self.key_col, self.value_col
550                    ),
551                    rusqlite::params![key.to_bytes_le(), value.to_repr()],
552                ) {
553                    Ok(i) => i,
554                    Err(e) => {
555                        error!(target: "walletdb::StorageAdapter::del", "Creating inverse query for key {key:?} failed: {e:?}");
556                        return Err(ContractError::SmtDelFailed)
557                    }
558                };
559
560                (q, rusqlite::params![key.to_bytes_le()], i)
561            }
562            None => {
563                // If record doesn't exist do nothing
564                return Ok(())
565            }
566        };
567
568        // Execute the query
569        if let Err(e) = self.wallet.exec_sql(&query, params) {
570            error!(target: "walletdb::StorageAdapter::del", "Removing key {key:?} from DB failed: {e:?}");
571            return Err(ContractError::SmtDelFailed)
572        }
573
574        // Store its inverse
575        if let Err(e) = self.wallet.cache_inverse(inverse) {
576            error!(target: "walletdb::StorageAdapter::del", "Inserting inverse query into cache failed: {e:?}");
577            return Err(ContractError::SmtDelFailed)
578        }
579
580        Ok(())
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use darkfi::zk::halo2::Field;
587    use darkfi_sdk::{
588        crypto::smt::{gen_empty_nodes, util::FieldHasher, PoseidonFp, SparseMerkleTree},
589        pasta::pallas,
590    };
591    use rand::rngs::OsRng;
592    use rusqlite::types::Value;
593
594    use crate::walletdb::{WalletDb, WalletStorage};
595
596    #[test]
597    fn test_mem_wallet() {
598        let wallet = WalletDb::new(None, Some("foobar")).unwrap();
599        wallet
600            .exec_batch_sql(
601                "CREATE TABLE mista ( numba INTEGER ); INSERT INTO mista ( numba ) VALUES ( 42 );",
602            )
603            .unwrap();
604
605        let ret = wallet.query_single("mista", &["numba"], &[]).unwrap();
606        assert_eq!(ret.len(), 1);
607        let numba: i64 = if let Value::Integer(numba) = ret[0] { numba } else { -1 };
608        assert_eq!(numba, 42);
609
610        let ret = wallet.query_custom("SELECT numba FROM mista;", &[]).unwrap();
611        assert_eq!(ret.len(), 1);
612        assert_eq!(ret[0].len(), 1);
613        let numba: i64 = if let Value::Integer(numba) = ret[0][0] { numba } else { -1 };
614        assert_eq!(numba, 42);
615    }
616
617    #[test]
618    fn test_query_single() {
619        let wallet = WalletDb::new(None, None).unwrap();
620        wallet
621            .exec_batch_sql("CREATE TABLE mista ( why INTEGER, are TEXT, you INTEGER, gae BLOB );")
622            .unwrap();
623
624        let why = 42;
625        let are = "are".to_string();
626        let you = 69;
627        let gae = vec![42u8; 32];
628
629        wallet
630            .exec_sql(
631                "INSERT INTO mista ( why, are, you, gae ) VALUES (?1, ?2, ?3, ?4);",
632                rusqlite::params![why, are, you, gae],
633            )
634            .unwrap();
635
636        let ret = wallet.query_single("mista", &["why", "are", "you", "gae"], &[]).unwrap();
637        assert_eq!(ret.len(), 4);
638        assert_eq!(ret[0], Value::Integer(why));
639        assert_eq!(ret[1], Value::Text(are.clone()));
640        assert_eq!(ret[2], Value::Integer(you));
641        assert_eq!(ret[3], Value::Blob(gae.clone()));
642        let ret = wallet.query_custom("SELECT why, are, you, gae FROM mista;", &[]).unwrap();
643        assert_eq!(ret.len(), 1);
644        assert_eq!(ret[0].len(), 4);
645        assert_eq!(ret[0][0], Value::Integer(why));
646        assert_eq!(ret[0][1], Value::Text(are.clone()));
647        assert_eq!(ret[0][2], Value::Integer(you));
648        assert_eq!(ret[0][3], Value::Blob(gae.clone()));
649
650        let ret = wallet
651            .query_single(
652                "mista",
653                &["gae"],
654                rusqlite::named_params! {":why": why, ":are": are, ":you": you},
655            )
656            .unwrap();
657        assert_eq!(ret.len(), 1);
658        assert_eq!(ret[0], Value::Blob(gae.clone()));
659        let ret = wallet
660            .query_custom(
661                "SELECT gae FROM mista WHERE why = ?1 AND are = ?2 AND you = ?3;",
662                rusqlite::params![why, are, you],
663            )
664            .unwrap();
665        assert_eq!(ret.len(), 1);
666        assert_eq!(ret[0].len(), 1);
667        assert_eq!(ret[0][0], Value::Blob(gae));
668    }
669
670    #[test]
671    fn test_query_multi() {
672        let wallet = WalletDb::new(None, None).unwrap();
673        wallet
674            .exec_batch_sql("CREATE TABLE mista ( why INTEGER, are TEXT, you INTEGER, gae BLOB );")
675            .unwrap();
676
677        let why = 42;
678        let are = "are".to_string();
679        let you = 69;
680        let gae = vec![42u8; 32];
681
682        wallet
683            .exec_sql(
684                "INSERT INTO mista ( why, are, you, gae ) VALUES (?1, ?2, ?3, ?4);",
685                rusqlite::params![why, are, you, gae],
686            )
687            .unwrap();
688        wallet
689            .exec_sql(
690                "INSERT INTO mista ( why, are, you, gae ) VALUES (?1, ?2, ?3, ?4);",
691                rusqlite::params![why, are, you, gae],
692            )
693            .unwrap();
694
695        let ret = wallet.query_multiple("mista", &[], &[]).unwrap();
696        assert_eq!(ret.len(), 2);
697        for row in ret {
698            assert_eq!(row.len(), 4);
699            assert_eq!(row[0], Value::Integer(why));
700            assert_eq!(row[1], Value::Text(are.clone()));
701            assert_eq!(row[2], Value::Integer(you));
702            assert_eq!(row[3], Value::Blob(gae.clone()));
703        }
704        let ret = wallet.query_custom("SELECT * FROM mista;", &[]).unwrap();
705        assert_eq!(ret.len(), 2);
706        for row in ret {
707            assert_eq!(row.len(), 4);
708            assert_eq!(row[0], Value::Integer(why));
709            assert_eq!(row[1], Value::Text(are.clone()));
710            assert_eq!(row[2], Value::Integer(you));
711            assert_eq!(row[3], Value::Blob(gae.clone()));
712        }
713
714        let ret = wallet
715            .query_multiple(
716                "mista",
717                &["gae"],
718                convert_named_params! {("why", why), ("are", are), ("you", you)},
719            )
720            .unwrap();
721        assert_eq!(ret.len(), 2);
722        for row in ret {
723            assert_eq!(row.len(), 1);
724            assert_eq!(row[0], Value::Blob(gae.clone()));
725        }
726        let ret = wallet
727            .query_custom(
728                "SELECT gae FROM mista WHERE why = ?1 AND are = ?2 AND you = ?3;",
729                rusqlite::params![why, are, you],
730            )
731            .unwrap();
732        assert_eq!(ret.len(), 2);
733        for row in ret {
734            assert_eq!(row.len(), 1);
735            assert_eq!(row[0], Value::Blob(gae.clone()));
736        }
737    }
738
739    #[test]
740    fn test_sqlite_smt() {
741        // Setup SQLite database
742        let table = &"smt";
743        let key_col = &"smt_key";
744        let value_col = &"smt_value";
745        let wallet = WalletDb::new(None, None).unwrap();
746        wallet.exec_batch_sql(&format!("CREATE TABLE {table} ( {key_col} BLOB INTEGER PRIMARY KEY NOT NULL, {value_col} BLOB NOT NULL);")).unwrap();
747
748        // Setup SMT
749        const HEIGHT: usize = 3;
750        let hasher = PoseidonFp::new();
751        let empty_leaf = pallas::Base::ZERO;
752        let empty_nodes = gen_empty_nodes::<{ HEIGHT + 1 }, _, _>(&hasher, empty_leaf);
753        let store = WalletStorage::new(&wallet, table, key_col, value_col);
754        let mut smt = SparseMerkleTree::<HEIGHT, { HEIGHT + 1 }, _, _, _>::new(
755            store,
756            hasher.clone(),
757            &empty_nodes,
758        );
759
760        // Verify database is empty
761        let rows = wallet.query_multiple(table, &[key_col], &[]).unwrap();
762        assert!(rows.is_empty());
763
764        let leaves = vec![
765            (pallas::Base::from(1), pallas::Base::random(&mut OsRng)),
766            (pallas::Base::from(2), pallas::Base::random(&mut OsRng)),
767            (pallas::Base::from(3), pallas::Base::random(&mut OsRng)),
768        ];
769        smt.insert_batch(leaves.clone()).unwrap();
770
771        let hash1 = leaves[0].1;
772        let hash2 = leaves[1].1;
773        let hash3 = leaves[2].1;
774
775        let hash = |l, r| hasher.hash([l, r]);
776
777        let hash01 = hash(empty_nodes[3], hash1);
778        let hash23 = hash(hash2, hash3);
779
780        let hash0123 = hash(hash01, hash23);
781        let root = hash(hash0123, empty_nodes[1]);
782        assert_eq!(root, smt.root());
783
784        // Now try to construct a membership proof for leaf 3
785        let pos = leaves[2].0;
786        let path = smt.prove_membership(&pos);
787        assert_eq!(path.path[0], empty_nodes[1]);
788        assert_eq!(path.path[1], hash01);
789        assert_eq!(path.path[2], hash2);
790
791        assert_eq!(hash23, hash(path.path[2], hash3));
792        assert_eq!(hash0123, hash(path.path[1], hash(path.path[2], hash3)));
793        assert_eq!(root, hash(hash(path.path[1], hash(path.path[2], hash3)), path.path[0]));
794
795        assert!(path.verify(&root, &hash3, &pos));
796
797        // Verify database contains keys
798        let rows = wallet.query_multiple(table, &[key_col], &[]).unwrap();
799        assert!(!rows.is_empty());
800
801        // We are now going to rollback the wallet changes
802        let rollback_query = wallet.grab_inverse_cache_block().unwrap();
803        wallet.exec_batch_sql(&rollback_query).unwrap();
804
805        // Clear cache
806        wallet.clear_inverse_cache().unwrap();
807
808        // Verify database is empty again
809        let rows = wallet.query_multiple(table, &[key_col], &[]).unwrap();
810        assert!(rows.is_empty());
811    }
812}