wetstring works right now, uses fn(&Request) -> Response for server function
This commit is contained in:
parent
1e83969ebd
commit
ab92b5a802
|
@ -1 +1,2 @@
|
|||
/target
|
||||
Cargo.lock
|
|
@ -2,26 +2,6 @@
|
|||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "0.7.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atty"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.0.1"
|
||||
|
@ -68,34 +48,6 @@ dependencies = [
|
|||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"humantime",
|
||||
"log",
|
||||
"regex",
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "humantime"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
|
@ -117,12 +69,6 @@ dependencies = [
|
|||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc"
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.6.4"
|
||||
|
@ -163,23 +109,6 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.5.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.6.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.1.0"
|
||||
|
@ -187,24 +116,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
|
||||
|
||||
[[package]]
|
||||
name = "tcp_server"
|
||||
version = "0.1.0"
|
||||
name = "wetstring"
|
||||
version = "1.0.0"
|
||||
dependencies = [
|
||||
"crossbeam-deque",
|
||||
"env_logger",
|
||||
"log",
|
||||
"mio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2dfed899f0eb03f32ee8c6a0aabdb8a7949659e3466561fc0adf54e26d88c5f4"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
|
@ -221,15 +140,6 @@ version = "0.4.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "tcp_server"
|
||||
version = "0.1.0"
|
||||
name = "wetstring"
|
||||
version = "1.0.0"
|
||||
authors = ["Ben Bridle <bridle.benjamin@gmail.com>"]
|
||||
edition = "2018"
|
||||
|
||||
|
@ -10,4 +10,3 @@ edition = "2018"
|
|||
mio = {version = "0.7.13", features = ["os-poll", "net"]}
|
||||
crossbeam-deque = "0.8.1"
|
||||
log = "0.4.14"
|
||||
env_logger = "0.8.4"
|
||||
|
|
|
@ -1,27 +1,30 @@
|
|||
use crate::{Request, Response};
|
||||
use crate::{Request, RequestParser};
|
||||
use mio::{net::TcpStream, Token};
|
||||
use std::net::SocketAddr;
|
||||
|
||||
pub struct Connection<Req: Request, Res: Response> {
|
||||
pub struct Connection<Req: Request> {
|
||||
pub stream: TcpStream,
|
||||
pub state: RequestState<Req, Res>,
|
||||
pub address: SocketAddr,
|
||||
pub state: ConnectionState<Req>,
|
||||
pub client_address: SocketAddr,
|
||||
pub token: Token,
|
||||
}
|
||||
|
||||
impl<Req: Request, Res: Response> Connection<Req, Res> {
|
||||
pub fn new(stream: TcpStream, address: SocketAddr, token: Token) -> Self {
|
||||
impl<Req: Request> Connection<Req> {
|
||||
pub fn new(stream: TcpStream, client_address: SocketAddr, token: Token) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
address,
|
||||
client_address,
|
||||
token,
|
||||
state: RequestState::Incoming(Req::new()),
|
||||
state: ConnectionState::Incoming(Req::Parser::new(client_address)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum RequestState<Req: Request, Res: Response> {
|
||||
Incoming(Req),
|
||||
pub enum ConnectionState<Req: Request> {
|
||||
/// The request has not yet been fully received.
|
||||
Incoming(Req::Parser),
|
||||
/// The request is valid and needs to be processed.
|
||||
Processing(Req),
|
||||
Outgoing(Res),
|
||||
/// A response has been generated for the request and is waiting to be sent.
|
||||
Outgoing(Req::Response),
|
||||
}
|
||||
|
|
15
src/lib.rs
15
src/lib.rs
|
@ -1,8 +1,11 @@
|
|||
mod tcp_server;
|
||||
pub use tcp_server::TcpServer;
|
||||
|
||||
mod connection;
|
||||
pub(crate) use connection::{Connection, RequestState};
|
||||
|
||||
mod request_response;
|
||||
pub use request_response::{ProcessRequest, Request, RequestParseResult, Response};
|
||||
mod server;
|
||||
mod server_error;
|
||||
|
||||
pub(crate) use connection::{Connection, ConnectionState};
|
||||
pub use request_response::{
|
||||
Request, RequestParseResult, RequestParser, RequestProcessor, Response,
|
||||
};
|
||||
pub use server::Server;
|
||||
pub use server_error::ServerError;
|
||||
|
|
|
@ -1,22 +1,45 @@
|
|||
pub trait Request {
|
||||
type Response;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
/// A trait for an object that can store the state of a partially parsed request.
|
||||
/// The incoming byte stream can be parsed incrementally, and the request can be converted
|
||||
/// to an error response if the incoming bytes are invalid.
|
||||
pub trait RequestParser<Req: Request> {
|
||||
fn new(client_address: SocketAddr) -> Self;
|
||||
|
||||
fn new() -> Self;
|
||||
fn push_bytes(&mut self, bytes: &[u8]);
|
||||
fn parse(&mut self) -> RequestParseResult<Self, Self::Response>
|
||||
where
|
||||
Self: Sized,
|
||||
Self::Response: Response;
|
||||
|
||||
fn try_parse(&mut self) -> RequestParseResult<Req>;
|
||||
}
|
||||
|
||||
/// A trait for objects that represent a valid network request.
|
||||
pub trait Request: Sized {
|
||||
type Response: Response;
|
||||
type Parser: RequestParser<Self>;
|
||||
|
||||
fn process(&self, request_processor: RequestProcessor<Self>) -> Self::Response {
|
||||
request_processor(&self)
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait for objects that represent the response to a network request.
|
||||
pub trait Response {
|
||||
fn to_bytes(self) -> Vec<u8>;
|
||||
}
|
||||
|
||||
pub type ProcessRequest<Req, Res> = fn(request: &Req) -> Res;
|
||||
|
||||
pub enum RequestParseResult<Req: Request, Res: Response> {
|
||||
pub enum RequestParseResult<Req: Request> {
|
||||
/// The request is waiting for the client to send more data before the
|
||||
/// request can either be discarded as erroneous or parsed.
|
||||
Incomplete,
|
||||
/// Sufficient data has been received from the client, and the request
|
||||
/// has been successfully parsed.
|
||||
Complete(Req),
|
||||
Invalid(Res),
|
||||
/// An error has been encountered in the received data, and a response
|
||||
/// has been generated to be returned to the client.
|
||||
Invalid(Req::Response),
|
||||
}
|
||||
|
||||
/// A function that converts a Request into a Response. This one function will
|
||||
/// contain all of the logic for a server.
|
||||
pub type RequestProcessor<Req> = fn(request: &Req) -> <Req as Request>::Response;
|
||||
|
||||
// TODO: pub type InvalidRequestProcessor<Req, Res> = fn(invalid_request: &Req) -> Res;
|
||||
|
|
|
@ -0,0 +1,330 @@
|
|||
use crate::*;
|
||||
use crossbeam_deque::{Injector, Steal};
|
||||
use log::{debug, error, info, warn};
|
||||
use mio::{event::Event, net::TcpListener, Events, Interest, Poll, Token};
|
||||
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read as _, Write as _};
|
||||
use std::marker::Send;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::{mpsc, Arc};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
const MINIMUM_POLL_DURATION: Duration = Duration::from_millis(1);
|
||||
|
||||
// TODO: Implement sleeping when inactive. If no read or write events have
|
||||
// taken place in the past duration, increase polling time. Make this
|
||||
// configurable, because a Gemini server will have different requirements to
|
||||
// the main Doctrine API. Default to no sleeping.
|
||||
|
||||
/// A server that processes incoming requests and generates responses
|
||||
pub struct Server<Req: Request> {
|
||||
/// The maximum number of concurrent inbound TCP connections that this
|
||||
/// server supports. When the number of concurrent connections equals
|
||||
/// this number, any new incoming connections will be dropped immediately
|
||||
/// with no response.
|
||||
/// TODO: Figure out a more elegant method of preventing overload below
|
||||
/// this level. Measure current CPU load, perhaps? Or measure median
|
||||
/// request time and throttle based on that?
|
||||
/// TODO: Consider checking the size of the connections vector every so
|
||||
/// often during quiet moments, and shrinking it if possible. This would
|
||||
/// prevent peak surges of traffic from allocating large amounts of RAM
|
||||
/// in perpetuity.
|
||||
max_connections: usize,
|
||||
|
||||
/// Hold all current inbound connections in a fixed-size Vec.
|
||||
connections: Vec<Option<Connection<Req>>>,
|
||||
|
||||
poll: Poll,
|
||||
listener: TcpListener,
|
||||
|
||||
/// The ID of the next token to be allocated. All tokens up to this value
|
||||
/// have already been allocated. This is also the size of the connections
|
||||
/// Vec.
|
||||
/// TODO: Could this be removed in favour of counting the length of the
|
||||
/// connections Vec?
|
||||
next_token_id: usize,
|
||||
/// A vector containing all tokens that are ready to be reused. These
|
||||
/// tokens have been allocated, used for a past connection, and then freed.
|
||||
freed_tokens: Vec<Token>,
|
||||
|
||||
/// All request processing threads.
|
||||
worker_threads: Vec<std::thread::JoinHandle<()>>,
|
||||
/// A queue of all the requests that are waiting to be sent to a worker
|
||||
/// thread to be processed into responses. Each worker thread has direct
|
||||
/// access to this queue.
|
||||
request_queue: Arc<Injector<Connection<Req>>>,
|
||||
/// The end of the one-way channel that connects all worker threads to
|
||||
/// this server.
|
||||
response_receiver: mpsc::Receiver<Connection<Req>>,
|
||||
}
|
||||
|
||||
impl<Req> Server<Req>
|
||||
where
|
||||
Req: Request + 'static + Send,
|
||||
Req::Response: Send,
|
||||
Req::Parser: Send,
|
||||
{
|
||||
/// Create a new server and create worker threads to process requests.
|
||||
pub fn new(
|
||||
server_address: SocketAddr,
|
||||
max_connections: usize,
|
||||
worker_count: usize,
|
||||
request_processor: RequestProcessor<Req>,
|
||||
) -> Self {
|
||||
let mut listener = match TcpListener::bind(server_address) {
|
||||
Ok(listener) => listener,
|
||||
Err(error) => handle_server_bind_error(error, server_address),
|
||||
};
|
||||
info!("Server is listening on address {}", server_address);
|
||||
|
||||
// Register the server TCP connection with the poll object. The poll
|
||||
// object will listen for incoming connections.
|
||||
let poll = Poll::new().unwrap();
|
||||
poll.registry()
|
||||
.register(&mut listener, Token(0), Interest::READABLE)
|
||||
.unwrap();
|
||||
|
||||
// Create a channel to connect worker threads to the main thread, so
|
||||
// that responses can be collected and returned to each client.
|
||||
let (response_sender, response_receiver) = mpsc::channel();
|
||||
|
||||
let mut worker_threads = Vec::new();
|
||||
let request_queue: Arc<Injector<Connection<Req>>> = Arc::new(Injector::new());
|
||||
|
||||
// Start a number of worker threads, which will be used to process
|
||||
// requests into responses.
|
||||
for _ in 0..worker_count {
|
||||
let request_queue = request_queue.clone();
|
||||
let response_sender = response_sender.clone();
|
||||
worker_threads.push(std::thread::spawn(move || loop {
|
||||
match request_queue.steal() {
|
||||
Steal::Success(mut connection) => {
|
||||
let request = match connection.state {
|
||||
ConnectionState::Processing(ref request) => request,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let response = request_processor(request);
|
||||
connection.state = ConnectionState::Outgoing(response);
|
||||
response_sender.send(connection).unwrap()
|
||||
}
|
||||
Steal::Empty | Steal::Retry => (),
|
||||
}
|
||||
// TODO: Instead of sleeping for a fixed duration, keep a
|
||||
// record of how busy the server has been for the past while.
|
||||
// If the worker threads are mostly idle, sleep for longer.
|
||||
// If the worker threads are screaming along, don't sleep.
|
||||
std::thread::sleep(MINIMUM_POLL_DURATION);
|
||||
}))
|
||||
}
|
||||
match worker_count {
|
||||
1 => info!("{} worker thread has been created", worker_count),
|
||||
_ => info!("{} worker threads have been created", worker_count),
|
||||
}
|
||||
|
||||
Self {
|
||||
max_connections,
|
||||
connections: Vec::new(),
|
||||
poll,
|
||||
listener,
|
||||
|
||||
next_token_id: 1,
|
||||
freed_tokens: Vec::new(),
|
||||
|
||||
worker_threads,
|
||||
request_queue,
|
||||
response_receiver,
|
||||
}
|
||||
}
|
||||
|
||||
/// Poll for, and handle, incoming connections.
|
||||
pub fn poll(&mut self) {
|
||||
let poll_start = Instant::now();
|
||||
let mut events = Events::with_capacity(1024);
|
||||
const TIMEOUT: Option<Duration> = Some(Duration::from_millis(1));
|
||||
self.poll.poll(&mut events, TIMEOUT).unwrap();
|
||||
for event in &events {
|
||||
if event.is_readable() {
|
||||
if event.token() == Token(0) {
|
||||
self.accept_new_connections();
|
||||
} else {
|
||||
self.process_read_event(event);
|
||||
}
|
||||
} else if event.is_writable() {
|
||||
self.process_write_event(event);
|
||||
} else {
|
||||
warn!("Received unreadable and unwritable event")
|
||||
}
|
||||
}
|
||||
while let Ok(connection) = self.response_receiver.try_recv() {
|
||||
self.set_outgoing_connection(connection)
|
||||
}
|
||||
let elapsed = poll_start.elapsed();
|
||||
if elapsed < MINIMUM_POLL_DURATION {
|
||||
std::thread::sleep(MINIMUM_POLL_DURATION - elapsed);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_outgoing_connection(&mut self, connection: Connection<Req>) {
|
||||
let slot = self.connections.get_mut(connection.token.0).unwrap();
|
||||
*slot = Some(connection);
|
||||
if let Some(ref mut connection) = slot {
|
||||
self.poll
|
||||
.registry()
|
||||
.reregister(&mut connection.stream, connection.token, Interest::WRITABLE)
|
||||
.unwrap();
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
}
|
||||
|
||||
fn process_read_event(&mut self, event: &Event) {
|
||||
let token = event.token();
|
||||
let slot_index = token.0;
|
||||
debug!("Read event for token {}", token.0);
|
||||
let connection = match self.connections[slot_index].as_mut() {
|
||||
Some(connection) => connection,
|
||||
None => return,
|
||||
};
|
||||
if let ConnectionState::Incoming(ref mut parser) = connection.state {
|
||||
loop {
|
||||
let mut buffer = [0_u8; 1024];
|
||||
match connection.stream.read(&mut buffer) {
|
||||
Ok(0) => {
|
||||
self.remove_connection(token);
|
||||
return;
|
||||
}
|
||||
Ok(len) => parser.push_bytes(&buffer[..len]),
|
||||
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
|
||||
Err(e) => error!("Unexpected error: {}", e),
|
||||
};
|
||||
}
|
||||
|
||||
match parser.try_parse() {
|
||||
RequestParseResult::Complete(request) => {
|
||||
let mut connection =
|
||||
std::mem::replace(&mut self.connections[slot_index], None).unwrap();
|
||||
connection.state = ConnectionState::Processing(request);
|
||||
self.request_queue.push(connection);
|
||||
}
|
||||
RequestParseResult::Invalid(response) => {
|
||||
connection.state = ConnectionState::Outgoing(response);
|
||||
let connection =
|
||||
std::mem::replace(&mut self.connections[slot_index], None).unwrap();
|
||||
self.set_outgoing_connection(connection);
|
||||
}
|
||||
RequestParseResult::Incomplete => (),
|
||||
};
|
||||
} else {
|
||||
warn!("Received read event for non-incoming connection")
|
||||
}
|
||||
}
|
||||
|
||||
fn process_write_event(&mut self, event: &Event) {
|
||||
let token = event.token();
|
||||
let mut connection = std::mem::replace(&mut self.connections[token.0], None).unwrap();
|
||||
if let ConnectionState::Outgoing(response) = connection.state {
|
||||
let bytes = response.to_bytes();
|
||||
connection.stream.write_all(&bytes).unwrap();
|
||||
} else {
|
||||
warn!("Received write event for non-outgoing connection")
|
||||
}
|
||||
self.remove_connection(connection.token);
|
||||
info!(
|
||||
"Closed connection from {} (token {})",
|
||||
connection.client_address, connection.token.0
|
||||
);
|
||||
}
|
||||
|
||||
/// Accept all pending incoming connections.
|
||||
fn accept_new_connections(&mut self) {
|
||||
loop {
|
||||
match self.listener.accept() {
|
||||
Ok((stream, address)) => {
|
||||
// Get an unused token
|
||||
let token = match self.get_unused_token() {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
warn!("Capacity reached, dropping connection from {}", address);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Allocate sufficient capacity in the connections Vec
|
||||
if self.connections.len() <= token.0 {
|
||||
let difference = token.0 - self.connections.len() + 1;
|
||||
(0..difference).for_each(|_| self.connections.push(None));
|
||||
}
|
||||
|
||||
// Create a connection object and register it as Readable.
|
||||
// The dance is required because I can't move the connection
|
||||
// once it's been registered to the poll object.
|
||||
let slot = self.connections.get_mut(token.0).unwrap();
|
||||
*slot = Some(Connection::new(stream, address, token));
|
||||
if let Some(ref mut connection) = slot {
|
||||
self.poll
|
||||
.registry()
|
||||
.register(&mut connection.stream, token, Interest::READABLE)
|
||||
.unwrap();
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
info!("Accepted connection from {} (token {})", address, token.0);
|
||||
}
|
||||
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
|
||||
Err(e) => error!("Unexpected error while accepting a connection: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an unused token if one exists, else None.
|
||||
fn get_unused_token(&mut self) -> Option<Token> {
|
||||
let freed_token = self.freed_tokens.pop();
|
||||
if freed_token.is_some() {
|
||||
return freed_token;
|
||||
};
|
||||
|
||||
if self.next_token_id < self.max_connections + 1 {
|
||||
// The +1 is because connection 0 is the server, so we need one more
|
||||
let token_id = self.next_token_id;
|
||||
self.next_token_id += 1;
|
||||
Some(Token(token_id))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop a connection and reclaim the token associated with that connection.
|
||||
fn remove_connection(&mut self, token: Token) {
|
||||
if let Some(slot) = self.connections.get_mut(token.0) {
|
||||
*slot = None;
|
||||
self.freed_tokens.push(token);
|
||||
} else {
|
||||
warn!("Attempted to remove non-existent connection {}", token.0);
|
||||
};
|
||||
}
|
||||
|
||||
/// TODO: Methods to accept or refuse incoming connections, to be used for
|
||||
/// gently shedding load before killing a server.
|
||||
pub fn accept_incoming_connections(&mut self) {}
|
||||
pub fn refuse_incoming_connections(&mut self) {}
|
||||
/// Return the number of current connections.
|
||||
pub fn num_current_connections(&self) {}
|
||||
|
||||
pub fn run(&mut self) -> ! {
|
||||
loop {
|
||||
self.poll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
fn handle_server_bind_error(error: IoError, server_address: SocketAddr) -> ! {
|
||||
let port = server_address.port();
|
||||
match error.kind() {
|
||||
IoErrorKind::PermissionDenied => match port < 1024 {
|
||||
true => error!("Could not bind the server to privileged port {} without admin permissions", port),
|
||||
false => error!("Could not bind server to port {} due to insufficient permissions", port),
|
||||
},
|
||||
_ => error!("Could not bind server to port {}", port),
|
||||
};
|
||||
std::process::exit(1);
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
#[derive(Debug)]
|
||||
pub enum ServerError {
|
||||
PermissionDeniedForPrivilegedPort,
|
||||
PermissionDenied,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl ServerError {
|
||||
pub fn description(&self) -> &str {
|
||||
match self {
|
||||
Self::PermissionDeniedForPrivilegedPort => {
|
||||
"Could not start server on a privileged port due to insufficient permissions"
|
||||
}
|
||||
Self::PermissionDenied => "Action failed due to insufficient permissions",
|
||||
Self::Unknown => "An unknown error has occurred",
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,240 +0,0 @@
|
|||
use crate::*;
|
||||
use crossbeam_deque::{Injector, Steal};
|
||||
use log::{error, info};
|
||||
use mio::{event::Event, net::TcpListener, Events, Interest, Poll, Token};
|
||||
use std::io::{Read as _, Write as _};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::{mpsc, Arc};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
const MINIMUM_POLL_DURATION: Duration = Duration::from_millis(1);
|
||||
|
||||
// TODO: Implement sleeping when inactive. If no read or write events have
|
||||
// taken place in the past duration, increase polling time. Make this
|
||||
// configurable, because a Gemini server will have different requirements to
|
||||
// the main Doctrine API. Default to no sleeping.
|
||||
pub struct TcpServer<Req: Request, Res: Response> {
|
||||
max_connections: usize,
|
||||
connections: Vec<Option<Connection<Req, Res>>>,
|
||||
poll: Poll,
|
||||
listener: TcpListener,
|
||||
|
||||
next_token_value: usize,
|
||||
freed_tokens: Vec<Token>,
|
||||
|
||||
worker_threads: Vec<thread::JoinHandle<()>>,
|
||||
request_queue: Arc<Injector<Connection<Req, Res>>>,
|
||||
response_receiver: mpsc::Receiver<Connection<Req, Res>>,
|
||||
}
|
||||
|
||||
impl<
|
||||
Req: 'static + Request + std::marker::Send + request_response::Request<Response = Res>,
|
||||
Res: 'static + Response + std::marker::Send,
|
||||
> TcpServer<Req, Res>
|
||||
{
|
||||
pub fn new(
|
||||
address: SocketAddr,
|
||||
max_connections: usize,
|
||||
worker_count: usize,
|
||||
process_request: ProcessRequest<Req, Res>,
|
||||
) -> Self {
|
||||
let mut listener = TcpListener::bind(address).unwrap();
|
||||
info!("Server is listening at address {}", address);
|
||||
let poll = Poll::new().unwrap();
|
||||
poll.registry()
|
||||
.register(&mut listener, Token(0), Interest::READABLE)
|
||||
.unwrap();
|
||||
|
||||
let (response_sender, response_receiver) = mpsc::channel();
|
||||
|
||||
let mut new_server = Self {
|
||||
max_connections,
|
||||
connections: Vec::new(),
|
||||
poll,
|
||||
listener,
|
||||
|
||||
next_token_value: 1,
|
||||
freed_tokens: Vec::new(),
|
||||
|
||||
worker_threads: Vec::new(),
|
||||
request_queue: Arc::new(Injector::new()),
|
||||
response_receiver,
|
||||
};
|
||||
|
||||
// Start the worker threads
|
||||
for _ in 0..worker_count {
|
||||
let request_queue = new_server.request_queue.clone();
|
||||
let response_sender = response_sender.clone();
|
||||
new_server.worker_threads.push(thread::spawn(move || loop {
|
||||
match request_queue.steal() {
|
||||
Steal::Success(mut connection) => {
|
||||
let request = match connection.state {
|
||||
RequestState::Processing(ref request) => request,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let response = process_request(request);
|
||||
connection.state = RequestState::Outgoing(response);
|
||||
response_sender.send(connection).unwrap()
|
||||
}
|
||||
Steal::Empty => (),
|
||||
Steal::Retry => (),
|
||||
}
|
||||
std::thread::sleep(MINIMUM_POLL_DURATION);
|
||||
}))
|
||||
}
|
||||
match worker_count {
|
||||
1 => info!("{} worker thread has been created", worker_count),
|
||||
_ => info!("{} worker threads have been created", worker_count),
|
||||
}
|
||||
return new_server;
|
||||
}
|
||||
|
||||
pub fn poll(&mut self) {
|
||||
let poll_start = std::time::Instant::now();
|
||||
let mut events = Events::with_capacity(1024);
|
||||
const TIMEOUT: Option<Duration> = Some(Duration::from_millis(1));
|
||||
self.poll.poll(&mut events, TIMEOUT).unwrap();
|
||||
for event in &events {
|
||||
if event.is_readable() {
|
||||
if event.token() == Token(0) {
|
||||
self.accept_new_connections();
|
||||
} else {
|
||||
self.process_read_event(event);
|
||||
}
|
||||
} else if event.is_writable() {
|
||||
self.process_write_event(event);
|
||||
} else {
|
||||
info!("Received unreadable and unwritable event")
|
||||
}
|
||||
}
|
||||
loop {
|
||||
match self.response_receiver.try_recv() {
|
||||
Ok(connection) => self.set_outgoing_connection(connection),
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
let elapsed = poll_start.elapsed();
|
||||
if elapsed < MINIMUM_POLL_DURATION {
|
||||
std::thread::sleep(MINIMUM_POLL_DURATION - elapsed);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_outgoing_connection(&mut self, connection: Connection<Req, Res>) {
|
||||
let slot = self.connections.get_mut(connection.token.0).unwrap();
|
||||
*slot = Some(connection);
|
||||
if let Some(ref mut connection) = slot {
|
||||
self.poll
|
||||
.registry()
|
||||
.reregister(&mut connection.stream, connection.token, Interest::WRITABLE)
|
||||
.unwrap();
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
}
|
||||
|
||||
fn process_read_event(&mut self, event: &Event) {
|
||||
let token = event.token();
|
||||
let connection = self.connections[token.0].as_mut().unwrap();
|
||||
if let RequestState::Incoming(ref mut req) = connection.state {
|
||||
loop {
|
||||
let mut buffer = [0 as u8; 1024];
|
||||
match connection.stream.read(&mut buffer) {
|
||||
Ok(0) => {
|
||||
self.remove_connection(event.token()).unwrap();
|
||||
return;
|
||||
}
|
||||
Ok(len) => req.push_bytes(&buffer[..len]),
|
||||
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
|
||||
Err(e) => panic!("Unexpected error: {}", e),
|
||||
};
|
||||
}
|
||||
|
||||
match req.parse() {
|
||||
RequestParseResult::Complete(request) => {
|
||||
let mut connection =
|
||||
std::mem::replace(&mut self.connections[token.0], None).unwrap();
|
||||
connection.state = RequestState::Processing(request);
|
||||
self.request_queue.push(connection);
|
||||
}
|
||||
RequestParseResult::Invalid(response) => {
|
||||
connection.state = RequestState::Outgoing(response);
|
||||
let connection =
|
||||
std::mem::replace(&mut self.connections[token.0], None).unwrap();
|
||||
self.set_outgoing_connection(connection);
|
||||
}
|
||||
RequestParseResult::Incomplete => (),
|
||||
};
|
||||
} else {
|
||||
info!("Received read event for non-incoming connection")
|
||||
}
|
||||
}
|
||||
|
||||
fn process_write_event(&mut self, event: &Event) {
|
||||
let token = event.token();
|
||||
let mut connection = std::mem::replace(&mut self.connections[token.0], None).unwrap();
|
||||
if let RequestState::Outgoing(response) = connection.state {
|
||||
let bytes = response.to_bytes();
|
||||
connection.stream.write_all(&bytes).unwrap();
|
||||
} else {
|
||||
info!("Received write event for non-outgoing connection")
|
||||
}
|
||||
self.remove_connection(connection.token).unwrap();
|
||||
info!(
|
||||
"Closed connection from {} (token {})",
|
||||
connection.address, connection.token.0
|
||||
);
|
||||
}
|
||||
|
||||
fn accept_new_connections(&mut self) {
|
||||
loop {
|
||||
match self.listener.accept() {
|
||||
Ok((stream, address)) => {
|
||||
// Get an unused token
|
||||
let token = if let Some(token) = self.freed_tokens.pop() {
|
||||
token
|
||||
} else if self.next_token_value < self.max_connections {
|
||||
let token_value = self.next_token_value;
|
||||
self.next_token_value += 1;
|
||||
Token(token_value)
|
||||
} else {
|
||||
error!("Capacity reached, dropping connection from {}", address);
|
||||
continue;
|
||||
};
|
||||
|
||||
// Initialise the connection vec up to this point
|
||||
if self.connections.len() <= token.0 {
|
||||
let difference = token.0 - self.connections.len() + 1;
|
||||
(0..difference).for_each(|_| self.connections.push(None));
|
||||
}
|
||||
|
||||
// Create the connection object and register it as Readable
|
||||
let slot = self.connections.get_mut(token.0).unwrap();
|
||||
*slot = Some(Connection::new(stream, address, token));
|
||||
if let Some(ref mut connection) = slot {
|
||||
self.poll
|
||||
.registry()
|
||||
.register(&mut connection.stream, token, Interest::READABLE)
|
||||
.unwrap();
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
info!(
|
||||
"Accepted incoming connection from {} (token {})",
|
||||
address, token.0
|
||||
);
|
||||
}
|
||||
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
|
||||
Err(e) => panic!("Unexpected error while accepting a connection: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_connection(&mut self, token: Token) -> Result<(), ()> {
|
||||
let slot = self.connections.get_mut(token.0).ok_or(())?;
|
||||
*slot = None;
|
||||
self.freed_tokens.push(token);
|
||||
Ok(())
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue