darkfi/dht/
handler.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2025 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use async_trait::async_trait;
20use futures::stream::FuturesUnordered;
21use log::{debug, info, warn};
22use num_bigint::BigUint;
23use smol::{lock::Semaphore, stream::StreamExt};
24use std::{
25    collections::{HashMap, HashSet},
26    marker::Sync,
27    sync::Arc,
28    time::Duration,
29};
30
31use super::{ChannelCacheItem, Dht, DhtNode, DhtRouterItem, DhtRouterPtr};
32use crate::{
33    geode::hash_to_string,
34    net::{
35        connector::Connector,
36        session::{Session, SESSION_REFINE, SESSION_SEED},
37        ChannelPtr, Message,
38    },
39    system::timeout::timeout,
40    Error, Result,
41};
42
43#[async_trait]
44pub trait DhtHandler<N: DhtNode>: Sync {
45    fn dht(&self) -> Arc<Dht<N>>;
46
47    /// Get our own node
48    async fn node(&self) -> N;
49
50    /// Send a DHT ping request
51    async fn ping(&self, channel: ChannelPtr) -> Result<N>;
52
53    /// Triggered when we find a new node
54    async fn on_new_node(&self, node: &N) -> Result<()>;
55
56    /// Send FIND NODES request to a peer to get nodes close to `key`
57    async fn fetch_nodes(&self, node: &N, key: &blake3::Hash) -> Result<Vec<N>>;
58
59    /// Announce message for a key, and add ourselves to router
60    async fn announce<M: Message>(
61        &self,
62        key: &blake3::Hash,
63        message: &M,
64        router: DhtRouterPtr<N>,
65    ) -> Result<()>
66    where
67        N: 'async_trait,
68    {
69        let self_node = self.node().await;
70        if self_node.addresses().is_empty() {
71            return Err(().into()); // TODO
72        }
73
74        self.add_to_router(router.clone(), key, vec![self_node.clone().into()]).await;
75        let nodes = self.lookup_nodes(key).await?;
76        info!(target: "dht::DhtHandler::announce()", "Announcing {} to {} nodes", hash_to_string(key), nodes.len());
77
78        for node in nodes {
79            let channel_res = self.get_channel(&node, None).await;
80            if let Ok(channel) = channel_res {
81                let _ = channel.send(message).await;
82                self.cleanup_channel(channel).await;
83            }
84        }
85
86        Ok(())
87    }
88
89    /// Lookup our own node id to bootstrap our DHT
90    async fn bootstrap(&self) {
91        self.dht().set_bootstrapped(true).await;
92
93        let self_node_id = self.node().await.id();
94        debug!(target: "dht::DhtHandler::bootstrap()", "DHT bootstrapping {}", hash_to_string(&self_node_id));
95        let nodes = self.lookup_nodes(&self_node_id).await;
96
97        if nodes.is_err() || nodes.map_or(true, |v| v.is_empty()) {
98            self.dht().set_bootstrapped(false).await;
99        }
100    }
101
102    /// Add a node in the correct bucket
103    async fn add_node(&self, node: N)
104    where
105        N: 'async_trait,
106    {
107        let self_node = self.node().await;
108
109        // Do not add ourselves to the buckets
110        if node.id() == self_node.id() {
111            return;
112        }
113
114        // Don't add this node if it has any external address that is the same as one of ours
115        let node_addresses = node.addresses();
116        if self_node.addresses().iter().any(|addr| node_addresses.contains(addr)) {
117            return;
118        }
119
120        // Do not add a node to the buckets if it does not have an address
121        if node.addresses().is_empty() {
122            return;
123        }
124
125        let bucket_index = self.dht().get_bucket_index(&self.node().await.id(), &node.id()).await;
126        let buckets_lock = self.dht().buckets.clone();
127        let mut buckets = buckets_lock.write().await;
128        let bucket = &mut buckets[bucket_index];
129
130        // Node is already in the bucket
131        if bucket.nodes.iter().any(|n| n.id() == node.id()) {
132            return;
133        }
134
135        // Bucket is full
136        if bucket.nodes.len() >= self.dht().settings.k {
137            // Ping the least recently seen node
138            if let Ok(channel) = self.get_channel(&bucket.nodes[0], None).await {
139                let ping_res = self.ping(channel.clone()).await;
140                self.cleanup_channel(channel).await;
141                if ping_res.is_ok() {
142                    // Ping was successful, move the least recently seen node to the tail
143                    let n = bucket.nodes.remove(0);
144                    bucket.nodes.push(n);
145                    return;
146                }
147            }
148
149            // Ping was not successful, remove the least recently seen node and add the new node
150            bucket.nodes.remove(0);
151            bucket.nodes.push(node);
152            return;
153        }
154
155        // Bucket is not full
156        bucket.nodes.push(node);
157    }
158
159    /// Move a node to the tail in its bucket,
160    /// to show that it is the most recently seen in the bucket.
161    /// If the node is not in a bucket it will be added using `add_node`
162    async fn update_node(&self, node: &N) {
163        let bucket_index = self.dht().get_bucket_index(&self.node().await.id(), &node.id()).await;
164        let buckets_lock = self.dht().buckets.clone();
165        let mut buckets = buckets_lock.write().await;
166        let bucket = &mut buckets[bucket_index];
167
168        let node_index = bucket.nodes.iter().position(|n| n.id() == node.id());
169        if node_index.is_none() {
170            drop(buckets);
171            self.add_node(node.clone()).await;
172            return;
173        }
174
175        let n = bucket.nodes.remove(node_index.unwrap());
176        bucket.nodes.push(n);
177    }
178
179    /// Wait to acquire a semaphore, then run `self.fetch_nodes`.
180    /// This is meant to be used in `lookup_nodes`.
181    async fn fetch_nodes_sp(
182        &self,
183        semaphore: Arc<Semaphore>,
184        node: N,
185        key: &blake3::Hash,
186    ) -> (N, Result<Vec<N>>)
187    where
188        N: 'async_trait,
189    {
190        let _permit = semaphore.acquire().await;
191        (node.clone(), self.fetch_nodes(&node, key).await)
192    }
193
194    /// Find `k` nodes closest to a key
195    async fn lookup_nodes(&self, key: &blake3::Hash) -> Result<Vec<N>> {
196        info!(target: "dht::DhtHandler::lookup_nodes()", "Starting node lookup for key {}", bs58::encode(key.as_bytes()).into_string());
197
198        let self_node_id = self.node().await.id();
199        let k = self.dht().settings.k;
200        let a = self.dht().settings.alpha;
201        let semaphore = Arc::new(Semaphore::new(self.dht().settings.concurrency));
202        let mut futures = FuturesUnordered::new();
203
204        // Nodes we did not send a request to (yet), sorted by distance from `key`
205        let mut nodes_to_visit = self.dht().find_neighbors(key, k).await;
206        // Nodes with a pending request or a request completed
207        let mut visited_nodes = HashSet::<blake3::Hash>::new();
208        // Nodes that responded to our request, sorted by distance from `key`
209        let mut result = Vec::<N>::new();
210
211        // Create the first `alpha` tasks
212        for _ in 0..a {
213            match nodes_to_visit.pop() {
214                Some(node) => {
215                    visited_nodes.insert(node.id());
216                    futures.push(self.fetch_nodes_sp(semaphore.clone(), node, key));
217                }
218                None => {
219                    break;
220                }
221            }
222        }
223
224        while let Some((queried_node, value_result)) = futures.next().await {
225            match value_result {
226                Ok(mut nodes) => {
227                    info!(target: "dht::DhtHandler::lookup_nodes()", "Queried {}, got {} nodes", hash_to_string(&queried_node.id()), nodes.len());
228
229                    // Remove ourselves and already known nodes from the new nodes
230                    nodes.retain(|node| {
231                        node.id() != self_node_id &&
232                            !visited_nodes.contains(&node.id()) &&
233                            !nodes_to_visit.iter().any(|n| n.id() == node.id())
234                    });
235
236                    // Add new nodes to our buckets
237                    for node in nodes.clone() {
238                        self.add_node(node).await;
239                    }
240
241                    // Add nodes to the list of nodes to visit
242                    nodes_to_visit.extend(nodes.clone());
243                    self.dht().sort_by_distance(&mut nodes_to_visit, key);
244
245                    // Update nearest_nodes
246                    result.push(queried_node.clone());
247                    self.dht().sort_by_distance(&mut result, key);
248
249                    // Early termination check:
250                    // Stop if our furthest visited node is closer than the closest node we will query,
251                    // and we already have `k` or more nodes in the result set
252                    if result.len() >= k {
253                        if let Some(furthest) = result.last() {
254                            if let Some(next_node) = nodes_to_visit.first() {
255                                let furthest_dist = BigUint::from_bytes_be(
256                                    &self.dht().distance(key, &furthest.id()),
257                                );
258                                let next_dist = BigUint::from_bytes_be(
259                                    &self.dht().distance(key, &next_node.id()),
260                                );
261                                if furthest_dist < next_dist {
262                                    info!(target: "dht::DhtHandler::lookup_nodes()", "Early termination for lookup nodes");
263                                    break;
264                                }
265                            }
266                        }
267                    }
268
269                    // Create the `alpha` tasks
270                    for _ in 0..a {
271                        match nodes_to_visit.pop() {
272                            Some(node) => {
273                                visited_nodes.insert(node.id());
274                                futures.push(self.fetch_nodes_sp(semaphore.clone(), node, key));
275                            }
276                            None => {
277                                break;
278                            }
279                        }
280                    }
281                }
282                Err(e) => {
283                    warn!(target: "dht::DhtHandler::lookup_nodes", "Error looking for nodes: {e}");
284                }
285            }
286        }
287
288        result.truncate(k);
289        return Ok(result.to_vec())
290    }
291
292    /// Get a channel (existing or create a new one) to `node` about `topic`.
293    /// Don't forget to call `cleanup_channel()` once you are done with it.
294    async fn get_channel(&self, node: &N, topic: Option<blake3::Hash>) -> Result<ChannelPtr> {
295        let channel_cache_lock = self.dht().channel_cache.clone();
296        let mut channel_cache = channel_cache_lock.write().await;
297
298        // Get existing channels for this node, regardless of topic
299        let channels: HashMap<u32, ChannelCacheItem<N>> = channel_cache
300            .iter()
301            .filter(|&(_, item)| item.node == *node)
302            .map(|(&key, item)| (key, item.clone()))
303            .collect();
304
305        let (channel_id, topic, usage_count) =
306            // If we already have a channel for this node and topic, use it
307            if let Some((cid, cached)) = channels.iter().find(|&(_, c)| c.topic == topic) {
308                (Some(*cid), cached.topic, cached.usage_count)
309            }
310            // If we have a topicless channel for this node, use it
311            else if let Some((cid, cached)) = channels.iter().find(|&(_, c)| c.topic.is_none()) {
312                (Some(*cid), topic, cached.usage_count)
313            }
314            // If we don't need any specific topic, use the first channel we have
315            else if topic.is_none() {
316                match channels.iter().next() {
317                    Some((cid, cached)) => (Some(*cid), cached.topic, cached.usage_count),
318                    _ => (None, topic, 0),
319                }
320            }
321            // There is no existing channel we can use, we will create one
322            else {
323                (None, topic, 0)
324            };
325
326        // If we found an existing channel we can use, try to use it
327        if let Some(channel_id) = channel_id {
328            if let Some(channel) = self.dht().p2p.get_channel(channel_id) {
329                if channel.session_type_id() & (SESSION_SEED | SESSION_REFINE) != 0 {
330                    return Err(Error::Custom(
331                        "Could not get a channel (for DHT) as this is a seed or refine session"
332                            .to_string(),
333                    ));
334                }
335
336                if channel.is_stopped() {
337                    channel.clone().start(self.dht().executor.clone());
338                }
339
340                channel_cache.insert(
341                    channel_id,
342                    ChannelCacheItem { node: node.clone(), topic, usage_count: usage_count + 1 },
343                );
344                return Ok(channel);
345            }
346        }
347
348        drop(channel_cache);
349
350        // Create a channel
351        for addr in node.addresses().clone() {
352            let session_out = self.dht().p2p.session_outbound();
353            let session_weak = Arc::downgrade(&self.dht().p2p.session_outbound());
354
355            let connector = Connector::new(self.dht().p2p.settings(), session_weak);
356            let dur = Duration::from_secs(self.dht().settings.timeout);
357            let Ok(connect_res) = timeout(dur, connector.connect(&addr)).await else {
358                warn!(target: "dht::DhtHandler::get_channel()", "Timeout trying to connect to {addr}");
359                return Err(Error::ConnectTimeout);
360            };
361            if connect_res.is_err() {
362                warn!(target: "dht::DhtHandler::get_channel()", "Error while connecting to {addr}: {}", connect_res.unwrap_err());
363                continue;
364            }
365            let (_, channel) = connect_res.unwrap();
366
367            if channel.session_type_id() & (SESSION_SEED | SESSION_REFINE) != 0 {
368                return Err(Error::Custom(
369                    "Could not create a channel (for DHT) as this is a seed or refine session"
370                        .to_string(),
371                ));
372            }
373
374            let register_res =
375                session_out.register_channel(channel.clone(), self.dht().executor.clone()).await;
376            if register_res.is_err() {
377                channel.clone().stop().await;
378                warn!(target: "dht::DhtHandler::get_channel()", "Error while registering channel {}: {}", channel.info.id, register_res.unwrap_err());
379                continue;
380            }
381
382            let mut channel_cache = channel_cache_lock.write().await;
383            channel_cache.insert(
384                channel.info.id,
385                ChannelCacheItem { node: node.clone(), topic, usage_count: 1 },
386            );
387
388            return Ok(channel)
389        }
390
391        Err(Error::Custom("Could not create channel".to_string()))
392    }
393
394    /// Decrement the channel usage count, if it becomes 0 then set the topic
395    /// to None, so that this channel is available for another task
396    async fn cleanup_channel(&self, channel: ChannelPtr) {
397        let channel_cache_lock = self.dht().channel_cache.clone();
398        let mut channel_cache = channel_cache_lock.write().await;
399
400        if let Some(cached) = channel_cache.get_mut(&channel.info.id) {
401            if cached.usage_count > 0 {
402                cached.usage_count -= 1;
403            }
404
405            // If the channel is not used by anything, remove the topic
406            if cached.usage_count == 0 {
407                cached.topic = None;
408            }
409        }
410    }
411
412    /// Add nodes as a provider for a key
413    async fn add_to_router(
414        &self,
415        router: DhtRouterPtr<N>,
416        key: &blake3::Hash,
417        router_items: Vec<DhtRouterItem<N>>,
418    ) where
419        N: 'async_trait,
420    {
421        let mut router_items = router_items.clone();
422        router_items.retain(|item| !item.node.addresses().is_empty());
423
424        debug!(target: "dht::DhtHandler::add_to_router()", "Inserting {} nodes to key {}", router_items.len(), bs58::encode(key.as_bytes()).into_string());
425
426        let mut router_write = router.write().await;
427        let key_r = router_write.get_mut(key);
428
429        let router_cache_lock = self.dht().router_cache.clone();
430        let mut router_cache = router_cache_lock.write().await;
431
432        // Add to router
433        if let Some(k) = key_r {
434            k.retain(|it| !router_items.contains(it));
435            k.extend(router_items.clone());
436        } else {
437            let mut hs = HashSet::new();
438            hs.extend(router_items.clone());
439            router_write.insert(*key, hs);
440        }
441
442        // Add to router_cache
443        for router_item in router_items {
444            let keys = router_cache.get_mut(&router_item.node.id());
445            if let Some(k) = keys {
446                k.insert(*key);
447            } else {
448                let mut keys = HashSet::new();
449                keys.insert(*key);
450                router_cache.insert(router_item.node.id(), keys);
451            }
452        }
453    }
454}