switch to quiche

This commit is contained in:
Ezra Barrow 2023-08-18 11:23:29 -05:00
parent bd70f012a3
commit 117af24977
No known key found for this signature in database
GPG Key ID: 5EF8BA3CE9180419
6 changed files with 763 additions and 156 deletions

51
Cargo.lock generated
View File

@ -100,6 +100,15 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cmake"
version = "0.1.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130"
dependencies = [
"cc",
]
[[package]]
name = "color-eyre"
version = "0.6.2"
@ -246,6 +255,12 @@ version = "0.2.147"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
[[package]]
name = "libm"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4"
[[package]]
name = "lock_api"
version = "0.4.10"
@ -319,6 +334,12 @@ dependencies = [
"memchr",
]
[[package]]
name = "octets"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a74f2cda724d43a0a63140af89836d4e7db6138ef67c9f96d3a0f0150d05000"
[[package]]
name = "once_cell"
version = "1.18.0"
@ -391,6 +412,24 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "quiche"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d9e4fa8718d45fd25dd89c196e128d6c3527b5b1735db47eb4bdb9ba3e4cc1c"
dependencies = [
"cmake",
"lazy_static",
"libc",
"libm",
"log",
"octets",
"ring",
"slab",
"smallvec",
"winapi",
]
[[package]]
name = "quinn"
version = "0.10.2"
@ -681,10 +720,13 @@ dependencies = [
"backoff",
"base64",
"color-eyre",
"quiche",
"quinn",
"rcgen",
"ring",
"rustls",
"serde",
"temp-dir",
"tokio",
"tokio-tun",
"toml",
@ -695,6 +737,9 @@ name = "smallvec"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
dependencies = [
"serde",
]
[[package]]
name = "socket2"
@ -729,6 +774,12 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "temp-dir"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af547b166dd1ea4b472165569fc456cfb6818116f854690b0ff205e636523dab"
[[package]]
name = "thiserror"
version = "1.0.44"

View File

@ -10,10 +10,13 @@ anyhow = { version = "1.0.72", features = ["backtrace"] }
backoff = { version = "0.4.0", features = ["tokio"] }
base64 = "0.21.2"
color-eyre = "0.6.2"
quiche = "0.17.2"
quinn = { version = "0.10.2", features = [] }
rcgen = "0.11.1"
ring = "0.16.20"
rustls = { version = "0.21.6", features = ["dangerous_configuration", "quic"] }
serde = { version = "1.0.183", features = ["derive"] }
temp-dir = "0.1.11"
tokio = { version = "1.31.0", features = ["full"] }
tokio-tun = "0.9.0"
toml = "0.7.6"

View File

@ -1,9 +1,11 @@
[server]
endpoint = "127.0.0.1:9092"
server_name = "alphamethyl.barr0w.net"
[client]
endpoint = "alphamethyl.barr0w.net:9092"
endpoint = "10.177.1.7:9092"
server_name = "alphamethyl.barr0w.net"
[interface]
address = "192.168.255.1"
netmask = "255.255.255.255"
netmask = "255.255.255.252"

View File

@ -2,30 +2,58 @@ use anyhow::Context;
use serde::{Serialize, Deserialize};
use std::fs;
use std::io::Write;
use std::net::Ipv4Addr;
use std::net::{Ipv4Addr, SocketAddr, ToSocketAddrs};
//TODO: use Arc<str> instead of String for all this stuff that needs to be cloned and last forever
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Server {
pub endpoint: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub server_name: Option<String>,
}
impl Server {
pub fn new() -> Self {
Self {
endpoint: String::from("127.0.0.1:9092"),
server_name: Some(String::from("alphamethyl.barr0w.net")),
}
}
/// Returns the socketaddr our endpoint should bind to
pub fn endpoint(&self) -> anyhow::Result<SocketAddr> {
self.endpoint.to_socket_addrs()?.next().context("bad server socketaddr")
}
/// Returns the server name for SNI
pub fn server_name(&self) -> anyhow::Result<Option<&String>> {
Ok(self.server_name.as_ref())
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Client {
pub endpoint: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub server_name: Option<String>,
}
impl Client {
pub fn new() -> Self {
Self {
endpoint: String::from("alphamethyl.barr0w.net:9092"),
server_name: Some(String::from("alphamethyl.barr0w.net")),
}
}
/// Returns the socketaddr that our udp socket should bind to
pub fn local_bind_addr(&self) -> anyhow::Result<SocketAddr> {
Ok("0.0.0.0:0".parse().unwrap())
}
/// Returns the socketaddr of the server's endpoint
pub fn endpoint(&self) -> anyhow::Result<SocketAddr> {
self.endpoint.to_socket_addrs()?.next().context("bad server socketaddr")
}
/// Returns the server name for SNI
pub fn server_name(&self) -> anyhow::Result<Option<&String>> {
Ok(self.server_name.as_ref())
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]

View File

@ -15,19 +15,24 @@
* --------------------
*/
use anyhow::Context;
use std::{io, time::Duration};
use tokio::time::sleep;
use std::{io, sync::Arc, time::Duration};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
sync::mpsc::error::{TryRecvError, SendError},
time::sleep,
};
use tokio_tun::Tun;
mod encoder;
use encoder::Encoder;
mod quic;
use quic::QuicModem;
pub mod config;
use config::Configuration;
// const STREAM_ID: u64 = 0b000;
fn is_timeout<T>(res: &anyhow::Result<T>) -> bool {
let e = match res {
Ok(_) => return false,
@ -66,7 +71,7 @@ macro_rules! handle_timeout {
}
pub async fn client_main() -> anyhow::Result<()> {
let c = Configuration::load_config("config.toml")?;
let c = Arc::new(Configuration::load_config("config.toml")?);
let tun = Tun::builder()
.name("sleepy")
.tap(false)
@ -76,32 +81,122 @@ pub async fn client_main() -> anyhow::Result<()> {
.up()
.try_build()?;
let mut tun = Encoder::new(tun);
let mut quic = QuicModem::new_client(&c.client()?.endpoint)?;
let (mut to_tun, mut from_quic) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();
let (mut to_quic, mut from_tun) = tokio::sync::mpsc::channel::<Vec<u8>>(8);
let (mut duplex_tun, mut duplex_quic) = tokio::io::duplex(65535);
let client_loop = tokio::spawn(quic::client_loop(c.clone(), {
let mut currently_sending: Option<(Vec<u8>, usize)> = None;
let mut rx_buffer = [0u8; 65535];
move |conn| {
let from_tun = &mut from_tun;
let to_tun = &mut to_tun;
let currently_sending = &mut currently_sending;
if !conn.is_established() {
return Ok(());
}
for stream_id in conn.readable() {
eprintln!("reading from {stream_id}");
while let Ok((read, fin)) = conn.stream_recv(stream_id, &mut rx_buffer) {
eprintln!("{} recieved {} bytes", conn.trace_id(), read);
let stream_buf = &rx_buffer[..read];
eprintln!(
"{} stream {} has {} bytes (fin? {})",
conn.trace_id(),
stream_id,
stream_buf.len(),
fin
);
if stream_id != 3 {
continue;
}
let vec = Vec::from(stream_buf);
match to_tun.send(vec) {
Ok(()) => {},
Err(SendError(vec)) => {
conn.close(true, 0x0, b"exiting")?;
return Ok(());
}
}
}
}
match conn.stream_writable(2, 1350) {
Ok(false) => {
return Ok(());
},
Err(quiche::Error::InvalidStreamState(_)) => {},
Err(e) => return Err(e.into()),
_ => {}
}
loop {
if let Some((buf, mut pos)) = currently_sending.take() {
eprintln!(" retrying packet send");
match conn.stream_send(2, &buf[pos..], false) {
Ok(written) => pos += written,
Err(quiche::Error::Done) => break,
Err(e) => {
return Err(e).context("stream send failed");
}
}
eprintln!(" sent {pos} bytes out of {}", buf.len());
if pos < buf.len() {
*currently_sending = Some((buf, pos));
}
} else {
let buf = match from_tun.try_recv() {
Ok(v) => v,
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => {
conn.close(true, 0x0, b"exiting")?;
return Ok(());
}
};
eprintln!("sending packet on stream 2");
let written = match conn.stream_send(2, &buf, false) {
Ok(v) => v,
Err(quiche::Error::Done) => 0,
Err(e) => {
return Err(e).context("stream send failed");
}
};
eprintln!(" sent {written} bytes out of {}", buf.len());
if written < buf.len() {
*currently_sending = Some((buf, written));
break;
}
}
}
Ok(())
}
}));
let mut backoff: u64 = 1;
loop {
eprintln!("[client] attempting to connect...");
let mut conn = handle_timeout!(quic.connect().await, {
eprintln!("[client] timed out, waiting {} seconds...", backoff);
sleep(Duration::from_secs(backoff)).await;
backoff *= 2;
continue;
})?;
backoff = 1;
eprintln!("[client] connected: addr={}", conn.remote_address());
handle_timeout!(tokio::io::copy_bidirectional(&mut tun, &mut conn)
.await
.context("io error"))?;
let mut buf = [0u8; 65535];
tokio::select! {
res = client_loop => {
res?
}
res = async {
loop {
tokio::select! {
res = tun.read(&mut buf) => {
eprintln!("sending packet");
let len = res?;
let vec = Vec::from(&buf[..len]);
to_quic.send(vec).await?;
},
Some(buf) = from_quic.recv() => {
eprintln!("recieved packet");
tun.write_all(&buf).await?;
},
}
}
} => res
}
}
pub async fn server_main() -> anyhow::Result<()> {
let c = Configuration::load_config("config.toml")?;
let c = Arc::new(Configuration::load_config("config2.toml")?);
let tun = Tun::builder()
.name("sleepy")
.name("sleepy2")
.tap(false)
.packet_info(true)
.address(c.interface.address)
@ -109,20 +204,112 @@ pub async fn server_main() -> anyhow::Result<()> {
.up()
.try_build()?;
let mut tun = Encoder::new(tun);
let mut quic = QuicModem::new_server(&c.server()?.endpoint)?;
let (mut to_tun, mut from_quic) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();
let (mut to_quic, mut from_tun) = tokio::sync::mpsc::channel::<Vec<u8>>(8);
let server_loop = tokio::spawn(quic::server_loop(c.clone(), {
let mut currently_sending: Option<(Vec<u8>, usize)> = None;
let mut rx_buffer = [0u8; 65535];
move |client| {
let from_tun = &mut from_tun;
let to_tun = &mut to_tun;
let mut currently_sending = &mut currently_sending;
if !client.conn.is_established() {
return Ok(());
}
for stream_id in client.conn.readable() {
while let Ok((read, fin)) = client.conn.stream_recv(stream_id, &mut rx_buffer) {
eprintln!("{} recieved {} bytes", client.conn.trace_id(), read);
let stream_buf = &rx_buffer[..read];
eprintln!(
"{} stream {} has {} bytes (fin? {})",
client.conn.trace_id(),
stream_id,
stream_buf.len(),
fin
);
if stream_id != 2 {
continue;
}
let vec = Vec::from(stream_buf);
match to_tun.send(vec) {
Ok(()) => {},
Err(SendError(vec)) => {
client.conn.close(true, 0x0, b"exiting")?;
return Ok(());
}
}
}
}
match client.conn.stream_writable(3, 1350) {
Ok(false) => {
return Ok(());
},
Err(quiche::Error::InvalidStreamState(_)) => {},
Err(e) => return Err(e.into()),
_ => {}
}
loop {
if let Some((buf, mut pos)) = currently_sending.take() {
eprintln!(" retrying packet send");
match client.conn.stream_send(3, &buf[pos..], false) {
Ok(written) => pos += written,
Err(quiche::Error::Done) => break,
Err(e) => {
return Err(e).context("stream send failed");
}
}
eprintln!(" sent {pos} bytes out of {}", buf.len());
if pos < buf.len() {
*currently_sending = Some((buf, pos));
}
} else {
let buf = match from_tun.try_recv() {
Ok(v) => v,
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => {
client.conn.close(true, 0x0, b"exiting")?;
return Ok(());
}
};
eprintln!("sending packet on stream 3");
let written = match client.conn.stream_send(3, &buf, false) {
Ok(v) => v,
Err(quiche::Error::Done) => 0,
Err(e) => {
return Err(e).context("stream send failed");
}
};
eprintln!(" sent {written} bytes out of {}", buf.len());
if written < buf.len() {
*currently_sending = Some((buf, written));
break;
}
}
}
Ok(())
}
}));
loop {
println!("[server] waiting for connection...");
let mut conn = handle_timeout!(quic.accept().await)?;
println!(
"[server] connection accepted: addr={}",
conn.remote_address()
);
handle_timeout!(tokio::io::copy_bidirectional(&mut tun, &mut conn)
.await
.context("io error"))?;
let mut buf = [0u8; 65535];
tokio::select! {
res = server_loop => {
res?
}
res = async {
loop {
tokio::select! {
res = tun.read(&mut buf) => {
eprintln!("sending packet");
let len = res?;
let vec = Vec::from(&buf[..len]);
to_quic.send(vec).await?;
},
Some(buf) = from_quic.recv() => {
eprintln!("recieved packet");
tun.write_all(&buf).await?;
},
}
}
} => res
}
}

View File

@ -14,146 +14,482 @@
* -Ezra Barrow
* --------------------
*/
/*
let sock: net::UdpSocket;
let conn: quiche::Connection;
let clients: Map<ConnId, Connection>;
loop {
'read: loop {
let rx_buffer;
let (len, from) = socket.recv_from(&mut rx_buffer)?;
let header = quiche::Header::from_slice(&mut rx_buffer[..len])?;
let conn_id = generate(header);
let client = if !clients.contains(conn_id) {
let conn = quiche::accept();
// send conn to accept queue?
conn
} else {
clients.get(conn_id)
}
let recv_info = quiche::RecvInfo {
to: socket.local_addr(),
from,
}
conn.recv(&mut rx_buffer[..len]?;
let len = conn.recv(&mut buf[..len], recv_info)?;
}
'write: loop {
let tx_buffer;
let (len, send_info) = conn.send(&mut tx_buffer)?;
socket.send_to(&tx_buffer[..len], send_info.to)?;
}
}
let data_len = conn.stream_send(stream_id, b"data", some_boolean_idk)?;
let mut tx_buffer = [u8; 1350];
let (tx_len, send_info) = conn.send(&mut tx_buffer)?;
sock.send_to(&tx_buffer[..tx_len], &send_info.to)?;
*/
/*
let sock: net::UdpSocket;
let conn: quiche::Connection;
loop {
'read: loop {
let rx_buffer;
let (len, from) = socket.recv_from(&mut rx_buffer)?;
let recv_info = quiche::RecvInfo {
to: socket.local_addr(),
from,
}
let len = conn.recv(&mut buf[..len], recv_info)?;
}
'write: loop {
let tx_buffer;
let (len, send_info) = conn.send(&mut tx_buffer)?;
socket.send_to(&tx_buffer[..len], send_info.to)?;
}
}
let data_len = conn.stream_send(stream_id, b"data", some_boolean_idk)?;
let mut tx_buffer = [u8; 1350];
let (tx_len, send_info) = conn.send(&mut tx_buffer)?;
sock.send_to(&tx_buffer[..tx_len], &send_info.to)?;
*/
use std::{
collections::HashMap,
future::Future,
net::{SocketAddr, ToSocketAddrs},
sync::Arc,
sync::{Arc, Mutex, RwLock},
time::Duration,
};
use anyhow::Context;
use ring::rand::*;
use temp_dir::TempDir;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream},
fs,
io::{AsyncRead, AsyncWrite, Interest},
net::{TcpListener, TcpStream, UdpSocket},
task,
};
struct SkipServerVerification;
impl SkipServerVerification {
fn new() -> Arc<Self> {
Arc::new(Self)
}
}
impl rustls::client::ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
end_entity: &rustls::Certificate,
intermediates: &[rustls::Certificate],
server_name: &rustls::ServerName,
scts: &mut dyn Iterator<Item = &[u8]>,
ocsp_response: &[u8],
now: std::time::SystemTime,
) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
Ok(rustls::client::ServerCertVerified::assertion())
}
}
fn configure_server() -> quinn::ServerConfig {
async fn configure_server() -> anyhow::Result<quiche::Config> {
let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
config.set_application_protos(&[b"sleepytunny"]); //change this to h3 eventually
config.set_max_idle_timeout(5000);
config.set_max_recv_udp_payload_size(1350);
config.set_max_send_udp_payload_size(1350);
config.set_initial_max_streams_bidi(100);
config.set_initial_max_streams_uni(100);
config.set_initial_max_data(10_000_000);
config.set_initial_max_stream_data_bidi_local(1_000_000);
config.set_initial_max_stream_data_bidi_remote(1_000_000);
config.set_initial_max_stream_data_uni(1_000_000);
config.set_disable_active_migration(true);
let mut cert_dir = TempDir::new()?;
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let cert_der = cert.serialize_der().unwrap();
let priv_key = cert.serialize_private_key_der();
let priv_key = rustls::PrivateKey(priv_key);
let cert_chain = vec![rustls::Certificate(cert_der.clone())];
let mut server_config = quinn::ServerConfig::with_single_cert(cert_chain, priv_key).unwrap();
let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
transport_config
.max_concurrent_uni_streams(0u8.into())
// .max_idle_timeout(Some(Duration::from_secs(60).try_into().unwrap()))
.keep_alive_interval(Some(Duration::from_secs(3)));
server_config
let cert_pem = cert_dir.child("cert.pem");
let priv_key_pem = cert_dir.child("key.pem");
fs::write(&cert_pem, cert.serialize_pem().unwrap()).await?;
fs::write(&priv_key_pem, cert.serialize_private_key_pem()).await?;
config.load_priv_key_from_pem_file(priv_key_pem.to_str().unwrap())?;
config.load_cert_chain_from_pem_file(cert_pem.to_str().unwrap())?;
config.verify_peer(false);
Ok(config)
}
fn configure_client() -> quinn::ClientConfig {
let crypto = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(SkipServerVerification::new())
.with_no_client_auth();
quinn::ClientConfig::new(Arc::new(crypto))
async fn configure_client() -> anyhow::Result<quiche::Config> {
let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
config.set_application_protos(&[b"sleepytunny"]); //change this to h3 eventually
config.set_max_idle_timeout(5000);
config.set_max_recv_udp_payload_size(1350);
config.set_max_send_udp_payload_size(1350);
config.set_initial_max_streams_bidi(100);
config.set_initial_max_streams_uni(100);
config.set_initial_max_data(10_000_000);
config.set_initial_max_stream_data_bidi_local(1_000_000);
config.set_initial_max_stream_data_bidi_remote(1_000_000);
config.set_initial_max_stream_data_uni(1_000_000);
config.set_disable_active_migration(true);
config.verify_peer(false);
Ok(config)
}
pub struct QuicModem {
endpoint: quinn::Endpoint,
addr: SocketAddr,
fn mint_token(header: &quiche::Header, from: &SocketAddr) -> Vec<u8> {
let mut token = Vec::with_capacity(6 + 16 + header.dcid.len());
token.extend_from_slice(b"quiche");
let addr = match from.ip() {
std::net::IpAddr::V4(a) => a.octets().to_vec(),
std::net::IpAddr::V6(a) => a.octets().to_vec(),
};
token.extend_from_slice(&addr);
token.extend_from_slice(&header.dcid);
token
}
impl QuicModem {
pub fn new_server(addr: impl ToSocketAddrs) -> anyhow::Result<Self> {
let addr = addr
.to_socket_addrs()?
.next()
.expect("bad server socketaddr");
let endpoint = quinn::Endpoint::server(configure_server(), addr)?;
eprintln!("initialized server");
Ok(Self { endpoint, addr })
fn validate_token<'a>(src: &SocketAddr, token: &'a [u8]) -> Option<quiche::ConnectionId<'a>> {
if token.len() < 6 {
return None;
}
pub fn new_client(addr: impl ToSocketAddrs) -> anyhow::Result<Self> {
let addr = addr
.to_socket_addrs()?
.next()
.expect("bad server socketaddr");
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())?;
endpoint.set_default_client_config(configure_client());
eprintln!("initialized client");
Ok(Self { endpoint, addr })
if &token[..6] != b"quiche" {
return None;
}
pub async fn accept(&mut self) -> anyhow::Result<QuicBiStream> {
let conn = self
.endpoint
.accept()
.await
.context("endpoint closed")?
.await?;
Ok(QuicBiStream::new(&conn, conn.open_bi().await?))
}
pub async fn connect(&mut self) -> anyhow::Result<QuicBiStream> {
let conn = self.endpoint.connect(self.addr, "localhost")?.await?;
Ok(QuicBiStream::new(&conn, conn.accept_bi().await?))
let token = &token[6..];
let addr = match src.ip() {
std::net::IpAddr::V4(a) => a.octets().to_vec(),
std::net::IpAddr::V6(a) => a.octets().to_vec(),
};
if token.len() < addr.len() || &token[..addr.len()] != &addr[..] {
return None;
}
Some(quiche::ConnectionId::from_ref(&token[addr.len()..]))
}
pub struct QuicBiStream {
conn: quinn::Connection,
send: quinn::SendStream,
recv: quinn::RecvStream,
struct PartialSend {
body: Vec<u8>,
written: usize,
}
impl QuicBiStream {
fn new(conn: &quinn::Connection, (send, recv): (quinn::SendStream, quinn::RecvStream)) -> Self {
pub struct ConnectedClient {
pub conn: quiche::Connection,
partial_sends: HashMap<u64, PartialSend>,
}
impl ConnectedClient {
fn new(conn: quiche::Connection) -> Self {
Self {
conn: conn.clone(),
send,
recv,
conn,
partial_sends: HashMap::new(),
}
}
pub fn remote_address(&self) -> SocketAddr {
self.conn.remote_address()
pub fn stream_send(&mut self, stream_id: u64, buf: &[u8]) {
if let Some(partial) = self.partial_sends.get_mut(&stream_id) {
partial.body.extend_from_slice(buf);
} else {
let written = match self.conn.stream_send(stream_id, buf, false) {
Ok(v) => v,
Err(quiche::Error::Done) => 0,
Err(e) => {
eprintln!("{} stream send failed {:?}", self.conn.trace_id(), e);
return;
}
};
if written < buf.len() {
let partial = PartialSend {
body: Vec::from(&buf[written..]),
written: 0,
};
self.partial_sends.insert(stream_id, partial);
}
}
}
fn handle_writable(&mut self, stream_id: u64) {
let conn = &mut self.conn;
// eprintln!("{} stream {} is writable", conn.trace_id(), stream_id);
let resp = match self.partial_sends.get_mut(&stream_id) {
Some(r) => r,
None => return,
};
let body = &resp.body[resp.written..];
let written = match conn.stream_send(stream_id, body, false) {
Ok(v) => v,
Err(quiche::Error::Done) => 0,
Err(e) => {
self.partial_sends.remove(&stream_id);
eprintln!("{} stream send failed {:?}", conn.trace_id(), e);
return;
}
};
resp.written += written;
if resp.written == resp.body.len() {
self.partial_sends.remove(&stream_id);
}
}
}
impl AsyncRead for QuicBiStream {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
unsafe { self.map_unchecked_mut(|s| &mut s.recv).poll_read(cx, buf) }
}
}
impl AsyncWrite for QuicBiStream {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
unsafe { self.map_unchecked_mut(|s| &mut s.send).poll_write(cx, buf) }
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
unsafe { self.map_unchecked_mut(|s| &mut s.send).poll_flush(cx) }
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
unsafe { self.map_unchecked_mut(|s| &mut s.send).poll_shutdown(cx) }
pub async fn server_loop(
app_config: Arc<crate::config::Configuration>,
mut handle_incoming: impl FnMut(&mut ConnectedClient) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
let local_addr = app_config.server()?.endpoint()?;
let socket = UdpSocket::bind(&local_addr).await?;
let mut config = configure_server().await?;
let conn_id_seed =
ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &SystemRandom::new()).unwrap();
let mut client_map: HashMap<quiche::ConnectionId<'static>, ConnectedClient> = HashMap::new();
let mut rx_buffer = [0u8; 65535];
let mut tx_buffer = [0u8; 1350];
loop {
'read: loop {
let (len, from) = match socket.try_recv_from(&mut rx_buffer) {
Ok(v) => v,
Err(e) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
break 'read;
} else {
return Err(e).context("try_recv_from failed");
}
}
};
let pkt_buf = &mut rx_buffer[..len];
let header = match quiche::Header::from_slice(pkt_buf, quiche::MAX_CONN_ID_LEN) {
Ok(v) => v,
Err(e) => {
eprintln!("parsing packet header failed: {e:?}");
continue 'read;
}
};
eprintln!("got packet {header:?}");
let conn_id = ring::hmac::sign(&conn_id_seed, &header.dcid);
let conn_id = &conn_id.as_ref()[..quiche::MAX_CONN_ID_LEN];
let conn_id = conn_id.to_vec().into();
let mut client = if !client_map.contains_key(&header.dcid)
&& !client_map.contains_key(&conn_id)
{
if header.ty != quiche::Type::Initial {
eprintln!("packet is not initial");
continue 'read;
}
if !quiche::version_is_supported(header.version) {
eprintln!("doing version negotiation");
let len = quiche::negotiate_version(&header.scid, &header.dcid, &mut tx_buffer)
.context("making version negotiation packet failed")?;
let out = &tx_buffer[..len];
match socket.try_send_to(out, from) {
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
eprintln!("version negotiation send_to() would block");
break 'read; // this break wasnt labeled in the example
}
res => res.context("version negotiation send_to failed")?,
};
continue 'read;
}
let mut scid = [0; quiche::MAX_CONN_ID_LEN];
scid.copy_from_slice(&conn_id);
let scid = quiche::ConnectionId::from_ref(&scid);
let token = header
.token
.as_ref()
.expect("token is always present in initial packets");
if token.is_empty() {
eprintln!("client didnt send token, doing stateless retry");
let new_token = mint_token(&header, &from);
let len = quiche::retry(
&header.scid,
&header.dcid,
&scid,
&new_token,
header.version,
&mut tx_buffer,
)
.context("makin retry packet failed")?;
let out = &tx_buffer[..len];
match socket.try_send_to(out, from) {
Ok(l) => {
eprintln!("wrote {l} bytes of the {len} byte retry packet");
continue 'read;
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
eprintln!("retry send_to() would block");
break 'read; // this break wasnt labeled in the example
}
Err(e) => return Err(e).context("version negotiation send_to failed"),
};
} else {
let odcid = validate_token(&from, token);
// The token was not valid, meaning the retry failed,
// so drop the packet
if odcid.is_none() {
eprintln!("Invalid address validation token");
continue 'read;
};
if scid.len() != header.dcid.len() {
eprintln!("Invalid destination connection ID");
continue 'read;
}
// Reuse the source connection id we sent in the retry packet
// instead of changing it again
let scid = header.dcid.clone();
eprintln!("New connection: dcid={:?} scid={:?}", header.dcid, scid);
let conn = quiche::accept(&scid, odcid.as_ref(), local_addr, from, &mut config)
.context("accepting connection failed")?;
client_map.insert(scid.clone(), ConnectedClient::new(conn));
client_map.get_mut(&scid).unwrap()
}
} else {
match client_map.get_mut(&header.dcid) {
Some(v) => v,
None => client_map.get_mut(&conn_id).unwrap(),
}
};
let recv_info = quiche::RecvInfo {
to: local_addr.clone(),
from,
};
let read = match client.conn.recv(pkt_buf, recv_info) {
Ok(v) => v,
Err(e) => {
eprintln!("{} recv failed: {:?}", client.conn.trace_id(), e);
continue 'read;
}
};
eprintln!("{} processed {} bytes", client.conn.trace_id(), read);
}
for client in client_map.values_mut() {
if client.conn.is_in_early_data() || client.conn.is_established() {
for stream_id in client.conn.writable() {
client.handle_writable(stream_id);
}
handle_incoming(client)?;
}
}
for client in client_map.values_mut() {
'write: loop {
let (write, send_info) = match client.conn.send(&mut tx_buffer) {
Ok(v) => v,
Err(quiche::Error::Done) => {
// eprintln!("{} done writing", client.conn.trace_id());
break 'write;
}
Err(e) => {
eprintln!("{} send failed: {:?}", client.conn.trace_id(), e);
client.conn.close(false, 0x1, b"fail").ok();
break 'write;
}
};
match socket.send_to(&tx_buffer[..write], send_info.to).await {
Ok(written) => eprintln!("sent {written} bytes of {write} to {}", send_info.to),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
eprintln!("send_to() would block");
break 'write;
}
Err(e) => return Err(e).context("send_to() failed"),
};
}
}
// Garbage collect closed connections
client_map.retain(|_, ref mut client| {
if client.conn.is_closed() {
eprintln!(
"{} connection garbage collected {:?}",
client.conn.trace_id(),
client.conn.stats()
);
}
!client.conn.is_closed()
});
}
}
pub async fn client_loop(
app_config: Arc<crate::config::Configuration>,
mut f: impl FnMut(&mut quiche::Connection) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
let mut qconfig = configure_client().await?;
let socket = UdpSocket::bind(app_config.client()?.local_bind_addr()?).await?;
let mut scid = [0u8; quiche::MAX_CONN_ID_LEN];
SystemRandom::new().fill(&mut scid[..]).unwrap();
let scid = quiche::ConnectionId::from_ref(&scid);
let local_addr = socket.local_addr().unwrap();
let mut rx_buffer = [0u8; 65535];
let mut tx_buffer = [0u8; 1350];
let mut conn = quiche::connect(
app_config.client()?.server_name()?.map(String::as_str),
&scid,
local_addr,
app_config.client()?.endpoint()?,
&mut qconfig,
)?;
// tokio::time::sleep(Duration::from_millis(200)).await;
let (write, send_info) = conn
.send(&mut tx_buffer[..])
.context("initial send failed")?;
let written = socket.send_to(&tx_buffer[..write], send_info.to).await?;
eprintln!("wrote {written} bytes of {write} byte initiation");
loop {
'read: loop {
let (len, from) = match socket.try_recv_from(&mut rx_buffer) {
Ok(v) => v,
Err(e) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
break 'read;
} else {
return Err(e).context("try_recv_from failed");
}
}
};
let recv_info = quiche::RecvInfo {
to: local_addr.clone(),
from,
};
eprintln!("recieving {len} bytes from {from}");
conn.recv(&mut rx_buffer[..len], recv_info)
.context("recv failed")?;
}
if conn.is_closed() {
eprintln!("connection closed");
return Ok(());
}
if conn.is_established() {
f(&mut conn)?;
}
'write: loop {
let (write, send_info) = match conn.send(&mut tx_buffer) {
Ok(v) => v,
Err(quiche::Error::Done) => {
// eprint!("{} done writing\r", conn.trace_id());
break 'write;
}
Err(e) => {
conn.close(false, 0x1, b"fail").ok();
return Err(e).context("send failed");
}
};
eprintln!("sending {write} bytes");
let v = socket
.send_to(&tx_buffer[..write], send_info.to)
.await
.context("send_to failed")?;
eprintln!("sent {v} bytes of {write}");
}
}
}