big rewrite of basic components
continuous-integration/drone/push Build is failing Details

- rewrite of Message structure
- rewrite of method of handling Messages
- rewirte of handle_data
- rewrite of JWT file structures
- build but does not actually work
- next step is to conver to a more generic password (also email) hashing
  library, the kind of password_hash
This commit is contained in:
ayham 2021-08-13 09:00:23 +03:00
parent 915483c465
commit 4f420925cf
Signed by: ayham
GPG Key ID: EAB7F5A9DF503678
54 changed files with 852 additions and 1872 deletions

View File

@ -24,7 +24,7 @@ client = []
tls_no_verify = ["tokio-rustls/dangerous_configuration"]
[dependencies]
argh = "*"
argh = "0.1.5"
chrono = "0.4"
tokio = { version = "1.6.1", features = [ "full" ] }
tokio-io = { version = "0.1.13" }
@ -32,23 +32,17 @@ tokio-rustls = { version = "0.22.0" }
tokio-util = { version = "0.6.7" }
tokio-postgres = { version = "0.7.2" }
webpki-roots = { version = "0.21" }
futures = "*"
bytes = "*"
futures = "0.3.16"
#postgres = { version = "0.4.0" }
postgres-types = { version = "0.2.1", features = ["derive"] }
log = "0.4"
fern = { version = "0.6.0", features = ["colored"] }
enum_primitive = "*"
os_type="2.2"
ring="*"
data-encoding="*"
bincode="*"
ring="0.16.20"
bincode="1.3.3"
serde = { version = "1.0", features = ["derive"] }
ct-logs="0.7"
either="*"
arrayref="*"
rust-crypto="0.2.36"
jsonwebtoken="*"
json="*"
bitflags="*"
rand="*"
jsonwebtoken="7.2.0"
json="0.12.4"
rand="0.8.4"

View File

@ -1,18 +1,11 @@
use data_encoding::HEXUPPER;
use ring::digest;
use std::io;
use crate::common::account::hash::hash;
use crate::common::message::assert_msg::assert_msg;
use crate::common::message::inst::CommandInst;
use crate::common::message::message::Message;
use crate::common::message::message_builder::message_builder;
use crate::common::message::message_type::MessageType;
use crate::common::misc::return_flags::ReturnFlags;
use crate::common::command::*;
use crate::common::message::*;
use crate::client::network::cmd::req_server_salt::req_server_salt;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
@ -44,86 +37,61 @@ pub async fn acc_auth(
username: &str,
email: &str,
password: &str,
) -> io::Result<String> {
) -> std::io::Result<String> {
/*
* get email salt
* */
let email_salt: [u8; digest::SHA512_OUTPUT_LEN] =
req_server_salt(socket, username, CommandInst::GetEmailSalt as i64).await?;
let email_salt = req_server_salt(socket, username, Command::GetEmailSalt).await?;
/*
* get password salt
* */
let password_salt: [u8; digest::SHA512_OUTPUT_LEN] =
req_server_salt(socket, username, CommandInst::GetPasswordSalt as i64).await?;
let password_salt = req_server_salt(socket, username, Command::GetPasswordSalt).await?;
/*
* hash the email
*/
let hashed_email = hash(&email.as_bytes().to_vec(), &email_salt.to_vec(), 175_000);
let hashed_email = hash(email, &email_salt, 175_000);
/*
* hash the password
*/
let hashed_password = hash(
&password.as_bytes().to_vec(),
&password_salt.to_vec(),
250_000,
);
let hashed_password = hash(password, &password_salt, 250_000);
/* generate message to be sent to the server */
let data = object! {
hashed_email: HEXUPPER.encode(&hashed_email),
hashed_password: HEXUPPER.encode(&hashed_password),
hashed_email: hashed_email,
hashed_password: hashed_password,
username: username
};
let message = message_builder(
MessageType::Command,
CommandInst::LoginMethod1 as i64,
3,
0,
0,
data.dump().as_bytes().to_vec(),
);
socket
.write_all(&bincode::serialize(&message).unwrap())
/* build message request */
Message::new()
.command(Command::LoginMethod1)
.data(data.to_string())
.send(socket)
.await?;
/* decode response */
let mut buf = Vec::with_capacity(4096);
socket.read_buf(&mut buf).await?;
let response: Message = bincode::deserialize(&buf).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("{}", ReturnFlags::ClientAccUnauthorized),
)
})?;
let ret_msg = Message::receive(socket).await?;
if assert_msg(
&response,
MessageType::ServerReturn,
true,
1,
false,
0,
false,
0,
false,
0,
) && response.data.len() != 0
&& response.instruction == 1
{
/* decode response */
if !ret_msg.assert_command(Command::Success) || !ret_msg.assert_data() {
/* authorized */
return Ok(String::from_utf8(response.data).map_err(|_| {
Ok(ret_msg.get_data().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("{}", ReturnFlags::ClientAccInvalidSessionId),
)
})?);
format!(
"Failed authorizing account, {}, server returned invalid data.",
username))
})?)
} else {
return Err(io::Error::new(
Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("{}", ReturnFlags::ClientAccUnauthorized),
));
format!(
"Failed authorizing account, {}, server returned error: {}.",
username,
ret_msg.get_err()
),
))
}
}

View File

@ -1,20 +1,13 @@
use data_encoding::HEXUPPER;
use ring::digest;
use std::io;
use crate::common::command::*;
use crate::common::message::*;
use crate::client::account::hash_email::hash_email;
use crate::client::account::hash_pwd::hash_pwd;
use crate::common::message::assert_msg::assert_msg;
use crate::common::message::inst::CommandInst;
use crate::common::message::message::Message;
use crate::common::message::message_builder::message_builder;
use crate::common::message::message_type::MessageType;
use crate::common::misc::return_flags::ReturnFlags;
use crate::client::network::cmd::get_server_salt::get_server_salt;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
@ -46,68 +39,50 @@ pub async fn acc_create(
username: &str,
email: &str,
password: &str,
) -> io::Result<()> {
) -> std::io::Result<()> {
/*
* get two server salts for email, and password
* */
let email_server_salt: [u8; digest::SHA512_OUTPUT_LEN / 2] = get_server_salt(socket).await?;
let password_server_salt: [u8; digest::SHA512_OUTPUT_LEN / 2] = get_server_salt(socket).await?;
let email_server_salt = get_server_salt(socket).await?;
let password_server_salt = get_server_salt(socket).await?;
/*
* generate hashes for email, password
* */
let email_hash = hash_email(&email.as_bytes().to_vec(), email_server_salt);
let password_hash = hash_pwd(&password.as_bytes().to_vec(), password_server_salt);
let email_hash = hash_email(email, &email_server_salt);
let password_hash = hash_pwd(password, &password_server_salt);
/* generate message to be sent to the server */
let data = object! {
email_hash: HEXUPPER.encode(&email_hash.0),
email_client_salt: HEXUPPER.encode(&email_hash.1),
password_hash: HEXUPPER.encode(&password_hash.0),
password_client_salt: HEXUPPER.encode(&password_hash.1),
email_hash: email_hash.0,
email_client_salt: email_hash.1,
password_hash: password_hash.0,
password_client_salt: password_hash.1,
username: username
};
let message = message_builder(
MessageType::Command,
CommandInst::Register as i64,
5,
0,
0,
data.dump().as_bytes().to_vec(),
);
socket
.write_all(&bincode::serialize(&message).unwrap())
/* build message request */
Message::new()
.command(Command::Register)
.data(data.to_string())
.send(socket)
.await?;
/* decode response */
let mut buf = Vec::with_capacity(4096);
socket.read_buf(&mut buf).await?;
let response: Message = bincode::deserialize(&buf).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("{}", ReturnFlags::ClientTlsReadError),
)
})?;
if !assert_msg(
&response,
MessageType::ServerReturn,
true,
1,
false,
0,
false,
0,
false,
0,
) && response.instruction == 1
{
let ret_msg = Message::receive(socket).await?;
/* assert received message */
if !ret_msg.assert_command(Command::Success) || !ret_msg.assert_data() {
/* created successfully */
return Ok(());
} else {
/* server rejected account creation */
return Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("{}", ReturnFlags::ClientAccCreationFailed),
format!(
"Failed creating account for user, {}, server reason: {}",
username,
ret_msg.get_err()
),
));
}
}

View File

@ -1,7 +1,5 @@
use ring::rand::SecureRandom;
use ring::{digest, rand};
use crate::common::account::hash::hash;
use crate::common::account::hash::*;
use crate::common::account::salt::*;
/// Generates a client email hash from a raw email.
///
@ -24,61 +22,10 @@ use crate::common::account::hash::hash;
/// println!("Client Email Salt: {}", HEXUPPER.encode(&enc.1));
/// ```
pub fn hash_email(
email: &Vec<u8>,
server_salt: [u8; digest::SHA512_OUTPUT_LEN / 2],
) -> (
[u8; digest::SHA512_OUTPUT_LEN],
[u8; digest::SHA512_OUTPUT_LEN],
) {
// client hash, full salt
let rng = rand::SystemRandom::new();
let mut client_salt = [0u8; digest::SHA512_OUTPUT_LEN / 2];
rng.fill(&mut client_salt).unwrap();
let salt = [server_salt, client_salt].concat();
let hash = hash(email, &salt, 175_000);
(hash, *array_ref!(salt, 0, digest::SHA512_OUTPUT_LEN))
}
#[cfg(test)]
mod test {
use super::*;
use data_encoding::HEXUPPER;
#[test]
fn test_account_hash_email_client() {
let email = "totallyrealemail@anemail.c0m";
/* generate server salt */
let rng = rand::SystemRandom::new();
let mut server_salt = [0u8; digest::SHA512_OUTPUT_LEN / 2];
rng.fill(&mut server_salt).unwrap();
/* ensure that hash_email_client() works */
let output = hash_email(&email.as_bytes().to_vec(), server_salt);
assert_ne!(output.0.len(), 0);
assert_ne!(output.1.len(), 0);
/* ensure that hash_email_client() doesn't generate same output
* with the same server salt.
* */
let mut enc0 = hash_email(&email.as_bytes().to_vec(), server_salt);
let mut enc1 = hash_email(&email.as_bytes().to_vec(), server_salt);
assert_ne!(HEXUPPER.encode(&enc0.0), HEXUPPER.encode(&enc1.0));
assert_ne!(HEXUPPER.encode(&enc0.1), HEXUPPER.encode(&enc1.1));
/* ensure that hash_email_client() generates a different output
* with different server salts
* */
// Generate new server salt.
let mut server_salt2 = [0u8; digest::SHA512_OUTPUT_LEN / 2];
rng.fill(&mut server_salt2).unwrap();
enc0 = hash_email(&email.as_bytes().to_vec(), server_salt);
enc1 = hash_email(&email.as_bytes().to_vec(), server_salt);
assert_ne!(HEXUPPER.encode(&enc0.0), HEXUPPER.encode(&enc1.0));
assert_ne!(HEXUPPER.encode(&enc0.1), HEXUPPER.encode(&enc1.1));
}
email: &str,
server_salt: &str) -> (String, String) {
let mut salt = gen_salt();
salt.push_str(server_salt);
let hash = hash(&email, &salt, 175_000);
(hash, salt.to_string())
}

View File

