darkfi/net/
message_publisher.rs1use std::{any::Any, collections::HashMap, sync::Arc, time::Duration};
20
21use async_trait::async_trait;
22use futures::stream::{FuturesUnordered, StreamExt};
23use log::{debug, error};
24use rand::{rngs::OsRng, Rng};
25use smol::{io::AsyncReadExt, lock::Mutex};
26
27use super::message::Message;
28use crate::{
29 net::{metering::MeteringQueue, transport::PtStream},
30 system::{msleep, timeout::timeout},
31 Error, Result,
32};
33use darkfi_serial::{AsyncDecodable, VarInt};
34
35pub type MessageSubscriptionId = u64;
37type MessageResult<M> = Result<Arc<M>>;
38
39type DispatcherSubscriptionsMap<M> =
41 Mutex<HashMap<MessageSubscriptionId, smol::channel::Sender<(MessageResult<M>, Option<u64>)>>>;
42
43#[derive(Debug)]
53struct MessageDispatcher<M: Message> {
54 subs: DispatcherSubscriptionsMap<M>,
55 metering_queue: Mutex<MeteringQueue>,
56}
57
58impl<M: Message> MessageDispatcher<M> {
59 fn new() -> Self {
61 Self {
62 subs: Mutex::new(HashMap::new()),
63 metering_queue: Mutex::new(MeteringQueue::new(M::METERING_CONFIGURATION)),
64 }
65 }
66
67 fn random_id() -> MessageSubscriptionId {
69 OsRng.gen()
70 }
71
72 pub async fn subscribe(self: Arc<Self>) -> MessageSubscription<M> {
75 let (sender, recv_queue) = smol::channel::unbounded();
76 let mut id = Self::random_id();
78 let mut subs = self.subs.lock().await;
79 loop {
80 if subs.contains_key(&id) {
81 id = Self::random_id();
82 continue
83 }
84
85 subs.insert(id, sender);
86 break
87 }
88
89 drop(subs);
90 MessageSubscription { id, recv_queue, parent: self }
91 }
92
93 async fn unsubscribe(&self, sub_id: MessageSubscriptionId) {
96 self.subs.lock().await.remove(&sub_id);
97 }
98
99 async fn _trigger_all(&self, message: MessageResult<M>) {
102 let mut subs = self.subs.lock().await;
103
104 let msg_result_type = if message.is_ok() { "Ok" } else { "Err" };
105 debug!(
106 target: "net::message_publisher::_trigger_all()", "START msg={msg_result_type}({}), subs={}",
107 M::NAME, subs.len()
108 );
109
110 let mut queue = self.metering_queue.lock().await;
112 queue.push(&M::METERING_SCORE);
113 let sleep_time = queue.sleep_time();
114 drop(queue);
115
116 let mut futures = FuturesUnordered::new();
117 let mut garbage_ids = vec![];
118
119 for (sub_id, sub) in &*subs {
121 let sub_id = *sub_id;
122 let sub = sub.clone();
123 let message = message.clone();
124 futures.push(async move {
125 match sub.send((message, sleep_time)).await {
126 Ok(res) => Ok((sub_id, res)),
127 Err(err) => Err((sub_id, err)),
128 }
129 });
130 }
131
132 while let Some(r) = futures.next().await {
134 if let Err((sub_id, _err)) = r {
135 garbage_ids.push(sub_id);
136 }
137 }
138
139 for sub_id in garbage_ids {
141 subs.remove(&sub_id);
142 }
143
144 debug!(
145 target: "net::message_publisher::_trigger_all()", "END msg={msg_result_type}({}), subs={}",
146 M::NAME, subs.len(),
147 );
148 }
149}
150
151#[derive(Debug)]
154pub struct MessageSubscription<M: Message> {
155 id: MessageSubscriptionId,
156 recv_queue: smol::channel::Receiver<(MessageResult<M>, Option<u64>)>,
157 parent: Arc<MessageDispatcher<M>>,
158}
159
160impl<M: Message> MessageSubscription<M> {
161 pub async fn receive(&self) -> MessageResult<M> {
165 let (message, sleep_time) = match self.recv_queue.recv().await {
166 Ok(pair) => pair,
167 Err(e) => panic!("MessageSubscription::receive(): recv_queue failed! {e}"),
168 };
169
170 if message.is_ok() {
172 if let Some(sleep_time) = sleep_time {
173 msleep(sleep_time).await;
174 }
175 }
176
177 message
178 }
179
180 pub async fn receive_with_timeout(&self, seconds: u64) -> MessageResult<M> {
182 let dur = Duration::from_secs(seconds);
183 let Ok(res) = timeout(dur, self.recv_queue.recv()).await else {
184 return Err(Error::ConnectTimeout)
185 };
186
187 let (message, sleep_time) = match res {
188 Ok(pair) => pair,
189 Err(e) => {
190 panic!("MessageSubscription::receive_with_timeout(): recv_queue failed! {e}")
191 }
192 };
193
194 if message.is_ok() {
196 if let Some(sleep_time) = sleep_time {
197 msleep(sleep_time).await;
198 }
199 }
200
201 message
202 }
203
204 pub async fn clean(&self) -> Result<()> {
206 loop {
207 match self.recv_queue.try_recv() {
208 Ok(_) => continue,
209 Err(smol::channel::TryRecvError::Empty) => return Ok(()),
210 Err(e) => panic!("MessageSubscription::receive(): recv_queue failed! {e}"),
211 }
212 }
213 }
214
215 pub async fn unsubscribe(&self) {
217 self.parent.unsubscribe(self.id).await
218 }
219}
220
221#[async_trait]
223trait MessageDispatcherInterface: Send + Sync {
224 async fn trigger(
225 &self,
226 stream: &mut smol::io::ReadHalf<Box<dyn PtStream + 'static>>,
227 ) -> Result<()>;
228
229 async fn trigger_error(&self, err: Error);
230
231 async fn metering_score(&self) -> u64;
232
233 fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
234}
235
236#[async_trait]
238impl<M: Message> MessageDispatcherInterface for MessageDispatcher<M> {
239 async fn trigger(
246 &self,
247 stream: &mut smol::io::ReadHalf<Box<dyn PtStream + 'static>>,
248 ) -> Result<()> {
249 let length = match VarInt::decode_async(stream).await {
251 Ok(int) => int.0,
252 Err(err) => {
253 error!(
254 target: "net::message_publisher::trigger()",
255 "Unable to decode VarInt. Dropping...: {err}"
256 );
257 return Err(Error::MessageInvalid)
258 }
259 };
260
261 if M::MAX_BYTES > 0 && length > M::MAX_BYTES {
263 error!(
264 target: "net::message_publisher::trigger()",
265 "Message length ({length}) exceeds configured limit ({}). Dropping...",
266 M::MAX_BYTES
267 );
268 return Err(Error::MessageInvalid)
269 }
270
271 let mut take = stream.take(length);
273 let message = match M::decode_async(&mut take).await {
274 Ok(payload) => Ok(Arc::new(payload)),
275 Err(err) => {
276 error!(
277 target: "net::message_publisher::trigger()",
278 "Unable to decode data. Dropping...: {err}"
279 );
280 return Err(Error::MessageInvalid)
281 }
282 };
283
284 self._trigger_all(message).await;
286 Ok(())
287 }
288
289 async fn trigger_error(&self, err: Error) {
291 self._trigger_all(Err(err)).await;
292 }
293
294 async fn metering_score(&self) -> u64 {
297 let mut lock = self.metering_queue.lock().await;
298 lock.clean();
299 lock.total()
300 }
301
302 fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
304 self
305 }
306}
307
308#[derive(Default)]
316pub struct MessageSubsystem {
317 dispatchers: Mutex<HashMap<&'static str, Arc<dyn MessageDispatcherInterface>>>,
318 metering_limit: Mutex<u64>,
319}
320
321impl MessageSubsystem {
322 pub fn new() -> Self {
324 Self { dispatchers: Mutex::new(HashMap::new()), metering_limit: Mutex::new(0) }
325 }
326
327 pub async fn add_dispatch<M: Message>(&self) {
329 let mut lock = self.dispatchers.lock().await;
331
332 *self.metering_limit.lock().await += M::METERING_CONFIGURATION.threshold;
334
335 lock.insert(M::NAME, Arc::new(MessageDispatcher::<M>::new()));
337 }
338
339 pub async fn subscribe<M: Message>(&self) -> Result<MessageSubscription<M>> {
343 let dispatcher = self.dispatchers.lock().await.get(M::NAME).cloned();
344
345 let sub = match dispatcher {
346 Some(dispatcher) => {
347 let dispatcher: Arc<MessageDispatcher<M>> = dispatcher
348 .as_any()
349 .downcast::<MessageDispatcher<M>>()
350 .expect("Multiple messages registered with different names");
351
352 dispatcher.subscribe().await
353 }
354
355 None => {
356 return Err(Error::NetworkOperationFailed)
358 }
359 };
360
361 Ok(sub)
362 }
363
364 pub async fn notify(
367 &self,
368 command: &str,
369 reader: &mut smol::io::ReadHalf<Box<dyn PtStream + 'static>>,
370 ) -> Result<()> {
371 let mut found = false;
374 let mut total_score = 0;
375 for (name, dispatcher) in self.dispatchers.lock().await.iter() {
376 if name == &command {
378 dispatcher.trigger(reader).await?;
379 found = true;
380 }
381
382 total_score += dispatcher.metering_score().await;
384 }
385
386 if !found {
388 return Err(Error::MissingDispatcher)
389 }
390
391 if total_score > *self.metering_limit.lock().await {
393 return Err(Error::MeteringLimitExceeded)
394 }
395
396 Ok(())
397 }
398
399 pub async fn trigger_error(&self, err: Error) {
401 let mut futures = FuturesUnordered::new();
402
403 let dispatchers = self.dispatchers.lock().await;
404
405 for dispatcher in dispatchers.values() {
406 let dispatcher = dispatcher.clone();
407 let error = err.clone();
408 futures.push(async move { dispatcher.trigger_error(error).await });
409 }
410
411 drop(dispatchers);
412
413 while let Some(_r) = futures.next().await {}
414 }
415}