wetstring/src/server.rs

252 lines
9.7 KiB
Rust

use crate::*;
use log::{debug, error, info, warn};
use mio::{event::Event, net::TcpListener, Events, Interest, Poll, Token};
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read as _, Write as _};
use std::net::SocketAddr;
use std::time::Duration;
/// A server that processes incoming requests and generates responses
pub struct Server<Parser: RequestParser, Proc: RequestProcessor> {
/// Hold all current inbound connections in a fixed-size Vec.
connections: Vec<Option<Connection<Parser>>>,
max_connections: usize,
poll: Poll,
listener: TcpListener,
/// The ID of the next token to be allocated. All tokens up to this value
/// have already been allocated. This is also the size of the connections
/// Vec.
next_token_id: usize,
/// A vector containing all tokens that are ready to be reused. These
/// tokens have been allocated, used for a past connection, and then freed.
freed_tokens: Vec<Token>,
request_processor: Proc,
}
impl<Parser, Proc> Server<Parser, Proc>
where
Parser: RequestParser,
Proc: RequestProcessor<Req = Parser::Req, Res = Parser::Res>,
{
/// Create a new server and create worker threads to process requests.
pub fn new(server_address: SocketAddr, request_processor: Proc) -> Self {
let mut listener = match TcpListener::bind(server_address) {
Ok(listener) => listener,
Err(error) => handle_server_bind_error(error, server_address),
};
info!("Server is listening on address {}", server_address);
// Register the server TCP connection with the poll object. The poll
// object will listen for incoming connections.
let poll = Poll::new().unwrap();
poll.registry()
.register(&mut listener, Token(0), Interest::READABLE)
.unwrap();
Self {
connections: Vec::new(),
max_connections: 100,
poll,
listener,
next_token_id: 1,
freed_tokens: Vec::new(),
request_processor,
}
}
/// Handle read and write events on all current connections
pub fn poll(&mut self) {
let mut events = Events::with_capacity(1024);
const TIMEOUT: Option<Duration> = Some(Duration::from_millis(1));
self.poll.poll(&mut events, TIMEOUT).unwrap();
for event in &events {
if event.is_readable() {
if event.token() == Token(0) {
self.accept_new_connections();
} else {
self.process_read_event(event);
}
} else if event.is_writable() {
self.process_write_event(event);
} else {
warn!("Received unreadable and unwritable event")
}
}
const MINIMUM_POLL_DURATION: Duration = Duration::from_millis(1);
std::thread::sleep(MINIMUM_POLL_DURATION);
}
/// Change the polling mode of a connection from readable to writable
fn set_outgoing_connection(&mut self, connection: Connection<Parser>) {
let slot = self.connections.get_mut(connection.token.0).unwrap();
*slot = Some(connection);
if let Some(ref mut connection) = slot {
self.poll
.registry()
.reregister(&mut connection.stream, connection.token, Interest::WRITABLE)
.unwrap();
} else {
unreachable!()
};
}
fn process_read_event(&mut self, event: &Event) {
let token = event.token();
let slot_index = token.0;
debug!("Read event for token {}", token.0);
let connection = match self.connections[slot_index].as_mut() {
Some(connection) => connection,
None => return,
};
if let ConnectionState::Incoming(ref mut parser) = connection.state {
loop {
let mut buffer = [0_u8; 1024];
match connection.stream.read(&mut buffer) {
Ok(0) => {
self.remove_connection(token);
return;
}
Ok(len) => parser.push_bytes(&buffer[..len]),
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(e) => error!("Unexpected error: {}", e),
};
}
match parser.try_parse() {
RequestParseResult::Complete(request) => {
let mut connection =
std::mem::replace(&mut self.connections[slot_index], None).unwrap();
let response = self.request_processor.process_request(&request);
connection.state = ConnectionState::Outgoing(response);
self.set_outgoing_connection(connection)
}
RequestParseResult::Invalid(response) => {
connection.state = ConnectionState::Outgoing(response);
let connection =
std::mem::replace(&mut self.connections[slot_index], None).unwrap();
self.set_outgoing_connection(connection);
}
RequestParseResult::Incomplete => (),
};
} else {
warn!("Received read event for non-incoming connection")
}
}
fn process_write_event(&mut self, event: &Event) {
let token = event.token();
let mut connection = std::mem::replace(&mut self.connections[token.0], None).unwrap();
if let ConnectionState::Outgoing(response) = connection.state {
let bytes = response.to_bytes();
connection.stream.write_all(&bytes).unwrap();
} else {
warn!("Received write event for non-outgoing connection")
}
self.remove_connection(connection.token);
info!(
"Closed connection from {} (token {})",
connection.client_address, connection.token.0
);
}
/// Accept all pending incoming connections.
fn accept_new_connections(&mut self) {
loop {
match self.listener.accept() {
Ok((stream, address)) => {
// Get an unused token
let token = match self.get_unused_token() {
Some(token) => token,
None => {
warn!("Capacity reached, dropping connection from {}", address);
continue;
}
};
// Allocate sufficient capacity in the connections Vec
if self.connections.len() <= token.0 {
let difference = token.0 - self.connections.len() + 1;
(0..difference).for_each(|_| self.connections.push(None));
}
// Create a connection object and register it as Readable.
// The dance is required because I can't move the connection
// once it's been registered to the poll object.
let slot = self.connections.get_mut(token.0).unwrap();
*slot = Some(Connection::new(stream, address, token));
if let Some(ref mut connection) = slot {
self.poll
.registry()
.register(&mut connection.stream, token, Interest::READABLE)
.unwrap();
} else {
unreachable!()
};
info!("Accepted connection from {} (token {})", address, token.0);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(e) => error!("Unexpected error while accepting a connection: {}", e),
}
}
}
/// Returns an unused token if one exists, else None.
fn get_unused_token(&mut self) -> Option<Token> {
let freed_token = self.freed_tokens.pop();
if freed_token.is_some() {
return freed_token;
};
if self.next_token_id < self.max_connections + 1 {
// The +1 is because connection 0 is the server, so we need one more
let token_id = self.next_token_id;
self.next_token_id += 1;
Some(Token(token_id))
} else {
None
}
}
/// Drop a connection and reclaim the token associated with that connection.
fn remove_connection(&mut self, token: Token) {
if let Some(slot) = self.connections.get_mut(token.0) {
*slot = None;
self.freed_tokens.push(token);
} else {
warn!("Attempted to remove non-existent connection {}", token.0);
};
}
/// TODO: Methods to accept or refuse incoming connections, to be used for
/// gently shedding load before killing a server.
pub fn accept_incoming_connections(&mut self) {}
pub fn refuse_incoming_connections(&mut self) {}
/// Return the number of current connections.
pub fn num_current_connections(&self) {}
pub fn run(&mut self) -> ! {
loop {
self.poll()
}
}
}
#[rustfmt::skip]
fn handle_server_bind_error(error: IoError, server_address: SocketAddr) -> ! {
let port = server_address.port();
match error.kind() {
IoErrorKind::PermissionDenied => match port < 1024 {
true => error!("Could not bind the server to privileged port {} without admin permissions", port),
false => error!("Could not bind server to port {} due to insufficient permissions", port),
},
_ => error!("Could not bind server to port {}", port),
};
std::process::exit(1);
}