wetstring/src/tcp_server.rs

241 lines
9.2 KiB
Rust

use crate::*;
use crossbeam_deque::{Injector, Steal};
use log::{error, info};
use mio::{event::Event, net::TcpListener, Events, Interest, Poll, Token};
use std::io::{Read as _, Write as _};
use std::net::SocketAddr;
use std::sync::{mpsc, Arc};
use std::thread;
use std::time::Duration;
const MINIMUM_POLL_DURATION: Duration = Duration::from_millis(1);
// TODO: Implement sleeping when inactive. If no read or write events have
// taken place in the past duration, increase polling time. Make this
// configurable, because a Gemini server will have different requirements to
// the main Doctrine API. Default to no sleeping.
pub struct TcpServer<Req: Request, Res: Response> {
max_connections: usize,
connections: Vec<Option<Connection<Req, Res>>>,
poll: Poll,
listener: TcpListener,
next_token_value: usize,
freed_tokens: Vec<Token>,
worker_threads: Vec<thread::JoinHandle<()>>,
request_queue: Arc<Injector<Connection<Req, Res>>>,
response_receiver: mpsc::Receiver<Connection<Req, Res>>,
}
impl<
Req: 'static + Request + std::marker::Send + request_response::Request<Response = Res>,
Res: 'static + Response + std::marker::Send,
> TcpServer<Req, Res>
{
pub fn new(
address: SocketAddr,
max_connections: usize,
worker_count: usize,
process_request: ProcessRequest<Req, Res>,
) -> Self {
let mut listener = TcpListener::bind(address).unwrap();
info!("Server is listening at address {}", address);
let poll = Poll::new().unwrap();
poll.registry()
.register(&mut listener, Token(0), Interest::READABLE)
.unwrap();
let (response_sender, response_receiver) = mpsc::channel();
let mut new_server = Self {
max_connections,
connections: Vec::new(),
poll,
listener,
next_token_value: 1,
freed_tokens: Vec::new(),
worker_threads: Vec::new(),
request_queue: Arc::new(Injector::new()),
response_receiver,
};
// Start the worker threads
for _ in 0..worker_count {
let request_queue = new_server.request_queue.clone();
let response_sender = response_sender.clone();
new_server.worker_threads.push(thread::spawn(move || loop {
match request_queue.steal() {
Steal::Success(mut connection) => {
let request = match connection.state {
RequestState::Processing(ref request) => request,
_ => unreachable!(),
};
let response = process_request(request);
connection.state = RequestState::Outgoing(response);
response_sender.send(connection).unwrap()
}
Steal::Empty => (),
Steal::Retry => (),
}
std::thread::sleep(MINIMUM_POLL_DURATION);
}))
}
match worker_count {
1 => info!("{} worker thread has been created", worker_count),
_ => info!("{} worker threads have been created", worker_count),
}
return new_server;
}
pub fn poll(&mut self) {
let poll_start = std::time::Instant::now();
let mut events = Events::with_capacity(1024);
const TIMEOUT: Option<Duration> = Some(Duration::from_millis(1));
self.poll.poll(&mut events, TIMEOUT).unwrap();
for event in &events {
if event.is_readable() {
if event.token() == Token(0) {
self.accept_new_connections();
} else {
self.process_read_event(event);
}
} else if event.is_writable() {
self.process_write_event(event);
} else {
info!("Received unreadable and unwritable event")
}
}
loop {
match self.response_receiver.try_recv() {
Ok(connection) => self.set_outgoing_connection(connection),
Err(_) => break,
}
}
let elapsed = poll_start.elapsed();
if elapsed < MINIMUM_POLL_DURATION {
std::thread::sleep(MINIMUM_POLL_DURATION - elapsed);
}
}
fn set_outgoing_connection(&mut self, connection: Connection<Req, Res>) {
let slot = self.connections.get_mut(connection.token.0).unwrap();
*slot = Some(connection);
if let Some(ref mut connection) = slot {
self.poll
.registry()
.reregister(&mut connection.stream, connection.token, Interest::WRITABLE)
.unwrap();
} else {
unreachable!()
};
}
fn process_read_event(&mut self, event: &Event) {
let token = event.token();
let connection = self.connections[token.0].as_mut().unwrap();
if let RequestState::Incoming(ref mut req) = connection.state {
loop {
let mut buffer = [0 as u8; 1024];
match connection.stream.read(&mut buffer) {
Ok(0) => {
self.remove_connection(event.token()).unwrap();
return;
}
Ok(len) => req.push_bytes(&buffer[..len]),
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(e) => panic!("Unexpected error: {}", e),
};
}
match req.parse() {
RequestParseResult::Complete(request) => {
let mut connection =
std::mem::replace(&mut self.connections[token.0], None).unwrap();
connection.state = RequestState::Processing(request);
self.request_queue.push(connection);
}
RequestParseResult::Invalid(response) => {
connection.state = RequestState::Outgoing(response);
let connection =
std::mem::replace(&mut self.connections[token.0], None).unwrap();
self.set_outgoing_connection(connection);
}
RequestParseResult::Incomplete => (),
};
} else {
info!("Received read event for non-incoming connection")
}
}
fn process_write_event(&mut self, event: &Event) {
let token = event.token();
let mut connection = std::mem::replace(&mut self.connections[token.0], None).unwrap();
if let RequestState::Outgoing(response) = connection.state {
let bytes = response.to_bytes();
connection.stream.write_all(&bytes).unwrap();
} else {
info!("Received write event for non-outgoing connection")
}
self.remove_connection(connection.token).unwrap();
info!(
"Closed connection from {} (token {})",
connection.address, connection.token.0
);
}
fn accept_new_connections(&mut self) {
loop {
match self.listener.accept() {
Ok((stream, address)) => {
// Get an unused token
let token = if let Some(token) = self.freed_tokens.pop() {
token
} else if self.next_token_value < self.max_connections {
let token_value = self.next_token_value;
self.next_token_value += 1;
Token(token_value)
} else {
error!("Capacity reached, dropping connection from {}", address);
continue;
};
// Initialise the connection vec up to this point
if self.connections.len() <= token.0 {
let difference = token.0 - self.connections.len() + 1;
(0..difference).for_each(|_| self.connections.push(None));
}
// Create the connection object and register it as Readable
let slot = self.connections.get_mut(token.0).unwrap();
*slot = Some(Connection::new(stream, address, token));
if let Some(ref mut connection) = slot {
self.poll
.registry()
.register(&mut connection.stream, token, Interest::READABLE)
.unwrap();
} else {
unreachable!()
};
info!(
"Accepted incoming connection from {} (token {})",
address, token.0
);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(e) => panic!("Unexpected error while accepting a connection: {}", e),
}
}
}
fn remove_connection(&mut self, token: Token) -> Result<(), ()> {
let slot = self.connections.get_mut(token.0).ok_or(())?;
*slot = None;
self.freed_tokens.push(token);
Ok(())
}
}