@ -1,7 +1,5 @@
use ring::rand::SecureRandom;
use ring::{digest, rand};
use crate::common::account::hash::hash;
use crate::common::account::hash::*;
use crate::common::account::salt::*;
/// Generates a client password hash from a raw password.
///
@ -22,63 +20,9 @@ use crate::common::account::hash::hash;
/// println!("Client Pass Hash: {}", HEXUPPER.encode(&enc.0));
/// println!("Client Pass Salt: {}", HEXUPPER.encode(&enc.1));
/// ```
pub fn hash_pwd(
pass: &Vec<u8>,
server_salt: [u8; digest::SHA512_OUTPUT_LEN / 2],
) -> (
[u8; digest::SHA512_OUTPUT_LEN],
[u8; digest::SHA512_OUTPUT_LEN],
) {
// client hash, full salt
let rng = rand::SystemRandom::new();
let mut client_salt = [0u8; digest::SHA512_OUTPUT_LEN / 2];
rng.fill(&mut client_salt).unwrap();
let salt = [server_salt, client_salt].concat();
let hash = hash(pass, &salt, 250_000);
(hash, *array_ref!(salt, 0, digest::SHA512_OUTPUT_LEN))
}
#[cfg(test)]
mod test {
use super::*;
use data_encoding::HEXUPPER;
#[test]
fn test_account_hash_pwd_client() {
let pass = "goodlilpassword";
/* generate server salt */
let rng = rand::SystemRandom::new();
let mut server_salt = [0u8; digest::SHA512_OUTPUT_LEN / 2];
rng.fill(&mut server_salt).unwrap();
/* ensure that hash_pwd_client() works */
let output = hash_pwd(&pass.as_bytes().to_vec(), server_salt);
assert_ne!(output.0.len(), 0);
assert_ne!(output.1.len(), 0);
/* ensure that hash_pwd_client() doesn't generate same output
* with the same server salt.
* */
let mut enc0 = hash_pwd(&pass.as_bytes().to_vec(), server_salt);
let mut enc1 = hash_pwd(&pass.as_bytes().to_vec(), server_salt);
assert_ne!(HEXUPPER.encode(&enc0.0), HEXUPPER.encode(&enc1.0));
assert_ne!(HEXUPPER.encode(&enc0.1), HEXUPPER.encode(&enc1.1));
/* ensure that hash_pwd_client() generates different output
* with different server salts.
* */
// Generate new server salt.
let mut server_salt2 = [0u8; digest::SHA512_OUTPUT_LEN / 2];
rng.fill(&mut server_salt2).unwrap();
enc0 = hash_pwd(&pass.as_bytes().to_vec(), server_salt);
enc1 = hash_pwd(&pass.as_bytes().to_vec(), server_salt);
assert_ne!(HEXUPPER.encode(&enc0.0), HEXUPPER.encode(&enc1.0));
assert_ne!(HEXUPPER.encode(&enc0.1), HEXUPPER.encode(&enc1.1));
}
pub fn hash_pwd(pass: &str, server_salt: &str) -> (String, String) {
let mut salt = gen_salt();
salt.push_str(server_salt);
let hash = hash(&pass, &salt, 250_000);
(hash, salt)
}

View File

@ -1,14 +1,9 @@
use std::io;
use crate::common::account::portfolio::Portfolio;
use crate::common::message::assert_msg::assert_msg;
use crate::common::message::inst::DataTransferInst;
use crate::common::message::message::Message;
use crate::common::message::message_builder::message_builder;
use crate::common::message::message_type::MessageType;
use crate::common::misc::return_flags::ReturnFlags;
use crate::common::command::*;
use crate::common::message::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
@ -32,7 +27,7 @@ use tokio_rustls::client::TlsStream;
pub async fn acc_retrieve_portfolio(
socket: &mut TlsStream<TcpStream>,
auth_jwt: String,
) -> io::Result<Portfolio> {
) -> std::io::Result<Portfolio> {
if auth_jwt.is_empty() == true {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
@ -41,51 +36,30 @@ pub async fn acc_retrieve_portfolio(
}
/* build message request */
let message = message_builder(
MessageType::Command,
DataTransferInst::GetUserPortfolio as i64,
1,
0,
0,
bincode::serialize(&auth_jwt).unwrap(),
);
socket
.write_all(&bincode::serialize(&message).unwrap())
Message::new()
.command(Command::GetUserPortfolio)
.data(auth_jwt)
.send(socket)
.await?;
/* decode response */
let mut buf = Vec::with_capacity(4096);
socket.read_buf(&mut buf).await?;
let ret_msg: Message = Message::receive(socket).await?;
let response: Message = bincode::deserialize(&buf).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("{}", ReturnFlags::ClientAccRetrievePortfolioError),
)
})?;
if assert_msg(
&response,
MessageType::DataTransfer,
true,
1,
false,
0,
false,
0,
false,
0,
) && response.instruction == 1
&& response.data.len() != 0
{
/* assert received message */
if !ret_msg.assert_command(Command::Success) || !ret_msg.assert_data() {
/* returned data */
let portfolio: Portfolio = bincode::deserialize(&response.data).unwrap();
return Ok(portfolio);
let portfolio: Portfolio = bincode::deserialize(&ret_msg.data).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Failed retrieving portfolio, received invalid data.",
)
})?;
Ok(portfolio)
} else {
/* could not get data */
return Err(io::Error::new(
Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("{}", ReturnFlags::ClientAccRetrievePortfolioError),
));
"Failed retrieving portfolio, received invalid message.",
))
}
}

View File

@ -1,15 +1,10 @@
use std::io;
use crate::common::command::*;
use crate::common::message::*;
use crate::common::account::transaction::Transaction;
use crate::common::message::assert_msg::assert_msg;
use crate::common::message::inst::DataTransferInst;
use crate::common::message::message::Message;
use crate::common::message::message_builder::message_builder;
use crate::common::message::message_type::MessageType;
use crate::common::misc::return_flags::ReturnFlags;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
@ -33,7 +28,7 @@ use tokio_rustls::client::TlsStream;
pub async fn acc_retrieve_transaction(
socket: &mut TlsStream<TcpStream>,
auth_jwt: String,
) -> io::Result<Vec<Transaction>> {
) -> std::io::Result<Vec<Transaction>> {
if auth_jwt.is_empty() == true {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
@ -42,57 +37,30 @@ pub async fn acc_retrieve_transaction(
}
/* build message request */
let message = message_builder(
MessageType::DataTransfer,
DataTransferInst::GetUserTransactionHist as i64,
1,
0,
0,
bincode::serialize(&auth_jwt).unwrap(),
);
socket
.write_all(&bincode::serialize(&message).unwrap())
Message::new()
.command(Command::GetUserTransactionHist)
.data(auth_jwt)
.send(socket)
.await?;
/* decode response */
let mut buf = Vec::with_capacity(4096);
socket.read_buf(&mut buf).await?;
let ret_msg = Message::receive(socket).await?;
let response: Message = bincode::deserialize(&buf).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("{}", ReturnFlags::ClientAccRetrieveTransactionError),
)
})?;
if assert_msg(
&response,
MessageType::ServerReturn,
true,
1,
false,
0,
false,
0,
false,
0,
) && response.data.len() != 0
&& response.instruction == 1
{
/* assert received message */
if !ret_msg.assert_command(Command::Success) || !ret_msg.assert_data() {
/* returned data*/
let transactions: Vec<Transaction> =
bincode::deserialize(&response.data).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("{}", ReturnFlags::ClientAccRetrievePortfolioError),
)
})?;
let transactions: Vec<Transaction> = bincode::deserialize(&ret_msg.data).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Failed retrieving transaction, received invalid data.",
)
})?;
return Ok(transactions);
} else {
/* could not get data */
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("{}", ReturnFlags::ClientAccRetrieveTransactionError),
"Failed retrieving transaction, received invalid message.",
));
}
}

View File

@ -1,15 +1,8 @@
use ring::digest;
use std::io;
use crate::common::message::assert_msg::assert_msg;
use crate::common::message::inst::CommandInst;
use crate::common::message::message::Message;
use crate::common::message::message_builder::message_builder;
use crate::common::message::message_type::MessageType;
use crate::common::misc::return_flags::ReturnFlags;
use crate::common::command::*;
use crate::common::message::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
@ -30,49 +23,28 @@ use tokio_rustls::client::TlsStream;
/// ```
pub async fn get_server_salt(
socket: &mut TlsStream<TcpStream>,
) -> io::Result<[u8; digest::SHA512_OUTPUT_LEN / 2]> {
) -> std::io::Result<String> {
/*
* request to generate a salt from the server.
* */
let message = message_builder(
MessageType::Command,
CommandInst::GenHashSalt as i64,
0,
0,
0,
Vec::new(),
);
socket
.write_all(&bincode::serialize(&message).unwrap())
Message::new()
.command(Command::GenHashSalt)
.send(socket)
.await?;
let mut buf = Vec::with_capacity(4096);
socket.read_buf(&mut buf).await?;
let ret_msg = Message::receive(socket).await?;
let ret_msg: Message = bincode::deserialize(&buf).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("{}", ReturnFlags::ClientGenSaltFailed),
)
})?;
if assert_msg(
&ret_msg,
MessageType::DataTransfer,
true,
1,
false,
0,
true,
1,
true,
digest::SHA512_OUTPUT_LEN / 2,
) {
Ok(*array_ref!(ret_msg.data, 0, digest::SHA512_OUTPUT_LEN / 2))
} else {
/* assert received message */
if !ret_msg.assert_command(Command::Success) || !ret_msg.assert_data() {
Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("{}", ReturnFlags::ClientReqSaltInvMsg),
"Failed getting generated server Salt, received an invalid message.",
))
} else {
ret_msg.get_data().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"Failed getting generated server Salt, received an invalid message.",
)})
}
}

View File

@ -1,13 +1,8 @@
use ring::digest;
use std::io;
use crate::common::message::inst::CommandInst;
use crate::common::message::message::Message;
use crate::common::message::message_builder::message_builder;
use crate::common::message::message_type::MessageType;
use crate::common::misc::return_flags::ReturnFlags;
use crate::common::command::*;
use crate::common::message::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
@ -32,64 +27,33 @@ use tokio_rustls::client::TlsStream;
pub async fn req_server_salt(
socket: &mut TlsStream<TcpStream>,
username: &str,
salt_type: i64,
) -> io::Result<[u8; digest::SHA512_OUTPUT_LEN]> {
salt_type: Command,
) -> std::io::Result<String> {
/* enforce salt_type to be either email or password */
assert_eq!(salt_type >= CommandInst::GetEmailSalt as i64, true);
assert_eq!(salt_type <= CommandInst::GetPasswordSalt as i64, true);
assert_eq!(
(salt_type == Command::GetEmailSalt) || (salt_type == Command::GetPasswordSalt),
true
);
/* generate message to send */
let message = message_builder(
MessageType::Command,
salt_type,
1,
0,
0,
username.as_bytes().to_vec(),
);
socket
.write_all(&bincode::serialize(&message).unwrap())
Message::new()
.command(salt_type)
.data(username)
.send(socket)
.await?;
let mut buf = Vec::with_capacity(4096);
socket.read_buf(&mut buf).await?;
let ret_msg = Message::receive(socket).await?;
let ret_msg: Message = bincode::deserialize(&buf).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("{}", ReturnFlags::ClientReqSaltInvMsg),
)
})?;
match ret_msg.msgtype {
MessageType::Command => Err(io::Error::new(
/* assert received message */
if !ret_msg.assert_command(Command::Success) || !ret_msg.assert_data() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("{}", ReturnFlags::ClientReqSaltInvMsg),
)),
MessageType::DataTransfer => {
if ret_msg.data.len() != digest::SHA512_OUTPUT_LEN {
Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("{}", ReturnFlags::ClientReqSaltInvMsgRetSize),
))
} else if ret_msg.instruction == salt_type {
Ok(*array_ref!(ret_msg.data, 0, digest::SHA512_OUTPUT_LEN))
} else {
Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("{}", ReturnFlags::ClientReqSaltInvMsgInst),
))
}
}
MessageType::ServerReturn => match ret_msg.instruction {
0 => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("{}", ReturnFlags::ClientReqSaltRej),
)),
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("{}", ReturnFlags::ClientReqSaltInvMsg),
)),
},
"Recieved invalid Salt from server.",
));
}
Ok(ret_msg.get_data().map_err(|_| {
io::Error::new(io::ErrorKind::InvalidData,
format!("Could not get server salt, received invalid data."))
})?)
}

