1use std::{
20 path::PathBuf,
21 sync::{Arc, Mutex},
22};
23
24use log::{debug, error};
25use rusqlite::{
26 types::{ToSql, Value},
27 Connection,
28};
29
30use crate::error::{WalletDbError, WalletDbResult};
31
32pub type WalletPtr = Arc<WalletDb>;
33
34pub struct WalletDb {
36 pub conn: Mutex<Connection>,
38}
39
40impl WalletDb {
41 pub fn new(path: Option<PathBuf>, password: Option<&str>) -> WalletDbResult<WalletPtr> {
43 let Ok(conn) = (match path.clone() {
44 Some(p) => Connection::open(p),
45 None => Connection::open_in_memory(),
46 }) else {
47 return Err(WalletDbError::ConnectionFailed);
48 };
49
50 if let Some(password) = password {
51 if let Err(e) = conn.pragma_update(None, "key", password) {
52 error!(target: "walletdb::new", "[WalletDb] Pragma update failed: {e}");
53 return Err(WalletDbError::PragmaUpdateError);
54 };
55 }
56 if let Err(e) = conn.pragma_update(None, "foreign_keys", "ON") {
57 error!(target: "walletdb::new", "[WalletDb] Pragma update failed: {e}");
58 return Err(WalletDbError::PragmaUpdateError);
59 };
60
61 debug!(target: "walletdb::new", "[WalletDb] Opened Sqlite connection at \"{path:?}\"");
62 Ok(Arc::new(Self { conn: Mutex::new(conn) }))
63 }
64
65 pub fn exec_batch_sql(&self, query: &str) -> WalletDbResult<()> {
68 debug!(target: "walletdb::exec_batch_sql", "[WalletDb] Executing batch SQL query:\n{query}");
69 let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
70 if let Err(e) = conn.execute_batch(query) {
71 error!(target: "walletdb::exec_batch_sql", "[WalletDb] Query failed: {e}");
72 return Err(WalletDbError::QueryExecutionFailed)
73 };
74
75 Ok(())
76 }
77
78 pub fn exec_sql(&self, query: &str, params: &[&dyn ToSql]) -> WalletDbResult<()> {
81 debug!(target: "walletdb::exec_sql", "[WalletDb] Executing SQL query:\n{query}");
82 let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
83
84 if params.is_empty() {
86 if let Err(e) = conn.execute(query, ()) {
87 error!(target: "walletdb::exec_sql", "[WalletDb] Query failed: {e}");
88 return Err(WalletDbError::QueryExecutionFailed)
89 };
90 return Ok(())
91 }
92
93 let Ok(mut stmt) = conn.prepare(query) else {
95 return Err(WalletDbError::QueryPreparationFailed)
96 };
97
98 if let Err(e) = stmt.execute(params) {
100 error!(target: "walletdb::exec_sql", "[WalletDb] Query failed: {e}");
101 return Err(WalletDbError::QueryExecutionFailed)
102 };
103
104 if let Err(e) = stmt.finalize() {
106 error!(target: "walletdb::exec_sql", "[WalletDb] Query finalization failed: {e}");
107 return Err(WalletDbError::QueryFinalizationFailed)
108 };
109 drop(conn);
110
111 Ok(())
112 }
113
114 pub fn create_prepared_statement(
117 &self,
118 query: &str,
119 params: &[&dyn ToSql],
120 ) -> WalletDbResult<String> {
121 debug!(target: "walletdb::create_prepared_statement", "[WalletDb] Preparing statement for SQL query:\n{query}");
122 let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
123
124 let Ok(mut stmt) = conn.prepare(query) else {
126 return Err(WalletDbError::QueryPreparationFailed)
127 };
128
129 for (index, param) in params.iter().enumerate() {
131 if stmt.raw_bind_parameter(index + 1, param).is_err() {
132 return Err(WalletDbError::QueryPreparationFailed)
133 };
134 }
135
136 let query = stmt.expanded_sql().unwrap();
138
139 drop(stmt);
141 drop(conn);
142
143 Ok(query)
144 }
145
146 fn generate_select_query(
150 &self,
151 table: &str,
152 col_names: &[&str],
153 params: &[(&str, &dyn ToSql)],
154 ) -> String {
155 let mut query = if col_names.is_empty() {
156 format!("SELECT * FROM {table}")
157 } else {
158 format!("SELECT {} FROM {table}", col_names.join(", "))
159 };
160 if params.is_empty() {
161 return query
162 }
163
164 let mut where_str = Vec::with_capacity(params.len());
165 for (k, _) in params {
166 let col = &k[1..];
167 where_str.push(format!("{col} = {k}"));
168 }
169 query.push_str(&format!(" WHERE {}", where_str.join(" AND ")));
170
171 query
172 }
173
174 pub fn query_single(
177 &self,
178 table: &str,
179 col_names: &[&str],
180 params: &[(&str, &dyn ToSql)],
181 ) -> WalletDbResult<Vec<Value>> {
182 let query = self.generate_select_query(table, col_names, params);
184 debug!(target: "walletdb::query_single", "[WalletDb] Executing SQL query:\n{query}");
185
186 let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
188
189 let Ok(mut stmt) = conn.prepare(&query) else {
190 return Err(WalletDbError::QueryPreparationFailed)
191 };
192
193 let Ok(mut rows) = stmt.query(params) else {
195 return Err(WalletDbError::QueryExecutionFailed)
196 };
197
198 let Ok(next) = rows.next() else { return Err(WalletDbError::QueryExecutionFailed) };
200 let row = match next {
201 Some(row_result) => row_result,
202 None => return Err(WalletDbError::RowNotFound),
203 };
204
205 let mut result = vec![];
207 if col_names.is_empty() {
208 let mut idx = 0;
209 loop {
210 let Ok(value) = row.get(idx) else { break };
211 result.push(value);
212 idx += 1;
213 }
214 } else {
215 for col in col_names {
216 let Ok(value) = row.get(*col) else {
217 return Err(WalletDbError::ParseColumnValueError)
218 };
219 result.push(value);
220 }
221 }
222
223 Ok(result)
224 }
225
226 pub fn query_multiple(
229 &self,
230 table: &str,
231 col_names: &[&str],
232 params: &[(&str, &dyn ToSql)],
233 ) -> WalletDbResult<Vec<Vec<Value>>> {
234 let query = self.generate_select_query(table, col_names, params);
236 debug!(target: "walletdb::query_multiple", "[WalletDb] Executing SQL query:\n{query}");
237
238 let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
240 let Ok(mut stmt) = conn.prepare(&query) else {
241 return Err(WalletDbError::QueryPreparationFailed)
242 };
243
244 let Ok(mut rows) = stmt.query(params) else {
246 return Err(WalletDbError::QueryExecutionFailed)
247 };
248
249 let mut result = vec![];
251 loop {
252 let row = match rows.next() {
254 Ok(r) => r,
255 Err(_) => return Err(WalletDbError::QueryExecutionFailed),
256 };
257
258 let row = match row {
260 Some(r) => r,
261 None => break,
262 };
263
264 let mut row_values = vec![];
266 if col_names.is_empty() {
267 let mut idx = 0;
268 loop {
269 let Ok(value) = row.get(idx) else { break };
270 row_values.push(value);
271 idx += 1;
272 }
273 } else {
274 for col in col_names {
275 let Ok(value) = row.get(*col) else {
276 return Err(WalletDbError::ParseColumnValueError)
277 };
278 row_values.push(value);
279 }
280 }
281 result.push(row_values);
282 }
283
284 Ok(result)
285 }
286
287 pub fn query_custom(
289 &self,
290 query: &str,
291 params: &[&dyn ToSql],
292 ) -> WalletDbResult<Vec<Vec<Value>>> {
293 debug!(target: "walletdb::query_custom", "[WalletDb] Executing SQL query:\n{query}");
294
295 let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
297 let Ok(mut stmt) = conn.prepare(query) else {
298 return Err(WalletDbError::QueryPreparationFailed)
299 };
300
301 let Ok(mut rows) = stmt.query(params) else {
303 return Err(WalletDbError::QueryExecutionFailed)
304 };
305
306 let mut result = vec![];
308 loop {
309 let row = match rows.next() {
311 Ok(r) => r,
312 Err(_) => return Err(WalletDbError::QueryExecutionFailed),
313 };
314
315 let row = match row {
317 Some(r) => r,
318 None => break,
319 };
320
321 let mut row_values = vec![];
323 let mut idx = 0;
324 loop {
325 let Ok(value) = row.get(idx) else { break };
326 row_values.push(value);
327 idx += 1;
328 }
329 result.push(row_values);
330 }
331
332 Ok(result)
333 }
334}
335
336#[macro_export]
339macro_rules! convert_named_params {
340 () => {
341 &[] as &[(&str, &dyn rusqlite::types::ToSql)]
342 };
343 ($(($param_name:expr, $param_val:expr)),+ $(,)?) => {
344 &[$((format!(":{}", $param_name).as_str(), &$param_val as &dyn rusqlite::types::ToSql)),+] as &[(&str, &dyn rusqlite::types::ToSql)]
345 };
346}
347
348#[cfg(test)]
349mod tests {
350 use rusqlite::types::Value;
351
352 use crate::walletdb::WalletDb;
353
354 #[test]
355 fn test_mem_wallet() {
356 let wallet = WalletDb::new(None, Some("foobar")).unwrap();
357 wallet
358 .exec_batch_sql(
359 "CREATE TABLE mista ( numba INTEGER ); INSERT INTO mista ( numba ) VALUES ( 42 );",
360 )
361 .unwrap();
362
363 let ret = wallet.query_single("mista", &["numba"], &[]).unwrap();
364 assert_eq!(ret.len(), 1);
365 let numba: i64 = if let Value::Integer(numba) = ret[0] { numba } else { -1 };
366 assert_eq!(numba, 42);
367
368 let ret = wallet.query_custom("SELECT numba FROM mista;", &[]).unwrap();
369 assert_eq!(ret.len(), 1);
370 assert_eq!(ret[0].len(), 1);
371 let numba: i64 = if let Value::Integer(numba) = ret[0][0] { numba } else { -1 };
372 assert_eq!(numba, 42);
373 }
374
375 #[test]
376 fn test_query_single() {
377 let wallet = WalletDb::new(None, None).unwrap();
378 wallet
379 .exec_batch_sql("CREATE TABLE mista ( why INTEGER, are TEXT, you INTEGER, gae BLOB );")
380 .unwrap();
381
382 let why = 42;
383 let are = "are".to_string();
384 let you = 69;
385 let gae = vec![42u8; 32];
386
387 wallet
388 .exec_sql(
389 "INSERT INTO mista ( why, are, you, gae ) VALUES (?1, ?2, ?3, ?4);",
390 rusqlite::params![why, are, you, gae],
391 )
392 .unwrap();
393
394 let ret = wallet.query_single("mista", &["why", "are", "you", "gae"], &[]).unwrap();
395 assert_eq!(ret.len(), 4);
396 assert_eq!(ret[0], Value::Integer(why));
397 assert_eq!(ret[1], Value::Text(are.clone()));
398 assert_eq!(ret[2], Value::Integer(you));
399 assert_eq!(ret[3], Value::Blob(gae.clone()));
400 let ret = wallet.query_custom("SELECT why, are, you, gae FROM mista;", &[]).unwrap();
401 assert_eq!(ret.len(), 1);
402 assert_eq!(ret[0].len(), 4);
403 assert_eq!(ret[0][0], Value::Integer(why));
404 assert_eq!(ret[0][1], Value::Text(are.clone()));
405 assert_eq!(ret[0][2], Value::Integer(you));
406 assert_eq!(ret[0][3], Value::Blob(gae.clone()));
407
408 let ret = wallet
409 .query_single(
410 "mista",
411 &["gae"],
412 rusqlite::named_params! {":why": why, ":are": are, ":you": you},
413 )
414 .unwrap();
415 assert_eq!(ret.len(), 1);
416 assert_eq!(ret[0], Value::Blob(gae.clone()));
417 let ret = wallet
418 .query_custom(
419 "SELECT gae FROM mista WHERE why = ?1 AND are = ?2 AND you = ?3;",
420 rusqlite::params![why, are, you],
421 )
422 .unwrap();
423 assert_eq!(ret.len(), 1);
424 assert_eq!(ret[0].len(), 1);
425 assert_eq!(ret[0][0], Value::Blob(gae));
426 }
427
428 #[test]
429 fn test_query_multi() {
430 let wallet = WalletDb::new(None, None).unwrap();
431 wallet
432 .exec_batch_sql("CREATE TABLE mista ( why INTEGER, are TEXT, you INTEGER, gae BLOB );")
433 .unwrap();
434
435 let why = 42;
436 let are = "are".to_string();
437 let you = 69;
438 let gae = vec![42u8; 32];
439
440 wallet
441 .exec_sql(
442 "INSERT INTO mista ( why, are, you, gae ) VALUES (?1, ?2, ?3, ?4);",
443 rusqlite::params![why, are, you, gae],
444 )
445 .unwrap();
446 wallet
447 .exec_sql(
448 "INSERT INTO mista ( why, are, you, gae ) VALUES (?1, ?2, ?3, ?4);",
449 rusqlite::params![why, are, you, gae],
450 )
451 .unwrap();
452
453 let ret = wallet.query_multiple("mista", &[], &[]).unwrap();
454 assert_eq!(ret.len(), 2);
455 for row in ret {
456 assert_eq!(row.len(), 4);
457 assert_eq!(row[0], Value::Integer(why));
458 assert_eq!(row[1], Value::Text(are.clone()));
459 assert_eq!(row[2], Value::Integer(you));
460 assert_eq!(row[3], Value::Blob(gae.clone()));
461 }
462 let ret = wallet.query_custom("SELECT * FROM mista;", &[]).unwrap();
463 assert_eq!(ret.len(), 2);
464 for row in ret {
465 assert_eq!(row.len(), 4);
466 assert_eq!(row[0], Value::Integer(why));
467 assert_eq!(row[1], Value::Text(are.clone()));
468 assert_eq!(row[2], Value::Integer(you));
469 assert_eq!(row[3], Value::Blob(gae.clone()));
470 }
471
472 let ret = wallet
473 .query_multiple(
474 "mista",
475 &["gae"],
476 convert_named_params! {("why", why), ("are", are), ("you", you)},
477 )
478 .unwrap();
479 assert_eq!(ret.len(), 2);
480 for row in ret {
481 assert_eq!(row.len(), 1);
482 assert_eq!(row[0], Value::Blob(gae.clone()));
483 }
484 let ret = wallet
485 .query_custom(
486 "SELECT gae FROM mista WHERE why = ?1 AND are = ?2 AND you = ?3;",
487 rusqlite::params![why, are, you],
488 )
489 .unwrap();
490 assert_eq!(ret.len(), 2);
491 for row in ret {
492 assert_eq!(row.len(), 1);
493 assert_eq!(row[0], Value::Blob(gae.clone()));
494 }
495 }
496}