1use std::{any::Any, collections::HashMap, sync::Arc, time::Duration};
20
21use async_trait::async_trait;
22use futures::stream::{FuturesUnordered, StreamExt};
23use log::{debug, error};
24use rand::{rngs::OsRng, Rng};
25use smol::{io::AsyncReadExt, lock::Mutex};
26
27use super::message::Message;
28use crate::{
29 net::{metering::MeteringQueue, transport::PtStream},
30 system::{msleep, timeout::timeout},
31 Error, Result,
32};
33use darkfi_serial::{AsyncDecodable, VarInt};
34
35pub type MessageSubscriptionId = u64;
37type MessageResult<M> = Result<Arc<M>>;
38
39type DispatcherSubscriptionsMap<M> =
41 Mutex<HashMap<MessageSubscriptionId, smol::channel::Sender<(MessageResult<M>, Option<u64>)>>>;
42
43#[derive(Debug)]
53struct MessageDispatcher<M: Message> {
54 subs: DispatcherSubscriptionsMap<M>,
55 metering_queue: Mutex<MeteringQueue>,
56}
57
58impl<M: Message> MessageDispatcher<M> {
59 fn new() -> Self {
61 Self {
62 subs: Mutex::new(HashMap::new()),
63 metering_queue: Mutex::new(MeteringQueue::new(M::METERING_CONFIGURATION)),
64 }
65 }
66
67 fn random_id() -> MessageSubscriptionId {
69 OsRng.gen()
70 }
71
72 pub async fn subscribe(self: Arc<Self>) -> MessageSubscription<M> {
75 let (sender, recv_queue) = smol::channel::unbounded();
76 let mut id = Self::random_id();
78 let mut subs = self.subs.lock().await;
79 loop {
80 if subs.contains_key(&id) {
81 id = Self::random_id();
82 continue
83 }
84
85 subs.insert(id, sender);
86 break
87 }
88
89 drop(subs);
90 MessageSubscription { id, recv_queue, parent: self }
91 }
92
93 async fn unsubscribe(&self, sub_id: MessageSubscriptionId) {
96 self.subs.lock().await.remove(&sub_id);
97 }
98
99 async fn _trigger_all(&self, message: MessageResult<M>) {
102 let mut subs = self.subs.lock().await;
103
104 let msg_result_type = if message.is_ok() { "Ok" } else { "Err" };
105 debug!(
106 target: "net::message_publisher::_trigger_all()", "START msg={}({}), subs={}",
107 msg_result_type,
108 M::NAME, subs.len(),
109 );
110
111 let mut queue = self.metering_queue.lock().await;
113 queue.push(&M::METERING_SCORE);
114 let sleep_time = queue.sleep_time();
115 drop(queue);
116
117 let mut futures = FuturesUnordered::new();
118 let mut garbage_ids = vec![];
119
120 for (sub_id, sub) in &*subs {
122 let sub_id = *sub_id;
123 let sub = sub.clone();
124 let message = message.clone();
125 futures.push(async move {
126 match sub.send((message, sleep_time)).await {
127 Ok(res) => Ok((sub_id, res)),
128 Err(err) => Err((sub_id, err)),
129 }
130 });
131 }
132
133 while let Some(r) = futures.next().await {
135 if let Err((sub_id, _err)) = r {
136 garbage_ids.push(sub_id);
137 }
138 }
139
140 for sub_id in garbage_ids {
142 subs.remove(&sub_id);
143 }
144
145 debug!(
146 target: "net::message_publisher::_trigger_all()", "END msg={}({}), subs={}",
147 msg_result_type,
148 M::NAME, subs.len(),
149 );
150 }
151}
152
153#[derive(Debug)]
156pub struct MessageSubscription<M: Message> {
157 id: MessageSubscriptionId,
158 recv_queue: smol::channel::Receiver<(MessageResult<M>, Option<u64>)>,
159 parent: Arc<MessageDispatcher<M>>,
160}
161
162impl<M: Message> MessageSubscription<M> {
163 pub async fn receive(&self) -> MessageResult<M> {
167 let (message, sleep_time) = match self.recv_queue.recv().await {
168 Ok(pair) => pair,
169 Err(e) => panic!("MessageSubscription::receive(): recv_queue failed! {}", e),
170 };
171
172 if message.is_ok() {
174 if let Some(sleep_time) = sleep_time {
175 msleep(sleep_time).await;
176 }
177 }
178
179 message
180 }
181
182 pub async fn receive_with_timeout(&self, seconds: u64) -> MessageResult<M> {
184 let dur = Duration::from_secs(seconds);
185 let Ok(res) = timeout(dur, self.recv_queue.recv()).await else {
186 return Err(Error::ConnectTimeout)
187 };
188
189 let (message, sleep_time) = match res {
190 Ok(pair) => pair,
191 Err(e) => {
192 panic!("MessageSubscription::receive_with_timeout(): recv_queue failed! {}", e)
193 }
194 };
195
196 if message.is_ok() {
198 if let Some(sleep_time) = sleep_time {
199 msleep(sleep_time).await;
200 }
201 }
202
203 message
204 }
205
206 pub async fn clean(&self) -> Result<()> {
208 loop {
209 match self.recv_queue.try_recv() {
210 Ok(_) => continue,
211 Err(smol::channel::TryRecvError::Empty) => return Ok(()),
212 Err(e) => panic!("MessageSubscription::receive(): recv_queue failed! {}", e),
213 }
214 }
215 }
216
217 pub async fn unsubscribe(&self) {
219 self.parent.unsubscribe(self.id).await
220 }
221}
222
223#[async_trait]
225trait MessageDispatcherInterface: Send + Sync {
226 async fn trigger(
227 &self,
228 stream: &mut smol::io::ReadHalf<Box<dyn PtStream + 'static>>,
229 ) -> Result<()>;
230
231 async fn trigger_error(&self, err: Error);
232
233 async fn metering_score(&self) -> u64;
234
235 fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
236}
237
238#[async_trait]
240impl<M: Message> MessageDispatcherInterface for MessageDispatcher<M> {
241 async fn trigger(
248 &self,
249 stream: &mut smol::io::ReadHalf<Box<dyn PtStream + 'static>>,
250 ) -> Result<()> {
251 let length = match VarInt::decode_async(stream).await {
253 Ok(int) => int.0,
254 Err(err) => {
255 error!(
256 target: "net::message_publisher::trigger()",
257 "Unable to decode VarInt. Dropping...: {}",
258 err,
259 );
260 return Err(Error::MessageInvalid)
261 }
262 };
263
264 if M::MAX_BYTES > 0 && length > M::MAX_BYTES {
266 error!(
267 target: "net::message_publisher::trigger()",
268 "Message length ({}) exceeds configured limit ({}). Dropping...",
269 length, M::MAX_BYTES,
270 );
271 return Err(Error::MessageInvalid)
272 }
273
274 let mut take = stream.take(length);
276 let message = match M::decode_async(&mut take).await {
277 Ok(payload) => Ok(Arc::new(payload)),
278 Err(err) => {
279 error!(
280 target: "net::message_publisher::trigger()",
281 "Unable to decode data. Dropping...: {}",
282 err,
283 );
284 return Err(Error::MessageInvalid)
285 }
286 };
287
288 self._trigger_all(message).await;
290 Ok(())
291 }
292
293 async fn trigger_error(&self, err: Error) {
295 self._trigger_all(Err(err)).await;
296 }
297
298 async fn metering_score(&self) -> u64 {
301 let mut lock = self.metering_queue.lock().await;
302 lock.clean();
303 lock.total()
304 }
305
306 fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
308 self
309 }
310}
311
312#[derive(Default)]
320pub struct MessageSubsystem {
321 dispatchers: Mutex<HashMap<&'static str, Arc<dyn MessageDispatcherInterface>>>,
322 metering_limit: Mutex<u64>,
323}
324
325impl MessageSubsystem {
326 pub fn new() -> Self {
328 Self { dispatchers: Mutex::new(HashMap::new()), metering_limit: Mutex::new(0) }
329 }
330
331 pub async fn add_dispatch<M: Message>(&self) {
333 let mut lock = self.dispatchers.lock().await;
335
336 *self.metering_limit.lock().await += M::METERING_CONFIGURATION.threshold;
338
339 lock.insert(M::NAME, Arc::new(MessageDispatcher::<M>::new()));
341 }
342
343 pub async fn subscribe<M: Message>(&self) -> Result<MessageSubscription<M>> {
347 let dispatcher = self.dispatchers.lock().await.get(M::NAME).cloned();
348
349 let sub = match dispatcher {
350 Some(dispatcher) => {
351 let dispatcher: Arc<MessageDispatcher<M>> = dispatcher
352 .as_any()
353 .downcast::<MessageDispatcher<M>>()
354 .expect("Multiple messages registered with different names");
355
356 dispatcher.subscribe().await
357 }
358
359 None => {
360 return Err(Error::NetworkOperationFailed)
362 }
363 };
364
365 Ok(sub)
366 }
367
368 pub async fn notify(
371 &self,
372 command: &str,
373 reader: &mut smol::io::ReadHalf<Box<dyn PtStream + 'static>>,
374 ) -> Result<()> {
375 let mut found = false;
378 let mut total_score = 0;
379 for (name, dispatcher) in self.dispatchers.lock().await.iter() {
380 if name == &command {
382 dispatcher.trigger(reader).await?;
383 found = true;
384 }
385
386 total_score += dispatcher.metering_score().await;
388 }
389
390 if !found {
392 return Err(Error::MissingDispatcher)
393 }
394
395 if total_score > *self.metering_limit.lock().await {
397 return Err(Error::MeteringLimitExceeded)
398 }
399
400 Ok(())
401 }
402
403 pub async fn trigger_error(&self, err: Error) {
405 let mut futures = FuturesUnordered::new();
406
407 let dispatchers = self.dispatchers.lock().await;
408
409 for dispatcher in dispatchers.values() {
410 let dispatcher = dispatcher.clone();
411 let error = err.clone();
412 futures.push(async move { dispatcher.trigger_error(error).await });
413 }
414
415 drop(dispatchers);
416
417 while let Some(_r) = futures.next().await {}
418 }
419}