PaperTrader/src/libtrader/common/message.rs

100 lines
2.8 KiB
Rust

use log::warn;
use serde::{Deserialize, Serialize};
pub use crate::common::command::*;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
#[cfg(all(feature = "client", not(feature = "server")))]
use tokio_rustls::client::TlsStream;
#[cfg(all(feature = "server", not(feature = "client")))]
use tokio_rustls::server::TlsStream;
#[derive(Serialize, Deserialize, PartialEq, Debug, Default)]
pub struct Message {
pub command: Command,
pub data: Vec<u8>,
}
impl Message {
pub fn new() -> Message {
Message {
command: Command::Default,
data: Vec::new(),
}
}
pub fn command<'a>(&'a mut self, command: Command) -> &'a mut Message {
self.command = command;
self
}
pub fn data<'a, T>(&'a mut self, data: T) -> &'a mut Message
where
T: serde::Serialize,
{
self.data = bincode::serialize(&data).unwrap();
self
}
pub async fn send(&self, socket: &mut TlsStream<TcpStream>) -> std::io::Result<()> {
/* automagically log failed commands */
if self.assert_command(Command::Failure) {
warn!("Operation failed, sending error to client.");
warn!("Error: {}", self.get_err());
}
socket
.write_all(bincode::serialize(self).unwrap().as_slice())
.await
}
pub async fn receive<'a>(socket: &mut TlsStream<TcpStream>) -> std::io::Result<Message> {
let mut buf = Vec::with_capacity(4096);
socket.read_buf(&mut buf).await?;
Ok(bincode::deserialize(&buf).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Failed parsing recieved message."),
)
})?)
}
/// Use this function for more easier bincode conversion of received data.
pub fn get_data<'a, T>(&'a self) -> Result<T, String>
where
T: serde::Deserialize<'a>,
{
bincode::deserialize(&self.data).map_err(|_| format!("Failed getting message data."))
}
pub fn get_err(&self) -> &str {
// Tests should handle this function being called not in the
// correct context.
// i.e. tests should ensure that functions error out gracefully.
bincode::deserialize(&self.data).unwrap()
}
pub fn assert_command(&self, command: Command) -> bool {
if self.command != command {
false
} else {
true
}
}
pub fn assert_data(&self) -> bool {
if self.data.len() == 0 {
false
} else {
true
}
}
}
impl std::fmt::Display for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({}, {:#?})", self.command, self.data)
}
}