use crate::*; use log::{debug, error, info, warn}; use mio::{event::Event, net::TcpListener, Events, Interest, Poll, Token}; use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read as _, Write as _}; use std::net::SocketAddr; use std::time::Duration; /// A server that processes incoming requests and generates responses pub struct Server { /// Hold all current inbound connections in a fixed-size Vec. connections: Vec>>, max_connections: usize, poll: Poll, listener: TcpListener, /// The ID of the next token to be allocated. All tokens up to this value /// have already been allocated. This is also the size of the connections /// Vec. next_token_id: usize, /// A vector containing all tokens that are ready to be reused. These /// tokens have been allocated, used for a past connection, and then freed. freed_tokens: Vec, request_processor: Proc, } impl Server where Parser: RequestParser, Proc: RequestProcessor, { /// Create a new server and create worker threads to process requests. pub fn new(server_address: SocketAddr, request_processor: Proc) -> Self { let mut listener = match TcpListener::bind(server_address) { Ok(listener) => listener, Err(error) => handle_server_bind_error(error, server_address), }; info!("Server is listening on address {}", server_address); // Register the server TCP connection with the poll object. The poll // object will listen for incoming connections. let poll = Poll::new().unwrap(); poll.registry() .register(&mut listener, Token(0), Interest::READABLE) .unwrap(); Self { connections: Vec::new(), max_connections: 100, poll, listener, next_token_id: 1, freed_tokens: Vec::new(), request_processor, } } /// Handle read and write events on all current connections pub fn poll(&mut self) { let mut events = Events::with_capacity(1024); const TIMEOUT: Option = Some(Duration::from_millis(1)); self.poll.poll(&mut events, TIMEOUT).unwrap(); for event in &events { if event.is_readable() { if event.token() == Token(0) { self.accept_new_connections(); } else { self.process_read_event(event); } } else if event.is_writable() { self.process_write_event(event); } else { warn!("Received unreadable and unwritable event") } } const MINIMUM_POLL_DURATION: Duration = Duration::from_millis(1); std::thread::sleep(MINIMUM_POLL_DURATION); } /// Change the polling mode of a connection from readable to writable fn set_outgoing_connection(&mut self, connection: Connection) { let slot = self.connections.get_mut(connection.token.0).unwrap(); *slot = Some(connection); if let Some(ref mut connection) = slot { self.poll .registry() .reregister(&mut connection.stream, connection.token, Interest::WRITABLE) .unwrap(); } else { unreachable!() }; } fn process_read_event(&mut self, event: &Event) { let token = event.token(); let slot_index = token.0; debug!("Read event for token {}", token.0); let connection = match self.connections[slot_index].as_mut() { Some(connection) => connection, None => return, }; if let ConnectionState::Incoming(ref mut parser) = connection.state { loop { let mut buffer = [0_u8; 1024]; match connection.stream.read(&mut buffer) { Ok(0) => { self.remove_connection(token); return; } Ok(len) => parser.push_bytes(&buffer[..len]), Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break, Err(e) => error!("Unexpected error: {}", e), }; } match parser.try_parse() { RequestParseResult::Complete(request) => { let mut connection = std::mem::replace(&mut self.connections[slot_index], None).unwrap(); let response = self.request_processor.process_request(&request); connection.state = ConnectionState::Outgoing(response); self.set_outgoing_connection(connection) } RequestParseResult::Invalid(response) => { connection.state = ConnectionState::Outgoing(response); let connection = std::mem::replace(&mut self.connections[slot_index], None).unwrap(); self.set_outgoing_connection(connection); } RequestParseResult::Incomplete => (), }; } else { warn!("Received read event for non-incoming connection") } } fn process_write_event(&mut self, event: &Event) { let token = event.token(); let mut connection = std::mem::replace(&mut self.connections[token.0], None).unwrap(); if let ConnectionState::Outgoing(response) = connection.state { let bytes = response.to_bytes(); connection.stream.write_all(&bytes).unwrap(); } else { warn!("Received write event for non-outgoing connection") } self.remove_connection(connection.token); info!( "Closed connection from {} (token {})", connection.client_address, connection.token.0 ); } /// Accept all pending incoming connections. fn accept_new_connections(&mut self) { loop { match self.listener.accept() { Ok((stream, address)) => { // Get an unused token let token = match self.get_unused_token() { Some(token) => token, None => { warn!("Capacity reached, dropping connection from {}", address); continue; } }; // Allocate sufficient capacity in the connections Vec if self.connections.len() <= token.0 { let difference = token.0 - self.connections.len() + 1; (0..difference).for_each(|_| self.connections.push(None)); } // Create a connection object and register it as Readable. // The dance is required because I can't move the connection // once it's been registered to the poll object. let slot = self.connections.get_mut(token.0).unwrap(); *slot = Some(Connection::new(stream, address, token)); if let Some(ref mut connection) = slot { self.poll .registry() .register(&mut connection.stream, token, Interest::READABLE) .unwrap(); } else { unreachable!() }; info!("Accepted connection from {} (token {})", address, token.0); } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break, Err(e) => error!("Unexpected error while accepting a connection: {}", e), } } } /// Returns an unused token if one exists, else None. fn get_unused_token(&mut self) -> Option { let freed_token = self.freed_tokens.pop(); if freed_token.is_some() { return freed_token; }; if self.next_token_id < self.max_connections + 1 { // The +1 is because connection 0 is the server, so we need one more let token_id = self.next_token_id; self.next_token_id += 1; Some(Token(token_id)) } else { None } } /// Drop a connection and reclaim the token associated with that connection. fn remove_connection(&mut self, token: Token) { if let Some(slot) = self.connections.get_mut(token.0) { *slot = None; self.freed_tokens.push(token); } else { warn!("Attempted to remove non-existent connection {}", token.0); }; } /// TODO: Methods to accept or refuse incoming connections, to be used for /// gently shedding load before killing a server. pub fn accept_incoming_connections(&mut self) {} pub fn refuse_incoming_connections(&mut self) {} /// Return the number of current connections. pub fn num_current_connections(&self) {} pub fn run(&mut self) -> ! { loop { self.poll() } } } #[rustfmt::skip] fn handle_server_bind_error(error: IoError, server_address: SocketAddr) -> ! { let port = server_address.port(); match error.kind() { IoErrorKind::PermissionDenied => match port < 1024 { true => error!("Could not bind the server to privileged port {} without admin permissions", port), false => error!("Could not bind server to port {} due to insufficient permissions", port), }, _ => error!("Could not bind server to port {}", port), }; std::process::exit(1); }