331 lines
13 KiB
Rust
331 lines
13 KiB
Rust
use crate::*;
|
|
use crossbeam_deque::{Injector, Steal};
|
|
use log::{debug, error, info, warn};
|
|
use mio::{event::Event, net::TcpListener, Events, Interest, Poll, Token};
|
|
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read as _, Write as _};
|
|
use std::marker::Send;
|
|
use std::net::SocketAddr;
|
|
use std::sync::{mpsc, Arc};
|
|
use std::time::{Duration, Instant};
|
|
|
|
const MINIMUM_POLL_DURATION: Duration = Duration::from_millis(1);
|
|
|
|
// TODO: Implement sleeping when inactive. If no read or write events have
|
|
// taken place in the past duration, increase polling time. Make this
|
|
// configurable, because a Gemini server will have different requirements to
|
|
// the main Doctrine API. Default to no sleeping.
|
|
|
|
/// A server that processes incoming requests and generates responses
|
|
pub struct Server<Req: Request> {
|
|
/// The maximum number of concurrent inbound TCP connections that this
|
|
/// server supports. When the number of concurrent connections equals
|
|
/// this number, any new incoming connections will be dropped immediately
|
|
/// with no response.
|
|
/// TODO: Figure out a more elegant method of preventing overload below
|
|
/// this level. Measure current CPU load, perhaps? Or measure median
|
|
/// request time and throttle based on that?
|
|
/// TODO: Consider checking the size of the connections vector every so
|
|
/// often during quiet moments, and shrinking it if possible. This would
|
|
/// prevent peak surges of traffic from allocating large amounts of RAM
|
|
/// in perpetuity.
|
|
max_connections: usize,
|
|
|
|
/// Hold all current inbound connections in a fixed-size Vec.
|
|
connections: Vec<Option<Connection<Req>>>,
|
|
|
|
poll: Poll,
|
|
listener: TcpListener,
|
|
|
|
/// The ID of the next token to be allocated. All tokens up to this value
|
|
/// have already been allocated. This is also the size of the connections
|
|
/// Vec.
|
|
/// TODO: Could this be removed in favour of counting the length of the
|
|
/// connections Vec?
|
|
next_token_id: usize,
|
|
/// A vector containing all tokens that are ready to be reused. These
|
|
/// tokens have been allocated, used for a past connection, and then freed.
|
|
freed_tokens: Vec<Token>,
|
|
|
|
/// All request processing threads.
|
|
worker_threads: Vec<std::thread::JoinHandle<()>>,
|
|
/// A queue of all the requests that are waiting to be sent to a worker
|
|
/// thread to be processed into responses. Each worker thread has direct
|
|
/// access to this queue.
|
|
request_queue: Arc<Injector<Connection<Req>>>,
|
|
/// The end of the one-way channel that connects all worker threads to
|
|
/// this server.
|
|
response_receiver: mpsc::Receiver<Connection<Req>>,
|
|
}
|
|
|
|
impl<Req> Server<Req>
|
|
where
|
|
Req: Request + 'static + Send,
|
|
Req::Response: Send,
|
|
Req::Parser: Send,
|
|
{
|
|
/// Create a new server and create worker threads to process requests.
|
|
pub fn new(
|
|
server_address: SocketAddr,
|
|
max_connections: usize,
|
|
worker_count: usize,
|
|
request_processor: RequestProcessor<Req>,
|
|
) -> Self {
|
|
let mut listener = match TcpListener::bind(server_address) {
|
|
Ok(listener) => listener,
|
|
Err(error) => handle_server_bind_error(error, server_address),
|
|
};
|
|
info!("Server is listening on address {}", server_address);
|
|
|
|
// Register the server TCP connection with the poll object. The poll
|
|
// object will listen for incoming connections.
|
|
let poll = Poll::new().unwrap();
|
|
poll.registry()
|
|
.register(&mut listener, Token(0), Interest::READABLE)
|
|
.unwrap();
|
|
|
|
// Create a channel to connect worker threads to the main thread, so
|
|
// that responses can be collected and returned to each client.
|
|
let (response_sender, response_receiver) = mpsc::channel();
|
|
|
|
let mut worker_threads = Vec::new();
|
|
let request_queue: Arc<Injector<Connection<Req>>> = Arc::new(Injector::new());
|
|
|
|
// Start a number of worker threads, which will be used to process
|
|
// requests into responses.
|
|
for _ in 0..worker_count {
|
|
let request_queue = request_queue.clone();
|
|
let response_sender = response_sender.clone();
|
|
worker_threads.push(std::thread::spawn(move || loop {
|
|
match request_queue.steal() {
|
|
Steal::Success(mut connection) => {
|
|
let request = match connection.state {
|
|
ConnectionState::Processing(ref request) => request,
|
|
_ => unreachable!(),
|
|
};
|
|
let response = request_processor(request);
|
|
connection.state = ConnectionState::Outgoing(response);
|
|
response_sender.send(connection).unwrap()
|
|
}
|
|
Steal::Empty | Steal::Retry => (),
|
|
}
|
|
// TODO: Instead of sleeping for a fixed duration, keep a
|
|
// record of how busy the server has been for the past while.
|
|
// If the worker threads are mostly idle, sleep for longer.
|
|
// If the worker threads are screaming along, don't sleep.
|
|
std::thread::sleep(MINIMUM_POLL_DURATION);
|
|
}))
|
|
}
|
|
match worker_count {
|
|
1 => info!("{} worker thread has been created", worker_count),
|
|
_ => info!("{} worker threads have been created", worker_count),
|
|
}
|
|
|
|
Self {
|
|
max_connections,
|
|
connections: Vec::new(),
|
|
poll,
|
|
listener,
|
|
|
|
next_token_id: 1,
|
|
freed_tokens: Vec::new(),
|
|
|
|
worker_threads,
|
|
request_queue,
|
|
response_receiver,
|
|
}
|
|
}
|
|
|
|
/// Poll for, and handle, incoming connections.
|
|
pub fn poll(&mut self) {
|
|
let poll_start = Instant::now();
|
|
let mut events = Events::with_capacity(1024);
|
|
const TIMEOUT: Option<Duration> = Some(Duration::from_millis(1));
|
|
self.poll.poll(&mut events, TIMEOUT).unwrap();
|
|
for event in &events {
|
|
if event.is_readable() {
|
|
if event.token() == Token(0) {
|
|
self.accept_new_connections();
|
|
} else {
|
|
self.process_read_event(event);
|
|
}
|
|
} else if event.is_writable() {
|
|
self.process_write_event(event);
|
|
} else {
|
|
warn!("Received unreadable and unwritable event")
|
|
}
|
|
}
|
|
while let Ok(connection) = self.response_receiver.try_recv() {
|
|
self.set_outgoing_connection(connection)
|
|
}
|
|
let elapsed = poll_start.elapsed();
|
|
if elapsed < MINIMUM_POLL_DURATION {
|
|
std::thread::sleep(MINIMUM_POLL_DURATION - elapsed);
|
|
}
|
|
}
|
|
|
|
fn set_outgoing_connection(&mut self, connection: Connection<Req>) {
|
|
let slot = self.connections.get_mut(connection.token.0).unwrap();
|
|
*slot = Some(connection);
|
|
if let Some(ref mut connection) = slot {
|
|
self.poll
|
|
.registry()
|
|
.reregister(&mut connection.stream, connection.token, Interest::WRITABLE)
|
|
.unwrap();
|
|
} else {
|
|
unreachable!()
|
|
};
|
|
}
|
|
|
|
fn process_read_event(&mut self, event: &Event) {
|
|
let token = event.token();
|
|
let slot_index = token.0;
|
|
debug!("Read event for token {}", token.0);
|
|
let connection = match self.connections[slot_index].as_mut() {
|
|
Some(connection) => connection,
|
|
None => return,
|
|
};
|
|
if let ConnectionState::Incoming(ref mut parser) = connection.state {
|
|
loop {
|
|
let mut buffer = [0_u8; 1024];
|
|
match connection.stream.read(&mut buffer) {
|
|
Ok(0) => {
|
|
self.remove_connection(token);
|
|
return;
|
|
}
|
|
Ok(len) => parser.push_bytes(&buffer[..len]),
|
|
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
|
|
Err(e) => error!("Unexpected error: {}", e),
|
|
};
|
|
}
|
|
|
|
match parser.try_parse() {
|
|
RequestParseResult::Complete(request) => {
|
|
let mut connection =
|
|
std::mem::replace(&mut self.connections[slot_index], None).unwrap();
|
|
connection.state = ConnectionState::Processing(request);
|
|
self.request_queue.push(connection);
|
|
}
|
|
RequestParseResult::Invalid(response) => {
|
|
connection.state = ConnectionState::Outgoing(response);
|
|
let connection =
|
|
std::mem::replace(&mut self.connections[slot_index], None).unwrap();
|
|
self.set_outgoing_connection(connection);
|
|
}
|
|
RequestParseResult::Incomplete => (),
|
|
};
|
|
} else {
|
|
warn!("Received read event for non-incoming connection")
|
|
}
|
|
}
|
|
|
|
fn process_write_event(&mut self, event: &Event) {
|
|
let token = event.token();
|
|
let mut connection = std::mem::replace(&mut self.connections[token.0], None).unwrap();
|
|
if let ConnectionState::Outgoing(response) = connection.state {
|
|
let bytes = response.to_bytes();
|
|
connection.stream.write_all(&bytes).unwrap();
|
|
} else {
|
|
warn!("Received write event for non-outgoing connection")
|
|
}
|
|
self.remove_connection(connection.token);
|
|
info!(
|
|
"Closed connection from {} (token {})",
|
|
connection.client_address, connection.token.0
|
|
);
|
|
}
|
|
|
|
/// Accept all pending incoming connections.
|
|
fn accept_new_connections(&mut self) {
|
|
loop {
|
|
match self.listener.accept() {
|
|
Ok((stream, address)) => {
|
|
// Get an unused token
|
|
let token = match self.get_unused_token() {
|
|
Some(token) => token,
|
|
None => {
|
|
warn!("Capacity reached, dropping connection from {}", address);
|
|
continue;
|
|
}
|
|
};
|
|
|
|
// Allocate sufficient capacity in the connections Vec
|
|
if self.connections.len() <= token.0 {
|
|
let difference = token.0 - self.connections.len() + 1;
|
|
(0..difference).for_each(|_| self.connections.push(None));
|
|
}
|
|
|
|
// Create a connection object and register it as Readable.
|
|
// The dance is required because I can't move the connection
|
|
// once it's been registered to the poll object.
|
|
let slot = self.connections.get_mut(token.0).unwrap();
|
|
*slot = Some(Connection::new(stream, address, token));
|
|
if let Some(ref mut connection) = slot {
|
|
self.poll
|
|
.registry()
|
|
.register(&mut connection.stream, token, Interest::READABLE)
|
|
.unwrap();
|
|
} else {
|
|
unreachable!()
|
|
};
|
|
info!("Accepted connection from {} (token {})", address, token.0);
|
|
}
|
|
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
|
|
Err(e) => error!("Unexpected error while accepting a connection: {}", e),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Returns an unused token if one exists, else None.
|
|
fn get_unused_token(&mut self) -> Option<Token> {
|
|
let freed_token = self.freed_tokens.pop();
|
|
if freed_token.is_some() {
|
|
return freed_token;
|
|
};
|
|
|
|
if self.next_token_id < self.max_connections + 1 {
|
|
// The +1 is because connection 0 is the server, so we need one more
|
|
let token_id = self.next_token_id;
|
|
self.next_token_id += 1;
|
|
Some(Token(token_id))
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// Drop a connection and reclaim the token associated with that connection.
|
|
fn remove_connection(&mut self, token: Token) {
|
|
if let Some(slot) = self.connections.get_mut(token.0) {
|
|
*slot = None;
|
|
self.freed_tokens.push(token);
|
|
} else {
|
|
warn!("Attempted to remove non-existent connection {}", token.0);
|
|
};
|
|
}
|
|
|
|
/// TODO: Methods to accept or refuse incoming connections, to be used for
|
|
/// gently shedding load before killing a server.
|
|
pub fn accept_incoming_connections(&mut self) {}
|
|
pub fn refuse_incoming_connections(&mut self) {}
|
|
/// Return the number of current connections.
|
|
pub fn num_current_connections(&self) {}
|
|
|
|
pub fn run(&mut self) -> ! {
|
|
loop {
|
|
self.poll()
|
|
}
|
|
}
|
|
}
|
|
|
|
#[rustfmt::skip]
|
|
fn handle_server_bind_error(error: IoError, server_address: SocketAddr) -> ! {
|
|
let port = server_address.port();
|
|
match error.kind() {
|
|
IoErrorKind::PermissionDenied => match port < 1024 {
|
|
true => error!("Could not bind the server to privileged port {} without admin permissions", port),
|
|
false => error!("Could not bind server to port {} due to insufficient permissions", port),
|
|
},
|
|
_ => error!("Could not bind server to port {}", port),
|
|
};
|
|
std::process::exit(1);
|
|
}
|