sleepytunny/src/encoder.rs

122 lines
4.0 KiB
Rust

/*
* --------------------
* THIS FILE IS LICENSED UNDER THE FOLLOWING TERMS
*
* this code may not be used for any purpose. be gay, do crime
*
* THE FOLLOWING MESSAGE IS NOT A LICENSE
*
* <barrow@tilde.team> wrote this file.
* by reading this text, you are reading "TRANS RIGHTS".
* this file and the content within it is the gay agenda.
* if we meet some day, and you think this stuff is worth it,
* you can buy me a beer, tea, or something stronger.
* -Ezra Barrow
* --------------------
*/
use base64::encoded_len;
use base64::engine::general_purpose::STANDARD_NO_PAD as BASE64_ENGINE;
use base64::Engine;
use std::io;
use std::{
pin::Pin,
task::{ready, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
const INITIAL_BUF_SIZE: usize = 1024;
pub struct Encoder<T: AsyncRead + AsyncWrite> {
inner_read: tokio::io::ReadHalf<T>,
inner_write: tokio::io::WriteHalf<T>,
// actually, if both sides have the same internal BUF_SIZE, one of these should be len BUF_SIZE
read_buffer: Vec<u8>,
write_buffer: Vec<u8>,
}
impl<T: AsyncRead + AsyncWrite> Encoder<T> {
pub fn new(inner: T) -> Self {
let (inner_read, inner_write) = tokio::io::split(inner);
Self {
inner_read,
inner_write,
read_buffer: Vec::with_capacity(INITIAL_BUF_SIZE / 4 * 3),
write_buffer: Vec::with_capacity(INITIAL_BUF_SIZE / 3 * 4 + 4),
}
}
}
impl<T: AsyncRead + AsyncWrite> AsyncRead for Encoder<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let remaining_octets = base64::decoded_len_estimate(buf.remaining());
// dbg!(
// buf.filled().len(),
// buf.remaining(),
// buf.remaining().div(4).mul(3),
// remaining_octets
// );
let (inner, mut bytebuf) = unsafe {
let s = self.get_unchecked_mut();
let inner = Pin::new_unchecked(&mut s.inner_read);
let buffer = &mut s.read_buffer;
if remaining_octets > buffer.len() {
buffer.resize(remaining_octets, 0);
}
(inner, ReadBuf::new(&mut buffer[..remaining_octets]))
};
ready!(inner.poll_read(cx, &mut bytebuf))?;
let base64_len = BASE64_ENGINE
.encode_slice(bytebuf.filled(), buf.initialize_unfilled())
.expect("read buffer too small somehow !");
buf.advance(base64_len);
Poll::Ready(Ok(()))
}
}
impl<T: AsyncRead + AsyncWrite> AsyncWrite for Encoder<T> {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let (inner, bytebuf) = unsafe {
let s = self.get_unchecked_mut();
let inner = Pin::new_unchecked(&mut s.inner_write);
let buffer = &mut s.write_buffer;
buffer.clear();
(inner, buffer)
};
BASE64_ENGINE
.decode_vec(buf, bytebuf)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let decoded_bytes_written = ready!(inner.poll_write(cx, bytebuf))?;
let encoded_bytes_written = encoded_len(decoded_bytes_written, false).unwrap();
assert_eq!(encoded_bytes_written, buf.len());
Poll::Ready(Ok(encoded_bytes_written))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let _inner = unsafe { self.map_unchecked_mut(|s| &mut s.inner_write) };
// this errors out for some reason
// let r = ready!(inner.poll_flush(cx)).unwrap();
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let inner = unsafe { self.map_unchecked_mut(|s| &mut s.inner_write) };
inner.poll_shutdown(cx)
}
}