1use async_trait::async_trait;
20use futures::stream::FuturesUnordered;
21use log::{debug, info, warn};
22use num_bigint::BigUint;
23use smol::{lock::Semaphore, stream::StreamExt};
24use std::{
25 collections::{HashMap, HashSet},
26 marker::Sync,
27 sync::Arc,
28 time::Duration,
29};
30
31use super::{ChannelCacheItem, Dht, DhtNode, DhtRouterItem, DhtRouterPtr};
32use crate::{
33 geode::hash_to_string,
34 net::{
35 connector::Connector,
36 session::{Session, SESSION_REFINE, SESSION_SEED},
37 ChannelPtr, Message,
38 },
39 system::timeout::timeout,
40 Error, Result,
41};
42
43#[async_trait]
44pub trait DhtHandler<N: DhtNode>: Sync {
45 fn dht(&self) -> Arc<Dht<N>>;
46
47 async fn node(&self) -> N;
49
50 async fn ping(&self, channel: ChannelPtr) -> Result<N>;
52
53 async fn on_new_node(&self, node: &N) -> Result<()>;
55
56 async fn fetch_nodes(&self, node: &N, key: &blake3::Hash) -> Result<Vec<N>>;
58
59 async fn announce<M: Message>(
61 &self,
62 key: &blake3::Hash,
63 message: &M,
64 router: DhtRouterPtr<N>,
65 ) -> Result<()>
66 where
67 N: 'async_trait,
68 {
69 let self_node = self.node().await;
70 if self_node.addresses().is_empty() {
71 return Err(().into()); }
73
74 self.add_to_router(router.clone(), key, vec![self_node.clone().into()]).await;
75 let nodes = self.lookup_nodes(key).await?;
76 info!(target: "dht::DhtHandler::announce()", "Announcing {} to {} nodes", hash_to_string(key), nodes.len());
77
78 for node in nodes {
79 let channel_res = self.get_channel(&node, None).await;
80 if let Ok(channel) = channel_res {
81 let _ = channel.send(message).await;
82 self.cleanup_channel(channel).await;
83 }
84 }
85
86 Ok(())
87 }
88
89 async fn bootstrap(&self) {
91 self.dht().set_bootstrapped(true).await;
92
93 let self_node_id = self.node().await.id();
94 debug!(target: "dht::DhtHandler::bootstrap()", "DHT bootstrapping {}", hash_to_string(&self_node_id));
95 let nodes = self.lookup_nodes(&self_node_id).await;
96
97 if nodes.is_err() || nodes.map_or(true, |v| v.is_empty()) {
98 self.dht().set_bootstrapped(false).await;
99 }
100 }
101
102 async fn add_node(&self, node: N)
104 where
105 N: 'async_trait,
106 {
107 let self_node = self.node().await;
108
109 if node.id() == self_node.id() {
111 return;
112 }
113
114 let node_addresses = node.addresses();
116 if self_node.addresses().iter().any(|addr| node_addresses.contains(addr)) {
117 return;
118 }
119
120 if node.addresses().is_empty() {
122 return;
123 }
124
125 let bucket_index = self.dht().get_bucket_index(&self.node().await.id(), &node.id()).await;
126 let buckets_lock = self.dht().buckets.clone();
127 let mut buckets = buckets_lock.write().await;
128 let bucket = &mut buckets[bucket_index];
129
130 if bucket.nodes.iter().any(|n| n.id() == node.id()) {
132 return;
133 }
134
135 if bucket.nodes.len() >= self.dht().settings.k {
137 if let Ok(channel) = self.get_channel(&bucket.nodes[0], None).await {
139 let ping_res = self.ping(channel.clone()).await;
140 self.cleanup_channel(channel).await;
141 if ping_res.is_ok() {
142 let n = bucket.nodes.remove(0);
144 bucket.nodes.push(n);
145 return;
146 }
147 }
148
149 bucket.nodes.remove(0);
151 bucket.nodes.push(node);
152 return;
153 }
154
155 bucket.nodes.push(node);
157 }
158
159 async fn update_node(&self, node: &N) {
163 let bucket_index = self.dht().get_bucket_index(&self.node().await.id(), &node.id()).await;
164 let buckets_lock = self.dht().buckets.clone();
165 let mut buckets = buckets_lock.write().await;
166 let bucket = &mut buckets[bucket_index];
167
168 let node_index = bucket.nodes.iter().position(|n| n.id() == node.id());
169 if node_index.is_none() {
170 drop(buckets);
171 self.add_node(node.clone()).await;
172 return;
173 }
174
175 let n = bucket.nodes.remove(node_index.unwrap());
176 bucket.nodes.push(n);
177 }
178
179 async fn fetch_nodes_sp(
182 &self,
183 semaphore: Arc<Semaphore>,
184 node: N,
185 key: &blake3::Hash,
186 ) -> (N, Result<Vec<N>>)
187 where
188 N: 'async_trait,
189 {
190 let _permit = semaphore.acquire().await;
191 (node.clone(), self.fetch_nodes(&node, key).await)
192 }
193
194 async fn lookup_nodes(&self, key: &blake3::Hash) -> Result<Vec<N>> {
196 info!(target: "dht::DhtHandler::lookup_nodes()", "Starting node lookup for key {}", bs58::encode(key.as_bytes()).into_string());
197
198 let self_node_id = self.node().await.id();
199 let k = self.dht().settings.k;
200 let a = self.dht().settings.alpha;
201 let semaphore = Arc::new(Semaphore::new(self.dht().settings.concurrency));
202 let mut futures = FuturesUnordered::new();
203
204 let mut nodes_to_visit = self.dht().find_neighbors(key, k).await;
206 let mut visited_nodes = HashSet::<blake3::Hash>::new();
208 let mut result = Vec::<N>::new();
210
211 for _ in 0..a {
213 match nodes_to_visit.pop() {
214 Some(node) => {
215 visited_nodes.insert(node.id());
216 futures.push(self.fetch_nodes_sp(semaphore.clone(), node, key));
217 }
218 None => {
219 break;
220 }
221 }
222 }
223
224 while let Some((queried_node, value_result)) = futures.next().await {
225 match value_result {
226 Ok(mut nodes) => {
227 info!(target: "dht::DhtHandler::lookup_nodes()", "Queried {}, got {} nodes", hash_to_string(&queried_node.id()), nodes.len());
228
229 nodes.retain(|node| {
231 node.id() != self_node_id &&
232 !visited_nodes.contains(&node.id()) &&
233 !nodes_to_visit.iter().any(|n| n.id() == node.id())
234 });
235
236 for node in nodes.clone() {
238 self.add_node(node).await;
239 }
240
241 nodes_to_visit.extend(nodes.clone());
243 self.dht().sort_by_distance(&mut nodes_to_visit, key);
244
245 result.push(queried_node.clone());
247 self.dht().sort_by_distance(&mut result, key);
248
249 if result.len() >= k {
253 if let Some(furthest) = result.last() {
254 if let Some(next_node) = nodes_to_visit.first() {
255 let furthest_dist = BigUint::from_bytes_be(
256 &self.dht().distance(key, &furthest.id()),
257 );
258 let next_dist = BigUint::from_bytes_be(
259 &self.dht().distance(key, &next_node.id()),
260 );
261 if furthest_dist < next_dist {
262 info!(target: "dht::DhtHandler::lookup_nodes()", "Early termination for lookup nodes");
263 break;
264 }
265 }
266 }
267 }
268
269 for _ in 0..a {
271 match nodes_to_visit.pop() {
272 Some(node) => {
273 visited_nodes.insert(node.id());
274 futures.push(self.fetch_nodes_sp(semaphore.clone(), node, key));
275 }
276 None => {
277 break;
278 }
279 }
280 }
281 }
282 Err(e) => {
283 warn!(target: "dht::DhtHandler::lookup_nodes", "Error looking for nodes: {e}");
284 }
285 }
286 }
287
288 result.truncate(k);
289 return Ok(result.to_vec())
290 }
291
292 async fn get_channel(&self, node: &N, topic: Option<blake3::Hash>) -> Result<ChannelPtr> {
295 let channel_cache_lock = self.dht().channel_cache.clone();
296 let mut channel_cache = channel_cache_lock.write().await;
297
298 let channels: HashMap<u32, ChannelCacheItem<N>> = channel_cache
300 .iter()
301 .filter(|&(_, item)| item.node == *node)
302 .map(|(&key, item)| (key, item.clone()))
303 .collect();
304
305 let (channel_id, topic, usage_count) =
306 if let Some((cid, cached)) = channels.iter().find(|&(_, c)| c.topic == topic) {
308 (Some(*cid), cached.topic, cached.usage_count)
309 }
310 else if let Some((cid, cached)) = channels.iter().find(|&(_, c)| c.topic.is_none()) {
312 (Some(*cid), topic, cached.usage_count)
313 }
314 else if topic.is_none() {
316 match channels.iter().next() {
317 Some((cid, cached)) => (Some(*cid), cached.topic, cached.usage_count),
318 _ => (None, topic, 0),
319 }
320 }
321 else {
323 (None, topic, 0)
324 };
325
326 if let Some(channel_id) = channel_id {
328 if let Some(channel) = self.dht().p2p.get_channel(channel_id) {
329 if channel.session_type_id() & (SESSION_SEED | SESSION_REFINE) != 0 {
330 return Err(Error::Custom(
331 "Could not get a channel (for DHT) as this is a seed or refine session"
332 .to_string(),
333 ));
334 }
335
336 if channel.is_stopped() {
337 channel.clone().start(self.dht().executor.clone());
338 }
339
340 channel_cache.insert(
341 channel_id,
342 ChannelCacheItem { node: node.clone(), topic, usage_count: usage_count + 1 },
343 );
344 return Ok(channel);
345 }
346 }
347
348 drop(channel_cache);
349
350 for addr in node.addresses().clone() {
352 let session_out = self.dht().p2p.session_outbound();
353 let session_weak = Arc::downgrade(&self.dht().p2p.session_outbound());
354
355 let connector = Connector::new(self.dht().p2p.settings(), session_weak);
356 let dur = Duration::from_secs(self.dht().settings.timeout);
357 let Ok(connect_res) = timeout(dur, connector.connect(&addr)).await else {
358 warn!(target: "dht::DhtHandler::get_channel()", "Timeout trying to connect to {addr}");
359 return Err(Error::ConnectTimeout);
360 };
361 if connect_res.is_err() {
362 warn!(target: "dht::DhtHandler::get_channel()", "Error while connecting to {addr}: {}", connect_res.unwrap_err());
363 continue;
364 }
365 let (_, channel) = connect_res.unwrap();
366
367 if channel.session_type_id() & (SESSION_SEED | SESSION_REFINE) != 0 {
368 return Err(Error::Custom(
369 "Could not create a channel (for DHT) as this is a seed or refine session"
370 .to_string(),
371 ));
372 }
373
374 let register_res =
375 session_out.register_channel(channel.clone(), self.dht().executor.clone()).await;
376 if register_res.is_err() {
377 channel.clone().stop().await;
378 warn!(target: "dht::DhtHandler::get_channel()", "Error while registering channel {}: {}", channel.info.id, register_res.unwrap_err());
379 continue;
380 }
381
382 let mut channel_cache = channel_cache_lock.write().await;
383 channel_cache.insert(
384 channel.info.id,
385 ChannelCacheItem { node: node.clone(), topic, usage_count: 1 },
386 );
387
388 return Ok(channel)
389 }
390
391 Err(Error::Custom("Could not create channel".to_string()))
392 }
393
394 async fn cleanup_channel(&self, channel: ChannelPtr) {
397 let channel_cache_lock = self.dht().channel_cache.clone();
398 let mut channel_cache = channel_cache_lock.write().await;
399
400 if let Some(cached) = channel_cache.get_mut(&channel.info.id) {
401 if cached.usage_count > 0 {
402 cached.usage_count -= 1;
403 }
404
405 if cached.usage_count == 0 {
407 cached.topic = None;
408 }
409 }
410 }
411
412 async fn add_to_router(
414 &self,
415 router: DhtRouterPtr<N>,
416 key: &blake3::Hash,
417 router_items: Vec<DhtRouterItem<N>>,
418 ) where
419 N: 'async_trait,
420 {
421 let mut router_items = router_items.clone();
422 router_items.retain(|item| !item.node.addresses().is_empty());
423
424 debug!(target: "dht::DhtHandler::add_to_router()", "Inserting {} nodes to key {}", router_items.len(), bs58::encode(key.as_bytes()).into_string());
425
426 let mut router_write = router.write().await;
427 let key_r = router_write.get_mut(key);
428
429 let router_cache_lock = self.dht().router_cache.clone();
430 let mut router_cache = router_cache_lock.write().await;
431
432 if let Some(k) = key_r {
434 k.retain(|it| !router_items.contains(it));
435 k.extend(router_items.clone());
436 } else {
437 let mut hs = HashSet::new();
438 hs.extend(router_items.clone());
439 router_write.insert(*key, hs);
440 }
441
442 for router_item in router_items {
444 let keys = router_cache.get_mut(&router_item.node.id());
445 if let Some(k) = keys {
446 k.insert(*key);
447 } else {
448 let mut keys = HashSet::new();
449 keys.insert(*key);
450 router_cache.insert(router_item.node.id(), keys);
451 }
452 }
453 }
454}