View File

@ -16,16 +16,16 @@ use std::num::NonZeroU32;
/// ```rust
/// let email_hash = hash("test@test.com", [0u8; 64], 124000);
/// ```
pub fn hash(val: &Vec<u8>, salt: &Vec<u8>, iter: u32) -> [u8; digest::SHA512_OUTPUT_LEN] {
pub fn hash(val: &str, salt: &str, iter: u32) -> String {
let iterations: NonZeroU32 = NonZeroU32::new(iter).unwrap();
let mut hash = [0u8; digest::SHA512_OUTPUT_LEN];
pbkdf2::derive(
pbkdf2::PBKDF2_HMAC_SHA512,
iterations,
&salt,
val,
salt.as_bytes(),
val.as_bytes(),
&mut hash,
);
hash
String::from_utf8(hash.into()).unwrap()
}

View File

@ -2,5 +2,5 @@ pub mod hash;
pub mod order;
pub mod portfolio;
pub mod position;
pub mod session;
pub mod salt;
pub mod transaction;

View File

@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize};
use crate::common::account::position::Position;
pub use crate::common::account::position::Position;
#[derive(Serialize, Deserialize, PartialEq, Debug, Default)]
pub struct Portfolio {

View File

@ -0,0 +1,14 @@
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use ring::digest;
/// generates a random salt
/// always assumes salt is sha512 output length
/// for "security" ofcourse
pub fn gen_salt() -> String {
thread_rng()
.sample_iter(&Alphanumeric)
.take(digest::SHA512_OUTPUT_LEN)
.map(char::from)
.collect()
}

View File

@ -1,19 +0,0 @@
use chrono::{DateTime, Utc};
use std::net::Ipv4Addr;
#[derive(PartialEq, Debug)]
pub struct SessionID {
pub sess_id: String,
pub client_ip: Ipv4Addr,
pub expiry_date: DateTime<Utc>,
pub is_active: bool,
}
impl std::fmt::Display for SessionID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"({}, {}, {}, {})",
self.sess_id, self.client_ip, self.expiry_date, self.is_active
)
}
}

View File

@ -0,0 +1,33 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Copy)]
pub enum Command {
Default = 0,
Success = 1,
Failure = 2,
LoginMethod1 = 3,
LoginMethod2,
Register,
PurchaseAsset,
SellAsset,
GenHashSalt,
GetEmailSalt,
GetPasswordSalt,
GetAssetInfo,
GetAssetValue,
GetAssetValueCurrent,
GetUserInfo,
GetUserPortfolio,
GetUserTransactionHist,
}
impl std::fmt::Display for Command {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:#?}", self)
}
}
impl Default for Command {
fn default() -> Self {
Command::Success
}
}

View File

@ -1,31 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Default, Eq, PartialEq, Clone, Debug)]
pub struct Company {
pub id: i64,
pub symbol: String,
pub isin: String,
pub company_name: String,
pub primary_exchange: String,
pub sector: String,
pub industry: String,
pub primary_sic_code: String,
pub employees: i64,
}
impl std::fmt::Display for Company {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"({}, {}, {}, {}, {}, {}, {}, {}, {})",
self.id,
self.symbol,
self.isin,
self.company_name,
self.primary_exchange,
self.sector,
self.industry,
self.primary_sic_code,
self.employees
)
}
}

View File

@ -1,2 +0,0 @@
pub mod company;
pub mod stock_val;

View File

@ -1,21 +0,0 @@
use postgres_types::{FromSql, ToSql};
use serde::{Deserialize, Serialize};
#[derive(Default, PartialEq, Debug, ToSql, FromSql, Serialize, Deserialize)]
pub struct StockVal {
pub id: i64,
pub isin: String,
pub time_epoch: i64,
pub ask_price: f64,
pub bid_price: f64,
pub volume: i64,
}
impl std::fmt::Display for StockVal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"({}, {}, {}, {}, {}, {})",
self.id, self.isin, self.time_epoch, self.ask_price, self.bid_price, self.volume
)
}
}

View File

@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize};
/// user_id - The user id that is authorized.
/// exp - The unix epoch at which this claim expires.
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct JWTClaim {
pub struct Claim {
pub user_id: i64,
pub exp: u64,
}

View File

@ -0,0 +1,99 @@
use log::warn;
use serde::{Deserialize, Serialize};
pub use crate::common::command::Command;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
#[cfg(all(feature = "client", not(feature = "server")))]
use tokio_rustls::client::TlsStream;
#[cfg(all(feature = "server", not(feature = "client")))]
use tokio_rustls::server::TlsStream;
#[derive(Serialize, Deserialize, PartialEq, Debug, Default)]
pub struct Message {
pub command: Command,
pub data: Vec<u8>,
}
impl Message {
pub fn new() -> Message {
Message {
command: Command::Default,
data: Vec::new(),
}
}
pub fn command<'a>(&'a mut self, command: Command) -> &'a mut Message {
self.command = command;
self
}
pub fn data<'a, T>(&'a mut self, data: T) -> &'a mut Message
where
T: serde::Serialize,
{
self.data = bincode::serialize(&data).unwrap();
self
}
pub async fn send(&self, socket: &mut TlsStream<TcpStream>) -> std::io::Result<()> {
/* automagically log failed commands */
if self.assert_command(Command::Failure) {
warn!("Operation failed, sending error to client.");
warn!("Error: {}", self.get_err());
}
socket
.write_all(bincode::serialize(self).unwrap().as_slice())
.await
}
pub async fn receive<'a>(socket: &mut TlsStream<TcpStream>) -> std::io::Result<Message> {
let mut buf = Vec::with_capacity(4096);
socket.read_buf(&mut buf).await?;
Ok(bincode::deserialize(&buf).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Failed parsing recieved message."),
)
})?)
}
/// Use this function for more easier bincode conversion of received data.
pub fn get_data<'a, T>(&'a self) -> Result<T, String>
where
T: serde::Deserialize<'a>,
{
bincode::deserialize(&self.data).map_err(|_| format!("Failed getting message data."))
}
pub fn get_err(&self) -> &str {
// Tests should handle this function being called not in the
// correct context.
// i.e. tests should ensure that functions error out gracefully.
bincode::deserialize(&self.data).unwrap()
}
pub fn assert_command(&self, command: Command) -> bool {
if self.command != command {
false
} else {
true
}
}
pub fn assert_data(&self) -> bool {
if self.data.len() == 0 {
false
} else {
true
}
}
}
impl std::fmt::Display for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({}, {:#?})", self.command, self.data)
}
}

View File

@ -1,49 +0,0 @@
use crate::common::message::message::Message;
use crate::common::message::message_type::MessageType;
/// Asserts a recieved message meta information.
///
/// Takes in a message and meta information to check the message against.
/// Can be used to check some attributes.
/// For example, you can check for message to have an X amount of arguments, and
/// not check how many arguments are passed.
///
/// Arguments:
/// message - The mesage to assert against.
/// msg_type - MessageType expected.
/// check_arg_cnt - Whether to check for the argument account.
/// arg_cnt - The argument count expected.
/// check_dnum - Whether to check for the number data message.
/// msg_dnum - The number of the data message.
/// check_dmax - Whether to check for the max data message.
/// msg_dmax - The number of the max data message.
/// check_len - Whether to check for the data payload length.
/// data_len - The length of the data payload.
///
/// Returns: a boolean.
pub fn assert_msg(
message: &Message,
msg_type: MessageType,
check_arg_cnt: bool,
arg_cnt: usize,
check_dnum: bool,
msg_dnum: usize,
check_dmax: bool,
msg_dmax: usize,
check_len: bool,
data_len: usize,
) -> bool {
if message.msgtype != msg_type {
return false;
} else if check_arg_cnt && (message.argument_count != arg_cnt) {
return false;
} else if check_dnum && (message.data_message_number != msg_dnum) {
return false;
} else if check_dmax && (message.data_message_max != msg_dmax) {
return false;
} else if check_len && (message.data.len() != data_len) {
return false;
}
return true;
}

View File

@ -1,40 +0,0 @@
#[allow(dead_code)]
static INST_SWITCH_STATE: isize = 0;
#[derive(PartialEq, Debug)]
pub enum CommandInst {
LoginMethod1 = 1,
LoginMethod2 = 2,
Register = 3,
PurchaseAsset = 4,
SellAsset = 5,
GenHashSalt = 6,
GetEmailSalt = 7,
GetPasswordSalt = 8,
}
impl std::fmt::Display for CommandInst {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:#?}", self)
}
}
#[allow(dead_code)]
static INST_COMMAND_MAX_ID: isize = CommandInst::GetPasswordSalt as isize;
#[derive(PartialEq, Debug)]
pub enum DataTransferInst {
GetAssetInfo = 6,
GetAssetValue = 7,
GetAssetValueCurrent = 8,
GetUserInfo = 9,
GetUserPortfolio = 10,
GetUserTransactionHist = 11,
}
impl std::fmt::Display for DataTransferInst {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:#?}", self)
}
}
#[allow(dead_code)]
static INST_DATA_MAX_ID: isize = DataTransferInst::GetUserTransactionHist as isize;

View File

@ -1,27 +0,0 @@
use serde::{Deserialize, Serialize};
use crate::common::message::message_type::MessageType;
#[derive(Serialize, Deserialize, PartialEq, Debug, Default)]
pub struct Message {
pub msgtype: MessageType,
pub instruction: i64,
pub argument_count: usize,
pub data_message_number: usize,
pub data_message_max: usize,
pub data: Vec<u8>,
}
impl std::fmt::Display for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"({}, {}, {}, {}, {}, {:#?})",
self.msgtype,
self.instruction,
self.argument_count,
self.data_message_number,
self.data_message_max,
self.data
)
}
}

View File

@ -1,20 +0,0 @@
use crate::common::message::message::Message;
use crate::common::message::message_type::MessageType;
pub fn message_builder(
msg_type: MessageType,
inst: i64,
arg_cnt: usize,
data_msg_num: usize,
data_msg_max: usize,
data: Vec<u8>,
) -> Message {
let mut message: Message = Message::default();
message.msgtype = msg_type;
message.instruction = inst;
message.argument_count = arg_cnt;
message.data_message_number = data_msg_num;
message.data_message_max = data_msg_max;
message.data = data;
message
}

View File

@ -1,19 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, PartialEq, Debug)]
pub enum MessageType {
Command = 0,
DataTransfer = 1,
ServerReturn = 2,
}
impl std::fmt::Display for MessageType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:#?}", self)
}
}
impl Default for MessageType {
fn default() -> Self {
MessageType::Command
}
}

View File

@ -1,5 +0,0 @@
pub mod assert_msg;
pub mod inst;
pub mod message;
pub mod message_builder;
pub mod message_type;

View File

@ -1,2 +0,0 @@
pub mod return_flags;
pub mod servers_pool;

View File

