drk/
walletdb.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2025 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use std::{
20    path::PathBuf,
21    sync::{Arc, Mutex},
22};
23
24use log::{debug, error};
25use rusqlite::{
26    types::{ToSql, Value},
27    Connection,
28};
29
30use crate::error::{WalletDbError, WalletDbResult};
31
32pub type WalletPtr = Arc<WalletDb>;
33
34/// Structure representing base wallet database operations.
35pub struct WalletDb {
36    /// Connection to the SQLite database.
37    pub conn: Mutex<Connection>,
38}
39
40impl WalletDb {
41    /// Create a new wallet database handler. If `path` is `None`, create it in memory.
42    pub fn new(path: Option<PathBuf>, password: Option<&str>) -> WalletDbResult<WalletPtr> {
43        let Ok(conn) = (match path.clone() {
44            Some(p) => Connection::open(p),
45            None => Connection::open_in_memory(),
46        }) else {
47            return Err(WalletDbError::ConnectionFailed);
48        };
49
50        if let Some(password) = password {
51            if let Err(e) = conn.pragma_update(None, "key", password) {
52                error!(target: "walletdb::new", "[WalletDb] Pragma update failed: {e}");
53                return Err(WalletDbError::PragmaUpdateError);
54            };
55        }
56        if let Err(e) = conn.pragma_update(None, "foreign_keys", "ON") {
57            error!(target: "walletdb::new", "[WalletDb] Pragma update failed: {e}");
58            return Err(WalletDbError::PragmaUpdateError);
59        };
60
61        debug!(target: "walletdb::new", "[WalletDb] Opened Sqlite connection at \"{path:?}\"");
62        Ok(Arc::new(Self { conn: Mutex::new(conn) }))
63    }
64
65    /// This function executes a given SQL query that contains multiple SQL statements,
66    /// that don't contain any parameters.
67    pub fn exec_batch_sql(&self, query: &str) -> WalletDbResult<()> {
68        debug!(target: "walletdb::exec_batch_sql", "[WalletDb] Executing batch SQL query:\n{query}");
69        let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
70        if let Err(e) = conn.execute_batch(query) {
71            error!(target: "walletdb::exec_batch_sql", "[WalletDb] Query failed: {e}");
72            return Err(WalletDbError::QueryExecutionFailed)
73        };
74
75        Ok(())
76    }
77
78    /// This function executes a given SQL query, but isn't able to return anything.
79    /// Therefore it's best to use it for initializing a table or similar things.
80    pub fn exec_sql(&self, query: &str, params: &[&dyn ToSql]) -> WalletDbResult<()> {
81        debug!(target: "walletdb::exec_sql", "[WalletDb] Executing SQL query:\n{query}");
82        let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
83
84        // If no params are provided, execute directly
85        if params.is_empty() {
86            if let Err(e) = conn.execute(query, ()) {
87                error!(target: "walletdb::exec_sql", "[WalletDb] Query failed: {e}");
88                return Err(WalletDbError::QueryExecutionFailed)
89            };
90            return Ok(())
91        }
92
93        // First we prepare the query
94        let Ok(mut stmt) = conn.prepare(query) else {
95            return Err(WalletDbError::QueryPreparationFailed)
96        };
97
98        // Execute the query using provided params
99        if let Err(e) = stmt.execute(params) {
100            error!(target: "walletdb::exec_sql", "[WalletDb] Query failed: {e}");
101            return Err(WalletDbError::QueryExecutionFailed)
102        };
103
104        // Finalize query and drop connection lock
105        if let Err(e) = stmt.finalize() {
106            error!(target: "walletdb::exec_sql", "[WalletDb] Query finalization failed: {e}");
107            return Err(WalletDbError::QueryFinalizationFailed)
108        };
109        drop(conn);
110
111        Ok(())
112    }
113
114    /// Generate a new statement for provided query and bind the provided params,
115    /// returning the raw SQL query as a string.
116    pub fn create_prepared_statement(
117        &self,
118        query: &str,
119        params: &[&dyn ToSql],
120    ) -> WalletDbResult<String> {
121        debug!(target: "walletdb::create_prepared_statement", "[WalletDb] Preparing statement for SQL query:\n{query}");
122        let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
123
124        // First we prepare the query
125        let Ok(mut stmt) = conn.prepare(query) else {
126            return Err(WalletDbError::QueryPreparationFailed)
127        };
128
129        // Bind all provided params
130        for (index, param) in params.iter().enumerate() {
131            if stmt.raw_bind_parameter(index + 1, param).is_err() {
132                return Err(WalletDbError::QueryPreparationFailed)
133            };
134        }
135
136        // Grab the raw SQL
137        let query = stmt.expanded_sql().unwrap();
138
139        // Drop statement and the connection lock
140        drop(stmt);
141        drop(conn);
142
143        Ok(query)
144    }
145
146    /// Generate a `SELECT` query for provided table from selected column names and
147    /// provided `WHERE` clauses. Named parameters are supported in the `WHERE` clauses,
148    /// assuming they follow the normal formatting ":{column_name}".
149    fn generate_select_query(
150        &self,
151        table: &str,
152        col_names: &[&str],
153        params: &[(&str, &dyn ToSql)],
154    ) -> String {
155        let mut query = if col_names.is_empty() {
156            format!("SELECT * FROM {table}")
157        } else {
158            format!("SELECT {} FROM {table}", col_names.join(", "))
159        };
160        if params.is_empty() {
161            return query
162        }
163
164        let mut where_str = Vec::with_capacity(params.len());
165        for (k, _) in params {
166            let col = &k[1..];
167            where_str.push(format!("{col} = {k}"));
168        }
169        query.push_str(&format!(" WHERE {}", where_str.join(" AND ")));
170
171        query
172    }
173
174    /// Query provided table from selected column names and provided `WHERE` clauses,
175    /// for a single row.
176    pub fn query_single(
177        &self,
178        table: &str,
179        col_names: &[&str],
180        params: &[(&str, &dyn ToSql)],
181    ) -> WalletDbResult<Vec<Value>> {
182        // Generate `SELECT` query
183        let query = self.generate_select_query(table, col_names, params);
184        debug!(target: "walletdb::query_single", "[WalletDb] Executing SQL query:\n{query}");
185
186        // First we prepare the query
187        let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
188
189        let Ok(mut stmt) = conn.prepare(&query) else {
190            return Err(WalletDbError::QueryPreparationFailed)
191        };
192
193        // Execute the query using provided params
194        let Ok(mut rows) = stmt.query(params) else {
195            return Err(WalletDbError::QueryExecutionFailed)
196        };
197
198        // Check if row exists
199        let Ok(next) = rows.next() else { return Err(WalletDbError::QueryExecutionFailed) };
200        let row = match next {
201            Some(row_result) => row_result,
202            None => return Err(WalletDbError::RowNotFound),
203        };
204
205        // Grab returned values
206        let mut result = vec![];
207        if col_names.is_empty() {
208            let mut idx = 0;
209            loop {
210                let Ok(value) = row.get(idx) else { break };
211                result.push(value);
212                idx += 1;
213            }
214        } else {
215            for col in col_names {
216                let Ok(value) = row.get(*col) else {
217                    return Err(WalletDbError::ParseColumnValueError)
218                };
219                result.push(value);
220            }
221        }
222
223        Ok(result)
224    }
225
226    /// Query provided table from selected column names and provided `WHERE` clauses,
227    /// for multiple rows.
228    pub fn query_multiple(
229        &self,
230        table: &str,
231        col_names: &[&str],
232        params: &[(&str, &dyn ToSql)],
233    ) -> WalletDbResult<Vec<Vec<Value>>> {
234        // Generate `SELECT` query
235        let query = self.generate_select_query(table, col_names, params);
236        debug!(target: "walletdb::query_multiple", "[WalletDb] Executing SQL query:\n{query}");
237
238        // First we prepare the query
239        let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
240        let Ok(mut stmt) = conn.prepare(&query) else {
241            return Err(WalletDbError::QueryPreparationFailed)
242        };
243
244        // Execute the query using provided converted params
245        let Ok(mut rows) = stmt.query(params) else {
246            return Err(WalletDbError::QueryExecutionFailed)
247        };
248
249        // Loop over returned rows and parse them
250        let mut result = vec![];
251        loop {
252            // Check if an error occured
253            let row = match rows.next() {
254                Ok(r) => r,
255                Err(_) => return Err(WalletDbError::QueryExecutionFailed),
256            };
257
258            // Check if no row was returned
259            let row = match row {
260                Some(r) => r,
261                None => break,
262            };
263
264            // Grab row returned values
265            let mut row_values = vec![];
266            if col_names.is_empty() {
267                let mut idx = 0;
268                loop {
269                    let Ok(value) = row.get(idx) else { break };
270                    row_values.push(value);
271                    idx += 1;
272                }
273            } else {
274                for col in col_names {
275                    let Ok(value) = row.get(*col) else {
276                        return Err(WalletDbError::ParseColumnValueError)
277                    };
278                    row_values.push(value);
279                }
280            }
281            result.push(row_values);
282        }
283
284        Ok(result)
285    }
286
287    /// Query provided table using provided query for multiple rows.
288    pub fn query_custom(
289        &self,
290        query: &str,
291        params: &[&dyn ToSql],
292    ) -> WalletDbResult<Vec<Vec<Value>>> {
293        debug!(target: "walletdb::query_custom", "[WalletDb] Executing SQL query:\n{query}");
294
295        // First we prepare the query
296        let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
297        let Ok(mut stmt) = conn.prepare(query) else {
298            return Err(WalletDbError::QueryPreparationFailed)
299        };
300
301        // Execute the query using provided converted params
302        let Ok(mut rows) = stmt.query(params) else {
303            return Err(WalletDbError::QueryExecutionFailed)
304        };
305
306        // Loop over returned rows and parse them
307        let mut result = vec![];
308        loop {
309            // Check if an error occured
310            let row = match rows.next() {
311                Ok(r) => r,
312                Err(_) => return Err(WalletDbError::QueryExecutionFailed),
313            };
314
315            // Check if no row was returned
316            let row = match row {
317                Some(r) => r,
318                None => break,
319            };
320
321            // Grab row returned values
322            let mut row_values = vec![];
323            let mut idx = 0;
324            loop {
325                let Ok(value) = row.get(idx) else { break };
326                row_values.push(value);
327                idx += 1;
328            }
329            result.push(row_values);
330        }
331
332        Ok(result)
333    }
334}
335
336/// Custom implementation of rusqlite::named_params! to use `expr` instead of `literal` as `$param_name`,
337/// and append the ":" named parameters prefix.
338#[macro_export]
339macro_rules! convert_named_params {
340    () => {
341        &[] as &[(&str, &dyn rusqlite::types::ToSql)]
342    };
343    ($(($param_name:expr, $param_val:expr)),+ $(,)?) => {
344        &[$((format!(":{}", $param_name).as_str(), &$param_val as &dyn rusqlite::types::ToSql)),+] as &[(&str, &dyn rusqlite::types::ToSql)]
345    };
346}
347
348#[cfg(test)]
349mod tests {
350    use rusqlite::types::Value;
351
352    use crate::walletdb::WalletDb;
353
354    #[test]
355    fn test_mem_wallet() {
356        let wallet = WalletDb::new(None, Some("foobar")).unwrap();
357        wallet
358            .exec_batch_sql(
359                "CREATE TABLE mista ( numba INTEGER ); INSERT INTO mista ( numba ) VALUES ( 42 );",
360            )
361            .unwrap();
362
363        let ret = wallet.query_single("mista", &["numba"], &[]).unwrap();
364        assert_eq!(ret.len(), 1);
365        let numba: i64 = if let Value::Integer(numba) = ret[0] { numba } else { -1 };
366        assert_eq!(numba, 42);
367
368        let ret = wallet.query_custom("SELECT numba FROM mista;", &[]).unwrap();
369        assert_eq!(ret.len(), 1);
370        assert_eq!(ret[0].len(), 1);
371        let numba: i64 = if let Value::Integer(numba) = ret[0][0] { numba } else { -1 };
372        assert_eq!(numba, 42);
373    }
374
375    #[test]
376    fn test_query_single() {
377        let wallet = WalletDb::new(None, None).unwrap();
378        wallet
379            .exec_batch_sql("CREATE TABLE mista ( why INTEGER, are TEXT, you INTEGER, gae BLOB );")
380            .unwrap();
381
382        let why = 42;
383        let are = "are".to_string();
384        let you = 69;
385        let gae = vec![42u8; 32];
386
387        wallet
388            .exec_sql(
389                "INSERT INTO mista ( why, are, you, gae ) VALUES (?1, ?2, ?3, ?4);",
390                rusqlite::params![why, are, you, gae],
391            )
392            .unwrap();
393
394        let ret = wallet.query_single("mista", &["why", "are", "you", "gae"], &[]).unwrap();
395        assert_eq!(ret.len(), 4);
396        assert_eq!(ret[0], Value::Integer(why));
397        assert_eq!(ret[1], Value::Text(are.clone()));
398        assert_eq!(ret[2], Value::Integer(you));
399        assert_eq!(ret[3], Value::Blob(gae.clone()));
400        let ret = wallet.query_custom("SELECT why, are, you, gae FROM mista;", &[]).unwrap();
401        assert_eq!(ret.len(), 1);
402        assert_eq!(ret[0].len(), 4);
403        assert_eq!(ret[0][0], Value::Integer(why));
404        assert_eq!(ret[0][1], Value::Text(are.clone()));
405        assert_eq!(ret[0][2], Value::Integer(you));
406        assert_eq!(ret[0][3], Value::Blob(gae.clone()));
407
408        let ret = wallet
409            .query_single(
410                "mista",
411                &["gae"],
412                rusqlite::named_params! {":why": why, ":are": are, ":you": you},
413            )
414            .unwrap();
415        assert_eq!(ret.len(), 1);
416        assert_eq!(ret[0], Value::Blob(gae.clone()));
417        let ret = wallet
418            .query_custom(
419                "SELECT gae FROM mista WHERE why = ?1 AND are = ?2 AND you = ?3;",
420                rusqlite::params![why, are, you],
421            )
422            .unwrap();
423        assert_eq!(ret.len(), 1);
424        assert_eq!(ret[0].len(), 1);
425        assert_eq!(ret[0][0], Value::Blob(gae));
426    }
427
428    #[test]
429    fn test_query_multi() {
430        let wallet = WalletDb::new(None, None).unwrap();
431        wallet
432            .exec_batch_sql("CREATE TABLE mista ( why INTEGER, are TEXT, you INTEGER, gae BLOB );")
433            .unwrap();
434
435        let why = 42;
436        let are = "are".to_string();
437        let you = 69;
438        let gae = vec![42u8; 32];
439
440        wallet
441            .exec_sql(
442                "INSERT INTO mista ( why, are, you, gae ) VALUES (?1, ?2, ?3, ?4);",
443                rusqlite::params![why, are, you, gae],
444            )
445            .unwrap();
446        wallet
447            .exec_sql(
448                "INSERT INTO mista ( why, are, you, gae ) VALUES (?1, ?2, ?3, ?4);",
449                rusqlite::params![why, are, you, gae],
450            )
451            .unwrap();
452
453        let ret = wallet.query_multiple("mista", &[], &[]).unwrap();
454        assert_eq!(ret.len(), 2);
455        for row in ret {
456            assert_eq!(row.len(), 4);
457            assert_eq!(row[0], Value::Integer(why));
458            assert_eq!(row[1], Value::Text(are.clone()));
459            assert_eq!(row[2], Value::Integer(you));
460            assert_eq!(row[3], Value::Blob(gae.clone()));
461        }
462        let ret = wallet.query_custom("SELECT * FROM mista;", &[]).unwrap();
463        assert_eq!(ret.len(), 2);
464        for row in ret {
465            assert_eq!(row.len(), 4);
466            assert_eq!(row[0], Value::Integer(why));
467            assert_eq!(row[1], Value::Text(are.clone()));
468            assert_eq!(row[2], Value::Integer(you));
469            assert_eq!(row[3], Value::Blob(gae.clone()));
470        }
471
472        let ret = wallet
473            .query_multiple(
474                "mista",
475                &["gae"],
476                convert_named_params! {("why", why), ("are", are), ("you", you)},
477            )
478            .unwrap();
479        assert_eq!(ret.len(), 2);
480        for row in ret {
481            assert_eq!(row.len(), 1);
482            assert_eq!(row[0], Value::Blob(gae.clone()));
483        }
484        let ret = wallet
485            .query_custom(
486                "SELECT gae FROM mista WHERE why = ?1 AND are = ?2 AND you = ?3;",
487                rusqlite::params![why, are, you],
488            )
489            .unwrap();
490        assert_eq!(ret.len(), 2);
491        for row in ret {
492            assert_eq!(row.len(), 1);
493            assert_eq!(row[0], Value::Blob(gae.clone()));
494        }
495    }
496}