@ -1,75 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(PartialEq, Debug, Serialize, Deserialize)]
pub enum ReturnFlags {
LibtraderInitClientConnect = 1,
LibtraderInitLogFailed = 2,
LibtraderInitFailed = 3,
CommonGenLogDirCreationFailed = 4,
CommonTlsBadConfig = 5,
CommonGetCompanyFailed = 6,
CommonGetStockFailed = 7,
ServerDbConnectFailed = 8,
ServerDbWriteFailed = 9,
ServerDbUserHashNotFound = 10,
ServerDbUserSaltNotFound = 11,
ServerDbCreateTransactionFailed = 12,
ServerDbCreatePositionFailed = 13,
ServerDbCreateStockFailed = 14,
ServerDbCreateCompanyFailed = 15,
ServerDbSearchStockNotFound = 16,
ServerDbSearchCompanyNotFound = 17,
ServerRegisterInvMsg = 18,
ServerLoginInvMsg = 19,
ServerPurchaseAssetInvMsg = 20,
ServerAccUnauthorized = 21,
ServerAccUserExists = 22,
ServerGetAssetDataInvMsg = 23,
ServerGetAssetInfoInvMsg = 24,
ServerGetUserIdNotFound = 25,
ServerRetrieveTransactionFailed = 26,
ServerRetrieveTransactionInvMsg = 27,
ServerRetrievePortfolioFailed = 28,
ServerRetrievePortfolioInvMsg = 29,
ServerCreateJwtTokenFailed = 30,
ServerTlsConnWriteFailed = 31,
ServerTlsConnProcessFailed = 32,
ServerTlsConnReadPlainFailed = 33,
ServerTlsServerAcceptFailed = 34,
ServerHandleDataRcvdInvMsg = 35,
ClientAccRetrievePortfolioError = 36,
ClientAccRetrieveTransactionError = 37,
ClientAccCreationFailed = 38,
ClientAccInvalidSessionId = 39,
ClientAccUnauthorized = 40,
ClientReqSaltFailed = 41,
ClientReqSaltInvMsg = 42,
ClientReqSaltInvMsgRetSize = 43,
ClientReqSaltInvMsgInst = 44,
ClientReqSaltRej = 45,
ClientGenSaltFailed = 46,
ClientTlsReadError = 47,
ClientWaitAndReadBranched = 48,
}
impl std::fmt::Display for ReturnFlags {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:#?}", self)
}
}

View File

@ -1,2 +0,0 @@
pub static SERVERS_IPS: [&str; 2] = ["", ""];
pub static SERVERS_CERTS: [&str; 2] = ["", ""];

View File

@ -1,5 +1,4 @@
pub mod account;
pub mod generic;
pub mod command;
pub mod jwt;
pub mod message;
pub mod misc;
pub mod sessions;

View File

@ -1 +0,0 @@
pub mod jwt_claim;

View File

@ -1,7 +1,5 @@
/* Server crates */
#[cfg(all(feature = "server", not(feature = "client")))]
extern crate arrayref;
#[cfg(all(feature = "server", not(feature = "client")))]
extern crate json;
#[cfg(all(feature = "server", not(feature = "client")))]
extern crate tokio;
@ -9,14 +7,8 @@ extern crate tokio;
/* Client crates */
#[cfg(all(feature = "client", not(feature = "server")))]
#[macro_use]
extern crate arrayref;
#[cfg(all(feature = "client", not(feature = "server")))]
#[macro_use]
extern crate json;
#[cfg(all(feature = "server", feature = "client"))]
#[macro_use]
extern crate arrayref;
#[cfg(all(feature = "server", feature = "client"))]
#[macro_use]
extern crate json;

View File

@ -1,14 +1,10 @@
use log::warn;
use std::num::NonZeroU32;
use data_encoding::HEXUPPER;
use ring::pbkdf2;
use crate::common::message::assert_msg::assert_msg;
use crate::common::message::inst::CommandInst;
use crate::common::message::message::Message;
use crate::common::message::message_builder::message_builder;
use crate::common::message::message_type::MessageType;
use crate::common::command::*;
use crate::common::message::*;
use crate::server::db::cmd::get_user_hash::get_user_hash;
use crate::server::db::cmd::get_user_id::get_user_id;
@ -22,26 +18,13 @@ use tokio_rustls::server::TlsStream;
pub async fn acc_auth(
sql_conn: &tokio_postgres::Client,
tls_connection: &mut TlsStream<TcpStream>,
socket: &mut TlsStream<TcpStream>,
message: &Message,
) -> std::io::Result<()> {
/* assert recieved message */
if !assert_msg(
message,
MessageType::Command,
true,
3,
false,
0,
false,
0,
false,
0,
) && message.instruction == CommandInst::LoginMethod1 as i64
&& message.data.len() != 0
{
if !message.assert_command(Command::LoginMethod1) || !message.assert_data() {
warn!("LOGIN_INVALID_MESSAGE");
return tls_connection.shutdown().await;
return socket.shutdown().await;
}
/*
@ -51,53 +34,72 @@ pub async fn acc_auth(
let stringified_data = std::str::from_utf8(&message.data).unwrap();
let data = json::parse(&stringified_data).unwrap();
/* get email, password, and username hashes */
let email_hash = HEXUPPER
.decode(data["hashed_email"].as_str().unwrap().as_bytes())
.unwrap();
let password_hash = HEXUPPER
.decode(data["hashed_password"].as_str().unwrap().as_bytes())
.unwrap();
let email_client_hash = data["hashed_email"].as_str().unwrap();
let password_client_hash = data["hashed_password"].as_str().unwrap();
let username = data["username"].as_str().unwrap();
/*
* Get server salts
* TODO: FIND A SMALLER CODE FOR THIS NONSENSE GARBAGE CODE
* */
let email_salt = HEXUPPER
.decode(
get_user_salt(sql_conn, username, true, true)
.await
.unwrap()
.as_bytes(),
)
.unwrap();
let password_salt = HEXUPPER
.decode(
get_user_salt(sql_conn, username, false, true)
.await
.unwrap()
.as_bytes(),
)
.unwrap();
let email_salt = match get_user_salt(sql_conn, username, Command::GetEmailSalt, true).await {
Ok(val) => val,
Err(_) => {
return Message::new()
.command(Command::Failure)
.data(format!(
"Failed authorizing user, {}, does not exist!",
username
))
.send(socket)
.await;
}
};
let password_salt =
match get_user_salt(sql_conn, username, Command::GetPasswordSalt, true).await {
Ok(val) => val,
Err(_) => {
return Message::new()
.command(Command::Failure)
.data(format!(
"Failed authorizing user, {}, does not exist!",
username
))
.send(socket)
.await;
}
};
/*
* Get server hashes
* */
let email_db = HEXUPPER
.decode(
get_user_hash(sql_conn, username, true)
.await
.unwrap()
.as_bytes(),
)
.unwrap();
let password_db = HEXUPPER
.decode(
get_user_hash(sql_conn, username, false)
.await
.unwrap()
.as_bytes(),
)
.unwrap();
let email_server_hash = match get_user_hash(sql_conn, username, true).await {
Ok(val) => val,
Err(_) => {
return Message::new()
.command(Command::Failure)
.data(format!(
"Failed authorizing user, {}, does not exist!",
username
))
.send(socket)
.await;
}
};
let password_server_hash = match get_user_hash(sql_conn, username, false).await {
Ok(val) => val,
Err(_) => {
return Message::new()
.command(Command::Failure)
.data(format!(
"Failed authorizing user, {}, does not exist!",
username
))
.send(socket)
.await;
}
};
/*
* Verify creds
@ -105,111 +107,64 @@ pub async fn acc_auth(
let email_ret = pbkdf2::verify(
pbkdf2::PBKDF2_HMAC_SHA512,
NonZeroU32::new(350_000).unwrap(),
&email_salt,
&email_hash,
&email_db,
email_salt.as_bytes(),
email_client_hash.as_bytes(),
email_server_hash.as_bytes(),
);
match email_ret.is_ok() {
true => {}
false => {
let server_response = message_builder(
MessageType::ServerReturn,
0,
0,
0,
0,
bincode::serialize(&"Email Incorrect").unwrap(),
);
match tls_connection
.write_all(&bincode::serialize(&server_response).unwrap())
.await
{
_ => return Ok(()),
};
}
};
if email_ret.is_err() {
return Message::new()
.command(Command::Failure)
.data("Email Incorrect")
.send(socket)
.await;
}
let pass_ret = pbkdf2::verify(
pbkdf2::PBKDF2_HMAC_SHA512,
NonZeroU32::new(500_000).unwrap(),
&password_salt,
&password_hash,
&password_db,
password_salt.as_bytes(),
password_client_hash.as_bytes(),
password_server_hash.as_bytes(),
);
match pass_ret.is_ok() {
true => {}
false => {
let server_response = message_builder(
MessageType::ServerReturn,
0,
0,
0,
0,
bincode::serialize(&"Password Incorrect").unwrap(),
);
match tls_connection
.write_all(&bincode::serialize(&server_response).unwrap())
.await
{
_ => return Ok(()),
};
}
};
if pass_ret.is_err() {
return Message::new()
.command(Command::Failure)
.data("Password Incorrect")
.send(socket)
.await;
}
/*
* Generate JWT token
* */
/* get user id*/
let user_id = get_user_id(sql_conn, username).await?;
let user_id = match get_user_id(sql_conn, username).await {
Ok(val) => val,
Err(_) => {
return Message::new()
.command(Command::Failure)
.data(format!(
"Failed authorizing user, {}, does not exist!",
username
))
.send(socket)
.await;
}
};
/* gen the actual token */
use std::time::{Duration, SystemTime, UNIX_EPOCH};
let beginning_of_time = SystemTime::now() + Duration::from_secs(4 * 60 * 60);
let jwt_token = create_jwt_token(
user_id,
beginning_of_time
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
);
/*
* server failed to generate JWT token.
* inform client about issue
* */
if jwt_token.is_err() {
let server_response = message_builder(
MessageType::ServerReturn,
0,
0,
0,
0,
bincode::serialize(&"Login failed, try again later.").unwrap(),
);
match tls_connection
.write_all(&bincode::serialize(&server_response).unwrap())
.await
{
// We already failed,
// we don't care if client doesn't recieve
_ => return Ok(()),
};
}
/*
* Send the JWT token
* */
let message = message_builder(
MessageType::ServerReturn,
1,
1,
0,
0,
jwt_token.unwrap().as_bytes().to_vec(),
);
match tls_connection
.write_all(bincode::serialize(&message).unwrap().as_slice())
Message::new()
.command(Command::Success)
.data(create_jwt_token(
user_id,
beginning_of_time
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
))
.send(socket)
.await
{
_ => Ok(()), // Don't care if client doesn't receive
}
}

View File

@ -1,14 +1,9 @@
use data_encoding::HEXUPPER;
use log::warn;
use crate::common::message::assert_msg::assert_msg;
use crate::common::message::inst::CommandInst;
use crate::common::message::message_builder::message_builder;
use crate::common::message::message_type::MessageType;
use crate::common::command::*;
use crate::common::message::*;
use crate::common::account::portfolio::Portfolio;
use crate::common::message::message::Message;
use crate::common::misc::return_flags::ReturnFlags;
use crate::server::account::hash_email::hash_email;
use crate::server::account::hash_pwd::hash_pwd;
@ -20,80 +15,42 @@ use tokio_rustls::server::TlsStream;
pub async fn acc_create(
sql_conn: &tokio_postgres::Client,
tls_connection: &mut TlsStream<TcpStream>,
socket: &mut TlsStream<TcpStream>,
message: &Message,
) -> std::io::Result<()> {
/* assert recieved message */
if !assert_msg(
message,
MessageType::Command,
true,
5,
false,
0,
false,
0,
false,
0,
) && message.instruction == CommandInst::Register as i64
&& message.data.len() != 0
{
if !message.assert_command(Command::Register) || !message.assert_data() {
warn!("REGISTER_INVALID_MESSAGE");
return tls_connection.shutdown().await;
return socket.shutdown().await; // just get off of my lawn
}
/*
* Parse account data
* */
/* get json data */
let stringified_data = std::str::from_utf8(&message.data).unwrap().to_string();
let stringified_data: String = bincode::deserialize(&message.data).unwrap();
let data = json::parse(&stringified_data).unwrap();
/* get email, password salts and client hashes */
let email_hash = HEXUPPER
.decode(data["email_hash"].as_str().unwrap().to_string().as_bytes())
.unwrap();
let email_client_salt = HEXUPPER
.decode(
data["email_client_salt"]
.as_str()
.unwrap()
.to_string()
.as_bytes(),
)
.unwrap();
let password_hash = HEXUPPER
.decode(
data["password_hash"]
.as_str()
.unwrap()
.to_string()
.as_bytes(),
)
.unwrap();
let password_client_salt = HEXUPPER
.decode(
data["password_client_salt"]
.as_str()
.unwrap()
.to_string()
.as_bytes(),
)
.unwrap();
let email_hash = data["email_hash"].as_str().unwrap(); // TODO: fix this
let email_client_salt = data["email_client_salt"].as_str().unwrap();
let password_hash = data["password_hash"].as_str().unwrap();
let password_client_salt = data["password_client_salt"].as_str().unwrap();
/* get username */
let username: String = data["username"].as_str().unwrap().to_string();
let username = data["username"].as_str().unwrap();
/* expect all received values to be non-None */
/* generate account struct */
let mut account: Account = Account {
username: username,
username: username.to_string(),
email_hash: "".to_string(),
server_email_salt: "".to_string(),
client_email_salt: HEXUPPER.encode(&email_client_salt),
client_email_salt: email_client_salt.to_string(),
pass_hash: "".to_string(),
server_pass_salt: "".to_string(),
client_pass_salt: HEXUPPER.encode(&password_client_salt),
client_pass_salt: password_client_salt.to_string(),
is_pass: true,
portfolio: Portfolio::default(),
@ -117,34 +74,24 @@ pub async fn acc_create(
* Inform cient that user already exists
* Note: figure out if this is a security? issue
*/
let server_response = message_builder(
MessageType::ServerReturn,
0,
0,
0,
0,
bincode::serialize(&format!("{:#?}", ReturnFlags::ServerAccUserExists)).unwrap(),
);
match tls_connection
.write_all(&bincode::serialize(&server_response).unwrap())
.await
{
// Don't care if user didn't recieve a reply
_ => return Ok(()),
};
return Message::new()
.command(Command::Failure)
.data("Username already exists")
.send(socket)
.await;
}
/*
* Hash the email and password.
* */
/* hash the email */
let email_server_hash = hash_email(&email_hash);
account.email_hash = HEXUPPER.encode(&email_server_hash.0);
account.server_email_salt = HEXUPPER.encode(&email_server_hash.1);
let email_server_hash = hash_email(email_hash);
account.email_hash = email_server_hash.0;
account.server_email_salt = email_server_hash.1;
/* hash the password */
let password_server_hash = hash_pwd(&password_hash);
account.pass_hash = HEXUPPER.encode(&password_server_hash.0);
account.server_pass_salt = HEXUPPER.encode(&password_server_hash.1);
let password_server_hash = hash_pwd(password_hash);
account.pass_hash = password_server_hash.0;
account.server_pass_salt = password_server_hash.1;
/*
* Write the account to the database.
@ -160,23 +107,18 @@ pub async fn acc_create(
/*
* Send to client SQL result
*/
let server_response = message_builder(
MessageType::ServerReturn,
if creation_result.is_ok() { 1 } else { 0 },
0,
0,
0,
if creation_result.is_ok() {
Vec::new()
} else {
bincode::serialize(&format!("{:#?}", creation_result)).unwrap()
},
);
match tls_connection
.write_all(&bincode::serialize(&server_response).unwrap())
.await
{
// Don't care if user didn't recieve a reply
_ => Ok(()),
match creation_result {
Ok(_) => Message::new().command(Command::Success).send(socket).await,
Err(_) => {
Message::new()
.command(Command::Failure)
.data(format!(
"Failed creating an account, server error, {:#?}, \
please try again later.",
creation_result
))
.send(socket)
.await
}
}
}

View File

@ -1,31 +0,0 @@
use ring::{digest, pbkdf2};
use std::num::NonZeroU32;
/// A generic hashing abstraction function.
///
/// Useful for quickly swapping the current hashing system.
///
/// Arguments:
/// val - The value to be hashed.
/// salt - The whole salt to be used.
/// iter - The number of iteration to use.
///
/// Returns: u8 array of size 64 bytes.
///
/// Example:
/// ```rust
/// let email_hash = hash("test@test.com", [0u8; 64], 124000);
/// ```
pub fn hash(val: &Vec<u8>, salt: &Vec<u8>, iter: u32) -> [u8; digest::SHA512_OUTPUT_LEN] {
let iterations: NonZeroU32 = NonZeroU32::new(iter).unwrap();
let mut hash = [0u8; digest::SHA512_OUTPUT_LEN];
pbkdf2::derive(
pbkdf2::PBKDF2_HMAC_SHA512,
iterations,
&salt,
val,
&mut hash,
);
hash
}

View File

@ -1,7 +1,5 @@
use ring::rand::SecureRandom;
use ring::{digest, rand};
use crate::common::account::hash::hash;
use crate::common::account::hash::*;
use crate::common::account::salt::*;
/// Generates a storable server email hash from a client hashed email.
///
@ -20,42 +18,8 @@ use crate::common::account::hash::hash;
/// println!("Server Email Hash: {}", HEXUPPER.encode(&enc.0));
/// println!("Server Email Salt: {}", HEXUPPER.encode(&enc.1));
/// ```
pub fn hash_email(
hashed_email: &Vec<u8>,
) -> (
[u8; digest::SHA512_OUTPUT_LEN],
[u8; digest::SHA512_OUTPUT_LEN],
) {
let rng = rand::SystemRandom::new();
let mut salt = [0u8; digest::SHA512_OUTPUT_LEN];
rng.fill(&mut salt).unwrap();
let hash = hash(hashed_email, &salt.to_vec(), 350_000);
pub fn hash_email(hashed_email: &str) -> (String, String) {
let salt = gen_salt();
let hash = hash(&hashed_email, &salt, 350_000);
(hash, salt)
}
#[cfg(test)]
mod test {
use super::*;
use data_encoding::HEXUPPER;
#[test]
fn test_account_hash_email_server() {
let email = "totallyrealemail@anemail.c0m";
/* ensure that hash_email_server() works */
let output = hash_email(&email.as_bytes().to_vec());
assert_ne!(output.0.len(), 0);
assert_ne!(output.1.len(), 0);
/* ensure that hash_email_server() generates different output
* each time it is run.
* */
// Generate new server salt.
let enc0 = hash_email(&email.as_bytes().to_vec());
let enc1 = hash_email(&email.as_bytes().to_vec());
assert_ne!(HEXUPPER.encode(&enc0.0), HEXUPPER.encode(&enc1.0));
assert_ne!(HEXUPPER.encode(&enc0.1), HEXUPPER.encode(&enc1.1));
}
}

View File

@ -1,7 +1,5 @@
use ring::rand::SecureRandom;
use ring::{digest, rand};
use crate::common::account::hash::hash;
use crate::common::account::hash::*;
use crate::common::account::salt::*;
/// Generates a storable server password hash from a client hashed password.
///
@ -20,43 +18,8 @@ use crate::common::account::hash::hash;
/// println!("Server Hash: {}", HEXUPPER.encode(&enc.0));
/// println!("Server Salt: {}", HEXUPPER.encode(&enc.1));
/// ```
pub fn hash_pwd(
hashed_pass: &Vec<u8>,
) -> (
[u8; digest::SHA512_OUTPUT_LEN],
[u8; digest::SHA512_OUTPUT_LEN],
) {
// sever hash, server salt
let rng = rand::SystemRandom::new();
let mut salt = [0u8; digest::SHA512_OUTPUT_LEN];
rng.fill(&mut salt).unwrap();
let hash = hash(hashed_pass, &salt.to_vec(), 500_000);
pub fn hash_pwd(hashed_pass: &str) -> (String, String) {
let salt = gen_salt();
let hash = hash(&hashed_pass, &salt, 500_000);
(hash, salt)
}
#[cfg(test)]
mod test {
use super::*;
use data_encoding::HEXUPPER;
#[test]
fn test_account_hash_pwd_server() {
let pass = "goodlilpassword";
/* ensure that hash_pwd_server() works */
let output = hash_pwd(&pass.as_bytes().to_vec());
assert_ne!(output.0.len(), 0);
assert_ne!(output.1.len(), 0);
/* ensure that hash_pwd_server() generates different output
* each time it is run.
* */
// Generate new server salt.
let enc0 = hash_pwd(&pass.as_bytes().to_vec());
let enc1 = hash_pwd(&pass.as_bytes().to_vec());
assert_ne!(HEXUPPER.encode(&enc0.0), HEXUPPER.encode(&enc1.0));
assert_ne!(HEXUPPER.encode(&enc0.1), HEXUPPER.encode(&enc1.1));
}
}

View File

@ -1,7 +1,4 @@
pub mod authorization;
pub mod creation;
pub mod hash;
pub mod hash_email;
pub mod hash_pwd;
pub mod retrieval_portfolio;
pub mod retrieval_transaction;

View File

@ -1,121 +0,0 @@
use log::warn;
use crate::common::account::portfolio::Portfolio;
use crate::common::account::position::Position;
use crate::common::message::assert_msg::assert_msg;
use crate::common::message::inst::DataTransferInst;
use crate::common::message::message::Message;
use crate::common::message::message_builder::message_builder;
use crate::common::message::message_type::MessageType;
use crate::server::db::initializer::db_connect;
use crate::server::network::jwt_wrapper::verify_jwt_token;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio_rustls::server::TlsStream;
pub async fn acc_retrieve_portfolio(
tls_connection: &mut TlsStream<TcpStream>,
message: &Message,
) -> std::io::Result<()> {
/* assert recieved message */
if !assert_msg(
message,
MessageType::Command,
true,
1,
false,
0,
false,
0,
false,
0,
) && message.instruction == DataTransferInst::GetUserPortfolio as i64
&& message.data.len() != 0
{
warn!("RETRIEVE_PORTFOLIO_INVALID_MESSAGE");
return tls_connection.shutdown().await;
}
/* verify JWT token */
let token = match verify_jwt_token(bincode::deserialize(&message.data).unwrap()) {
Ok(token) => token,
Err(_) => {
warn!("ACC_RETRIEVE_PORTFOLIO_UNAUTH_TOKEN");
let server_response = message_builder(
MessageType::ServerReturn,
0,
0,
0,
0,
bincode::serialize(&"Password Incorrect").unwrap(),
);
match tls_connection
.write_all(&bincode::serialize(&server_response).unwrap())
.await
{
_ => {
// TODO: do we shutdown connection or do we let handle_data caller do it's
// thing
tls_connection.shutdown().await.unwrap();
return Ok(());
}
};
}
};
/* connect to SQL database using user ```postfolio_schema_user``` */
let sql_conn = db_connect(
std::env::var("DB_PORTFOLIO_USER").unwrap(),
std::env::var("DB_PORTFOLIO_PASS").unwrap(),
)
.await?;
/* get userId's portfolio positions */
let mut portfolio: Portfolio = Portfolio::default();
// get position data from the portfolio_schema.positions table.
for row in sql_conn
.query(
"SELECT * FROM portfolio_schema.positions WHERE user_id = $1",
&[&token.user_id],
)
.await
.unwrap()
{
let mut pos: Position = Position::default();
pos.stock_symbol = row.get(2);
pos.stock_open_amount = row.get(3);
pos.stock_open_price = row.get(4);
pos.stock_open_cost = row.get(5);
pos.stock_close_amount = row.get(6);
pos.stock_close_price = row.get(7);
pos.open_epoch = row.get(8);
pos.close_epoch = row.get(9);
pos.is_open = row.get(10);
pos.is_buy = row.get(11);
portfolio.open_positions.push(pos);
}
/* build a message */
let message = message_builder(
MessageType::DataTransfer,
1,
1,
0,
0,
bincode::serialize(&portfolio).unwrap(),
);
match tls_connection
.write_all(&bincode::serialize(&message).unwrap())
.await
{
Ok(()) => Ok(()),
Err(err) => {
// Log issue in writing to client
warn!("Could not write to cient! With Error: {}\nIgnoring...", err);
Ok(())
}
}
}

View File

@ -1,90 +0,0 @@
use log::warn;
use crate::common::account::transaction::Transaction;
use crate::common::message::assert_msg::assert_msg;
use crate::common::message::inst::DataTransferInst;
use crate::common::message::message::Message;
use crate::common::message::message_builder::message_builder;
use crate::common::message::message_type::MessageType;
use crate::server::network::jwt_wrapper::verify_jwt_token;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio_rustls::server::TlsStream;
pub async fn acc_retrieve_transaction(
sql_conn: &tokio_postgres::Client,
tls_connection: &mut TlsStream<TcpStream>,
message: &Message,
) -> std::io::Result<()> {
/* assert recieved message */
if !assert_msg(
message,
MessageType::DataTransfer,
true,
1,
false,
0,
false,
0,
false,
0,
) && message.instruction == DataTransferInst::GetUserTransactionHist as i64
&& message.data.len() != 0
{
warn!("RETRIEVE_TRANSACTION_INVALID_MESSAGE");
return tls_connection.shutdown().await;
}
/* verify JWT token */
let token = match verify_jwt_token(bincode::deserialize(&message.data).unwrap()) {
Ok(token) => token,
Err(_) => {
warn!("ACC_RETRIEVE_TRANSACTION_UNAUTH_TOKEN");
tls_connection.shutdown().await.unwrap();
// Unauth aren't big deal,
// maybe later we can make more sophisticated DOS attack detection
return Ok(());
}
};
/* get userId's transactions */
let mut transactions: Vec<Transaction> = Vec::new();
for row in sql_conn
.query(
"SELECT * FROM accounts_schema.transactions WHERE user_id = $1",
&[&token.user_id],
)
.await
.unwrap()
{
let mut transaction = Transaction::default();
transaction.stock_symbol = row.get(2);
transaction.shares_size = row.get(3);
transaction.shares_cost = row.get(4);
transaction.is_buy = row.get(5);
transactions.push(transaction);
}
/* build message to be send */
let message = message_builder(
MessageType::ServerReturn,
1,
1,
0,
0,
bincode::serialize(&transactions).unwrap(),
);
match tls_connection
.write_all(&bincode::serialize(&message).unwrap())
.await
{
Ok(()) => Ok(()),
Err(err) => {
warn!("Could not write to client! Error: {}", err);
Ok(())
}
}
}

View File

@ -1,51 +1,50 @@
use crate::common::generic::company::Company;
use crate::common::misc::return_flags::ReturnFlags;
/// Creates a company on the postgres SQL database.
///
/// Takes in a company and writes an entry in public.companies.
/// Should be used in Async contexts.
///
/// Arguments:
/// sql_conn - The SQL connection to use.
/// company - The company to create.
///
/// Returns: the company, a string containing reason of failure on error.
///
/// Example:
/// ```rust
/// match create_company(company) {
/// Ok(()) => info!("created company"),
/// Err(err) => error!("Failed to create company with error: {}", err),
/// }
/// ```
pub async fn create_company(
sql_conn: &mut tokio_postgres::Client,
company: Company,
) -> Result<Company, ReturnFlags> {
/*
* Creates a company entry in database in public.companies.
*/
// Insert argument company into public.companies database table.
match sql_conn
.execute(
"INSERT INTO public.companies VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
&[
&company.id,
&company.symbol,
&company.isin,
&company.company_name,
&company.primary_exchange,
&company.sector,
&company.industry,
&company.primary_sic_code,
&company.employees,
],
)
.await
{
Ok(_row) => Ok(company),
Err(_) => Err(ReturnFlags::ServerDbCreateCompanyFailed),
}
}
//use crate::common::generic::company::Company;
//
///// Creates a company on the postgres SQL database.
/////
///// Takes in a company and writes an entry in public.companies.
///// Should be used in Async contexts.
/////
///// Arguments:
///// sql_conn - The SQL connection to use.
///// company - The company to create.
/////
///// Returns: the company, a string containing reason of failure on error.
/////
///// Example:
///// ```rust
///// match create_company(company) {
///// Ok(()) => info!("created company"),
///// Err(err) => error!("Failed to create company with error: {}", err),
///// }
///// ```
//pub async fn create_company(
// sql_conn: &mut tokio_postgres::Client,
// company: Company,
//) -> Result<Company, String> {
// /*
// * Creates a company entry in database in public.companies.
// */
//
// // Insert argument company into public.companies database table.
// match sql_conn
// .execute(
// "INSERT INTO public.companies VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
// &[
// &company.id,
// &company.symbol,
// &company.isin,
// &company.company_name,
// &company.primary_exchange,
// &company.sector,
// &company.industry,
// &company.primary_sic_code,
// &company.employees,
// ],
// )
// .await
// {
// Ok(_row) => Ok(company),
// Err(_) => Err(format!("Failed creating company, {}.", company.symbol))
// }
//}

View File

@ -1,5 +1,4 @@
use crate::common::account::position::Position;
use crate::common::misc::return_flags::ReturnFlags;
/// Creates a position on the postgre SQL database
///
@ -22,7 +21,7 @@ pub async fn create_position(
sql_conn: &mut tokio_postgres::Client,
user_id: i64,
position: Position,
) -> Result<(), ReturnFlags> {
) -> Result<(), String> {
/*
* Creates a position entry in database in portfolio_schema.positions.
* */
@ -36,6 +35,6 @@ pub async fn create_position(
&position.stock_open_cost, &position.stock_close_amount, &position.stock_close_price,
&position.open_epoch, &position.close_epoch, &position.is_buy, &position.is_open]).await {
Ok(_rows) => Ok(()),
Err(_) => Err(ReturnFlags::ServerDbCreatePositionFailed),
Err(_) => Err("Failed to create user position.".to_string())
}
}

View File

@ -1,5 +1,3 @@
use crate::common::misc::return_flags::ReturnFlags;
/// Creates a stock on the postgres SQL database.
///
/// Takes in a stock name and creates a table in the ```asset_schema``` schema
@ -21,7 +19,7 @@ use crate::common::misc::return_flags::ReturnFlags;
pub async fn create_stock(
sql_conn: &mut tokio_postgres::Client,
stock_name: &str,
) -> Result<(), ReturnFlags> {
) -> Result<(), String> {
/*
* Creates a stock table in database in assets schema.
*/
@ -46,6 +44,6 @@ pub async fn create_stock(
.await
{
Ok(_rows) => Ok(()),
Err(_) => Err(ReturnFlags::ServerDbCreateStockFailed),
Err(_) => Err(format!("Failed creating stock, {}", stock_name)),
}
}

View File

@ -1,5 +1,4 @@
use crate::common::account::transaction::Transaction;
use crate::common::misc::return_flags::ReturnFlags;
/// Creates a transaction on the postgre SQL database
///
@ -22,7 +21,7 @@ pub async fn create_transaction(
sql_conn: &mut tokio_postgres::Client,
user_id: i64,
transaction: &Transaction,
) -> Result<(), ReturnFlags> {
) -> Result<(), String> {
/*
* Creates a transaction entry in database in accounts_schema.transactions.
* */
@ -44,6 +43,6 @@ pub async fn create_transaction(
.await
{
Ok(_rows) => Ok(()),
Err(_) => Err(ReturnFlags::ServerDbCreateTransactionFailed),
Err(_) => Err("Failed to create user transaction.".to_string()),
}
}

View File

@ -1,54 +1,53 @@
use crate::common::generic::company::Company;
use crate::common::misc::return_flags::ReturnFlags;
/// Returns a company from the postgres SQL database.
///
/// Takes in a company symbol and returns a company.
/// Shuold be used in Async contexts
///
/// Arguments:
/// sql_conn - The SQL connection to use.
/// search_symbol - The specific company symbol to find.
///
/// Returns: a reference to the found company on success, and a string containing the reason of
/// failure on error.
///
/// Example:
/// ```rust
/// match get_company_from_db("AAPL".to_string()) {
/// Ok(found_company) => info!("we found it! {:?}", found_company),
/// Err(err) => error!("we must found the sacred company! err: {}", err),
/// }
/// ```
pub async fn get_company_from_db(
sql_conn: &mut tokio_postgres::Client,
searched_symbol: &str,
) -> Result<Company, ReturnFlags> {
/*
* Returns company entry from database
*/
// Connect to database.
match sql_conn
.query(
"SELECT * FROM public.companies WHERE symbol=$1",
&[&searched_symbol],
)
.await
{
Ok(row) => {
let mut found_company: Company = Company::default();
found_company.id = row[0].get(0);
found_company.symbol = row[0].get(1);
found_company.isin = row[0].get(2);
found_company.company_name = row[0].get(3);
found_company.primary_exchange = row[0].get(4);
found_company.sector = row[0].get(5);
found_company.industry = row[0].get(6);
found_company.primary_sic_code = row[0].get(7);
found_company.employees = row[0].get(8);
return Ok(found_company);
}
Err(_) => Err(ReturnFlags::ServerDbSearchCompanyNotFound),
}
}
//use crate::common::generic::company::Company;
//
///// Returns a company from the postgres SQL database.
/////
///// Takes in a company symbol and returns a company.
///// Shuold be used in Async contexts
/////
///// Arguments:
///// sql_conn - The SQL connection to use.
///// search_symbol - The specific company symbol to find.
/////
///// Returns: a reference to the found company on success, and a string containing the reason of
///// failure on error.
/////
///// Example:
///// ```rust
///// match get_company_from_db("AAPL".to_string()) {
///// Ok(found_company) => info!("we found it! {:?}", found_company),
///// Err(err) => error!("we must found the sacred company! err: {}", err),
///// }
///// ```
//pub async fn get_company_from_db(
// sql_conn: &mut tokio_postgres::Client,
// searched_symbol: &str,
//) -> Result<Company, String> {
// /*
// * Returns company entry from database
// */
// // Connect to database.
// match sql_conn
// .query(
// "SELECT * FROM public.companies WHERE symbol=$1",
// &[&searched_symbol],
// )
// .await
// {
// Ok(row) => {
// let mut found_company: Company = Company::default();
// found_company.id = row[0].get(0);
// found_company.symbol = row[0].get(1);
// found_company.isin = row[0].get(2);
// found_company.company_name = row[0].get(3);
// found_company.primary_exchange = row[0].get(4);
// found_company.sector = row[0].get(5);
// found_company.industry = row[0].get(6);
// found_company.primary_sic_code = row[0].get(7);
// found_company.employees = row[0].get(8);
//
// return Ok(found_company);
// }
// Err(_) => Err(format!("Failed getting company, {}, not found.", searched_symbol))
// }
//}

View File

@ -1,179 +1,178 @@
use crate::common::generic::stock_val::StockVal;
use crate::common::misc::return_flags::ReturnFlags;
/// Returns the whole stock data from the postgres SQL database.
///
/// Takes in a stock symbol and returns the whole data entries of the searched stock.
/// Should be used in Async contexts.
///
/// Arguments:
/// sql_conn - The SQL connection to use.
/// searched_symbol - The name of the stock table.
///
/// Returns: a Vec<StockVal> on success, and a string containing the reason of failure on error.
///
/// Example:
/// ```rust
/// match get_stock_from_db("AAPL".into()) {
/// Ok(vals) => {
/// /* do something with the values */
/// },
/// Err(err) => panic!("failed to get the stock value, reason: {}", err)
/// };
/// ```
pub async fn get_stock_from_db(
sql_conn: &mut tokio_postgres::Client,
searched_symbol: &str,
) -> Result<Vec<StockVal>, ReturnFlags> {
/*
* Returns all stock values from database.
*/
// Query database for table.
let mut stocks: Vec<StockVal> = Vec::new();
match sql_conn
.query(
format!("SELECT * FROM asset_schema.{}", searched_symbol).as_str(),
&[],
)
.await
{
Ok(all_rows) => {
for row in all_rows {
let mut val: StockVal = StockVal::default();
val.id = row.get(0);
val.isin = row.get(1);
val.time_epoch = row.get(2);
val.ask_price = row.get(3);
val.bid_price = row.get(4);
val.volume = row.get(5);
stocks.push(val);
}
Ok(stocks)
}
Err(_) => Err(ReturnFlags::ServerDbSearchStockNotFound),
}
}
/// Returns stock data since an unix epoch from the postgres SQL database.
///
/// Takes in a stock symbol and returns the data entries after a specified epoch of the searched stock.
/// Should be used in Async contexts.
///
/// Arguments:
/// sql_conn - The SQL connection to use.
/// searched_symbol - The name of the stock table.
/// time_epoch - The time from which the stock data retrieved.
///
/// Returns: a Vec<StockVal> on success, and a string containing the reason of failure on error.
///
/// Example:
/// ```rust
/// match get_stock_from_db_since_epoch("AAPL".into(), 123456) {
/// Ok(vals) => {
/// /* do something with the filtered values */
/// },
/// Err(err) => panic!("failed to get the stock value, reason: {}", err)
/// };
/// ```
pub async fn get_stock_from_db_since_epoch(
sql_conn: &mut tokio_postgres::Client,
searched_symbol: &str,
time_epoch: i64,
) -> Result<Vec<StockVal>, ReturnFlags> {
/*
* Returns all stock values from database since a time epoch.
*/
// Query database for table.
let mut stocks: Vec<StockVal> = Vec::new();
match sql_conn
.query(
format!(
"SELECT * FROM asset_schema.{} WHERE time_epoch >= {}",
searched_symbol, time_epoch
)
.as_str(),
&[],
)
.await
{
Ok(all_rows) => {
for row in all_rows {
let mut val: StockVal = StockVal::default();
val.id = row.get(0);
val.isin = row.get(1);
val.time_epoch = row.get(2);
val.ask_price = row.get(3);
val.bid_price = row.get(4);
val.volume = row.get(5);
stocks.push(val);
}
Ok(stocks)
}
Err(_) => Err(ReturnFlags::ServerDbSearchStockNotFound),
}
}
/// Returns stock data between two unix epochs from the postgres SQL database.
///
/// Takes in a stock symbol and returns the data entries between two specified unix epochs of the searched
/// stock.
/// Should be used in Async contexts.
///
/// Arguments:
/// sql_conn - The SQL connection to use.
/// searched_symbol - The name of the stock table.
/// first_time_epoch - The time from which the stock data is first retrieved.
/// second_time_epoch - The time from which the stock data ends.
///
/// Returns: a Vec<StockVal> on success, and a string containing the reason of failure on error.
///
/// Example:
/// ```rust
/// match get_stock_from_db_between_epochs("AAPL".into(), 123456, 123459) {
/// Ok(vals) => {
/// /* do something with the filtered values */
/// },
/// Err(err) => panic!("failed to get the stock value, reason: {}", err)
/// };
/// ```
pub async fn get_stock_from_db_between_epochs(
sql_conn: &mut tokio_postgres::Client,
searched_symbol: &str,
first_time_epoch: i64,
second_time_epoch: i64,
) -> Result<Vec<StockVal>, ReturnFlags> {
/*
* Returns all stock values from database between two time epochs.
*/
// Query database for table.
let mut stocks: Vec<StockVal> = Vec::new();
match sql_conn
.query(
format!(
"SELECT * FROM asset_schema.{} WHERE time_epoch >= {} AND time_epoch <= {}",
searched_symbol, first_time_epoch, second_time_epoch
)
.as_str(),
&[],
)
.await
{
Ok(all_rows) => {
for row in all_rows {
let mut val: StockVal = StockVal::default();
val.id = row.get(0);
val.isin = row.get(1);
val.time_epoch = row.get(2);
val.ask_price = row.get(3);
val.bid_price = row.get(4);
val.volume = row.get(5);
stocks.push(val);
}
Ok(stocks)
}
Err(_) => Err(ReturnFlags::ServerDbSearchStockNotFound),
}
}
//use crate::common::generic::stock_val::StockVal;
//
///// Returns the whole stock data from the postgres SQL database.
/////
///// Takes in a stock symbol and returns the whole data entries of the searched stock.
///// Should be used in Async contexts.
/////
///// Arguments:
///// sql_conn - The SQL connection to use.
///// searched_symbol - The name of the stock table.
/////
///// Returns: a Vec<StockVal> on success, and a string containing the reason of failure on error.
/////
///// Example:
///// ```rust
///// match get_stock_from_db("AAPL".into()) {
///// Ok(vals) => {
///// /* do something with the values */
///// },
///// Err(err) => panic!("failed to get the stock value, reason: {}", err)
///// };
///// ```
//pub async fn get_stock_from_db(
// sql_conn: &mut tokio_postgres::Client,
// searched_symbol: &str,
//) -> Result<Vec<StockVal>, String> {
// /*
// * Returns all stock values from database.
// */
//
// // Query database for table.
// let mut stocks: Vec<StockVal> = Vec::new();
// match sql_conn
// .query(
// format!("SELECT * FROM asset_schema.{}", searched_symbol).as_str(),
// &[],
// )
// .await
// {
// Ok(all_rows) => {
// for row in all_rows {
// let mut val: StockVal = StockVal::default();
// val.id = row.get(0);
// val.isin = row.get(1);
// val.time_epoch = row.get(2);
// val.ask_price = row.get(3);
// val.bid_price = row.get(4);
// val.volume = row.get(5);
// stocks.push(val);
// }
// Ok(stocks)
// }
// Err(_) => Err(format!("Failed to retrieve stock, {}, could not be found.", searched_symbol))
// }
//}
//
///// Returns stock data since an unix epoch from the postgres SQL database.
/////
///// Takes in a stock symbol and returns the data entries after a specified epoch of the searched stock.
///// Should be used in Async contexts.
/////
///// Arguments:
///// sql_conn - The SQL connection to use.
///// searched_symbol - The name of the stock table.
///// time_epoch - The time from which the stock data retrieved.
/////
///// Returns: a Vec<StockVal> on success, and a string containing the reason of failure on error.
/////
///// Example:
///// ```rust
///// match get_stock_from_db_since_epoch("AAPL".into(), 123456) {
///// Ok(vals) => {
///// /* do something with the filtered values */
///// },
///// Err(err) => panic!("failed to get the stock value, reason: {}", err)
///// };
///// ```
//pub async fn get_stock_from_db_since_epoch(
// sql_conn: &mut tokio_postgres::Client,
// searched_symbol: &str,
// time_epoch: i64,
//) -> Result<Vec<StockVal>, String> {
// /*
// * Returns all stock values from database since a time epoch.
// */
//
// // Query database for table.
// let mut stocks: Vec<StockVal> = Vec::new();
// match sql_conn
// .query(
// format!(
// "SELECT * FROM asset_schema.{} WHERE time_epoch >= {}",
// searched_symbol, time_epoch
// )
// .as_str(),
// &[],
// )
// .await
// {
// Ok(all_rows) => {
// for row in all_rows {
// let mut val: StockVal = StockVal::default();
// val.id = row.get(0);
// val.isin = row.get(1);
// val.time_epoch = row.get(2);
// val.ask_price = row.get(3);
// val.bid_price = row.get(4);
// val.volume = row.get(5);
// stocks.push(val);
// }
// Ok(stocks)
// }
// Err(_) => Err(format!("Failed to retrieve stock, {}, could not be found.", searched_symbol))
// }
//}
//
///// Returns stock data between two unix epochs from the postgres SQL database.
/////
///// Takes in a stock symbol and returns the data entries between two specified unix epochs of the searched
///// stock.
///// Should be used in Async contexts.
/////
///// Arguments:
///// sql_conn - The SQL connection to use.
///// searched_symbol - The name of the stock table.
///// first_time_epoch - The time from which the stock data is first retrieved.
///// second_time_epoch - The time from which the stock data ends.
/////
///// Returns: a Vec<StockVal> on success, and a string containing the reason of failure on error.
/////
///// Example:
///// ```rust
///// match get_stock_from_db_between_epochs("AAPL".into(), 123456, 123459) {
///// Ok(vals) => {
///// /* do something with the filtered values */
///// },
///// Err(err) => panic!("failed to get the stock value, reason: {}", err)
///// };
///// ```
//pub async fn get_stock_from_db_between_epochs(
// sql_conn: &mut tokio_postgres::Client,
// searched_symbol: &str,
// first_time_epoch: i64,
// second_time_epoch: i64,
//) -> Result<Vec<StockVal>, String> {
// /*
// * Returns all stock values from database between two time epochs.
// */
//
// // Query database for table.
// let mut stocks: Vec<StockVal> = Vec::new();
// match sql_conn
// .query(
// format!(
// "SELECT * FROM asset_schema.{} WHERE time_epoch >= {} AND time_epoch <= {}",
// searched_symbol, first_time_epoch, second_time_epoch
// )
// .as_str(),
// &[],
// )
// .await
// {
// Ok(all_rows) => {
// for row in all_rows {
// let mut val: StockVal = StockVal::default();
// val.id = row.get(0);
// val.isin = row.get(1);
// val.time_epoch = row.get(2);
// val.ask_price = row.get(3);
// val.bid_price = row.get(4);
// val.volume = row.get(5);
// stocks.push(val);
// }
// Ok(stocks)
// }
// Err(_) => Err(format!("Failed to retrieve stock, {}, could not be found.", searched_symbol))
// }
//}

View File

@ -1,12 +1,10 @@
use crate::common::misc::return_flags::ReturnFlags;
use crate::server::db::cmd::user_exists::user_exists;
pub async fn get_user_hash(
sql_conn: &tokio_postgres::Client,
username: &str,
is_email: bool,
) -> Result<String, ReturnFlags> {
) -> Result<String, String> {
/* check that user exists*/
if user_exists(sql_conn, username).await {
if is_email {
@ -24,5 +22,5 @@ pub async fn get_user_hash(
}
}
Err(ReturnFlags::ServerDbUserHashNotFound)
Err("Failed to get username hash of non-existing user.".to_string())
}

View File

@ -1,11 +1,6 @@
use crate::server::db::cmd::user_exists::user_exists;
use crate::common::misc::return_flags::ReturnFlags;
pub async fn get_user_id(
sql_conn: &tokio_postgres::Client,
username: &str,
) -> std::io::Result<i64> {
pub async fn get_user_id(sql_conn: &tokio_postgres::Client, username: &str) -> Result<i64, String> {
/* check that user exists */
if user_exists(sql_conn, username).await {
for row in sql_conn
@ -19,8 +14,5 @@ pub async fn get_user_id(
return Ok(row.get(0));
}
}
Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("{}", ReturnFlags::ServerGetUserIdNotFound),
))
Err(format!("Failed getting ID of username, {}", username))
}

View File

@ -1,45 +1,39 @@
use crate::common::misc::return_flags::ReturnFlags;
use crate::common::command::*;
use crate::server::db::cmd::user_exists::user_exists;
pub async fn get_user_salt(
sql_conn: &tokio_postgres::Client,
username: &str,
is_email: bool,
is_server: bool,
) -> Result<String, ReturnFlags> {
/* check that user exists*/
command: Command,
is_server_owned: bool,
) -> Result<String, String> {
/* check that user exists */
if user_exists(sql_conn, username).await {
if is_server {
if is_email {
for row in
&sql_conn.query("SELECT username, server_email_salt FROM accounts_schema.accounts WHERE username LIKE $1",
&[&username]).await.unwrap() {
return Ok(row.get(1));
}
} else {
for row in
&sql_conn.query("SELECT username, server_pass_salt FROM accounts_schema.accounts WHERE username LIKE $1",
&[&username]).await.unwrap() {
return Ok(row.get(1));
}
}
} else {
if is_email {
for row in
&sql_conn.query("SELECT username, client_email_salt FROM accounts_schema.accounts WHERE username LIKE $1",
&[&username]).await.unwrap() {
return Ok(row.get(1));
}
} else {
for row in
&sql_conn.query("SELECT username, client_pass_salt FROM accounts_schema.accounts WHERE username LIKE $1",
&[&username]).await.unwrap() {
return Ok(row.get(1));
}
let query_variable: String = match command {
Command::GetEmailSalt if is_server_owned => "server_email_salt",
Command::GetPasswordSalt if is_server_owned => "server_pass_salt",
Command::GetEmailSalt if !is_server_owned => "client_email_salt",
Command::GetPasswordSalt if !is_server_owned => "client_pass_salt",
_ => {
return Err(format!(
"Could not get salt for user, {}, requested salt type, {}, is invalid.",
username, command
))
}
}
.into();
for row in &sql_conn
.query(
"SELECT username, $1 FROM accounts_schema.accounts WHERE username LIKE $2",
&[&query_variable, &username],
)
.await
.unwrap()
{
return Ok(row.get(1));
}
}
Err(ReturnFlags::ServerDbUserSaltNotFound)
Err("Failed to retrieve salt of non-existing user.".to_string())
}

View File

@ -1,15 +0,0 @@
use std::collections::HashMap;
use crate::common::generic::company::Company;
use crate::common::generic::stock_val::StockVal;
#[derive(PartialEq, Debug)]
pub struct GlobalState {
pub companies: HashMap<String, Company>, // symbol, company
pub stock_vals: HashMap<String, StockVal>, // symbol, stockval
}
impl std::fmt::Display for GlobalState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({:#?}, {:#?})", self.companies, self.stock_vals)
}
}

View File

@ -1,2 +1 @@
pub mod account;
pub mod global_state;

View File

@ -1,16 +1,11 @@
use data_encoding::HEXUPPER;
use crate::common::message::inst::{CommandInst, DataTransferInst};
use crate::common::message::message::Message;
use crate::common::message::message_builder::message_builder;
use crate::common::message::message_type::MessageType;
use crate::common::account::salt::*;
use crate::common::command::*;
use crate::common::message::*;
use crate::server::account::authorization::acc_auth;
use crate::server::account::creation::acc_create;
use crate::server::account::retrieval_portfolio::acc_retrieve_portfolio;
use crate::server::account::retrieval_transaction::acc_retrieve_transaction;
use crate::server::db::cmd::get_user_salt::get_user_salt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio_rustls::server::TlsStream;
@ -18,114 +13,52 @@ pub async fn handle_data(
sql_conn: &tokio_postgres::Client,
socket: &mut TlsStream<TcpStream>,
buf: &[u8],
) -> std::io::Result<()> {
) -> Result<(), String> {
/* decode incoming message */
let client_msg: Message = bincode::deserialize(&buf).map_err(|err| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("HANDLE_DATA_RCVD_INVALID_MSG: {}", err),
)
})?;
//println!("This is a message: {}", client_msg);
let client_msg: Message = bincode::deserialize(&buf)
.map_err(|err| format!("HANDLE_DATA_RCVD_INVALID_MSG: {}", err))?;
/* handle individual client instructions */
match client_msg.instruction {
_ if client_msg.instruction == CommandInst::GenHashSalt as i64 => {
use ring::rand::SecureRandom;
use ring::{digest, rand};
let rng = rand::SystemRandom::new();
let mut salt = [0u8; digest::SHA512_OUTPUT_LEN / 2];
rng.fill(&mut salt).unwrap();
let server_response: Message = message_builder(
MessageType::DataTransfer,
CommandInst::GenHashSalt as i64,
1,
0,
1,
salt.to_vec(),
);
socket
.write_all(bincode::serialize(&server_response).unwrap().as_slice())
// a command which is executed is assumed to return an IO result.
// handle_data() would then log if writing to client failed.
let cmd_client_write_result: std::io::Result<()> = match client_msg.command {
Command::GenHashSalt => {
Message::new()
.command(Command::Success)
.data(gen_salt())
.send(socket)
.await
}
_ if client_msg.instruction == CommandInst::GetEmailSalt as i64 => {
use crate::server::db::cmd::get_user_salt::get_user_salt;
match get_user_salt(
sql_conn,
String::from_utf8(client_msg.data).unwrap().as_str(),
true,
false,
)
.await
{
Ok(salt) => {
let server_response: Message = message_builder(
MessageType::DataTransfer,
CommandInst::GetEmailSalt as i64,
1,
0,
1,
HEXUPPER.decode(salt.as_bytes()).unwrap(),
);
socket
.write_all(bincode::serialize(&server_response).unwrap().as_slice())
.await
}
Err(_) => {
let server_response =
message_builder(MessageType::ServerReturn, 0, 0, 0, 0, Vec::new());
socket
.write_all(bincode::serialize(&server_response).unwrap().as_slice())
.await
}
Command::GetEmailSalt | Command::GetPasswordSalt => {
let salt =
get_user_salt(sql_conn, client_msg.get_data()?, client_msg.command, false).await;
let mut response = Message::new();
if salt.is_ok() {
response
.command(Command::Success)
.data(salt.unwrap())
.send(socket)
.await
} else {
response.command(Command::Failure).send(socket).await
}
}
_ if client_msg.instruction == CommandInst::GetPasswordSalt as i64 => {
use crate::server::db::cmd::get_user_salt::get_user_salt;
match get_user_salt(
sql_conn,
String::from_utf8(client_msg.data).unwrap().as_str(),
false,
false,
)
.await
{
Ok(salt) => {
let server_response: Message = message_builder(
MessageType::DataTransfer,
CommandInst::GetPasswordSalt as i64,
1,
0,
1,
HEXUPPER.decode(salt.as_bytes()).unwrap(),
);
socket
.write_all(bincode::serialize(&server_response).unwrap().as_slice())
.await
}
Err(_) => {
let server_response =
message_builder(MessageType::ServerReturn, 0, 0, 0, 0, Vec::new());
Command::Register => acc_create(sql_conn, socket, &client_msg).await,
Command::LoginMethod1 => acc_auth(sql_conn, socket, &client_msg).await,
_ => {
Message::new()
.command(Command::Failure)
.data("Could not handle an unknown command!")
.send(socket)
.await
}
};
socket
.write_all(bincode::serialize(&server_response).unwrap().as_slice())
.await
}
}
}
_ if client_msg.instruction == CommandInst::Register as i64 => {
acc_create(sql_conn, socket, &client_msg).await
}
_ if client_msg.instruction == CommandInst::LoginMethod1 as i64 => {
acc_auth(sql_conn, socket, &client_msg).await
}
_ if client_msg.instruction == DataTransferInst::GetUserPortfolio as i64 => {
acc_retrieve_portfolio(socket, &client_msg).await
}
_ if client_msg.instruction == DataTransferInst::GetUserTransactionHist as i64 => {
acc_retrieve_transaction(sql_conn, socket, &client_msg).await
}
_ => Ok(()),
match cmd_client_write_result {
Ok(_) => Ok(()),
Err(e) => Err(format!(
"Failed running command, writing to socket failed. Error: {}",
e
)),
}
}

View File

@ -1,5 +1,4 @@
use crate::common::misc::return_flags::ReturnFlags;
use crate::common::sessions::jwt_claim::JWTClaim;
use crate::common::jwt;
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
pub static JWT_SECRET: &'static str = "seecreet";
@ -19,11 +18,11 @@ pub static JWT_SECRET: &'static str = "seecreet";
/// ```rust
/// let token = create_jwt_token(auth_user_id, unix_expiry_epoch).unwrap();
/// ```
pub fn create_jwt_token(user_id: i64, exp: u64) -> Result<String, ReturnFlags> {
pub fn create_jwt_token(user_id: i64, exp: u64) -> Result<String, String> {
let mut header = Header::default();
header.alg = Algorithm::HS512;
let claim = JWTClaim {
let claim = jwt::Claim {
user_id: user_id,
exp: exp,
};
@ -33,7 +32,7 @@ pub fn create_jwt_token(user_id: i64, exp: u64) -> Result<String, ReturnFlags> {
&EncodingKey::from_secret(JWT_SECRET.as_bytes()),
) {
Ok(token) => Ok(token),
Err(_) => Err(ReturnFlags::ServerCreateJwtTokenFailed),
Err(_) => Err("Failed to create JWT token.".to_string()),
}
}
@ -50,55 +49,15 @@ pub fn create_jwt_token(user_id: i64, exp: u64) -> Result<String, ReturnFlags> {
/// ```rust
/// assert_eq!(verify_jwt_token(token).unwrap(), true);
/// ```
pub fn verify_jwt_token(token: String) -> Result<JWTClaim, ()> {
pub fn verify_jwt_token(token: String) -> Result<jwt::Claim, String> {
let mut validation = Validation::new(Algorithm::HS512);
validation.leeway = 25;
match decode::<JWTClaim>(
match decode::<jwt::Claim>(
&token,
&DecodingKey::from_secret(JWT_SECRET.as_bytes()),
&validation,
) {
Ok(data) => Ok(data.claims),
Err(_) => Err(()),
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_create_jwt_token() {
use std::time::{Duration, SystemTime, UNIX_EPOCH};
let start = SystemTime::now() + Duration::from_secs(4 * 60 * 60);
match create_jwt_token(1i64, start.duration_since(UNIX_EPOCH).unwrap().as_secs()) {
Ok(token) => {
let claims = verify_jwt_token(token).unwrap();
assert_eq!(claims.user_id, 1i64);
assert_eq!(
claims.exp,
start.duration_since(UNIX_EPOCH).unwrap().as_secs()
);
}
Err(_) => panic!("TEST_CREATE_JWT_TOKEN_FAILED"),
}
}
#[test]
fn test_verify_jwt_token() {
use std::time::{Duration, SystemTime, UNIX_EPOCH};
let start = SystemTime::now() + Duration::from_secs(4 * 60 * 60);
let token =
create_jwt_token(1i64, start.duration_since(UNIX_EPOCH).unwrap().as_secs()).unwrap();
match verify_jwt_token(token) {
Ok(claims) => {
assert_eq!(claims.user_id, 1i64);
assert_eq!(
claims.exp,
start.duration_since(UNIX_EPOCH).unwrap().as_secs()
);
}
Err(_) => panic!("TEST_VERIFY_JWT_TOKEN_FAILED"),
}
Err(_) => Err("Invalid JWT token.".to_string()),
}
}

View File

@ -19,6 +19,11 @@
## In progress
- create correct modules
* [ ] fix weird file naming
* [ ] fix namespace naming
* [ ] remove unneeded MessageType
* [ ] make server return coded
## Done
@ -145,3 +150,16 @@
- server move network code to somewhere plausable.
* [x] move assert_msg to message namespace
- investigate server logging doesn't include location of log method caller
- data implmentation stuff
* [ ] implement buy & sell
* [ ] impl on client
* [ ] impl on server
* [ ] assets data retrieval
* [ ] ret data client
* [ ] ret data on server
* [ ] transaction data retrieval
* [ ] split data into multiple writes
- add testing suite
* [ ] configure ci to run those tests
* [ ] add tests for all server functions
* [ ] add tests for all client functions